sleep
parent
28efbf94e6
commit
2dc79a8acc
|
@ -1,5 +1,5 @@
|
||||||
cmake_minimum_required(VERSION 3.25)
|
cmake_minimum_required(VERSION 3.25)
|
||||||
project(COSC-4P80-Assignment-2 VERSION 0.0.8)
|
project(COSC-4P80-Assignment-2 VERSION 0.0.9)
|
||||||
|
|
||||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||||
|
|
|
@ -60,6 +60,19 @@ namespace assign2
|
||||||
std::vector<data_t> data_points;
|
std::vector<data_t> data_points;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct error_data_t
|
||||||
|
{
|
||||||
|
Scalar error;
|
||||||
|
Scalar d_error;
|
||||||
|
|
||||||
|
error_data_t& operator+=(const error_data_t& e)
|
||||||
|
{
|
||||||
|
error += e.error;
|
||||||
|
d_error += e.d_error;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class layer_t;
|
class layer_t;
|
||||||
|
|
||||||
class network_t;
|
class network_t;
|
||||||
|
@ -197,9 +210,15 @@ namespace assign2
|
||||||
line_data.is_bad = std::stoi(*line_data_it) == 1;
|
line_data.is_bad = std::stoi(*line_data_it) == 1;
|
||||||
line_data.bins.reserve(bin_count);
|
line_data.bins.reserve(bin_count);
|
||||||
Scalar total = 0;
|
Scalar total = 0;
|
||||||
|
Scalar min = 1000;
|
||||||
|
Scalar max = 0;
|
||||||
for (++line_data_it; line_data_it != line_data_meta.end(); ++line_data_it)
|
for (++line_data_it; line_data_it != line_data_meta.end(); ++line_data_it)
|
||||||
{
|
{
|
||||||
auto v = std::stof(*line_data_it);
|
auto v = std::stof(*line_data_it);
|
||||||
|
if (v > max)
|
||||||
|
max = v;
|
||||||
|
if (v < min)
|
||||||
|
min = v;
|
||||||
total += v * v;
|
total += v * v;
|
||||||
line_data.bins.push_back(v);
|
line_data.bins.push_back(v);
|
||||||
}
|
}
|
||||||
|
@ -207,8 +226,12 @@ namespace assign2
|
||||||
// normalize vector.
|
// normalize vector.
|
||||||
total = std::sqrt(total);
|
total = std::sqrt(total);
|
||||||
//
|
//
|
||||||
for (auto& v : line_data.bins)
|
// for (auto& v : line_data.bins)
|
||||||
v /= total;
|
// {
|
||||||
|
// v /= total;
|
||||||
|
// v *= 2.71828;
|
||||||
|
// v -= 2.71828 / 2;
|
||||||
|
// }
|
||||||
//
|
//
|
||||||
// if (line_data.bins.size() == 32)
|
// if (line_data.bins.size() == 32)
|
||||||
// print_vec(line_data.bins) << std::endl;
|
// print_vec(line_data.bins) << std::endl;
|
||||||
|
|
|
@ -30,7 +30,6 @@ namespace assign2
|
||||||
{
|
{
|
||||||
|
|
||||||
inline blt::size_t layer_id_counter = 0;
|
inline blt::size_t layer_id_counter = 0;
|
||||||
inline const blt::size_t distance_between_layers = 250;
|
|
||||||
inline std::atomic_bool pause_mode = true;
|
inline std::atomic_bool pause_mode = true;
|
||||||
inline std::atomic_bool pause_flag = false;
|
inline std::atomic_bool pause_flag = false;
|
||||||
|
|
||||||
|
@ -54,6 +53,10 @@ namespace assign2
|
||||||
|
|
||||||
inline std::vector<Scalar> errors_over_time;
|
inline std::vector<Scalar> errors_over_time;
|
||||||
inline std::vector<Scalar> error_derivative_over_time;
|
inline std::vector<Scalar> error_derivative_over_time;
|
||||||
|
inline std::vector<Scalar> error_of_test;
|
||||||
|
inline std::vector<Scalar> error_of_test_derivative;
|
||||||
|
|
||||||
|
inline std::vector<Scalar> error_derivative_of_test;
|
||||||
inline std::vector<Scalar> correct_over_time;
|
inline std::vector<Scalar> correct_over_time;
|
||||||
inline std::vector<node_data> nodes;
|
inline std::vector<node_data> nodes;
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,10 +41,12 @@ namespace assign2
|
||||||
explicit neuron_t(weight_view weights, weight_view dw, Scalar bias): bias(bias), dw(dw), weights(weights)
|
explicit neuron_t(weight_view weights, weight_view dw, Scalar bias): bias(bias), dw(dw), weights(weights)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Scalar activate(const Scalar* inputs, function_t* act_func)
|
Scalar activate(const std::vector<Scalar>& inputs, function_t* act_func)
|
||||||
{
|
{
|
||||||
|
BLT_ASSERT_MSG(inputs.size() == weights.size(), (std::to_string(inputs.size()) + " vs " + std::to_string(weights.size())).c_str());
|
||||||
|
|
||||||
z = bias;
|
z = bias;
|
||||||
for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()}))
|
for (auto [x, w] : blt::zip_iterator_container({inputs.begin(), inputs.end()}, {weights.begin(), weights.end()}))
|
||||||
z += x * w;
|
z += x * w;
|
||||||
a = act_func->call(z);
|
a = act_func->call(z);
|
||||||
return a;
|
return a;
|
||||||
|
@ -131,11 +133,11 @@ namespace assign2
|
||||||
throw std::runtime_exception("Input vector doesn't match expected input size!");
|
throw std::runtime_exception("Input vector doesn't match expected input size!");
|
||||||
#endif
|
#endif
|
||||||
for (auto& n : neurons)
|
for (auto& n : neurons)
|
||||||
outputs.push_back(n.activate(in.data(), act_func));
|
outputs.push_back(n.activate(in, act_func));
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Scalar, Scalar> back_prop(const std::vector<Scalar>& prev_layer_output,
|
error_data_t back_prop(const std::vector<Scalar>& prev_layer_output,
|
||||||
const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data)
|
const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data)
|
||||||
{
|
{
|
||||||
Scalar total_error = 0;
|
Scalar total_error = 0;
|
||||||
|
@ -151,6 +153,8 @@ namespace assign2
|
||||||
total_derivative += d;
|
total_derivative += d;
|
||||||
n.back_prop(act_func, prev_layer_output, d);
|
n.back_prop(act_func, prev_layer_output, d);
|
||||||
}
|
}
|
||||||
|
total_error /= static_cast<Scalar>(expected.size());
|
||||||
|
total_derivative /= static_cast<Scalar>(expected.size());
|
||||||
},
|
},
|
||||||
// interior layer
|
// interior layer
|
||||||
[this, &prev_layer_output](const layer_t& layer) {
|
[this, &prev_layer_output](const layer_t& layer) {
|
||||||
|
@ -208,9 +212,32 @@ namespace assign2
|
||||||
|
|
||||||
#ifdef BLT_USE_GRAPHICS
|
#ifdef BLT_USE_GRAPHICS
|
||||||
|
|
||||||
void render() const
|
void render(blt::gfx::batch_renderer_2d& renderer) const
|
||||||
{
|
{
|
||||||
|
const blt::size_t distance_between_layers = 30;
|
||||||
|
const float neuron_size = 30;
|
||||||
|
const float padding = -5;
|
||||||
|
for (const auto& [i, n] : blt::enumerate(neurons))
|
||||||
|
{
|
||||||
|
auto color = std::abs(n.a);
|
||||||
|
renderer.drawPointInternal(blt::make_color(0.1, 0.1, 0.1),
|
||||||
|
blt::gfx::point2d_t{static_cast<float>(i) * (neuron_size + padding) + neuron_size,
|
||||||
|
static_cast<float>(layer_id * distance_between_layers) + neuron_size,
|
||||||
|
neuron_size / 2}, 10);
|
||||||
|
auto outline_size = neuron_size + 10;
|
||||||
|
renderer.drawPointInternal(blt::make_color(color, color, color),
|
||||||
|
blt::gfx::point2d_t{static_cast<float>(i) * (neuron_size + padding) + neuron_size,
|
||||||
|
static_cast<float>(layer_id * distance_between_layers) + neuron_size,
|
||||||
|
outline_size / 2}, 0);
|
||||||
|
// const ImVec2 alignment = ImVec2(0.5f, 0.5f);
|
||||||
|
// if (i > 0)
|
||||||
|
// ImGui::SameLine();
|
||||||
|
// ImGui::PushStyleVar(ImGuiStyleVar_SelectableTextAlign, alignment);
|
||||||
|
// std::string name;
|
||||||
|
// name = std::to_string(n.a);
|
||||||
|
// ImGui::Selectable(name.c_str(), false, ImGuiSelectableFlags_None, ImVec2(80, 80));
|
||||||
|
// ImGui::PopStyleVar();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -78,58 +78,96 @@ namespace assign2
|
||||||
std::vector<blt::ref<const std::vector<Scalar>>> outputs;
|
std::vector<blt::ref<const std::vector<Scalar>>> outputs;
|
||||||
outputs.emplace_back(input);
|
outputs.emplace_back(input);
|
||||||
|
|
||||||
for (auto& v : layers)
|
for (auto [i, v] : blt::enumerate(layers))
|
||||||
|
{
|
||||||
|
// auto in = outputs.back();
|
||||||
|
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Going In: ";
|
||||||
|
// print_vec(in.get()) << std::endl;
|
||||||
|
// auto& out = v->call(in);
|
||||||
|
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Coming out: ";
|
||||||
|
// print_vec(out) << std::endl;
|
||||||
|
//// std::cout << "(" << i << "/" << layers.size() << ") Weights: ";
|
||||||
|
//// v->weights.debug();
|
||||||
|
//// std::cout << std::endl;
|
||||||
|
// std::cout << std::endl;
|
||||||
|
//
|
||||||
|
// outputs.emplace_back(out);
|
||||||
outputs.emplace_back(v->call(outputs.back()));
|
outputs.emplace_back(v->call(outputs.back()));
|
||||||
|
}
|
||||||
|
// std::cout << std::endl;
|
||||||
|
|
||||||
return outputs.back();
|
return outputs.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Scalar, Scalar> train_epoch(const data_file_t& example)
|
error_data_t error(const data_file_t& data)
|
||||||
{
|
{
|
||||||
Scalar total_error = 0;
|
Scalar total_error = 0;
|
||||||
Scalar total_d_error = 0;
|
Scalar total_d_error = 0;
|
||||||
for (const auto& x : example.data_points)
|
|
||||||
|
for (auto& d : data.data_points)
|
||||||
{
|
{
|
||||||
execute(x.bins);
|
std::vector<Scalar> expected{d.is_bad ? 0.0f : 1.0f, d.is_bad ? 1.0f : 0.0f};
|
||||||
std::vector<Scalar> expected{x.is_bad ? 0.0f : 1.0f, x.is_bad ? 1.0f : 0.0f};
|
|
||||||
|
auto out = execute(d.bins);
|
||||||
|
|
||||||
|
Scalar local_total_error = 0;
|
||||||
|
Scalar local_total_d_error = 0;
|
||||||
|
BLT_ASSERT(out.size() == expected.size());
|
||||||
|
for (auto [o, e] : blt::in_pairs(out, expected))
|
||||||
|
{
|
||||||
|
auto d_error = o - e;
|
||||||
|
auto error = 0.5f * (d_error * d_error);
|
||||||
|
|
||||||
|
local_total_error += error;
|
||||||
|
local_total_d_error += d_error;
|
||||||
|
}
|
||||||
|
total_error += local_total_error / 2;
|
||||||
|
total_d_error += local_total_d_error / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {total_error / static_cast<Scalar>(data.data_points.size()), total_d_error / static_cast<Scalar>(data.data_points.size())};
|
||||||
|
}
|
||||||
|
|
||||||
|
error_data_t train(const data_t& data)
|
||||||
|
{
|
||||||
|
error_data_t error = {0, 0};
|
||||||
|
execute(data.bins);
|
||||||
|
std::vector<Scalar> expected{data.is_bad ? 0.0f : 1.0f, data.is_bad ? 1.0f : 0.0f};
|
||||||
|
|
||||||
for (auto [i, layer] : blt::iterate(layers).enumerate().rev())
|
for (auto [i, layer] : blt::iterate(layers).enumerate().rev())
|
||||||
{
|
{
|
||||||
if (i == layers.size() - 1)
|
if (i == layers.size() - 1)
|
||||||
{
|
{
|
||||||
auto e = layer->back_prop(layers[i - 1]->outputs, expected);
|
error += layer->back_prop(layers[i - 1]->outputs, expected);
|
||||||
// layer->update();
|
|
||||||
total_error += e.first;
|
|
||||||
total_d_error += e.second;
|
|
||||||
} else if (i == 0)
|
} else if (i == 0)
|
||||||
{
|
{
|
||||||
auto e = layer->back_prop(x.bins, *layers[i + 1]);
|
error += layer->back_prop(data.bins, *layers[i + 1]);
|
||||||
// layer->update();
|
|
||||||
total_error += e.first;
|
|
||||||
total_d_error += e.second;
|
|
||||||
} else
|
} else
|
||||||
{
|
{
|
||||||
auto e = layer->back_prop(layers[i - 1]->outputs, *layers[i + 1]);
|
error += layer->back_prop(layers[i - 1]->outputs, *layers[i + 1]);
|
||||||
// layer->update();
|
|
||||||
total_error += e.first;
|
|
||||||
total_d_error += e.second;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& l : layers)
|
for (auto& l : layers)
|
||||||
l->update();
|
l->update();
|
||||||
|
return error;
|
||||||
}
|
}
|
||||||
// errors_over_time.push_back(total_error);
|
|
||||||
// BLT_DEBUG("Total Errors found %f, %f", total_error, total_d_error);
|
|
||||||
|
|
||||||
return {total_error, total_d_error};
|
error_data_t train_epoch(const data_file_t& example)
|
||||||
|
{
|
||||||
|
error_data_t error {0, 0};
|
||||||
|
for (const auto& x : example.data_points)
|
||||||
|
error += train(x);
|
||||||
|
error.d_error /= static_cast<Scalar>(example.data_points.size());
|
||||||
|
error.error /= static_cast<Scalar>(example.data_points.size());
|
||||||
|
return error;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef BLT_USE_GRAPHICS
|
#ifdef BLT_USE_GRAPHICS
|
||||||
|
|
||||||
void render() const
|
void render(blt::gfx::batch_renderer_2d& renderer) const
|
||||||
{
|
{
|
||||||
for (auto& l : layers)
|
for (auto& l : layers)
|
||||||
l->render();
|
l->render(renderer);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 8103a3ad0ff5341c61ba07c4b1b9803b2cc740e9
|
Subproject commit 1983a6789e12cb003ae457d68836be17bc4fbeba
|
157
src/main.cpp
157
src/main.cpp
|
@ -9,11 +9,15 @@
|
||||||
#include <assign2/network.h>
|
#include <assign2/network.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
using namespace assign2;
|
using namespace assign2;
|
||||||
|
|
||||||
std::vector<data_file_t> data_files;
|
std::vector<data_file_t> data_files;
|
||||||
random_init randomizer{619};
|
blt::hashmap_t<blt::i32, std::vector<data_file_t>> groups;
|
||||||
|
|
||||||
|
random_init randomizer{std::random_device{}()};
|
||||||
empty_init empty;
|
empty_init empty;
|
||||||
small_init small;
|
small_init small;
|
||||||
sigmoid_function sig;
|
sigmoid_function sig;
|
||||||
|
@ -22,9 +26,9 @@ tanh_function func_tanh;
|
||||||
|
|
||||||
network_t create_network(blt::i32 input, blt::i32 hidden)
|
network_t create_network(blt::i32 input, blt::i32 hidden)
|
||||||
{
|
{
|
||||||
auto layer1 = std::make_unique<layer_t>(input, hidden * 2, &sig, randomizer, empty);
|
auto layer1 = std::make_unique<layer_t>(input, hidden, &sig, randomizer, empty);
|
||||||
auto layer2 = std::make_unique<layer_t>(hidden * 2, hidden / 2, &sig, randomizer, empty);
|
auto layer2 = std::make_unique<layer_t>(hidden, hidden * 0.7, &sig, randomizer, empty);
|
||||||
auto layer_output = std::make_unique<layer_t>(hidden / 2, 2, &sig, randomizer, empty);
|
auto layer_output = std::make_unique<layer_t>(hidden * 0.7, 2, &sig, randomizer, empty);
|
||||||
|
|
||||||
std::vector<std::unique_ptr<layer_t>> vec;
|
std::vector<std::unique_ptr<layer_t>> vec;
|
||||||
vec.push_back(std::move(layer1));
|
vec.push_back(std::move(layer1));
|
||||||
|
@ -49,18 +53,24 @@ blt::gfx::batch_renderer_2d renderer_2d(resources, global_matrices);
|
||||||
blt::gfx::first_person_camera_2d camera;
|
blt::gfx::first_person_camera_2d camera;
|
||||||
|
|
||||||
blt::hashmap_t<blt::i32, network_t> networks;
|
blt::hashmap_t<blt::i32, network_t> networks;
|
||||||
blt::hashmap_t<blt::i32, data_file_t*> file_map;
|
|
||||||
|
data_file_t current_training;
|
||||||
|
data_file_t current_testing;
|
||||||
|
std::atomic_int32_t run_epoch = -1;
|
||||||
|
std::mutex vec_lock;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::unique_ptr<std::thread> network_thread;
|
std::unique_ptr<std::thread> network_thread;
|
||||||
std::atomic_bool running = true;
|
std::atomic_bool running = true;
|
||||||
std::atomic_bool run_exit = true;
|
std::atomic_bool run_exit = true;
|
||||||
std::atomic_int32_t run_epoch = -1;
|
|
||||||
std::atomic_uint64_t epochs = 0;
|
std::atomic_uint64_t epochs = 0;
|
||||||
blt::i32 time_between_runs = 0;
|
blt::i32 time_between_runs = 0;
|
||||||
blt::size_t correct_recall = 0;
|
blt::size_t correct_recall = 0;
|
||||||
blt::size_t wrong_recall = 0;
|
blt::size_t wrong_recall = 0;
|
||||||
bool run_network = false;
|
bool run_network = false;
|
||||||
|
|
||||||
void init(const blt::gfx::window_data& data)
|
void init(const blt::gfx::window_data&)
|
||||||
{
|
{
|
||||||
using namespace blt::gfx;
|
using namespace blt::gfx;
|
||||||
|
|
||||||
|
@ -79,26 +89,33 @@ void init(const blt::gfx::window_data& data)
|
||||||
int hidden = input * 1;
|
int hidden = input * 1;
|
||||||
|
|
||||||
BLT_INFO("Making network of size %d", input);
|
BLT_INFO("Making network of size %d", input);
|
||||||
|
layer_id_counter = 0;
|
||||||
networks[input] = create_network(input, hidden);
|
networks[input] = create_network(input, hidden);
|
||||||
file_map[input] = &f;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
errors_over_time.reserve(25000);
|
errors_over_time.reserve(25000);
|
||||||
error_derivative_over_time.reserve(25000);
|
error_derivative_over_time.reserve(25000);
|
||||||
correct_over_time.reserve(25000);
|
correct_over_time.reserve(25000);
|
||||||
|
error_of_test.reserve(25000);
|
||||||
|
error_of_test_derivative.reserve(25000);
|
||||||
|
|
||||||
network_thread = std::make_unique<std::thread>([]() {
|
network_thread = std::make_unique<std::thread>([]() {
|
||||||
while (running)
|
while (running)
|
||||||
{
|
{
|
||||||
if (run_epoch >= 0)
|
if (run_epoch >= 0)
|
||||||
{
|
{
|
||||||
auto error = networks.at(run_epoch).train_epoch(*file_map[run_epoch]);
|
std::scoped_lock lock(vec_lock);
|
||||||
errors_over_time.push_back(error.first);
|
auto error = networks.at(run_epoch).train_epoch(current_training);
|
||||||
error_derivative_over_time.push_back(error.second);
|
errors_over_time.push_back(error.error);
|
||||||
|
error_derivative_over_time.push_back(error.d_error);
|
||||||
|
|
||||||
|
auto error_test = networks.at(run_epoch).error(current_testing);
|
||||||
|
error_of_test.push_back(error_test.error);
|
||||||
|
error_of_test_derivative.push_back(error_test.d_error);
|
||||||
|
|
||||||
blt::size_t right = 0;
|
blt::size_t right = 0;
|
||||||
blt::size_t wrong = 0;
|
blt::size_t wrong = 0;
|
||||||
for (auto& d : file_map[run_epoch]->data_points)
|
for (auto& d : current_testing.data_points)
|
||||||
{
|
{
|
||||||
auto out = networks.at(run_epoch).execute(d.bins);
|
auto out = networks.at(run_epoch).execute(d.bins);
|
||||||
auto is_bad = is_thinks_bad(out);
|
auto is_bad = is_thinks_bad(out);
|
||||||
|
@ -200,16 +217,34 @@ void update(const blt::gfx::window_data& data)
|
||||||
errors_over_time.clear();
|
errors_over_time.clear();
|
||||||
correct_over_time.clear();
|
correct_over_time.clear();
|
||||||
error_derivative_over_time.clear();
|
error_derivative_over_time.clear();
|
||||||
|
error_of_test.clear();
|
||||||
|
error_of_test_derivative.clear();
|
||||||
run_network = false;
|
run_network = false;
|
||||||
}
|
}
|
||||||
ImGui::Separator();
|
ImGui::Separator();
|
||||||
ImGui::Text("Using network %d size %d", selected, net->first);
|
ImGui::Text("Using network %d size %d", selected, net->first);
|
||||||
static bool pause = pause_mode.load();
|
|
||||||
ImGui::Checkbox("Stepped Mode", &pause);
|
|
||||||
pause_mode = pause;
|
|
||||||
ImGui::Checkbox("Train Network", &run_network);
|
ImGui::Checkbox("Train Network", &run_network);
|
||||||
if (run_network)
|
if (run_network)
|
||||||
|
{
|
||||||
|
if (groups[net->first].size() > 1)
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(vec_lock);
|
||||||
|
current_testing.data_points.clear();
|
||||||
|
current_training.data_points.clear();
|
||||||
|
|
||||||
|
current_testing.data_points.insert(current_testing.data_points.begin(), groups[net->first].front().data_points.begin(),
|
||||||
|
groups[net->first].front().data_points.end());
|
||||||
|
for (auto a : blt::iterate(groups[net->first]).skip(1))
|
||||||
|
current_training.data_points.insert(current_training.data_points.begin(), a.data_points.begin(), a.data_points.end());
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
std::scoped_lock lock(vec_lock);
|
||||||
|
current_training = groups[net->first].front();
|
||||||
|
current_testing = groups[net->first].front();
|
||||||
|
}
|
||||||
|
|
||||||
run_epoch = net->first;
|
run_epoch = net->first;
|
||||||
|
}
|
||||||
ImGui::InputInt("Time Between Runs", &time_between_runs);
|
ImGui::InputInt("Time Between Runs", &time_between_runs);
|
||||||
if (time_between_runs < 0)
|
if (time_between_runs < 0)
|
||||||
time_between_runs = 0;
|
time_between_runs = 0;
|
||||||
|
@ -227,7 +262,7 @@ void update(const blt::gfx::window_data& data)
|
||||||
BLT_INFO("Test Cases:");
|
BLT_INFO("Test Cases:");
|
||||||
blt::size_t right = 0;
|
blt::size_t right = 0;
|
||||||
blt::size_t wrong = 0;
|
blt::size_t wrong = 0;
|
||||||
for (auto& d : file_map[net->first]->data_points)
|
for (auto& d : current_testing.data_points)
|
||||||
{
|
{
|
||||||
std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: ";
|
std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: ";
|
||||||
auto out = net->second.execute(d.bins);
|
auto out = net->second.execute(d.bins);
|
||||||
|
@ -261,20 +296,34 @@ void update(const blt::gfx::window_data& data)
|
||||||
static ImPlotRect lims(0, 100, 0, 1);
|
static ImPlotRect lims(0, 100, 0, 1);
|
||||||
if (ImPlot::BeginAlignedPlots("AlignedGroup"))
|
if (ImPlot::BeginAlignedPlots("AlignedGroup"))
|
||||||
{
|
{
|
||||||
plot_vector(lims, errors_over_time, "Error", "Time", "Error", [](auto v, bool b) {
|
plot_vector(lims, errors_over_time, "Global Error over epochs", "Epoch", "Error", [](auto v, bool b) {
|
||||||
float percent = 0.15;
|
float percent = 0.15;
|
||||||
if (b)
|
if (b)
|
||||||
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||||
else
|
else
|
||||||
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||||
});
|
});
|
||||||
plot_vector(lims, correct_over_time, "Correct", "Time", "Correct", [](auto v, bool b) {
|
plot_vector(lims, error_of_test, "Global Error (Tests)", "Epoch", "Error", [](auto v, bool b) {
|
||||||
|
float percent = 0.15;
|
||||||
|
if (b)
|
||||||
|
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||||
|
else
|
||||||
|
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||||
|
});
|
||||||
|
plot_vector(lims, correct_over_time, "% Correct over epochs", "Epoch", "Correct%", [](auto v, bool b) {
|
||||||
if (b)
|
if (b)
|
||||||
return v - 1;
|
return v - 1;
|
||||||
else
|
else
|
||||||
return v + 1;
|
return v + 1;
|
||||||
});
|
});
|
||||||
plot_vector(lims, error_derivative_over_time, "DError/Dw", "Time", "Error", [](auto v, bool b) {
|
plot_vector(lims, error_derivative_over_time, "DError/Dw over epochs", "Epoch", "Error", [](auto v, bool b) {
|
||||||
|
float percent = 0.05;
|
||||||
|
if (b)
|
||||||
|
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||||
|
else
|
||||||
|
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||||
|
});
|
||||||
|
plot_vector(lims, error_of_test_derivative, "DError/Dw (Test)", "Epoch", "Error", [](auto v, bool b) {
|
||||||
float percent = 0.05;
|
float percent = 0.05;
|
||||||
if (b)
|
if (b)
|
||||||
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||||
|
@ -290,7 +339,7 @@ void update(const blt::gfx::window_data& data)
|
||||||
ImGui::Begin("Hello", nullptr,
|
ImGui::Begin("Hello", nullptr,
|
||||||
ImGuiWindowFlags_AlwaysAutoResize | ImGuiWindowFlags_NoBackground | ImGuiWindowFlags_NoCollapse | ImGuiWindowFlags_NoInputs |
|
ImGuiWindowFlags_AlwaysAutoResize | ImGuiWindowFlags_NoBackground | ImGuiWindowFlags_NoCollapse | ImGuiWindowFlags_NoInputs |
|
||||||
ImGuiWindowFlags_NoTitleBar);
|
ImGuiWindowFlags_NoTitleBar);
|
||||||
net->second.render();
|
net->second.render(renderer_2d);
|
||||||
ImGui::End();
|
ImGui::End();
|
||||||
|
|
||||||
renderer_2d.render(data.width, data.height);
|
renderer_2d.render(data.width, data.height);
|
||||||
|
@ -308,7 +357,6 @@ void destroy()
|
||||||
network_thread->join();
|
network_thread->join();
|
||||||
network_thread = nullptr;
|
network_thread = nullptr;
|
||||||
networks.clear();
|
networks.clear();
|
||||||
file_map.clear();
|
|
||||||
ImPlot::DestroyContext();
|
ImPlot::DestroyContext();
|
||||||
global_matrices.cleanup();
|
global_matrices.cleanup();
|
||||||
resources.cleanup();
|
resources.cleanup();
|
||||||
|
@ -322,12 +370,73 @@ int main(int argc, const char** argv)
|
||||||
{
|
{
|
||||||
blt::arg_parse parser;
|
blt::arg_parse parser;
|
||||||
parser.addArgument(blt::arg_builder("-f", "--file").setHelp("path to the data files").setDefault("../data").build());
|
parser.addArgument(blt::arg_builder("-f", "--file").setHelp("path to the data files").setDefault("../data").build());
|
||||||
|
parser.addArgument(
|
||||||
|
blt::arg_builder("-k", "--kfold").setHelp("Number of groups to split into").setAction(blt::arg_action_t::STORE).setNArgs('?')
|
||||||
|
.setConst("3").build());
|
||||||
|
|
||||||
auto args = parser.parse_args(argc, argv);
|
auto args = parser.parse_args(argc, argv);
|
||||||
std::string data_directory = blt::string::ensure_ends_with_path_separator(args.get<std::string>("file"));
|
std::string data_directory = blt::string::ensure_ends_with_path_separator(args.get<std::string>("file"));
|
||||||
|
|
||||||
data_files = load_data_files(get_data_files(data_directory));
|
data_files = load_data_files(get_data_files(data_directory));
|
||||||
|
|
||||||
|
if (args.contains("kfold"))
|
||||||
|
{
|
||||||
|
auto kfold = std::stoul(args.get<std::string>("kfold"));
|
||||||
|
BLT_INFO("Running K-Fold-%ld", kfold);
|
||||||
|
blt::random::random_t rand(std::random_device{}());
|
||||||
|
for (auto& n : data_files)
|
||||||
|
{
|
||||||
|
std::vector<data_t> goods;
|
||||||
|
std::vector<data_t> bads;
|
||||||
|
|
||||||
|
for (auto& p : n.data_points)
|
||||||
|
{
|
||||||
|
if (p.is_bad)
|
||||||
|
bads.push_back(p);
|
||||||
|
else
|
||||||
|
goods.push_back(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
// can randomize the order of good and bad inputs
|
||||||
|
std::shuffle(goods.begin(), goods.end(), rand);
|
||||||
|
std::shuffle(bads.begin(), bads.end(), rand);
|
||||||
|
|
||||||
|
auto size = static_cast<blt::i32>(n.data_points.begin()->bins.size());
|
||||||
|
groups[size] = {};
|
||||||
|
for (blt::size_t i = 0; i < kfold; i++)
|
||||||
|
groups[size].emplace_back();
|
||||||
|
|
||||||
|
// then copy proportionally into the groups, creating roughly equal groups of data.
|
||||||
|
blt::size_t select = 0;
|
||||||
|
for (auto& v : goods)
|
||||||
|
{
|
||||||
|
++select %= kfold;
|
||||||
|
groups[size][select].data_points.push_back(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// because bad motors are in a separate step they are still proportional
|
||||||
|
for (auto& v : bads)
|
||||||
|
{
|
||||||
|
++select %= kfold;
|
||||||
|
groups[size][select].data_points.push_back(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
for (auto& n : data_files)
|
||||||
|
{
|
||||||
|
auto size = static_cast<blt::i32>(n.data_points.begin()->bins.size());
|
||||||
|
groups[size].push_back(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& [set, g] : groups)
|
||||||
|
{
|
||||||
|
BLT_INFO("Set %d has groups %ld", set, g.size());
|
||||||
|
for (auto [i, f] : blt::enumerate(g))
|
||||||
|
BLT_INFO("\tData file %ld contains %ld elements", i + 1, f.data_points.size());
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef BLT_USE_GRAPHICS
|
#ifdef BLT_USE_GRAPHICS
|
||||||
blt::gfx::init(blt::gfx::window_data{"Freeplay Graphics", init, update, 1440, 720}.setSyncInterval(1).setMonitor(glfwGetPrimaryMonitor())
|
blt::gfx::init(blt::gfx::window_data{"Freeplay Graphics", init, update, 1440, 720}.setSyncInterval(1).setMonitor(glfwGetPrimaryMonitor())
|
||||||
.setMaximized(true));
|
.setMaximized(true));
|
||||||
|
@ -338,9 +447,9 @@ int main(int argc, const char** argv)
|
||||||
for (auto f : data_files)
|
for (auto f : data_files)
|
||||||
{
|
{
|
||||||
int input = static_cast<int>(f.data_points.begin()->bins.size());
|
int input = static_cast<int>(f.data_points.begin()->bins.size());
|
||||||
int hidden = input * 3;
|
int hidden = input;
|
||||||
|
|
||||||
if (input != 64)
|
if (input != 16)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
BLT_INFO("-----------------");
|
BLT_INFO("-----------------");
|
||||||
|
@ -358,7 +467,6 @@ int main(int argc, const char** argv)
|
||||||
blt::size_t wrong = 0;
|
blt::size_t wrong = 0;
|
||||||
for (auto& d : f.data_points)
|
for (auto& d : f.data_points)
|
||||||
{
|
{
|
||||||
std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: ";
|
|
||||||
auto out = network.execute(d.bins);
|
auto out = network.execute(d.bins);
|
||||||
auto is_bad = is_thinks_bad(out);
|
auto is_bad = is_thinks_bad(out);
|
||||||
|
|
||||||
|
@ -367,6 +475,7 @@ int main(int argc, const char** argv)
|
||||||
else
|
else
|
||||||
wrong++;
|
wrong++;
|
||||||
|
|
||||||
|
std::cout << "Good or bad? " << (d.is_bad ? "Bad" : "Good") << " :: ";
|
||||||
std::cout << "NN Thinks: " << (is_bad ? "Bad" : "Good") << " || Outs: [";
|
std::cout << "NN Thinks: " << (is_bad ? "Bad" : "Good") << " || Outs: [";
|
||||||
print_vec(out) << "]" << std::endl;
|
print_vec(out) << "]" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue