diff --git a/.idea/editor.xml b/.idea/editor.xml
index 55d1bc1..b0d69ef 100644
--- a/.idea/editor.xml
+++ b/.idea/editor.xml
@@ -2,482 +2,482 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace (conflicted copy 2025-01-07 130058).xml b/.idea/workspace (conflicted copy 2025-01-07 130058).xml
new file mode 100644
index 0000000..df9d07f
--- /dev/null
+++ b/.idea/workspace (conflicted copy 2025-01-07 130058).xml
@@ -0,0 +1,225 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "useNewFormat": true
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "associatedIndex": 0
+}
+
+
+
+
+
+ {
+ "keyToString": {
+ "CMake Application.COSC-4P80-Final-Project.executor": "Run",
+ "RunOnceActivity.RadMigrateCodeStyle": "true",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "RunOnceActivity.cidr.known.project.marker": "true",
+ "RunOnceActivity.readMode.enableVisualFormatting": "true",
+ "RunOnceActivity.west.config.association.type.startup.service": "true",
+ "SHARE_PROJECT_CONFIGURATION_FILES": "true",
+ "cf.first.check.clang-format": "false",
+ "cidr.known.project.marker": "true",
+ "git-widget-placeholder": "main",
+ "last_opened_file_path": "/home/brett/Documents/Brock/CS 4P80/COSC-4P80-Final-Project",
+ "node.js.detected.package.eslint": "true",
+ "node.js.detected.package.tslint": "true",
+ "node.js.selected.package.eslint": "(autodetect)",
+ "node.js.selected.package.tslint": "(autodetect)",
+ "nodejs_package_manager_path": "npm",
+ "settings.editor.selected.configurable": "preferences.lookFeel",
+ "vue.rearranger.settings.migration": "true"
+ }
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ MAKEINDEX
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1733702642308
+
+
+ 1733702642308
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b881707..39853bd 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
-project(COSC-4P80-Final-Project VERSION 0.0.5)
+project(COSC-4P80-Final-Project VERSION 0.0.6)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
diff --git a/src/.MNIST.cpp.~63b6ca65 b/src/.MNIST.cpp.~63b6ca65
deleted file mode 100644
index cb6fe21..0000000
--- a/src/.MNIST.cpp.~63b6ca65
+++ /dev/null
@@ -1,346 +0,0 @@
-/*
- *
- * Copyright (C) 2024 Brett Terpstra
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- */
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-namespace fp
-{
- class idx_file_t
- {
- template
- using mk_v = std::vector;
- using vec_t = std::variant, mk_v, mk_v, mk_v, mk_v, mk_v>;
-
- public:
- explicit idx_file_t(const std::string& path)
- {
- std::ifstream file{path, std::ios::in | std::ios::binary};
-
- using char_type = std::ifstream::char_type;
- char_type magic_arr[4];
- file.read(magic_arr, 4);
- BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0);
-
- blt::u8 dims = magic_arr[3];
- blt::size_t total_size = 1;
-
- for (blt::i32 i = 0; i < dims; i++)
- {
- char_type dim_arr[4];
- file.read(dim_arr, 4);
- blt::u32 dim;
- blt::mem::fromBytes(dim_arr, dim);
- dimensions.push_back(dim);
- total_size *= dim;
- }
-
- switch (magic_arr[2])
- {
- // unsigned char
- case 0x08:
- data = mk_v{};
- read_data(file, total_size);
- break;
- // signed char
- case 0x09:
- data = mk_v{};
- read_data(file, total_size);
- break;
- // short
- case 0x0B:
- data = mk_v{};
- read_data(file, total_size);
- reverse_data();
- break;
- // int
- case 0x0C:
- data = mk_v{};
- read_data(file, total_size);
- reverse_data();
- break;
- // float
- case 0x0D:
- data = mk_v{};
- read_data(file, total_size);
- reverse_data();
- break;
- // double
- case 0x0E:
- data = mk_v{};
- read_data(file, total_size);
- reverse_data();
- break;
- default:
- BLT_ERROR("Unspported idx file type!");
- }
- if (file.eof())
- {
- BLT_ERROR("EOF reached. It's unlikely your file was read correctly!");
- }
- }
-
- template
- [[nodiscard]] const std::vector& get_data_as() const
- {
- return std::get>(data);
- }
-
- template
- std::vector> get_as_spans() const
- {
- std::vector> spans;
-
- blt::size_t total_size = data_size(1);
-
- for (blt::size_t i = 0; i < dimensions[0]; i++)
- {
- auto& array = std::get>(data);
- spans.push_back({&array[i * total_size], total_size});
- }
-
- return spans;
- }
-
- [[nodiscard]] const std::vector& get_dimensions() const
- {
- return dimensions;
- }
-
- [[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const
- {
- blt::size_t total_size = 1;
- for (const auto d : blt::iterate(dimensions).skip(starting_dimension))
- total_size *= d;
- return total_size;
- }
-
- private:
- template
- void read_data(std::ifstream& file, blt::size_t total_size)
- {
- auto& array = std::get>(data);
- array.resize(total_size);
- file.read(reinterpret_cast(array.data()), static_cast(total_size) * sizeof(T));
- }
-
- template
- void reverse_data()
- {
- auto& array = std::get>(data);
- for (auto& v : array)
- blt::mem::reverse(v);
- }
-
- std::vector dimensions;
- vec_t data;
- };
-
- class image_t
- {
- public:
- static constexpr blt::u32 target_size = 10;
- using data_iterator = std::vector>::const_iterator;
- using label_iterator = std::vector::const_iterator;
-
- image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]),
- input_size(image_data.data_size(1))
- {
- BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0],
- ("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])).
- c_str());
- auto& image_array = image_data.get_data_as();
- auto& label_array = label_data.get_data_as();
-
- for (const auto label : label_array)
- image_labels.push_back(label);
-
- const auto row_length = image_data.get_dimensions()[2];
- const auto number_of_rows = image_data.get_dimensions()[1];
-
- for (blt::u32 i = 0; i < samples; i++)
- {
- dlib::matrix mat(number_of_rows, row_length);
- for (blt::u32 y = 0; y < number_of_rows; y++)
- {
- for (blt::u32 x = 0; x < row_length; x++)
- {
- mat(x, y) = image_array[i * input_size + y * row_length + x];
- }
- }
- data.push_back(mat);
- }
- }
-
- [[nodiscard]] const std::vector>& get_image_data() const
- {
- return data;
- }
-
- [[nodiscard]] const std::vector& get_image_labels() const
- {
- return image_labels;
- }
-
- private:
- blt::u32 samples;
- blt::u32 input_size;
- std::vector> data;
- std::vector image_labels;
- };
-
- struct batch_stats_t
- {
- blt::u64 batch_size;
-
- };
-
- struct network_stats_t
- {
- };
-
- template
- batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, image_t::data_iterator end, image_t::label_iterator lbegin)
- {
- batch_stats_t stats;
-
-
-
- return stats;
- }
-
- template
- void test_network(NetworkType& network)
- {
- const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
- const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
-
- const auto test_samples = test_images.get_dimensions()[0];
-
- const image_t test_image{test_images, test_labels};
-
- const auto predicted_labels = network(test_image.get_image_data());
- int num_right = 0;
- int num_wrong = 0;
- for (size_t i = 0; i < test_image.get_image_data().size(); ++i)
- {
- if (predicted_labels[i] == test_image.get_image_labels()[i])
- ++num_right;
- else
- ++num_wrong;
- }
- std::cout << "testing num_right: " << num_right << std::endl;
- std::cout << "testing num_wrong: " << num_wrong << std::endl;
- std::cout << "testing accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
- }
-
- template
- network_stats_t train_network(const std::string& ident, NetworkType& network)
- {
- const idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"};
- const idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"};
-
- const image_t train_image{train_images, train_labels};
-
- network_stats_t stats;
-
- dlib::dnn_trainer trainer(network);
- trainer.set_learning_rate(0.01);
- trainer.set_min_learning_rate(0.00001);
- trainer.set_mini_batch_size(128);
- trainer.be_verbose();
-
- trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20));
-
- blt::size_t epochs = 0;
- blt::ptrdiff_t epoch_pos = 0;
- for (; epochs < trainer.getmax_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++)
- {
- auto& data = train_image.get_image_data();
- auto& labels = train_image.get_image_labels();
- for (; epoch_pos < data.size() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epoch_pos += trainer.
- get_mini_batch_size())
- {
- auto begin = epoch_pos;
- auto end = std::min(epoch_pos + trainer.get_mini_batch_size(), data.size());
-
- if (end - begin <= 0)
- break;
-
- trainer.train_one_step(train_image.get_image_data().begin() + begin,
- data.begin() + end, labels.begin() + begin);
- }
- epoch_pos = 0;
- trainer.wait_for_thread_to_pause();
- }
-
- // trainer.train(train_image.get_image_data(), train_image.get_image_labels());
-
- network.clean();
- dlib::serialize("mnist_network_" + ident + ".dat") << network;
-
- const std::vector predicted_labels = network(train_image.get_image_data());
- int num_right = 0;
- int num_wrong = 0;
- // And then let's see if it classified them correctly.
- for (size_t i = 0; i < train_image.get_image_data().size(); ++i)
- {
- if (predicted_labels[i] == train_image.get_image_labels()[i])
- ++num_right;
- else
- ++num_wrong;
- }
- std::cout << "training num_right: " << num_right << std::endl;
- std::cout << "training num_wrong: " << num_wrong << std::endl;
- std::cout << "training accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
-
- return stats;
- }
-
- template
- NetworkType load_network(const std::string& ident)
- {
- NetworkType network{};
- dlib::deserialize("mnist_network_" + ident + ".dat") >> network;
- return network;
- }
-
- void run_mnist(int argc, const char** argv)
- {
- using namespace dlib;
-
- // using net_type = loss_multiclass_log<
- // fc<10,
- // relu>>>>>>>>>>>>>;
-
- using net_type = loss_multiclass_log<
- fc<10,
- sig>>>>>>>;
- }
-}
diff --git a/src/MNIST.cpp b/src/MNIST.cpp
index cb6fe21..81b32dd 100644
--- a/src/MNIST.cpp
+++ b/src/MNIST.cpp
@@ -211,20 +211,131 @@ namespace fp
struct batch_stats_t
{
- blt::u64 batch_size;
+ blt::u64 hits;
+ blt::u64 misses;
+ friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats)
+ {
+ file << stats.hits << ',' << stats.misses;
+ return file;
+ }
+
+ friend std::ifstream& operator>>(std::ifstream& file, batch_stats_t& stats)
+ {
+ file >> stats.hits;
+ file.ignore();
+ file >> stats.misses;
+ return file;
+ }
+
+ batch_stats_t& operator+=(const batch_stats_t& stats)
+ {
+ hits += stats.hits;
+ misses += stats.misses;
+ return *this;
+ }
+
+ batch_stats_t& operator/=(const blt::u64 divisor)
+ {
+ hits /= divisor;
+ misses /= divisor;
+ return *this;
+ }
+ };
+
+ struct epoch_stats_t
+ {
+ batch_stats_t test_results;
+ double average_loss;
+ double learn_rate;
+
+ friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats)
+ {
+ file << stats.test_results << ',' << stats.average_loss << ',' << stats.learn_rate;
+ return file;
+ }
+
+ friend std::ifstream& operator>>(std::ifstream& file, epoch_stats_t& stats)
+ {
+ file >> stats.test_results;
+ file.ignore();
+ file >> stats.average_loss;
+ file.ignore();
+ file >> stats.learn_rate;
+ return file;
+ }
+
+ epoch_stats_t& operator+=(const epoch_stats_t& stats)
+ {
+ test_results += stats.test_results;
+ average_loss += stats.average_loss;
+ learn_rate += stats.learn_rate;
+ return *this;
+ }
+
+ epoch_stats_t& operator/=(const blt::u64 divisor)
+ {
+ test_results /= divisor;
+ average_loss /= static_cast(divisor);
+ learn_rate /= static_cast(divisor);
+ return *this;
+ }
};
struct network_stats_t
{
+ std::vector epoch_stats;
+
+ friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats)
+ {
+ file << stats.epoch_stats.size();
+ for (const auto& v : stats.epoch_stats)
+ file << v << "\n";
+ return file;
+ }
+
+ friend std::ifstream& operator>>(std::ifstream& file, network_stats_t& stats)
+ {
+ blt::size_t size;
+ file >> size;
+ for (blt::size_t i = 0; i < size; i++)
+ {
+ stats.epoch_stats.emplace_back();
+ file >> stats.epoch_stats.back();
+ file.ignore();
+ }
+ return file;
+ }
+
+
};
- template
- batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, image_t::data_iterator end, image_t::label_iterator lbegin)
+ template
+ batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, const image_t::data_iterator end, image_t::label_iterator lbegin)
{
- batch_stats_t stats;
+ batch_stats_t stats{};
+ std::array output_labels{};
+ auto amount_remaining = std::distance(begin, end);
+
+ while (amount_remaining != 0)
+ {
+ const auto batch = std::min(amount_remaining, batch_size);
+ network(begin, begin + batch, output_labels.begin());
+
+ for (auto [predicted, expected] : blt::iterate(output_labels.begin(), output_labels.begin() + batch).zip(lbegin, lbegin + batch))
+ {
+ if (predicted == expected)
+ ++stats.hits;
+ else
+ ++stats.misses;
+ }
+
+ begin += batch;
+ lbegin += batch;
+ amount_remaining -= batch;
+ }
return stats;
}
@@ -235,23 +346,14 @@ namespace fp
const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
- const auto test_samples = test_images.get_dimensions()[0];
-
const image_t test_image{test_images, test_labels};
- const auto predicted_labels = network(test_image.get_image_data());
- int num_right = 0;
- int num_wrong = 0;
- for (size_t i = 0; i < test_image.get_image_data().size(); ++i)
- {
- if (predicted_labels[i] == test_image.get_image_labels()[i])
- ++num_right;
- else
- ++num_wrong;
- }
- std::cout << "testing num_right: " << num_right << std::endl;
- std::cout << "testing num_wrong: " << num_wrong << std::endl;
- std::cout << "testing accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
+ auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(),
+ test_image.get_image_labels().begin());
+
+ BLT_INFO("Testing hits: %lu", test_results.hits);
+ BLT_INFO("Testing misses: %lu", test_results.misses);
+ BLT_INFO("Testing accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses));
}
template
@@ -274,7 +376,7 @@ namespace fp
blt::size_t epochs = 0;
blt::ptrdiff_t epoch_pos = 0;
- for (; epochs < trainer.getmax_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++)
+ for (; epochs < trainer.get_max_num_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++)
{
auto& data = train_image.get_image_data();
auto& labels = train_image.get_image_labels();
@@ -291,28 +393,39 @@ namespace fp
data.begin() + end, labels.begin() + begin);
}
epoch_pos = 0;
- trainer.wait_for_thread_to_pause();
+ BLT_DEBUG("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(),
+ trainer.get_learning_rate(), trainer.get_average_loss());
+
+ // sync and test
+ trainer.get_net(dlib::force_flush_to_disk::no);
+
+ epoch_stats_t epoch_stats{};
+ epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
+ train_image.get_image_labels().begin());
+ epoch_stats.average_loss = trainer.get_average_loss();
+ epoch_stats.learn_rate = trainer.get_learning_rate();
+
+ BLT_DEBUG("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses,
+ epoch_stats.test_results.hits / static_cast(epoch_stats.test_results.hits + epoch_stats.test_results.misses));
+
+ stats.epoch_stats.push_back(epoch_stats);
}
+ BLT_INFO("Finished Training");
+
+ // sync
+ trainer.get_net();
+ network.clean();
+
// trainer.train(train_image.get_image_data(), train_image.get_image_labels());
-
- network.clean();
dlib::serialize("mnist_network_" + ident + ".dat") << network;
- const std::vector predicted_labels = network(train_image.get_image_data());
- int num_right = 0;
- int num_wrong = 0;
- // And then let's see if it classified them correctly.
- for (size_t i = 0; i < train_image.get_image_data().size(); ++i)
- {
- if (predicted_labels[i] == train_image.get_image_labels()[i])
- ++num_right;
- else
- ++num_wrong;
- }
- std::cout << "training num_right: " << num_right << std::endl;
- std::cout << "training num_wrong: " << num_wrong << std::endl;
- std::cout << "training accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
+ auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
+ train_image.get_image_labels().begin());
+
+ BLT_INFO("Training hits: %lu", test_results.hits);
+ BLT_INFO("Training misses: %lu", test_results.misses);
+ BLT_INFO("Training accuracy: %lf", test_results.hits / static_cast(test_results.hits + test_results.misses));
return stats;
}
@@ -342,5 +455,12 @@ namespace fp
sig>>>>>>>;
+
+ net_type test_net;
+ const auto stats = train_network("fc_nn", test_net);
+ std::ofstream out_file{"fc_nn.csv"};
+ out_file << stats;
+
+ test_network(test_net);
}
}