silly boi
parent
8392855bc5
commit
72f3019700
|
@ -1,5 +1,5 @@
|
||||||
cmake_minimum_required(VERSION 3.25)
|
cmake_minimum_required(VERSION 3.25)
|
||||||
project(COSC-4P80-Final-Project VERSION 0.0.6)
|
project(COSC-4P80-Final-Project VERSION 0.0.7)
|
||||||
|
|
||||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||||
|
|
167
src/MNIST.cpp
167
src/MNIST.cpp
|
@ -20,7 +20,10 @@
|
||||||
#include <blt/std/memory.h>
|
#include <blt/std/memory.h>
|
||||||
#include <blt/std/memory_util.h>
|
#include <blt/std/memory_util.h>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
#include <bits/fs_ops.h>
|
||||||
#include <blt/iterator/iterator.h>
|
#include <blt/iterator/iterator.h>
|
||||||
|
#include <blt/parse/argparse.h>
|
||||||
|
#include <blt/std/time.h>
|
||||||
#include <dlib/dnn.h>
|
#include <dlib/dnn.h>
|
||||||
#include <dlib/data_io.h>
|
#include <dlib/data_io.h>
|
||||||
|
|
||||||
|
@ -211,8 +214,8 @@ namespace fp
|
||||||
|
|
||||||
struct batch_stats_t
|
struct batch_stats_t
|
||||||
{
|
{
|
||||||
blt::u64 hits;
|
blt::u64 hits = 0;
|
||||||
blt::u64 misses;
|
blt::u64 misses = 0;
|
||||||
|
|
||||||
friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats)
|
friend std::ofstream& operator<<(std::ofstream& file, const batch_stats_t& stats)
|
||||||
{
|
{
|
||||||
|
@ -245,9 +248,9 @@ namespace fp
|
||||||
|
|
||||||
struct epoch_stats_t
|
struct epoch_stats_t
|
||||||
{
|
{
|
||||||
batch_stats_t test_results;
|
batch_stats_t test_results {};
|
||||||
double average_loss;
|
double average_loss = 0;
|
||||||
double learn_rate;
|
double learn_rate = 0;
|
||||||
|
|
||||||
friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats)
|
friend std::ofstream& operator<<(std::ofstream& file, const epoch_stats_t& stats)
|
||||||
{
|
{
|
||||||
|
@ -306,8 +309,42 @@ namespace fp
|
||||||
}
|
}
|
||||||
return file;
|
return file;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct network_average_stats_t
|
||||||
|
{
|
||||||
|
std::vector<network_stats_t> run_stats;
|
||||||
|
|
||||||
|
network_average_stats_t& operator+=(const network_stats_t& stats)
|
||||||
|
{
|
||||||
|
run_stats.push_back(stats);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
blt::size_t average_size() const
|
||||||
|
{
|
||||||
|
blt::size_t acc = 0;
|
||||||
|
for (const auto& [epoch_stats] : run_stats)
|
||||||
|
acc += epoch_stats.size();
|
||||||
|
return acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
network_stats_t average_stats() const
|
||||||
|
{
|
||||||
|
network_stats_t stats;
|
||||||
|
for (const auto& [epoch_stats] : run_stats)
|
||||||
|
{
|
||||||
|
if (stats.epoch_stats.size() < epoch_stats.size())
|
||||||
|
stats.epoch_stats.resize(epoch_stats.size());
|
||||||
|
for (const auto& [i, v] : blt::enumerate(epoch_stats))
|
||||||
|
{
|
||||||
|
stats.epoch_stats[i] += v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& v : stats.epoch_stats)
|
||||||
|
v /= run_stats.size();
|
||||||
|
return stats;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <blt::i64 batch_size = 128, typename NetworkType>
|
template <blt::i64 batch_size = 128, typename NetworkType>
|
||||||
|
@ -341,7 +378,7 @@ namespace fp
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NetworkType>
|
template <typename NetworkType>
|
||||||
void test_network(NetworkType& network)
|
batch_stats_t test_network(NetworkType& network)
|
||||||
{
|
{
|
||||||
const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
|
const idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
|
||||||
const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
|
const idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
|
||||||
|
@ -351,9 +388,11 @@ namespace fp
|
||||||
auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(),
|
auto test_results = test_batch(network, test_image.get_image_data().begin(), test_image.get_image_data().end(),
|
||||||
test_image.get_image_labels().begin());
|
test_image.get_image_labels().begin());
|
||||||
|
|
||||||
BLT_INFO("Testing hits: %lu", test_results.hits);
|
BLT_DEBUG("Testing hits: %lu", test_results.hits);
|
||||||
BLT_INFO("Testing misses: %lu", test_results.misses);
|
BLT_DEBUG("Testing misses: %lu", test_results.misses);
|
||||||
BLT_INFO("Testing accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses));
|
BLT_DEBUG("Testing accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses));
|
||||||
|
|
||||||
|
return test_results;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NetworkType>
|
template <typename NetworkType>
|
||||||
|
@ -393,11 +432,12 @@ namespace fp
|
||||||
data.begin() + end, labels.begin() + begin);
|
data.begin() + end, labels.begin() + begin);
|
||||||
}
|
}
|
||||||
epoch_pos = 0;
|
epoch_pos = 0;
|
||||||
BLT_DEBUG("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(),
|
BLT_TRACE("Trained an epoch (%ld/%ld) learn rate %lf average loss %lf", epochs, trainer.get_max_num_epochs(),
|
||||||
trainer.get_learning_rate(), trainer.get_average_loss());
|
trainer.get_learning_rate(), trainer.get_average_loss());
|
||||||
|
|
||||||
// sync and test
|
// sync and test
|
||||||
trainer.get_net(dlib::force_flush_to_disk::no);
|
trainer.get_net(dlib::force_flush_to_disk::no);
|
||||||
|
network.clean();
|
||||||
|
|
||||||
epoch_stats_t epoch_stats{};
|
epoch_stats_t epoch_stats{};
|
||||||
epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
|
epoch_stats.test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
|
||||||
|
@ -405,10 +445,12 @@ namespace fp
|
||||||
epoch_stats.average_loss = trainer.get_average_loss();
|
epoch_stats.average_loss = trainer.get_average_loss();
|
||||||
epoch_stats.learn_rate = trainer.get_learning_rate();
|
epoch_stats.learn_rate = trainer.get_learning_rate();
|
||||||
|
|
||||||
BLT_DEBUG("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses,
|
BLT_TRACE("\t\tHits: %lu\tMisses: %lu\tAccuracy: %lf", epoch_stats.test_results.hits, epoch_stats.test_results.misses,
|
||||||
epoch_stats.test_results.hits / static_cast<double>(epoch_stats.test_results.hits + epoch_stats.test_results.misses));
|
epoch_stats.test_results.hits / static_cast<double>(epoch_stats.test_results.hits + epoch_stats.test_results.misses));
|
||||||
|
|
||||||
stats.epoch_stats.push_back(epoch_stats);
|
stats.epoch_stats.push_back(epoch_stats);
|
||||||
|
network.clean();
|
||||||
|
// dlib::serialize("mnist_network_" + ident + ".dat") << network;
|
||||||
}
|
}
|
||||||
|
|
||||||
BLT_INFO("Finished Training");
|
BLT_INFO("Finished Training");
|
||||||
|
@ -423,9 +465,9 @@ namespace fp
|
||||||
auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
|
auto test_results = test_batch(network, train_image.get_image_data().begin(), train_image.get_image_data().end(),
|
||||||
train_image.get_image_labels().begin());
|
train_image.get_image_labels().begin());
|
||||||
|
|
||||||
BLT_INFO("Training hits: %lu", test_results.hits);
|
BLT_DEBUG("Training hits: %lu", test_results.hits);
|
||||||
BLT_INFO("Training misses: %lu", test_results.misses);
|
BLT_DEBUG("Training misses: %lu", test_results.misses);
|
||||||
BLT_INFO("Training accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses));
|
BLT_DEBUG("Training accuracy: %lf", test_results.hits / static_cast<double>(test_results.hits + test_results.misses));
|
||||||
|
|
||||||
return stats;
|
return stats;
|
||||||
}
|
}
|
||||||
|
@ -438,29 +480,92 @@ namespace fp
|
||||||
return network;
|
return network;
|
||||||
}
|
}
|
||||||
|
|
||||||
void run_mnist(int argc, const char** argv)
|
template<typename NetworkType>
|
||||||
|
void run_network_tests(std::string path, const std::string& ident, const blt::i32 runs, const bool restore)
|
||||||
|
{
|
||||||
|
path += ("/" + ident + "/");
|
||||||
|
std::filesystem::create_directories(path);
|
||||||
|
std::filesystem::current_path(path);
|
||||||
|
|
||||||
|
network_average_stats_t stats{};
|
||||||
|
std::vector<batch_stats_t> test_stats;
|
||||||
|
|
||||||
|
for (blt::i32 i = 0; i < runs; i++)
|
||||||
|
{
|
||||||
|
auto local_ident = ident + std::to_string(i);
|
||||||
|
NetworkType network{};
|
||||||
|
if (restore)
|
||||||
|
network = load_network<NetworkType>(local_ident);
|
||||||
|
else
|
||||||
|
stats += train_network(local_ident, network);
|
||||||
|
test_stats.push_back(test_network(network));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void run_deep_learning_tests(const std::string& path, const blt::i32 runs, const bool restore)
|
||||||
|
{
|
||||||
|
using namespace dlib;
|
||||||
|
using net_type_dl = loss_multiclass_log<
|
||||||
|
fc<10,
|
||||||
|
relu<fc<84,
|
||||||
|
relu<fc<120,
|
||||||
|
max_pool<2, 2, 2, 2, relu<con<16, 5, 5, 1, 1,
|
||||||
|
max_pool<2, 2, 2, 2, relu<con<6, 5, 5, 1, 1,
|
||||||
|
input<matrix<blt::u8>>>>>>>>>>>>>>;
|
||||||
|
run_network_tests<net_type_dl>(path, "deep_learning", runs, restore);
|
||||||
|
}
|
||||||
|
|
||||||
|
void run_feed_forward_tests(const std::string& path, const blt::i32 runs, const bool restore)
|
||||||
{
|
{
|
||||||
using namespace dlib;
|
using namespace dlib;
|
||||||
|
|
||||||
// using net_type = loss_multiclass_log<
|
using net_type_ff = loss_multiclass_log<
|
||||||
// fc<10,
|
|
||||||
// relu<fc<84,
|
|
||||||
// relu<fc<120,
|
|
||||||
// max_pool<2,2,2,2,relu<con<16,5,5,1,1,
|
|
||||||
// max_pool<2,2,2,2,relu<con<6,5,5,1,1,
|
|
||||||
// input<matrix<blt::u8>>>>>>>>>>>>>>;
|
|
||||||
|
|
||||||
using net_type = loss_multiclass_log<
|
|
||||||
fc<10,
|
fc<10,
|
||||||
sig<fc<84,
|
relu<fc<84,
|
||||||
sig<fc<120,
|
relu<fc<120,
|
||||||
input<matrix<blt::u8>>>>>>>>;
|
input<matrix<blt::u8>>>>>>>>;
|
||||||
|
|
||||||
net_type test_net;
|
run_network_tests<net_type_ff>(path, "feed_forward", runs, restore);
|
||||||
const auto stats = train_network("fc_nn", test_net);
|
}
|
||||||
std::ofstream out_file{"fc_nn.csv"};
|
|
||||||
out_file << stats;
|
|
||||||
|
|
||||||
test_network(test_net);
|
void run_mnist(const int argc, const char** argv)
|
||||||
|
{
|
||||||
|
using namespace dlib;
|
||||||
|
|
||||||
|
blt::arg_parse parser{};
|
||||||
|
parser.addArgument(
|
||||||
|
blt::arg_builder{"-r", "--restore"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).setHelp("Restores from last save").build());
|
||||||
|
parser.addArgument(blt::arg_builder{"-t", "--runs"}.setHelp("Number of runs to perform [default: 10]").setDefault("10").build());
|
||||||
|
parser.addArgument(blt::arg_builder{"type"}.setDefault("all").setHelp("Type of network to run [ff, dl, default: all]").build());
|
||||||
|
|
||||||
|
auto args = parser.parse_args(argc, argv);
|
||||||
|
|
||||||
|
const auto type = blt::string::toLowerCase(args.get<std::string>("type"));
|
||||||
|
const auto runs = std::stoi(args.get<std::string>("runs"));
|
||||||
|
const auto restore = args.get<bool>("restore");
|
||||||
|
const auto path = "./" + std::to_string(blt::system::getCurrentTimeMilliseconds());
|
||||||
|
|
||||||
|
if (type == "all")
|
||||||
|
{
|
||||||
|
run_deep_learning_tests(path, runs, restore);
|
||||||
|
run_feed_forward_tests(path, runs, restore);
|
||||||
|
} else if (type == "ff")
|
||||||
|
{
|
||||||
|
run_feed_forward_tests(path, runs, restore);
|
||||||
|
} else if (type == "df")
|
||||||
|
{
|
||||||
|
run_deep_learning_tests(path, runs, restore);
|
||||||
|
}
|
||||||
|
|
||||||
|
// net_type_dl test_net;
|
||||||
|
// const auto stats = train_network("dl_nn", test_net);
|
||||||
|
// std::ofstream out_file{"dl_nn.csv"};
|
||||||
|
// out_file << stats;
|
||||||
|
|
||||||
|
// test_net = load_network<net_type_dl>("dl_nn");
|
||||||
|
|
||||||
|
// test_network(test_net);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue