main
Brett 2024-11-06 16:14:48 -05:00
parent 7b2bd7679a
commit a69e3274dc
5 changed files with 55 additions and 36 deletions

View File

@ -1,5 +1,6 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-3 VERSION 0.0.7) project(COSC-4P80-Assignment-3 VERSION 0.0.8)
include(FetchContent)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
@ -7,12 +8,20 @@ option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
FetchContent_Declare(implot
GIT_REPOSITORY https://github.com/epezent/implot
GIT_TAG 419a8a0f5fcb77e1e7c19ab540441686bfe21bca
FIND_PACKAGE_ARGS)
FetchContent_MakeAvailable(implot)
add_subdirectory(lib/blt-with-graphics) add_subdirectory(lib/blt-with-graphics)
include_directories(include/) include_directories(include/)
include_directories(${implot_SOURCE_DIR})
file(GLOB_RECURSE PROJECT_BUILD_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") file(GLOB_RECURSE PROJECT_BUILD_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
file(GLOB IM_PLOT_FILES "${implot_SOURCE_DIR}/*.cpp")
add_executable(COSC-4P80-Assignment-3 ${PROJECT_BUILD_FILES}) add_executable(COSC-4P80-Assignment-3 ${PROJECT_BUILD_FILES} ${IM_PLOT_FILES})
target_compile_options(COSC-4P80-Assignment-3 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment) target_compile_options(COSC-4P80-Assignment-3 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
target_link_options(COSC-4P80-Assignment-3 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment) target_link_options(COSC-4P80-Assignment-3 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)

View File

@ -33,7 +33,7 @@ namespace assign3
{ {
for (blt::size_t i = 0; i < width; i++) for (blt::size_t i = 0; i < width; i++)
for (blt::size_t j = 0; j < height; j++) for (blt::size_t j = 0; j < height; j++)
map.emplace_back(dimensions, i, j); map.emplace_back(dimensions, (j % 2 == 0 ? static_cast<Scalar>(i) : static_cast<Scalar>(i) + 0.5f), j);
} }
inline neuron_t& get(blt::size_t x, blt::size_t y) inline neuron_t& get(blt::size_t x, blt::size_t y)

View File

@ -19,6 +19,11 @@ blt::gfx::resource_manager resources;
blt::gfx::batch_renderer_2d renderer_2d(resources, global_matrices); blt::gfx::batch_renderer_2d renderer_2d(resources, global_matrices);
blt::gfx::first_person_camera_2d camera; blt::gfx::first_person_camera_2d camera;
blt::size_t som_width = 7;
blt::size_t som_height = 7;
blt::size_t max_epochs = 100;
Scalar initial_learn_rate = 0.1;
void init(const blt::gfx::window_data&) void init(const blt::gfx::window_data&)
{ {
using namespace blt::gfx; using namespace blt::gfx;
@ -28,10 +33,9 @@ void init(const blt::gfx::window_data&)
resources.load_resources(); resources.load_resources();
renderer_2d.create(); renderer_2d.create();
blt::size_t size = 5;
som = std::make_unique<som_t>( som = std::make_unique<som_t>(
*std::find_if(files.begin(), files.end(), [](const data_file_t& v) { return v.data_points.begin()->bins.size() == 32; }), *std::find_if(files.begin(), files.end(), [](const data_file_t& v) { return v.data_points.begin()->bins.size() == 32; }),
size, size, 100); som_width, som_height, max_epochs);
} }
void update(const blt::gfx::window_data& data) void update(const blt::gfx::window_data& data)
@ -45,47 +49,51 @@ void update(const blt::gfx::window_data& data)
if (ImGui::Begin("Controls")) if (ImGui::Begin("Controls"))
{ {
ImGui::Button("Run Epoch"); if (ImGui::Button("Run Epoch"))
if (ImGui::IsItemClicked())
{ {
static gaussian_function_t func; static gaussian_function_t func;
som->train_epoch(0.1, &func); som->train_epoch(initial_learn_rate, &func);
} }
static bool run;
ImGui::Checkbox("Run to completion", &run);
if (run)
{
static gaussian_function_t func;
if (som->get_current_epoch() < som->get_max_epochs())
som->train_epoch(initial_learn_rate, &func);
}
ImGui::Text("Epoch %ld / %ld", som->get_current_epoch(), som->get_max_epochs());
} }
ImGui::End(); ImGui::End();
static std::vector<blt::i64> activations;
activations.clear();
activations.resize(som->get_array().get_map().size());
auto& meow = *std::find_if(files.begin(), files.end(), [](const data_file_t& v) { return v.data_points.begin()->bins.size() == 32; }); auto& meow = *std::find_if(files.begin(), files.end(), [](const data_file_t& v) { return v.data_points.begin()->bins.size() == 32; });
for (auto& v : som->get_array().get_map()) for (auto& v : meow.data_points)
{
auto nearest = som->get_closest_neuron(v.bins);
activations[nearest] += v.is_bad ? -1 : 1;
}
blt::i64 max = *std::max_element(activations.begin(), activations.end());
blt::i64 min = *std::min_element(activations.begin(), activations.end());
for (auto [i, v] : blt::enumerate(som->get_array().get_map()))
{ {
float scale = 35; float scale = 35;
float total_good_distance = 0; auto activation = activations[i];
float total_bad_distance = 0;
float total_goods = 0;
float total_bads = 0;
for (auto& point : meow.data_points) blt::vec4 color = blt::make_color(1,1,1);
{ if (activation > 0)
auto dist = v.dist(point.bins); color = blt::make_color(0, static_cast<Scalar>(activation) / static_cast<Scalar>(max), 0);
if (point.is_bad) else if (activation < 0)
{ color = blt::make_color(std::abs(static_cast<Scalar>(activation) / static_cast<Scalar>(min)), 0, 0);
total_bads++;
total_bad_distance += dist;
} else
{
total_goods++;
total_good_distance += dist;
}
}
float good_ratio = total_goods > 0 ? total_good_distance / total_goods : 0; renderer_2d.drawPointInternal(color, point2d_t{v.get_x() * scale + scale, v.get_y() * scale + scale, scale});
float bad_ratio = total_bads > 0 ? total_bad_distance / total_bads : 0;
float good_to_bad = total_good_distance / total_bad_distance;
BLT_TRACE("%f %f %f", good_ratio, bad_ratio, good_to_bad);
renderer_2d.drawPointInternal(blt::make_color(good_ratio, bad_ratio, good_to_bad),
point2d_t{v.get_x() * scale + scale, v.get_y() * scale + scale, scale});
} }
renderer_2d.render(data.width, data.height); renderer_2d.render(data.width, data.height);

View File

@ -19,6 +19,7 @@
#include <blt/std/random.h> #include <blt/std/random.h>
#include <blt/iterator/iterator.h> #include <blt/iterator/iterator.h>
#include <cmath> #include <cmath>
#include "blt/std/logging.h"
namespace assign3 namespace assign3
{ {
@ -37,8 +38,8 @@ namespace assign3
static thread_local std::vector<Scalar> diff; static thread_local std::vector<Scalar> diff;
diff.clear(); diff.clear();
for (auto [x, v] : blt::in_pairs(new_data, data)) for (auto [v, x] : blt::in_pairs(data, new_data))
diff.push_back(v - x); diff.push_back(x - v);
for (auto [v, d] : blt::in_pairs(data, diff)) for (auto [v, d] : blt::in_pairs(data, diff))
v += eta * dist * d; v += eta * dist * d;

View File

@ -20,6 +20,7 @@
#include <algorithm> #include <algorithm>
#include <blt/std/random.h> #include <blt/std/random.h>
#include <blt/iterator/enumerate.h> #include <blt/iterator/enumerate.h>
#include <blt/std/logging.h>
namespace assign3 namespace assign3
{ {