/* * <Short Description> * Copyright (C) 2024 Brett Terpstra * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <https://www.gnu.org/licenses/>. */ #include <MNIST.h> #include <blt/fs/loader.h> #include <blt/std/memory.h> #include <blt/std/memory_util.h> #include <variant> #include <filesystem> #include <iomanip> #include <blt/iterator/iterator.h> #include <blt/parse/argparse.h> #include <blt/std/time.h> #include <dlib/dnn.h> #include <dlib/data_io.h> #include <csignal> namespace fp { constexpr blt::i64 batch_size = 256; std::string binary_directory; std::string python_dual_stacked_graph_program; std::atomic_bool break_flag = false; std::atomic_bool stop_flag = false; std::atomic_bool learn_flag = false; std::atomic_int64_t last_epoch = -1; void run_python_line_graph(const std::string& title, const std::string& output_file, const std::string& csv1, const std::string& csv2, const blt::size_t pos_forward, const blt::size_t pos_deep) { const auto command = "python3 '" + python_dual_stacked_graph_program + "' '" + title + "' '" + output_file + "' '" + csv1 + "' '" + csv2 + "' " + std::to_string(pos_forward) + " " + std::to_string(pos_deep); BLT_TRACE("Running %s", command.c_str()); std::system(command.c_str()); } class idx_file_t { template <typename T> using mk_v = std::vector<T>; using vec_t = std::variant<mk_v<blt::u8>, mk_v<blt::i8>, mk_v<blt::u16>, mk_v<blt::u32>, mk_v<blt::f32>, mk_v<blt::f64>>; public: explicit idx_file_t(const std::string& path) { std::ifstream file{path, std::ios::in | std::ios::binary}; using char_type = std::ifstream::char_type; char_type magic_arr[4]; file.read(magic_arr, 4); BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0); blt::u8 dims = magic_arr[3]; blt::size_t total_size = 1; for (blt::i32 i = 0; i < dims; i++) { char_type dim_arr[4]; file.read(dim_arr, 4); blt::u32 dim; blt::mem::fromBytes(dim_arr, dim); dimensions.push_back(dim); total_size *= dim; } switch (magic_arr[2]) { // unsigned char case 0x08: data = mk_v<blt::u8>{}; read_data<blt::u8>(file, total_size); break; // signed char case 0x09: data = mk_v<blt::i8>{}; read_data<blt::i8>(file, total_size); break; // short case 0x0B: data = mk_v<blt::u16>{}; read_data<blt::u16>(file, total_size); reverse_data<blt::u16>(); break; // int case 0x0C: data = mk_v<blt::u32>{}; read_data<blt::u32>(file, total_size); reverse_data<blt::u32>(); break; // float case 0x0D: data = mk_v<blt::f32>{}; read_data<blt::f32>(file, total_size); reverse_data<blt::f32>(); break; // double case 0x0E: data = mk_v<blt::f64>{}; read_data<blt::f64>(file, total_size); reverse_data<blt::f64>(); break; default: BLT_ERROR("Unspported idx file type!"); } if (file.eof()) { BLT_ERROR("EOF reached. It's unlikely your file was read correctly!"); } } template <typename T> [[nodiscard]] const std::vector<T>& get_data_as() const { return std::get<mk_v<T>>(data); } template <typename T> std::vector<blt::span<T>> get_as_spans() const { std::vector<blt::span<T>> spans; blt::size_t total_size = data_size(1); for (blt::size_t i = 0; i < dimensions[0]; i++) { auto& array = std::get<mk_v<T>>(data); spans.push_back({&array[i * total_size], total_size}); } return spans; } [[nodiscard]] const std::vector<blt::u32>& get_dimensions() const { return dimensions; } [[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const { blt::size_t total_size = 1; for (const auto d : blt::iterate(dimensions).skip(starting_dimension)) total_size *= d; return total_size; } private: template <typename T> void read_data(std::ifstream& file, blt::size_t total_size) { auto& array = std::get<mk_v<T>>(data); array.resize(total_size); file.read(reinterpret_cast<char*>(array.data()), static_cast<std::streamsize>(total_size) * sizeof(T)); } template <typename T> void reverse_data() { auto& array = std::get<mk_v<T>>(data); for (auto& v : array) blt::mem::reverse(v); } std::vector<blt::u32> dimensions; vec_t data; }; class image_t { public: static constexpr blt::u32 target_size = 10; using data_iterator = std::vector<dlib::matrix<blt::u8>>::const_iterator; using label_iterator = std::vector<blt::u64>::const_iterator; image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]), input_size(image_data.data_size(1)) { BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0], ("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])). c_str()); auto& image_array = image_data.get_data_as<blt::u8>(); auto& label_array = label_data.get_data_as<blt::u8>(); for (const auto label : label_array) image_labels.push_back(label); const auto row_length = image_data.get_dimensions()[2]; const auto number_of_rows = image_data.get_dimensions()[1]; for (blt::u32 i = 0; i < samples; i++) { dlib::matrix<blt::u8> mat(number_of_rows, row_length); for (blt::u32 y = 0; y < number_of_rows; y++) { for (blt::u32 x = 0; x < row_length; x++) { mat(x, y) = image_array[i * input_size + y * row_length + x]; } } data.push_back(mat); } } [[nodiscard]] const std::vector<dlib::matrix<blt::u8>>& get_image_data() const { return data; } [[nodiscard]] const std::vector<blt::u64>& get_image_labels() const { return image_labels; } private: blt::u32 samples; blt::u32 input_size; std::vector<dlib::matrix<blt::u8>> data; std::vector<blt::u64> image_labels; }; struct batch_stats_t { blt::u64 hits = 0; blt::u64 misses = 0; friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats) { file << stats.hits << ',' << stats.misses; return file; } friend std::ifstream& operator>>(std::ifstream& file, batch_stats_t& stats) { file >> stats.hits; file.ignore(); file >> stats.misses; return file; } batch_stats_t& operator+=(const batch_stats_t& stats) { hits += stats.hits; misses += stats.misses; return *this; } batch_stats_t& operator/=(const blt::u64 divisor) { hits /= divisor; misses /= divisor; return *this; } }; struct epoch_stats_t { batch_stats_t test_results{}; double average_loss = 0; double learn_rate = 0; friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats) { file << stats.test_results << ',' << stats.average_loss << ',' << stats.learn_rate; return file; } friend std::ifstream& operator>>(std::ifstream& file, epoch_stats_t& stats) { file >> stats.test_results; file.ignore(); file >> stats.average_loss; file.ignore(); file >> stats.learn_rate; return file; } epoch_stats_t& operator+=(const epoch_stats_t& stats) { test_results += stats.test_results; average_loss += stats.average_loss; learn_rate += stats.learn_rate; return *this; } epoch_stats_t& operator/=(const blt::u64 divisor) { test_results /= divisor; average_loss /= static_cast<double>(divisor); learn_rate /= static_cast<double>(divisor); return *this; } }; struct network_stats_t { std::vector<epoch_stats_t> epoch_stats; friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats) { file << stats.epoch_stats.size(); file << '\n'; for (const auto& v : stats.epoch_stats) file << v << "\n"; return file; } friend std::ifstream& operator>>(std::ifstream& file, network_stats_t& stats) { blt::size_t size; file >> size; file.ignore(); for (blt::size_t i = 0; i < size; i++) { stats.epoch_stats.emplace_back(); file >> stats.epoch_stats.back(); file.ignore(); } return file; } }; struct network_average_stats_t { std::vector<network_stats_t> run_stats; network_average_stats_t& operator+=(const network_stats_t& stats) { run_stats.push_back(stats); return *this; } [[nodiscard]] blt::size_t average_size() const { blt::size_t acc = 0; for (const auto& [epoch_stats] : run_stats) acc += epoch_stats.size(); return acc; } [[nodiscard]] network_stats_t average_stats() const { network_stats_t stats; for (const auto& [epoch_stats] : run_stats) { if (stats.epoch_stats.size() < epoch_stats.size()) stats.epoch_stats.resize(epoch_stats.size()); for (const auto& [i, v] : blt::enumerate(epoch_stats)) { stats.epoch_stats[i] += v; } } for (auto& v : stats.epoch_stats) v /= run_stats.size(); return stats; } friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats) { file << stats.run_stats.size(); file << '\n'; for (const auto& v : stats.run_stats) file << v << "---\n"; return file; } friend std::ifstream& operator>>(std::ifstream& file, network_average_stats_t& stats) { blt::size_t size; file >> size; file.ignore(); for (blt::size_t i = 0; i < size; i++) { stats.run_stats.emplace_back(); file >> stats.run_stats.back(); file.ignore(4); } return file; } }; template <blt::i64 batch_size = batch_size, typename NetworkType> batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, const image_t::data_iterator end, image_t::label_iterator lbegin) { batch_stats_t stats{}; std::array<image_t::label_iterator::value_type, batch_size> output_labels{}; auto amount_remaining = std::distance(begin, end); while (amount_remaining != 0) { const auto batch = std::min(amount_remaining, batch_size); network(begin, begin + batch, output_labels.begin()); for (auto [predicted, expected] : blt::iterate(output_labels.begin(), output_labels.begin() + batch).zip(lbegin, lbegin + batch)) { if (predicted == expected) ++stats.hits; else ++stats.misses; } begin += batch; lbegin += batch; amount_remaining -= batch; } return stats; } template <typename NetworkType> batch_stats_t test_network(NetworkType& network) { const idx_file_t test_images{binary_directory + "../problems/mnist/t10k-images.idx3-ubyte"}; const idx_file_t test_labels{binary_directory + "../problems/mnist/t10k-labels.idx1-ubyte"}; const image_t test_image{test_images, test_labels}; auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(), test_image.get_image_labels().begin()); BLT_DEBUG("Testing hits: %lu", test_results.hits); BLT_DEBUG("Testing misses: %lu", test_results.misses); BLT_DEBUG("Testing accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses)); return test_results; } template <typename NetworkType> network_stats_t train_network(const std::string& ident, NetworkType& network) { const idx_file_t train_images{binary_directory + "../problems/mnist/train-images.idx3-ubyte"}; const idx_file_t train_labels{binary_directory + "../problems/mnist/train-labels.idx1-ubyte"}; const image_t train_image{train_images, train_labels}; network_stats_t stats; dlib::dnn_trainer trainer(network); trainer.set_learning_rate(0.01); trainer.set_min_learning_rate(0.00001); trainer.set_mini_batch_size(batch_size); trainer.set_max_num_epochs(100); trainer.set_iterations_without_progress_threshold(2000); trainer.be_verbose(); trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20)); blt::size_t epochs = 0; if (last_epoch > 0) epochs = static_cast<blt::size_t>(last_epoch); blt::ptrdiff_t epoch_pos = 0; for (; epochs < trainer.get_max_num_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++) { auto& data = train_image.get_image_data(); auto& labels = train_image.get_image_labels(); for (; epoch_pos < data.size() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epoch_pos += trainer. get_mini_batch_size()) { auto begin = epoch_pos; auto end = std::min(epoch_pos + trainer.get_mini_batch_size(), data.size()); if (end - begin <= 0) break; if (learn_flag) trainer.set_learning_rate(trainer.get_learning_rate() / 10); trainer.train_one_step(train_image.get_image_data().begin() + begin, data.begin() + end, labels.begin() + begin); } epoch_pos = 0; BLT_TRACE("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(), trainer.get_learning_rate(), trainer.get_average_loss()); // sync and test trainer.get_net(dlib::force_flush_to_disk::no); network.clean(); epoch_stats_t epoch_stats{}; epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(), train_image.get_image_labels().begin()); epoch_stats.average_loss = trainer.get_average_loss(); epoch_stats.learn_rate = trainer.get_learning_rate(); BLT_TRACE("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses, epoch_stats.test_results.hits / static_cast<double>(epoch_stats.test_results.hits + epoch_stats.test_results.misses)); stats.epoch_stats.push_back(epoch_stats); network.clean(); if (break_flag) { break_flag = false; last_epoch = epochs; break; } // dlib::serialize("mnist_network_" + ident + ".dat") << network; } BLT_INFO("Finished Training"); // sync trainer.get_net(); network.clean(); // trainer.train(train_image.get_image_data(), train_image.get_image_labels()); dlib::serialize("mnist_network_" + ident + ".dat") << network; auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(), train_image.get_image_labels().begin()); BLT_DEBUG("Training hits: %lu", test_results.hits); BLT_DEBUG("Training misses: %lu", test_results.misses); BLT_DEBUG("Training accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses)); return stats; } template <typename NetworkType> NetworkType load_network(const std::string& ident) { NetworkType network{}; dlib::deserialize("mnist_network_" + ident + ".dat") >> network; return network; } template <typename NetworkType> std::pair<network_average_stats_t, batch_stats_t> run_network_tests(std::string path, const std::string& ident, const blt::i32 runs, const bool restore) { path += ("/" + ident + "/"); std::filesystem::create_directories(path); std::filesystem::current_path(path); network_average_stats_t stats{}; std::vector<batch_stats_t> test_stats; blt::i32 i = 0; if (std::filesystem::exists(path + "/state.bin")) { std::ifstream state{path + "/state.bin", std::ios::binary | std::ios::in}; if (!state.is_open()) { BLT_ERROR("Failed to open state file!"); std::exit(-1); } state >> i; state.ignore(); blt::i64 load_epoch = 0; state >> load_epoch; state.ignore(); last_epoch = load_epoch; state >> stats; state.ignore(); blt::size_t test_stats_size = 0; state >> test_stats_size; state.ignore(); for (blt::size_t _ = 0; _ < test_stats_size; _++) { test_stats.emplace_back(); state >> test_stats.back(); state.ignore(); } } blt::i64 last_epoch_save = -1; for (; i < runs; i++) { if (stop_flag) { BLT_TRACE("Stopping!"); break; } BLT_TRACE("Starting run %d", i); auto local_ident = ident + std::to_string(i); NetworkType network{}; if (restore) try { network = load_network<NetworkType>(local_ident); } catch (dlib::serialization_error&) { stats += train_network(local_ident, network); } else stats += train_network(local_ident, network); last_epoch_save = last_epoch; last_epoch = -1; test_stats.push_back(test_network(network)); } batch_stats_t average; for (const auto& v : test_stats) average += v; average /= runs; std::ofstream state{path + "/state.bin", std::ios::binary | std::ios::out}; if (!state.is_open()) { BLT_ERROR("Failed to open state file!"); std::exit(-1); } state << i; state << '\n'; state << last_epoch_save; state << '\n'; state << stats; state << '\n'; state << test_stats.size(); state << '\n'; for (const auto& v : test_stats) { state << v; state << '\n'; } return {stats, average}; } auto run_deep_learning_tests(const std::string& path, const blt::i32 runs, const bool restore) { using namespace dlib; using net_type_dl = loss_multiclass_log< fc<10, relu<fc<84, relu<fc<120, max_pool<2, 2, 2, 2, relu<con<16, 5, 5, 1, 1, max_pool<2, 2, 2, 2, relu<con<6, 5, 5, 1, 1, input<matrix<blt::u8>>>>>>>>>>>>>>; BLT_TRACE("Running deep learning tests"); return run_network_tests<net_type_dl>(path, "deep_learning", runs, restore); } auto run_feed_forward_tests(const std::string& path, const blt::i32 runs, const bool restore) { using namespace dlib; using net_type_ff = loss_multiclass_log< fc<10, relu<fc<84, relu<fc<120, input<matrix<blt::u8>>>>>>>>; BLT_TRACE("Running feed forward tests"); return run_network_tests<net_type_ff>(path, "feed_forward", runs, restore); } void run_mnist(const int argc, const char** argv) { binary_directory = std::filesystem::current_path(); blt::size_t pos = 0; if (!blt::string::ends_with(binary_directory, '/')) { pos = binary_directory.find_last_of('/'); binary_directory += '/'; } else pos = binary_directory.substr(0, binary_directory.size() - 1).find_last_of('/'); python_dual_stacked_graph_program = binary_directory.substr(0, pos) + "/graph.py"; BLT_DEBUG(binary_directory); BLT_DEBUG(python_dual_stacked_graph_program); BLT_DEBUG("Running with batch size %d", batch_size); BLT_DEBUG("Installing Signal Handlers"); if (std::signal(SIGINT, [](int) { BLT_INFO("Stopping current training"); break_flag = true; }) == SIG_ERR) { BLT_ERROR("Failed to replace SIGINT"); } if (std::signal(SIGQUIT, [](int) { BLT_INFO("Exiting Program"); stop_flag = true; break_flag = true; }) == SIG_ERR) { BLT_ERROR("Failed to replace SIGQUIT"); } if (std::signal(SIGUSR1, [](int) { BLT_INFO("Decreasing Learn Rate for current training"); learn_flag = true; }) == SIG_ERR) { BLT_ERROR("Failed to replace SIGUSR1"); } if (std::signal(SIGUSR2, [](int) { BLT_INFO("Exiting Program"); stop_flag = true; break_flag = true; }) == SIG_ERR) { BLT_ERROR("Failed to replace SIGUSR2"); } using namespace dlib; blt::arg_parse parser{}; parser.addArgument( blt::arg_builder{"-r", "--restore"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).setHelp( "Restores from last save").build()); parser.addArgument(blt::arg_builder{"-t", "--runs"}.setHelp("Number of runs to perform [default: 10]").setDefault("10").build()); parser.addArgument( blt::arg_builder{"-p", "--python"}.setHelp("Only run the python scripts").setAction(blt::arg_action_t::STORE_TRUE).setDefault(false). build()); parser.addArgument( blt::arg_builder{"network"}.setDefault(std::to_string(blt::system::getCurrentTimeMilliseconds())).setHelp("location of network files"). build()); auto args = parser.parse_args(argc, argv); const auto runs = std::stoi(args.get<std::string>("runs")); const auto restore = args.get<bool>("restore"); auto path = binary_directory + args.get<std::string>("network"); auto [deep_stats, deep_tests] = run_deep_learning_tests(path, runs, restore); auto [forward_stats, forward_tests] = run_feed_forward_tests(path, runs, restore); auto average_forward_size = forward_stats.average_size(); auto average_deep_size = deep_stats.average_size(); { std::ofstream test_results_f{path + "/test_results_table.txt"}; test_results_f << "\\begin{figure}" << std::endl; test_results_f << "\t\\begin{tabular}{|c|c|c|c|}" << std::endl; test_results_f << "\t\t\\hline" << std::endl; test_results_f << "\t\tTest & Correct & Incorrect & Accuracy (\\%) \\\\" << std::endl; test_results_f << "\t\t\\hline" << std::endl; auto test_accuracy = forward_tests.hits / static_cast<double>(forward_tests.hits + forward_tests.misses) * 100; test_results_f << "\t\tFeed-Forward & " << forward_tests.hits << " & " << forward_tests.misses << " & " << std::setprecision(2) << test_accuracy << "\\\\" << std::endl; test_accuracy = deep_tests.hits / static_cast<double>(deep_tests.hits + deep_tests.misses) * 100; test_results_f << "\t\tDeep Learning & " << deep_tests.hits << " & " << deep_tests.misses << " & " << std::setprecision(2) << test_accuracy << "\\\\" << std::endl; test_results_f << "\t\\end{tabular}" << std::endl; test_results_f << "\\end{figure}" << std::endl; const auto [forward_epoch_stats] = forward_stats.average_stats(); std::ofstream train_forward{path + "/forward_train_results.csv"}; train_forward << "Epoch,Loss" << std::endl; for (const auto& [i, v] : blt::enumerate(forward_epoch_stats)) train_forward << i << ',' << v.average_loss << std::endl; const auto [deep_epoch_stats] = deep_stats.average_stats(); std::ofstream train_deep{path + "/deep_train_results.csv"}; train_deep << "Epoch,Loss" << std::endl; for (const auto& [i, v] : blt::enumerate(deep_epoch_stats)) train_deep << i << ',' << v.average_loss << std::endl; std::ofstream average_epochs{path + "/average_epochs.txt"}; average_epochs << average_forward_size << "," << average_deep_size << std::endl; } run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv", path + "/deep_train_results.csv", average_forward_size, average_deep_size); // net_type_dl test_net; // const auto stats = train_network("dl_nn", test_net); // std::ofstream out_file{"dl_nn.csv"}; // out_file << stats; // test_net = load_network<net_type_dl>("dl_nn"); // test_network(test_net); } }