Brett 2025-01-08 18:02:41 -05:00
parent 818c1151da
commit cacc94d937
2 changed files with 48 additions and 52 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Final-Project VERSION 0.0.11) project(COSC-4P80-Final-Project VERSION 0.0.12)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -580,15 +580,22 @@ namespace fp
void run_mnist(const int argc, const char** argv) void run_mnist(const int argc, const char** argv)
{ {
binary_directory = std::filesystem::current_path(); binary_directory = std::filesystem::current_path();
blt::size_t pos = 0;
if (!blt::string::ends_with(binary_directory, '/')) if (!blt::string::ends_with(binary_directory, '/'))
{
pos = binary_directory.find_last_of('/') - 1;
binary_directory += '/'; binary_directory += '/';
python_dual_stacked_graph_program = binary_directory + "../graph.py"; }
else
pos = binary_directory.substr(0, binary_directory.size() - 1).find_last_of('/') - 1;
python_dual_stacked_graph_program = binary_directory.substr(0, pos) + "/graph.py";
BLT_TRACE(binary_directory); BLT_TRACE(binary_directory);
BLT_TRACE(python_dual_stacked_graph_program); BLT_TRACE(python_dual_stacked_graph_program);
BLT_TRACE("Running with batch size %d", batch_size); BLT_TRACE("Running with batch size %d", batch_size);
BLT_TRACE("Installing Signal Handlers"); BLT_TRACE("Installing Signal Handlers");
if (std::signal(SIGINT, [](int){ if (std::signal(SIGINT, [](int)
{
BLT_TRACE("Stopping current training"); BLT_TRACE("Stopping current training");
break_flag = true; break_flag = true;
}) == SIG_ERR) }) == SIG_ERR)
@ -623,18 +630,16 @@ namespace fp
parser.addArgument( parser.addArgument(
blt::arg_builder{"-p", "--python"}.setHelp("Only run the python scripts").setAction(blt::arg_action_t::STORE_TRUE).setDefault(false). blt::arg_builder{"-p", "--python"}.setHelp("Only run the python scripts").setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).
build()); build());
parser.addArgument(blt::arg_builder{"type"}.setDefault("all").setHelp("Type of network to run [ff, dl, default: all]").build()); parser.addArgument(
blt::arg_builder{"network"}.setDefault(std::to_string(blt::system::getCurrentTimeMilliseconds())).setHelp("location of network files").
build());
auto args = parser.parse_args(argc, argv); auto args = parser.parse_args(argc, argv);
const auto type = blt::string::toLowerCase(args.get<std::string>("type"));
const auto runs = std::stoi(args.get<std::string>("runs")); const auto runs = std::stoi(args.get<std::string>("runs"));
const auto restore = args.get<bool>("restore"); const auto restore = args.get<bool>("restore");
const auto path = binary_directory + std::to_string(blt::system::getCurrentTimeMilliseconds()); auto path = binary_directory + args.get<std::string>("network");
if (type == "all")
{
auto [deep_stats, deep_tests] = run_deep_learning_tests(path, runs, restore); auto [deep_stats, deep_tests] = run_deep_learning_tests(path, runs, restore);
auto [forward_stats, forward_tests] = run_feed_forward_tests(path, runs, restore); auto [forward_stats, forward_tests] = run_feed_forward_tests(path, runs, restore);
@ -675,15 +680,6 @@ namespace fp
run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv", run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv",
path + "/deep_train_results.csv", average_forward_size, average_deep_size); path + "/deep_train_results.csv", average_forward_size, average_deep_size);
}
else if (type == "ff")
{
run_feed_forward_tests(path, runs, restore);
}
else if (type == "df")
{
run_deep_learning_tests(path, runs, restore);
}
// net_type_dl test_net; // net_type_dl test_net;
// const auto stats = train_network("dl_nn", test_net); // const auto stats = train_network("dl_nn", test_net);