main
Brett 2024-10-30 03:02:28 -04:00
parent 6021b7e4bb
commit c7cf4721c8
4 changed files with 12 additions and 12 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-2 VERSION 0.0.10) project(COSC-4P80-Assignment-2 VERSION 0.0.11)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -57,19 +57,19 @@ namespace assign2
{ {
// delta for weights // delta for weights
error = act->derivative(z) * next_error; error = act->derivative(z) * next_error;
db = learn_rate * error; db = -learn_rate * error;
BLT_ASSERT(previous_outputs.size() == dw.size()); BLT_ASSERT(previous_outputs.size() == dw.size());
for (auto [prev_out, d_weight] : blt::zip(previous_outputs, dw)) for (auto [prev_out, d_weight] : blt::zip(previous_outputs, dw))
{ {
// dw // dw
d_weight = -learn_rate * prev_out * error; d_weight = learn_rate * prev_out * error;
} }
} }
void update(float omega, bool reset) void update(float omega, bool reset)
{ {
// if omega is zero we are not using momentum. // if omega is zero we are not using momentum.
if (reset || omega == 0) if (omega == 0)
{ {
// BLT_TRACE("Momentum Reset"); // BLT_TRACE("Momentum Reset");
// for (auto& v : momentum) // for (auto& v : momentum)
@ -165,7 +165,7 @@ namespace assign2
[this, &prev_layer_output, &total_error, &total_derivative](const std::vector<Scalar>& expected) { [this, &prev_layer_output, &total_error, &total_derivative](const std::vector<Scalar>& expected) {
for (auto [i, n] : blt::enumerate(neurons)) for (auto [i, n] : blt::enumerate(neurons))
{ {
auto d = outputs[i] - expected[i]; auto d = expected[i] - outputs[i];
// if (outputs[0] > 0.3 && outputs[1] > 0.3) // if (outputs[0] > 0.3 && outputs[1] > 0.3)
// d *= 10 * (outputs[0] + outputs[1]); // d *= 10 * (outputs[0] + outputs[1]);
auto d2 = 0.5f * (d * d); auto d2 = 0.5f * (d * d);

View File

@ -98,7 +98,7 @@ namespace assign2
BLT_ASSERT(out.size() == expected.size()); BLT_ASSERT(out.size() == expected.size());
for (auto [o, e] : blt::in_pairs(out, expected)) for (auto [o, e] : blt::in_pairs(out, expected))
{ {
auto d_error = o - e; auto d_error = e - o;
auto error = 0.5f * (d_error * d_error); auto error = 0.5f * (d_error * d_error);

View File

@ -34,8 +34,8 @@ network_t create_network(blt::i32 input, blt::i32 hidden)
const auto inner_mul = 0.25; const auto inner_mul = 0.25;
auto layer1 = std::make_unique<layer_t>(input, hidden * mul, &sig, randomizer, empty); auto layer1 = std::make_unique<layer_t>(input, hidden * mul, &sig, randomizer, empty);
auto layer2 = std::make_unique<layer_t>(hidden * mul, hidden * inner_mul, &sig, randomizer, empty); auto layer2 = std::make_unique<layer_t>(hidden * mul, hidden * inner_mul, &sig, randomizer, empty);
// auto layer3 = std::make_unique<layer_t>(hidden * mul, hidden * mul, &sig, randomizer, empty); // auto layer3 = std::make_unique<layer_t>(hidden * inner_mul, hidden * inner_mul, &sig, randomizer, empty);
// auto layer4 = std::make_unique<layer_t>(hidden * mul, hidden * mul, &sig, randomizer, empty); // auto layer4 = std::make_unique<layer_t>(hidden * inner_mul, hidden * inner_mul, &sig, randomizer, empty);
auto layer_output = std::make_unique<layer_t>(hidden * inner_mul, 2, &sig, randomizer, empty); auto layer_output = std::make_unique<layer_t>(hidden * inner_mul, 2, &sig, randomizer, empty);
std::vector<std::unique_ptr<layer_t>> vec; std::vector<std::unique_ptr<layer_t>> vec;
@ -215,7 +215,7 @@ void init(const blt::gfx::window_data&)
auto error = errors_over_time.back(); auto error = errors_over_time.back();
// error = std::sqrt(error * error + error + 0.01f); // error = std::sqrt(error * error + error + 0.01f);
error = std::max(0.0f, std::min(1.0f, error)); // error = std::max(0.0f, std::min(1.0f, error));
learn_rate = error * init_learn; learn_rate = error * init_learn;
omega = error * init_momentum; omega = error * init_momentum;
@ -605,7 +605,7 @@ int main(int argc, const char** argv)
int input = static_cast<int>(f.data_points.begin()->bins.size()); int input = static_cast<int>(f.data_points.begin()->bins.size());
int hidden = input; int hidden = input;
if (input != 32) if (input != 64)
continue; continue;
BLT_INFO("-----------------"); BLT_INFO("-----------------");
@ -616,8 +616,8 @@ int main(int argc, const char** argv)
network_t network = create_network(input, hidden); network_t network = create_network(input, hidden);
float o = 0.00001; float o = 0.00001;
network.with_momentum(&o); // network.with_momentum(&o);
for (blt::size_t i = 0; i < 300; i++) for (blt::size_t i = 0; i < 10000; i++)
network.train_epoch(f, 1); network.train_epoch(f, 1);
BLT_INFO("Test Cases:"); BLT_INFO("Test Cases:");