ugh
parent
2dc79a8acc
commit
6021b7e4bb
|
@ -1,5 +1,5 @@
|
|||
cmake_minimum_required(VERSION 3.25)
|
||||
project(COSC-4P80-Assignment-2 VERSION 0.0.9)
|
||||
project(COSC-4P80-Assignment-2 VERSION 0.0.10)
|
||||
|
||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||
|
|
|
@ -164,7 +164,7 @@ namespace assign2
|
|||
std::vector<Scalar> data;
|
||||
};
|
||||
|
||||
std::vector<std::string> get_data_files(std::string_view path)
|
||||
inline std::vector<std::string> get_data_files(std::string_view path)
|
||||
{
|
||||
std::vector<std::string> files;
|
||||
|
||||
|
@ -180,7 +180,7 @@ namespace assign2
|
|||
return files;
|
||||
}
|
||||
|
||||
std::vector<data_file_t> load_data_files(const std::vector<std::string>& files)
|
||||
inline std::vector<data_file_t> load_data_files(const std::vector<std::string>& files)
|
||||
{
|
||||
std::vector<data_file_t> loaded_data;
|
||||
|
||||
|
@ -245,7 +245,31 @@ namespace assign2
|
|||
return loaded_data;
|
||||
}
|
||||
|
||||
bool is_thinks_bad(const std::vector<Scalar>& out)
|
||||
inline void save_as_csv(const std::string& file, const std::vector<std::pair<std::string, std::vector<Scalar>>>& data)
|
||||
{
|
||||
std::ofstream stream{file};
|
||||
stream << "epoch,";
|
||||
for (auto [i, d] : blt::enumerate(data))
|
||||
{
|
||||
stream << d.first;
|
||||
if (i != data.size() - 1)
|
||||
stream << ',';
|
||||
}
|
||||
stream << '\n';
|
||||
for (blt::size_t i = 0; i < data.begin()->second.size(); i++)
|
||||
{
|
||||
stream << i << ',';
|
||||
for (auto [j, d] : blt::enumerate(data))
|
||||
{
|
||||
stream << d.second[i];
|
||||
if (j != data.size() - 1)
|
||||
stream << ',';
|
||||
}
|
||||
stream << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_thinks_bad(const std::vector<Scalar>& out)
|
||||
{
|
||||
return out[0] < out[1];
|
||||
}
|
||||
|
|
|
@ -66,6 +66,19 @@ namespace assign2
|
|||
return s >= 0 ? 1 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct bulu_function : public function_t
|
||||
{
|
||||
[[nodiscard]] Scalar call(const Scalar s) const final
|
||||
{
|
||||
return s > 0.5 ? s : -s;
|
||||
}
|
||||
|
||||
[[nodiscard]] Scalar derivative(Scalar s) const final
|
||||
{
|
||||
return s >= 0 ? 1 : -1;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#endif //COSC_4P80_ASSIGNMENT_2_FUNCTIONS_H
|
||||
|
|
|
@ -58,7 +58,18 @@ namespace assign2
|
|||
|
||||
inline std::vector<Scalar> error_derivative_of_test;
|
||||
inline std::vector<Scalar> correct_over_time;
|
||||
inline std::vector<Scalar> correct_over_time_test;
|
||||
inline std::vector<node_data> nodes;
|
||||
|
||||
void save_error_info(const std::string& name)
|
||||
{
|
||||
save_as_csv("network" + name + ".csv", {{"train_error", errors_over_time},
|
||||
{"train_d_error", error_derivative_over_time},
|
||||
{"test_error", error_of_test},
|
||||
{"test_d_error", error_of_test_derivative},
|
||||
{"correct_train", correct_over_time},
|
||||
{"correct_test", correct_over_time_test}});
|
||||
}
|
||||
}
|
||||
|
||||
#endif //COSC_4P80_ASSIGNMENT_2_GLOBAL_MAGIC_H
|
||||
|
|
|
@ -34,16 +34,17 @@ namespace assign2
|
|||
friend layer_t;
|
||||
public:
|
||||
// empty neuron for loading from a stream
|
||||
explicit neuron_t(weight_view weights, weight_view dw): dw(dw), weights(weights)
|
||||
{}
|
||||
// explicit neuron_t(weight_view weights, weight_view dw): dw(dw), weights(weights)
|
||||
// {}
|
||||
|
||||
// neuron with bias
|
||||
explicit neuron_t(weight_view weights, weight_view dw, Scalar bias): bias(bias), dw(dw), weights(weights)
|
||||
explicit neuron_t(weight_view weights, weight_view dw, weight_view momentum, Scalar bias):
|
||||
bias(bias), dw(dw), weights(weights), momentum(momentum)
|
||||
{}
|
||||
|
||||
Scalar activate(const std::vector<Scalar>& inputs, function_t* act_func)
|
||||
{
|
||||
BLT_ASSERT_MSG(inputs.size() == weights.size(), (std::to_string(inputs.size()) + " vs " + std::to_string(weights.size())).c_str());
|
||||
BLT_ASSERT_MSG(inputs.size() == weights.size(), (std::to_string(inputs.size()) + " vs " + std::to_string(weights.size())).c_str());
|
||||
|
||||
z = bias;
|
||||
for (auto [x, w] : blt::zip_iterator_container({inputs.begin(), inputs.end()}, {weights.begin(), weights.end()}))
|
||||
|
@ -65,10 +66,24 @@ namespace assign2
|
|||
}
|
||||
}
|
||||
|
||||
void update()
|
||||
void update(float omega, bool reset)
|
||||
{
|
||||
for (auto [w, d] : blt::in_pairs(weights, dw))
|
||||
w += d;
|
||||
// if omega is zero we are not using momentum.
|
||||
if (reset || omega == 0)
|
||||
{
|
||||
// BLT_TRACE("Momentum Reset");
|
||||
// for (auto& v : momentum)
|
||||
// std::cout << v << ',';
|
||||
// std::cout << std::endl;
|
||||
for (auto& m : momentum)
|
||||
m = 0;
|
||||
} else
|
||||
{
|
||||
for (auto [m, d] : blt::in_pairs(momentum, dw))
|
||||
m += omega * d;
|
||||
}
|
||||
for (auto [w, m, d] : blt::zip(weights, momentum, dw))
|
||||
w += m + d;
|
||||
bias += db;
|
||||
}
|
||||
|
||||
|
@ -101,6 +116,7 @@ namespace assign2
|
|||
float error = 0;
|
||||
weight_view dw;
|
||||
weight_view weights;
|
||||
weight_view momentum;
|
||||
};
|
||||
|
||||
class layer_t
|
||||
|
@ -114,13 +130,15 @@ namespace assign2
|
|||
neurons.reserve(out_size);
|
||||
weights.preallocate(in_size * out_size);
|
||||
weight_derivatives.preallocate(in_size * out_size);
|
||||
momentum.preallocate(in_size * out_size);
|
||||
for (blt::i32 i = 0; i < out_size; i++)
|
||||
{
|
||||
auto weight = weights.allocate_view(in_size);
|
||||
auto dw = weight_derivatives.allocate_view(in_size);
|
||||
auto m = momentum.allocate_view(in_size);
|
||||
for (auto& v : weight)
|
||||
v = w(i);
|
||||
neurons.push_back(neuron_t{weight, dw, b(i)});
|
||||
neurons.push_back(neuron_t{weight, dw, m, b(i)});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,7 +156,7 @@ namespace assign2
|
|||
}
|
||||
|
||||
error_data_t back_prop(const std::vector<Scalar>& prev_layer_output,
|
||||
const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data)
|
||||
const std::variant<blt::ref<const std::vector<Scalar>>, blt::ref<const layer_t>>& data)
|
||||
{
|
||||
Scalar total_error = 0;
|
||||
Scalar total_derivative = 0;
|
||||
|
@ -148,20 +166,23 @@ namespace assign2
|
|||
for (auto [i, n] : blt::enumerate(neurons))
|
||||
{
|
||||
auto d = outputs[i] - expected[i];
|
||||
// if (outputs[0] > 0.3 && outputs[1] > 0.3)
|
||||
// d *= 10 * (outputs[0] + outputs[1]);
|
||||
auto d2 = 0.5f * (d * d);
|
||||
// according to the slides and the 3b1b video we sum on the squared error
|
||||
// not sure why on the slides the 1/2 is moved outside the sum as the cost function is defined (1/2) * (o - y)^2
|
||||
// and that the total cost for an input pattern is the sum of costs on the output
|
||||
total_error += d2;
|
||||
total_derivative += d;
|
||||
n.back_prop(act_func, prev_layer_output, d);
|
||||
}
|
||||
total_error /= static_cast<Scalar>(expected.size());
|
||||
total_derivative /= static_cast<Scalar>(expected.size());
|
||||
},
|
||||
// interior layer
|
||||
[this, &prev_layer_output](const layer_t& layer) {
|
||||
for (auto [i, n] : blt::enumerate(neurons))
|
||||
{
|
||||
Scalar w = 0;
|
||||
// TODO: this is not efficient on the cache!
|
||||
Scalar w = 0;
|
||||
for (auto nn : layer.neurons)
|
||||
w += nn.error * nn.weights[i];
|
||||
n.back_prop(act_func, prev_layer_output, w);
|
||||
|
@ -171,10 +192,10 @@ namespace assign2
|
|||
return {total_error, total_derivative};
|
||||
}
|
||||
|
||||
void update()
|
||||
void update(const float* omega, bool reset)
|
||||
{
|
||||
for (auto& n : neurons)
|
||||
n.update();
|
||||
n.update(omega == nullptr ? 0 : *omega, reset);
|
||||
}
|
||||
|
||||
template<typename OStream>
|
||||
|
@ -247,6 +268,7 @@ namespace assign2
|
|||
const blt::size_t layer_id;
|
||||
weight_t weights;
|
||||
weight_t weight_derivatives;
|
||||
weight_t momentum;
|
||||
function_t* act_func;
|
||||
std::vector<neuron_t> neurons;
|
||||
std::vector<Scalar> outputs;
|
||||
|
|
|
@ -79,22 +79,7 @@ namespace assign2
|
|||
outputs.emplace_back(input);
|
||||
|
||||
for (auto [i, v] : blt::enumerate(layers))
|
||||
{
|
||||
// auto in = outputs.back();
|
||||
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Going In: ";
|
||||
// print_vec(in.get()) << std::endl;
|
||||
// auto& out = v->call(in);
|
||||
// std::cout << "(" << i + 1 << "/" << layers.size() << ") Coming out: ";
|
||||
// print_vec(out) << std::endl;
|
||||
//// std::cout << "(" << i << "/" << layers.size() << ") Weights: ";
|
||||
//// v->weights.debug();
|
||||
//// std::cout << std::endl;
|
||||
// std::cout << std::endl;
|
||||
//
|
||||
// outputs.emplace_back(out);
|
||||
outputs.emplace_back(v->call(outputs.back()));
|
||||
}
|
||||
// std::cout << std::endl;
|
||||
|
||||
return outputs.back();
|
||||
}
|
||||
|
@ -110,25 +95,22 @@ namespace assign2
|
|||
|
||||
auto out = execute(d.bins);
|
||||
|
||||
Scalar local_total_error = 0;
|
||||
Scalar local_total_d_error = 0;
|
||||
BLT_ASSERT(out.size() == expected.size());
|
||||
for (auto [o, e] : blt::in_pairs(out, expected))
|
||||
{
|
||||
auto d_error = o - e;
|
||||
|
||||
auto error = 0.5f * (d_error * d_error);
|
||||
|
||||
local_total_error += error;
|
||||
local_total_d_error += d_error;
|
||||
total_error += error;
|
||||
total_d_error += d_error;
|
||||
}
|
||||
total_error += local_total_error / 2;
|
||||
total_d_error += local_total_d_error / 2;
|
||||
}
|
||||
|
||||
return {total_error / static_cast<Scalar>(data.data_points.size()), total_d_error / static_cast<Scalar>(data.data_points.size())};
|
||||
}
|
||||
|
||||
error_data_t train(const data_t& data)
|
||||
error_data_t train(const data_t& data, bool reset)
|
||||
{
|
||||
error_data_t error = {0, 0};
|
||||
execute(data.bins);
|
||||
|
@ -148,19 +130,34 @@ namespace assign2
|
|||
}
|
||||
}
|
||||
for (auto& l : layers)
|
||||
l->update();
|
||||
l->update(m_omega, reset);
|
||||
// BLT_TRACE("Error for input: %f, derr: %f", error.error, error.d_error);
|
||||
return error;
|
||||
}
|
||||
|
||||
error_data_t train_epoch(const data_file_t& example)
|
||||
error_data_t train_epoch(const data_file_t& example, blt::i32 trains_per_data = 1)
|
||||
{
|
||||
error_data_t error {0, 0};
|
||||
error_data_t error{0, 0};
|
||||
for (const auto& x : example.data_points)
|
||||
error += train(x);
|
||||
error.d_error /= static_cast<Scalar>(example.data_points.size());
|
||||
error.error /= static_cast<Scalar>(example.data_points.size());
|
||||
{
|
||||
for (blt::i32 i = 0; i < trains_per_data; i++)
|
||||
error += train(x, reset_next);
|
||||
}
|
||||
// take the average cost over all the training.
|
||||
error.d_error /= static_cast<Scalar>(example.data_points.size() * trains_per_data);
|
||||
error.error /= static_cast<Scalar>(example.data_points.size() * trains_per_data);
|
||||
// as long as we are reducing error in the same direction in overall terms, we should still build momentum.
|
||||
auto last_sign = last_d_error >= 0;
|
||||
auto cur_sign = error.d_error >= 0;
|
||||
last_d_error = error.d_error;
|
||||
reset_next = last_sign != cur_sign;
|
||||
return error;
|
||||
}
|
||||
|
||||
void with_momentum(Scalar* omega)
|
||||
{
|
||||
m_omega = omega;
|
||||
}
|
||||
|
||||
#ifdef BLT_USE_GRAPHICS
|
||||
|
||||
|
@ -173,6 +170,10 @@ namespace assign2
|
|||
#endif
|
||||
|
||||
private:
|
||||
// pointer so it can be changed from the UI
|
||||
Scalar* m_omega = nullptr;
|
||||
Scalar last_d_error = 0;
|
||||
bool reset_next = false;
|
||||
std::vector<std::unique_ptr<layer_t>> layers;
|
||||
};
|
||||
}
|
||||
|
|
360
src/main.cpp
360
src/main.cpp
|
@ -16,26 +16,58 @@ using namespace assign2;
|
|||
|
||||
std::vector<data_file_t> data_files;
|
||||
blt::hashmap_t<blt::i32, std::vector<data_file_t>> groups;
|
||||
blt::hashmap_t<blt::i32, network_t> networks;
|
||||
bool with_momentum = false;
|
||||
Scalar omega = 0.001;
|
||||
|
||||
random_init randomizer{std::random_device{}()};
|
||||
empty_init empty;
|
||||
small_init small;
|
||||
sigmoid_function sig;
|
||||
relu_function relu;
|
||||
bulu_function bulu;
|
||||
tanh_function func_tanh;
|
||||
|
||||
network_t create_network(blt::i32 input, blt::i32 hidden)
|
||||
{
|
||||
auto layer1 = std::make_unique<layer_t>(input, hidden, &sig, randomizer, empty);
|
||||
auto layer2 = std::make_unique<layer_t>(hidden, hidden * 0.7, &sig, randomizer, empty);
|
||||
auto layer_output = std::make_unique<layer_t>(hidden * 0.7, 2, &sig, randomizer, empty);
|
||||
const auto mul = 0.5;
|
||||
const auto inner_mul = 0.25;
|
||||
auto layer1 = std::make_unique<layer_t>(input, hidden * mul, &sig, randomizer, empty);
|
||||
auto layer2 = std::make_unique<layer_t>(hidden * mul, hidden * inner_mul, &sig, randomizer, empty);
|
||||
// auto layer3 = std::make_unique<layer_t>(hidden * mul, hidden * mul, &sig, randomizer, empty);
|
||||
// auto layer4 = std::make_unique<layer_t>(hidden * mul, hidden * mul, &sig, randomizer, empty);
|
||||
auto layer_output = std::make_unique<layer_t>(hidden * inner_mul, 2, &sig, randomizer, empty);
|
||||
|
||||
std::vector<std::unique_ptr<layer_t>> vec;
|
||||
vec.push_back(std::move(layer1));
|
||||
vec.push_back(std::move(layer2));
|
||||
// vec.push_back(std::move(layer3));
|
||||
// vec.push_back(std::move(layer4));
|
||||
vec.push_back(std::move(layer_output));
|
||||
|
||||
return network_t{std::move(vec)};
|
||||
network_t network{std::move(vec)};
|
||||
if (with_momentum)
|
||||
network.with_momentum(&omega);
|
||||
return network;
|
||||
}
|
||||
|
||||
std::pair<data_file_t, data_file_t> create_groups(blt::i32 network, blt::i32 k = 0)
|
||||
{
|
||||
data_file_t training;
|
||||
data_file_t testing;
|
||||
|
||||
testing.data_points.insert(testing.data_points.begin(),
|
||||
(groups[network].begin() + k)->data_points.begin(),
|
||||
(groups[network].begin() + k)->data_points.end());
|
||||
|
||||
for (auto [i, a] : blt::enumerate(groups[network]))
|
||||
{
|
||||
if (i == static_cast<blt::size_t>(k))
|
||||
continue;
|
||||
training.data_points.insert(training.data_points.begin(), a.data_points.begin(), a.data_points.end());
|
||||
}
|
||||
|
||||
return {training, testing};
|
||||
}
|
||||
|
||||
#ifdef BLT_USE_GRAPHICS
|
||||
|
@ -52,24 +84,63 @@ blt::gfx::resource_manager resources;
|
|||
blt::gfx::batch_renderer_2d renderer_2d(resources, global_matrices);
|
||||
blt::gfx::first_person_camera_2d camera;
|
||||
|
||||
blt::hashmap_t<blt::i32, network_t> networks;
|
||||
|
||||
data_file_t current_training;
|
||||
data_file_t current_testing;
|
||||
std::atomic_int32_t run_epoch = -1;
|
||||
blt::i32 stop_at = -1;
|
||||
blt::i32 trains_per_data = 1;
|
||||
std::mutex vec_lock;
|
||||
|
||||
|
||||
|
||||
std::unique_ptr<std::thread> network_thread;
|
||||
std::atomic_bool running = true;
|
||||
std::atomic_bool run_exit = true;
|
||||
std::atomic_uint64_t epochs = 0;
|
||||
blt::i32 time_between_runs = 0;
|
||||
blt::size_t correct_recall = 0;
|
||||
blt::size_t wrong_recall = 0;
|
||||
blt::i32 number_before_switch = 10;
|
||||
bool swap_k_after = false;
|
||||
blt::size_t correct_recall_train = 0;
|
||||
blt::size_t correct_recall_test = 0;
|
||||
blt::size_t wrong_recall_train = 0;
|
||||
blt::size_t wrong_recall_test = 0;
|
||||
bool run_network = false;
|
||||
|
||||
float init_learn = learn_rate;
|
||||
float init_momentum = omega;
|
||||
|
||||
blt::i32 current_k = 0;
|
||||
|
||||
void update_current(int network)
|
||||
{
|
||||
if (groups[network].size() > 1)
|
||||
{
|
||||
std::scoped_lock lock(vec_lock);
|
||||
current_testing.data_points.clear();
|
||||
current_training.data_points.clear();
|
||||
|
||||
auto g = create_groups(network, current_k);
|
||||
current_testing = g.second;
|
||||
current_training = g.first;
|
||||
} else
|
||||
{
|
||||
std::scoped_lock lock(vec_lock);
|
||||
current_training = groups[network].front();
|
||||
current_testing = groups[network].front();
|
||||
}
|
||||
}
|
||||
|
||||
void reset_errors(int network)
|
||||
{
|
||||
save_error_info(std::to_string(network));
|
||||
errors_over_time.clear();
|
||||
correct_over_time.clear();
|
||||
correct_over_time_test.clear();
|
||||
error_derivative_over_time.clear();
|
||||
error_of_test.clear();
|
||||
error_of_test_derivative.clear();
|
||||
epochs = 0;
|
||||
run_network = false;
|
||||
}
|
||||
|
||||
void init(const blt::gfx::window_data&)
|
||||
{
|
||||
using namespace blt::gfx;
|
||||
|
@ -83,51 +154,70 @@ void init(const blt::gfx::window_data&)
|
|||
renderer_2d.create();
|
||||
ImPlot::CreateContext();
|
||||
|
||||
for (auto& f : data_files)
|
||||
{
|
||||
int input = static_cast<int>(f.data_points.begin()->bins.size());
|
||||
int hidden = input * 1;
|
||||
|
||||
BLT_INFO("Making network of size %d", input);
|
||||
layer_id_counter = 0;
|
||||
networks[input] = create_network(input, hidden);
|
||||
}
|
||||
|
||||
errors_over_time.reserve(25000);
|
||||
error_derivative_over_time.reserve(25000);
|
||||
correct_over_time.reserve(25000);
|
||||
error_of_test.reserve(25000);
|
||||
error_of_test_derivative.reserve(25000);
|
||||
update_current(networks.begin()->first);
|
||||
|
||||
network_thread = std::make_unique<std::thread>([]() {
|
||||
while (running)
|
||||
{
|
||||
if (run_epoch >= 0)
|
||||
{
|
||||
std::scoped_lock lock(vec_lock);
|
||||
auto error = networks.at(run_epoch).train_epoch(current_training);
|
||||
errors_over_time.push_back(error.error);
|
||||
error_derivative_over_time.push_back(error.d_error);
|
||||
|
||||
auto error_test = networks.at(run_epoch).error(current_testing);
|
||||
error_of_test.push_back(error_test.error);
|
||||
error_of_test_derivative.push_back(error_test.d_error);
|
||||
|
||||
blt::size_t right = 0;
|
||||
blt::size_t wrong = 0;
|
||||
for (auto& d : current_testing.data_points)
|
||||
if (swap_k_after && epochs % number_before_switch == static_cast<blt::size_t>(number_before_switch - 1))
|
||||
{
|
||||
auto out = networks.at(run_epoch).execute(d.bins);
|
||||
auto is_bad = is_thinks_bad(out);
|
||||
|
||||
if ((is_bad && d.is_bad) || (!is_bad && !d.is_bad))
|
||||
right++;
|
||||
else
|
||||
wrong++;
|
||||
current_k++;
|
||||
current_k %= static_cast<blt::i32>(groups[run_epoch].size());
|
||||
update_current(run_epoch);
|
||||
}
|
||||
correct_recall = right;
|
||||
wrong_recall = wrong;
|
||||
correct_over_time.push_back(static_cast<Scalar>(right) / static_cast<Scalar>(right + wrong) * 100);
|
||||
|
||||
blt::size_t right_t = 0;
|
||||
blt::size_t wrong_t = 0;
|
||||
blt::size_t right_a = 0;
|
||||
blt::size_t wrong_a = 0;
|
||||
{
|
||||
std::scoped_lock lock(vec_lock);
|
||||
auto error = networks.at(run_epoch).train_epoch(current_training, trains_per_data);
|
||||
errors_over_time.push_back(error.error);
|
||||
error_derivative_over_time.push_back(error.d_error);
|
||||
|
||||
auto error_test = networks.at(run_epoch).error(current_testing);
|
||||
error_of_test.push_back(error_test.error);
|
||||
error_of_test_derivative.push_back(error_test.d_error);
|
||||
|
||||
for (auto& d : current_testing.data_points)
|
||||
{
|
||||
auto out = networks.at(run_epoch).execute(d.bins);
|
||||
auto is_bad = is_thinks_bad(out);
|
||||
|
||||
if ((is_bad && d.is_bad) || (!is_bad && !d.is_bad))
|
||||
right_t++;
|
||||
else
|
||||
wrong_t++;
|
||||
}
|
||||
|
||||
for (auto& d : current_training.data_points)
|
||||
{
|
||||
auto out = networks.at(run_epoch).execute(d.bins);
|
||||
auto is_bad = is_thinks_bad(out);
|
||||
|
||||
if ((is_bad && d.is_bad) || (!is_bad && !d.is_bad))
|
||||
right_a++;
|
||||
else
|
||||
wrong_a++;
|
||||
}
|
||||
}
|
||||
correct_recall_test = right_t;
|
||||
correct_recall_train = right_a;
|
||||
wrong_recall_test = wrong_t;
|
||||
wrong_recall_train = wrong_a;
|
||||
correct_over_time
|
||||
.push_back(static_cast<Scalar>(correct_recall_train) / static_cast<Scalar>(correct_recall_train + wrong_recall_train) * 100);
|
||||
correct_over_time_test
|
||||
.push_back(static_cast<Scalar>(correct_recall_test) / static_cast<Scalar>(correct_recall_test + wrong_recall_test) * 100);
|
||||
|
||||
auto error = errors_over_time.back();
|
||||
// error = std::sqrt(error * error + error + 0.01f);
|
||||
error = std::max(0.0f, std::min(1.0f, error));
|
||||
learn_rate = error * init_learn;
|
||||
omega = error * init_momentum;
|
||||
|
||||
epochs++;
|
||||
run_epoch = -1;
|
||||
|
@ -179,6 +269,18 @@ void plot_vector(ImPlotRect& lims, const std::vector<Scalar>& v, std::string nam
|
|||
}
|
||||
}
|
||||
|
||||
static void HelpMarker(const char* desc)
|
||||
{
|
||||
ImGui::TextDisabled("(?)");
|
||||
if (ImGui::BeginItemTooltip())
|
||||
{
|
||||
ImGui::PushTextWrapPos(ImGui::GetFontSize() * 35.0f);
|
||||
ImGui::TextUnformatted(desc);
|
||||
ImGui::PopTextWrapPos();
|
||||
ImGui::EndTooltip();
|
||||
}
|
||||
}
|
||||
|
||||
void update(const blt::gfx::window_data& data)
|
||||
{
|
||||
global_matrices.update_perspectives(data.width, data.height, 90, 0.1, 2000);
|
||||
|
@ -207,55 +309,44 @@ void update(const blt::gfx::window_data& data)
|
|||
lists.push_back(ptr);
|
||||
}
|
||||
}
|
||||
static int selected = 1;
|
||||
static int selected = 0;
|
||||
for (int i = 0; i < selected; i++)
|
||||
net++;
|
||||
ImGui::Separator();
|
||||
ImGui::Text("Select Network Size");
|
||||
if (ImGui::ListBox("", &selected, lists.data(), static_cast<int>(lists.size()), 4))
|
||||
{
|
||||
errors_over_time.clear();
|
||||
correct_over_time.clear();
|
||||
error_derivative_over_time.clear();
|
||||
error_of_test.clear();
|
||||
error_of_test_derivative.clear();
|
||||
run_network = false;
|
||||
reset_errors(net->first);
|
||||
net = networks.begin();
|
||||
for (int i = 0; i < selected; i++)
|
||||
net++;
|
||||
update_current(net->first);
|
||||
}
|
||||
ImGui::Separator();
|
||||
ImGui::Text("Using network %d size %d", selected, net->first);
|
||||
ImGui::Checkbox("Train Network", &run_network);
|
||||
ImGui::InputInt("Stop At", &stop_at);
|
||||
if (static_cast<blt::i32>(epochs) >= stop_at && stop_at > 0)
|
||||
run_network = false;
|
||||
if (run_network)
|
||||
{
|
||||
if (groups[net->first].size() > 1)
|
||||
{
|
||||
std::scoped_lock lock(vec_lock);
|
||||
current_testing.data_points.clear();
|
||||
current_training.data_points.clear();
|
||||
|
||||
current_testing.data_points.insert(current_testing.data_points.begin(), groups[net->first].front().data_points.begin(),
|
||||
groups[net->first].front().data_points.end());
|
||||
for (auto a : blt::iterate(groups[net->first]).skip(1))
|
||||
current_training.data_points.insert(current_training.data_points.begin(), a.data_points.begin(), a.data_points.end());
|
||||
} else
|
||||
{
|
||||
std::scoped_lock lock(vec_lock);
|
||||
current_training = groups[net->first].front();
|
||||
current_testing = groups[net->first].front();
|
||||
}
|
||||
|
||||
// update_current(net->first);
|
||||
run_epoch = net->first;
|
||||
}
|
||||
ImGui::InputInt("Time Between Runs", &time_between_runs);
|
||||
if (time_between_runs < 0)
|
||||
time_between_runs = 0;
|
||||
std::string str = std::to_string(correct_recall) + "/" + std::to_string(wrong_recall + correct_recall);
|
||||
std::string str = std::to_string(correct_recall_test) + "/" + std::to_string(wrong_recall_test + correct_recall_test);
|
||||
ImGui::ProgressBar(
|
||||
(wrong_recall + correct_recall != 0) ? static_cast<float>(correct_recall) / static_cast<float>(wrong_recall + correct_recall) : 0,
|
||||
(wrong_recall_test + correct_recall_test != 0) ? static_cast<float>(correct_recall_test) /
|
||||
static_cast<float>(wrong_recall_test + correct_recall_test) : 0,
|
||||
ImVec2(0, 0), str.c_str());
|
||||
ImGui::Separator();
|
||||
str = std::to_string(correct_recall_train) + "/" + std::to_string(wrong_recall_train + correct_recall_train);
|
||||
ImGui::ProgressBar(
|
||||
(wrong_recall_train + correct_recall_train != 0) ? static_cast<float>(correct_recall_train) /
|
||||
static_cast<float>(wrong_recall_train + correct_recall_train) : 0,
|
||||
ImVec2(0, 0), str.c_str());
|
||||
// const float max_learn = 100000;
|
||||
// static float learn = max_learn;
|
||||
// ImGui::SliderFloat("Learn Rate", &learn, 1, max_learn, "", ImGuiSliderFlags_Logarithmic);
|
||||
// learn_rate = learn / (max_learn * 1000);
|
||||
ImGui::Text("Learn Rate %.9f", learn_rate);
|
||||
if (ImGui::Button("Print Current"))
|
||||
{
|
||||
|
@ -278,6 +369,35 @@ void update(const blt::gfx::window_data& data)
|
|||
}
|
||||
BLT_INFO("NN got %ld right and %ld wrong (%%%lf)", right, wrong, static_cast<double>(right) / static_cast<double>(right + wrong) * 100);
|
||||
}
|
||||
if (ImGui::SliderInt("K For Testing", ¤t_k, 0, static_cast<int>(groups[net->first].size() - 1)))
|
||||
update_current(net->first);
|
||||
ImGui::Checkbox("Auto-swap K", &swap_k_after);
|
||||
if (swap_k_after)
|
||||
{
|
||||
ImGui::InputInt("Number of epochs before switch", &number_before_switch);
|
||||
if (number_before_switch < 1)
|
||||
number_before_switch = 1;
|
||||
}
|
||||
ImGui::Checkbox("Momentum", &with_momentum);
|
||||
ImGui::SameLine();
|
||||
HelpMarker("You might want to reset the network after changing this");
|
||||
if (with_momentum)
|
||||
ImGui::SliderFloat("##MomentumSlider", &omega, 0, 0.1, "%.8f", ImGuiSliderFlags_Logarithmic);
|
||||
ImGui::InputInt("Trains per Epoch", &trains_per_data);
|
||||
ImGui::SameLine();
|
||||
HelpMarker("Number of times to run back-prop on a piece of data before moving on to the next");
|
||||
if (trains_per_data < 1)
|
||||
trains_per_data = 1;
|
||||
ImGui::Separator();
|
||||
if (ImGui::Button("Reset Network"))
|
||||
{
|
||||
reset_errors(net->first);
|
||||
layer_id_counter = 0;
|
||||
networks[net->first] = create_network(net->first, net->first);
|
||||
}
|
||||
ImGui::Separator();
|
||||
if (ImGui::Button("Save current to CSV"))
|
||||
save_error_info(std::to_string(net->first) + "_" + std::to_string(current_k));
|
||||
}
|
||||
ImGui::End();
|
||||
|
||||
|
@ -291,12 +411,10 @@ void update(const blt::gfx::window_data& data)
|
|||
x_points.push_back(i);
|
||||
}
|
||||
|
||||
auto domain = static_cast<int>(errors_over_time.size());
|
||||
blt::i32 history = std::min(100, domain);
|
||||
static ImPlotRect lims(0, 100, 0, 1);
|
||||
if (ImPlot::BeginAlignedPlots("AlignedGroup"))
|
||||
static ImPlotRect lims(0, 500, 0, 1);
|
||||
if (ImPlot::BeginSubplots("##LinkedGroup", 3, 2, ImVec2(-1, -1)))
|
||||
{
|
||||
plot_vector(lims, errors_over_time, "Global Error over epochs", "Epoch", "Error", [](auto v, bool b) {
|
||||
plot_vector(lims, errors_over_time, "Global Error (Training)", "Epoch", "Error", [](auto v, bool b) {
|
||||
float percent = 0.15;
|
||||
if (b)
|
||||
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||
|
@ -310,27 +428,33 @@ void update(const blt::gfx::window_data& data)
|
|||
else
|
||||
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||
});
|
||||
plot_vector(lims, correct_over_time, "% Correct over epochs", "Epoch", "Correct%", [](auto v, bool b) {
|
||||
plot_vector(lims, error_derivative_over_time, "DError/Dw (Training)", "Epoch", "DError", [](auto v, bool b) {
|
||||
float percent = 0.05;
|
||||
if (b)
|
||||
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||
else
|
||||
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||
});
|
||||
plot_vector(lims, error_of_test_derivative, "DError/Dw (Test)", "Epoch", "DError", [](auto v, bool b) {
|
||||
float percent = 0.05;
|
||||
if (b)
|
||||
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||
else
|
||||
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||
});
|
||||
plot_vector(lims, correct_over_time, "% Correct (Training)", "Epoch", "Correct%", [](auto v, bool b) {
|
||||
if (b)
|
||||
return v - 1;
|
||||
else
|
||||
return v + 1;
|
||||
});
|
||||
plot_vector(lims, error_derivative_over_time, "DError/Dw over epochs", "Epoch", "Error", [](auto v, bool b) {
|
||||
float percent = 0.05;
|
||||
plot_vector(lims, correct_over_time_test, "% Correct (Test)", "Epoch", "Correct%", [](auto v, bool b) {
|
||||
if (b)
|
||||
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||
return v - 1;
|
||||
else
|
||||
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||
return v + 1;
|
||||
});
|
||||
plot_vector(lims, error_of_test_derivative, "DError/Dw (Test)", "Epoch", "Error", [](auto v, bool b) {
|
||||
float percent = 0.05;
|
||||
if (b)
|
||||
return v < 0 ? v * (1 + percent) : v * (1 - percent);
|
||||
else
|
||||
return v < 0 ? v * (1 - percent) : v * (1 + percent);
|
||||
});
|
||||
ImPlot::EndAlignedPlots();
|
||||
ImPlot::EndSubplots();
|
||||
}
|
||||
}
|
||||
ImGui::End();
|
||||
|
@ -347,6 +471,7 @@ void update(const blt::gfx::window_data& data)
|
|||
|
||||
void destroy()
|
||||
{
|
||||
save_error_info(std::to_string(run_epoch));
|
||||
running = false;
|
||||
while (run_exit)
|
||||
{
|
||||
|
@ -369,12 +494,20 @@ void destroy()
|
|||
int main(int argc, const char** argv)
|
||||
{
|
||||
blt::arg_parse parser;
|
||||
parser.addArgument(blt::arg_builder("-f", "--file").setHelp("path to the data files").setDefault("../data").build());
|
||||
parser.addArgument(blt::arg_builder("-f", "--file").setHelp("Path to the data files").setDefault("../data").setMetavar("FOLDER").build());
|
||||
parser.addArgument(
|
||||
blt::arg_builder("-k", "--kfold").setHelp("Number of groups to split into").setAction(blt::arg_action_t::STORE).setNArgs('?')
|
||||
.setConst("3").build());
|
||||
blt::arg_builder("-k", "--kfold").setHelp("Number of groups to split into [Defaults to 3 if no number is provided]")
|
||||
.setAction(blt::arg_action_t::STORE).setNArgs('?').setConst("3").setMetavar("GROUPS").build());
|
||||
parser.addArgument(blt::arg_builder("-m", "--momentum").setHelp("Use momentum in weight calculations").setAction(blt::arg_action_t::STORE_TRUE)
|
||||
.setDefault(false).build());
|
||||
|
||||
auto args = parser.parse_args(argc, argv);
|
||||
if (args.get<bool>("momentum"))
|
||||
{
|
||||
BLT_INFO("Using Momentum");
|
||||
with_momentum = true;
|
||||
}
|
||||
|
||||
std::string data_directory = blt::string::ensure_ends_with_path_separator(args.get<std::string>("file"));
|
||||
|
||||
data_files = load_data_files(get_data_files(data_directory));
|
||||
|
@ -387,6 +520,7 @@ int main(int argc, const char** argv)
|
|||
for (auto& n : data_files)
|
||||
{
|
||||
std::vector<data_t> goods;
|
||||
// Big Airship of Doom (BAD)
|
||||
std::vector<data_t> bads;
|
||||
|
||||
for (auto& p : n.data_points)
|
||||
|
@ -407,6 +541,10 @@ int main(int argc, const char** argv)
|
|||
groups[size].emplace_back();
|
||||
|
||||
// then copy proportionally into the groups, creating roughly equal groups of data.
|
||||
// my previous setup randomly selected the group index
|
||||
// this resulted in wildly uneven groups, if you got unlucky.
|
||||
// 25 vs 13 in some groups
|
||||
// not sure if this is what we want, but it felt like this would create issues
|
||||
blt::size_t select = 0;
|
||||
for (auto& v : goods)
|
||||
{
|
||||
|
@ -436,6 +574,24 @@ int main(int argc, const char** argv)
|
|||
for (auto [i, f] : blt::enumerate(g))
|
||||
BLT_INFO("\tData file %ld contains %ld elements", i + 1, f.data_points.size());
|
||||
}
|
||||
|
||||
for (auto& f : data_files)
|
||||
{
|
||||
int input = static_cast<int>(f.data_points.begin()->bins.size());
|
||||
int hidden = input * 1;
|
||||
|
||||
BLT_INFO("Making network of size %d", input);
|
||||
layer_id_counter = 0;
|
||||
networks[input] = create_network(input, hidden);
|
||||
}
|
||||
|
||||
// this is to prevent threading issues due to expanding buffers.
|
||||
errors_over_time.reserve(25000);
|
||||
error_derivative_over_time.reserve(25000);
|
||||
correct_over_time.reserve(25000);
|
||||
correct_over_time_test.reserve(25000);
|
||||
error_of_test.reserve(25000);
|
||||
error_of_test_derivative.reserve(25000);
|
||||
|
||||
#ifdef BLT_USE_GRAPHICS
|
||||
blt::gfx::init(blt::gfx::window_data{"Freeplay Graphics", init, update, 1440, 720}.setSyncInterval(1).setMonitor(glfwGetPrimaryMonitor())
|
||||
|
@ -449,7 +605,7 @@ int main(int argc, const char** argv)
|
|||
int input = static_cast<int>(f.data_points.begin()->bins.size());
|
||||
int hidden = input;
|
||||
|
||||
if (input != 16)
|
||||
if (input != 32)
|
||||
continue;
|
||||
|
||||
BLT_INFO("-----------------");
|
||||
|
@ -459,8 +615,10 @@ int main(int argc, const char** argv)
|
|||
|
||||
network_t network = create_network(input, hidden);
|
||||
|
||||
for (blt::size_t i = 0; i < 2000; i++)
|
||||
network.train_epoch(f);
|
||||
float o = 0.00001;
|
||||
network.with_momentum(&o);
|
||||
for (blt::size_t i = 0; i < 300; i++)
|
||||
network.train_epoch(f, 1);
|
||||
|
||||
BLT_INFO("Test Cases:");
|
||||
blt::size_t right = 0;
|
||||
|
|
Loading…
Reference in New Issue