485 lines
16 KiB
C++
485 lines
16 KiB
C++
/**
|
|
* Brett Terpstra 6920201
|
|
* Michael Boulos 6973523
|
|
* Adrian Pinu 6970677
|
|
*/
|
|
|
|
#include <iostream>
|
|
|
|
#include <asio.hpp>
|
|
#include <blt/std/logging.h>
|
|
#include <blt/std/memory.h>
|
|
#include <blt/compatibility.h>
|
|
#include <string>
|
|
#include <array>
|
|
#include <thread>
|
|
#include <iomanip>
|
|
#include <vector>
|
|
#include <typeinfo>
|
|
#include <type_traits>
|
|
#include <unordered_set>
|
|
#include "ip.h"
|
|
#include "insane_dns/util.h"
|
|
|
|
/*
|
|
* ----------------------------
|
|
* | CONFIG |
|
|
* ----------------------------
|
|
*/
|
|
/** What port to run the server on */
|
|
static constexpr unsigned short int SERVER_PORT = 5555;
|
|
|
|
/**
|
|
* How should we bind to the local machine? By default this will use IPv4 and IPv6 with the default ASIO endpoint.
|
|
* Only change this if you really need to and know what you are doing.
|
|
*/
|
|
static BLT_CPP20_CONSTEXPR bind_address address() {
|
|
return {SERVER_PORT};
|
|
}
|
|
|
|
/** should we strictly match results? ie block `*wikipedia.org*` or just `wikipedia.org`? */
|
|
static constexpr bool STRICT_MATCHING = false;
|
|
|
|
/** DNS server to use for forwarding to / resolving DNS requests */
|
|
static inline BLT_CPP20_CONSTEXPR std::string DNS_SERVER_IP()
|
|
{
|
|
return "8.8.8.8";
|
|
}
|
|
|
|
/** replacement IP address. Make sure this is a 4 octet string seperated by `.` */
|
|
static inline BLT_CPP20_CONSTEXPR IPAddress REPLACEMENT_IP()
|
|
{
|
|
return {"139.57.100.6"};
|
|
}
|
|
|
|
/**
|
|
* List of disallowed domains.
|
|
* Note: if you are using STRICT_MATCHING=false it will not match to the root domain.
|
|
* Eg it will match `*en.wikipedia.org*` NOT `*wikipedia.org*`
|
|
*/
|
|
static const std::unordered_set<std::string> DISALLOWED_DOMAINS{
|
|
"en.wikipedia.org",
|
|
"zombo.com"
|
|
};
|
|
|
|
/*
|
|
* -----------------------------------
|
|
* | Do Not Change |
|
|
* -----------------------------------
|
|
*/
|
|
// these features were planned but not added because I realized you guys won't care or give extra marks which broke my obsession with it
|
|
// so uhh don't change em otherwise the code will break :3
|
|
// as it is im still adding new features and stuff (TCP) and messing with the code trying to get it cleaner
|
|
|
|
/** true -> only match A records ; false -> match any named record (configure with NON_STRICT_REPLACE_ALL) */
|
|
static constexpr bool STRICT_FILTERING = true;
|
|
/** true -> match all records ; false -> match only records we might want to replace (A, AAAA, CNAME) */
|
|
static constexpr bool NON_STRICT_REPLACE_ALL = true;
|
|
|
|
// was going to add ~~TCP~~ (this is now a thing for full DIG support. "can you dig it, sucka!?") and ad blocking support
|
|
// that also isn't going to happen now.
|
|
/** list of web address to download the ad block lists from */
|
|
static BLT_CPP20_CONSTEXPR std::vector<std::string> BLOCK_LISTS{};
|
|
/** true -> block ad DNS requests ; false -> do nothing */
|
|
static constexpr bool BLOCK_ADS = false;
|
|
/** true -> send back the REPLACEMENT_IP() ; false -> send back a fail state in the DNS request. */
|
|
static constexpr bool REDIRECT_ADS = true;
|
|
|
|
// 5F826B
|
|
|
|
// DNS data contains:
|
|
// 2 bytes for transaction id
|
|
// 2 bytes for flags
|
|
// 2 bytes for number of questions
|
|
// 2 bytes for Answer RRs
|
|
// 2 bytes for Authority RRs
|
|
// 2 bytes for Additional RRs
|
|
|
|
// question format:
|
|
// 1 byte for length
|
|
// (length) bytes per section
|
|
// ... until length 0
|
|
// 2 bytes for QTYPE (useless)
|
|
// 2 bytes for QCLASS (useless)
|
|
|
|
// answer format:
|
|
// 2 byte for domain name (offset ptr, still not 100 on from where)
|
|
// 2 byte for type
|
|
// 2 byte for class
|
|
// 4 byte for time to live
|
|
// 2 byte for length of data
|
|
// (lengthy) byte for data
|
|
|
|
class send_buffer;
|
|
|
|
/**
|
|
* This data structure represents a DNS question. When constructed it will read the FULL domain as a single string, along with the QTYPE and QCLASS
|
|
* It is safe to read QDCOUNT questions by constructing this class.
|
|
* The question will be reconstructed by the send_buffer class.
|
|
*/
|
|
class question
|
|
{
|
|
friend send_buffer;
|
|
private:
|
|
std::string domain;
|
|
uint16_t QTYPE = 0;
|
|
uint16_t QCLASS = 0;
|
|
public:
|
|
explicit question(const blt::byte_reader& reader)
|
|
{
|
|
// process the full question
|
|
while (true)
|
|
{
|
|
uint8_t length = reader.next();
|
|
if (length == 0)
|
|
break;
|
|
if (!domain.empty())
|
|
domain += '.';
|
|
for (uint8_t j = 0; j < length; j++)
|
|
domain += static_cast<char>(reader.next());
|
|
}
|
|
// Skip QTYPE and QCLASS
|
|
// but keep a copy of it for future writing
|
|
reader.to(QTYPE);
|
|
reader.to(QCLASS);
|
|
}
|
|
|
|
const std::string& operator()()
|
|
{
|
|
return domain;
|
|
}
|
|
};
|
|
|
|
/**
|
|
* This data structure represents a DNS answer. When constructed it will read the FULL answer along with the associated data. It is therefore safe
|
|
* to read ANCOUNT by constructing a series of answers which read from the byte stream. This class cannot be copied but can be moved.
|
|
* The answer will be rebuilt by the send_buffer class for you.
|
|
*/
|
|
class answer
|
|
{
|
|
friend send_buffer;
|
|
private:
|
|
mutable uint16_t NAME = 0;
|
|
uint16_t TYPE = 0;
|
|
uint16_t CLASS = 0;
|
|
uint32_t TTL = 0;
|
|
uint16_t RDLENGTH = 0;
|
|
bool requires_reset = false;
|
|
blt::scoped_buffer<unsigned char> RDATA;
|
|
public:
|
|
explicit answer(const blt::byte_reader& reader)
|
|
{
|
|
reader.to(NAME);
|
|
reader.to(TYPE);
|
|
reader.to(CLASS);
|
|
reader.to(TTL);
|
|
reader.to(RDLENGTH);
|
|
RDATA = blt::scoped_buffer<unsigned char>(RDLENGTH);
|
|
reader.copy(RDATA.data(), RDLENGTH);
|
|
}
|
|
|
|
[[nodiscard]] uint16_t type() const
|
|
{
|
|
return TYPE;
|
|
}
|
|
|
|
inline void substitute(const IPAddress& addr)
|
|
{
|
|
BLT_DEBUG("Substituting with replacement address '%s'", REPLACEMENT_IP().asString.c_str());
|
|
BLT_ASSERT(RDLENGTH == 4);
|
|
std::memcpy(RDATA.data(), addr.octets, 4);
|
|
}
|
|
|
|
inline void setARecord(const IPAddress& addr)
|
|
{
|
|
BLT_DEBUG("Setting answer to A record");
|
|
NAME = 0;
|
|
NAME |= (0b11 << 14);
|
|
requires_reset = true;
|
|
BLT_TRACE(NAME);
|
|
RDATA = blt::scoped_buffer<unsigned char>(4);
|
|
RDLENGTH = 4;
|
|
TYPE = 1;
|
|
CLASS = 1;
|
|
substitute(addr);
|
|
}
|
|
|
|
inline void reset(size_t offset) const
|
|
{
|
|
if (!requires_reset)
|
|
return;
|
|
// like I said not 100 on how to construct the ptr
|
|
// seems to be causing issues. I've stopped working on this as it's not required.
|
|
auto i16 = static_cast<uint16_t>(offset) & (~(0b11 << 14));
|
|
NAME |= i16;
|
|
}
|
|
|
|
// rule of 5
|
|
// (there used to be a destructor)
|
|
answer(const answer& copy) = delete;
|
|
|
|
answer& operator=(const answer& copy) = delete;
|
|
|
|
answer(answer&& move) noexcept
|
|
{
|
|
NAME = move.NAME;
|
|
TYPE = move.TYPE;
|
|
CLASS = move.CLASS;
|
|
TTL = move.TTL;
|
|
RDLENGTH = move.RDLENGTH;
|
|
RDATA = std::move(move.RDATA);
|
|
}
|
|
|
|
answer& operator=(answer&& move) noexcept
|
|
{
|
|
NAME = 0;
|
|
NAME = move.NAME;
|
|
TYPE = move.TYPE;
|
|
CLASS = move.CLASS;
|
|
TTL = move.TTL;
|
|
RDLENGTH = move.RDLENGTH;
|
|
RDATA = std::move(move.RDATA);
|
|
return *this;
|
|
}
|
|
|
|
// there used to be a destructor
|
|
~answer() = default;
|
|
};
|
|
|
|
class send_buffer
|
|
{
|
|
private:
|
|
mutable std::array<unsigned char, 65535> internal_data{};
|
|
mutable size_t write_index = 0;
|
|
public:
|
|
send_buffer() = default;
|
|
|
|
void write(unsigned char* data, size_t size) const
|
|
{
|
|
std::memcpy(&internal_data[write_index], data, size);
|
|
write_index += size;
|
|
}
|
|
|
|
template<typename T>
|
|
void write(const T& t) const
|
|
{
|
|
if constexpr (std::is_arithmetic_v<T>)
|
|
{
|
|
blt::mem::toBytes(t, &internal_data[write_index]);
|
|
write_index += sizeof(T);
|
|
} else if constexpr (std::is_same_v<T, answer>)
|
|
{
|
|
write(t.NAME);
|
|
write(t.TYPE);
|
|
write(t.CLASS);
|
|
write(t.TTL);
|
|
write(t.RDLENGTH);
|
|
std::memcpy(&internal_data[write_index], t.RDATA.data(), t.RDLENGTH);
|
|
write_index += t.RDLENGTH;
|
|
} else if constexpr (std::is_same_v<T, question>)
|
|
{
|
|
auto labels = blt::string::split(t.domain, '.');
|
|
for (const auto& label : labels)
|
|
{
|
|
auto length = static_cast<uint8_t>(label.length());
|
|
write(length);
|
|
std::memcpy(&internal_data[write_index], &label[0], length);
|
|
write_index += length;
|
|
}
|
|
// write length of 0 to signal end of labels
|
|
write('\0');
|
|
write(t.QTYPE);
|
|
write(t.QCLASS);
|
|
} else
|
|
static_assert("Data type not supported!");
|
|
}
|
|
|
|
void reset() const
|
|
{
|
|
write_index = 0;
|
|
}
|
|
|
|
[[nodiscard]] unsigned char operator[](size_t i)
|
|
{
|
|
return internal_data[i];
|
|
}
|
|
|
|
[[nodiscard]] unsigned char* data()
|
|
{
|
|
return internal_data.data();
|
|
}
|
|
|
|
[[nodiscard]] size_t size() const
|
|
{
|
|
return write_index;
|
|
}
|
|
|
|
auto buffer() const
|
|
{
|
|
return asio::buffer(internal_data.data(), size());
|
|
}
|
|
};
|
|
|
|
inline bool shouldReplace(const answer& a)
|
|
{
|
|
// a records will be handled in either case, check for others like AAAA or CNAME
|
|
// TODO: add enums to this + a way to add custom types
|
|
return NON_STRICT_REPLACE_ALL || a.type() == 28 || a.type() == 5;
|
|
}
|
|
|
|
void process_answers(std::vector<answer>& answers)
|
|
{
|
|
for (auto& a : answers)
|
|
{
|
|
if (a.type() == 1)
|
|
{
|
|
a.substitute(REPLACEMENT_IP());
|
|
} else if (!STRICT_FILTERING && shouldReplace(a))
|
|
{
|
|
a.setARecord(REPLACEMENT_IP());
|
|
}
|
|
}
|
|
}
|
|
|
|
template<typename T, typename INFO, typename MESSENGER>
|
|
request_info handle_forward_request(MESSENGER messenger, const INFO& info, const T& input_recv_buffer, size_t bytes, T& forward_recv_buffer)
|
|
{
|
|
// get the number of questions
|
|
uint16_t questions; // yes I made this part of my library just for this :3
|
|
blt::mem::fromBytes(&input_recv_buffer[info.QUESTIONS_BEGIN], questions); // i hate little endian
|
|
|
|
BLT_INFO("(%s) Bytes received: %d with %d questions", blt::is_UDP_or_TCP<INFO>().c_str(), bytes, questions);
|
|
|
|
// forward to google.
|
|
size_t out_bytes;
|
|
const auto* data = input_recv_buffer.data();
|
|
messenger(DNS_SERVER_IP(), data, bytes, forward_recv_buffer, out_bytes);
|
|
|
|
uint16_t num_of_answers;
|
|
blt::mem::fromBytes(&forward_recv_buffer[info.ANSWERS_BEGIN], num_of_answers);
|
|
|
|
return {out_bytes, num_of_answers};
|
|
}
|
|
|
|
template<typename T, typename INFO>
|
|
void handle_response(const INFO& info, send_buffer& return_send_buffer, request_info rq_info, T& forward_recv_buffer)
|
|
{
|
|
blt::byte_reader reader(forward_recv_buffer.data(), forward_recv_buffer.size(), info.HEADER_END);
|
|
|
|
auto TYPE_STR = blt::is_UDP_or_TCP<INFO>();
|
|
BLT_INFO("(%s) Bytes answered: %d with %d answers", TYPE_STR.c_str(), rq_info.number_of_bytes, rq_info.number_of_answers);
|
|
|
|
// no one actually does multiple questions. trying to do it in dig is not easy
|
|
// and the standard isn't really designed for this (how do we handle if one question errors but the other doesn't? there is only
|
|
// one return code.)
|
|
question q(reader);
|
|
std::vector<answer> answers;
|
|
for (size_t i = 0; i < rq_info.number_of_answers; i++)
|
|
{
|
|
answer a(reader);
|
|
answers.push_back(std::move(a));
|
|
}
|
|
|
|
BLT_INFO("(%s) DOMAIN: %s", TYPE_STR.c_str(), q().c_str());
|
|
if (STRICT_MATCHING && BLT_CONTAINS(DISALLOWED_DOMAINS, q()))
|
|
process_answers(answers);
|
|
else if (!STRICT_MATCHING)
|
|
{
|
|
// linear search the domains for contains. Maybe find a better way to do this.
|
|
for (const auto& v : DISALLOWED_DOMAINS)
|
|
if (blt::string::contains(q(), v))
|
|
process_answers(answers);
|
|
}
|
|
|
|
return_send_buffer.write(forward_recv_buffer.data(), info.HEADER_END);
|
|
// need to cache this value oh wait we aren't doing this anymore
|
|
auto question_offset = return_send_buffer.size();
|
|
return_send_buffer.write(q);
|
|
for (const answer& a : answers)
|
|
{
|
|
BLT_TRACE("(%s) Writing answer with type of %d", TYPE_STR.c_str(), a.type());
|
|
a.reset(question_offset);
|
|
return_send_buffer.write(a);
|
|
}
|
|
return_send_buffer.write(reader.from(), rq_info.number_of_bytes - reader.last());
|
|
}
|
|
|
|
void run_udp_server()
|
|
{
|
|
try
|
|
{
|
|
asio::io_context io_context;
|
|
|
|
udp::socket socket(io_context, toUDPEndpoint(address()));
|
|
|
|
blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE};
|
|
blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE};
|
|
while (true)
|
|
{
|
|
udp::endpoint remote_endpoint;
|
|
size_t bytes = socket.receive_from(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), remote_endpoint);
|
|
|
|
auto rq_info = handle_forward_request(blt::network::sendUDPMessage, DNS_UDP_INFO, input_recv_buffer, bytes, forward_recv_buffer);
|
|
|
|
send_buffer return_send_buffer;
|
|
handle_response(DNS_UDP_INFO, return_send_buffer, rq_info, forward_recv_buffer);
|
|
|
|
asio::error_code ignored_error;
|
|
socket.send_to(return_send_buffer.buffer(), remote_endpoint, 0, ignored_error);
|
|
}
|
|
}
|
|
catch (std::exception& e)
|
|
{
|
|
BLT_ERROR(e.what());
|
|
}
|
|
}
|
|
|
|
void run_tcp_server()
|
|
{
|
|
try
|
|
{
|
|
asio::io_context io_context;
|
|
|
|
tcp::acceptor acceptor(io_context, toTCPEndpoint(address()));
|
|
|
|
blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE};
|
|
blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE};
|
|
while (true)
|
|
{
|
|
tcp::socket socket(io_context);
|
|
acceptor.accept(socket);
|
|
|
|
asio::error_code error;
|
|
size_t bytes = socket.read_some(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), error);
|
|
if (error == asio::error::eof)
|
|
break;
|
|
else if (error)
|
|
throw asio::system_error(error);
|
|
|
|
auto rq_info = handle_forward_request(blt::network::sendTCPMessage, DNS_TCP_INFO, input_recv_buffer, bytes, forward_recv_buffer);
|
|
|
|
send_buffer return_send_buffer;
|
|
handle_response(DNS_TCP_INFO, return_send_buffer, rq_info, forward_recv_buffer);
|
|
|
|
asio::write(socket, return_send_buffer.buffer());
|
|
}
|
|
}
|
|
catch (std::exception& e)
|
|
{
|
|
BLT_ERROR(e.what());
|
|
}
|
|
}
|
|
|
|
int main()
|
|
{
|
|
BLT_INFO("Creating UDP Server");
|
|
std::thread udp_server(run_udp_server);
|
|
BLT_INFO("Creating TCP Server");
|
|
std::thread tcp_server(run_tcp_server);
|
|
|
|
BLT_INFO("Awaiting");
|
|
udp_server.join();
|
|
tcp_server.join();
|
|
return 0;
|
|
}
|