diff --git a/CMakeLists.txt b/CMakeLists.txt index 08fa53d..46b13e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Assignment-1 VERSION 0.0.21) +project(COSC-4P80-Assignment-1 VERSION 0.0.22) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) diff --git a/src/main.cpp b/src/main.cpp index 787b7bf..37a9995 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include constexpr blt::u32 num_values_part_a = 3; @@ -14,10 +15,26 @@ constexpr blt::u32 num_values_part_c2 = 7; constexpr blt::u32 input_vec_size = 5; constexpr blt::u32 output_vec_size = 4; +bool print_latex = false; + using input_t = a1::matrix_t<1, input_vec_size>; using output_t = a1::matrix_t<1, output_vec_size>; using weight_t = decltype(std::declval().transpose() * std::declval()); +template +Os& print_vec_square(Os& o, const blt::vec& v) +{ + o << "["; + for (auto [i, f] : blt::enumerate(v)) + { + o << f; + if (i != size - 1) + o << ", "; + } + o << "]"; + return o; +} + struct correctness_t { blt::size_t correct_input = 0; @@ -184,6 +201,20 @@ class executor } while (steps.rbegin()[0] != steps.rbegin()[1]); } + [[nodiscard]] input_t correct(const input_t& v) const + { + // outputs here do not matter. + ping_pong current{weights, v, outputs.front()}; + ping_pong next{weights, v, outputs.front()}; + do + { + current = next; + next = current.run_step(); + // run until stability + } while (current != next); + return next.get_input(); + } + void print_execution_summary() { using namespace blt::logging; @@ -342,11 +373,6 @@ class executor { return steps.back(); } - - std::vector& get_inputs() - { - return inputs; - } private: weight_t weights; @@ -422,7 +448,7 @@ blt::size_t hdist(const input_t& a, const input_t& b) { blt::size_t diff = 0; for (auto [av, bv] : blt::in_pairs(a.vec_from_column_row(), b.vec_from_column_row())) - diff += (av == bv ? 1 : 0); + diff += (av != bv ? 1 : 0); return diff; } @@ -430,41 +456,102 @@ void part_d() { blt::log_box_t box(BLT_TRACE_STREAM, "Part D", 8); blt::random::random_t random(std::random_device{}()); - blt::size_t number_of_runs = 20; + executor cute(part_a_inputs, part_a_outputs); + constexpr blt::size_t number_of_runs = 80; + std::vector mutations; + std::vector corrections; + blt::size_t total_corrections = 0; + blt::size_t total_mutations = 0; + blt::size_t min_corrections = std::numeric_limits::max(); + blt::size_t max_corrections = 0; + blt::size_t min_mutations = std::numeric_limits::max(); + blt::size_t max_mutations = 0; for (blt::size_t run = 0; run < number_of_runs; run++) { - auto inputs = part_a_inputs; - - executor cute(part_a_inputs, part_a_outputs); - - auto pos = random.get_size_t(0, inputs.size()); - auto& input = inputs[pos]; - auto original = input; - for (blt::size_t i = 0; i < std::remove_reference_t::data_columns; i++) + auto pos = random.get_size_t(0, part_a_inputs.size()); + auto original = part_a_inputs[pos]; + auto modified = original; + for (blt::size_t i = 0; i < std::remove_reference_t::data_columns; i++) { - if (random.choice(0.8)) + if (random.choice(0.2)) { // flip value of this location - auto& d = input[i][0]; + auto& d = modified[i][0]; if (d >= 0) d = -1; else d = 1; } } - cute.get_inputs()[pos] = input; - cute.execute(); - auto corrected = cute.get_results()[pos].get_input(); + auto corrected = cute.correct(modified); - auto dist_o_m = hdist(original, input); - auto dist_m_c = hdist(input, corrected); + auto dist_o_m = hdist(original, modified); + auto dist_o_c = hdist(original, corrected); - BLT_TRACE("Run %ld mutated difference: %ld corrected difference: %ld", run, dist_o_m, dist_m_c); + corrections.push_back(dist_o_c); + mutations.push_back(dist_o_m); + total_corrections += dist_o_c; + total_mutations += dist_o_m; + + min_corrections = std::min(dist_o_c, min_corrections); + max_corrections = std::max(dist_o_c, max_corrections); + + min_mutations = std::min(dist_o_m, min_mutations); + max_mutations = std::max(dist_o_m, max_mutations); + + if (print_latex) + { + std::cout << run + 1 << " & "; + print_vec_square(std::cout, original.vec_from_column_row()) << " & "; + print_vec_square(std::cout, modified.vec_from_column_row()) << " & "; + std::cout << dist_o_m << " & "; + print_vec_square(std::cout, corrected.vec_from_column_row()) << " & "; + std::cout << dist_o_c << " \\\\ \n\\hline\n"; + } else + { + BLT_TRACE_STREAM << "Run " << run << " " << original.vec_from_column_row() << " || mutated " << modified.vec_from_column_row() + << " difference: " << dist_o_m << " || corrected " << corrected.vec_from_column_row() << " || difference: " << dist_o_c + << "\n"; + } } + double mean_corrections = static_cast(total_corrections) / number_of_runs; + double mean_mutations = static_cast(total_mutations) / number_of_runs; + + double stddev_corrections = 0; + double stddev_mutations = 0; + + for (const auto& v : corrections) + { + auto x = (static_cast(v) - mean_corrections); + stddev_corrections += x * x; + } + + for (const auto& v : mutations) + { + auto x = (static_cast(v) - mean_mutations); + stddev_mutations += x * x; + } + + stddev_corrections /= number_of_runs; + stddev_mutations /= number_of_runs; + + stddev_corrections = std::sqrt(stddev_corrections); + stddev_mutations = std::sqrt(stddev_mutations); + + std::cout << "Mean Distance Corrections: " << mean_corrections << " Stddev: " << stddev_corrections << " Min: " << min_corrections << " Max: " + << max_corrections << '\n'; + std::cout << "Mean Distance Mutations: " << mean_mutations << " Stddev: " << stddev_mutations << " Min: " << min_mutations << " Max: " + << max_mutations << '\n'; } -int main() +int main(int argc, const char** argv) { + blt::arg_parse parser; + parser.addArgument(blt::arg_builder{"--latex", "-l"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).build()); + + auto args = parser.parse_args(argc, argv); + print_latex = blt::arg_parse::get(args["latex"]); + blt::logging::setLogOutputFormat("\033[94m[${{TIME}}]${{RC}} \033[35m(${{FILE}}:${{LINE}})${{RC}} ${{LF}}${{CNR}}${{STR}}${{RC}}\n"); a1::test_math();