From 72f30197001a055d98fcba6243c63cae398b82af Mon Sep 17 00:00:00 2001 From: Brett Laptop Date: Tue, 7 Jan 2025 20:52:23 -0500 Subject: [PATCH] silly boi --- CMakeLists.txt | 2 +- src/MNIST.cpp | 167 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 137 insertions(+), 32 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39853bd..ab8b351 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Final-Project VERSION 0.0.6) +project(COSC-4P80-Final-Project VERSION 0.0.7) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/src/MNIST.cpp b/src/MNIST.cpp index 81b32dd..b981430 100644 --- a/src/MNIST.cpp +++ b/src/MNIST.cpp @@ -20,7 +20,10 @@ #include #include #include +#include #include +#include +#include #include #include @@ -211,8 +214,8 @@ namespace fp struct batch_stats_t { - blt::u64 hits; - blt::u64 misses; + blt::u64 hits = 0; + blt::u64 misses = 0; friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats) { @@ -245,9 +248,9 @@ namespace fp struct epoch_stats_t { - batch_stats_t test_results; - double average_loss; - double learn_rate; + batch_stats_t test_results {}; + double average_loss = 0; + double learn_rate = 0; friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats) { @@ -306,8 +309,42 @@ namespace fp } return file; } + }; + struct network_average_stats_t + { + std::vector run_stats; + network_average_stats_t& operator+=(const network_stats_t& stats) + { + run_stats.push_back(stats); + return *this; + } + + blt::size_t average_size() const + { + blt::size_t acc = 0; + for (const auto& [epoch_stats] : run_stats) + acc += epoch_stats.size(); + return acc; + } + + network_stats_t average_stats() const + { + network_stats_t stats; + for (const auto& [epoch_stats] : run_stats) + { + if (stats.epoch_stats.size() < epoch_stats.size()) + stats.epoch_stats.resize(epoch_stats.size()); + for (const auto& [i, v] : blt::enumerate(epoch_stats)) + { + stats.epoch_stats[i] += v; + } + } + for (auto& v : stats.epoch_stats) + v /= run_stats.size(); + return stats; + } }; template @@ -341,7 +378,7 @@ namespace fp } template - void test_network(NetworkType& network) + batch_stats_t test_network(NetworkType& network) { const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"}; const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"}; @@ -351,9 +388,11 @@ namespace fp auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(), test_image.get_image_labels().begin()); - BLT_INFO("Testing hits: %lu", test_results.hits); - BLT_INFO("Testing misses: %lu", test_results.misses); - BLT_INFO("Testing accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses)); + BLT_DEBUG("Testing hits: %lu", test_results.hits); + BLT_DEBUG("Testing misses: %lu", test_results.misses); + BLT_DEBUG("Testing accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses)); + + return test_results; } template @@ -393,11 +432,12 @@ namespace fp data.begin() + end, labels.begin() + begin); } epoch_pos = 0; - BLT_DEBUG("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(), + BLT_TRACE("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(), trainer.get_learning_rate(), trainer.get_average_loss()); // sync and test trainer.get_net(dlib::force_flush_to_disk::no); + network.clean(); epoch_stats_t epoch_stats{}; epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(), @@ -405,10 +445,12 @@ namespace fp epoch_stats.average_loss = trainer.get_average_loss(); epoch_stats.learn_rate = trainer.get_learning_rate(); - BLT_DEBUG("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses, + BLT_TRACE("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses, epoch_stats.test_results.hits / static_cast(epoch_stats.test_results.hits + epoch_stats.test_results.misses)); stats.epoch_stats.push_back(epoch_stats); + network.clean(); + // dlib::serialize("mnist_network_" + ident + ".dat") << network; } BLT_INFO("Finished Training"); @@ -423,9 +465,9 @@ namespace fp auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(), train_image.get_image_labels().begin()); - BLT_INFO("Training hits: %lu", test_results.hits); - BLT_INFO("Training misses: %lu", test_results.misses); - BLT_INFO("Training accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses)); + BLT_DEBUG("Training hits: %lu", test_results.hits); + BLT_DEBUG("Training misses: %lu", test_results.misses); + BLT_DEBUG("Training accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses)); return stats; } @@ -438,29 +480,92 @@ namespace fp return network; } - void run_mnist(int argc, const char** argv) + template + void run_network_tests(std::string path, const std::string& ident, const blt::i32 runs, const bool restore) + { + path += ("/" + ident + "/"); + std::filesystem::create_directories(path); + std::filesystem::current_path(path); + + network_average_stats_t stats{}; + std::vector test_stats; + + for (blt::i32 i = 0; i < runs; i++) + { + auto local_ident = ident + std::to_string(i); + NetworkType network{}; + if (restore) + network = load_network(local_ident); + else + stats += train_network(local_ident, network); + test_stats.push_back(test_network(network)); + } + + + } + + void run_deep_learning_tests(const std::string& path, const blt::i32 runs, const bool restore) + { + using namespace dlib; + using net_type_dl = loss_multiclass_log< + fc<10, + relu>>>>>>>>>>>>>; + run_network_tests(path, "deep_learning", runs, restore); + } + + void run_feed_forward_tests(const std::string& path, const blt::i32 runs, const bool restore) { using namespace dlib; - // using net_type = loss_multiclass_log< - // fc<10, - // relu>>>>>>>>>>>>>; - - using net_type = loss_multiclass_log< + using net_type_ff = loss_multiclass_log< fc<10, - sig>>>>>>>; - net_type test_net; - const auto stats = train_network("fc_nn", test_net); - std::ofstream out_file{"fc_nn.csv"}; - out_file << stats; + run_network_tests(path, "feed_forward", runs, restore); + } - test_network(test_net); + void run_mnist(const int argc, const char** argv) + { + using namespace dlib; + + blt::arg_parse parser{}; + parser.addArgument( + blt::arg_builder{"-r", "--restore"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).setHelp("Restores from last save").build()); + parser.addArgument(blt::arg_builder{"-t", "--runs"}.setHelp("Number of runs to perform [default: 10]").setDefault("10").build()); + parser.addArgument(blt::arg_builder{"type"}.setDefault("all").setHelp("Type of network to run [ff, dl, default: all]").build()); + + auto args = parser.parse_args(argc, argv); + + const auto type = blt::string::toLowerCase(args.get("type")); + const auto runs = std::stoi(args.get("runs")); + const auto restore = args.get("restore"); + const auto path = "./" + std::to_string(blt::system::getCurrentTimeMilliseconds()); + + if (type == "all") + { + run_deep_learning_tests(path, runs, restore); + run_feed_forward_tests(path, runs, restore); + } else if (type == "ff") + { + run_feed_forward_tests(path, runs, restore); + } else if (type == "df") + { + run_deep_learning_tests(path, runs, restore); + } + + // net_type_dl test_net; + // const auto stats = train_network("dl_nn", test_net); + // std::ofstream out_file{"dl_nn.csv"}; + // out_file << stats; + + // test_net = load_network("dl_nn"); + + // test_network(test_net); } }