Compare commits
8 Commits
cc76244156
...
10dc129e18
Author | SHA1 | Date |
---|---|---|
|
10dc129e18 | |
|
961fcfd714 | |
|
22befde2f5 | |
|
830883930b | |
|
f9d57b3579 | |
|
56b54ed78e | |
|
377372842c | |
|
cf64816d2a |
|
@ -1,5 +1,5 @@
|
||||||
cmake_minimum_required(VERSION 3.25)
|
cmake_minimum_required(VERSION 3.25)
|
||||||
project(COSC-4P80-Final-Project VERSION 0.0.15)
|
project(COSC-4P80-Final-Project VERSION 0.0.23)
|
||||||
|
|
||||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||||
|
|
9
graph.py
9
graph.py
|
@ -20,9 +20,6 @@ def plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)
|
||||||
|
|
||||||
ax.plot(x1, y1, label=f"{csv_file1}")
|
ax.plot(x1, y1, label=f"{csv_file1}")
|
||||||
ax.plot(x2, y2, label=f"{csv_file2}")
|
ax.plot(x2, y2, label=f"{csv_file2}")
|
||||||
|
|
||||||
ax.fill_between(x1, y1, alpha=0.5)
|
|
||||||
ax.fill_between(x2, y2, alpha=0.5)
|
|
||||||
|
|
||||||
if position < 2 ** 32:
|
if position < 2 ** 32:
|
||||||
ax.axvline(x=position, color='red', linestyle='--')
|
ax.axvline(x=position, color='red', linestyle='--')
|
||||||
|
@ -34,13 +31,13 @@ def plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)
|
||||||
|
|
||||||
ax.set_xlabel(x1_label)
|
ax.set_xlabel(x1_label)
|
||||||
ax.set_ylabel(y1_label)
|
ax.set_ylabel(y1_label)
|
||||||
ax.legend()
|
# ax.legend()
|
||||||
ax.set_title(title)
|
ax.set_title(title)
|
||||||
|
|
||||||
plt.savefig(output)
|
plt.savefig(output)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) != 5:
|
if len(sys.argv) != 7:
|
||||||
print("Usage: python script.py <title> <output_file> <csv_file1> <csv_file2> <position_feed_forward> <position_deep>")
|
print("Usage: python script.py <title> <output_file> <csv_file1> <csv_file2> <position_feed_forward> <position_deep>")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -51,4 +48,4 @@ if __name__ == "__main__":
|
||||||
position = sys.argv[5]
|
position = sys.argv[5]
|
||||||
position2 = sys.argv[6]
|
position2 = sys.argv[6]
|
||||||
|
|
||||||
plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)
|
plot_stacked_graph(title, output, csv_file1, csv_file2, int(position), int(position2))
|
|
@ -312,6 +312,7 @@ namespace fp
|
||||||
friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats)
|
friend std::ofstream& operator<<(std::ofstream& file, const network_stats_t& stats)
|
||||||
{
|
{
|
||||||
file << stats.epoch_stats.size();
|
file << stats.epoch_stats.size();
|
||||||
|
file << '\n';
|
||||||
for (const auto& v : stats.epoch_stats)
|
for (const auto& v : stats.epoch_stats)
|
||||||
file << v << "\n";
|
file << v << "\n";
|
||||||
return file;
|
return file;
|
||||||
|
@ -321,6 +322,7 @@ namespace fp
|
||||||
{
|
{
|
||||||
blt::size_t size;
|
blt::size_t size;
|
||||||
file >> size;
|
file >> size;
|
||||||
|
file.ignore();
|
||||||
for (blt::size_t i = 0; i < size; i++)
|
for (blt::size_t i = 0; i < size; i++)
|
||||||
{
|
{
|
||||||
stats.epoch_stats.emplace_back();
|
stats.epoch_stats.emplace_back();
|
||||||
|
@ -369,6 +371,7 @@ namespace fp
|
||||||
friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats)
|
friend std::ofstream& operator<<(std::ofstream& file, const network_average_stats_t& stats)
|
||||||
{
|
{
|
||||||
file << stats.run_stats.size();
|
file << stats.run_stats.size();
|
||||||
|
file << '\n';
|
||||||
for (const auto& v : stats.run_stats)
|
for (const auto& v : stats.run_stats)
|
||||||
file << v << "---\n";
|
file << v << "---\n";
|
||||||
return file;
|
return file;
|
||||||
|
@ -378,6 +381,7 @@ namespace fp
|
||||||
{
|
{
|
||||||
blt::size_t size;
|
blt::size_t size;
|
||||||
file >> size;
|
file >> size;
|
||||||
|
file.ignore();
|
||||||
for (blt::size_t i = 0; i < size; i++)
|
for (blt::size_t i = 0; i < size; i++)
|
||||||
{
|
{
|
||||||
stats.run_stats.emplace_back();
|
stats.run_stats.emplace_back();
|
||||||
|
@ -501,7 +505,7 @@ namespace fp
|
||||||
if (break_flag)
|
if (break_flag)
|
||||||
{
|
{
|
||||||
break_flag = false;
|
break_flag = false;
|
||||||
last_epoch = epochs;
|
last_epoch = epochs + 1;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// dlib::serialize("mnist_network_" + ident + ".dat") << network;
|
// dlib::serialize("mnist_network_" + ident + ".dat") << network;
|
||||||
|
@ -556,20 +560,29 @@ namespace fp
|
||||||
}
|
}
|
||||||
|
|
||||||
state >> i;
|
state >> i;
|
||||||
|
state.ignore();
|
||||||
blt::i64 load_epoch = 0;
|
blt::i64 load_epoch = 0;
|
||||||
state >> load_epoch;
|
state >> load_epoch;
|
||||||
|
state.ignore();
|
||||||
last_epoch = load_epoch;
|
last_epoch = load_epoch;
|
||||||
state >> stats;
|
state >> stats;
|
||||||
|
state.ignore();
|
||||||
blt::size_t test_stats_size = 0;
|
blt::size_t test_stats_size = 0;
|
||||||
state >> test_stats_size;
|
state >> test_stats_size;
|
||||||
|
state.ignore();
|
||||||
for (blt::size_t _ = 0; _ < test_stats_size; _++)
|
for (blt::size_t _ = 0; _ < test_stats_size; _++)
|
||||||
{
|
{
|
||||||
test_stats.emplace_back();
|
test_stats.emplace_back();
|
||||||
state >> test_stats.back();
|
state >> test_stats.back();
|
||||||
|
state.ignore();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BLT_TRACE("Restoring at run %d with epoch %ld", i, load_epoch);
|
||||||
|
BLT_TRACE("\tRestored state size %lu", stats.run_stats.size());
|
||||||
|
BLT_TRACE("\tRestored test size %lu", test_stats_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
blt::i64 last_epoch_save = -1;
|
blt::i64 last_epoch_save = last_epoch;
|
||||||
for (; i < runs; i++)
|
for (; i < runs; i++)
|
||||||
{
|
{
|
||||||
if (stop_flag)
|
if (stop_flag)
|
||||||
|
@ -587,10 +600,20 @@ namespace fp
|
||||||
}
|
}
|
||||||
catch (dlib::serialization_error&)
|
catch (dlib::serialization_error&)
|
||||||
{
|
{
|
||||||
stats += train_network(local_ident, network);
|
goto train_label;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
stats += train_network(local_ident, network);
|
{
|
||||||
|
train_label:
|
||||||
|
auto stat = train_network(local_ident, network);
|
||||||
|
if (last_epoch_save > 0)
|
||||||
|
{
|
||||||
|
// add in all the new epochs
|
||||||
|
auto& vec = stats.run_stats.back();
|
||||||
|
vec.epoch_stats.insert(vec.epoch_stats.end(), stat.epoch_stats.begin(), stat.epoch_stats.end());
|
||||||
|
} else
|
||||||
|
stats += stat;
|
||||||
|
}
|
||||||
last_epoch_save = last_epoch;
|
last_epoch_save = last_epoch;
|
||||||
last_epoch = -1;
|
last_epoch = -1;
|
||||||
test_stats.push_back(test_network(network));
|
test_stats.push_back(test_network(network));
|
||||||
|
@ -608,12 +631,27 @@ namespace fp
|
||||||
std::exit(-1);
|
std::exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
state << i;
|
// user can skip this if required.
|
||||||
|
state << std::max(i - 1, 0);
|
||||||
|
state << '\n';
|
||||||
state << last_epoch_save;
|
state << last_epoch_save;
|
||||||
|
state << '\n';
|
||||||
state << stats;
|
state << stats;
|
||||||
state << test_stats.size();
|
state << '\n';
|
||||||
for (const auto& v : test_stats)
|
blt::i32 remove = 0;
|
||||||
state << v;
|
if (stop_flag)
|
||||||
|
remove = 1;
|
||||||
|
state << static_cast<blt::size_t>(std::max(static_cast<blt::i64>(test_stats.size()) - remove, 0l));
|
||||||
|
state << '\n';
|
||||||
|
if (!test_stats.empty())
|
||||||
|
{
|
||||||
|
// the last test stat will be recalculated on restore. keeping it is an error.
|
||||||
|
for (const auto& v : blt::iterate(test_stats).take(test_stats.size() - remove))
|
||||||
|
{
|
||||||
|
state << v;
|
||||||
|
state << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {stats, average};
|
return {stats, average};
|
||||||
}
|
}
|
||||||
|
@ -756,6 +794,7 @@ namespace fp
|
||||||
average_epochs << average_forward_size << "," << average_deep_size << std::endl;
|
average_epochs << average_forward_size << "," << average_deep_size << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BLT_INFO("Running python!");
|
||||||
run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv",
|
run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv",
|
||||||
path + "/deep_train_results.csv", average_forward_size, average_deep_size);
|
path + "/deep_train_results.csv", average_forward_size, average_deep_size);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue