/**
 * Brett    Terpstra    6920201
 * Michael  Boulos      6973523
 * Adrian   Pinu        6970677
 */

#include <iostream>

#include <asio.hpp>
#include <blt/std/logging.h>
#include <blt/std/memory.h>
#include <blt/compatibility.h>
#include <string>
#include <array>
#include <thread>
#include <iomanip>
#include <vector>
#include <typeinfo>
#include <type_traits>
#include <unordered_set>
#include "ip.h"
#include "insane_dns/util.h"

/*
 * ----------------------------
 *          | CONFIG |
 * ----------------------------
 */
/** What port to run the server on */
static constexpr unsigned short int SERVER_PORT = 5555;

/**
 * How should we bind to the local machine? By default this will use IPv4 and IPv6 with the default ASIO endpoint.
 * Only change this if you really need to and know what you are doing.
 */
static BLT_CPP20_CONSTEXPR bind_address address() {
    return {SERVER_PORT};
}

/** should we strictly match results? ie block `*wikipedia.org*` or just `wikipedia.org`? */
static constexpr bool STRICT_MATCHING = false;

/** DNS server to use for forwarding to / resolving DNS requests */
static inline BLT_CPP20_CONSTEXPR std::string DNS_SERVER_IP()
{
    return "8.8.8.8";
}

/** replacement IP address. Make sure this is a 4 octet string seperated by `.` */
static inline BLT_CPP20_CONSTEXPR IPAddress REPLACEMENT_IP()
{
    return {"139.57.100.6"};
}

/**
 * List of disallowed domains.
 * Note: if you are using STRICT_MATCHING=false it will not match to the root domain.
 * Eg it will match `*en.wikipedia.org*` NOT `*wikipedia.org*`
 */
static const std::unordered_set<std::string> DISALLOWED_DOMAINS{
        "en.wikipedia.org",
        "zombo.com"
};

/*
 * -----------------------------------
 *          | Do Not Change |
 * -----------------------------------
 */
// these features were planned but not added because I realized you guys won't care or give extra marks which broke my obsession with it
// so uhh don't change em otherwise the code will break :3
// as it is im still adding new features and stuff (TCP) and messing with the code trying to get it cleaner

/** true -> only match A records ; false -> match any named record (configure with NON_STRICT_REPLACE_ALL) */
static constexpr bool STRICT_FILTERING = true;
/** true -> match all records ; false -> match only records we might want to replace (A, AAAA, CNAME) */
static constexpr bool NON_STRICT_REPLACE_ALL = true;

// was going to add ~~TCP~~ (this is now a thing for full DIG support. "can you dig it, sucka!?") and ad blocking support
// that also isn't going to happen now.
/** list of web address to download the ad block lists from */
static BLT_CPP20_CONSTEXPR std::vector<std::string> BLOCK_LISTS{};
/** true -> block ad DNS requests ; false -> do nothing */
static constexpr bool BLOCK_ADS = false;
/** true -> send back the REPLACEMENT_IP() ; false -> send back a fail state in the DNS request. */
static constexpr bool REDIRECT_ADS = true;

// 5F826B

// DNS data contains:
// 2 bytes for transaction id
// 2 bytes for flags
// 2 bytes for number of questions
// 2 bytes for Answer RRs
// 2 bytes for Authority RRs
// 2 bytes for Additional RRs

// question format:
// 1 byte for length
// (length) bytes per section
// ... until length 0
// 2 bytes for QTYPE (useless)
// 2 bytes for QCLASS (useless)

// answer format:
// 2 byte for domain name (offset ptr, still not 100 on from where)
// 2 byte for type
// 2 byte for class
// 4 byte for time to live
// 2 byte for length of data
// (lengthy) byte for data

class send_buffer;

/**
 * This data structure represents a DNS question. When constructed it will read the FULL domain as a single string, along with the QTYPE and QCLASS
 * It is safe to read QDCOUNT questions by constructing this class.
 * The question will be reconstructed by the send_buffer class.
 */
class question
{
        friend send_buffer;
    private:
        std::string domain;
        uint16_t QTYPE = 0;
        uint16_t QCLASS = 0;
    public:
        explicit question(const blt::byte_reader& reader)
        {
            // process the full question
            while (true)
            {
                uint8_t length = reader.next();
                if (length == 0)
                    break;
                if (!domain.empty())
                    domain += '.';
                for (uint8_t j = 0; j < length; j++)
                    domain += static_cast<char>(reader.next());
            }
            // Skip QTYPE and QCLASS
            // but keep a copy of it for future writing
            reader.to(QTYPE);
            reader.to(QCLASS);
        }
        
        const std::string& operator()()
        {
            return domain;
        }
};

/**
 * This data structure represents a DNS answer. When constructed it will read the FULL answer along with the associated data. It is therefore safe
 * to read ANCOUNT by constructing a series of answers which read from the byte stream. This class cannot be copied but can be moved.
 * The answer will be rebuilt by the send_buffer class for you.
 */
class answer
{
        friend send_buffer;
    private:
        mutable uint16_t NAME = 0;
        uint16_t TYPE = 0;
        uint16_t CLASS = 0;
        uint32_t TTL = 0;
        uint16_t RDLENGTH = 0;
        bool requires_reset = false;
        blt::scoped_buffer<unsigned char> RDATA;
    public:
        explicit answer(const blt::byte_reader& reader)
        {
            reader.to(NAME);
            reader.to(TYPE);
            reader.to(CLASS);
            reader.to(TTL);
            reader.to(RDLENGTH);
            RDATA = blt::scoped_buffer<unsigned char>(RDLENGTH);
            reader.copy(RDATA.data(), RDLENGTH);
        }
        
        [[nodiscard]] uint16_t type() const
        {
            return TYPE;
        }
        
        inline void substitute(const IPAddress& addr)
        {
            BLT_DEBUG("Substituting with replacement address '%s'", REPLACEMENT_IP().asString.c_str());
            BLT_ASSERT(RDLENGTH == 4);
            std::memcpy(RDATA.data(), addr.octets, 4);
        }
        
        inline void setARecord(const IPAddress& addr)
        {
            BLT_DEBUG("Setting answer to A record");
            NAME = 0;
            NAME |= (0b11 << 14);
            requires_reset = true;
            BLT_TRACE(NAME);
            RDATA = blt::scoped_buffer<unsigned char>(4);
            RDLENGTH = 4;
            TYPE = 1;
            CLASS = 1;
            substitute(addr);
        }
        
        inline void reset(size_t offset) const
        {
            if (!requires_reset)
                return;
            // like I said not 100 on how to construct the ptr
            // seems to be causing issues. I've stopped working on this as it's not required.
            auto i16 = static_cast<uint16_t>(offset) & (~(0b11 << 14));
            NAME |= i16;
        }
        
        // rule of 5
        // (there used to be a destructor)
        answer(const answer& copy) = delete;
        
        answer& operator=(const answer& copy) = delete;
        
        answer(answer&& move) noexcept
        {
            NAME = move.NAME;
            TYPE = move.TYPE;
            CLASS = move.CLASS;
            TTL = move.TTL;
            RDLENGTH = move.RDLENGTH;
            RDATA = std::move(move.RDATA);
        }
        
        answer& operator=(answer&& move) noexcept
        {
            NAME = 0;
            NAME = move.NAME;
            TYPE = move.TYPE;
            CLASS = move.CLASS;
            TTL = move.TTL;
            RDLENGTH = move.RDLENGTH;
            RDATA = std::move(move.RDATA);
            return *this;
        }
        
        // there used to be a destructor
        ~answer() = default;
};

class send_buffer
{
    private:
        mutable std::array<unsigned char, 65535> internal_data{};
        mutable size_t write_index = 0;
    public:
        send_buffer() = default;
        
        void write(unsigned char* data, size_t size) const
        {
            std::memcpy(&internal_data[write_index], data, size);
            write_index += size;
        }
        
        template<typename T>
        void write(const T& t) const
        {
            if constexpr (std::is_arithmetic_v<T>)
            {
                blt::mem::toBytes(t, &internal_data[write_index]);
                write_index += sizeof(T);
            } else if constexpr (std::is_same_v<T, answer>)
            {
                write(t.NAME);
                write(t.TYPE);
                write(t.CLASS);
                write(t.TTL);
                write(t.RDLENGTH);
                std::memcpy(&internal_data[write_index], t.RDATA.data(), t.RDLENGTH);
                write_index += t.RDLENGTH;
            } else if constexpr (std::is_same_v<T, question>)
            {
                auto labels = blt::string::split(t.domain, '.');
                for (const auto& label : labels)
                {
                    auto length = static_cast<uint8_t>(label.length());
                    write(length);
                    std::memcpy(&internal_data[write_index], &label[0], length);
                    write_index += length;
                }
                // write length of 0 to signal end of labels
                write('\0');
                write(t.QTYPE);
                write(t.QCLASS);
            } else
                static_assert("Data type not supported!");
        }
        
        void reset() const
        {
            write_index = 0;
        }
        
        [[nodiscard]] unsigned char operator[](size_t i)
        {
            return internal_data[i];
        }
        
        [[nodiscard]] unsigned char* data()
        {
            return internal_data.data();
        }
        
        [[nodiscard]] size_t size() const
        {
            return write_index;
        }
        
        auto buffer() const
        {
            return asio::buffer(internal_data.data(), size());
        }
};

inline bool shouldReplace(const answer& a)
{
    // a records will be handled in either case, check for others like AAAA or CNAME
    // TODO: add enums to this + a way to add custom types
    return NON_STRICT_REPLACE_ALL || a.type() == 28 || a.type() == 5;
}

void process_answers(std::vector<answer>& answers)
{
    for (auto& a : answers)
    {
        if (a.type() == 1)
        {
            a.substitute(REPLACEMENT_IP());
        } else if (!STRICT_FILTERING && shouldReplace(a))
        {
            a.setARecord(REPLACEMENT_IP());
        }
    }
}

template<typename T, typename INFO, typename MESSENGER>
request_info handle_forward_request(MESSENGER messenger, const INFO& info, const T& input_recv_buffer, size_t bytes, T& forward_recv_buffer)
{
    // get the number of questions
    uint16_t questions; // yes I made this part of my library just for this :3
    blt::mem::fromBytes(&input_recv_buffer[info.QUESTIONS_BEGIN], questions); // i hate little endian
    
    BLT_INFO("(%s) Bytes received: %d with %d questions", blt::is_UDP_or_TCP<INFO>().c_str(), bytes, questions);
    
    // forward to google.
    size_t out_bytes;
    const auto* data = input_recv_buffer.data();
    messenger(DNS_SERVER_IP(), data, bytes, forward_recv_buffer, out_bytes);
    
    uint16_t num_of_answers;
    blt::mem::fromBytes(&forward_recv_buffer[info.ANSWERS_BEGIN], num_of_answers);
    
    return {out_bytes, num_of_answers};
}

template<typename T, typename INFO>
void handle_response(const INFO& info, send_buffer& return_send_buffer, request_info rq_info, T& forward_recv_buffer)
{
    blt::byte_reader reader(forward_recv_buffer.data(), forward_recv_buffer.size(), info.HEADER_END);
    
    auto TYPE_STR = blt::is_UDP_or_TCP<INFO>();
    BLT_INFO("(%s) Bytes answered: %d with %d answers", TYPE_STR.c_str(), rq_info.number_of_bytes, rq_info.number_of_answers);
    
    // no one actually does multiple questions. trying to do it in dig is not easy
    // and the standard isn't really designed for this (how do we handle if one question errors but the other doesn't? there is only
    // one return code.)
    question q(reader);
    std::vector<answer> answers;
    for (size_t i = 0; i < rq_info.number_of_answers; i++)
    {
        answer a(reader);
        answers.push_back(std::move(a));
    }
    
    BLT_INFO("(%s) DOMAIN: %s", TYPE_STR.c_str(), q().c_str());
    if (STRICT_MATCHING && BLT_CONTAINS(DISALLOWED_DOMAINS, q()))
        process_answers(answers);
    else if (!STRICT_MATCHING)
    {
        // linear search the domains for contains. Maybe find a better way to do this.
        for (const auto& v : DISALLOWED_DOMAINS)
            if (blt::string::contains(q(), v))
                process_answers(answers);
    }
    
    return_send_buffer.write(forward_recv_buffer.data(), info.HEADER_END);
    // need to cache this value oh wait we aren't doing this anymore
    auto question_offset = return_send_buffer.size();
    return_send_buffer.write(q);
    for (const answer& a : answers)
    {
        BLT_TRACE("(%s) Writing answer with type of %d", TYPE_STR.c_str(), a.type());
        a.reset(question_offset);
        return_send_buffer.write(a);
    }
    return_send_buffer.write(reader.from(), rq_info.number_of_bytes - reader.last());
}

void run_udp_server()
{
    try
    {
        asio::io_context io_context;
        
        udp::socket socket(io_context, toUDPEndpoint(address()));
        
        blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE};
        blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE};
        while (true)
        {
            udp::endpoint remote_endpoint;
            size_t bytes = socket.receive_from(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), remote_endpoint);
            
            auto rq_info = handle_forward_request(blt::network::sendUDPMessage, DNS_UDP_INFO, input_recv_buffer, bytes, forward_recv_buffer);
            
            send_buffer return_send_buffer;
            handle_response(DNS_UDP_INFO, return_send_buffer, rq_info, forward_recv_buffer);
            
            asio::error_code ignored_error;
            socket.send_to(return_send_buffer.buffer(), remote_endpoint, 0, ignored_error);
        }
    }
    catch (std::exception& e)
    {
        BLT_ERROR(e.what());
    }
}

void run_tcp_server()
{
    try
    {
        asio::io_context io_context;
        
        tcp::acceptor acceptor(io_context, toTCPEndpoint(address()));
        
        blt::scoped_buffer<unsigned char> input_recv_buffer{PACKET_BUFFER_SIZE};
        blt::scoped_buffer<unsigned char> forward_recv_buffer{PACKET_BUFFER_SIZE};
        while (true)
        {
            tcp::socket socket(io_context);
            acceptor.accept(socket);
            
            asio::error_code error;
            size_t bytes = socket.read_some(asio::buffer(input_recv_buffer.data(), input_recv_buffer.size()), error);
            if (error == asio::error::eof)
                break;
            else if (error)
                throw asio::system_error(error);
            
            auto rq_info = handle_forward_request(blt::network::sendTCPMessage, DNS_TCP_INFO, input_recv_buffer, bytes, forward_recv_buffer);
            
            send_buffer return_send_buffer;
            handle_response(DNS_TCP_INFO, return_send_buffer, rq_info, forward_recv_buffer);
            
            asio::write(socket, return_send_buffer.buffer());
        }
    }
    catch (std::exception& e)
    {
        BLT_ERROR(e.what());
    }
}

int main()
{
    BLT_INFO("Creating UDP Server");
    std::thread udp_server(run_udp_server);
    BLT_INFO("Creating TCP Server");
    std::thread tcp_server(run_tcp_server);
    
    BLT_INFO("Awaiting");
    udp_server.join();
    tcp_server.join();
    return 0;
}