first test
parent
72f3019700
commit
8963ea41ba
|
@ -0,0 +1,54 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import sys
|
||||
|
||||
def plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2):
|
||||
# Read CSV files
|
||||
data1 = pd.read_csv(csv_file1, header=0)
|
||||
data2 = pd.read_csv(csv_file2, header=0)
|
||||
|
||||
# Extract column titles
|
||||
x1_label, y1_label = data1.columns[0], data1.columns[1]
|
||||
x2_label, y2_label = data2.columns[0], data2.columns[1]
|
||||
|
||||
# Extract data
|
||||
x1, y1 = data1[x1_label], data1[y1_label]
|
||||
x2, y2 = data2[x2_label], data2[y2_label]
|
||||
|
||||
# Create the plot
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(x1, y1, label=f"{csv_file1}")
|
||||
ax.plot(x2, y2, label=f"{csv_file2}")
|
||||
|
||||
ax.fill_between(x1, y1, alpha=0.5)
|
||||
ax.fill_between(x2, y2, alpha=0.5)
|
||||
|
||||
if position < 2 ** 32:
|
||||
ax.axvline(x=position, color='red', linestyle='--')
|
||||
ax.text(position, ax.get_ylim()[1] * 0.95, f"Feed-forward average # of epochs", color='red', fontsize=10, ha='right', va='top', backgroundcolor='white')
|
||||
|
||||
if position2 < 2 ** 32:
|
||||
ax.axvline(x=position2, color='red', linestyle='--')
|
||||
ax.text(position2, ax.get_ylim()[1] * 0.95, f"Deep learning average # of epochs", color='red', fontsize=10, ha='right', va='top', backgroundcolor='white')
|
||||
|
||||
ax.set_xlabel(x1_label)
|
||||
ax.set_ylabel(y1_label)
|
||||
ax.legend()
|
||||
ax.set_title(title)
|
||||
|
||||
plt.savefig(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 5:
|
||||
print("Usage: python script.py <title> <output_file> <csv_file1> <csv_file2> <position_feed_forward> <position_deep>")
|
||||
sys.exit(1)
|
||||
|
||||
csv_file1 = sys.argv[3]
|
||||
csv_file2 = sys.argv[4]
|
||||
title = sys.argv[1]
|
||||
output = sys.argv[2]
|
||||
position = sys.argv[5]
|
||||
position2 = sys.argv[6]
|
||||
|
||||
plot_stacked_graph(title, output, csv_file1, csv_file2, position, position2)
|
124
src/MNIST.cpp
124
src/MNIST.cpp
|
@ -20,7 +20,8 @@
|
|||
#include <blt/std/memory.h>
|
||||
#include <blt/std/memory_util.h>
|
||||
#include <variant>
|
||||
#include <bits/fs_ops.h>
|
||||
#include <filesystem>
|
||||
#include <iomanip>
|
||||
#include <blt/iterator/iterator.h>
|
||||
#include <blt/parse/argparse.h>
|
||||
#include <blt/std/time.h>
|
||||
|
@ -29,6 +30,18 @@
|
|||
|
||||
namespace fp
|
||||
{
|
||||
std::string binary_directory;
|
||||
std::string python_dual_stacked_graph_program;
|
||||
|
||||
void run_python_line_graph(const std::string& title, const std::string& output_file, const std::string& csv1, const std::string& csv2,
|
||||
const blt::size_t pos_forward, const blt::size_t pos_deep)
|
||||
{
|
||||
const auto command = "python3 " + python_dual_stacked_graph_program + " '" + title + "' '" + output_file + "' '" + csv1 + "' '" + csv2 + "' "
|
||||
+ std::to_string(pos_forward) + " " + std::to_string(pos_deep);
|
||||
BLT_TRACE("Running %s", command.c_str());
|
||||
std::system(command.c_str());
|
||||
}
|
||||
|
||||
class idx_file_t
|
||||
{
|
||||
template <typename T>
|
||||
|
@ -248,7 +261,7 @@ namespace fp
|
|||
|
||||
struct epoch_stats_t
|
||||
{
|
||||
batch_stats_t test_results {};
|
||||
batch_stats_t test_results{};
|
||||
double average_loss = 0;
|
||||
double learn_rate = 0;
|
||||
|
||||
|
@ -321,7 +334,7 @@ namespace fp
|
|||
return *this;
|
||||
}
|
||||
|
||||
blt::size_t average_size() const
|
||||
[[nodiscard]] blt::size_t average_size() const
|
||||
{
|
||||
blt::size_t acc = 0;
|
||||
for (const auto& [epoch_stats] : run_stats)
|
||||
|
@ -329,7 +342,7 @@ namespace fp
|
|||
return acc;
|
||||
}
|
||||
|
||||
network_stats_t average_stats() const
|
||||
[[nodiscard]] network_stats_t average_stats() const
|
||||
{
|
||||
network_stats_t stats;
|
||||
for (const auto& [epoch_stats] : run_stats)
|
||||
|
@ -380,8 +393,8 @@ namespace fp
|
|||
template <typename NetworkType>
|
||||
batch_stats_t test_network(NetworkType& network)
|
||||
{
|
||||
const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
|
||||
const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
|
||||
const idx_file_t test_images{binary_directory + "../problems/mnist/t10k-images.idx3-ubyte"};
|
||||
const idx_file_t test_labels{binary_directory + "../problems/mnist/t10k-labels.idx1-ubyte"};
|
||||
|
||||
const image_t test_image{test_images, test_labels};
|
||||
|
||||
|
@ -398,8 +411,8 @@ namespace fp
|
|||
template <typename NetworkType>
|
||||
network_stats_t train_network(const std::string& ident, NetworkType& network)
|
||||
{
|
||||
const idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"};
|
||||
const idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"};
|
||||
const idx_file_t train_images{binary_directory + "../problems/mnist/train-images.idx3-ubyte"};
|
||||
const idx_file_t train_labels{binary_directory + "../problems/mnist/train-labels.idx1-ubyte"};
|
||||
|
||||
const image_t train_image{train_images, train_labels};
|
||||
|
||||
|
@ -480,8 +493,9 @@ namespace fp
|
|||
return network;
|
||||
}
|
||||
|
||||
template<typename NetworkType>
|
||||
void run_network_tests(std::string path, const std::string& ident, const blt::i32 runs, const bool restore)
|
||||
template <typename NetworkType>
|
||||
std::pair<network_average_stats_t, batch_stats_t> run_network_tests(std::string path, const std::string& ident, const blt::i32 runs,
|
||||
const bool restore)
|
||||
{
|
||||
path += ("/" + ident + "/");
|
||||
std::filesystem::create_directories(path);
|
||||
|
@ -495,16 +509,28 @@ namespace fp
|
|||
auto local_ident = ident + std::to_string(i);
|
||||
NetworkType network{};
|
||||
if (restore)
|
||||
network = load_network<NetworkType>(local_ident);
|
||||
try
|
||||
{
|
||||
network = load_network<NetworkType>(local_ident);
|
||||
}
|
||||
catch (dlib::serialization_error&)
|
||||
{
|
||||
stats += train_network(local_ident, network);
|
||||
}
|
||||
else
|
||||
stats += train_network(local_ident, network);
|
||||
test_stats.push_back(test_network(network));
|
||||
}
|
||||
|
||||
batch_stats_t average;
|
||||
for (const auto& v : test_stats)
|
||||
average += v;
|
||||
average /= runs;
|
||||
|
||||
return {stats, average};
|
||||
}
|
||||
|
||||
void run_deep_learning_tests(const std::string& path, const blt::i32 runs, const bool restore)
|
||||
auto run_deep_learning_tests(const std::string& path, const blt::i32 runs, const bool restore)
|
||||
{
|
||||
using namespace dlib;
|
||||
using net_type_dl = loss_multiclass_log<
|
||||
|
@ -514,30 +540,41 @@ namespace fp
|
|||
max_pool<2, 2, 2, 2, relu<con<16, 5, 5, 1, 1,
|
||||
max_pool<2, 2, 2, 2, relu<con<6, 5, 5, 1, 1,
|
||||
input<matrix<blt::u8>>>>>>>>>>>>>>;
|
||||
run_network_tests<net_type_dl>(path, "deep_learning", runs, restore);
|
||||
return run_network_tests<net_type_dl>(path, "deep_learning", runs, restore);
|
||||
}
|
||||
|
||||
void run_feed_forward_tests(const std::string& path, const blt::i32 runs, const bool restore)
|
||||
auto run_feed_forward_tests(const std::string& path, const blt::i32 runs, const bool restore)
|
||||
{
|
||||
using namespace dlib;
|
||||
|
||||
using net_type_ff = loss_multiclass_log<
|
||||
fc<10,
|
||||
relu<fc<84,
|
||||
relu<fc<120,
|
||||
input<matrix<blt::u8>>>>>>>>;
|
||||
relu<fc<120,
|
||||
input<matrix<blt::u8>>>>>>>>;
|
||||
|
||||
run_network_tests<net_type_ff>(path, "feed_forward", runs, restore);
|
||||
return run_network_tests<net_type_ff>(path, "feed_forward", runs, restore);
|
||||
}
|
||||
|
||||
void run_mnist(const int argc, const char** argv)
|
||||
{
|
||||
binary_directory = std::filesystem::current_path();
|
||||
if (!blt::string::ends_with(binary_directory, '/'))
|
||||
binary_directory += '/';
|
||||
python_dual_stacked_graph_program = binary_directory + "../graph.py";
|
||||
BLT_TRACE(binary_directory);
|
||||
BLT_TRACE(python_dual_stacked_graph_program);
|
||||
|
||||
using namespace dlib;
|
||||
|
||||
blt::arg_parse parser{};
|
||||
parser.addArgument(
|
||||
blt::arg_builder{"-r", "--restore"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).setHelp("Restores from last save").build());
|
||||
blt::arg_builder{"-r", "--restore"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).setHelp(
|
||||
"Restores from last save").build());
|
||||
parser.addArgument(blt::arg_builder{"-t", "--runs"}.setHelp("Number of runs to perform [default: 10]").setDefault("10").build());
|
||||
parser.addArgument(
|
||||
blt::arg_builder{"-p", "--python"}.setHelp("Only run the python scripts").setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).
|
||||
build());
|
||||
parser.addArgument(blt::arg_builder{"type"}.setDefault("all").setHelp("Type of network to run [ff, dl, default: all]").build());
|
||||
|
||||
auto args = parser.parse_args(argc, argv);
|
||||
|
@ -545,16 +582,57 @@ namespace fp
|
|||
const auto type = blt::string::toLowerCase(args.get<std::string>("type"));
|
||||
const auto runs = std::stoi(args.get<std::string>("runs"));
|
||||
const auto restore = args.get<bool>("restore");
|
||||
const auto path = "./" + std::to_string(blt::system::getCurrentTimeMilliseconds());
|
||||
const auto path = binary_directory + std::to_string(blt::system::getCurrentTimeMilliseconds());
|
||||
|
||||
|
||||
if (type == "all")
|
||||
{
|
||||
run_deep_learning_tests(path, runs, restore);
|
||||
run_feed_forward_tests(path, runs, restore);
|
||||
} else if (type == "ff")
|
||||
auto [deep_stats, deep_tests] = run_deep_learning_tests(path, runs, restore);
|
||||
auto [forward_stats, forward_tests] = run_feed_forward_tests(path, runs, restore);
|
||||
|
||||
auto average_forward_size = forward_stats.average_size();
|
||||
auto average_deep_size = deep_stats.average_size();
|
||||
|
||||
{
|
||||
std::ofstream test_results_f{path + "/test_results_table.txt"};
|
||||
test_results_f << "\\begin{figure}" << std::endl;
|
||||
test_results_f << "\t\\begin{tabular}{|c|c|c|c|}" << std::endl;
|
||||
test_results_f << "\t\t\\hline" << std::endl;
|
||||
test_results_f << "\t\tTest & Correct & Incorrect & Accuracy (\\%) \\\\" << std::endl;
|
||||
test_results_f << "\t\t\\hline" << std::endl;
|
||||
auto test_accuracy = forward_tests.hits / static_cast<double>(forward_tests.hits + forward_tests.misses) * 100;
|
||||
test_results_f << "\t\tFeed-Forward & " << forward_tests.hits << " & " << forward_tests.misses << " & " << std::setprecision(2) <<
|
||||
test_accuracy << "\\\\" << std::endl;
|
||||
test_accuracy = deep_tests.hits / static_cast<double>(deep_tests.hits + deep_tests.misses) * 100;
|
||||
test_results_f << "\t\tDeep Learning & " << deep_tests.hits << " & " << deep_tests.misses << " & " << std::setprecision(2) <<
|
||||
test_accuracy << "\\\\" << std::endl;
|
||||
test_results_f << "\t\\end{tabular}" << std::endl;
|
||||
test_results_f << "\\end{figure}" << std::endl;
|
||||
|
||||
const auto [forward_epoch_stats] = forward_stats.average_stats();
|
||||
std::ofstream train_forward{path + "/forward_train_results.csv"};
|
||||
train_forward << "Epoch,Loss" << std::endl;
|
||||
for (const auto& [i, v] : blt::enumerate(forward_epoch_stats))
|
||||
train_forward << i << ',' << v.average_loss << std::endl;
|
||||
|
||||
const auto [deep_epoch_stats] = deep_stats.average_stats();
|
||||
std::ofstream train_deep{path + "/deep_train_results.csv"};
|
||||
train_deep << "Epoch,Loss" << std::endl;
|
||||
for (const auto& [i, v] : blt::enumerate(deep_epoch_stats))
|
||||
train_deep << i << ',' << v.average_loss << std::endl;
|
||||
|
||||
std::ofstream average_epochs{path + "/average_epochs.txt"};
|
||||
average_epochs << average_forward_size << "," << average_deep_size << std::endl;
|
||||
}
|
||||
|
||||
run_python_line_graph("Feed-Forward vs Deep Learning, Average Loss over Epochs", "epochs.png", path + "/forward_train_results.csv",
|
||||
path + "/deep_train_results.csv", average_forward_size, average_deep_size);
|
||||
}
|
||||
else if (type == "ff")
|
||||
{
|
||||
run_feed_forward_tests(path, runs, restore);
|
||||
} else if (type == "df")
|
||||
}
|
||||
else if (type == "df")
|
||||
{
|
||||
run_deep_learning_tests(path, runs, restore);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue