main
Brett 2024-10-28 01:55:13 -04:00
parent 28efbf94e6
commit 2dc79a8acc
7 changed files with 272 additions and 72 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-2 VERSION 0.0.8) project(COSC-4P80-Assignment-2 VERSION 0.0.9)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -60,6 +60,19 @@ namespace assign2
std::vector<data_t> data_points; std::vector<data_t> data_points;
}; };
struct error_data_t
{
Scalar error;
Scalar d_error;
error_data_t& operator+=(const error_data_t& e)
{
error += e.error;
d_error += e.d_error;
return *this;
}
};
class layer_t; class layer_t;
class network_t; class network_t;
@ -197,9 +210,15 @@ namespace assign2
line_data.is_bad = std::stoi(*line_data_it) == 1; line_data.is_bad = std::stoi(*line_data_it) == 1;
line_data.bins.reserve(bin_count); line_data.bins.reserve(bin_count);
Scalar total = 0; Scalar total = 0;
Scalar min = 1000;
Scalar max = 0;
for (++line_data_it; line_data_it != line_data_meta.end(); ++line_data_it) for (++line_data_it; line_data_it != line_data_meta.end(); ++line_data_it)
{ {
auto v = std::stof(*line_data_it); auto v = std::stof(*line_data_it);
if (v > max)
max = v;
if (v < min)
min = v;
total += v * v; total += v * v;
line_data.bins.push_back(v); line_data.bins.push_back(v);
} }
@ -207,8 +226,12 @@ namespace assign2
// normalize vector. // normalize vector.
total = std::sqrt(total); total = std::sqrt(total);
// //
for (auto& v : line_data.bins) // for (auto& v : line_data.bins)
v /= total; // {
// v /= total;
// v *= 2.71828;
// v -= 2.71828 / 2;
// }
// //
// if (line_data.bins.size() == 32) // if (line_data.bins.size() == 32)
// print_vec(line_data.bins) << std::endl; // print_vec(line_data.bins) << std::endl;

View File

@ -30,7 +30,6 @@ namespace assign2
{ {
inline blt::size_t layer_id_counter = 0; inline blt::size_t layer_id_counter = 0;
inline const blt::size_t distance_between_layers = 250;
inline std::atomic_bool pause_mode = true; inline std::atomic_bool pause_mode = true;
inline std::atomic_bool pause_flag = false; inline std::atomic_bool pause_flag = false;
@ -54,6 +53,10 @@ namespace assign2
inline std::vector<Scalar> errors_over_time; inline std::vector<Scalar> errors_over_time;
inline std::vector<Scalar> error_derivative_over_time; inline std::vector<Scalar> error_derivative_over_time;
inline std::vector<Scalar> error_of_test;
inline std::vector<Scalar> error_of_test_derivative;
inline std::vector<Scalar> error_derivative_of_test;
inline std::vector<Scalar> correct_over_time; inline std::vector<Scalar> correct_over_time;
inline std::vector<node_data> nodes; inline std::vector<node_data> nodes;
} }

View File

@ -41,10 +41,12 @@ namespace assign2
explicit neuron_t(weight_view weights, weight_view dw, Scalar bias): bias(bias), dw(dw), weights(weights) explicit neuron_t(weight_view weights, weight_view dw, Scalar bias): bias(bias), dw(dw), weights(weights)
{} {}
Scalar activate(const Scalar* inputs, function_t* act_func) Scalar activate(const std::vector<Scalar>& inputs, function_t* act_func)
{ {
BLT_ASSERT_MSG(inputs.size() == weights.size(), (std::to_string(inputs.size()) + " vs " + std::to_string(weights.size())).c_str());
z = bias; z = bias;
for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()})) for (auto [x, w] : blt::zip_iterator_container({inputs.begin(), inputs.end()}, {weights.begin(), weights.end()}))
z += x * w; z += x * w;
a = act_func->call(z); a = act_func->call(z);
return a; return a;
@ -131,11 +133,11 @@ namespace assign2
throw std::runtime_exception("Input vector doesn't match expected input size!"); throw std::runtime_exception("Input vector doesn't match expected input size!");
#endif #endif
for (auto& n : neurons) for (auto& n : neurons)
outputs.push_back(n.activate(in.data(), act_func)); outputs.push_back(n.activate(in, act_func));
return outputs; return outputs;
} }
std::pair<Scalar, Scalar> back_prop(const std::vector<Scalar>& prev_layer_output, error_data_t back_prop(const std::vector<Scalar>& prev_layer_output,
const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data) const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data)
{ {
Scalar total_error = 0; Scalar total_error = 0;
@ -151,6 +153,8 @@ namespace assign2
total_derivative += d; total_derivative += d;
n.back_prop(act_func, prev_layer_output, d); n.back_prop(act_func, prev_layer_output, d);
} }
total_error /= static_cast<Scalar>(expected.size());
total_derivative /= static_cast<Scalar>(expected.size());
}, },
// interior layer // interior layer
[this, &prev_layer_output](const layer_t& layer) { [this, &prev_layer_output](const layer_t& layer) {
@ -208,9 +212,32 @@ namespace assign2
#ifdef BLT_USE_GRAPHICS #ifdef BLT_USE_GRAPHICS
void render() const void render(blt::gfx::batch_renderer_2d& renderer) const
{ {
const blt::size_t distance_between_layers = 30;
const float neuron_size = 30;
const float padding = -5;
for (const auto& [i, n] : blt::enumerate(neurons))
{
auto color = std::abs(n.a);
renderer.drawPointInternal(blt::make_color(0.1, 0.1, 0.1),
blt::gfx::point2d_t{static_cast<float>(i) * (neuron_size + padding) + neuron_size,
static_cast<float>(layer_id * distance_between_layers) + neuron_size,
neuron_size / 2}, 10);
auto outline_size = neuron_size + 10;
renderer.drawPointInternal(blt::make_color(color, color, color),
blt::gfx::point2d_t{static_cast<float>(i) * (neuron_size + padding) + neuron_size,
static_cast<float>(layer_id * distance_between_layers) + neuron_size,
outline_size / 2}, 0);
// const ImVec2 alignment = ImVec2(0.5f, 0.5f);
// if (i > 0)
// ImGui::SameLine();
// ImGui::PushStyleVar(ImGuiStyleVar_SelectableTextAlign, alignment);
// std::string name;
// name = std::to_string(n.a);
// ImGui::Selectable(name.c_str(), false, ImGuiSelectableFlags_None, ImVec2(80, 80));
// ImGui::PopStyleVar();
}
} }
#endif #endif

View File

@ -78,58 +78,96 @@ namespace assign2
std::vector<blt::ref<const std::vector<Scalar>>> outputs; std::vector<blt::ref<const std::vector<Scalar>>> outputs;
outputs.emplace_back(input); outputs.emplace_back(input);
for (auto& v : layers) for (auto [i, v] : blt::enumerate(layers))
{
// auto in = outputs.back();
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Going In: ";
// print_vec(in.get()) << std::endl;
// auto& out = v->call(in);
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Coming out: ";
// print_vec(out) << std::endl;
//// std::cout << "(" << i << "/" << layers.size() << ") Weights: ";
//// v->weights.debug();
//// std::cout << std::endl;
// std::cout << std::endl;
//
// outputs.emplace_back(out);
outputs.emplace_back(v->call(outputs.back())); outputs.emplace_back(v->call(outputs.back()));
}
// std::cout << std::endl;
return outputs.back(); return outputs.back();
} }
std::pair<Scalar, Scalar> train_epoch(const data_file_t& example) error_data_t error(const data_file_t& data)
{ {
Scalar total_error = 0; Scalar total_error = 0;
Scalar total_d_error = 0; Scalar total_d_error = 0;
for (const auto& x : example.data_points)
{
execute(x.bins);
std::vector<Scalar> expected{x.is_bad ? 0.0f : 1.0f, x.is_bad ? 1.0f : 0.0f};
for (auto [i, layer] : blt::iterate(layers).enumerate().rev())
{
if (i == layers.size() - 1)
{
auto e = layer->back_prop(layers[i - 1]->outputs, expected);
// layer->update();
total_error += e.first;
total_d_error += e.second;
} else if (i == 0)
{
auto e = layer->back_prop(x.bins, *layers[i + 1]);
// layer->update();
total_error += e.first;
total_d_error += e.second;
} else
{
auto e = layer->back_prop(layers[i - 1]->outputs, *layers[i + 1]);
// layer->update();
total_error += e.first;
total_d_error += e.second;
}
}
for (auto& l : layers)
l->update();
}
// errors_over_time.push_back(total_error);
// BLT_DEBUG("Total Errors found %f, %f", total_error, total_d_error);
return {total_error, total_d_error}; for (auto& d : data.data_points)
{
std::vector<Scalar> expected{d.is_bad ? 0.0f : 1.0f, d.is_bad ? 1.0f : 0.0f};
auto out = execute(d.bins);
Scalar local_total_error = 0;
Scalar local_total_d_error = 0;
BLT_ASSERT(out.size() == expected.size());
for (auto [o, e] : blt::in_pairs(out, expected))
{
auto d_error = o - e;
auto error = 0.5f * (d_error * d_error);
local_total_error += error;
local_total_d_error += d_error;
}
total_error += local_total_error / 2;
total_d_error += local_total_d_error / 2;
}
return {total_error / static_cast<Scalar>(data.data_points.size()), total_d_error / static_cast<Scalar>(data.data_points.size())};
}
error_data_t train(const data_t& data)
{
error_data_t error = {0, 0};
execute(data.bins);
std::vector<Scalar> expected{data.is_bad ? 0.0f : 1.0f, data.is_bad ? 1.0f : 0.0f};
for (auto [i, layer] : blt::iterate(layers).enumerate().rev())
{
if (i == layers.size() - 1)
{
error += layer->back_prop(layers[i - 1]->outputs, expected);
} else if (i == 0)
{
error += layer->back_prop(data.bins, *layers[i + 1]);
} else
{
error += layer->back_prop(layers[i - 1]->outputs, *layers[i + 1]);
}
}
for (auto& l : layers)
l->update();
return error;
}
error_data_t train_epoch(const data_file_t& example)
{
error_data_t error {0, 0};
for (const auto& x : example.data_points)
error += train(x);
error.d_error /= static_cast<Scalar>(example.data_points.size());
error.error /= static_cast<Scalar>(example.data_points.size());
return error;
} }
#ifdef BLT_USE_GRAPHICS #ifdef BLT_USE_GRAPHICS
void render() const void render(blt::gfx::batch_renderer_2d& renderer) const
{ {
for (auto& l : layers) for (auto& l : layers)
l->render(); l->render(renderer);
} }
#endif #endif

@ -1 +1 @@
Subproject commit 8103a3ad0ff5341c61ba07c4b1b9803b2cc740e9 Subproject commit 1983a6789e12cb003ae457d68836be17bc4fbeba

View File

@ -9,11 +9,15 @@
#include <assign2/network.h> #include <assign2/network.h>
#include <memory> #include <memory>
#include <thread> #include <thread>
#include <algorithm>
#include <mutex>
using namespace assign2; using namespace assign2;
std::vector<data_file_t> data_files; std::vector<data_file_t> data_files;
random_init randomizer{619}; blt::hashmap_t<blt::i32, std::vector<data_file_t>> groups;
random_init randomizer{std::random_device{}()};
empty_init empty; empty_init empty;
small_init small; small_init small;
sigmoid_function sig; sigmoid_function sig;
@ -22,9 +26,9 @@ tanh_function func_tanh;
network_t create_network(blt::i32 input, blt::i32 hidden) network_t create_network(blt::i32 input, blt::i32 hidden)
{ {
auto layer1 = std::make_unique<layer_t>(input, hidden * 2, &sig, randomizer, empty); auto layer1 = std::make_unique<layer_t>(input, hidden, &sig, randomizer, empty);
auto layer2 = std::make_unique<layer_t>(hidden * 2, hidden / 2, &sig, randomizer, empty); auto layer2 = std::make_unique<layer_t>(hidden, hidden * 0.7, &sig, randomizer, empty);
auto layer_output = std::make_unique<layer_t>(hidden / 2, 2, &sig, randomizer, empty); auto layer_output = std::make_unique<layer_t>(hidden * 0.7, 2, &sig, randomizer, empty);
std::vector<std::unique_ptr<layer_t>> vec; std::vector<std::unique_ptr<layer_t>> vec;
vec.push_back(std::move(layer1)); vec.push_back(std::move(layer1));
@ -49,18 +53,24 @@ blt::gfx::batch_renderer_2d renderer_2d(resources, global_matrices);
blt::gfx::first_person_camera_2d camera; blt::gfx::first_person_camera_2d camera;
blt::hashmap_t<blt::i32, network_t> networks; blt::hashmap_t<blt::i32, network_t> networks;
blt::hashmap_t<blt::i32, data_file_t*> file_map;
data_file_t current_training;
data_file_t current_testing;
std::atomic_int32_t run_epoch = -1;
std::mutex vec_lock;
std::unique_ptr<std::thread> network_thread; std::unique_ptr<std::thread> network_thread;
std::atomic_bool running = true; std::atomic_bool running = true;
std::atomic_bool run_exit = true; std::atomic_bool run_exit = true;
std::atomic_int32_t run_epoch = -1;
std::atomic_uint64_t epochs = 0; std::atomic_uint64_t epochs = 0;
blt::i32 time_between_runs = 0; blt::i32 time_between_runs = 0;
blt::size_t correct_recall = 0; blt::size_t correct_recall = 0;
blt::size_t wrong_recall = 0; blt::size_t wrong_recall = 0;
bool run_network = false; bool run_network = false;
void init(const blt::gfx::window_data& data) void init(const blt::gfx::window_data&)
{ {
using namespace blt::gfx; using namespace blt::gfx;
@ -79,26 +89,33 @@ void init(const blt::gfx::window_data& data)
int hidden = input * 1; int hidden = input * 1;
BLT_INFO("Making network of size %d", input); BLT_INFO("Making network of size %d", input);
layer_id_counter = 0;
networks[input] = create_network(input, hidden); networks[input] = create_network(input, hidden);
file_map[input] = &f;
} }
errors_over_time.reserve(25000); errors_over_time.reserve(25000);
error_derivative_over_time.reserve(25000); error_derivative_over_time.reserve(25000);
correct_over_time.reserve(25000); correct_over_time.reserve(25000);
error_of_test.reserve(25000);
error_of_test_derivative.reserve(25000);
network_thread = std::make_unique<std::thread>([]() { network_thread = std::make_unique<std::thread>([]() {
while (running) while (running)
{ {
if (run_epoch >= 0) if (run_epoch >= 0)
{ {
auto error = networks.at(run_epoch).train_epoch(*file_map[run_epoch]); std::scoped_lock lock(vec_lock);
errors_over_time.push_back(error.first); auto error = networks.at(run_epoch).train_epoch(current_training);
error_derivative_over_time.push_back(error.second); errors_over_time.push_back(error.error);
error_derivative_over_time.push_back(error.d_error);
auto error_test = networks.at(run_epoch).error(current_testing);
error_of_test.push_back(error_test.error);
error_of_test_derivative.push_back(error_test.d_error);
blt::size_t right = 0; blt::size_t right = 0;
blt::size_t wrong = 0; blt::size_t wrong = 0;
for (auto& d : file_map[run_epoch]->data_points) for (auto& d : current_testing.data_points)
{ {
auto out = networks.at(run_epoch).execute(d.bins); auto out = networks.at(run_epoch).execute(d.bins);
auto is_bad = is_thinks_bad(out); auto is_bad = is_thinks_bad(out);
@ -200,16 +217,34 @@ void update(const blt::gfx::window_data& data)
errors_over_time.clear(); errors_over_time.clear();
correct_over_time.clear(); correct_over_time.clear();
error_derivative_over_time.clear(); error_derivative_over_time.clear();
error_of_test.clear();
error_of_test_derivative.clear();
run_network = false; run_network = false;
} }
ImGui::Separator(); ImGui::Separator();
ImGui::Text("Using network %d size %d", selected, net->first); ImGui::Text("Using network %d size %d", selected, net->first);
static bool pause = pause_mode.load();
ImGui::Checkbox("Stepped Mode", &pause);
pause_mode = pause;
ImGui::Checkbox("Train Network", &run_network); ImGui::Checkbox("Train Network", &run_network);
if (run_network) if (run_network)
{
if (groups[net->first].size() > 1)
{
std::scoped_lock lock(vec_lock);
current_testing.data_points.clear();
current_training.data_points.clear();
current_testing.data_points.insert(current_testing.data_points.begin(), groups[net->first].front().data_points.begin(),
groups[net->first].front().data_points.end());
for (auto a : blt::iterate(groups[net->first]).skip(1))
current_training.data_points.insert(current_training.data_points.begin(), a.data_points.begin(), a.data_points.end());
} else
{
std::scoped_lock lock(vec_lock);
current_training = groups[net->first].front();
current_testing = groups[net->first].front();
}
run_epoch = net->first; run_epoch = net->first;
}
ImGui::InputInt("Time Between Runs", &time_between_runs); ImGui::InputInt("Time Between Runs", &time_between_runs);
if (time_between_runs < 0) if (time_between_runs < 0)
time_between_runs = 0; time_between_runs = 0;
@ -227,7 +262,7 @@ void update(const blt::gfx::window_data& data)
BLT_INFO("Test Cases:"); BLT_INFO("Test Cases:");
blt::size_t right = 0; blt::size_t right = 0;
blt::size_t wrong = 0; blt::size_t wrong = 0;
for (auto& d : file_map[net->first]->data_points) for (auto& d : current_testing.data_points)
{ {
std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: "; std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: ";
auto out = net->second.execute(d.bins); auto out = net->second.execute(d.bins);
@ -261,20 +296,34 @@ void update(const blt::gfx::window_data& data)
static ImPlotRect lims(0, 100, 0, 1); static ImPlotRect lims(0, 100, 0, 1);
if (ImPlot::BeginAlignedPlots("AlignedGroup")) if (ImPlot::BeginAlignedPlots("AlignedGroup"))
{ {
plot_vector(lims, errors_over_time, "Error", "Time", "Error", [](auto v, bool b) { plot_vector(lims, errors_over_time, "Global Error over epochs", "Epoch", "Error", [](auto v, bool b) {
float percent = 0.15; float percent = 0.15;
if (b) if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent); return v < 0 ? v * (1 + percent) : v * (1 - percent);
else else
return v < 0 ? v * (1 - percent) : v * (1 + percent); return v < 0 ? v * (1 - percent) : v * (1 + percent);
}); });
plot_vector(lims, correct_over_time, "Correct", "Time", "Correct", [](auto v, bool b) { plot_vector(lims, error_of_test, "Global Error (Tests)", "Epoch", "Error", [](auto v, bool b) {
float percent = 0.15;
if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent);
else
return v < 0 ? v * (1 - percent) : v * (1 + percent);
});
plot_vector(lims, correct_over_time, "% Correct over epochs", "Epoch", "Correct%", [](auto v, bool b) {
if (b) if (b)
return v - 1; return v - 1;
else else
return v + 1; return v + 1;
}); });
plot_vector(lims, error_derivative_over_time, "DError/Dw", "Time", "Error", [](auto v, bool b) { plot_vector(lims, error_derivative_over_time, "DError/Dw over epochs", "Epoch", "Error", [](auto v, bool b) {
float percent = 0.05;
if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent);
else
return v < 0 ? v * (1 - percent) : v * (1 + percent);
});
plot_vector(lims, error_of_test_derivative, "DError/Dw (Test)", "Epoch", "Error", [](auto v, bool b) {
float percent = 0.05; float percent = 0.05;
if (b) if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent); return v < 0 ? v * (1 + percent) : v * (1 - percent);
@ -290,7 +339,7 @@ void update(const blt::gfx::window_data& data)
ImGui::Begin("Hello", nullptr, ImGui::Begin("Hello", nullptr,
ImGuiWindowFlags_AlwaysAutoResize | ImGuiWindowFlags_NoBackground | ImGuiWindowFlags_NoCollapse | ImGuiWindowFlags_NoInputs | ImGuiWindowFlags_AlwaysAutoResize | ImGuiWindowFlags_NoBackground | ImGuiWindowFlags_NoCollapse | ImGuiWindowFlags_NoInputs |
ImGuiWindowFlags_NoTitleBar); ImGuiWindowFlags_NoTitleBar);
net->second.render(); net->second.render(renderer_2d);
ImGui::End(); ImGui::End();
renderer_2d.render(data.width, data.height); renderer_2d.render(data.width, data.height);
@ -308,7 +357,6 @@ void destroy()
network_thread->join(); network_thread->join();
network_thread = nullptr; network_thread = nullptr;
networks.clear(); networks.clear();
file_map.clear();
ImPlot::DestroyContext(); ImPlot::DestroyContext();
global_matrices.cleanup(); global_matrices.cleanup();
resources.cleanup(); resources.cleanup();
@ -322,11 +370,72 @@ int main(int argc, const char** argv)
{ {
blt::arg_parse parser; blt::arg_parse parser;
parser.addArgument(blt::arg_builder("-f", "--file").setHelp("path to the data files").setDefault("../data").build()); parser.addArgument(blt::arg_builder("-f", "--file").setHelp("path to the data files").setDefault("../data").build());
parser.addArgument(
blt::arg_builder("-k", "--kfold").setHelp("Number of groups to split into").setAction(blt::arg_action_t::STORE).setNArgs('?')
.setConst("3").build());
auto args = parser.parse_args(argc, argv); auto args = parser.parse_args(argc, argv);
std::string data_directory = blt::string::ensure_ends_with_path_separator(args.get<std::string>("file")); std::string data_directory = blt::string::ensure_ends_with_path_separator(args.get<std::string>("file"));
data_files = load_data_files(get_data_files(data_directory)); data_files = load_data_files(get_data_files(data_directory));
if (args.contains("kfold"))
{
auto kfold = std::stoul(args.get<std::string>("kfold"));
BLT_INFO("Running K-Fold-%ld", kfold);
blt::random::random_t rand(std::random_device{}());
for (auto& n : data_files)
{
std::vector<data_t> goods;
std::vector<data_t> bads;
for (auto& p : n.data_points)
{
if (p.is_bad)
bads.push_back(p);
else
goods.push_back(p);
}
// can randomize the order of good and bad inputs
std::shuffle(goods.begin(), goods.end(), rand);
std::shuffle(bads.begin(), bads.end(), rand);
auto size = static_cast<blt::i32>(n.data_points.begin()->bins.size());
groups[size] = {};
for (blt::size_t i = 0; i < kfold; i++)
groups[size].emplace_back();
// then copy proportionally into the groups, creating roughly equal groups of data.
blt::size_t select = 0;
for (auto& v : goods)
{
++select %= kfold;
groups[size][select].data_points.push_back(v);
}
// because bad motors are in a separate step they are still proportional
for (auto& v : bads)
{
++select %= kfold;
groups[size][select].data_points.push_back(v);
}
}
} else
{
for (auto& n : data_files)
{
auto size = static_cast<blt::i32>(n.data_points.begin()->bins.size());
groups[size].push_back(n);
}
}
for (const auto& [set, g] : groups)
{
BLT_INFO("Set %d has groups %ld", set, g.size());
for (auto [i, f] : blt::enumerate(g))
BLT_INFO("\tData file %ld contains %ld elements", i + 1, f.data_points.size());
}
#ifdef BLT_USE_GRAPHICS #ifdef BLT_USE_GRAPHICS
blt::gfx::init(blt::gfx::window_data{"Freeplay Graphics", init, update, 1440, 720}.setSyncInterval(1).setMonitor(glfwGetPrimaryMonitor()) blt::gfx::init(blt::gfx::window_data{"Freeplay Graphics", init, update, 1440, 720}.setSyncInterval(1).setMonitor(glfwGetPrimaryMonitor())
@ -338,9 +447,9 @@ int main(int argc, const char** argv)
for (auto f : data_files) for (auto f : data_files)
{ {
int input = static_cast<int>(f.data_points.begin()->bins.size()); int input = static_cast<int>(f.data_points.begin()->bins.size());
int hidden = input * 3; int hidden = input;
if (input != 64) if (input != 16)
continue; continue;
BLT_INFO("-----------------"); BLT_INFO("-----------------");
@ -358,7 +467,6 @@ int main(int argc, const char** argv)
blt::size_t wrong = 0; blt::size_t wrong = 0;
for (auto& d : f.data_points) for (auto& d : f.data_points)
{ {
std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: ";
auto out = network.execute(d.bins); auto out = network.execute(d.bins);
auto is_bad = is_thinks_bad(out); auto is_bad = is_thinks_bad(out);
@ -367,6 +475,7 @@ int main(int argc, const char** argv)
else else
wrong++; wrong++;
std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: ";
std::cout << "NN Thinks: " << (is_bad ? "Bad" : "Good") << " || Outs: ["; std::cout << "NN Thinks: " << (is_bad ? "Bad" : "Good") << " || Outs: [";
print_vec(out) << "]" << std::endl; print_vec(out) << "]" << std::endl;
} }