diff --git a/CMakeLists.txt b/CMakeLists.txt index 937e1e4..2f4467b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Final-Project VERSION 0.0.13) +project(COSC-4P80-Final-Project VERSION 0.0.14) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/src/MNIST.cpp b/src/MNIST.cpp index 65d3b69..09d5520 100644 --- a/src/MNIST.cpp +++ b/src/MNIST.cpp @@ -569,10 +569,14 @@ namespace fp } } + blt::i64 last_epoch_save = -1; for (; i < runs; i++) { if (stop_flag) + { + BLT_TRACE("Stopping!"); break; + } BLT_TRACE("Starting run %d", i); auto local_ident = ident + std::to_string(i); NetworkType network{}; @@ -587,6 +591,8 @@ namespace fp } else stats += train_network(local_ident, network); + last_epoch_save = last_epoch; + last_epoch = -1; test_stats.push_back(test_network(network)); } @@ -603,7 +609,7 @@ namespace fp } state << i; - state << last_epoch.load(std::memory_order_relaxed); + state << last_epoch_save; state << stats; state << test_stats.size(); for (const auto& v : test_stats) @@ -652,14 +658,14 @@ namespace fp else pos = binary_directory.substr(0, binary_directory.size() - 1).find_last_of('/'); python_dual_stacked_graph_program = binary_directory.substr(0, pos) + "/graph.py"; - BLT_TRACE(binary_directory); - BLT_TRACE(python_dual_stacked_graph_program); - BLT_TRACE("Running with batch size %d", batch_size); + BLT_DEBUG(binary_directory); + BLT_DEBUG(python_dual_stacked_graph_program); + BLT_DEBUG("Running with batch size %d", batch_size); - BLT_TRACE("Installing Signal Handlers"); + BLT_DEBUG("Installing Signal Handlers"); if (std::signal(SIGINT, [](int) { - BLT_TRACE("Stopping current training"); + BLT_INFO("Stopping current training"); break_flag = true; }) == SIG_ERR) { @@ -667,7 +673,7 @@ namespace fp } if (std::signal(SIGQUIT, [](int) { - BLT_TRACE("Exiting Program"); + BLT_INFO("Exiting Program"); stop_flag = true; break_flag = true; }) == SIG_ERR) @@ -676,12 +682,21 @@ namespace fp } if (std::signal(SIGUSR1, [](int) { - BLT_TRACE("Decreasing Learn Rate for current training"); + BLT_INFO("Decreasing Learn Rate for current training"); learn_flag = true; }) == SIG_ERR) { BLT_ERROR("Failed to replace SIGUSR1"); } + if (std::signal(SIGUSR2, [](int) + { + BLT_INFO("Exiting Program"); + stop_flag = true; + break_flag = true; + }) == SIG_ERR) + { + BLT_ERROR("Failed to replace SIGUSR2"); + } using namespace dlib;