main
Brett 2024-10-07 16:28:39 -04:00
parent e3ef31ffc2
commit 1a8f828975
2 changed files with 112 additions and 25 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-1 VERSION 0.0.21)
project(COSC-4P80-Assignment-1 VERSION 0.0.22)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)

View File

@ -6,6 +6,7 @@
#include <blt/std/random.h>
#include <blt/format/boxing.h>
#include <blt/iterator/iterator.h>
#include <blt/parse/argparse.h>
#include <a1.h>
constexpr blt::u32 num_values_part_a = 3;
@ -14,10 +15,26 @@ constexpr blt::u32 num_values_part_c2 = 7;
constexpr blt::u32 input_vec_size = 5;
constexpr blt::u32 output_vec_size = 4;
bool print_latex = false;
using input_t = a1::matrix_t<1, input_vec_size>;
using output_t = a1::matrix_t<1, output_vec_size>;
using weight_t = decltype(std::declval<input_t>().transpose() * std::declval<output_t>());
template<typename Os, typename T, blt::u32 size>
Os& print_vec_square(Os& o, const blt::vec<T, size>& v)
{
o << "[";
for (auto [i, f] : blt::enumerate(v))
{
o << f;
if (i != size - 1)
o << ", ";
}
o << "]";
return o;
}
struct correctness_t
{
blt::size_t correct_input = 0;
@ -184,6 +201,20 @@ class executor
} while (steps.rbegin()[0] != steps.rbegin()[1]);
}
[[nodiscard]] input_t correct(const input_t& v) const
{
// outputs here do not matter.
ping_pong current{weights, v, outputs.front()};
ping_pong next{weights, v, outputs.front()};
do
{
current = next;
next = current.run_step();
// run until stability
} while (current != next);
return next.get_input();
}
void print_execution_summary()
{
using namespace blt::logging;
@ -343,11 +374,6 @@ class executor
return steps.back();
}
std::vector<input_t>& get_inputs()
{
return inputs;
}
private:
weight_t weights;
std::vector<input_t> inputs;
@ -422,7 +448,7 @@ blt::size_t hdist(const input_t& a, const input_t& b)
{
blt::size_t diff = 0;
for (auto [av, bv] : blt::in_pairs(a.vec_from_column_row(), b.vec_from_column_row()))
diff += (av == bv ? 1 : 0);
diff += (av != bv ? 1 : 0);
return diff;
}
@ -430,41 +456,102 @@ void part_d()
{
blt::log_box_t box(BLT_TRACE_STREAM, "Part D", 8);
blt::random::random_t random(std::random_device{}());
blt::size_t number_of_runs = 20;
executor cute(part_a_inputs, part_a_outputs);
constexpr blt::size_t number_of_runs = 80;
std::vector<blt::size_t> mutations;
std::vector<blt::size_t> corrections;
blt::size_t total_corrections = 0;
blt::size_t total_mutations = 0;
blt::size_t min_corrections = std::numeric_limits<blt::size_t>::max();
blt::size_t max_corrections = 0;
blt::size_t min_mutations = std::numeric_limits<blt::size_t>::max();
blt::size_t max_mutations = 0;
for (blt::size_t run = 0; run < number_of_runs; run++)
{
auto inputs = part_a_inputs;
executor cute(part_a_inputs, part_a_outputs);
auto pos = random.get_size_t(0, inputs.size());
auto& input = inputs[pos];
auto original = input;
for (blt::size_t i = 0; i < std::remove_reference_t<decltype(input)>::data_columns; i++)
auto pos = random.get_size_t(0, part_a_inputs.size());
auto original = part_a_inputs[pos];
auto modified = original;
for (blt::size_t i = 0; i < std::remove_reference_t<decltype(modified)>::data_columns; i++)
{
if (random.choice(0.8))
if (random.choice(0.2))
{
// flip value of this location
auto& d = input[i][0];
auto& d = modified[i][0];
if (d >= 0)
d = -1;
else
d = 1;
}
}
cute.get_inputs()[pos] = input;
cute.execute();
auto corrected = cute.get_results()[pos].get_input();
auto corrected = cute.correct(modified);
auto dist_o_m = hdist(original, input);
auto dist_m_c = hdist(input, corrected);
auto dist_o_m = hdist(original, modified);
auto dist_o_c = hdist(original, corrected);
BLT_TRACE("Run %ld mutated difference: %ld corrected difference: %ld", run, dist_o_m, dist_m_c);
}
}
corrections.push_back(dist_o_c);
mutations.push_back(dist_o_m);
total_corrections += dist_o_c;
total_mutations += dist_o_m;
int main()
min_corrections = std::min(dist_o_c, min_corrections);
max_corrections = std::max(dist_o_c, max_corrections);
min_mutations = std::min(dist_o_m, min_mutations);
max_mutations = std::max(dist_o_m, max_mutations);
if (print_latex)
{
std::cout << run + 1 << " & ";
print_vec_square(std::cout, original.vec_from_column_row()) << " & ";
print_vec_square(std::cout, modified.vec_from_column_row()) << " & ";
std::cout << dist_o_m << " & ";
print_vec_square(std::cout, corrected.vec_from_column_row()) << " & ";
std::cout << dist_o_c << " \\\\ \n\\hline\n";
} else
{
BLT_TRACE_STREAM << "Run " << run << " " << original.vec_from_column_row() << " || mutated " << modified.vec_from_column_row()
<< " difference: " << dist_o_m << " || corrected " << corrected.vec_from_column_row() << " || difference: " << dist_o_c
<< "\n";
}
}
double mean_corrections = static_cast<double>(total_corrections) / number_of_runs;
double mean_mutations = static_cast<double>(total_mutations) / number_of_runs;
double stddev_corrections = 0;
double stddev_mutations = 0;
for (const auto& v : corrections)
{
auto x = (static_cast<double>(v) - mean_corrections);
stddev_corrections += x * x;
}
for (const auto& v : mutations)
{
auto x = (static_cast<double>(v) - mean_mutations);
stddev_mutations += x * x;
}
stddev_corrections /= number_of_runs;
stddev_mutations /= number_of_runs;
stddev_corrections = std::sqrt(stddev_corrections);
stddev_mutations = std::sqrt(stddev_mutations);
std::cout << "Mean Distance Corrections: " << mean_corrections << " Stddev: " << stddev_corrections << " Min: " << min_corrections << " Max: "
<< max_corrections << '\n';
std::cout << "Mean Distance Mutations: " << mean_mutations << " Stddev: " << stddev_mutations << " Min: " << min_mutations << " Max: "
<< max_mutations << '\n';
}
int main(int argc, const char** argv)
{
blt::arg_parse parser;
parser.addArgument(blt::arg_builder{"--latex", "-l"}.setAction(blt::arg_action_t::STORE_TRUE).setDefault(false).build());
auto args = parser.parse_args(argc, argv);
print_latex = blt::arg_parse::get<bool>(args["latex"]);
blt::logging::setLogOutputFormat("\033[94m[${{TIME}}]${{RC}} \033[35m(${{FILE}}:${{LINE}})${{RC}} ${{LF}}${{CNR}}${{STR}}${{RC}}\n");
a1::test_math();