diff --git a/.idea/editor.xml b/.idea/editor.xml
index b0d69ef..55d1bc1 100644
--- a/.idea/editor.xml
+++ b/.idea/editor.xml
@@ -2,482 +2,482 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
index 9521f67..b619dc4 100644
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -4,24 +4,16 @@
-
-
-
-
-
-
-
-
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a17eb7c..b881707 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
-project(COSC-4P80-Final-Project VERSION 0.0.2)
+project(COSC-4P80-Final-Project VERSION 0.0.5)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
diff --git a/commit.py b/commit.py
old mode 100644
new mode 100755
index 440e164..bcd7661
--- a/commit.py
+++ b/commit.py
@@ -36,7 +36,7 @@ else:
if len(str(XDG_CONFIG_HOME)) == 0:
XDG_CONFIG_HOME = USER_HOME
CONFIG_FILE_DIRECTORY = XDG_CONFIG_HOME / "blt"
- CONFIG_FILE_LOCATION = CONFIG_FILE_DIRECTORY / "commit_config.env"
+ CONFIG_FILE_LOCATION = CONFIG_FILE_DIRECTORY / "commit_config.json"
class Config:
def __init__(self):
diff --git a/include/MNIST.h b/include/MNIST.h
index e788f7c..c0d8ada 100644
--- a/include/MNIST.h
+++ b/include/MNIST.h
@@ -21,7 +21,7 @@
namespace fp
{
- void run_mnist();
+ void run_mnist(int argc, const char** argv);
}
#endif //MNIST_H
diff --git a/include/cat_and_dogs.h b/include/cat_and_dogs.h
new file mode 100644
index 0000000..276e0bd
--- /dev/null
+++ b/include/cat_and_dogs.h
@@ -0,0 +1,26 @@
+#pragma once
+/*
+ * Copyright (C) 2024 Brett Terpstra
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef CAT_AND_DOGS_H
+#define CAT_AND_DOGS_H
+
+namespace fp {
+ void run_cat_and_dogs();
+}
+
+#endif //CAT_AND_DOGS_H
diff --git a/problems/mnist/mnist-dataset.zip b/problems/mnist/mnist-dataset.zip
deleted file mode 100644
index 188e947..0000000
Binary files a/problems/mnist/mnist-dataset.zip and /dev/null differ
diff --git a/problems/mnist/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte b/problems/mnist/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte
deleted file mode 100644
index 1170b2c..0000000
Binary files a/problems/mnist/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte and /dev/null differ
diff --git a/problems/mnist/t10k-images.idx3-ubyte b/problems/mnist/t10k-images.idx3-ubyte
deleted file mode 100644
index 1170b2c..0000000
Binary files a/problems/mnist/t10k-images.idx3-ubyte and /dev/null differ
diff --git a/problems/mnist/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte b/problems/mnist/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte
deleted file mode 100644
index d1c3a97..0000000
Binary files a/problems/mnist/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte and /dev/null differ
diff --git a/problems/mnist/t10k-labels.idx1-ubyte b/problems/mnist/t10k-labels.idx1-ubyte
deleted file mode 100644
index d1c3a97..0000000
Binary files a/problems/mnist/t10k-labels.idx1-ubyte and /dev/null differ
diff --git a/problems/mnist/train-images-idx3-ubyte/train-images-idx3-ubyte b/problems/mnist/train-images-idx3-ubyte/train-images-idx3-ubyte
deleted file mode 100644
index bbce276..0000000
Binary files a/problems/mnist/train-images-idx3-ubyte/train-images-idx3-ubyte and /dev/null differ
diff --git a/problems/mnist/train-images.idx3-ubyte b/problems/mnist/train-images.idx3-ubyte
deleted file mode 100644
index bbce276..0000000
Binary files a/problems/mnist/train-images.idx3-ubyte and /dev/null differ
diff --git a/problems/mnist/train-labels-idx1-ubyte/train-labels-idx1-ubyte b/problems/mnist/train-labels-idx1-ubyte/train-labels-idx1-ubyte
deleted file mode 100644
index d6b4c5d..0000000
Binary files a/problems/mnist/train-labels-idx1-ubyte/train-labels-idx1-ubyte and /dev/null differ
diff --git a/problems/mnist/train-labels.idx1-ubyte b/problems/mnist/train-labels.idx1-ubyte
deleted file mode 100644
index d6b4c5d..0000000
Binary files a/problems/mnist/train-labels.idx1-ubyte and /dev/null differ
diff --git a/python/.idea/.gitignore b/python/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/python/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/python/.idea/inspectionProfiles/profiles_settings.xml b/python/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/python/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/python/.idea/misc.xml b/python/.idea/misc.xml
new file mode 100644
index 0000000..a6218fe
--- /dev/null
+++ b/python/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/python/.idea/modules.xml b/python/.idea/modules.xml
new file mode 100644
index 0000000..614b3c1
--- /dev/null
+++ b/python/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/python/.idea/python.iml b/python/.idea/python.iml
new file mode 100644
index 0000000..909438d
--- /dev/null
+++ b/python/.idea/python.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/python/.idea/vcs.xml b/python/.idea/vcs.xml
new file mode 100644
index 0000000..6c0b863
--- /dev/null
+++ b/python/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/.MNIST.cpp.~63b6ca65 b/src/.MNIST.cpp.~63b6ca65
new file mode 100644
index 0000000..cb6fe21
--- /dev/null
+++ b/src/.MNIST.cpp.~63b6ca65
@@ -0,0 +1,346 @@
+/*
+ *
+ * Copyright (C) 2024 Brett Terpstra
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace fp
+{
+ class idx_file_t
+ {
+ template
+ using mk_v = std::vector;
+ using vec_t = std::variant, mk_v, mk_v, mk_v, mk_v, mk_v>;
+
+ public:
+ explicit idx_file_t(const std::string& path)
+ {
+ std::ifstream file{path, std::ios::in | std::ios::binary};
+
+ using char_type = std::ifstream::char_type;
+ char_type magic_arr[4];
+ file.read(magic_arr, 4);
+ BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0);
+
+ blt::u8 dims = magic_arr[3];
+ blt::size_t total_size = 1;
+
+ for (blt::i32 i = 0; i < dims; i++)
+ {
+ char_type dim_arr[4];
+ file.read(dim_arr, 4);
+ blt::u32 dim;
+ blt::mem::fromBytes(dim_arr, dim);
+ dimensions.push_back(dim);
+ total_size *= dim;
+ }
+
+ switch (magic_arr[2])
+ {
+ // unsigned char
+ case 0x08:
+ data = mk_v{};
+ read_data(file, total_size);
+ break;
+ // signed char
+ case 0x09:
+ data = mk_v{};
+ read_data(file, total_size);
+ break;
+ // short
+ case 0x0B:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ // int
+ case 0x0C:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ // float
+ case 0x0D:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ // double
+ case 0x0E:
+ data = mk_v{};
+ read_data(file, total_size);
+ reverse_data();
+ break;
+ default:
+ BLT_ERROR("Unspported idx file type!");
+ }
+ if (file.eof())
+ {
+ BLT_ERROR("EOF reached. It's unlikely your file was read correctly!");
+ }
+ }
+
+ template
+ [[nodiscard]] const std::vector& get_data_as() const
+ {
+ return std::get>(data);
+ }
+
+ template
+ std::vector> get_as_spans() const
+ {
+ std::vector> spans;
+
+ blt::size_t total_size = data_size(1);
+
+ for (blt::size_t i = 0; i < dimensions[0]; i++)
+ {
+ auto& array = std::get>(data);
+ spans.push_back({&array[i * total_size], total_size});
+ }
+
+ return spans;
+ }
+
+ [[nodiscard]] const std::vector& get_dimensions() const
+ {
+ return dimensions;
+ }
+
+ [[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const
+ {
+ blt::size_t total_size = 1;
+ for (const auto d : blt::iterate(dimensions).skip(starting_dimension))
+ total_size *= d;
+ return total_size;
+ }
+
+ private:
+ template
+ void read_data(std::ifstream& file, blt::size_t total_size)
+ {
+ auto& array = std::get>(data);
+ array.resize(total_size);
+ file.read(reinterpret_cast(array.data()), static_cast(total_size) * sizeof(T));
+ }
+
+ template
+ void reverse_data()
+ {
+ auto& array = std::get>(data);
+ for (auto& v : array)
+ blt::mem::reverse(v);
+ }
+
+ std::vector dimensions;
+ vec_t data;
+ };
+
+ class image_t
+ {
+ public:
+ static constexpr blt::u32 target_size = 10;
+ using data_iterator = std::vector>::const_iterator;
+ using label_iterator = std::vector::const_iterator;
+
+ image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]),
+ input_size(image_data.data_size(1))
+ {
+ BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0],
+ ("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])).
+ c_str());
+ auto& image_array = image_data.get_data_as();
+ auto& label_array = label_data.get_data_as();
+
+ for (const auto label : label_array)
+ image_labels.push_back(label);
+
+ const auto row_length = image_data.get_dimensions()[2];
+ const auto number_of_rows = image_data.get_dimensions()[1];
+
+ for (blt::u32 i = 0; i < samples; i++)
+ {
+ dlib::matrix mat(number_of_rows, row_length);
+ for (blt::u32 y = 0; y < number_of_rows; y++)
+ {
+ for (blt::u32 x = 0; x < row_length; x++)
+ {
+ mat(x, y) = image_array[i * input_size + y * row_length + x];
+ }
+ }
+ data.push_back(mat);
+ }
+ }
+
+ [[nodiscard]] const std::vector>& get_image_data() const
+ {
+ return data;
+ }
+
+ [[nodiscard]] const std::vector& get_image_labels() const
+ {
+ return image_labels;
+ }
+
+ private:
+ blt::u32 samples;
+ blt::u32 input_size;
+ std::vector> data;
+ std::vector image_labels;
+ };
+
+ struct batch_stats_t
+ {
+ blt::u64 batch_size;
+
+ };
+
+ struct network_stats_t
+ {
+ };
+
+ template
+ batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, image_t::data_iterator end, image_t::label_iterator lbegin)
+ {
+ batch_stats_t stats;
+
+
+
+ return stats;
+ }
+
+ template
+ void test_network(NetworkType& network)
+ {
+ const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
+ const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
+
+ const auto test_samples = test_images.get_dimensions()[0];
+
+ const image_t test_image{test_images, test_labels};
+
+ const auto predicted_labels = network(test_image.get_image_data());
+ int num_right = 0;
+ int num_wrong = 0;
+ for (size_t i = 0; i < test_image.get_image_data().size(); ++i)
+ {
+ if (predicted_labels[i] == test_image.get_image_labels()[i])
+ ++num_right;
+ else
+ ++num_wrong;
+ }
+ std::cout << "testing num_right: " << num_right << std::endl;
+ std::cout << "testing num_wrong: " << num_wrong << std::endl;
+ std::cout << "testing accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
+ }
+
+ template
+ network_stats_t train_network(const std::string& ident, NetworkType& network)
+ {
+ const idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"};
+ const idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"};
+
+ const image_t train_image{train_images, train_labels};
+
+ network_stats_t stats;
+
+ dlib::dnn_trainer trainer(network);
+ trainer.set_learning_rate(0.01);
+ trainer.set_min_learning_rate(0.00001);
+ trainer.set_mini_batch_size(128);
+ trainer.be_verbose();
+
+ trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20));
+
+ blt::size_t epochs = 0;
+ blt::ptrdiff_t epoch_pos = 0;
+ for (; epochs < trainer.getmax_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++)
+ {
+ auto& data = train_image.get_image_data();
+ auto& labels = train_image.get_image_labels();
+ for (; epoch_pos < data.size() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epoch_pos += trainer.
+ get_mini_batch_size())
+ {
+ auto begin = epoch_pos;
+ auto end = std::min(epoch_pos + trainer.get_mini_batch_size(), data.size());
+
+ if (end - begin <= 0)
+ break;
+
+ trainer.train_one_step(train_image.get_image_data().begin() + begin,
+ data.begin() + end, labels.begin() + begin);
+ }
+ epoch_pos = 0;
+ trainer.wait_for_thread_to_pause();
+ }
+
+ // trainer.train(train_image.get_image_data(), train_image.get_image_labels());
+
+ network.clean();
+ dlib::serialize("mnist_network_" + ident + ".dat") << network;
+
+ const std::vector predicted_labels = network(train_image.get_image_data());
+ int num_right = 0;
+ int num_wrong = 0;
+ // And then let's see if it classified them correctly.
+ for (size_t i = 0; i < train_image.get_image_data().size(); ++i)
+ {
+ if (predicted_labels[i] == train_image.get_image_labels()[i])
+ ++num_right;
+ else
+ ++num_wrong;
+ }
+ std::cout << "training num_right: " << num_right << std::endl;
+ std::cout << "training num_wrong: " << num_wrong << std::endl;
+ std::cout << "training accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
+
+ return stats;
+ }
+
+ template
+ NetworkType load_network(const std::string& ident)
+ {
+ NetworkType network{};
+ dlib::deserialize("mnist_network_" + ident + ".dat") >> network;
+ return network;
+ }
+
+ void run_mnist(int argc, const char** argv)
+ {
+ using namespace dlib;
+
+ // using net_type = loss_multiclass_log<
+ // fc<10,
+ // relu>>>>>>>>>>>>>;
+
+ using net_type = loss_multiclass_log<
+ fc<10,
+ sig>>>>>>>;
+ }
+}
diff --git a/src/MNIST.cpp b/src/MNIST.cpp
index f64078a..cb6fe21 100644
--- a/src/MNIST.cpp
+++ b/src/MNIST.cpp
@@ -160,6 +160,8 @@ namespace fp
{
public:
static constexpr blt::u32 target_size = 10;
+ using data_iterator = std::vector>::const_iterator;
+ using label_iterator = std::vector::const_iterator;
image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]),
input_size(image_data.data_size(1))
@@ -207,26 +209,126 @@ namespace fp
std::vector image_labels;
};
- void run_mnist()
+ struct batch_stats_t
+ {
+ blt::u64 batch_size;
+
+ };
+
+ struct network_stats_t
+ {
+ };
+
+ template
+ batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, image_t::data_iterator end, image_t::label_iterator lbegin)
+ {
+ batch_stats_t stats;
+
+
+
+ return stats;
+ }
+
+ template
+ void test_network(NetworkType& network)
+ {
+ const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
+ const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
+
+ const auto test_samples = test_images.get_dimensions()[0];
+
+ const image_t test_image{test_images, test_labels};
+
+ const auto predicted_labels = network(test_image.get_image_data());
+ int num_right = 0;
+ int num_wrong = 0;
+ for (size_t i = 0; i < test_image.get_image_data().size(); ++i)
+ {
+ if (predicted_labels[i] == test_image.get_image_labels()[i])
+ ++num_right;
+ else
+ ++num_wrong;
+ }
+ std::cout << "testing num_right: " << num_right << std::endl;
+ std::cout << "testing num_wrong: " << num_wrong << std::endl;
+ std::cout << "testing accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
+ }
+
+ template
+ network_stats_t train_network(const std::string& ident, NetworkType& network)
+ {
+ const idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"};
+ const idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"};
+
+ const image_t train_image{train_images, train_labels};
+
+ network_stats_t stats;
+
+ dlib::dnn_trainer trainer(network);
+ trainer.set_learning_rate(0.01);
+ trainer.set_min_learning_rate(0.00001);
+ trainer.set_mini_batch_size(128);
+ trainer.be_verbose();
+
+ trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20));
+
+ blt::size_t epochs = 0;
+ blt::ptrdiff_t epoch_pos = 0;
+ for (; epochs < trainer.getmax_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++)
+ {
+ auto& data = train_image.get_image_data();
+ auto& labels = train_image.get_image_labels();
+ for (; epoch_pos < data.size() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epoch_pos += trainer.
+ get_mini_batch_size())
+ {
+ auto begin = epoch_pos;
+ auto end = std::min(epoch_pos + trainer.get_mini_batch_size(), data.size());
+
+ if (end - begin <= 0)
+ break;
+
+ trainer.train_one_step(train_image.get_image_data().begin() + begin,
+ data.begin() + end, labels.begin() + begin);
+ }
+ epoch_pos = 0;
+ trainer.wait_for_thread_to_pause();
+ }
+
+ // trainer.train(train_image.get_image_data(), train_image.get_image_labels());
+
+ network.clean();
+ dlib::serialize("mnist_network_" + ident + ".dat") << network;
+
+ const std::vector predicted_labels = network(train_image.get_image_data());
+ int num_right = 0;
+ int num_wrong = 0;
+ // And then let's see if it classified them correctly.
+ for (size_t i = 0; i < train_image.get_image_data().size(); ++i)
+ {
+ if (predicted_labels[i] == train_image.get_image_labels()[i])
+ ++num_right;
+ else
+ ++num_wrong;
+ }
+ std::cout << "training num_right: " << num_right << std::endl;
+ std::cout << "training num_wrong: " << num_wrong << std::endl;
+ std::cout << "training accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
+
+ return stats;
+ }
+
+ template
+ NetworkType load_network(const std::string& ident)
+ {
+ NetworkType network{};
+ dlib::deserialize("mnist_network_" + ident + ".dat") >> network;
+ return network;
+ }
+
+ void run_mnist(int argc, const char** argv)
{
using namespace dlib;
- idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
- idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
- idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"};
- idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"};
-
- auto train_samples = train_images.get_dimensions()[0];
- auto test_samples = test_images.get_dimensions()[0];
-
- auto columns = train_images.get_dimensions()[1];
- auto rows = train_images.get_dimensions()[2];
-
- auto input_size = static_cast(train_images.data_size(1));
-
- image_t train_image{train_images, train_labels};
- image_t test_image{test_images, test_labels};
-
// using net_type = loss_multiclass_log<
// fc<10,
// relu>>>>>>>>>>>>>;
using net_type = loss_multiclass_log<
- fc<10,
- sig>>>>>>>;
-
- net_type network{};
-
- dnn_trainer trainer(network);
- trainer.set_learning_rate(0.01);
- trainer.set_min_learning_rate(0.00001);
- trainer.set_mini_batch_size(128);
- trainer.be_verbose();
-
- // trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
-
- trainer.train(train_image.get_image_data(), train_image.get_image_labels());
-
- network.clean();
- serialize("mnist_network.dat") << network;
-
- std::vector predicted_labels = network(train_image.get_image_data());
- int num_right = 0;
- int num_wrong = 0;
- // And then let's see if it classified them correctly.
- for (size_t i = 0; i < train_image.get_image_data().size(); ++i)
- {
- if (predicted_labels[i] == train_image.get_image_labels()[i])
- ++num_right;
- else
- ++num_wrong;
-
- }
- std::cout << "training num_right: " << num_right << std::endl;
- std::cout << "training num_wrong: " << num_wrong << std::endl;
- std::cout << "training accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
-
- predicted_labels = network(test_image.get_image_data());
- num_right = 0;
- num_wrong = 0;
- for (size_t i = 0; i < test_image.get_image_data().size(); ++i)
- {
- if (predicted_labels[i] == test_image.get_image_labels()[i])
- ++num_right;
- else
- ++num_wrong;
-
- }
- std::cout << "testing num_right: " << num_right << std::endl;
- std::cout << "testing num_wrong: " << num_wrong << std::endl;
- std::cout << "testing accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl;
+ fc<10,
+ sig>>>>>>>;
}
}
diff --git a/src/cats_and_dogs.cpp b/src/cats_and_dogs.cpp
new file mode 100644
index 0000000..e047656
--- /dev/null
+++ b/src/cats_and_dogs.cpp
@@ -0,0 +1,33 @@
+/*
+ *
+ * Copyright (C) 2024 Brett Terpstra
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace fp {
+
+void run_cat_and_dogs() {
+
+}
+
+}
\ No newline at end of file
diff --git a/src/main.cpp b/src/main.cpp
index 18fbfbe..68a339f 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -1,8 +1,12 @@
+#include
#include
+#include
#include
-int main()
+int main(int argc, const char** argv)
{
- fp::run_mnist();
+ fp::run_mnist(argc, argv);
+ // fp::run_cat_and_dogs();
+ return 0;
}