Merge remote-tracking branch 'refs/remotes/origin/main'

main
Brett 2024-02-27 23:58:11 -05:00
commit 853c58a853
3 changed files with 321 additions and 221 deletions

208
include/data_structs.h Normal file
View File

@ -0,0 +1,208 @@
/*
* <Short Description>
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef DISCORD_BOT_DATA_STRUCTS_H
#define DISCORD_BOT_DATA_STRUCTS_H
#include <dpp/dpp.h>
#include <sqlite_orm/sqlite_orm.h>
namespace db
{
struct server_info_t
{
blt::u32 member_count;
std::string name;
std::string description;
std::string icon;
std::string splash;
std::string discovery_splash;
std::string banner;
};
struct user_info_t
{
blt::u64 userID;
std::string username;
std::string global_nickname;
std::string server_name;
};
struct user_history_t
{
blt::u64 userID;
blt::u64 time_changed;
std::string old_username;
std::string old_global_nickname;
std::string old_server_name;
};
struct channel_info_t
{
blt::u64 channelID;
std::string channel_name;
};
struct channel_history_t
{
blt::u64 channelID;
blt::u64 time_changed;
std::string old_channel_name;
};
struct message_t
{
blt::u64 messageID;
blt::u64 channelID;
blt::u64 userID;
std::string content;
};
struct attachment_t
{
blt::u64 messageID;
std::string url;
};
struct message_edits_t
{
blt::u64 messageID;
std::string old_content;
std::string new_content;
};
struct message_deletes_t
{
blt::u64 messageID;
blt::u64 channelID;
std::string content;
};
auto make_user_table()
{
using namespace sqlite_orm;
return make_table("users",
make_column("userID", &user_info_t::userID, primary_key()),
make_column("username", &user_info_t::username),
make_column("global_nickname", &user_info_t::global_nickname),
make_column("server_name", &user_info_t::server_name));
}
using user_table_t = decltype(make_user_table());
auto make_user_history_table()
{
using namespace sqlite_orm;
return make_table("user_history",
make_column("userID", &user_history_t::userID),
make_column("time_changed", &user_history_t::time_changed),
make_column("old_username", &user_history_t::old_username),
make_column("old_global_nickname", &user_history_t::old_global_nickname),
make_column("old_server_name", &user_history_t::old_server_name),
foreign_key(&user_history_t::userID).references(&user_info_t::userID),
primary_key(&user_history_t::userID, &user_history_t::time_changed));
}
using user_history_table_t = decltype(make_user_history_table());
auto make_channel_table()
{
using namespace sqlite_orm;
return make_table("channels",
make_column("channelID", &channel_info_t::channelID, primary_key()),
make_column("channel_name", &channel_info_t::channel_name));
}
using channel_table_t = decltype(make_channel_table());
auto make_channel_history_table()
{
using namespace sqlite_orm;
return make_table("channel_history",
make_column("channelID", &channel_history_t::channelID),
make_column("time_changed", &channel_history_t::time_changed),
make_column("old_channel_name", &channel_history_t::old_channel_name),
foreign_key(&channel_history_t::channelID).references(&channel_info_t::channelID),
primary_key(&channel_history_t::channelID, &channel_history_t::time_changed));
}
using channel_history_table_t = decltype(make_channel_history_table());
auto make_message_table()
{
return make_table("messages",
make_column("messageID", &message_t::messageID, primary_key()),
make_column("channelID", &message_t::channelID),
make_column("userID", &message_t::userID),
make_column("content", &message_t::content),
foreign_key(&message_t::channelID).references(&channel_info_t::channelID),
foreign_key(&message_t::userID).references(&user_info_t::userID));
}
using message_table_t = decltype(make_message_table());
auto make_attachment_table()
{
using namespace sqlite_orm;
return make_table("attachments",
make_column("messageID", &attachment_t::messageID),
make_column("url", &attachment_t::url),
foreign_key(&attachment_t::messageID).references(&message_t::messageID),
primary_key(&attachment_t::messageID, &attachment_t::url));
}
using attachment_table_t = decltype(make_attachment_table());
auto make_message_edits_table()
{
using namespace sqlite_orm;
return make_table("message_edits",
make_column("messageID", &message_edits_t::messageID),
make_column("old_content", &message_edits_t::old_content),
make_column("new_content", &message_edits_t::new_content),
foreign_key(&message_edits_t::messageID).references(&message_t::messageID),
primary_key(&message_edits_t::messageID, &message_edits_t::old_content, &message_edits_t::new_content));
}
using message_edits_table_t = decltype(make_message_edits_table());
auto make_message_deletes_table()
{
using namespace sqlite_orm;
return make_table("message_deletes",
make_column("messageID", &message_deletes_t::messageID),
make_column("channelID", &message_deletes_t::channelID),
make_column("content", &message_deletes_t::content),
foreign_key(&message_deletes_t::messageID).references(&message_t::messageID),
foreign_key(&message_deletes_t::channelID).references(&channel_info_t::channelID),
primary_key(&message_deletes_t::messageID, &message_deletes_t::channelID));
}
using message_deletes_table_t = decltype(make_message_deletes_table());
auto make_database(std::string path)
{
using namespace sqlite_orm;
return make_storage(std::move(path), make_user_table(), make_user_history_table(), make_channel_table(), make_channel_history_table(),
make_message_table(), make_attachment_table(), make_message_edits_table(), make_message_deletes_table());
}
using database_type = decltype(make_database(""));
}
#endif //DISCORD_BOT_DATA_STRUCTS_H

@ -1 +1 @@
Subproject commit 9ad652195b0a69f9977d313eff4dd01a7890f1df Subproject commit 9b4d0cc9a8493c608ab0075ab2c6a2b66061f3be

View File

@ -7,206 +7,27 @@
#include <sqlite_orm/sqlite_orm.h> #include <sqlite_orm/sqlite_orm.h>
#include "blt/std/types.h" #include "blt/std/types.h"
#include "blt/std/utility.h" #include "blt/std/utility.h"
#include <data_structs.h>
#include <curl/curl.h> #include <curl/curl.h>
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
#include <mutex> #include <mutex>
namespace sql = sqlite_orm; namespace sql = sqlite_orm;
using namespace db;
struct server_info_t
{
blt::u32 member_count;
std::string name;
std::string description;
std::string icon;
std::string splash;
std::string discovery_splash;
std::string banner;
};
struct user_info_t
{
blt::u64 userID;
std::string username;
std::string global_nickname;
std::string server_name;
};
auto make_user_table()
{
return sql::make_table("users",
sql::make_column("userID", &user_info_t::userID, sql::primary_key()),
sql::make_column("username", &user_info_t::username),
sql::make_column("global_nickname", &user_info_t::global_nickname),
sql::make_column("server_name", &user_info_t::server_name));
}
using user_table_t = decltype(make_user_table());
struct user_history_t
{
blt::u64 userID;
blt::u64 time_changed;
std::string old_username;
std::string old_global_nickname;
std::string old_server_name;
};
auto make_user_history_table()
{
return sql::make_table("user_history",
sql::make_column("userID", &user_history_t::userID),
sql::make_column("time_changed", &user_history_t::time_changed),
sql::make_column("old_username", &user_history_t::old_username),
sql::make_column("old_global_nickname", &user_history_t::old_global_nickname),
sql::make_column("old_server_name", &user_history_t::old_server_name),
sql::foreign_key(&user_history_t::userID).references(&user_info_t::userID),
sql::primary_key(&user_history_t::userID, &user_history_t::time_changed));
}
using user_history_table_t = decltype(make_user_history_table());
struct channel_info_t
{
blt::u64 channelID;
std::string channel_name;
};
auto make_channel_table()
{
return sql::make_table("channels",
sql::make_column("channelID", &channel_info_t::channelID, sql::primary_key()),
sql::make_column("channel_name", &channel_info_t::channel_name));
}
using channel_table_t = decltype(make_channel_table());
struct channel_history_t
{
blt::u64 channelID;
blt::u64 time_changed;
std::string old_channel_name;
};
auto make_channel_history_table()
{
return sql::make_table("channel_history",
sql::make_column("channelID", &channel_history_t::channelID),
sql::make_column("time_changed", &channel_history_t::time_changed),
sql::make_column("old_channel_name", &channel_history_t::old_channel_name),
sql::foreign_key(&channel_history_t::channelID).references(&channel_info_t::channelID),
sql::primary_key(&channel_history_t::channelID, &channel_history_t::time_changed));
}
using channel_history_table_t = decltype(make_channel_history_table());
struct message_t
{
blt::u64 messageID;
blt::u64 channelID;
blt::u64 userID;
std::string content;
};
auto make_message_table()
{
return sql::make_table("messages",
sql::make_column("messageID", &message_t::messageID, sql::primary_key()),
sql::make_column("channelID", &message_t::channelID),
sql::make_column("userID", &message_t::userID),
sql::make_column("content", &message_t::content),
sql::foreign_key(&message_t::channelID).references(&channel_info_t::channelID),
sql::foreign_key(&message_t::userID).references(&user_info_t::userID));
}
using message_table_t = decltype(make_message_table());
struct attachment_t
{
blt::u64 messageID;
std::string url;
};
auto make_attachment_table()
{
return sql::make_table("attachments",
sql::make_column("messageID", &attachment_t::messageID),
sql::make_column("url", &attachment_t::url),
sql::foreign_key(&attachment_t::messageID).references(&message_t::messageID),
sql::primary_key(&attachment_t::messageID, &attachment_t::url));
}
using attachment_table_t = decltype(make_attachment_table());
struct message_edits_t
{
blt::u64 messageID;
std::string old_content;
std::string new_content;
};
auto make_message_edits_table()
{
return sql::make_table("message_edits",
sql::make_column("messageID", &message_edits_t::messageID),
sql::make_column("old_content", &message_edits_t::old_content),
sql::make_column("new_content", &message_edits_t::new_content),
sql::foreign_key(&message_edits_t::messageID).references(&message_t::messageID),
sql::primary_key(&message_edits_t::messageID, &message_edits_t::old_content, &message_edits_t::new_content));
}
using message_edits_table_t = decltype(make_message_edits_table());
struct message_deletes_t
{
blt::u64 messageID;
blt::u64 channelID;
std::string content;
};
auto make_message_deletes_table()
{
return sql::make_table("message_deletes",
sql::make_column("messageID", &message_deletes_t::messageID),
sql::make_column("channelID", &message_deletes_t::channelID),
sql::make_column("content", &message_deletes_t::content),
sql::foreign_key(&message_deletes_t::messageID).references(&message_t::messageID),
sql::foreign_key(&message_deletes_t::channelID).references(&channel_info_t::channelID),
sql::primary_key(&message_deletes_t::messageID, &message_deletes_t::channelID));
}
using message_deletes_table_t = decltype(make_message_deletes_table());
auto make_database(std::string path)
{
return sql::make_storage(std::move(path), make_user_table(), make_user_history_table(), make_channel_table(), make_channel_history_table(),
make_message_table(), make_attachment_table(), make_message_edits_table(), make_message_deletes_table());
}
using database_type = decltype(make_database(""));
struct db_obj struct db_obj
{ {
private: private:
blt::u64 guildID; blt::u64 guildID;
std::atomic_bool loaded_channels = false; std::atomic_bool loaded_channels = false;
std::atomic_bool loaded_members = false;
blt::u64 user_count = -1;
std::atomic_uint64_t loaded_users = 0;
std::queue<blt::u64> user_load_queue;
std::mutex user_load_queue_mutex;
database_type db; database_type db;
std::thread* thread;
void ensure_channel_exists()
{
}
void ensure_user_exists()
{
}
bool loading_complete()
{
return loaded_channels.load();
}
public: public:
explicit db_obj(blt::u64 guildID, const std::string& path): guildID(guildID), db(make_database(path + "/" + std::to_string(guildID) + "/")) explicit db_obj(blt::u64 guildID, const std::string& path): guildID(guildID), db(make_database(path + "/" + std::to_string(guildID) + "/"))
@ -216,7 +37,7 @@ struct db_obj
void load(dpp::cluster& bot, const dpp::guild& guild) void load(dpp::cluster& bot, const dpp::guild& guild)
{ {
bot.channels_get(guild.id, [&bot, this, &guild](const dpp::confirmation_callback_t& event) { bot.channels_get(guild.id, [this, guild, &bot](const dpp::confirmation_callback_t& event) {
if (event.is_error()) if (event.is_error())
{ {
BLT_WARN("Failed to fetch channels for guild %ld", guildID); BLT_WARN("Failed to fetch channels for guild %ld", guildID);
@ -229,44 +50,39 @@ struct db_obj
for (const auto& channel : channel_map) for (const auto& channel : channel_map)
{ {
BLT_DEBUG("\tfetched channel id %ld with name '%s'", channel.first, channel.second.name.c_str()); if (channel.second.is_category())
continue;
BLT_DEBUG("\tFetched channel id %ld with name '%s'", channel.first, channel.second.name.c_str());
} }
loaded_channels = true; loaded_channels = true;
BLT_INFO("Finished loading channels for guild '%s'", guild.name.c_str());
}); });
while (!loaded_channels.load()) bot.guild_get_members(guildID, 1000, 0, [this, &bot, guild](const dpp::confirmation_callback_t& event) {
{} if (event.is_error())
{
for (const auto& member : guild.members) BLT_WARN("Failed to fetch members for guild %ld", guildID);
{ BLT_WARN("Cause: %s", event.get_error().human_readable.c_str());
BLT_TRACE("\tfetching user %ld -> %ld", member.first, member.second.user_id); loaded_members = true;
bot.user_get(member.first, [member, this](const dpp::confirmation_callback_t& event) { return;
if (event.is_error()) }
auto member_map = event.get<dpp::guild_member_map>();
BLT_INFO("Guild '%s' member count: %ld", guild.name.c_str(), member_map.size());
{
user_count = member_map.size();
std::scoped_lock lock(user_load_queue_mutex);
for (const auto& member : member_map)
{ {
BLT_WARN("Failed to fetch user $ld for guild '%s'", member.first, guildID); BLT_DEBUG("\tFetched member '%s'", member.second.get_nickname().c_str());
BLT_WARN("Cause: %s", event.get_error().human_readable.c_str()); user_load_queue.push(member.first);
return;
} }
auto user = event.get<dpp::user_identified>(); }
BLT_INFO("Finished loading members for guild '%s'", guild.name.c_str());
loaded_members = true;
});
BLT_DEBUG("We got user '%s' with username '%s' and global name '%s'", user.username.c_str(), user.global_name.c_str()); BLT_DEBUG("Finished requesting info for guild '%s'", guild.name.c_str());
}); process_queue(bot);
bot.guild_get_member(guildID, member.first, [&guild, member](const dpp::confirmation_callback_t& event) {
if (event.is_error())
{
BLT_WARN("Failed to fetch member %ld for guild %ld", member.first, guild.id);
BLT_WARN("Cause: %s", event.get_error().human_readable.c_str());
return;
}
auto user = event.get<dpp::guild_member>();
BLT_DEBUG("Member of guild '%s' with nickname '%s'", guild.name.c_str(), user.get_nickname().c_str());
});
}
while (!loading_complete())
{}
BLT_DEBUG("Finished loading guild '%s'", guild.name.c_str());
} }
void commit(const user_info_t& edited) void commit(const user_info_t& edited)
@ -309,16 +125,87 @@ struct db_obj
} }
void process_queue(dpp::cluster& bot)
{
thread = new std::thread([this, &bot]() {
while (user_count != loaded_users.load())
{
std::this_thread::sleep_for(std::chrono::milliseconds(50));
blt::u64 member = 0;
{
std::scoped_lock lock(user_load_queue_mutex);
if (user_load_queue.empty())
continue;
member = user_load_queue.front();
user_load_queue.pop();
}
bot.user_get(member, [member, this](const dpp::confirmation_callback_t& event) {
if (event.is_error())
{
BLT_WARN("Failed to fetch user %ld for guild '%ld'", member, guildID);
BLT_WARN("Cause: %s", event.get_error().human_readable.c_str());
BLT_INFO("Error code %d with error message '%s'", event.get_error().code, event.get_error().message.c_str());
// requeue on rate limit
if (event.get_error().code == 0)
{
for (const auto& v : event.get_error().errors)
{
BLT_TRACE0_STREAM << "\t" << v.code << '\n';
BLT_TRACE0_STREAM << "\t" << v.field << '\n';
BLT_TRACE0_STREAM << "\t" << v.index << '\n';
BLT_TRACE0_STREAM << "\t" << v.object << '\n';
BLT_TRACE0_STREAM << "\t" << v.reason << '\n';
}
std::scoped_lock lock(user_load_queue_mutex);
user_load_queue.push(member);
} else
loaded_users++;
BLT_INFO("%ld vs %ld", user_count, loaded_users.load());
return;
}
auto user = event.get<dpp::user_identified>();
BLT_DEBUG("We got user '%s' with global name '%s'", user.username.c_str(), user.global_name.c_str());
loaded_users++;
});
}
});
}
bool loading_complete()
{
return loaded_channels.load() && loaded_members.load() && user_count != -1ul && user_count == loaded_users.load();
}
~db_obj()
{
thread->join();
delete thread;
}
}; };
blt::hashmap_t<blt::u64, std::unique_ptr<db_obj>> databases; blt::hashmap_t<blt::u64, std::unique_ptr<db_obj>> databases;
std::string path; std::string path;
blt::u64 total_guilds = 0; blt::u64 total_guilds = 0;
std::atomic_uint64_t completed_guilds = 0; std::atomic_uint64_t completed_guilds = 0;
std::atomic_bool finished_loading = false;
bool loading_complete() bool loading_complete()
{ {
return total_guilds != 0 && total_guilds == completed_guilds.load(); bool finished = true;
if (!finished_loading.load())
{
for (const auto& v : databases)
{
if (!v.second->loading_complete())
finished = false;
}
finished_loading.store(finished);
if (finished)
BLT_INFO("Loading complete!");
}
return finished && total_guilds != 0 && total_guilds == completed_guilds.load();
} }
db_obj& get(blt::u64 id) db_obj& get(blt::u64 id)
@ -370,12 +257,17 @@ int main(int argc, const char** argv)
completed_guilds++; completed_guilds++;
}); });
std::this_thread::sleep_for(std::chrono::milliseconds(100));
} }
} }
}); });
bot.on_user_update(wait_wrapper<dpp::user_update_t>([&bot](const dpp::user_update_t& event) { bot.on_user_update(wait_wrapper<dpp::user_update_t>([&bot](const dpp::user_update_t& event) {
BLT_INFO("User '%s' updated in some way; global name: '%s'", event.updated.username.c_str(), event.updated.global_name.c_str()); BLT_INFO("User '%s' updated in some way; global name: '%s'", event.updated.username.c_str(), event.updated.global_name.c_str());
for (const auto& guild : databases)
{
}
})); }));
bot.on_guild_member_update(wait_wrapper<dpp::guild_member_update_t>([&bot](const dpp::guild_member_update_t& event) { bot.on_guild_member_update(wait_wrapper<dpp::guild_member_update_t>([&bot](const dpp::guild_member_update_t& event) {