Insane_DNS/src/main.cpp

485 lines
16 KiB
C++

/**
* Brett Terpstra 6920201
* Michael Boulos 6973523
* Adrian Pinu 6970677
*/
#include <iostream>
#include <asio.hpp>
#include <blt/std/logging.h>
#include <blt/std/memory.h>
#include <blt/compatibility.h>
#include <string>
#include <array>
#include <thread>
#include <iomanip>
#include <vector>
#include <typeinfo>
#include <type_traits>
#include <unordered_set>
#include "ip.h"
#include "insane_dns/util.h"
/*
* ----------------------------
* | CONFIG |
* ----------------------------
*/
/** What port to run the server on */
static constexpr unsigned short int SERVER_PORT = 5555;
/**
* How should we bind to the local machine? By default this will use IPv4 and IPv6 with the default ASIO endpoint.
* Only change this if you really need to and know what you are doing.
*/
static BLT_CPP20_CONSTEXPR bind_address address() {
return {SERVER_PORT};
}
/** should we strictly match results? ie block `*wikipedia.org*` or just `wikipedia.org`? */
static constexpr bool STRICT_MATCHING = false;
/** DNS server to use for forwarding to / resolving DNS requests */
static inline BLT_CPP20_CONSTEXPR std::string DNS_SERVER_IP()
{
return "8.8.8.8";
}
/** replacement IP address. Make sure this is a 4 octet string seperated by `.` */
static inline BLT_CPP20_CONSTEXPR IPAddress REPLACEMENT_IP()
{
return {"139.57.100.6"};
}
/**
* List of disallowed domains.
* Note: if you are using STRICT_MATCHING=false it will not match to the root domain.
* Eg it will match `*en.wikipedia.org*` NOT `*wikipedia.org*`
*/
static const std::unordered_set<std::string> DISALLOWED_DOMAINS{
"en.wikipedia.org",
"zombo.com"
};
/*
* -----------------------------------
* | Do Not Change |
* -----------------------------------
*/
// these features were planned but not added because I realized you guys won't care or give extra marks which broke my obsession with it
// so uhh don't change em otherwise the code will break :3
// as it is im still adding new features and stuff (TCP) and messing with the code trying to get it cleaner
/** true -> only match A records ; false -> match any named record (configure with NON_STRICT_REPLACE_ALL) */
static constexpr bool STRICT_FILTERING = true;
/** true -> match all records ; false -> match only records we might want to replace (A, AAAA, CNAME) */
static constexpr bool NON_STRICT_REPLACE_ALL = true;
// was going to add ~~TCP~~ (this is now a thing for full DIG support. "can you dig it, sucka!?") and ad blocking support
// that also isn't going to happen now.
/** list of web address to download the ad block lists from */
static BLT_CPP20_CONSTEXPR std::vector<std::string> BLOCK_LISTS{};
/** true -> block ad DNS requests ; false -> do nothing */
static constexpr bool BLOCK_ADS = false;
/** true -> send back the REPLACEMENT_IP() ; false -> send back a fail state in the DNS request. */
static constexpr bool REDIRECT_ADS = true;
// 5F826B
// DNS data contains:
// 2 bytes for transaction id
// 2 bytes for flags
// 2 bytes for number of questions
// 2 bytes for Answer RRs
// 2 bytes for Authority RRs
// 2 bytes for Additional RRs
// question format:
// 1 byte for length
// (length) bytes per section
// ... until length 0
// 2 bytes for QTYPE (useless)
// 2 bytes for QCLASS (useless)
// answer format:
// 2 byte for domain name (offset ptr, still not 100 on from where)
// 2 byte for type
// 2 byte for class
// 4 byte for time to live
// 2 byte for length of data
// (lengthy) byte for data
class send_buffer;
/**
* This data structure represents a DNS question. When constructed it will read the FULL domain as a single string, along with the QTYPE and QCLASS
* It is safe to read QDCOUNT questions by constructing this class.
* The question will be reconstructed by the send_buffer class.
*/
class question
{
friend send_buffer;
private:
std::string domain;
uint16_t QTYPE = 0;
uint16_t QCLASS = 0;
public:
explicit question(const blt::byte_reader& reader)
{
// process the full question
while (true)
{
uint8_t length = reader.next();
if (length == 0)
break;
if (!domain.empty())
domain += '.';
for (uint8_t j = 0; j < length; j++)
domain += static_cast<char>(reader.next());
}
// Skip QTYPE and QCLASS
// but keep a copy of it for future writing
reader.to(QTYPE);
reader.to(QCLASS);
}
const std::string& operator()()
{
return domain;
}
};
/**
* This data structure represents a DNS answer. When constructed it will read the FULL answer along with the associated data. It is therefore safe
* to read ANCOUNT by constructing a series of answers which read from the byte stream. This class cannot be copied but can be moved.
* The answer will be rebuilt by the send_buffer class for you.
*/
class answer
{
friend send_buffer;
private:
mutable uint16_t NAME = 0;
uint16_t TYPE = 0;
uint16_t CLASS = 0;
uint32_t TTL = 0;
uint16_t RDLENGTH = 0;
bool requires_reset = false;
blt::scoped_buffer<unsigned char> RDATA;
public:
explicit answer(const blt::byte_reader& reader)
{
reader.to(NAME);
reader.to(TYPE);
reader.to(CLASS);
reader.to(TTL);
reader.to(RDLENGTH);
RDATA = blt::scoped_buffer<unsigned char>(RDLENGTH);
reader.copy(RDATA.data(), RDLENGTH);
}
[[nodiscard]] uint16_t type() const
{
return TYPE;
}
inline void substitute(const IPAddress& addr)
{
BLT_DEBUG("Substituting with replacement address '%s'", REPLACEMENT_IP().asString.c_str());
BLT_ASSERT(RDLENGTH == 4);
std::memcpy(RDATA.data(), addr.octets, 4);
}
inline void setARecord(const IPAddress& addr)
{
BLT_DEBUG("Setting answer to A record");
NAME = 0;
NAME |= (0b11 << 14);
requires_reset = true;
BLT_TRACE(NAME);
RDATA = blt::scoped_buffer<unsigned char>(4);
RDLENGTH = 4;
TYPE = 1;
CLASS = 1;
substitute(addr);
}
inline void reset(size_t offset) const
{
if (!requires_reset)
return;
// like I said not 100 on how to construct the ptr
// seems to be causing issues. I've stopped working on this as it's not required.
auto i16 = static_cast<uint16_t>(offset) & (~(0b11 << 14));
NAME |= i16;
}
// rule of 5
// (there used to be a destructor)
answer(const answer& copy) = delete;
answer& operator=(const answer& copy) = delete;
answer(answer&& move) noexcept
{
NAME = move.NAME;
TYPE = move.TYPE;
CLASS = move.CLASS;
TTL = move.TTL;
RDLENGTH = move.RDLENGTH;
RDATA = std::move(move.RDATA);
}
answer& operator=(answer&& move) noexcept
{
NAME = 0;
NAME = move.NAME;
TYPE = move.TYPE;
CLASS = move.CLASS;
TTL = move.TTL;
RDLENGTH = move.RDLENGTH;
RDATA = std::move(move.RDATA);
return *this;
}
// there used to be a destructor
~answer() = default;
};
class send_buffer
{
private:
mutable std::array<unsigned char, 65535> internal_data{};
mutable size_t write_index = 0;
public:
send_buffer() = default;
void write(unsigned char* data, size_t size) const
{
std::memcpy(&internal_data[write_index], data, size);
write_index += size;
}
template<typename T>
void write(const T& t) const
{
if constexpr (std::is_arithmetic_v<T>)
{
blt::mem::toBytes(t, &internal_data[write_index]);
write_index += sizeof(T);
} else if constexpr (std::is_same_v<T, answer>)
{
write(t.NAME);
write(t.TYPE);
write(t.CLASS);
write(t.TTL);
write(t.RDLENGTH);
std::memcpy(&internal_data[write_index], t.RDATA.data(), t.RDLENGTH);
write_index += t.RDLENGTH;
} else if constexpr (std::is_same_v<T, question>)
{
auto labels = blt::string::split(t.domain, '.');
for (const auto& label : labels)
{
auto length = static_cast<uint8_t>(label.length());
write(length);
std::memcpy(&internal_data[write_index], &label[0], length);
write_index += length;
}
// write length of 0 to signal end of labels
write('\0');
write(t.QTYPE);
write(t.QCLASS);
} else
static_assert("Data type not supported!");
}
void reset() const
{
write_index = 0;
}
[[nodiscard]] unsigned char operator[](size_t i)
{
return internal_data[i];
}
[[nodiscard]] unsigned char* data()
{
return internal_data.data();
}
[[nodiscard]] size_t size() const
{
return write_index;
}
auto buffer() const
{
return asio::buffer(internal_data.data(), size());
}
};
inline bool shouldReplace(const answer& a)
{
// a records will be handled in either case, check for others like AAAA or CNAME
// TODO: add enums to this + a way to add custom types
return NON_STRICT_REPLACE_ALL || a.type() == 28 || a.type() == 5;
}
void process_answers(std::vector<answer>& answers)
{
for (auto& a : answers)
{
if (a.type() == 1)
{
a.substitute(REPLACEMENT_IP());
} else if (!STRICT_FILTERING && shouldReplace(a))
{
a.setARecord(REPLACEMENT_IP());
}
}
}
template<typename T, typename INFO, typename MESSENGER>
request_info handle_forward_request(MESSENGER messenger, const INFO& info, const T& input_recv_buffer, size_t bytes, T& forward_recv_buffer)
{
// get the number of questions
uint16_t questions; // yes I made this part of my library just for this :3
blt::mem::fromBytes(&input_recv_buffer[info.QUESTIONS_BEGIN], questions); // i hate little endian
BLT_INFO("(%s) Bytes received: %d with %d questions", blt::is_UDP_or_TCP<INFO>().c_str(), bytes, questions);
// forward to google.
size_t out_bytes;
const auto* data = input_recv_buffer.data();
messenger(DNS_SERVER_IP(), data, bytes, forward_recv_buffer, out_bytes);
uint16_t num_of_answers;
blt::mem::fromBytes(&forward_recv_buffer[info.ANSWERS_BEGIN], num_of_answers);
return {out_bytes, num_of_answers};
}
template<typename T, typename INFO>
void handle_response(const INFO& info, send_buffer& return_send_buffer, request_info rq_info, T& forward_recv_buffer)
{
blt::byte_reader reader(forward_recv_buffer.data(), forward_recv_buffer.size(), info.HEADER_END);
auto TYPE_STR = blt::is_UDP_or_TCP<INFO>();
BLT_INFO("(%s) Bytes answered: %d with %d answers", TYPE_STR.c_str(), rq_info.number_of_bytes, rq_info.number_of_answers);
// no one actually does multiple questions. trying to do it in dig is not easy
// and the standard isn't really designed for this (how do we handle if one question errors but the other doesn't? there is only
// one return code.)
question q(reader);
std::vector<answer> answers;
for (size_t i = 0; i < rq_info.number_of_answers; i++)
{
answer a(reader);
answers.push_back(std::move(a));
}
BLT_INFO("(%s) DOMAIN: %s", TYPE_STR.c_str(), q().c_str());
if (STRICT_MATCHING && BLT_CONTAINS(DISALLOWED_DOMAINS, q()))
process_answers(answers);
else if (!STRICT_MATCHING)
{
// linear search the domains for contains. Maybe find a better way to do this.
for (const auto& v : DISALLOWED_DOMAINS)
if (blt::string::contains(q(), v))
process_answers(answers);
}
return_send_buffer.write(forward_recv_buffer.data(), info.HEADER_END);
// need to cache this value oh wait we aren't doing this anymore
auto question_offset = return_send_buffer.size();
return_send_buffer.write(q);
for (const answer& a : answers)
{
BLT_TRACE("(%s) Writing answer with type of %d", TYPE_STR.c_str(), a.type());
a.reset(question_offset);
return_send_buffer.write(a);
}
return_send_buffer.write(reader.from(), rq_info.number_of_bytes - reader.last());
}
void run_udp_server()
{
try
{
asio::io_context io_context;
udp::socket socket(io_context, toUDPEndpoint(address()));
blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE};
blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE};
while (true)
{
udp::endpoint remote_endpoint;
size_t bytes = socket.receive_from(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), remote_endpoint);
auto rq_info = handle_forward_request(blt::network::sendUDPMessage, DNS_UDP_INFO, input_recv_buffer, bytes, forward_recv_buffer);
send_buffer return_send_buffer;
handle_response(DNS_UDP_INFO, return_send_buffer, rq_info, forward_recv_buffer);
asio::error_code ignored_error;
socket.send_to(return_send_buffer.buffer(), remote_endpoint, 0, ignored_error);
}
}
catch (std::exception& e)
{
BLT_ERROR(e.what());
}
}
void run_tcp_server()
{
try
{
asio::io_context io_context;
tcp::acceptor acceptor(io_context, toTCPEndpoint(address()));
blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE};
blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE};
while (true)
{
tcp::socket socket(io_context);
acceptor.accept(socket);
asio::error_code error;
size_t bytes = socket.read_some(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), error);
if (error == asio::error::eof)
break;
else if (error)
throw asio::system_error(error);
auto rq_info = handle_forward_request(blt::network::sendTCPMessage, DNS_TCP_INFO, input_recv_buffer, bytes, forward_recv_buffer);
send_buffer return_send_buffer;
handle_response(DNS_TCP_INFO, return_send_buffer, rq_info, forward_recv_buffer);
asio::write(socket, return_send_buffer.buffer());
}
}
catch (std::exception& e)
{
BLT_ERROR(e.what());
}
}
int main()
{
BLT_INFO("Creating UDP Server");
std::thread udp_server(run_udp_server);
BLT_INFO("Creating TCP Server");
std::thread tcp_server(run_tcp_server);
BLT_INFO("Awaiting");
udp_server.join();
tcp_server.join();
return 0;
}