diff --git a/include/fwd_decl.h b/include/fwd_decl.h
new file mode 100644
index 0000000..d824303
--- /dev/null
+++ b/include/fwd_decl.h
@@ -0,0 +1,58 @@
+#pragma once
+/*
+ * Copyright (C) 2024 Brett Terpstra
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+#ifndef COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
+#define COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
+
+#include
+#include
+
+namespace a1
+{
+ void test_math()
+ {
+ blt::generalized_matrix input{1, -1, -1, 1};
+ blt::generalized_matrix output{1, 1, 1};
+ blt::generalized_matrix expected{
+ blt::vec4{1, -1, -1, 1},
+ blt::vec4{1, -1, -1, 1},
+ blt::vec4{1, -1, -1, 1}
+ };
+
+ auto w_matrix = input.transpose() * output;
+ BLT_ASSERT(w_matrix == expected && "MATH MATRIX FAILURE");
+
+ blt::vec4 one{5, 1, 3, 0};
+ blt::vec4 two{9, -5, -8, 3};
+
+ blt::generalized_matrix g1{5, 1, 3, 0};
+ blt::generalized_matrix g2{9, -5, -8, 3};
+
+ BLT_ASSERT(g1 * g2.transpose() == blt::vec4::dot(one, two) && "MATH DOT FAILURE");
+ }
+
+ enum class recall_error_t
+ {
+ // failed to predict input
+ INPUT_FAILURE,
+ // failed to predict output
+ OUTPUT_FAILURE
+ };
+}
+
+#endif //COSC_4P80_ASSIGNMENT_1_FWD_DECL_H
diff --git a/lib/blt b/lib/blt
index b5ea7a1..7300f89 160000
--- a/lib/blt
+++ b/lib/blt
@@ -1 +1 @@
-Subproject commit b5ea7a1e1500dc695490c730dcedbc93dae3ba73
+Subproject commit 7300f895bb8c1e7f1c5a96866d466126ee861281
diff --git a/src/main.cpp b/src/main.cpp
index 683d7ed..5151083 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -3,20 +3,7 @@
#include
#include "blt/std/assert.h"
#include
-
-void test_math()
-{
- blt::generalized_matrix input{1, -1, -1, 1};
- blt::generalized_matrix output{1, 1, 1};
- blt::generalized_matrix expected{
- blt::vec4{1, -1, -1, 1},
- blt::vec4{1, -1, -1, 1},
- blt::vec4{1, -1, -1, 1}
- };
-
- auto w_matrix = input.transpose() * output;
- BLT_ASSERT(w_matrix == expected && "MATH FAILURE");
-}
+#include
constexpr blt::u32 num_values = 4;
constexpr blt::u32 input_count = 5;
@@ -42,67 +29,43 @@ output_t output_2{1, -1, -1, -1};
output_t output_3{-1, -1, 1, 1};
output_t output_4{-1, 1, 1, -1};
-weight_t weight_1 = input_1.transpose() * output_1;
-weight_t weight_2 = input_2.transpose() * output_2;
-weight_t weight_3 = input_3.transpose() * output_3;
-weight_t weight_4 = input_4.transpose() * output_4;
+const weight_t weight_1 = input_1.transpose() * output_1;
+const weight_t weight_2 = input_2.transpose() * output_2;
+const weight_t weight_3 = input_3.transpose() * output_3;
+const weight_t weight_4 = input_4.transpose() * output_4;
-auto inputs = std::array{input_1, input_2, input_3, input_4};
-auto outputs = std::array{output_1, output_2, output_3, output_4};
+auto starting_inputs = std::array{input_1, input_2, input_3, input_4};
+auto starting_outputs = std::array{output_1, output_2, output_3, output_4};
-auto weight_total_a = weight_1 + weight_2 + weight_3;
-auto weight_total_c = weight_total_a + weight_4;
+const auto weight_total_a = weight_1 + weight_2 + weight_3;
+const auto weight_total_c = weight_total_a + weight_4;
crosstalk_t crosstalk_values{};
template
-blt::generalized_matrix normalize(const blt::generalized_matrix& in)
+blt::generalized_matrix threshold(const blt::generalized_matrix& y, const blt::generalized_matrix& base)
{
blt::generalized_matrix result;
for (blt::u32 i = 0; i < columns; i++)
{
for (blt::u32 j = 0; j < rows; j++)
- result[i][j] = in[i][j] >= 0 ? 1 : -1;
+ result[i][j] = y[i][j] > 1 ? 1 : (y[i][j] < -1 ? -1 : base[i][j]);
}
return result;
}
-auto calculate_recall()
+std::pair run_step(const weight_t& associated_weights, const input_t& input, const output_t & output)
{
-
-}
-
-void test_recall(blt::size_t index, const weight_t& associated_weights)
-{
- auto& input = inputs[index];
- auto& output = outputs[index];
+ output_t output_recall = input * associated_weights;
+ input_t input_recall = output * associated_weights.transpose();
- auto output_recall = normalize(input * associated_weights);
- auto input_recall = normalize(output * associated_weights.transpose());
-
- if (output_recall != output)
- {
- BLT_ERROR_STREAM << "Output '" << index + 1 << "' recalled failed!" << '\n';
- BLT_WARN_STREAM << "\t- Found: " << output_recall.vec_from_column_row() << '\n';
- BLT_WARN_STREAM << "\t- Expected: " << output.vec_from_column_row() << '\n';
- } else
- BLT_INFO("Output '%ld' recall passed!", index + 1);
-
- if (input_recall != input)
- {
- BLT_ERROR_STREAM << "Input '" << index + 1 << "' recalled failed!" << "\n";
- BLT_WARN_STREAM << "\t- Found: " << input_recall.vec_from_column_row() << '\n';
- BLT_WARN_STREAM << "\t- Expected: " << input.vec_from_column_row() << '\n';
- } else
- BLT_INFO("Input '%ld' recall passed!", index + 1);
+ return std::pair{threshold(input_recall, input), threshold(output_recall, output)};
}
void part_a()
{
blt::log_box_t box(BLT_TRACE_STREAM, "Part A", 8);
- test_recall(0, weight_total_a);
- test_recall(1, weight_total_a);
- test_recall(2, weight_total_a);
+
}
void part_b()