diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b5fcf8..7cc37d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Final-Project VERSION 0.0.17) +project(COSC-4P80-Final-Project VERSION 0.0.18) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/src/MNIST.cpp b/src/MNIST.cpp index 8e7047c..1e30af2 100644 --- a/src/MNIST.cpp +++ b/src/MNIST.cpp @@ -505,7 +505,7 @@ namespace fp if (break_flag) { break_flag = false; - last_epoch = epochs; + last_epoch = epochs + 1; break; } // dlib::serialize("mnist_network_" + ident + ".dat") << network; @@ -576,9 +576,13 @@ namespace fp state >> test_stats.back(); state.ignore(); } + + BLT_TRACE("Restoring at run %d with epoch %ld", i, load_epoch); + BLT_TRACE("\tRestored state size %lu", stats.run_stats.size()); + BLT_TRACE("\tRestored test size %lu", test_stats_size); } - blt::i64 last_epoch_save = -1; + blt::i64 last_epoch_save = last_epoch; for (; i < runs; i++) { if (stop_flag) @@ -596,10 +600,20 @@ namespace fp } catch (dlib::serialization_error&) { - stats += train_network(local_ident, network); + goto train_label; } else - stats += train_network(local_ident, network); + { + train_label: + auto stat = train_network(local_ident, network); + if (last_epoch_save > 0) + { + // add in all the new epochs + auto& vec = stats.run_stats.back(); + vec.epoch_stats.insert(vec.epoch_stats.begin(), stat.epoch_stats.begin(), stat.epoch_stats.end()); + } else + stats += stat; + } last_epoch_save = last_epoch; last_epoch = -1; test_stats.push_back(test_network(network)); @@ -626,7 +640,7 @@ namespace fp state << '\n'; state << test_stats.size(); state << '\n'; - for (const auto& v : test_stats) + for (const auto& v : blt::iterate(test_stats).take(test_stats.size() - 1)) { state << v; state << '\n';