more state

main
Brett 2025-01-08 19:03:28 -05:00
parent cc76244156
commit cf64816d2a
3 changed files with 18 additions and 2 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Final-Project VERSION 0.0.15) project(COSC-4P80-Final-Project VERSION 0.0.16)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -40,7 +40,7 @@ def plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)
plt.savefig(output) plt.savefig(output)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 5: if len(sys.argv) != 7:
print("Usage: python script.py <title> <output_file> <csv_file1> <csv_file2> <position_feed_forward> <position_deep>") print("Usage: python script.py <title> <output_file> <csv_file1> <csv_file2> <position_feed_forward> <position_deep>")
sys.exit(1) sys.exit(1)

View File

@ -312,6 +312,7 @@ namespace fp
friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats) friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats)
{ {
file << stats.epoch_stats.size(); file << stats.epoch_stats.size();
file << '\n';
for (const auto& v : stats.epoch_stats) for (const auto& v : stats.epoch_stats)
file << v << "\n"; file << v << "\n";
return file; return file;
@ -321,6 +322,7 @@ namespace fp
{ {
blt::size_t size; blt::size_t size;
file >> size; file >> size;
file.ignore();
for (blt::size_t i = 0; i < size; i++) for (blt::size_t i = 0; i < size; i++)
{ {
stats.epoch_stats.emplace_back(); stats.epoch_stats.emplace_back();
@ -369,6 +371,7 @@ namespace fp
friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats) friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats)
{ {
file << stats.run_stats.size(); file << stats.run_stats.size();
file << '\n';
for (const auto& v : stats.run_stats) for (const auto& v : stats.run_stats)
file << v << "---\n"; file << v << "---\n";
return file; return file;
@ -378,6 +381,7 @@ namespace fp
{ {
blt::size_t size; blt::size_t size;
file >> size; file >> size;
file.ignore();
for (blt::size_t i = 0; i < size; i++) for (blt::size_t i = 0; i < size; i++)
{ {
stats.run_stats.emplace_back(); stats.run_stats.emplace_back();
@ -556,16 +560,21 @@ namespace fp
} }
state >> i; state >> i;
state.ignore();
blt::i64 load_epoch = 0; blt::i64 load_epoch = 0;
state >> load_epoch; state >> load_epoch;
state.ignore();
last_epoch = load_epoch; last_epoch = load_epoch;
state >> stats; state >> stats;
state.ignore();
blt::size_t test_stats_size = 0; blt::size_t test_stats_size = 0;
state >> test_stats_size; state >> test_stats_size;
state.ignore();
for (blt::size_t _ = 0; _ < test_stats_size; _++) for (blt::size_t _ = 0; _ < test_stats_size; _++)
{ {
test_stats.emplace_back(); test_stats.emplace_back();
state >> test_stats.back(); state >> test_stats.back();
state.ignore();
} }
} }
@ -609,11 +618,18 @@ namespace fp
} }
state << i; state << i;
state << '\n';
state << last_epoch_save; state << last_epoch_save;
state << '\n';
state << stats; state << stats;
state << '\n';
state << test_stats.size(); state << test_stats.size();
state << '\n';
for (const auto& v : test_stats) for (const auto& v : test_stats)
{
state << v; state << v;
state << '\n';
}
return {stats, average}; return {stats, average};
} }