diff --git a/.idea/workspace (conflicted copy 2025-01-08 181720).xml b/.idea/workspace (conflicted copy 2025-01-08 181720).xml
new file mode 100644
index 0000000..8ae7f97
--- /dev/null
+++ b/.idea/workspace (conflicted copy 2025-01-08 181720).xml
@@ -0,0 +1,268 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "useNewFormat": true
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "associatedIndex": 0
+}
+
+
+
+
+
+ {
+ "keyToString": {
+ "CMake Application.COSC-4P80-Final-Project.executor": "Run",
+ "RunOnceActivity.RadMigrateCodeStyle": "true",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "RunOnceActivity.cidr.known.project.marker": "true",
+ "RunOnceActivity.readMode.enableVisualFormatting": "true",
+ "RunOnceActivity.west.config.association.type.startup.service": "true",
+ "SHARE_PROJECT_CONFIGURATION_FILES": "true",
+ "cf.first.check.clang-format": "false",
+ "cidr.known.project.marker": "true",
+ "git-widget-placeholder": "main",
+ "last_opened_file_path": "/home/brett/Documents/Brock/CS 4P80/COSC-4P80-Final-Project",
+ "node.js.detected.package.eslint": "true",
+ "node.js.detected.package.tslint": "true",
+ "node.js.selected.package.eslint": "(autodetect)",
+ "node.js.selected.package.tslint": "(autodetect)",
+ "nodejs_package_manager_path": "npm",
+ "settings.editor.selected.configurable": "preferences.lookFeel",
+ "vue.rearranger.settings.migration": "true"
+ }
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ PDFLATEX
+
+
+ OKULAR
+
+
+
+
+
+ {projectDir}/out
+ {projectDir}/auxil
+ false
+ PDF
+ TEXLIVE
+ false
+ []
+ []
+
+
+
+
+
+ MAKEINDEX
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1733702642308
+
+
+ 1733702642308
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 37ae7a1..937e1e4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
-project(COSC-4P80-Final-Project VERSION 0.0.12)
+project(COSC-4P80-Final-Project VERSION 0.0.13)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
diff --git a/src/MNIST.cpp b/src/MNIST.cpp
index e3f7534..65d3b69 100644
--- a/src/MNIST.cpp
+++ b/src/MNIST.cpp
@@ -38,6 +38,7 @@ namespace fp
std::atomic_bool break_flag = false;
std::atomic_bool stop_flag = false;
std::atomic_bool learn_flag = false;
+ std::atomic_int64_t last_epoch = -1;
void run_python_line_graph(const std::string& title, const std::string& output_file, const std::string& csv1, const std::string& csv2,
const blt::size_t pos_forward, const blt::size_t pos_deep)
@@ -364,6 +365,27 @@ namespace fp
v /= run_stats.size();
return stats;
}
+
+ friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats)
+ {
+ file << stats.run_stats.size();
+ for (const auto& v : stats.run_stats)
+ file << v << "---\n";
+ return file;
+ }
+
+ friend std::ifstream& operator>>(std::ifstream& file, network_average_stats_t& stats)
+ {
+ blt::size_t size;
+ file >> size;
+ for (blt::size_t i = 0; i < size; i++)
+ {
+ stats.run_stats.emplace_back();
+ file >> stats.run_stats.back();
+ file.ignore(4);
+ }
+ return file;
+ }
};
template
@@ -435,6 +457,8 @@ namespace fp
trainer.set_synchronization_file("mnist_sync_" + ident, std::chrono::seconds(20));
blt::size_t epochs = 0;
+ if (last_epoch > 0)
+ epochs = static_cast(last_epoch);
blt::ptrdiff_t epoch_pos = 0;
for (; epochs < trainer.get_max_num_epochs() && trainer.get_learning_rate() >= trainer.get_min_learning_rate(); epochs++)
{
@@ -477,6 +501,7 @@ namespace fp
if (break_flag)
{
break_flag = false;
+ last_epoch = epochs;
break;
}
// dlib::serialize("mnist_network_" + ident + ".dat") << network;
@@ -520,7 +545,31 @@ namespace fp
network_average_stats_t stats{};
std::vector test_stats;
- for (blt::i32 i = 0; i < runs; i++)
+ blt::i32 i = 0;
+ if (std::filesystem::exists(path + "/state.bin"))
+ {
+ std::ifstream state{path + "/state.bin", std::ios::binary | std::ios::in};
+ if (!state.is_open())
+ {
+ BLT_ERROR("Failed to open state file!");
+ std::exit(-1);
+ }
+
+ state >> i;
+ blt::i64 load_epoch = 0;
+ state >> load_epoch;
+ last_epoch = load_epoch;
+ state >> stats;
+ blt::size_t test_stats_size = 0;
+ state >> test_stats_size;
+ for (blt::size_t _ = 0; _ < test_stats_size; _++)
+ {
+ test_stats.emplace_back();
+ state >> test_stats.back();
+ }
+ }
+
+ for (; i < runs; i++)
{
if (stop_flag)
break;
@@ -546,6 +595,20 @@ namespace fp
average += v;
average /= runs;
+ std::ofstream state{path + "/state.bin", std::ios::binary | std::ios::out};
+ if (!state.is_open())
+ {
+ BLT_ERROR("Failed to open state file!");
+ std::exit(-1);
+ }
+
+ state << i;
+ state << last_epoch.load(std::memory_order_relaxed);
+ state << stats;
+ state << test_stats.size();
+ for (const auto& v : test_stats)
+ state << v;
+
return {stats, average};
}
@@ -583,11 +646,11 @@ namespace fp
blt::size_t pos = 0;
if (!blt::string::ends_with(binary_directory, '/'))
{
- pos = binary_directory.find_last_of('/') - 1;
+ pos = binary_directory.find_last_of('/');
binary_directory += '/';
}
else
- pos = binary_directory.substr(0, binary_directory.size() - 1).find_last_of('/') - 1;
+ pos = binary_directory.substr(0, binary_directory.size() - 1).find_last_of('/');
python_dual_stacked_graph_program = binary_directory.substr(0, pos) + "/graph.py";
BLT_TRACE(binary_directory);
BLT_TRACE(python_dual_stacked_graph_program);