Brett 2024-10-28 23:12:19 -04:00
parent 2dc79a8acc
commit 6021b7e4bb
7 changed files with 376 additions and 147 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-2 VERSION 0.0.9)
project(COSC-4P80-Assignment-2 VERSION 0.0.10)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -164,7 +164,7 @@ namespace assign2
std::vector<Scalar> data;
};
std::vector<std::string> get_data_files(std::string_view path)
inline std::vector<std::string> get_data_files(std::string_view path)
{
std::vector<std::string> files;
@ -180,7 +180,7 @@ namespace assign2
return files;
}
std::vector<data_file_t> load_data_files(const std::vector<std::string>& files)
inline std::vector<data_file_t> load_data_files(const std::vector<std::string>& files)
{
std::vector<data_file_t> loaded_data;
@ -245,7 +245,31 @@ namespace assign2
return loaded_data;
}
bool is_thinks_bad(const std::vector<Scalar>& out)
inline void save_as_csv(const std::string& file, const std::vector<std::pair<std::string, std::vector<Scalar>>>& data)
{
std::ofstream stream{file};
stream << "epoch,";
for (auto [i, d] : blt::enumerate(data))
{
stream << d.first;
if (i != data.size() - 1)
stream << ',';
}
stream << '\n';
for (blt::size_t i = 0; i < data.begin()->second.size(); i++)
{
stream << i << ',';
for (auto [j, d] : blt::enumerate(data))
{
stream << d.second[i];
if (j != data.size() - 1)
stream << ',';
}
stream << '\n';
}
}
inline bool is_thinks_bad(const std::vector<Scalar>& out)
{
return out[0] < out[1];
}

View File

@ -66,6 +66,19 @@ namespace assign2
return s >= 0 ? 1 : 0;
}
};
struct bulu_function : public function_t
{
[[nodiscard]] Scalar call(const Scalar s) const final
{
return s > 0.5 ? s : -s;
}
[[nodiscard]] Scalar derivative(Scalar s) const final
{
return s >= 0 ? 1 : -1;
}
};
}
#endif //COSC_4P80_ASSIGNMENT_2_FUNCTIONS_H

View File

@ -58,7 +58,18 @@ namespace assign2
inline std::vector<Scalar> error_derivative_of_test;
inline std::vector<Scalar> correct_over_time;
inline std::vector<Scalar> correct_over_time_test;
inline std::vector<node_data> nodes;
void save_error_info(const std::string& name)
{
save_as_csv("network" + name + ".csv", {{"train_error", errors_over_time},
{"train_d_error", error_derivative_over_time},
{"test_error", error_of_test},
{"test_d_error", error_of_test_derivative},
{"correct_train", correct_over_time},
{"correct_test", correct_over_time_test}});
}
}
#endif //COSC_4P80_ASSIGNMENT_2_GLOBAL_MAGIC_H

View File

@ -34,16 +34,17 @@ namespace assign2
friend layer_t;
public:
// empty neuron for loading from a stream
explicit neuron_t(weight_view weights, weight_view dw): dw(dw), weights(weights)
{}
// explicit neuron_t(weight_view weights, weight_view dw): dw(dw), weights(weights)
// {}
// neuron with bias
explicit neuron_t(weight_view weights, weight_view dw, Scalar bias): bias(bias), dw(dw), weights(weights)
explicit neuron_t(weight_view weights, weight_view dw, weight_view momentum, Scalar bias):
bias(bias), dw(dw), weights(weights), momentum(momentum)
{}
Scalar activate(const std::vector<Scalar>& inputs, function_t* act_func)
{
BLT_ASSERT_MSG(inputs.size() == weights.size(), (std::to_string(inputs.size()) + " vs " + std::to_string(weights.size())).c_str());
BLT_ASSERT_MSG(inputs.size() == weights.size(), (std::to_string(inputs.size()) + " vs " + std::to_string(weights.size())).c_str());
z = bias;
for (auto [x, w] : blt::zip_iterator_container({inputs.begin(), inputs.end()}, {weights.begin(), weights.end()}))
@ -65,10 +66,24 @@ namespace assign2
}
}
void update()
void update(float omega, bool reset)
{
for (auto [w, d] : blt::in_pairs(weights, dw))
w += d;
// if omega is zero we are not using momentum.
if (reset || omega == 0)
{
// BLT_TRACE("Momentum Reset");
// for (auto& v : momentum)
// std::cout << v << ',';
// std::cout << std::endl;
for (auto& m : momentum)
m = 0;
} else
{
for (auto [m, d] : blt::in_pairs(momentum, dw))
m += omega * d;
}
for (auto [w, m, d] : blt::zip(weights, momentum, dw))
w += m + d;
bias += db;
}
@ -101,6 +116,7 @@ namespace assign2
float error = 0;
weight_view dw;
weight_view weights;
weight_view momentum;
};
class layer_t
@ -114,13 +130,15 @@ namespace assign2
neurons.reserve(out_size);
weights.preallocate(in_size * out_size);
weight_derivatives.preallocate(in_size * out_size);
momentum.preallocate(in_size * out_size);
for (blt::i32 i = 0; i < out_size; i++)
{
auto weight = weights.allocate_view(in_size);
auto dw = weight_derivatives.allocate_view(in_size);
auto m = momentum.allocate_view(in_size);
for (auto& v : weight)
v = w(i);
neurons.push_back(neuron_t{weight, dw, b(i)});
neurons.push_back(neuron_t{weight, dw, m, b(i)});
}
}
@ -138,7 +156,7 @@ namespace assign2
}
error_data_t back_prop(const std::vector<Scalar>& prev_layer_output,
const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data)
const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data)
{
Scalar total_error = 0;
Scalar total_derivative = 0;
@ -148,20 +166,23 @@ namespace assign2
for (auto [i, n] : blt::enumerate(neurons))
{
auto d = outputs[i] - expected[i];
// if (outputs[0] > 0.3 && outputs[1] > 0.3)
// d *= 10 * (outputs[0] + outputs[1]);
auto d2 = 0.5f * (d * d);
// according to the slides and the 3b1b video we sum on the squared error
// not sure why on the slides the 1/2 is moved outside the sum as the cost function is defined (1/2) * (o - y)^2
// and that the total cost for an input pattern is the sum of costs on the output
total_error += d2;
total_derivative += d;
n.back_prop(act_func, prev_layer_output, d);
}
total_error /= static_cast<Scalar>(expected.size());
total_derivative /= static_cast<Scalar>(expected.size());
},
// interior layer
[this, &prev_layer_output](const layer_t& layer) {
for (auto [i, n] : blt::enumerate(neurons))
{
Scalar w = 0;
// TODO: this is not efficient on the cache!
Scalar w = 0;
for (auto nn : layer.neurons)
w += nn.error * nn.weights[i];
n.back_prop(act_func, prev_layer_output, w);
@ -171,10 +192,10 @@ namespace assign2
return {total_error, total_derivative};
}
void update()
void update(const float* omega, bool reset)
{
for (auto& n : neurons)
n.update();
n.update(omega == nullptr ? 0 : *omega, reset);
}
template<typename OStream>
@ -247,6 +268,7 @@ namespace assign2
const blt::size_t layer_id;
weight_t weights;
weight_t weight_derivatives;
weight_t momentum;
function_t* act_func;
std::vector<neuron_t> neurons;
std::vector<Scalar> outputs;

View File

@ -79,22 +79,7 @@ namespace assign2
outputs.emplace_back(input);
for (auto [i, v] : blt::enumerate(layers))
{
// auto in = outputs.back();
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Going In: ";
// print_vec(in.get()) << std::endl;
// auto& out = v->call(in);
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Coming out: ";
// print_vec(out) << std::endl;
//// std::cout << "(" << i << "/" << layers.size() << ") Weights: ";
//// v->weights.debug();
//// std::cout << std::endl;
// std::cout << std::endl;
//
// outputs.emplace_back(out);
outputs.emplace_back(v->call(outputs.back()));
}
// std::cout << std::endl;
return outputs.back();
}
@ -110,25 +95,22 @@ namespace assign2
auto out = execute(d.bins);
Scalar local_total_error = 0;
Scalar local_total_d_error = 0;
BLT_ASSERT(out.size() == expected.size());
for (auto [o, e] : blt::in_pairs(out, expected))
{
auto d_error = o - e;
auto error = 0.5f * (d_error * d_error);
local_total_error += error;
local_total_d_error += d_error;
total_error += error;
total_d_error += d_error;
}
total_error += local_total_error / 2;
total_d_error += local_total_d_error / 2;
}
return {total_error / static_cast<Scalar>(data.data_points.size()), total_d_error / static_cast<Scalar>(data.data_points.size())};
}
error_data_t train(const data_t& data)
error_data_t train(const data_t& data, bool reset)
{
error_data_t error = {0, 0};
execute(data.bins);
@ -148,19 +130,34 @@ namespace assign2
}
}
for (auto& l : layers)
l->update();
l->update(m_omega, reset);
// BLT_TRACE("Error for input: %f, derr: %f", error.error, error.d_error);
return error;
}
error_data_t train_epoch(const data_file_t& example)
error_data_t train_epoch(const data_file_t& example, blt::i32 trains_per_data = 1)
{
error_data_t error {0, 0};
error_data_t error{0, 0};
for (const auto& x : example.data_points)
error += train(x);
error.d_error /= static_cast<Scalar>(example.data_points.size());
error.error /= static_cast<Scalar>(example.data_points.size());
{
for (blt::i32 i = 0; i < trains_per_data; i++)
error += train(x, reset_next);
}
// take the average cost over all the training.
error.d_error /= static_cast<Scalar>(example.data_points.size() * trains_per_data);
error.error /= static_cast<Scalar>(example.data_points.size() * trains_per_data);
// as long as we are reducing error in the same direction in overall terms, we should still build momentum.
auto last_sign = last_d_error >= 0;
auto cur_sign = error.d_error >= 0;
last_d_error = error.d_error;
reset_next = last_sign != cur_sign;
return error;
}
void with_momentum(Scalar* omega)
{
m_omega = omega;
}
#ifdef BLT_USE_GRAPHICS
@ -173,6 +170,10 @@ namespace assign2
#endif
private:
// pointer so it can be changed from the UI
Scalar* m_omega = nullptr;
Scalar last_d_error = 0;
bool reset_next = false;
std::vector<std::unique_ptr<layer_t>> layers;
};
}

View File

@ -16,26 +16,58 @@ using namespace assign2;
std::vector<data_file_t> data_files;
blt::hashmap_t<blt::i32, std::vector<data_file_t>> groups;
blt::hashmap_t<blt::i32, network_t> networks;
bool with_momentum = false;
Scalar omega = 0.001;
random_init randomizer{std::random_device{}()};
empty_init empty;
small_init small;
sigmoid_function sig;
relu_function relu;
bulu_function bulu;
tanh_function func_tanh;
network_t create_network(blt::i32 input, blt::i32 hidden)
{
auto layer1 = std::make_unique<layer_t>(input, hidden, &sig, randomizer, empty);
auto layer2 = std::make_unique<layer_t>(hidden, hidden * 0.7, &sig, randomizer, empty);
auto layer_output = std::make_unique<layer_t>(hidden * 0.7, 2, &sig, randomizer, empty);
const auto mul = 0.5;
const auto inner_mul = 0.25;
auto layer1 = std::make_unique<layer_t>(input, hidden * mul, &sig, randomizer, empty);
auto layer2 = std::make_unique<layer_t>(hidden * mul, hidden * inner_mul, &sig, randomizer, empty);
// auto layer3 = std::make_unique<layer_t>(hidden * mul, hidden * mul, &sig, randomizer, empty);
// auto layer4 = std::make_unique<layer_t>(hidden * mul, hidden * mul, &sig, randomizer, empty);
auto layer_output = std::make_unique<layer_t>(hidden * inner_mul, 2, &sig, randomizer, empty);
std::vector<std::unique_ptr<layer_t>> vec;
vec.push_back(std::move(layer1));
vec.push_back(std::move(layer2));
// vec.push_back(std::move(layer3));
// vec.push_back(std::move(layer4));
vec.push_back(std::move(layer_output));
return network_t{std::move(vec)};
network_t network{std::move(vec)};
if (with_momentum)
network.with_momentum(&omega);
return network;
}
std::pair<data_file_t, data_file_t> create_groups(blt::i32 network, blt::i32 k = 0)
{
data_file_t training;
data_file_t testing;
testing.data_points.insert(testing.data_points.begin(),
(groups[network].begin() + k)->data_points.begin(),
(groups[network].begin() + k)->data_points.end());
for (auto [i, a] : blt::enumerate(groups[network]))
{
if (i == static_cast<blt::size_t>(k))
continue;
training.data_points.insert(training.data_points.begin(), a.data_points.begin(), a.data_points.end());
}
return {training, testing};
}
#ifdef BLT_USE_GRAPHICS
@ -52,24 +84,63 @@ blt::gfx::resource_manager resources;
blt::gfx::batch_renderer_2d renderer_2d(resources, global_matrices);
blt::gfx::first_person_camera_2d camera;
blt::hashmap_t<blt::i32, network_t> networks;
data_file_t current_training;
data_file_t current_testing;
std::atomic_int32_t run_epoch = -1;
blt::i32 stop_at = -1;
blt::i32 trains_per_data = 1;
std::mutex vec_lock;
std::unique_ptr<std::thread> network_thread;
std::atomic_bool running = true;
std::atomic_bool run_exit = true;
std::atomic_uint64_t epochs = 0;
blt::i32 time_between_runs = 0;
blt::size_t correct_recall = 0;
blt::size_t wrong_recall = 0;
blt::i32 number_before_switch = 10;
bool swap_k_after = false;
blt::size_t correct_recall_train = 0;
blt::size_t correct_recall_test = 0;
blt::size_t wrong_recall_train = 0;
blt::size_t wrong_recall_test = 0;
bool run_network = false;
float init_learn = learn_rate;
float init_momentum = omega;
blt::i32 current_k = 0;
void update_current(int network)
{
if (groups[network].size() > 1)
{
std::scoped_lock lock(vec_lock);
current_testing.data_points.clear();
current_training.data_points.clear();
auto g = create_groups(network, current_k);
current_testing = g.second;
current_training = g.first;
} else
{
std::scoped_lock lock(vec_lock);
current_training = groups[network].front();
current_testing = groups[network].front();
}
}
void reset_errors(int network)
{
save_error_info(std::to_string(network));
errors_over_time.clear();
correct_over_time.clear();
correct_over_time_test.clear();
error_derivative_over_time.clear();
error_of_test.clear();
error_of_test_derivative.clear();
epochs = 0;
run_network = false;
}
void init(const blt::gfx::window_data&)
{
using namespace blt::gfx;
@ -83,51 +154,70 @@ void init(const blt::gfx::window_data&)
renderer_2d.create();
ImPlot::CreateContext();
for (auto& f : data_files)
{
int input = static_cast<int>(f.data_points.begin()->bins.size());
int hidden = input * 1;
BLT_INFO("Making network of size %d", input);
layer_id_counter = 0;
networks[input] = create_network(input, hidden);
}
errors_over_time.reserve(25000);
error_derivative_over_time.reserve(25000);
correct_over_time.reserve(25000);
error_of_test.reserve(25000);
error_of_test_derivative.reserve(25000);
update_current(networks.begin()->first);
network_thread = std::make_unique<std::thread>([]() {
while (running)
{
if (run_epoch >= 0)
{
std::scoped_lock lock(vec_lock);
auto error = networks.at(run_epoch).train_epoch(current_training);
errors_over_time.push_back(error.error);
error_derivative_over_time.push_back(error.d_error);
auto error_test = networks.at(run_epoch).error(current_testing);
error_of_test.push_back(error_test.error);
error_of_test_derivative.push_back(error_test.d_error);
blt::size_t right = 0;
blt::size_t wrong = 0;
for (auto& d : current_testing.data_points)
if (swap_k_after && epochs % number_before_switch == static_cast<blt::size_t>(number_before_switch - 1))
{
auto out = networks.at(run_epoch).execute(d.bins);
auto is_bad = is_thinks_bad(out);
if ((is_bad && d.is_bad) || (!is_bad && !d.is_bad))
right++;
else
wrong++;
current_k++;
current_k %= static_cast<blt::i32>(groups[run_epoch].size());
update_current(run_epoch);
}
correct_recall = right;
wrong_recall = wrong;
correct_over_time.push_back(static_cast<Scalar>(right) / static_cast<Scalar>(right + wrong) * 100);
blt::size_t right_t = 0;
blt::size_t wrong_t = 0;
blt::size_t right_a = 0;
blt::size_t wrong_a = 0;
{
std::scoped_lock lock(vec_lock);
auto error = networks.at(run_epoch).train_epoch(current_training, trains_per_data);
errors_over_time.push_back(error.error);
error_derivative_over_time.push_back(error.d_error);
auto error_test = networks.at(run_epoch).error(current_testing);
error_of_test.push_back(error_test.error);
error_of_test_derivative.push_back(error_test.d_error);
for (auto& d : current_testing.data_points)
{
auto out = networks.at(run_epoch).execute(d.bins);
auto is_bad = is_thinks_bad(out);
if ((is_bad && d.is_bad) || (!is_bad && !d.is_bad))
right_t++;
else
wrong_t++;
}
for (auto& d : current_training.data_points)
{
auto out = networks.at(run_epoch).execute(d.bins);
auto is_bad = is_thinks_bad(out);
if ((is_bad && d.is_bad) || (!is_bad && !d.is_bad))
right_a++;
else
wrong_a++;
}
}
correct_recall_test = right_t;
correct_recall_train = right_a;
wrong_recall_test = wrong_t;
wrong_recall_train = wrong_a;
correct_over_time
.push_back(static_cast<Scalar>(correct_recall_train) / static_cast<Scalar>(correct_recall_train + wrong_recall_train) * 100);
correct_over_time_test
.push_back(static_cast<Scalar>(correct_recall_test) / static_cast<Scalar>(correct_recall_test + wrong_recall_test) * 100);
auto error = errors_over_time.back();
// error = std::sqrt(error * error + error + 0.01f);
error = std::max(0.0f, std::min(1.0f, error));
learn_rate = error * init_learn;
omega = error * init_momentum;
epochs++;
run_epoch = -1;
@ -179,6 +269,18 @@ void plot_vector(ImPlotRect& lims, const std::vector<Scalar>& v, std::string nam
}
}
static void HelpMarker(const char* desc)
{
ImGui::TextDisabled("(?)");
if (ImGui::BeginItemTooltip())
{
ImGui::PushTextWrapPos(ImGui::GetFontSize() * 35.0f);
ImGui::TextUnformatted(desc);
ImGui::PopTextWrapPos();
ImGui::EndTooltip();
}
}
void update(const blt::gfx::window_data& data)
{
global_matrices.update_perspectives(data.width, data.height, 90, 0.1, 2000);
@ -207,55 +309,44 @@ void update(const blt::gfx::window_data& data)
lists.push_back(ptr);
}
}
static int selected = 1;
static int selected = 0;
for (int i = 0; i < selected; i++)
net++;
ImGui::Separator();
ImGui::Text("Select Network Size");
if (ImGui::ListBox("", &selected, lists.data(), static_cast<int>(lists.size()), 4))
{
errors_over_time.clear();
correct_over_time.clear();
error_derivative_over_time.clear();
error_of_test.clear();
error_of_test_derivative.clear();
run_network = false;
reset_errors(net->first);
net = networks.begin();
for (int i = 0; i < selected; i++)
net++;
update_current(net->first);
}
ImGui::Separator();
ImGui::Text("Using network %d size %d", selected, net->first);
ImGui::Checkbox("Train Network", &run_network);
ImGui::InputInt("Stop At", &stop_at);
if (static_cast<blt::i32>(epochs) >= stop_at && stop_at > 0)
run_network = false;
if (run_network)
{
if (groups[net->first].size() > 1)
{
std::scoped_lock lock(vec_lock);
current_testing.data_points.clear();
current_training.data_points.clear();
current_testing.data_points.insert(current_testing.data_points.begin(), groups[net->first].front().data_points.begin(),
groups[net->first].front().data_points.end());
for (auto a : blt::iterate(groups[net->first]).skip(1))
current_training.data_points.insert(current_training.data_points.begin(), a.data_points.begin(), a.data_points.end());
} else
{
std::scoped_lock lock(vec_lock);
current_training = groups[net->first].front();
current_testing = groups[net->first].front();
}
// update_current(net->first);
run_epoch = net->first;
}
ImGui::InputInt("Time Between Runs", &time_between_runs);
if (time_between_runs < 0)
time_between_runs = 0;
std::string str = std::to_string(correct_recall) + "/" + std::to_string(wrong_recall + correct_recall);
std::string str = std::to_string(correct_recall_test) + "/" + std::to_string(wrong_recall_test + correct_recall_test);
ImGui::ProgressBar(
(wrong_recall + correct_recall != 0) ? static_cast<float>(correct_recall) / static_cast<float>(wrong_recall + correct_recall) : 0,
(wrong_recall_test + correct_recall_test != 0) ? static_cast<float>(correct_recall_test) /
static_cast<float>(wrong_recall_test + correct_recall_test) : 0,
ImVec2(0, 0), str.c_str());
ImGui::Separator();
str = std::to_string(correct_recall_train) + "/" + std::to_string(wrong_recall_train + correct_recall_train);
ImGui::ProgressBar(
(wrong_recall_train + correct_recall_train != 0) ? static_cast<float>(correct_recall_train) /
static_cast<float>(wrong_recall_train + correct_recall_train) : 0,
ImVec2(0, 0), str.c_str());
// const float max_learn = 100000;
// static float learn = max_learn;
// ImGui::SliderFloat("Learn Rate", &learn, 1, max_learn, "", ImGuiSliderFlags_Logarithmic);
// learn_rate = learn / (max_learn * 1000);
ImGui::Text("Learn Rate %.9f", learn_rate);
if (ImGui::Button("Print Current"))
{
@ -278,6 +369,35 @@ void update(const blt::gfx::window_data& data)
}
BLT_INFO("NN got %ld right and %ld wrong (%%%lf)", right, wrong, static_cast<double>(right) / static_cast<double>(right + wrong) * 100);
}
if (ImGui::SliderInt("K For Testing", &current_k, 0, static_cast<int>(groups[net->first].size() - 1)))
update_current(net->first);
ImGui::Checkbox("Auto-swap K", &swap_k_after);
if (swap_k_after)
{
ImGui::InputInt("Number of epochs before switch", &number_before_switch);
if (number_before_switch < 1)
number_before_switch = 1;
}
ImGui::Checkbox("Momentum", &with_momentum);
ImGui::SameLine();
HelpMarker("You might want to reset the network after changing this");
if (with_momentum)
ImGui::SliderFloat("##MomentumSlider", &omega, 0, 0.1, "%.8f", ImGuiSliderFlags_Logarithmic);
ImGui::InputInt("Trains per Epoch", &trains_per_data);
ImGui::SameLine();
HelpMarker("Number of times to run back-prop on a piece of data before moving on to the next");
if (trains_per_data < 1)
trains_per_data = 1;
ImGui::Separator();
if (ImGui::Button("Reset Network"))
{
reset_errors(net->first);
layer_id_counter = 0;
networks[net->first] = create_network(net->first, net->first);
}
ImGui::Separator();
if (ImGui::Button("Save current to CSV"))
save_error_info(std::to_string(net->first) + "_" + std::to_string(current_k));
}
ImGui::End();
@ -291,12 +411,10 @@ void update(const blt::gfx::window_data& data)
x_points.push_back(i);
}
auto domain = static_cast<int>(errors_over_time.size());
blt::i32 history = std::min(100, domain);
static ImPlotRect lims(0, 100, 0, 1);
if (ImPlot::BeginAlignedPlots("AlignedGroup"))
static ImPlotRect lims(0, 500, 0, 1);
if (ImPlot::BeginSubplots("##LinkedGroup", 3, 2, ImVec2(-1, -1)))
{
plot_vector(lims, errors_over_time, "Global Error over epochs", "Epoch", "Error", [](auto v, bool b) {
plot_vector(lims, errors_over_time, "Global Error (Training)", "Epoch", "Error", [](auto v, bool b) {
float percent = 0.15;
if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent);
@ -310,27 +428,33 @@ void update(const blt::gfx::window_data& data)
else
return v < 0 ? v * (1 - percent) : v * (1 + percent);
});
plot_vector(lims, correct_over_time, "% Correct over epochs", "Epoch", "Correct%", [](auto v, bool b) {
plot_vector(lims, error_derivative_over_time, "DError/Dw (Training)", "Epoch", "DError", [](auto v, bool b) {
float percent = 0.05;
if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent);
else
return v < 0 ? v * (1 - percent) : v * (1 + percent);
});
plot_vector(lims, error_of_test_derivative, "DError/Dw (Test)", "Epoch", "DError", [](auto v, bool b) {
float percent = 0.05;
if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent);
else
return v < 0 ? v * (1 - percent) : v * (1 + percent);
});
plot_vector(lims, correct_over_time, "% Correct (Training)", "Epoch", "Correct%", [](auto v, bool b) {
if (b)
return v - 1;
else
return v + 1;
});
plot_vector(lims, error_derivative_over_time, "DError/Dw over epochs", "Epoch", "Error", [](auto v, bool b) {
float percent = 0.05;
plot_vector(lims, correct_over_time_test, "% Correct (Test)", "Epoch", "Correct%", [](auto v, bool b) {
if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent);
return v - 1;
else
return v < 0 ? v * (1 - percent) : v * (1 + percent);
return v + 1;
});
plot_vector(lims, error_of_test_derivative, "DError/Dw (Test)", "Epoch", "Error", [](auto v, bool b) {
float percent = 0.05;
if (b)
return v < 0 ? v * (1 + percent) : v * (1 - percent);
else
return v < 0 ? v * (1 - percent) : v * (1 + percent);
});
ImPlot::EndAlignedPlots();
ImPlot::EndSubplots();
}
}
ImGui::End();
@ -347,6 +471,7 @@ void update(const blt::gfx::window_data& data)
void destroy()
{
save_error_info(std::to_string(run_epoch));
running = false;
while (run_exit)
{
@ -369,12 +494,20 @@ void destroy()
int main(int argc, const char** argv)
{
blt::arg_parse parser;
parser.addArgument(blt::arg_builder("-f", "--file").setHelp("path to the data files").setDefault("../data").build());
parser.addArgument(blt::arg_builder("-f", "--file").setHelp("Path to the data files").setDefault("../data").setMetavar("FOLDER").build());
parser.addArgument(
blt::arg_builder("-k", "--kfold").setHelp("Number of groups to split into").setAction(blt::arg_action_t::STORE).setNArgs('?')
.setConst("3").build());
blt::arg_builder("-k", "--kfold").setHelp("Number of groups to split into [Defaults to 3 if no number is provided]")
.setAction(blt::arg_action_t::STORE).setNArgs('?').setConst("3").setMetavar("GROUPS").build());
parser.addArgument(blt::arg_builder("-m", "--momentum").setHelp("Use momentum in weight calculations").setAction(blt::arg_action_t::STORE_TRUE)
.setDefault(false).build());
auto args = parser.parse_args(argc, argv);
if (args.get<bool>("momentum"))
{
BLT_INFO("Using Momentum");
with_momentum = true;
}
std::string data_directory = blt::string::ensure_ends_with_path_separator(args.get<std::string>("file"));
data_files = load_data_files(get_data_files(data_directory));
@ -387,6 +520,7 @@ int main(int argc, const char** argv)
for (auto& n : data_files)
{
std::vector<data_t> goods;
// Big Airship of Doom (BAD)
std::vector<data_t> bads;
for (auto& p : n.data_points)
@ -407,6 +541,10 @@ int main(int argc, const char** argv)
groups[size].emplace_back();
// then copy proportionally into the groups, creating roughly equal groups of data.
// my previous setup randomly selected the group index
// this resulted in wildly uneven groups, if you got unlucky.
// 25 vs 13 in some groups
// not sure if this is what we want, but it felt like this would create issues
blt::size_t select = 0;
for (auto& v : goods)
{
@ -436,6 +574,24 @@ int main(int argc, const char** argv)
for (auto [i, f] : blt::enumerate(g))
BLT_INFO("\tData file %ld contains %ld elements", i + 1, f.data_points.size());
}
for (auto& f : data_files)
{
int input = static_cast<int>(f.data_points.begin()->bins.size());
int hidden = input * 1;
BLT_INFO("Making network of size %d", input);
layer_id_counter = 0;
networks[input] = create_network(input, hidden);
}
// this is to prevent threading issues due to expanding buffers.
errors_over_time.reserve(25000);
error_derivative_over_time.reserve(25000);
correct_over_time.reserve(25000);
correct_over_time_test.reserve(25000);
error_of_test.reserve(25000);
error_of_test_derivative.reserve(25000);
#ifdef BLT_USE_GRAPHICS
blt::gfx::init(blt::gfx::window_data{"Freeplay Graphics", init, update, 1440, 720}.setSyncInterval(1).setMonitor(glfwGetPrimaryMonitor())
@ -449,7 +605,7 @@ int main(int argc, const char** argv)
int input = static_cast<int>(f.data_points.begin()->bins.size());
int hidden = input;
if (input != 16)
if (input != 32)
continue;
BLT_INFO("-----------------");
@ -459,8 +615,10 @@ int main(int argc, const char** argv)
network_t network = create_network(input, hidden);
for (blt::size_t i = 0; i < 2000; i++)
network.train_epoch(f);
float o = 0.00001;
network.with_momentum(&o);
for (blt::size_t i = 0; i < 300; i++)
network.train_epoch(f, 1);
BLT_INFO("Test Cases:");
blt::size_t right = 0;