/** * Brett Terpstra 6920201 * Michael Boulos 6973523 * Adrian Pinu 6970677 */ #include <iostream> #include <asio.hpp> #include <blt/std/logging.h> #include <blt/std/memory.h> #include <blt/compatibility.h> #include <string> #include <array> #include <thread> #include <iomanip> #include <vector> #include <typeinfo> #include <type_traits> #include <unordered_set> #include "ip.h" #include "insane_dns/util.h" /* * ---------------------------- * | CONFIG | * ---------------------------- */ /** What port to run the server on */ static constexpr unsigned short int SERVER_PORT = 5555; /** * How should we bind to the local machine? By default this will use IPv4 and IPv6 with the default ASIO endpoint. * Only change this if you really need to and know what you are doing. */ static BLT_CPP20_CONSTEXPR bind_address address() { return {SERVER_PORT}; } /** should we strictly match results? ie block `*wikipedia.org*` or just `wikipedia.org`? */ static constexpr bool STRICT_MATCHING = false; /** DNS server to use for forwarding to / resolving DNS requests */ static inline BLT_CPP20_CONSTEXPR std::string DNS_SERVER_IP() { return "8.8.8.8"; } /** replacement IP address. Make sure this is a 4 octet string seperated by `.` */ static inline BLT_CPP20_CONSTEXPR IPAddress REPLACEMENT_IP() { return {"139.57.100.6"}; } /** * List of disallowed domains. * Note: if you are using STRICT_MATCHING=false it will not match to the root domain. * Eg it will match `*en.wikipedia.org*` NOT `*wikipedia.org*` */ static const std::unordered_set<std::string> DISALLOWED_DOMAINS{ "en.wikipedia.org", "zombo.com" }; /* * ----------------------------------- * | Do Not Change | * ----------------------------------- */ // these features were planned but not added because I realized you guys won't care or give extra marks which broke my obsession with it // so uhh don't change em otherwise the code will break :3 // as it is im still adding new features and stuff (TCP) and messing with the code trying to get it cleaner /** true -> only match A records ; false -> match any named record (configure with NON_STRICT_REPLACE_ALL) */ static constexpr bool STRICT_FILTERING = true; /** true -> match all records ; false -> match only records we might want to replace (A, AAAA, CNAME) */ static constexpr bool NON_STRICT_REPLACE_ALL = true; // was going to add ~~TCP~~ (this is now a thing for full DIG support. "can you dig it, sucka!?") and ad blocking support // that also isn't going to happen now. /** list of web address to download the ad block lists from */ static BLT_CPP20_CONSTEXPR std::vector<std::string> BLOCK_LISTS{}; /** true -> block ad DNS requests ; false -> do nothing */ static constexpr bool BLOCK_ADS = false; /** true -> send back the REPLACEMENT_IP() ; false -> send back a fail state in the DNS request. */ static constexpr bool REDIRECT_ADS = true; // 5F826B // DNS data contains: // 2 bytes for transaction id // 2 bytes for flags // 2 bytes for number of questions // 2 bytes for Answer RRs // 2 bytes for Authority RRs // 2 bytes for Additional RRs // question format: // 1 byte for length // (length) bytes per section // ... until length 0 // 2 bytes for QTYPE (useless) // 2 bytes for QCLASS (useless) // answer format: // 2 byte for domain name (offset ptr, still not 100 on from where) // 2 byte for type // 2 byte for class // 4 byte for time to live // 2 byte for length of data // (lengthy) byte for data class send_buffer; /** * This data structure represents a DNS question. When constructed it will read the FULL domain as a single string, along with the QTYPE and QCLASS * It is safe to read QDCOUNT questions by constructing this class. * The question will be reconstructed by the send_buffer class. */ class question { friend send_buffer; private: std::string domain; uint16_t QTYPE = 0; uint16_t QCLASS = 0; public: explicit question(const blt::byte_reader& reader) { // process the full question while (true) { uint8_t length = reader.next(); if (length == 0) break; if (!domain.empty()) domain += '.'; for (uint8_t j = 0; j < length; j++) domain += static_cast<char>(reader.next()); } // Skip QTYPE and QCLASS // but keep a copy of it for future writing reader.to(QTYPE); reader.to(QCLASS); } const std::string& operator()() { return domain; } }; /** * This data structure represents a DNS answer. When constructed it will read the FULL answer along with the associated data. It is therefore safe * to read ANCOUNT by constructing a series of answers which read from the byte stream. This class cannot be copied but can be moved. * The answer will be rebuilt by the send_buffer class for you. */ class answer { friend send_buffer; private: mutable uint16_t NAME = 0; uint16_t TYPE = 0; uint16_t CLASS = 0; uint32_t TTL = 0; uint16_t RDLENGTH = 0; bool requires_reset = false; blt::scoped_buffer<unsigned char> RDATA; public: explicit answer(const blt::byte_reader& reader) { reader.to(NAME); reader.to(TYPE); reader.to(CLASS); reader.to(TTL); reader.to(RDLENGTH); RDATA = blt::scoped_buffer<unsigned char>(RDLENGTH); reader.copy(RDATA.data(), RDLENGTH); } [[nodiscard]] uint16_t type() const { return TYPE; } inline void substitute(const IPAddress& addr) { BLT_DEBUG("Substituting with replacement address '%s'", REPLACEMENT_IP().asString.c_str()); BLT_ASSERT(RDLENGTH == 4); std::memcpy(RDATA.data(), addr.octets, 4); } inline void setARecord(const IPAddress& addr) { BLT_DEBUG("Setting answer to A record"); NAME = 0; NAME |= (0b11 << 14); requires_reset = true; BLT_TRACE(NAME); RDATA = blt::scoped_buffer<unsigned char>(4); RDLENGTH = 4; TYPE = 1; CLASS = 1; substitute(addr); } inline void reset(size_t offset) const { if (!requires_reset) return; // like I said not 100 on how to construct the ptr // seems to be causing issues. I've stopped working on this as it's not required. auto i16 = static_cast<uint16_t>(offset) & (~(0b11 << 14)); NAME |= i16; } // rule of 5 // (there used to be a destructor) answer(const answer& copy) = delete; answer& operator=(const answer& copy) = delete; answer(answer&& move) noexcept { NAME = move.NAME; TYPE = move.TYPE; CLASS = move.CLASS; TTL = move.TTL; RDLENGTH = move.RDLENGTH; RDATA = std::move(move.RDATA); } answer& operator=(answer&& move) noexcept { NAME = 0; NAME = move.NAME; TYPE = move.TYPE; CLASS = move.CLASS; TTL = move.TTL; RDLENGTH = move.RDLENGTH; RDATA = std::move(move.RDATA); return *this; } // there used to be a destructor ~answer() = default; }; class send_buffer { private: mutable std::array<unsigned char, 65535> internal_data{}; mutable size_t write_index = 0; public: send_buffer() = default; void write(unsigned char* data, size_t size) const { std::memcpy(&internal_data[write_index], data, size); write_index += size; } template<typename T> void write(const T& t) const { if constexpr (std::is_arithmetic_v<T>) { blt::mem::toBytes(t, &internal_data[write_index]); write_index += sizeof(T); } else if constexpr (std::is_same_v<T, answer>) { write(t.NAME); write(t.TYPE); write(t.CLASS); write(t.TTL); write(t.RDLENGTH); std::memcpy(&internal_data[write_index], t.RDATA.data(), t.RDLENGTH); write_index += t.RDLENGTH; } else if constexpr (std::is_same_v<T, question>) { auto labels = blt::string::split(t.domain, '.'); for (const auto& label : labels) { auto length = static_cast<uint8_t>(label.length()); write(length); std::memcpy(&internal_data[write_index], &label[0], length); write_index += length; } // write length of 0 to signal end of labels write('\0'); write(t.QTYPE); write(t.QCLASS); } else static_assert("Data type not supported!"); } void reset() const { write_index = 0; } [[nodiscard]] unsigned char operator[](size_t i) { return internal_data[i]; } [[nodiscard]] unsigned char* data() { return internal_data.data(); } [[nodiscard]] size_t size() const { return write_index; } auto buffer() const { return asio::buffer(internal_data.data(), size()); } }; inline bool shouldReplace(const answer& a) { // a records will be handled in either case, check for others like AAAA or CNAME // TODO: add enums to this + a way to add custom types return NON_STRICT_REPLACE_ALL || a.type() == 28 || a.type() == 5; } void process_answers(std::vector<answer>& answers) { for (auto& a : answers) { if (a.type() == 1) { a.substitute(REPLACEMENT_IP()); } else if (!STRICT_FILTERING && shouldReplace(a)) { a.setARecord(REPLACEMENT_IP()); } } } template<typename T, typename INFO, typename MESSENGER> request_info handle_forward_request(MESSENGER messenger, const INFO& info, const T& input_recv_buffer, size_t bytes, T& forward_recv_buffer) { // get the number of questions uint16_t questions; // yes I made this part of my library just for this :3 blt::mem::fromBytes(&input_recv_buffer[info.QUESTIONS_BEGIN], questions); // i hate little endian BLT_INFO("(%s) Bytes received: %d with %d questions", blt::is_UDP_or_TCP<INFO>().c_str(), bytes, questions); // forward to google. size_t out_bytes; const auto* data = input_recv_buffer.data(); messenger(DNS_SERVER_IP(), data, bytes, forward_recv_buffer, out_bytes); uint16_t num_of_answers; blt::mem::fromBytes(&forward_recv_buffer[info.ANSWERS_BEGIN], num_of_answers); return {out_bytes, num_of_answers}; } template<typename T, typename INFO> void handle_response(const INFO& info, send_buffer& return_send_buffer, request_info rq_info, T& forward_recv_buffer) { blt::byte_reader reader(forward_recv_buffer.data(), forward_recv_buffer.size(), info.HEADER_END); auto TYPE_STR = blt::is_UDP_or_TCP<INFO>(); BLT_INFO("(%s) Bytes answered: %d with %d answers", TYPE_STR.c_str(), rq_info.number_of_bytes, rq_info.number_of_answers); // no one actually does multiple questions. trying to do it in dig is not easy // and the standard isn't really designed for this (how do we handle if one question errors but the other doesn't? there is only // one return code.) question q(reader); std::vector<answer> answers; for (size_t i = 0; i < rq_info.number_of_answers; i++) { answer a(reader); answers.push_back(std::move(a)); } BLT_INFO("(%s) DOMAIN: %s", TYPE_STR.c_str(), q().c_str()); if (STRICT_MATCHING && BLT_CONTAINS(DISALLOWED_DOMAINS, q())) process_answers(answers); else if (!STRICT_MATCHING) { // linear search the domains for contains. Maybe find a better way to do this. for (const auto& v : DISALLOWED_DOMAINS) if (blt::string::contains(q(), v)) process_answers(answers); } return_send_buffer.write(forward_recv_buffer.data(), info.HEADER_END); // need to cache this value oh wait we aren't doing this anymore auto question_offset = return_send_buffer.size(); return_send_buffer.write(q); for (const answer& a : answers) { BLT_TRACE("(%s) Writing answer with type of %d", TYPE_STR.c_str(), a.type()); a.reset(question_offset); return_send_buffer.write(a); } return_send_buffer.write(reader.from(), rq_info.number_of_bytes - reader.last()); } void run_udp_server() { try { asio::io_context io_context; udp::socket socket(io_context, toUDPEndpoint(address())); blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE}; blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE}; while (true) { udp::endpoint remote_endpoint; size_t bytes = socket.receive_from(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), remote_endpoint); auto rq_info = handle_forward_request(blt::network::sendUDPMessage, DNS_UDP_INFO, input_recv_buffer, bytes, forward_recv_buffer); send_buffer return_send_buffer; handle_response(DNS_UDP_INFO, return_send_buffer, rq_info, forward_recv_buffer); asio::error_code ignored_error; socket.send_to(return_send_buffer.buffer(), remote_endpoint, 0, ignored_error); } } catch (std::exception& e) { BLT_ERROR(e.what()); } } void run_tcp_server() { try { asio::io_context io_context; tcp::acceptor acceptor(io_context, toTCPEndpoint(address())); blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE}; blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE}; while (true) { tcp::socket socket(io_context); acceptor.accept(socket); asio::error_code error; size_t bytes = socket.read_some(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), error); if (error == asio::error::eof) break; else if (error) throw asio::system_error(error); auto rq_info = handle_forward_request(blt::network::sendTCPMessage, DNS_TCP_INFO, input_recv_buffer, bytes, forward_recv_buffer); send_buffer return_send_buffer; handle_response(DNS_TCP_INFO, return_send_buffer, rq_info, forward_recv_buffer); asio::write(socket, return_send_buffer.buffer()); } } catch (std::exception& e) { BLT_ERROR(e.what()); } } int main() { BLT_INFO("Creating UDP Server"); std::thread udp_server(run_udp_server); BLT_INFO("Creating TCP Server"); std::thread tcp_server(run_tcp_server); BLT_INFO("Awaiting"); udp_server.join(); tcp_server.join(); return 0; }