main
Brett 2025-01-08 19:14:47 -05:00
parent 377372842c
commit 56b54ed78e
2 changed files with 20 additions and 6 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Final-Project VERSION 0.0.17) project(COSC-4P80-Final-Project VERSION 0.0.18)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -505,7 +505,7 @@ namespace fp
if (break_flag) if (break_flag)
{ {
break_flag = false; break_flag = false;
last_epoch = epochs; last_epoch = epochs + 1;
break; break;
} }
// dlib::serialize("mnist_network_" + ident + ".dat") << network; // dlib::serialize("mnist_network_" + ident + ".dat") << network;
@ -576,9 +576,13 @@ namespace fp
state >> test_stats.back(); state >> test_stats.back();
state.ignore(); state.ignore();
} }
BLT_TRACE("Restoring at run %d with epoch %ld", i, load_epoch);
BLT_TRACE("\tRestored state size %lu", stats.run_stats.size());
BLT_TRACE("\tRestored test size %lu", test_stats_size);
} }
blt::i64 last_epoch_save = -1; blt::i64 last_epoch_save = last_epoch;
for (; i < runs; i++) for (; i < runs; i++)
{ {
if (stop_flag) if (stop_flag)
@ -596,10 +600,20 @@ namespace fp
} }
catch (dlib::serialization_error&) catch (dlib::serialization_error&)
{ {
stats += train_network(local_ident, network); goto train_label;
} }
else else
stats += train_network(local_ident, network); {
train_label:
auto stat = train_network(local_ident, network);
if (last_epoch_save > 0)
{
// add in all the new epochs
auto& vec = stats.run_stats.back();
vec.epoch_stats.insert(vec.epoch_stats.begin(), stat.epoch_stats.begin(), stat.epoch_stats.end());
} else
stats += stat;
}
last_epoch_save = last_epoch; last_epoch_save = last_epoch;
last_epoch = -1; last_epoch = -1;
test_stats.push_back(test_network(network)); test_stats.push_back(test_network(network));
@ -626,7 +640,7 @@ namespace fp
state << '\n'; state << '\n';
state << test_stats.size(); state << test_stats.size();
state << '\n'; state << '\n';
for (const auto& v : test_stats) for (const auto& v : blt::iterate(test_stats).take(test_stats.size() - 1))
{ {
state << v; state << v;
state << '\n'; state << '\n';