idk if this is supposed to work so well?
parent
8f0361df0f
commit
fcb28332a5
|
@ -1,5 +1,5 @@
|
|||
cmake_minimum_required(VERSION 3.25)
|
||||
project(COSC-4P80-Assignment-1 VERSION 0.0.7)
|
||||
project(COSC-4P80-Assignment-1 VERSION 0.0.8)
|
||||
|
||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
#pragma once
|
||||
/*
|
||||
* Copyright (C) 2024 Brett Terpstra
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef COSC_4P80_ASSIGNMENT_1_A1_H
|
||||
#define COSC_4P80_ASSIGNMENT_1_A1_H
|
||||
|
||||
#include <blt/math/matrix.h>
|
||||
#include <blt/math/log_util.h>
|
||||
|
||||
namespace a1
|
||||
{
|
||||
void test_math()
|
||||
{
|
||||
blt::generalized_matrix<float, 1, 4> input{1, -1, -1, 1};
|
||||
blt::generalized_matrix<float, 1, 3> output{1, 1, 1};
|
||||
blt::generalized_matrix<float, 4, 3> expected{
|
||||
blt::vec4{1, -1, -1, 1},
|
||||
blt::vec4{1, -1, -1, 1},
|
||||
blt::vec4{1, -1, -1, 1}
|
||||
};
|
||||
|
||||
auto w_matrix = input.transpose() * output;
|
||||
BLT_ASSERT(w_matrix == expected && "MATH MATRIX FAILURE");
|
||||
|
||||
blt::vec4 one{5, 1, 3, 0};
|
||||
blt::vec4 two{9, -5, -8, 3};
|
||||
|
||||
blt::generalized_matrix<float, 1, 4> g1{5, 1, 3, 0};
|
||||
blt::generalized_matrix<float, 1, 4> g2{9, -5, -8, 3};
|
||||
|
||||
BLT_ASSERT(g1 * g2.transpose() == blt::vec4::dot(one, two) && "MATH DOT FAILURE");
|
||||
}
|
||||
|
||||
template<typename input_t>
|
||||
float crosstalk(const input_t& i, const input_t& j)
|
||||
{
|
||||
return i * j.transpose();
|
||||
}
|
||||
|
||||
template<typename T, blt::u32 rows, blt::u32 columns>
|
||||
blt::generalized_matrix<T, rows, columns> threshold(const blt::generalized_matrix<T, rows, columns>& y,
|
||||
const blt::generalized_matrix<T, rows, columns>& base)
|
||||
{
|
||||
blt::generalized_matrix<T, rows, columns> result;
|
||||
for (blt::u32 i = 0; i < columns; i++)
|
||||
{
|
||||
for (blt::u32 j = 0; j < rows; j++)
|
||||
result[i][j] = y[i][j] > 1 ? 1 : (y[i][j] < -1 ? -1 : base[i][j]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T, blt::size_t size>
|
||||
bool equal(const std::array<T, size>& a, const std::array<T, size>& b)
|
||||
{
|
||||
for (const auto& [index, val] : blt::enumerate(a))
|
||||
{
|
||||
if (b[index] != val)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename weight_t, typename input_t, typename output_t>
|
||||
std::pair<input_t, output_t> run_step(const weight_t& associated_weights, const input_t& input, const output_t& output)
|
||||
{
|
||||
output_t output_recall = input * associated_weights;
|
||||
input_t input_recall = output * associated_weights.transpose();
|
||||
|
||||
return std::pair{a1::threshold(input_recall, input), a1::threshold(output_recall, output)};
|
||||
}
|
||||
|
||||
template<typename weight_t, typename T, typename G, blt::size_t size>
|
||||
void check_recall(const weight_t& weights, const std::array<G, size>& inputs, const std::array<T, size>& outputs)
|
||||
{
|
||||
for (const auto& [index, val] : blt::enumerate(inputs))
|
||||
{
|
||||
auto result = run_step(weights, val, outputs[index]);
|
||||
if (result.first != val)
|
||||
BLT_ERROR("Recall of input #%ld failed", index + 1);
|
||||
else
|
||||
BLT_INFO("Recall of input #%ld passed", index + 1);
|
||||
if (result.second != outputs[index])
|
||||
BLT_ERROR("Recall of output #%ld failed", index + 1);
|
||||
else
|
||||
BLT_INFO("recall of output #%ld passed", index + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //COSC_4P80_ASSIGNMENT_1_A1_H
|
|
@ -1,58 +0,0 @@
|
|||
#pragma once
|
||||
/*
|
||||
* Copyright (C) 2024 Brett Terpstra
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
|
||||
#define COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
|
||||
|
||||
#include <blt/math/matrix.h>
|
||||
#include <blt/math/log_util.h>
|
||||
|
||||
namespace a1
|
||||
{
|
||||
void test_math()
|
||||
{
|
||||
blt::generalized_matrix<float, 1, 4> input{1, -1, -1, 1};
|
||||
blt::generalized_matrix<float, 1, 3> output{1, 1, 1};
|
||||
blt::generalized_matrix<float, 4, 3> expected{
|
||||
blt::vec4{1, -1, -1, 1},
|
||||
blt::vec4{1, -1, -1, 1},
|
||||
blt::vec4{1, -1, -1, 1}
|
||||
};
|
||||
|
||||
auto w_matrix = input.transpose() * output;
|
||||
BLT_ASSERT(w_matrix == expected && "MATH MATRIX FAILURE");
|
||||
|
||||
blt::vec4 one{5, 1, 3, 0};
|
||||
blt::vec4 two{9, -5, -8, 3};
|
||||
|
||||
blt::generalized_matrix<float, 1, 4> g1{5, 1, 3, 0};
|
||||
blt::generalized_matrix<float, 1, 4> g2{9, -5, -8, 3};
|
||||
|
||||
BLT_ASSERT(g1 * g2.transpose() == blt::vec4::dot(one, two) && "MATH DOT FAILURE");
|
||||
}
|
||||
|
||||
enum class recall_error_t
|
||||
{
|
||||
// failed to predict input
|
||||
INPUT_FAILURE,
|
||||
// failed to predict output
|
||||
OUTPUT_FAILURE
|
||||
};
|
||||
}
|
||||
|
||||
#endif //COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
|
86
src/main.cpp
86
src/main.cpp
|
@ -3,7 +3,7 @@
|
|||
#include <blt/math/log_util.h>
|
||||
#include "blt/std/assert.h"
|
||||
#include <blt/format/boxing.h>
|
||||
#include <fwd_decl.h>
|
||||
#include <a1.h>
|
||||
|
||||
constexpr blt::u32 num_values = 4;
|
||||
constexpr blt::u32 input_count = 5;
|
||||
|
@ -14,58 +14,83 @@ using output_t = blt::generalized_matrix<float, 1, output_count>;
|
|||
using weight_t = decltype(std::declval<input_t>().transpose() * std::declval<output_t>());
|
||||
using crosstalk_t = blt::generalized_matrix<float, output_count, num_values>;
|
||||
|
||||
float crosstalk(const input_t& i, const input_t& j)
|
||||
{
|
||||
return i * j.transpose();
|
||||
}
|
||||
|
||||
// part a
|
||||
input_t input_1{-1, 1, 1, 1, -1};
|
||||
input_t input_2{-1, -1, -1, -1, 1};
|
||||
input_t input_3{-1, -1, -1, 1, 1};
|
||||
// part c 1
|
||||
input_t input_4{1, 1, 1, 1, 1};
|
||||
// part c 2
|
||||
input_t input_5{-1, 1, -1, 1, 1};
|
||||
input_t input_6{1, -1, 1, -1, 1};
|
||||
input_t input_7{-1, 1, -1, 1, -1};
|
||||
|
||||
// part a
|
||||
output_t output_1{1, 1, -1, 1};
|
||||
output_t output_2{1, -1, -1, -1};
|
||||
output_t output_3{-1, -1, 1, 1};
|
||||
// part c 1
|
||||
output_t output_4{-1, 1, 1, -1};
|
||||
// part c 2
|
||||
output_t output_5{1, 1, 1, 1};
|
||||
output_t output_6{1, -1, -1, 1};
|
||||
output_t output_7{1, 1, 1, -1};
|
||||
|
||||
const weight_t weight_1 = input_1.transpose() * output_1;
|
||||
const weight_t weight_2 = input_2.transpose() * output_2;
|
||||
const weight_t weight_3 = input_3.transpose() * output_3;
|
||||
const weight_t weight_4 = input_4.transpose() * output_4;
|
||||
const weight_t weight_5 = input_5.transpose() * output_5;
|
||||
const weight_t weight_6 = input_6.transpose() * output_6;
|
||||
const weight_t weight_7 = input_7.transpose() * output_7;
|
||||
|
||||
auto starting_inputs = std::array{input_1, input_2, input_3, input_4};
|
||||
auto starting_outputs = std::array{output_1, output_2, output_3, output_4};
|
||||
auto part_a_inputs = std::array{input_1, input_2, input_3};
|
||||
auto part_a_outputs = std::array{output_1, output_2, output_3};
|
||||
|
||||
auto part_c_1_inputs = std::array{input_1, input_2, input_3, input_4};
|
||||
auto part_c_1_outputs = std::array{output_1, output_2, output_3, output_4};
|
||||
|
||||
auto part_c_2_inputs = std::array{input_1, input_2, input_3, input_4, input_5, input_6, input_7};
|
||||
auto part_c_2_outputs = std::array{output_1, output_2, output_3, output_4, output_5, output_6, output_7};
|
||||
|
||||
const auto weight_total_a = weight_1 + weight_2 + weight_3;
|
||||
const auto weight_total_c = weight_total_a + weight_4;
|
||||
const auto weight_total_c_2 = weight_total_c + weight_5 + weight_6 + weight_7;
|
||||
|
||||
crosstalk_t crosstalk_values{};
|
||||
|
||||
template<typename T, blt::u32 rows, blt::u32 columns>
|
||||
blt::generalized_matrix<T, rows, columns> threshold(const blt::generalized_matrix<T, rows, columns>& y, const blt::generalized_matrix<T, rows, columns>& base)
|
||||
template<typename Weights, typename Inputs, typename Outputs>
|
||||
void execute_BAM(const Weights& weights, const Inputs& input, const Outputs& output)
|
||||
{
|
||||
blt::generalized_matrix<T, rows, columns> result;
|
||||
for (blt::u32 i = 0; i < columns; i++)
|
||||
{
|
||||
for (blt::u32 j = 0; j < rows; j++)
|
||||
result[i][j] = y[i][j] > 1 ? 1 : (y[i][j] < -1 ? -1 : base[i][j]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<input_t, output_t> run_step(const weight_t& associated_weights, const input_t& input, const output_t & output)
|
||||
{
|
||||
output_t output_recall = input * associated_weights;
|
||||
input_t input_recall = output * associated_weights.transpose();
|
||||
auto current_inputs = input;
|
||||
auto current_outputs = output;
|
||||
auto next_inputs = current_inputs;
|
||||
auto next_outputs = current_outputs;
|
||||
blt::size_t iterations = 0;
|
||||
constexpr blt::size_t max_iterations = 5;
|
||||
|
||||
return std::pair{threshold(input_recall, input), threshold(output_recall, output)};
|
||||
do
|
||||
{
|
||||
current_inputs = next_inputs;
|
||||
current_outputs = next_outputs;
|
||||
++iterations;
|
||||
for (const auto& [index, val] : blt::enumerate(current_inputs))
|
||||
{
|
||||
auto next = a1::run_step(weights, val, current_outputs[index]);
|
||||
next_inputs[index] = next.first;
|
||||
next_outputs[index] = next.second;
|
||||
}
|
||||
} while ((!a1::equal(current_inputs, next_inputs) || !a1::equal(current_outputs, next_outputs)) && iterations < max_iterations);
|
||||
|
||||
BLT_DEBUG("Tracked after %ld iterations", iterations);
|
||||
a1::check_recall(weights, next_inputs, next_outputs);
|
||||
}
|
||||
|
||||
void part_a()
|
||||
{
|
||||
blt::log_box_t box(BLT_TRACE_STREAM, "Part A", 8);
|
||||
|
||||
execute_BAM(weight_total_a, part_a_inputs, part_a_outputs);
|
||||
}
|
||||
|
||||
void part_b()
|
||||
|
@ -78,7 +103,7 @@ void part_b()
|
|||
{
|
||||
if (i == k)
|
||||
continue;
|
||||
accum += (outputs[k] * crosstalk(inputs[k].normalize(), inputs[i].normalize()));
|
||||
accum += (part_a_outputs[k] * a1::crosstalk(part_a_inputs[k].normalize(), part_a_inputs[i].normalize()));
|
||||
}
|
||||
crosstalk_values.assign_to_column_from_column_rows(accum, i);
|
||||
}
|
||||
|
@ -88,11 +113,20 @@ void part_b()
|
|||
}
|
||||
}
|
||||
|
||||
void part_c()
|
||||
{
|
||||
blt::log_box_t box(BLT_TRACE_STREAM, "Part C", 8);
|
||||
execute_BAM(weight_total_c, part_c_1_inputs, part_c_1_outputs);
|
||||
BLT_TRACE("--- { Part C with 3 extra pairs } ---");
|
||||
execute_BAM(weight_total_c_2, part_c_2_inputs, part_c_2_outputs);
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
blt::logging::setLogOutputFormat("\033[94m[${{TIME}}]${{RC}} \033[35m(${{FILE}}:${{LINE}})${{RC}} ${{LF}}${{CNR}}${{STR}}${{RC}}\n");
|
||||
test_math();
|
||||
a1::test_math();
|
||||
|
||||
part_a();
|
||||
part_b();
|
||||
part_c();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue