main
Brett 2025-01-08 18:44:20 -05:00
parent 9197bfdc34
commit 479693bae0
2 changed files with 24 additions and 9 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Final-Project VERSION 0.0.13) project(COSC-4P80-Final-Project VERSION 0.0.14)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -569,10 +569,14 @@ namespace fp
} }
} }
blt::i64 last_epoch_save = -1;
for (; i < runs; i++) for (; i < runs; i++)
{ {
if (stop_flag) if (stop_flag)
{
BLT_TRACE("Stopping!");
break; break;
}
BLT_TRACE("Starting run %d", i); BLT_TRACE("Starting run %d", i);
auto local_ident = ident + std::to_string(i); auto local_ident = ident + std::to_string(i);
NetworkType network{}; NetworkType network{};
@ -587,6 +591,8 @@ namespace fp
} }
else else
stats += train_network(local_ident, network); stats += train_network(local_ident, network);
last_epoch_save = last_epoch;
last_epoch = -1;
test_stats.push_back(test_network(network)); test_stats.push_back(test_network(network));
} }
@ -603,7 +609,7 @@ namespace fp
} }
state << i; state << i;
state << last_epoch.load(std::memory_order_relaxed); state << last_epoch_save;
state << stats; state << stats;
state << test_stats.size(); state << test_stats.size();
for (const auto& v : test_stats) for (const auto& v : test_stats)
@ -652,14 +658,14 @@ namespace fp
else else
pos = binary_directory.substr(0, binary_directory.size() - 1).find_last_of('/'); pos = binary_directory.substr(0, binary_directory.size() - 1).find_last_of('/');
python_dual_stacked_graph_program = binary_directory.substr(0, pos) + "/graph.py"; python_dual_stacked_graph_program = binary_directory.substr(0, pos) + "/graph.py";
BLT_TRACE(binary_directory); BLT_DEBUG(binary_directory);
BLT_TRACE(python_dual_stacked_graph_program); BLT_DEBUG(python_dual_stacked_graph_program);
BLT_TRACE("Running with batch size %d", batch_size); BLT_DEBUG("Running with batch size %d", batch_size);
BLT_TRACE("Installing Signal Handlers"); BLT_DEBUG("Installing Signal Handlers");
if (std::signal(SIGINT, [](int) if (std::signal(SIGINT, [](int)
{ {
BLT_TRACE("Stopping current training"); BLT_INFO("Stopping current training");
break_flag = true; break_flag = true;
}) == SIG_ERR) }) == SIG_ERR)
{ {
@ -667,7 +673,7 @@ namespace fp
} }
if (std::signal(SIGQUIT, [](int) if (std::signal(SIGQUIT, [](int)
{ {
BLT_TRACE("Exiting Program"); BLT_INFO("Exiting Program");
stop_flag = true; stop_flag = true;
break_flag = true; break_flag = true;
}) == SIG_ERR) }) == SIG_ERR)
@ -676,12 +682,21 @@ namespace fp
} }
if (std::signal(SIGUSR1, [](int) if (std::signal(SIGUSR1, [](int)
{ {
BLT_TRACE("Decreasing Learn Rate for current training"); BLT_INFO("Decreasing Learn Rate for current training");
learn_flag = true; learn_flag = true;
}) == SIG_ERR) }) == SIG_ERR)
{ {
BLT_ERROR("Failed to replace SIGUSR1"); BLT_ERROR("Failed to replace SIGUSR1");
} }
if (std::signal(SIGUSR2, [](int)
{
BLT_INFO("Exiting Program");
stop_flag = true;
break_flag = true;
}) == SIG_ERR)
{
BLT_ERROR("Failed to replace SIGUSR2");
}
using namespace dlib; using namespace dlib;