/* * * Copyright (C) 2024 Brett Terpstra * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include #include #include #include #include #include #include #include namespace fp { class idx_file_t { template using mk_v = std::vector; using vec_t = std::variant, mk_v, mk_v, mk_v, mk_v, mk_v>; public: explicit idx_file_t(const std::string& path) { std::ifstream file{path, std::ios::in | std::ios::binary}; using char_type = std::ifstream::char_type; char_type magic_arr[4]; file.read(magic_arr, 4); BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0); blt::u8 dims = magic_arr[3]; blt::size_t total_size = 1; for (blt::i32 i = 0; i < dims; i++) { char_type dim_arr[4]; file.read(dim_arr, 4); blt::u32 dim; blt::mem::fromBytes(dim_arr, dim); dimensions.push_back(dim); total_size *= dim; } switch (magic_arr[2]) { // unsigned char case 0x08: data = mk_v{}; read_data(file, total_size); break; // signed char case 0x09: data = mk_v{}; read_data(file, total_size); break; // short case 0x0B: data = mk_v{}; read_data(file, total_size); reverse_data(); break; // int case 0x0C: data = mk_v{}; read_data(file, total_size); reverse_data(); break; // float case 0x0D: data = mk_v{}; read_data(file, total_size); reverse_data(); break; // double case 0x0E: data = mk_v{}; read_data(file, total_size); reverse_data(); break; default: BLT_ERROR("Unspported idx file type!"); } if (file.eof()) { BLT_ERROR("EOF reached. It's unlikely your file was read correctly!"); } } template [[nodiscard]] const std::vector& get_data_as() const { return std::get>(data); } template std::vector> get_as_spans() const { std::vector> spans; blt::size_t total_size = data_size(1); for (blt::size_t i = 0; i < dimensions[0]; i++) { auto& array = std::get>(data); spans.push_back({&array[i * total_size], total_size}); } return spans; } [[nodiscard]] const std::vector& get_dimensions() const { return dimensions; } [[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const { blt::size_t total_size = 1; for (const auto d : blt::iterate(dimensions).skip(starting_dimension)) total_size *= d; return total_size; } private: template void read_data(std::ifstream& file, blt::size_t total_size) { auto& array = std::get>(data); array.resize(total_size); file.read(reinterpret_cast(array.data()), static_cast(total_size) * sizeof(T)); } template void reverse_data() { auto& array = std::get>(data); for (auto& v : array) blt::mem::reverse(v); } std::vector dimensions; vec_t data; }; class image_t { public: static constexpr blt::u32 target_size = 10; using data_iterator = std::vector>::const_iterator; using label_iterator = std::vector::const_iterator; image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]), input_size(image_data.data_size(1)) { BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0], ("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])). c_str()); auto& image_array = image_data.get_data_as(); auto& label_array = label_data.get_data_as(); for (const auto label : label_array) image_labels.push_back(label); const auto row_length = image_data.get_dimensions()[2]; const auto number_of_rows = image_data.get_dimensions()[1]; for (blt::u32 i = 0; i < samples; i++) { dlib::matrix mat(number_of_rows, row_length); for (blt::u32 y = 0; y < number_of_rows; y++) { for (blt::u32 x = 0; x < row_length; x++) { mat(x, y) = image_array[i * input_size + y * row_length + x]; } } data.push_back(mat); } } [[nodiscard]] const std::vector>& get_image_data() const { return data; } [[nodiscard]] const std::vector& get_image_labels() const { return image_labels; } private: blt::u32 samples; blt::u32 input_size; std::vector> data; std::vector image_labels; }; struct batch_stats_t { blt::u64 batch_size; }; struct network_stats_t { }; template batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, image_t::data_iterator end, image_t::label_iterator lbegin) { batch_stats_t stats; return stats; } template void test_network(NetworkType& network) { const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"}; const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"}; const auto test_samples = test_images.get_dimensions()[0]; const image_t test_image{test_images, test_labels}; const auto predicted_labels = network(test_image.get_image_data()); int num_right = 0; int num_wrong = 0; for (size_t i = 0; i < test_image.get_image_data().size(); ++i) { if (predicted_labels[i] == test_image.get_image_labels()[i]) ++num_right; else ++num_wrong; } std::cout << "testing num_right: " << num_right << std::endl; std::cout << "testing num_wrong: " << num_wrong << std::endl; std::cout << "testing accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl; } template network_stats_t train_network(const std::string& ident, NetworkType& network) { const idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"}; const idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"}; const image_t train_image{train_images, train_labels}; network_stats_t stats; dlib::dnn_trainer trainer(network); trainer.set_learning_rate(0.01); trainer.set_min_learning_rate(0.00001); trainer.set_mini_batch_size(128); trainer.be_verbose(); trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20)); blt::size_t epochs = 0; blt::ptrdiff_t epoch_pos = 0; for (; epochs < trainer.getmax_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++) { auto& data = train_image.get_image_data(); auto& labels = train_image.get_image_labels(); for (; epoch_pos < data.size() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epoch_pos += trainer. get_mini_batch_size()) { auto begin = epoch_pos; auto end = std::min(epoch_pos + trainer.get_mini_batch_size(), data.size()); if (end - begin <= 0) break; trainer.train_one_step(train_image.get_image_data().begin() + begin, data.begin() + end, labels.begin() + begin); } epoch_pos = 0; trainer.wait_for_thread_to_pause(); } // trainer.train(train_image.get_image_data(), train_image.get_image_labels()); network.clean(); dlib::serialize("mnist_network_" + ident + ".dat") << network; const std::vector predicted_labels = network(train_image.get_image_data()); int num_right = 0; int num_wrong = 0; // And then let's see if it classified them correctly. for (size_t i = 0; i < train_image.get_image_data().size(); ++i) { if (predicted_labels[i] == train_image.get_image_labels()[i]) ++num_right; else ++num_wrong; } std::cout << "training num_right: " << num_right << std::endl; std::cout << "training num_wrong: " << num_wrong << std::endl; std::cout << "training accuracy: " << num_right / static_cast(num_right + num_wrong) << std::endl; return stats; } template NetworkType load_network(const std::string& ident) { NetworkType network{}; dlib::deserialize("mnist_network_" + ident + ".dat") >> network; return network; } void run_mnist(int argc, const char** argv) { using namespace dlib; // using net_type = loss_multiclass_log< // fc<10, // relu>>>>>>>>>>>>>; using net_type = loss_multiclass_log< fc<10, sig>>>>>>>; } }