pairs
parent
fcb28332a5
commit
3bf43d0043
|
@ -1,5 +1,5 @@
|
|||
cmake_minimum_required(VERSION 3.25)
|
||||
project(COSC-4P80-Assignment-1 VERSION 0.0.8)
|
||||
project(COSC-4P80-Assignment-1 VERSION 0.0.9)
|
||||
|
||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||
|
|
26
include/a1.h
26
include/a1.h
|
@ -19,6 +19,7 @@
|
|||
#ifndef COSC_4P80_ASSIGNMENT_1_A1_H
|
||||
#define COSC_4P80_ASSIGNMENT_1_A1_H
|
||||
|
||||
#include <blt/std/logging.h>
|
||||
#include <blt/math/matrix.h>
|
||||
#include <blt/math/log_util.h>
|
||||
|
||||
|
@ -65,22 +66,33 @@ namespace a1
|
|||
return result;
|
||||
}
|
||||
|
||||
template<typename T, blt::size_t size>
|
||||
blt::size_t difference(const std::array<T, size>& a, const std::array<T, size>& b)
|
||||
{
|
||||
blt::size_t count = 0;
|
||||
for (const auto& [a_val, b_val] : blt::in_pairs(a, b))
|
||||
{
|
||||
if (a_val != b_val)
|
||||
count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
template<typename T, blt::size_t size>
|
||||
bool equal(const std::array<T, size>& a, const std::array<T, size>& b)
|
||||
{
|
||||
for (const auto& [index, val] : blt::enumerate(a))
|
||||
{
|
||||
if (b[index] != val)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return difference(a, b) == 0;
|
||||
}
|
||||
|
||||
template<typename weight_t, typename input_t, typename output_t>
|
||||
std::pair<input_t, output_t> run_step(const weight_t& associated_weights, const input_t& input, const output_t& output)
|
||||
{
|
||||
output_t output_recall = input * associated_weights;
|
||||
input_t input_recall = output * associated_weights.transpose();
|
||||
input_t input_recall = output_recall * associated_weights.transpose();
|
||||
|
||||
// BLT_DEBUG_STREAM << "Input: " << input.vec_from_column_row() << "\nOutput: " << output.vec_from_column_row() << '\n';
|
||||
// BLT_DEBUG_STREAM << "Recalled Input: " << a1::threshold(input_recall, input).vec_from_column_row() << "\nRecalled Output: "
|
||||
// << a1::threshold(output_recall, output).vec_from_column_row() << '\n';
|
||||
|
||||
return std::pair{a1::threshold(input_recall, input), a1::threshold(output_recall, output)};
|
||||
}
|
||||
|
|
2
lib/blt
2
lib/blt
|
@ -1 +1 @@
|
|||
Subproject commit 7300f895bb8c1e7f1c5a96866d466126ee861281
|
||||
Subproject commit 96b071e337a7124d4afe2f633f11f9be6d965ac8
|
|
@ -80,6 +80,7 @@ void execute_BAM(const Weights& weights, const Inputs& input, const Outputs& out
|
|||
next_inputs[index] = next.first;
|
||||
next_outputs[index] = next.second;
|
||||
}
|
||||
// loop until no changes or we hit the iteration limit
|
||||
} while ((!a1::equal(current_inputs, next_inputs) || !a1::equal(current_outputs, next_outputs)) && iterations < max_iterations);
|
||||
|
||||
BLT_DEBUG("Tracked after %ld iterations", iterations);
|
||||
|
|
Loading…
Reference in New Issue