diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1eede30..d68395f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
-project(COSC-4P80-Assignment-1 VERSION 0.0.7)
+project(COSC-4P80-Assignment-1 VERSION 0.0.8)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
diff --git a/include/a1.h b/include/a1.h
new file mode 100644
index 0000000..4ac845e
--- /dev/null
+++ b/include/a1.h
@@ -0,0 +1,106 @@
+#pragma once
+/*
+ * Copyright (C) 2024 Brett Terpstra
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef COSC_4P80_ASSIGNMENT_1_A1_H
+#define COSC_4P80_ASSIGNMENT_1_A1_H
+
+#include
+#include
+
+namespace a1
+{
+ void test_math()
+ {
+ blt::generalized_matrix input{1, -1, -1, 1};
+ blt::generalized_matrix output{1, 1, 1};
+ blt::generalized_matrix expected{
+ blt::vec4{1, -1, -1, 1},
+ blt::vec4{1, -1, -1, 1},
+ blt::vec4{1, -1, -1, 1}
+ };
+
+ auto w_matrix = input.transpose() * output;
+ BLT_ASSERT(w_matrix == expected && "MATH MATRIX FAILURE");
+
+ blt::vec4 one{5, 1, 3, 0};
+ blt::vec4 two{9, -5, -8, 3};
+
+ blt::generalized_matrix g1{5, 1, 3, 0};
+ blt::generalized_matrix g2{9, -5, -8, 3};
+
+ BLT_ASSERT(g1 * g2.transpose() == blt::vec4::dot(one, two) && "MATH DOT FAILURE");
+ }
+
+ template
+ float crosstalk(const input_t& i, const input_t& j)
+ {
+ return i * j.transpose();
+ }
+
+ template
+ blt::generalized_matrix threshold(const blt::generalized_matrix& y,
+ const blt::generalized_matrix& base)
+ {
+ blt::generalized_matrix result;
+ for (blt::u32 i = 0; i < columns; i++)
+ {
+ for (blt::u32 j = 0; j < rows; j++)
+ result[i][j] = y[i][j] > 1 ? 1 : (y[i][j] < -1 ? -1 : base[i][j]);
+ }
+ return result;
+ }
+
+ template
+ bool equal(const std::array& a, const std::array& b)
+ {
+ for (const auto& [index, val] : blt::enumerate(a))
+ {
+ if (b[index] != val)
+ return false;
+ }
+ return true;
+ }
+
+ template
+ std::pair run_step(const weight_t& associated_weights, const input_t& input, const output_t& output)
+ {
+ output_t output_recall = input * associated_weights;
+ input_t input_recall = output * associated_weights.transpose();
+
+ return std::pair{a1::threshold(input_recall, input), a1::threshold(output_recall, output)};
+ }
+
+ template
+ void check_recall(const weight_t& weights, const std::array& inputs, const std::array& outputs)
+ {
+ for (const auto& [index, val] : blt::enumerate(inputs))
+ {
+ auto result = run_step(weights, val, outputs[index]);
+ if (result.first != val)
+ BLT_ERROR("Recall of input #%ld failed", index + 1);
+ else
+ BLT_INFO("Recall of input #%ld passed", index + 1);
+ if (result.second != outputs[index])
+ BLT_ERROR("Recall of output #%ld failed", index + 1);
+ else
+ BLT_INFO("recall of output #%ld passed", index + 1);
+ }
+ }
+}
+
+#endif //COSC_4P80_ASSIGNMENT_1_A1_H
diff --git a/include/fwd_decl.h b/include/fwd_decl.h
deleted file mode 100644
index d824303..0000000
--- a/include/fwd_decl.h
+++ /dev/null
@@ -1,58 +0,0 @@
-#pragma once
-/*
- * Copyright (C) 2024 Brett Terpstra
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- */
-
-#ifndef COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
-#define COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
-
-#include
-#include
-
-namespace a1
-{
- void test_math()
- {
- blt::generalized_matrix input{1, -1, -1, 1};
- blt::generalized_matrix output{1, 1, 1};
- blt::generalized_matrix expected{
- blt::vec4{1, -1, -1, 1},
- blt::vec4{1, -1, -1, 1},
- blt::vec4{1, -1, -1, 1}
- };
-
- auto w_matrix = input.transpose() * output;
- BLT_ASSERT(w_matrix == expected && "MATH MATRIX FAILURE");
-
- blt::vec4 one{5, 1, 3, 0};
- blt::vec4 two{9, -5, -8, 3};
-
- blt::generalized_matrix g1{5, 1, 3, 0};
- blt::generalized_matrix g2{9, -5, -8, 3};
-
- BLT_ASSERT(g1 * g2.transpose() == blt::vec4::dot(one, two) && "MATH DOT FAILURE");
- }
-
- enum class recall_error_t
- {
- // failed to predict input
- INPUT_FAILURE,
- // failed to predict output
- OUTPUT_FAILURE
- };
-}
-
-#endif //COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
diff --git a/src/main.cpp b/src/main.cpp
index 5151083..fe20233 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -3,7 +3,7 @@
#include
#include "blt/std/assert.h"
#include
-#include
+#include
constexpr blt::u32 num_values = 4;
constexpr blt::u32 input_count = 5;
@@ -14,58 +14,83 @@ using output_t = blt::generalized_matrix;
using weight_t = decltype(std::declval().transpose() * std::declval());
using crosstalk_t = blt::generalized_matrix;
-float crosstalk(const input_t& i, const input_t& j)
-{
- return i * j.transpose();
-}
-
+// part a
input_t input_1{-1, 1, 1, 1, -1};
input_t input_2{-1, -1, -1, -1, 1};
input_t input_3{-1, -1, -1, 1, 1};
+// part c 1
input_t input_4{1, 1, 1, 1, 1};
+// part c 2
+input_t input_5{-1, 1, -1, 1, 1};
+input_t input_6{1, -1, 1, -1, 1};
+input_t input_7{-1, 1, -1, 1, -1};
+// part a
output_t output_1{1, 1, -1, 1};
output_t output_2{1, -1, -1, -1};
output_t output_3{-1, -1, 1, 1};
+// part c 1
output_t output_4{-1, 1, 1, -1};
+// part c 2
+output_t output_5{1, 1, 1, 1};
+output_t output_6{1, -1, -1, 1};
+output_t output_7{1, 1, 1, -1};
const weight_t weight_1 = input_1.transpose() * output_1;
const weight_t weight_2 = input_2.transpose() * output_2;
const weight_t weight_3 = input_3.transpose() * output_3;
const weight_t weight_4 = input_4.transpose() * output_4;
+const weight_t weight_5 = input_5.transpose() * output_5;
+const weight_t weight_6 = input_6.transpose() * output_6;
+const weight_t weight_7 = input_7.transpose() * output_7;
-auto starting_inputs = std::array{input_1, input_2, input_3, input_4};
-auto starting_outputs = std::array{output_1, output_2, output_3, output_4};
+auto part_a_inputs = std::array{input_1, input_2, input_3};
+auto part_a_outputs = std::array{output_1, output_2, output_3};
+
+auto part_c_1_inputs = std::array{input_1, input_2, input_3, input_4};
+auto part_c_1_outputs = std::array{output_1, output_2, output_3, output_4};
+
+auto part_c_2_inputs = std::array{input_1, input_2, input_3, input_4, input_5, input_6, input_7};
+auto part_c_2_outputs = std::array{output_1, output_2, output_3, output_4, output_5, output_6, output_7};
const auto weight_total_a = weight_1 + weight_2 + weight_3;
const auto weight_total_c = weight_total_a + weight_4;
+const auto weight_total_c_2 = weight_total_c + weight_5 + weight_6 + weight_7;
crosstalk_t crosstalk_values{};
-template
-blt::generalized_matrix threshold(const blt::generalized_matrix& y, const blt::generalized_matrix& base)
+template
+void execute_BAM(const Weights& weights, const Inputs& input, const Outputs& output)
{
- blt::generalized_matrix result;
- for (blt::u32 i = 0; i < columns; i++)
- {
- for (blt::u32 j = 0; j < rows; j++)
- result[i][j] = y[i][j] > 1 ? 1 : (y[i][j] < -1 ? -1 : base[i][j]);
- }
- return result;
-}
-
-std::pair run_step(const weight_t& associated_weights, const input_t& input, const output_t & output)
-{
- output_t output_recall = input * associated_weights;
- input_t input_recall = output * associated_weights.transpose();
+ auto current_inputs = input;
+ auto current_outputs = output;
+ auto next_inputs = current_inputs;
+ auto next_outputs = current_outputs;
+ blt::size_t iterations = 0;
+ constexpr blt::size_t max_iterations = 5;
- return std::pair{threshold(input_recall, input), threshold(output_recall, output)};
+ do
+ {
+ current_inputs = next_inputs;
+ current_outputs = next_outputs;
+ ++iterations;
+ for (const auto& [index, val] : blt::enumerate(current_inputs))
+ {
+ auto next = a1::run_step(weights, val, current_outputs[index]);
+ next_inputs[index] = next.first;
+ next_outputs[index] = next.second;
+ }
+ } while ((!a1::equal(current_inputs, next_inputs) || !a1::equal(current_outputs, next_outputs)) && iterations < max_iterations);
+
+ BLT_DEBUG("Tracked after %ld iterations", iterations);
+ a1::check_recall(weights, next_inputs, next_outputs);
}
void part_a()
{
blt::log_box_t box(BLT_TRACE_STREAM, "Part A", 8);
+ execute_BAM(weight_total_a, part_a_inputs, part_a_outputs);
}
void part_b()
@@ -78,7 +103,7 @@ void part_b()
{
if (i == k)
continue;
- accum += (outputs[k] * crosstalk(inputs[k].normalize(), inputs[i].normalize()));
+ accum += (part_a_outputs[k] * a1::crosstalk(part_a_inputs[k].normalize(), part_a_inputs[i].normalize()));
}
crosstalk_values.assign_to_column_from_column_rows(accum, i);
}
@@ -88,11 +113,20 @@ void part_b()
}
}
+void part_c()
+{
+ blt::log_box_t box(BLT_TRACE_STREAM, "Part C", 8);
+ execute_BAM(weight_total_c, part_c_1_inputs, part_c_1_outputs);
+ BLT_TRACE("--- { Part C with 3 extra pairs } ---");
+ execute_BAM(weight_total_c_2, part_c_2_inputs, part_c_2_outputs);
+}
+
int main()
{
blt::logging::setLogOutputFormat("\033[94m[${{TIME}}]${{RC}} \033[35m(${{FILE}}:${{LINE}})${{RC}} ${{LF}}${{CNR}}${{STR}}${{RC}}\n");
- test_math();
+ a1::test_math();
part_a();
part_b();
+ part_c();
}