diff --git a/CMakeLists.txt b/CMakeLists.txt index 0fa0a52..771dcc8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Final-Project VERSION 0.0.15) +project(COSC-4P80-Final-Project VERSION 0.0.16) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/graph.py b/graph.py index 6421a2d..fee20a3 100644 --- a/graph.py +++ b/graph.py @@ -40,7 +40,7 @@ def plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2) plt.savefig(output) if __name__ == "__main__": - if len(sys.argv) != 5: + if len(sys.argv) != 7: print("Usage: python script.py <output_file> <csv_file1> <csv_file2> <position_feed_forward> <position_deep>") sys.exit(1) diff --git a/src/MNIST.cpp b/src/MNIST.cpp index 1d61cdc..63833a0 100644 --- a/src/MNIST.cpp +++ b/src/MNIST.cpp @@ -312,6 +312,7 @@ namespace fp friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats) { file << stats.epoch_stats.size(); + file << '\n'; for (const auto& v : stats.epoch_stats) file << v << "\n"; return file; @@ -321,6 +322,7 @@ namespace fp { blt::size_t size; file >> size; + file.ignore(); for (blt::size_t i = 0; i < size; i++) { stats.epoch_stats.emplace_back(); @@ -369,6 +371,7 @@ namespace fp friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats) { file << stats.run_stats.size(); + file << '\n'; for (const auto& v : stats.run_stats) file << v << "---\n"; return file; @@ -378,6 +381,7 @@ namespace fp { blt::size_t size; file >> size; + file.ignore(); for (blt::size_t i = 0; i < size; i++) { stats.run_stats.emplace_back(); @@ -556,16 +560,21 @@ namespace fp } state >> i; + state.ignore(); blt::i64 load_epoch = 0; state >> load_epoch; + state.ignore(); last_epoch = load_epoch; state >> stats; + state.ignore(); blt::size_t test_stats_size = 0; state >> test_stats_size; + state.ignore(); for (blt::size_t _ = 0; _ < test_stats_size; _++) { test_stats.emplace_back(); state >> test_stats.back(); + state.ignore(); } } @@ -609,11 +618,18 @@ namespace fp } state << i; + state << '\n'; state << last_epoch_save; + state << '\n'; state << stats; + state << '\n'; state << test_stats.size(); + state << '\n'; for (const auto& v : test_stats) + { state << v; + state << '\n'; + } return {stats, average}; }