From fcb28332a5ec525618e38630974741a6cf3ea393 Mon Sep 17 00:00:00 2001 From: Brett Date: Sun, 22 Sep 2024 16:53:09 -0400 Subject: [PATCH] idk if this is supposed to work so well? --- CMakeLists.txt | 2 +- include/a1.h | 106 +++++++++++++++++++++++++++++++++++++++++++++ include/fwd_decl.h | 58 ------------------------- src/main.cpp | 86 +++++++++++++++++++++++++----------- 4 files changed, 167 insertions(+), 85 deletions(-) create mode 100644 include/a1.h delete mode 100644 include/fwd_decl.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 1eede30..d68395f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Assignment-1 VERSION 0.0.7) +project(COSC-4P80-Assignment-1 VERSION 0.0.8) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/include/a1.h b/include/a1.h new file mode 100644 index 0000000..4ac845e --- /dev/null +++ b/include/a1.h @@ -0,0 +1,106 @@ +#pragma once +/* + * Copyright (C) 2024 Brett Terpstra + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef COSC_4P80_ASSIGNMENT_1_A1_H +#define COSC_4P80_ASSIGNMENT_1_A1_H + +#include +#include + +namespace a1 +{ + void test_math() + { + blt::generalized_matrix input{1, -1, -1, 1}; + blt::generalized_matrix output{1, 1, 1}; + blt::generalized_matrix expected{ + blt::vec4{1, -1, -1, 1}, + blt::vec4{1, -1, -1, 1}, + blt::vec4{1, -1, -1, 1} + }; + + auto w_matrix = input.transpose() * output; + BLT_ASSERT(w_matrix == expected && "MATH MATRIX FAILURE"); + + blt::vec4 one{5, 1, 3, 0}; + blt::vec4 two{9, -5, -8, 3}; + + blt::generalized_matrix g1{5, 1, 3, 0}; + blt::generalized_matrix g2{9, -5, -8, 3}; + + BLT_ASSERT(g1 * g2.transpose() == blt::vec4::dot(one, two) && "MATH DOT FAILURE"); + } + + template + float crosstalk(const input_t& i, const input_t& j) + { + return i * j.transpose(); + } + + template + blt::generalized_matrix threshold(const blt::generalized_matrix& y, + const blt::generalized_matrix& base) + { + blt::generalized_matrix result; + for (blt::u32 i = 0; i < columns; i++) + { + for (blt::u32 j = 0; j < rows; j++) + result[i][j] = y[i][j] > 1 ? 1 : (y[i][j] < -1 ? -1 : base[i][j]); + } + return result; + } + + template + bool equal(const std::array& a, const std::array& b) + { + for (const auto& [index, val] : blt::enumerate(a)) + { + if (b[index] != val) + return false; + } + return true; + } + + template + std::pair run_step(const weight_t& associated_weights, const input_t& input, const output_t& output) + { + output_t output_recall = input * associated_weights; + input_t input_recall = output * associated_weights.transpose(); + + return std::pair{a1::threshold(input_recall, input), a1::threshold(output_recall, output)}; + } + + template + void check_recall(const weight_t& weights, const std::array& inputs, const std::array& outputs) + { + for (const auto& [index, val] : blt::enumerate(inputs)) + { + auto result = run_step(weights, val, outputs[index]); + if (result.first != val) + BLT_ERROR("Recall of input #%ld failed", index + 1); + else + BLT_INFO("Recall of input #%ld passed", index + 1); + if (result.second != outputs[index]) + BLT_ERROR("Recall of output #%ld failed", index + 1); + else + BLT_INFO("recall of output #%ld passed", index + 1); + } + } +} + +#endif //COSC_4P80_ASSIGNMENT_1_A1_H diff --git a/include/fwd_decl.h b/include/fwd_decl.h deleted file mode 100644 index d824303..0000000 --- a/include/fwd_decl.h +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once -/* - * Copyright (C) 2024 Brett Terpstra - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -#ifndef COSC_4P80_ASSIGNMENT_1_FWD_DECL_H -#define COSC_4P80_ASSIGNMENT_1_FWD_DECL_H - -#include -#include - -namespace a1 -{ - void test_math() - { - blt::generalized_matrix input{1, -1, -1, 1}; - blt::generalized_matrix output{1, 1, 1}; - blt::generalized_matrix expected{ - blt::vec4{1, -1, -1, 1}, - blt::vec4{1, -1, -1, 1}, - blt::vec4{1, -1, -1, 1} - }; - - auto w_matrix = input.transpose() * output; - BLT_ASSERT(w_matrix == expected && "MATH MATRIX FAILURE"); - - blt::vec4 one{5, 1, 3, 0}; - blt::vec4 two{9, -5, -8, 3}; - - blt::generalized_matrix g1{5, 1, 3, 0}; - blt::generalized_matrix g2{9, -5, -8, 3}; - - BLT_ASSERT(g1 * g2.transpose() == blt::vec4::dot(one, two) && "MATH DOT FAILURE"); - } - - enum class recall_error_t - { - // failed to predict input - INPUT_FAILURE, - // failed to predict output - OUTPUT_FAILURE - }; -} - -#endif //COSC_4P80_ASSIGNMENT_1_FWD_DECL_H diff --git a/src/main.cpp b/src/main.cpp index 5151083..fe20233 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -3,7 +3,7 @@ #include #include "blt/std/assert.h" #include -#include +#include constexpr blt::u32 num_values = 4; constexpr blt::u32 input_count = 5; @@ -14,58 +14,83 @@ using output_t = blt::generalized_matrix; using weight_t = decltype(std::declval().transpose() * std::declval()); using crosstalk_t = blt::generalized_matrix; -float crosstalk(const input_t& i, const input_t& j) -{ - return i * j.transpose(); -} - +// part a input_t input_1{-1, 1, 1, 1, -1}; input_t input_2{-1, -1, -1, -1, 1}; input_t input_3{-1, -1, -1, 1, 1}; +// part c 1 input_t input_4{1, 1, 1, 1, 1}; +// part c 2 +input_t input_5{-1, 1, -1, 1, 1}; +input_t input_6{1, -1, 1, -1, 1}; +input_t input_7{-1, 1, -1, 1, -1}; +// part a output_t output_1{1, 1, -1, 1}; output_t output_2{1, -1, -1, -1}; output_t output_3{-1, -1, 1, 1}; +// part c 1 output_t output_4{-1, 1, 1, -1}; +// part c 2 +output_t output_5{1, 1, 1, 1}; +output_t output_6{1, -1, -1, 1}; +output_t output_7{1, 1, 1, -1}; const weight_t weight_1 = input_1.transpose() * output_1; const weight_t weight_2 = input_2.transpose() * output_2; const weight_t weight_3 = input_3.transpose() * output_3; const weight_t weight_4 = input_4.transpose() * output_4; +const weight_t weight_5 = input_5.transpose() * output_5; +const weight_t weight_6 = input_6.transpose() * output_6; +const weight_t weight_7 = input_7.transpose() * output_7; -auto starting_inputs = std::array{input_1, input_2, input_3, input_4}; -auto starting_outputs = std::array{output_1, output_2, output_3, output_4}; +auto part_a_inputs = std::array{input_1, input_2, input_3}; +auto part_a_outputs = std::array{output_1, output_2, output_3}; + +auto part_c_1_inputs = std::array{input_1, input_2, input_3, input_4}; +auto part_c_1_outputs = std::array{output_1, output_2, output_3, output_4}; + +auto part_c_2_inputs = std::array{input_1, input_2, input_3, input_4, input_5, input_6, input_7}; +auto part_c_2_outputs = std::array{output_1, output_2, output_3, output_4, output_5, output_6, output_7}; const auto weight_total_a = weight_1 + weight_2 + weight_3; const auto weight_total_c = weight_total_a + weight_4; +const auto weight_total_c_2 = weight_total_c + weight_5 + weight_6 + weight_7; crosstalk_t crosstalk_values{}; -template -blt::generalized_matrix threshold(const blt::generalized_matrix& y, const blt::generalized_matrix& base) +template +void execute_BAM(const Weights& weights, const Inputs& input, const Outputs& output) { - blt::generalized_matrix result; - for (blt::u32 i = 0; i < columns; i++) - { - for (blt::u32 j = 0; j < rows; j++) - result[i][j] = y[i][j] > 1 ? 1 : (y[i][j] < -1 ? -1 : base[i][j]); - } - return result; -} - -std::pair run_step(const weight_t& associated_weights, const input_t& input, const output_t & output) -{ - output_t output_recall = input * associated_weights; - input_t input_recall = output * associated_weights.transpose(); + auto current_inputs = input; + auto current_outputs = output; + auto next_inputs = current_inputs; + auto next_outputs = current_outputs; + blt::size_t iterations = 0; + constexpr blt::size_t max_iterations = 5; - return std::pair{threshold(input_recall, input), threshold(output_recall, output)}; + do + { + current_inputs = next_inputs; + current_outputs = next_outputs; + ++iterations; + for (const auto& [index, val] : blt::enumerate(current_inputs)) + { + auto next = a1::run_step(weights, val, current_outputs[index]); + next_inputs[index] = next.first; + next_outputs[index] = next.second; + } + } while ((!a1::equal(current_inputs, next_inputs) || !a1::equal(current_outputs, next_outputs)) && iterations < max_iterations); + + BLT_DEBUG("Tracked after %ld iterations", iterations); + a1::check_recall(weights, next_inputs, next_outputs); } void part_a() { blt::log_box_t box(BLT_TRACE_STREAM, "Part A", 8); + execute_BAM(weight_total_a, part_a_inputs, part_a_outputs); } void part_b() @@ -78,7 +103,7 @@ void part_b() { if (i == k) continue; - accum += (outputs[k] * crosstalk(inputs[k].normalize(), inputs[i].normalize())); + accum += (part_a_outputs[k] * a1::crosstalk(part_a_inputs[k].normalize(), part_a_inputs[i].normalize())); } crosstalk_values.assign_to_column_from_column_rows(accum, i); } @@ -88,11 +113,20 @@ void part_b() } } +void part_c() +{ + blt::log_box_t box(BLT_TRACE_STREAM, "Part C", 8); + execute_BAM(weight_total_c, part_c_1_inputs, part_c_1_outputs); + BLT_TRACE("--- { Part C with 3 extra pairs } ---"); + execute_BAM(weight_total_c_2, part_c_2_inputs, part_c_2_outputs); +} + int main() { blt::logging::setLogOutputFormat("\033[94m[${{TIME}}]${{RC}} \033[35m(${{FILE}}:${{LINE}})${{RC}} ${{LF}}${{CNR}}${{STR}}${{RC}}\n"); - test_math(); + a1::test_math(); part_a(); part_b(); + part_c(); }