diff --git a/.idea/vcs.xml b/.idea/vcs.xml index b619dc4..309514d 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -2,18 +2,6 @@ - - - - - - - - - - - - diff --git a/CMakeLists.txt b/CMakeLists.txt index ab8b351..d2cd27c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Final-Project VERSION 0.0.7) +project(COSC-4P80-Final-Project VERSION 0.0.8) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/src/MNIST.cpp b/src/MNIST.cpp index 3813212..0d0da2f 100644 --- a/src/MNIST.cpp +++ b/src/MNIST.cpp @@ -30,6 +30,8 @@ namespace fp { + constexpr blt::i64 batch_size = 512; + std::string binary_directory; std::string python_dual_stacked_graph_program; @@ -360,7 +362,7 @@ namespace fp } }; - template + template batch_stats_t test_batch(NetworkType& network, image_t::data_iterator begin, const image_t::data_iterator end, image_t::label_iterator lbegin) { batch_stats_t stats{}; @@ -421,7 +423,7 @@ namespace fp dlib::dnn_trainer trainer(network); trainer.set_learning_rate(0.01); trainer.set_min_learning_rate(0.00001); - trainer.set_mini_batch_size(128); + trainer.set_mini_batch_size(batch_size); trainer.be_verbose(); trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20));