silly boi

main
Brett 2025-01-07 20:52:23 -05:00
parent 8392855bc5
commit 72f3019700
2 changed files with 137 additions and 32 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Final-Project VERSION 0.0.6) project(COSC-4P80-Final-Project VERSION 0.0.7)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -20,7 +20,10 @@
#include <blt/std/memory.h> #include <blt/std/memory.h>
#include <blt/std/memory_util.h> #include <blt/std/memory_util.h>
#include <variant> #include <variant>
#include <bits/fs_ops.h>
#include <blt/iterator/iterator.h> #include <blt/iterator/iterator.h>
#include <blt/parse/argparse.h>
#include <blt/std/time.h>
#include <dlib/dnn.h> #include <dlib/dnn.h>
#include <dlib/data_io.h> #include <dlib/data_io.h>
@ -211,8 +214,8 @@ namespace fp
struct batch_stats_t struct batch_stats_t
{ {
blt::u64 hits; blt::u64 hits = 0;
blt::u64 misses; blt::u64 misses = 0;
friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats) friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats)
{ {
@ -245,9 +248,9 @@ namespace fp
struct epoch_stats_t struct epoch_stats_t
{ {
batch_stats_t test_results; batch_stats_t test_results {};
double average_loss; double average_loss = 0;
double learn_rate; double learn_rate = 0;
friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats) friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats)
{ {
@ -306,8 +309,42 @@ namespace fp
} }
return file; return file;
} }
};
struct network_average_stats_t
{
std::vector<network_stats_t> run_stats;
network_average_stats_t& operator+=(const network_stats_t& stats)
{
run_stats.push_back(stats);
return *this;
}
blt::size_t average_size() const
{
blt::size_t acc = 0;
for (const auto& [epoch_stats] : run_stats)
acc += epoch_stats.size();
return acc;
}
network_stats_t average_stats() const
{
network_stats_t stats;
for (const auto& [epoch_stats] : run_stats)
{
if (stats.epoch_stats.size() < epoch_stats.size())
stats.epoch_stats.resize(epoch_stats.size());
for (const auto& [i, v] : blt::enumerate(epoch_stats))
{
stats.epoch_stats[i] += v;
}
}
for (auto& v : stats.epoch_stats)
v /= run_stats.size();
return stats;
}
}; };
template <blt::i64 batch_size = 128, typename NetworkType> template <blt::i64 batch_size = 128, typename NetworkType>
@ -341,7 +378,7 @@ namespace fp
} }
template <typename NetworkType> template <typename NetworkType>
void test_network(NetworkType& network) batch_stats_t test_network(NetworkType& network)
{ {
const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"}; const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"}; const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
@ -351,9 +388,11 @@ namespace fp
auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(), auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(),
test_image.get_image_labels().begin()); test_image.get_image_labels().begin());
BLT_INFO("Testing hits: %lu", test_results.hits); BLT_DEBUG("Testing hits: %lu", test_results.hits);
BLT_INFO("Testing misses: %lu", test_results.misses); BLT_DEBUG("Testing misses: %lu", test_results.misses);
BLT_INFO("Testing accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses)); BLT_DEBUG("Testing accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses));
return test_results;
} }
template <typename NetworkType> template <typename NetworkType>
@ -393,11 +432,12 @@ namespace fp
data.begin() + end, labels.begin() + begin); data.begin() + end, labels.begin() + begin);
} }
epoch_pos = 0; epoch_pos = 0;
BLT_DEBUG("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(), BLT_TRACE("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(),
trainer.get_learning_rate(), trainer.get_average_loss()); trainer.get_learning_rate(), trainer.get_average_loss());
// sync and test // sync and test
trainer.get_net(dlib::force_flush_to_disk::no); trainer.get_net(dlib::force_flush_to_disk::no);
network.clean();
epoch_stats_t epoch_stats{}; epoch_stats_t epoch_stats{};
epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(), epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
@ -405,10 +445,12 @@ namespace fp
epoch_stats.average_loss = trainer.get_average_loss(); epoch_stats.average_loss = trainer.get_average_loss();
epoch_stats.learn_rate = trainer.get_learning_rate(); epoch_stats.learn_rate = trainer.get_learning_rate();
BLT_DEBUG("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses, BLT_TRACE("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses,
epoch_stats.test_results.hits / static_cast<double>(epoch_stats.test_results.hits + epoch_stats.test_results.misses)); epoch_stats.test_results.hits / static_cast<double>(epoch_stats.test_results.hits + epoch_stats.test_results.misses));
stats.epoch_stats.push_back(epoch_stats); stats.epoch_stats.push_back(epoch_stats);
network.clean();
// dlib::serialize("mnist_network_" + ident + ".dat") << network;
} }
BLT_INFO("Finished Training"); BLT_INFO("Finished Training");
@ -423,9 +465,9 @@ namespace fp
auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(), auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
train_image.get_image_labels().begin()); train_image.get_image_labels().begin());
BLT_INFO("Training hits: %lu", test_results.hits); BLT_DEBUG("Training hits: %lu", test_results.hits);
BLT_INFO("Training misses: %lu", test_results.misses); BLT_DEBUG("Training misses: %lu", test_results.misses);
BLT_INFO("Training accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses)); BLT_DEBUG("Training accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses));
return stats; return stats;
} }
@ -438,29 +480,92 @@ namespace fp
return network; return network;
} }
void run_mnist(int argc, const char** argv) template<typename NetworkType>
void run_network_tests(std::string path, const std::string& ident, const blt::i32 runs, const bool restore)
{
path += ("/" + ident + "/");
std::filesystem::create_directories(path);
std::filesystem::current_path(path);
network_average_stats_t stats{};
std::vector<batch_stats_t> test_stats;
for (blt::i32 i = 0; i < runs; i++)
{
auto local_ident = ident + std::to_string(i);
NetworkType network{};
if (restore)
network = load_network<NetworkType>(local_ident);
else
stats += train_network(local_ident, network);
test_stats.push_back(test_network(network));
}
}
void run_deep_learning_tests(const std::string& path, const blt::i32 runs, const bool restore)
{
using namespace dlib;
using net_type_dl = loss_multiclass_log<
fc<10,
relu<fc<84,
relu<fc<120,
max_pool<2, 2, 2, 2, relu<con<16, 5, 5, 1, 1,
max_pool<2, 2, 2, 2, relu<con<6, 5, 5, 1, 1,
input<matrix<blt::u8>>>>>>>>>>>>>>;
run_network_tests<net_type_dl>(path, "deep_learning", runs, restore);
}
void run_feed_forward_tests(const std::string& path, const blt::i32 runs, const bool restore)
{ {
using namespace dlib; using namespace dlib;
// using net_type = loss_multiclass_log< using net_type_ff = loss_multiclass_log<
// fc<10,
// relu<fc<84,
// relu<fc<120,
// max_pool<2,2,2,2,relu<con<16,5,5,1,1,
// max_pool<2,2,2,2,relu<con<6,5,5,1,1,
// input<matrix<blt::u8>>>>>>>>>>>>>>;
using net_type = loss_multiclass_log<
fc<10, fc<10,
sig<fc<84, relu<fc<84,
sig<fc<120, relu<fc<120,
input<matrix<blt::u8>>>>>>>>; input<matrix<blt::u8>>>>>>>>;
net_type test_net; run_network_tests<net_type_ff>(path, "feed_forward", runs, restore);
const auto stats = train_network("fc_nn", test_net); }
std::ofstream out_file{"fc_nn.csv"};
out_file << stats;
test_network(test_net); void run_mnist(const int argc, const char** argv)
{
using namespace dlib;
blt::arg_parse parser{};
parser.addArgument(
blt::arg_builder{"-r", "--restore"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).setHelp("Restores from last save").build());
parser.addArgument(blt::arg_builder{"-t", "--runs"}.setHelp("Number of runs to perform [default: 10]").setDefault("10").build());
parser.addArgument(blt::arg_builder{"type"}.setDefault("all").setHelp("Type of network to run [ff, dl, default: all]").build());
auto args = parser.parse_args(argc, argv);
const auto type = blt::string::toLowerCase(args.get<std::string>("type"));
const auto runs = std::stoi(args.get<std::string>("runs"));
const auto restore = args.get<bool>("restore");
const auto path = "./" + std::to_string(blt::system::getCurrentTimeMilliseconds());
if (type == "all")
{
run_deep_learning_tests(path, runs, restore);
run_feed_forward_tests(path, runs, restore);
} else if (type == "ff")
{
run_feed_forward_tests(path, runs, restore);
} else if (type == "df")
{
run_deep_learning_tests(path, runs, restore);
}
// net_type_dl test_net;
// const auto stats = train_network("dl_nn", test_net);
// std::ofstream out_file{"dl_nn.csv"};
// out_file << stats;
// test_net = load_network<net_type_dl>("dl_nn");
// test_network(test_net);
} }
} }