Brett 2024-09-23 15:51:35 -04:00
parent fcb28332a5
commit 3bf43d0043
4 changed files with 22 additions and 9 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-1 VERSION 0.0.8) project(COSC-4P80-Assignment-1 VERSION 0.0.9)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -19,6 +19,7 @@
#ifndef COSC_4P80_ASSIGNMENT_1_A1_H #ifndef COSC_4P80_ASSIGNMENT_1_A1_H
#define COSC_4P80_ASSIGNMENT_1_A1_H #define COSC_4P80_ASSIGNMENT_1_A1_H
#include <blt/std/logging.h>
#include <blt/math/matrix.h> #include <blt/math/matrix.h>
#include <blt/math/log_util.h> #include <blt/math/log_util.h>
@ -65,22 +66,33 @@ namespace a1
return result; return result;
} }
template<typename T, blt::size_t size>
blt::size_t difference(const std::array<T, size>& a, const std::array<T, size>& b)
{
blt::size_t count = 0;
for (const auto& [a_val, b_val] : blt::in_pairs(a, b))
{
if (a_val != b_val)
count++;
}
return count;
}
template<typename T, blt::size_t size> template<typename T, blt::size_t size>
bool equal(const std::array<T, size>& a, const std::array<T, size>& b) bool equal(const std::array<T, size>& a, const std::array<T, size>& b)
{ {
for (const auto& [index, val] : blt::enumerate(a)) return difference(a, b) == 0;
{
if (b[index] != val)
return false;
}
return true;
} }
template<typename weight_t, typename input_t, typename output_t> template<typename weight_t, typename input_t, typename output_t>
std::pair<input_t, output_t> run_step(const weight_t& associated_weights, const input_t& input, const output_t& output) std::pair<input_t, output_t> run_step(const weight_t& associated_weights, const input_t& input, const output_t& output)
{ {
output_t output_recall = input * associated_weights; output_t output_recall = input * associated_weights;
input_t input_recall = output * associated_weights.transpose(); input_t input_recall = output_recall * associated_weights.transpose();
// BLT_DEBUG_STREAM << "Input: " << input.vec_from_column_row() << "\nOutput: " << output.vec_from_column_row() << '\n';
// BLT_DEBUG_STREAM << "Recalled Input: " << a1::threshold(input_recall, input).vec_from_column_row() << "\nRecalled Output: "
// << a1::threshold(output_recall, output).vec_from_column_row() << '\n';
return std::pair{a1::threshold(input_recall, input), a1::threshold(output_recall, output)}; return std::pair{a1::threshold(input_recall, input), a1::threshold(output_recall, output)};
} }

@ -1 +1 @@
Subproject commit 7300f895bb8c1e7f1c5a96866d466126ee861281 Subproject commit 96b071e337a7124d4afe2f633f11f9be6d965ac8

View File

@ -80,6 +80,7 @@ void execute_BAM(const Weights& weights, const Inputs& input, const Outputs& out
next_inputs[index] = next.first; next_inputs[index] = next.first;
next_outputs[index] = next.second; next_outputs[index] = next.second;
} }
// loop until no changes or we hit the iteration limit
} while ((!a1::equal(current_inputs, next_inputs) || !a1::equal(current_outputs, next_outputs)) && iterations < max_iterations); } while ((!a1::equal(current_inputs, next_inputs) || !a1::equal(current_outputs, next_outputs)) && iterations < max_iterations);
BLT_DEBUG("Tracked after %ld iterations", iterations); BLT_DEBUG("Tracked after %ld iterations", iterations);