diff --git a/CMakeLists.txt b/CMakeLists.txt index d68395f..8c726bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Assignment-1 VERSION 0.0.8) +project(COSC-4P80-Assignment-1 VERSION 0.0.9) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/include/a1.h b/include/a1.h index 4ac845e..30fb813 100644 --- a/include/a1.h +++ b/include/a1.h @@ -19,6 +19,7 @@ #ifndef COSC_4P80_ASSIGNMENT_1_A1_H #define COSC_4P80_ASSIGNMENT_1_A1_H +#include #include #include @@ -65,22 +66,33 @@ namespace a1 return result; } + template + blt::size_t difference(const std::array& a, const std::array& b) + { + blt::size_t count = 0; + for (const auto& [a_val, b_val] : blt::in_pairs(a, b)) + { + if (a_val != b_val) + count++; + } + return count; + } + template bool equal(const std::array& a, const std::array& b) { - for (const auto& [index, val] : blt::enumerate(a)) - { - if (b[index] != val) - return false; - } - return true; + return difference(a, b) == 0; } template std::pair run_step(const weight_t& associated_weights, const input_t& input, const output_t& output) { output_t output_recall = input * associated_weights; - input_t input_recall = output * associated_weights.transpose(); + input_t input_recall = output_recall * associated_weights.transpose(); + +// BLT_DEBUG_STREAM << "Input: " << input.vec_from_column_row() << "\nOutput: " << output.vec_from_column_row() << '\n'; +// BLT_DEBUG_STREAM << "Recalled Input: " << a1::threshold(input_recall, input).vec_from_column_row() << "\nRecalled Output: " +// << a1::threshold(output_recall, output).vec_from_column_row() << '\n'; return std::pair{a1::threshold(input_recall, input), a1::threshold(output_recall, output)}; } diff --git a/lib/blt b/lib/blt index 7300f89..96b071e 160000 --- a/lib/blt +++ b/lib/blt @@ -1 +1 @@ -Subproject commit 7300f895bb8c1e7f1c5a96866d466126ee861281 +Subproject commit 96b071e337a7124d4afe2f633f11f9be6d965ac8 diff --git a/src/main.cpp b/src/main.cpp index fe20233..0b4d999 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -80,6 +80,7 @@ void execute_BAM(const Weights& weights, const Inputs& input, const Outputs& out next_inputs[index] = next.first; next_outputs[index] = next.second; } + // loop until no changes or we hit the iteration limit } while ((!a1::equal(current_inputs, next_inputs) || !a1::equal(current_outputs, next_outputs)) && iterations < max_iterations); BLT_DEBUG("Tracked after %ld iterations", iterations);