diff --git a/include/data_structs.h b/include/data_structs.h new file mode 100644 index 0000000..e8dfcfd --- /dev/null +++ b/include/data_structs.h @@ -0,0 +1,208 @@ +/* + * + * Copyright (C) 2024 Brett Terpstra + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef DISCORD_BOT_DATA_STRUCTS_H +#define DISCORD_BOT_DATA_STRUCTS_H + +#include +#include + +namespace db +{ + struct server_info_t + { + blt::u32 member_count; + std::string name; + std::string description; + std::string icon; + std::string splash; + std::string discovery_splash; + std::string banner; + }; + + struct user_info_t + { + blt::u64 userID; + std::string username; + std::string global_nickname; + std::string server_name; + }; + + struct user_history_t + { + blt::u64 userID; + blt::u64 time_changed; + std::string old_username; + std::string old_global_nickname; + std::string old_server_name; + }; + + struct channel_info_t + { + blt::u64 channelID; + std::string channel_name; + }; + + struct channel_history_t + { + blt::u64 channelID; + blt::u64 time_changed; + std::string old_channel_name; + }; + + struct message_t + { + blt::u64 messageID; + blt::u64 channelID; + blt::u64 userID; + std::string content; + }; + + struct attachment_t + { + blt::u64 messageID; + std::string url; + }; + + struct message_edits_t + { + blt::u64 messageID; + std::string old_content; + std::string new_content; + }; + + struct message_deletes_t + { + blt::u64 messageID; + blt::u64 channelID; + std::string content; + }; + + auto make_user_table() + { + using namespace sqlite_orm; + return make_table("users", + make_column("userID", &user_info_t::userID, primary_key()), + make_column("username", &user_info_t::username), + make_column("global_nickname", &user_info_t::global_nickname), + make_column("server_name", &user_info_t::server_name)); + } + + using user_table_t = decltype(make_user_table()); + + auto make_user_history_table() + { + using namespace sqlite_orm; + return make_table("user_history", + make_column("userID", &user_history_t::userID), + make_column("time_changed", &user_history_t::time_changed), + make_column("old_username", &user_history_t::old_username), + make_column("old_global_nickname", &user_history_t::old_global_nickname), + make_column("old_server_name", &user_history_t::old_server_name), + foreign_key(&user_history_t::userID).references(&user_info_t::userID), + primary_key(&user_history_t::userID, &user_history_t::time_changed)); + } + + using user_history_table_t = decltype(make_user_history_table()); + + auto make_channel_table() + { + using namespace sqlite_orm; + return make_table("channels", + make_column("channelID", &channel_info_t::channelID, primary_key()), + make_column("channel_name", &channel_info_t::channel_name)); + } + + using channel_table_t = decltype(make_channel_table()); + + auto make_channel_history_table() + { + using namespace sqlite_orm; + return make_table("channel_history", + make_column("channelID", &channel_history_t::channelID), + make_column("time_changed", &channel_history_t::time_changed), + make_column("old_channel_name", &channel_history_t::old_channel_name), + foreign_key(&channel_history_t::channelID).references(&channel_info_t::channelID), + primary_key(&channel_history_t::channelID, &channel_history_t::time_changed)); + } + + using channel_history_table_t = decltype(make_channel_history_table()); + + auto make_message_table() + { + return make_table("messages", + make_column("messageID", &message_t::messageID, primary_key()), + make_column("channelID", &message_t::channelID), + make_column("userID", &message_t::userID), + make_column("content", &message_t::content), + foreign_key(&message_t::channelID).references(&channel_info_t::channelID), + foreign_key(&message_t::userID).references(&user_info_t::userID)); + } + + using message_table_t = decltype(make_message_table()); + + auto make_attachment_table() + { + using namespace sqlite_orm; + return make_table("attachments", + make_column("messageID", &attachment_t::messageID), + make_column("url", &attachment_t::url), + foreign_key(&attachment_t::messageID).references(&message_t::messageID), + primary_key(&attachment_t::messageID, &attachment_t::url)); + } + + using attachment_table_t = decltype(make_attachment_table()); + + auto make_message_edits_table() + { + using namespace sqlite_orm; + return make_table("message_edits", + make_column("messageID", &message_edits_t::messageID), + make_column("old_content", &message_edits_t::old_content), + make_column("new_content", &message_edits_t::new_content), + foreign_key(&message_edits_t::messageID).references(&message_t::messageID), + primary_key(&message_edits_t::messageID, &message_edits_t::old_content, &message_edits_t::new_content)); + } + + using message_edits_table_t = decltype(make_message_edits_table()); + + auto make_message_deletes_table() + { + using namespace sqlite_orm; + return make_table("message_deletes", + make_column("messageID", &message_deletes_t::messageID), + make_column("channelID", &message_deletes_t::channelID), + make_column("content", &message_deletes_t::content), + foreign_key(&message_deletes_t::messageID).references(&message_t::messageID), + foreign_key(&message_deletes_t::channelID).references(&channel_info_t::channelID), + primary_key(&message_deletes_t::messageID, &message_deletes_t::channelID)); + } + + using message_deletes_table_t = decltype(make_message_deletes_table()); + + auto make_database(std::string path) + { + using namespace sqlite_orm; + return make_storage(std::move(path), make_user_table(), make_user_history_table(), make_channel_table(), make_channel_history_table(), + make_message_table(), make_attachment_table(), make_message_edits_table(), make_message_deletes_table()); + } + + using database_type = decltype(make_database("")); +} + +#endif //DISCORD_BOT_DATA_STRUCTS_H diff --git a/libs/blt b/libs/blt index 9ad6521..9b4d0cc 160000 --- a/libs/blt +++ b/libs/blt @@ -1 +1 @@ -Subproject commit 9ad652195b0a69f9977d313eff4dd01a7890f1df +Subproject commit 9b4d0cc9a8493c608ab0075ab2c6a2b66061f3be diff --git a/src/main.cpp b/src/main.cpp index ce422d8..7a7542b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,206 +7,27 @@ #include #include "blt/std/types.h" #include "blt/std/utility.h" +#include #include #include #include #include namespace sql = sqlite_orm; - -struct server_info_t -{ - blt::u32 member_count; - std::string name; - std::string description; - std::string icon; - std::string splash; - std::string discovery_splash; - std::string banner; -}; - -struct user_info_t -{ - blt::u64 userID; - std::string username; - std::string global_nickname; - std::string server_name; -}; - -auto make_user_table() -{ - return sql::make_table("users", - sql::make_column("userID", &user_info_t::userID, sql::primary_key()), - sql::make_column("username", &user_info_t::username), - sql::make_column("global_nickname", &user_info_t::global_nickname), - sql::make_column("server_name", &user_info_t::server_name)); -} - -using user_table_t = decltype(make_user_table()); - -struct user_history_t -{ - blt::u64 userID; - blt::u64 time_changed; - std::string old_username; - std::string old_global_nickname; - std::string old_server_name; -}; - -auto make_user_history_table() -{ - return sql::make_table("user_history", - sql::make_column("userID", &user_history_t::userID), - sql::make_column("time_changed", &user_history_t::time_changed), - sql::make_column("old_username", &user_history_t::old_username), - sql::make_column("old_global_nickname", &user_history_t::old_global_nickname), - sql::make_column("old_server_name", &user_history_t::old_server_name), - sql::foreign_key(&user_history_t::userID).references(&user_info_t::userID), - sql::primary_key(&user_history_t::userID, &user_history_t::time_changed)); -} - -using user_history_table_t = decltype(make_user_history_table()); - -struct channel_info_t -{ - blt::u64 channelID; - std::string channel_name; -}; - -auto make_channel_table() -{ - return sql::make_table("channels", - sql::make_column("channelID", &channel_info_t::channelID, sql::primary_key()), - sql::make_column("channel_name", &channel_info_t::channel_name)); -} - -using channel_table_t = decltype(make_channel_table()); - -struct channel_history_t -{ - blt::u64 channelID; - blt::u64 time_changed; - std::string old_channel_name; -}; - -auto make_channel_history_table() -{ - return sql::make_table("channel_history", - sql::make_column("channelID", &channel_history_t::channelID), - sql::make_column("time_changed", &channel_history_t::time_changed), - sql::make_column("old_channel_name", &channel_history_t::old_channel_name), - sql::foreign_key(&channel_history_t::channelID).references(&channel_info_t::channelID), - sql::primary_key(&channel_history_t::channelID, &channel_history_t::time_changed)); -} - -using channel_history_table_t = decltype(make_channel_history_table()); - -struct message_t -{ - blt::u64 messageID; - blt::u64 channelID; - blt::u64 userID; - std::string content; -}; - -auto make_message_table() -{ - return sql::make_table("messages", - sql::make_column("messageID", &message_t::messageID, sql::primary_key()), - sql::make_column("channelID", &message_t::channelID), - sql::make_column("userID", &message_t::userID), - sql::make_column("content", &message_t::content), - sql::foreign_key(&message_t::channelID).references(&channel_info_t::channelID), - sql::foreign_key(&message_t::userID).references(&user_info_t::userID)); -} - -using message_table_t = decltype(make_message_table()); - -struct attachment_t -{ - blt::u64 messageID; - std::string url; -}; - -auto make_attachment_table() -{ - return sql::make_table("attachments", - sql::make_column("messageID", &attachment_t::messageID), - sql::make_column("url", &attachment_t::url), - sql::foreign_key(&attachment_t::messageID).references(&message_t::messageID), - sql::primary_key(&attachment_t::messageID, &attachment_t::url)); -} - -using attachment_table_t = decltype(make_attachment_table()); - -struct message_edits_t -{ - blt::u64 messageID; - std::string old_content; - std::string new_content; -}; - -auto make_message_edits_table() -{ - return sql::make_table("message_edits", - sql::make_column("messageID", &message_edits_t::messageID), - sql::make_column("old_content", &message_edits_t::old_content), - sql::make_column("new_content", &message_edits_t::new_content), - sql::foreign_key(&message_edits_t::messageID).references(&message_t::messageID), - sql::primary_key(&message_edits_t::messageID, &message_edits_t::old_content, &message_edits_t::new_content)); -} - -using message_edits_table_t = decltype(make_message_edits_table()); - -struct message_deletes_t -{ - blt::u64 messageID; - blt::u64 channelID; - std::string content; -}; - -auto make_message_deletes_table() -{ - return sql::make_table("message_deletes", - sql::make_column("messageID", &message_deletes_t::messageID), - sql::make_column("channelID", &message_deletes_t::channelID), - sql::make_column("content", &message_deletes_t::content), - sql::foreign_key(&message_deletes_t::messageID).references(&message_t::messageID), - sql::foreign_key(&message_deletes_t::channelID).references(&channel_info_t::channelID), - sql::primary_key(&message_deletes_t::messageID, &message_deletes_t::channelID)); -} - -using message_deletes_table_t = decltype(make_message_deletes_table()); - -auto make_database(std::string path) -{ - return sql::make_storage(std::move(path), make_user_table(), make_user_history_table(), make_channel_table(), make_channel_history_table(), - make_message_table(), make_attachment_table(), make_message_edits_table(), make_message_deletes_table()); -} - -using database_type = decltype(make_database("")); +using namespace db; struct db_obj { private: blt::u64 guildID; std::atomic_bool loaded_channels = false; + std::atomic_bool loaded_members = false; + blt::u64 user_count = -1; + std::atomic_uint64_t loaded_users = 0; + std::queue user_load_queue; + std::mutex user_load_queue_mutex; database_type db; - - void ensure_channel_exists() - { - - } - - void ensure_user_exists() - { - - } - - bool loading_complete() - { - return loaded_channels.load(); - } + std::thread* thread; public: explicit db_obj(blt::u64 guildID, const std::string& path): guildID(guildID), db(make_database(path + "/" + std::to_string(guildID) + "/")) @@ -216,7 +37,7 @@ struct db_obj void load(dpp::cluster& bot, const dpp::guild& guild) { - bot.channels_get(guild.id, [&bot, this, &guild](const dpp::confirmation_callback_t& event) { + bot.channels_get(guild.id, [this, guild, &bot](const dpp::confirmation_callback_t& event) { if (event.is_error()) { BLT_WARN("Failed to fetch channels for guild %ld", guildID); @@ -229,44 +50,39 @@ struct db_obj for (const auto& channel : channel_map) { - BLT_DEBUG("\tfetched channel id %ld with name '%s'", channel.first, channel.second.name.c_str()); + if (channel.second.is_category()) + continue; + BLT_DEBUG("\tFetched channel id %ld with name '%s'", channel.first, channel.second.name.c_str()); } loaded_channels = true; + BLT_INFO("Finished loading channels for guild '%s'", guild.name.c_str()); }); - while (!loaded_channels.load()) - {} - - for (const auto& member : guild.members) - { - BLT_TRACE("\tfetching user %ld -> %ld", member.first, member.second.user_id); - bot.user_get(member.first, [member, this](const dpp::confirmation_callback_t& event) { - if (event.is_error()) + bot.guild_get_members(guildID, 1000, 0, [this, &bot, guild](const dpp::confirmation_callback_t& event) { + if (event.is_error()) + { + BLT_WARN("Failed to fetch members for guild %ld", guildID); + BLT_WARN("Cause: %s", event.get_error().human_readable.c_str()); + loaded_members = true; + return; + } + auto member_map = event.get(); + BLT_INFO("Guild '%s' member count: %ld", guild.name.c_str(), member_map.size()); + { + user_count = member_map.size(); + std::scoped_lock lock(user_load_queue_mutex); + for (const auto& member : member_map) { - BLT_WARN("Failed to fetch user $ld for guild '%s'", member.first, guildID); - BLT_WARN("Cause: %s", event.get_error().human_readable.c_str()); - return; + BLT_DEBUG("\tFetched member '%s'", member.second.get_nickname().c_str()); + user_load_queue.push(member.first); } - auto user = event.get(); - - BLT_DEBUG("We got user '%s' with username '%s' and global name '%s'", user.username.c_str(), user.global_name.c_str()); - }); - bot.guild_get_member(guildID, member.first, [&guild, member](const dpp::confirmation_callback_t& event) { - if (event.is_error()) - { - BLT_WARN("Failed to fetch member %ld for guild %ld", member.first, guild.id); - BLT_WARN("Cause: %s", event.get_error().human_readable.c_str()); - return; - } - auto user = event.get(); - - BLT_DEBUG("Member of guild '%s' with nickname '%s'", guild.name.c_str(), user.get_nickname().c_str()); - }); - } + } + BLT_INFO("Finished loading members for guild '%s'", guild.name.c_str()); + loaded_members = true; + }); - while (!loading_complete()) - {} - BLT_DEBUG("Finished loading guild '%s'", guild.name.c_str()); + BLT_DEBUG("Finished requesting info for guild '%s'", guild.name.c_str()); + process_queue(bot); } void commit(const user_info_t& edited) @@ -308,6 +124,64 @@ struct db_obj { } + + void process_queue(dpp::cluster& bot) + { + thread = new std::thread([this, &bot]() { + while (user_count != loaded_users.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + blt::u64 member = 0; + { + std::scoped_lock lock(user_load_queue_mutex); + if (user_load_queue.empty()) + continue; + member = user_load_queue.front(); + user_load_queue.pop(); + } + bot.user_get(member, [member, this](const dpp::confirmation_callback_t& event) { + if (event.is_error()) + { + BLT_WARN("Failed to fetch user %ld for guild '%ld'", member, guildID); + BLT_WARN("Cause: %s", event.get_error().human_readable.c_str()); + BLT_INFO("Error code %d with error message '%s'", event.get_error().code, event.get_error().message.c_str()); + // requeue on rate limit + if (event.get_error().code == 0) + { + for (const auto& v : event.get_error().errors) + { + BLT_TRACE0_STREAM << "\t" << v.code << '\n'; + BLT_TRACE0_STREAM << "\t" << v.field << '\n'; + BLT_TRACE0_STREAM << "\t" << v.index << '\n'; + BLT_TRACE0_STREAM << "\t" << v.object << '\n'; + BLT_TRACE0_STREAM << "\t" << v.reason << '\n'; + } + std::scoped_lock lock(user_load_queue_mutex); + user_load_queue.push(member); + } else + loaded_users++; + BLT_INFO("%ld vs %ld", user_count, loaded_users.load()); + return; + } + auto user = event.get(); + + BLT_DEBUG("We got user '%s' with global name '%s'", user.username.c_str(), user.global_name.c_str()); + loaded_users++; + }); + } + }); + } + + bool loading_complete() + { + return loaded_channels.load() && loaded_members.load() && user_count != -1ul && user_count == loaded_users.load(); + } + + ~db_obj() + { + thread->join(); + delete thread; + } }; @@ -315,10 +189,23 @@ blt::hashmap_t> databases; std::string path; blt::u64 total_guilds = 0; std::atomic_uint64_t completed_guilds = 0; +std::atomic_bool finished_loading = false; bool loading_complete() { - return total_guilds != 0 && total_guilds == completed_guilds.load(); + bool finished = true; + if (!finished_loading.load()) + { + for (const auto& v : databases) + { + if (!v.second->loading_complete()) + finished = false; + } + finished_loading.store(finished); + if (finished) + BLT_INFO("Loading complete!"); + } + return finished && total_guilds != 0 && total_guilds == completed_guilds.load(); } db_obj& get(blt::u64 id) @@ -370,12 +257,17 @@ int main(int argc, const char** argv) completed_guilds++; }); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } }); bot.on_user_update(wait_wrapper([&bot](const dpp::user_update_t& event) { BLT_INFO("User '%s' updated in some way; global name: '%s'", event.updated.username.c_str(), event.updated.global_name.c_str()); + for (const auto& guild : databases) + { + + } })); bot.on_guild_member_update(wait_wrapper([&bot](const dpp::guild_member_update_t& event) {