Compare commits

..

No commits in common. "10dc129e1876711bab47bc552ffb36d149952336" and "cc7624415612166b3d0e5b6789fb032005758615" have entirely different histories.

3 changed files with 15 additions and 51 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Final-Project VERSION 0.0.23)
project(COSC-4P80-Final-Project VERSION 0.0.15)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -21,6 +21,9 @@ def plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)
ax.plot(x1, y1, label=f"{csv_file1}")
ax.plot(x2, y2, label=f"{csv_file2}")
ax.fill_between(x1, y1, alpha=0.5)
ax.fill_between(x2, y2, alpha=0.5)
if position < 2 ** 32:
ax.axvline(x=position, color='red', linestyle='--')
ax.text(position, ax.get_ylim()[1] * 0.95, f"Feed-forward average # of epochs", color='red', fontsize=10, ha='right', va='top', backgroundcolor='white')
@ -31,13 +34,13 @@ def plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)
ax.set_xlabel(x1_label)
ax.set_ylabel(y1_label)
# ax.legend()
ax.legend()
ax.set_title(title)
plt.savefig(output)
if __name__ == "__main__":
if len(sys.argv) != 7:
if len(sys.argv) != 5:
print("Usage: python script.py <title> <output_file> <csv_file1> <csv_file2> <position_feed_forward> <position_deep>")
sys.exit(1)
@ -48,4 +51,4 @@ if __name__ == "__main__":
position = sys.argv[5]
position2 = sys.argv[6]
plot_stacked_graph(title, output, csv_file1, csv_file2, int(position), int(position2))
plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)

View File

@ -312,7 +312,6 @@ namespace fp
friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats)
{
file << stats.epoch_stats.size();
file << '\n';
for (const auto& v : stats.epoch_stats)
file << v << "\n";
return file;
@ -322,7 +321,6 @@ namespace fp
{
blt::size_t size;
file >> size;
file.ignore();
for (blt::size_t i = 0; i < size; i++)
{
stats.epoch_stats.emplace_back();
@ -371,7 +369,6 @@ namespace fp
friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats)
{
file << stats.run_stats.size();
file << '\n';
for (const auto& v : stats.run_stats)
file << v << "---\n";
return file;
@ -381,7 +378,6 @@ namespace fp
{
blt::size_t size;
file >> size;
file.ignore();
for (blt::size_t i = 0; i < size; i++)
{
stats.run_stats.emplace_back();
@ -505,7 +501,7 @@ namespace fp
if (break_flag)
{
break_flag = false;
last_epoch = epochs + 1;
last_epoch = epochs;
break;
}
// dlib::serialize("mnist_network_" + ident + ".dat") << network;
@ -560,29 +556,20 @@ namespace fp
}
state >> i;
state.ignore();
blt::i64 load_epoch = 0;
state >> load_epoch;
state.ignore();
last_epoch = load_epoch;
state >> stats;
state.ignore();
blt::size_t test_stats_size = 0;
state >> test_stats_size;
state.ignore();
for (blt::size_t _ = 0; _ < test_stats_size; _++)
{
test_stats.emplace_back();
state >> test_stats.back();
state.ignore();
}
}
BLT_TRACE("Restoring at run %d with epoch %ld", i, load_epoch);
BLT_TRACE("\tRestored state size %lu", stats.run_stats.size());
BLT_TRACE("\tRestored test size %lu", test_stats_size);
}
blt::i64 last_epoch_save = last_epoch;
blt::i64 last_epoch_save = -1;
for (; i < runs; i++)
{
if (stop_flag)
@ -600,20 +587,10 @@ namespace fp
}
catch (dlib::serialization_error&)
{
goto train_label;
stats += train_network(local_ident, network);
}
else
{
train_label:
auto stat = train_network(local_ident, network);
if (last_epoch_save > 0)
{
// add in all the new epochs
auto& vec = stats.run_stats.back();
vec.epoch_stats.insert(vec.epoch_stats.end(), stat.epoch_stats.begin(), stat.epoch_stats.end());
} else
stats += stat;
}
stats += train_network(local_ident, network);
last_epoch_save = last_epoch;
last_epoch = -1;
test_stats.push_back(test_network(network));
@ -631,27 +608,12 @@ namespace fp
std::exit(-1);
}
// user can skip this if required.
state << std::max(i - 1, 0);
state << '\n';
state << i;
state << last_epoch_save;
state << '\n';
state << stats;
state << '\n';
blt::i32 remove = 0;
if (stop_flag)
remove = 1;
state << static_cast<blt::size_t>(std::max(static_cast<blt::i64>(test_stats.size()) - remove, 0l));
state << '\n';
if (!test_stats.empty())
{
// the last test stat will be recalculated on restore. keeping it is an error.
for (const auto& v : blt::iterate(test_stats).take(test_stats.size() - remove))
{
state << test_stats.size();
for (const auto& v : test_stats)
state << v;
state << '\n';
}
}
return {stats, average};
}
@ -794,7 +756,6 @@ namespace fp
average_epochs << average_forward_size << "," << average_deep_size << std::endl;
}
BLT_INFO("Running python!");
run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv",
path + "/deep_train_results.csv", average_forward_size, average_deep_size);