diff --git a/.idea/workspace (conflicted copy 2025-01-08 173718).xml b/.idea/workspace (conflicted copy 2025-01-08 173718).xml
new file mode 100644
index 0000000..d6ba888
--- /dev/null
+++ b/.idea/workspace (conflicted copy 2025-01-08 173718).xml
@@ -0,0 +1,242 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "useNewFormat": true
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "associatedIndex": 0
+}
+
+
+
+
+
+ {
+ "keyToString": {
+ "CMake Application.COSC-4P80-Final-Project.executor": "Run",
+ "RunOnceActivity.RadMigrateCodeStyle": "true",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "RunOnceActivity.cidr.known.project.marker": "true",
+ "RunOnceActivity.readMode.enableVisualFormatting": "true",
+ "RunOnceActivity.west.config.association.type.startup.service": "true",
+ "SHARE_PROJECT_CONFIGURATION_FILES": "true",
+ "cf.first.check.clang-format": "false",
+ "cidr.known.project.marker": "true",
+ "git-widget-placeholder": "main",
+ "last_opened_file_path": "/home/brett/Documents/Brock/CS 4P80/COSC-4P80-Final-Project",
+ "node.js.detected.package.eslint": "true",
+ "node.js.detected.package.tslint": "true",
+ "node.js.selected.package.eslint": "(autodetect)",
+ "node.js.selected.package.tslint": "(autodetect)",
+ "nodejs_package_manager_path": "npm",
+ "settings.editor.selected.configurable": "preferences.lookFeel",
+ "vue.rearranger.settings.migration": "true"
+ }
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ PDFLATEX
+
+
+ OKULAR
+
+
+
+
+
+ {projectDir}/out
+ {projectDir}/auxil
+ false
+ PDF
+ TEXLIVE
+ false
+ []
+ []
+
+
+
+
+
+ MAKEINDEX
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1733702642308
+
+
+ 1733702642308
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 43b1b91..e68ab05 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
-project(COSC-4P80-Final-Project VERSION 0.0.9)
+project(COSC-4P80-Final-Project VERSION 0.0.10)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
diff --git a/src/MNIST (conflicted copy 2025-01-08 173718).cpp b/src/MNIST (conflicted copy 2025-01-08 173718).cpp
new file mode 100644
index 0000000..5c6a010
--- /dev/null
+++ b/src/MNIST (conflicted copy 2025-01-08 173718).cpp
@@ -0,0 +1,664 @@
+/*
+ *
+ * Copyright (C) 2024 Brett Terpstra
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace fp
+{
+ constexpr blt::i64 batch_size = 256;
+
+ std::string binary_directory;
+ std::string python_dual_stacked_graph_program;
+ std::atomic_bool break_flag = false;
+ std::atomic_bool stop_flag = false;
+
+ void run_python_line_graph(const std::string& title, const std::string& output_file, const std::string& csv1, const std::string& csv2,
+ const blt::size_t pos_forward, const blt::size_t pos_deep)
+ {
+ const auto command = "python3 " + python_dual_stacked_graph_program + " '" + title + "' '" + output_file + "' '" + csv1 + "' '" + csv2 + "' "
+ + std::to_string(pos_forward) + " " + std::to_string(pos_deep);
+ BLT_TRACE("Running %s", command.c_str());
+ std::system(command.c_str());
+ }
+
+ class idx_file_t
+ {
+ template
+ using mk_v = std::vector;
+ using vec_t = std::variant, mk_v, mk_v, mk_v, mk_v, mk_v>;
+
+ public:
+ explicit idx_file_t(const std::string& path)
+ {
+ std::ifstream file{path, std::ios::in | std::ios::binary};
+
+ using char_type = std::ifstream::char_type;
+ char_type magic_arr[4];
+ file.read(magic_arr, 4);
+ BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0);
+
+ blt::u8 dims = magic_arr[3];
+ blt::size_t total_size = 1;
+
+ for (blt::i32 i = 0; i < dims; i++)
+ {
+ char_type dim_arr[4];
+ file.read(dim_arr, 4);
+ blt::u32 dim;
+ blt::mem::fromBytes(dim_arr, dim);
+ dimensions.push_back(dim);
+ total_size *= dim;
+ }
+
+ switch (magic_arr[2])
+ {
+ // unsigned char
+ case 0x08:
+ data = mk_v{};
+ read_data(file, total_size);
+ break;
+ // signed char
+ case 0x09:
+ data = mk_v{};
+ read_data(file, total_size);
+ break;
+ // short
+ case 0x0B:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ // int
+ case 0x0C:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ // float
+ case 0x0D:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ // double
+ case 0x0E:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ default:
+ BLT_ERROR("Unspported idx file type!");
+ }
+ if (file.eof())
+ {
+ BLT_ERROR("EOF reached. It's unlikely your file was read correctly!");
+ }
+ }
+
+ template
+ [[nodiscard]] const std::vector& get_data_as() const
+ {
+ return std::get>(data);
+ }
+
+ template
+ std::vector> get_as_spans() const
+ {
+ std::vector> spans;
+
+ blt::size_t total_size = data_size(1);
+
+ for (blt::size_t i = 0; i < dimensions[0]; i++)
+ {
+ auto& array = std::get>(data);
+ spans.push_back({&array[i * total_size], total_size});
+ }
+
+ return spans;
+ }
+
+ [[nodiscard]] const std::vector& get_dimensions() const
+ {
+ return dimensions;
+ }
+
+ [[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const
+ {
+ blt::size_t total_size = 1;
+ for (const auto d : blt::iterate(dimensions).skip(starting_dimension))
+ total_size *= d;
+ return total_size;
+ }
+
+ private:
+ template
+ void read_data(std::ifstream& file, blt::size_t total_size)
+ {
+ auto& array = std::get>(data);
+ array.resize(total_size);
+ file.read(reinterpret_cast(array.data()), static_cast(total_size) * sizeof(T));
+ }
+
+ template
+ void reverse_data()
+ {
+ auto& array = std::get>(data);
+ for (auto& v : array)
+ blt::mem::reverse(v);
+ }
+
+ std::vector dimensions;
+ vec_t data;
+ };
+
+ class image_t
+ {
+ public:
+ static constexpr blt::u32 target_size = 10;
+ using data_iterator = std::vector>::const_iterator;
+ using label_iterator = std::vector::const_iterator;
+
+ image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]),
+ input_size(image_data.data_size(1))
+ {
+ BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0],
+ ("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])).
+ c_str());
+ auto& image_array = image_data.get_data_as();
+ auto& label_array = label_data.get_data_as();
+
+ for (const auto label : label_array)
+ image_labels.push_back(label);
+
+ const auto row_length = image_data.get_dimensions()[2];
+ const auto number_of_rows = image_data.get_dimensions()[1];
+
+ for (blt::u32 i = 0; i < samples; i++)
+ {
+ dlib::matrix mat(number_of_rows, row_length);
+ for (blt::u32 y = 0; y < number_of_rows; y++)
+ {
+ for (blt::u32 x = 0; x < row_length; x++)
+ {
+ mat(x, y) = image_array[i * input_size + y * row_length + x];
+ }
+ }
+ data.push_back(mat);
+ }
+ }
+
+ [[nodiscard]] const std::vector>& get_image_data() const
+ {
+ return data;
+ }
+
+ [[nodiscard]] const std::vector& get_image_labels() const
+ {
+ return image_labels;
+ }
+
+ private:
+ blt::u32 samples;
+ blt::u32 input_size;
+ std::vector> data;
+ std::vector image_labels;
+ };
+
+ struct batch_stats_t
+ {
+ blt::u64 hits = 0;
+ blt::u64 misses = 0;
+
+ friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats)
+ {
+ file << stats.hits << ',' << stats.misses;
+ return file;
+ }
+
+ friend std::ifstream& operator>>(std::ifstream& file, batch_stats_t& stats)
+ {
+ file >> stats.hits;
+ file.ignore();
+ file >> stats.misses;
+ return file;
+ }
+
+ batch_stats_t& operator+=(const batch_stats_t& stats)
+ {
+ hits += stats.hits;
+ misses += stats.misses;
+ return *this;
+ }
+
+ batch_stats_t& operator/=(const blt::u64 divisor)
+ {
+ hits /= divisor;
+ misses /= divisor;
+ return *this;
+ }
+ };
+
+ struct epoch_stats_t
+ {
+ batch_stats_t test_results{};
+ double average_loss = 0;
+ double learn_rate = 0;
+
+ friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats)
+ {
+ file << stats.test_results << ',' << stats.average_loss << ',' << stats.learn_rate;
+ return file;
+ }
+
+ friend std::ifstream& operator>>(std::ifstream& file, epoch_stats_t& stats)
+ {
+ file >> stats.test_results;
+ file.ignore();
+ file >> stats.average_loss;
+ file.ignore();
+ file >> stats.learn_rate;
+ return file;
+ }
+
+ epoch_stats_t& operator+=(const epoch_stats_t& stats)
+ {
+ test_results += stats.test_results;
+ average_loss += stats.average_loss;
+ learn_rate += stats.learn_rate;
+ return *this;
+ }
+
+ epoch_stats_t& operator/=(const blt::u64 divisor)
+ {
+ test_results /= divisor;
+ average_loss /= static_cast(divisor);
+ learn_rate /= static_cast(divisor);
+ return *this;
+ }
+ };
+
+ struct network_stats_t
+ {
+ std::vector epoch_stats;
+
+ friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats)
+ {
+ file << stats.epoch_stats.size();
+ for (const auto& v : stats.epoch_stats)
+ file << v << "\n";
+ return file;
+ }
+
+ friend std::ifstream& operator>>(std::ifstream& file, network_stats_t& stats)
+ {
+ blt::size_t size;
+ file >> size;
+ for (blt::size_t i = 0; i < size; i++)
+ {
+ stats.epoch_stats.emplace_back();
+ file >> stats.epoch_stats.back();
+ file.ignore();
+ }
+ return file;
+ }
+ };
+
+ struct network_average_stats_t
+ {
+ std::vector run_stats;
+
+ network_average_stats_t& operator+=(const network_stats_t& stats)
+ {
+ run_stats.push_back(stats);
+ return *this;
+ }
+
+ [[nodiscard]] blt::size_t average_size() const
+ {
+ blt::size_t acc = 0;
+ for (const auto& [epoch_stats] : run_stats)
+ acc += epoch_stats.size();
+ return acc;
+ }
+
+ [[nodiscard]] network_stats_t average_stats() const
+ {
+ network_stats_t stats;
+ for (const auto& [epoch_stats] : run_stats)
+ {
+ if (stats.epoch_stats.size() < epoch_stats.size())
+ stats.epoch_stats.resize(epoch_stats.size());
+ for (const auto& [i, v] : blt::enumerate(epoch_stats))
+ {
+ stats.epoch_stats[i] += v;
+ }
+ }
+ for (auto& v : stats.epoch_stats)
+ v /= run_stats.size();
+ return stats;
+ }
+ };
+
+ template
+ batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, const image_t::data_iterator end, image_t::label_iterator lbegin)
+ {
+ batch_stats_t stats{};
+
+ std::array output_labels{};
+
+ auto amount_remaining = std::distance(begin, end);
+
+ while (amount_remaining != 0)
+ {
+ const auto batch = std::min(amount_remaining, batch_size);
+ network(begin, begin + batch, output_labels.begin());
+
+ for (auto [predicted, expected] : blt::iterate(output_labels.begin(), output_labels.begin() + batch).zip(lbegin, lbegin + batch))
+ {
+ if (predicted == expected)
+ ++stats.hits;
+ else
+ ++stats.misses;
+ }
+
+ begin += batch;
+ lbegin += batch;
+ amount_remaining -= batch;
+ }
+
+ return stats;
+ }
+
+ template
+ batch_stats_t test_network(NetworkType& network)
+ {
+ const idx_file_t test_images{binary_directory + "../problems/mnist/t10k-images.idx3-ubyte"};
+ const idx_file_t test_labels{binary_directory + "../problems/mnist/t10k-labels.idx1-ubyte"};
+
+ const image_t test_image{test_images, test_labels};
+
+ auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(),
+ test_image.get_image_labels().begin());
+
+ BLT_DEBUG("Testing hits: %lu", test_results.hits);
+ BLT_DEBUG("Testing misses: %lu", test_results.misses);
+ BLT_DEBUG("Testing accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses));
+
+ return test_results;
+ }
+
+ template
+ network_stats_t train_network(const std::string& ident, NetworkType& network)
+ {
+ const idx_file_t train_images{binary_directory + "../problems/mnist/train-images.idx3-ubyte"};
+ const idx_file_t train_labels{binary_directory + "../problems/mnist/train-labels.idx1-ubyte"};
+
+ const image_t train_image{train_images, train_labels};
+
+ network_stats_t stats;
+
+ dlib::dnn_trainer trainer(network);
+ trainer.set_learning_rate(0.01);
+ trainer.set_min_learning_rate(0.00001);
+ trainer.set_mini_batch_size(batch_size);
+ trainer.set_max_num_epochs(100);
+ trainer.set_iterations_without_progress_threshold(300);
+ trainer.be_verbose();
+
+ trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20));
+
+ blt::size_t epochs = 0;
+ blt::ptrdiff_t epoch_pos = 0;
+ for (; epochs < trainer.get_max_num_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++)
+ {
+ auto& data = train_image.get_image_data();
+ auto& labels = train_image.get_image_labels();
+ for (; epoch_pos < data.size() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epoch_pos += trainer.
+ get_mini_batch_size())
+ {
+ auto begin = epoch_pos;
+ auto end = std::min(epoch_pos + trainer.get_mini_batch_size(), data.size());
+
+ if (end - begin <= 0)
+ break;
+
+ trainer.train_one_step(train_image.get_image_data().begin() + begin,
+ data.begin() + end, labels.begin() + begin);
+ }
+ epoch_pos = 0;
+ BLT_TRACE("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(),
+ trainer.get_learning_rate(), trainer.get_average_loss());
+
+ // sync and test
+ trainer.get_net(dlib::force_flush_to_disk::no);
+ network.clean();
+
+ epoch_stats_t epoch_stats{};
+ epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
+ train_image.get_image_labels().begin());
+ epoch_stats.average_loss = trainer.get_average_loss();
+ epoch_stats.learn_rate = trainer.get_learning_rate();
+
+ BLT_TRACE("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses,
+ epoch_stats.test_results.hits / static_cast(epoch_stats.test_results.hits + epoch_stats.test_results.misses));
+
+ stats.epoch_stats.push_back(epoch_stats);
+ network.clean();
+ if (break_flag)
+ {
+ break_flag = false;
+ break;
+ }
+ // dlib::serialize("mnist_network_" + ident + ".dat") << network;
+ }
+
+ BLT_INFO("Finished Training");
+
+ // sync
+ trainer.get_net();
+ network.clean();
+
+ // trainer.train(train_image.get_image_data(), train_image.get_image_labels());
+ dlib::serialize("mnist_network_" + ident + ".dat") << network;
+
+ auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
+ train_image.get_image_labels().begin());
+
+ BLT_DEBUG("Training hits: %lu", test_results.hits);
+ BLT_DEBUG("Training misses: %lu", test_results.misses);
+ BLT_DEBUG("Training accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses));
+
+ return stats;
+ }
+
+ template
+ NetworkType load_network(const std::string& ident)
+ {
+ NetworkType network{};
+ dlib::deserialize("mnist_network_" + ident + ".dat") >> network;
+ return network;
+ }
+
+ template
+ std::pair run_network_tests(std::string path, const std::string& ident, const blt::i32 runs,
+ const bool restore)
+ {
+ path += ("/" + ident + "/");
+ std::filesystem::create_directories(path);
+ std::filesystem::current_path(path);
+
+ network_average_stats_t stats{};
+ std::vector test_stats;
+
+ for (blt::i32 i = 0; i < runs; i++)
+ {
+ BLT_TRACE("Starting run %d", i);
+ auto local_ident = ident + std::to_string(i);
+ NetworkType network{};
+ if (restore)
+ try
+ {
+ network = load_network(local_ident);
+ }
+ catch (dlib::serialization_error&)
+ {
+ stats += train_network(local_ident, network);
+ }
+ else
+ stats += train_network(local_ident, network);
+ test_stats.push_back(test_network(network));
+ }
+
+ batch_stats_t average;
+ for (const auto& v : test_stats)
+ average += v;
+ average /= runs;
+
+ return {stats, average};
+ }
+
+ auto run_deep_learning_tests(const std::string& path, const blt::i32 runs, const bool restore)
+ {
+ using namespace dlib;
+ using net_type_dl = loss_multiclass_log<
+ fc<10,
+ relu>>>>>>>>>>>>>;
+ BLT_TRACE("Running deep learning tests");
+ return run_network_tests(path, "deep_learning", runs, restore);
+ }
+
+ auto run_feed_forward_tests(const std::string& path, const blt::i32 runs, const bool restore)
+ {
+ using namespace dlib;
+
+ using net_type_ff = loss_multiclass_log<
+ fc<10,
+ relu>>>>>>>;
+
+ BLT_TRACE("Running feed forward tests");
+ return run_network_tests(path, "feed_forward", runs, restore);
+ }
+
+ void run_mnist(const int argc, const char** argv)
+ {
+ binary_directory = std::filesystem::current_path();
+ if (!blt::string::ends_with(binary_directory, '/'))
+ binary_directory += '/';
+ python_dual_stacked_graph_program = binary_directory + "../graph.py";
+ BLT_TRACE(binary_directory);
+ BLT_TRACE(python_dual_stacked_graph_program);
+ BLT_TRACE("Running with batch size %d", batch_size);
+
+ using namespace dlib;
+
+ blt::arg_parse parser{};
+ parser.addArgument(
+ blt::arg_builder{"-r", "--restore"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).setHelp(
+ "Restores from last save").build());
+ parser.addArgument(blt::arg_builder{"-t", "--runs"}.setHelp("Number of runs to perform [default: 10]").setDefault("10").build());
+ parser.addArgument(
+ blt::arg_builder{"-p", "--python"}.setHelp("Only run the python scripts").setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).
+ build());
+ parser.addArgument(blt::arg_builder{"type"}.setDefault("all").setHelp("Type of network to run [ff, dl, default: all]").build());
+
+ auto args = parser.parse_args(argc, argv);
+
+ const auto type = blt::string::toLowerCase(args.get("type"));
+ const auto runs = std::stoi(args.get("runs"));
+ const auto restore = args.get("restore");
+ const auto path = binary_directory + std::to_string(blt::system::getCurrentTimeMilliseconds());
+
+
+ if (type == "all")
+ {
+ auto [deep_stats, deep_tests] = run_deep_learning_tests(path, runs, restore);
+ auto [forward_stats, forward_tests] = run_feed_forward_tests(path, runs, restore);
+
+ auto average_forward_size = forward_stats.average_size();
+ auto average_deep_size = deep_stats.average_size();
+
+ {
+ std::ofstream test_results_f{path + "/test_results_table.txt"};
+ test_results_f << "\\begin{figure}" << std::endl;
+ test_results_f << "\t\\begin{tabular}{|c|c|c|c|}" << std::endl;
+ test_results_f << "\t\t\\hline" << std::endl;
+ test_results_f << "\t\tTest & Correct & Incorrect & Accuracy (\\%) \\\\" << std::endl;
+ test_results_f << "\t\t\\hline" << std::endl;
+ auto test_accuracy = forward_tests.hits / static_cast(forward_tests.hits + forward_tests.misses) * 100;
+ test_results_f << "\t\tFeed-Forward & " << forward_tests.hits << " & " << forward_tests.misses << " & " << std::setprecision(2) <<
+ test_accuracy << "\\\\" << std::endl;
+ test_accuracy = deep_tests.hits / static_cast(deep_tests.hits + deep_tests.misses) * 100;
+ test_results_f << "\t\tDeep Learning & " << deep_tests.hits << " & " << deep_tests.misses << " & " << std::setprecision(2) <<
+ test_accuracy << "\\\\" << std::endl;
+ test_results_f << "\t\\end{tabular}" << std::endl;
+ test_results_f << "\\end{figure}" << std::endl;
+
+ const auto [forward_epoch_stats] = forward_stats.average_stats();
+ std::ofstream train_forward{path + "/forward_train_results.csv"};
+ train_forward << "Epoch,Loss" << std::endl;
+ for (const auto& [i, v] : blt::enumerate(forward_epoch_stats))
+ train_forward << i << ',' << v.average_loss << std::endl;
+
+ const auto [deep_epoch_stats] = deep_stats.average_stats();
+ std::ofstream train_deep{path + "/deep_train_results.csv"};
+ train_deep << "Epoch,Loss" << std::endl;
+ for (const auto& [i, v] : blt::enumerate(deep_epoch_stats))
+ train_deep << i << ',' << v.average_loss << std::endl;
+
+ std::ofstream average_epochs{path + "/average_epochs.txt"};
+ average_epochs << average_forward_size << "," << average_deep_size << std::endl;
+ }
+
+ run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv",
+ path + "/deep_train_results.csv", average_forward_size, average_deep_size);
+ }
+ else if (type == "ff")
+ {
+ run_feed_forward_tests(path, runs, restore);
+ }
+ else if (type == "df")
+ {
+ run_deep_learning_tests(path, runs, restore);
+ }
+
+ // net_type_dl test_net;
+ // const auto stats = train_network("dl_nn", test_net);
+ // std::ofstream out_file{"dl_nn.csv"};
+ // out_file << stats;
+
+ // test_net = load_network("dl_nn");
+
+ // test_network(test_net);
+ }
+}
diff --git a/src/MNIST.cpp b/src/MNIST.cpp
index ddff1b8..9e2b552 100644
--- a/src/MNIST.cpp
+++ b/src/MNIST.cpp
@@ -27,13 +27,17 @@
#include
#include
#include
+#include
namespace fp
{
- constexpr blt::i64 batch_size = 512;
+ constexpr blt::i64 batch_size = 256;
std::string binary_directory;
std::string python_dual_stacked_graph_program;
+ std::atomic_bool break_flag = false;
+ std::atomic_bool stop_flag = false;
+ std::atomic_bool learn_flag = false;
void run_python_line_graph(const std::string& title, const std::string& output_file, const std::string& csv1, const std::string& csv2,
const blt::size_t pos_forward, const blt::size_t pos_deep)
@@ -362,7 +366,7 @@ namespace fp
}
};
- template
+ template
batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, const image_t::data_iterator end, image_t::label_iterator lbegin)
{
batch_stats_t stats{};
@@ -424,6 +428,8 @@ namespace fp
trainer.set_learning_rate(0.01);
trainer.set_min_learning_rate(0.00001);
trainer.set_mini_batch_size(batch_size);
+ trainer.set_max_num_epochs(100);
+ trainer.set_iterations_without_progress_threshold(2000);
trainer.be_verbose();
trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20));
@@ -443,6 +449,9 @@ namespace fp
if (end - begin <= 0)
break;
+ if (learn_flag)
+ trainer.set_learning_rate(trainer.get_learning_rate() / 10);
+
trainer.train_one_step(train_image.get_image_data().begin() + begin,
data.begin() + end, labels.begin() + begin);
}
@@ -465,6 +474,11 @@ namespace fp
stats.epoch_stats.push_back(epoch_stats);
network.clean();
+ if (break_flag)
+ {
+ break_flag = false;
+ break;
+ }
// dlib::serialize("mnist_network_" + ident + ".dat") << network;
}
@@ -508,6 +522,8 @@ namespace fp
for (blt::i32 i = 0; i < runs; i++)
{
+ if (stop_flag)
+ break;
BLT_TRACE("Starting run %d", i);
auto local_ident = ident + std::to_string(i);
NetworkType network{};
@@ -571,6 +587,32 @@ namespace fp
BLT_TRACE(python_dual_stacked_graph_program);
BLT_TRACE("Running with batch size %d", batch_size);
+ BLT_TRACE("Installing Signal Handlers");
+ if (std::signal(SIGINT, [](int){
+ BLT_TRACE("Stopping current training");
+ break_flag = true;
+ }) == SIG_ERR)
+ {
+ BLT_ERROR("Failed to replace SIGINT");
+ }
+ if (std::signal(SIGQUIT, [](int)
+ {
+ BLT_TRACE("Exiting Program");
+ stop_flag = true;
+ break_flag = true;
+ }) == SIG_ERR)
+ {
+ BLT_ERROR("Failed to replace SIGQUIT");
+ }
+ if (std::signal(SIGUSR1, [](int)
+ {
+ BLT_TRACE("Decreasing Learn Rate for current training");
+ learn_flag = true;
+ }) == SIG_ERR)
+ {
+ BLT_ERROR("Failed to replace SIGUSR1");
+ }
+
using namespace dlib;
blt::arg_parse parser{};