Brett 2024-11-15 19:06:12 -05:00
parent 7179be26e3
commit 6480c060c6
5 changed files with 81 additions and 11 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-3 VERSION 0.0.16) project(COSC-4P80-Assignment-3 VERSION 0.0.17)
include(FetchContent) include(FetchContent)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)

View File

@ -92,9 +92,9 @@ namespace assign3
void update_graphics(); void update_graphics();
void generate_network(int selection) void regenerate_network()
{ {
som = std::make_unique<som_t>(motor_data.files[selection].normalize(), som_width, som_height, max_epochs); som = std::make_unique<som_t>(motor_data.files[currently_selected_network].normalize(), som_width, som_height, max_epochs);
} }
private: private:
@ -109,9 +109,9 @@ namespace assign3
float draw_height = 0; float draw_height = 0;
float neuron_scale = 35; float neuron_scale = 35;
blt::size_t som_width = 5; blt::i32 som_width = 5;
blt::size_t som_height = 5; blt::i32 som_height = 5;
blt::size_t max_epochs = 100; blt::i32 max_epochs = 100;
Scalar initial_learn_rate = 0.1; Scalar initial_learn_rate = 0.1;
int currently_selected_network = 0; int currently_selected_network = 0;

View File

@ -37,6 +37,8 @@ namespace assign3
void train_epoch(Scalar initial_learn_rate, topology_function_t* basis_func); void train_epoch(Scalar initial_learn_rate, topology_function_t* basis_func);
blt::vec2 get_topological_position(const std::vector<Scalar>& data);
[[nodiscard]] const array_t& get_array() const [[nodiscard]] const array_t& get_array() const
{ return array; } { return array; }

View File

@ -78,7 +78,7 @@ namespace assign3
topology_function = std::make_unique<gaussian_function_t>(); topology_function = std::make_unique<gaussian_function_t>();
generate_network(currently_selected_network); regenerate_network();
update_graphics(); update_graphics();
} }
@ -117,13 +117,21 @@ namespace assign3
ImGui::Text("Network Select"); ImGui::Text("Network Select");
if (ImGui::ListBox("##Network Select", &currently_selected_network, get_selection_string, motor_data.map_files_names.data(), if (ImGui::ListBox("##Network Select", &currently_selected_network, get_selection_string, motor_data.map_files_names.data(),
static_cast<int>(motor_data.map_files_names.size()))) static_cast<int>(motor_data.map_files_names.size())))
generate_network(currently_selected_network); regenerate_network();
if (ImGui::Button("Run Epoch")) if (ImGui::Button("Run Epoch"))
som->train_epoch(initial_learn_rate, topology_function.get()); som->train_epoch(initial_learn_rate, topology_function.get());
ImGui::Checkbox("Run to completion", &running); ImGui::Checkbox("Run to completion", &running);
ImGui::Text("Epoch %ld / %ld", som->get_current_epoch(), som->get_max_epochs()); ImGui::Text("Epoch %ld / %ld", som->get_current_epoch(), som->get_max_epochs());
} }
if (ImGui::CollapsingHeader("SOM Settings"))
{
if (ImGui::InputInt("SOM Width", &som_width) || ImGui::InputInt("SOM Height", &som_height) ||
ImGui::InputInt("Max Epochs", &max_epochs))
regenerate_network();
if (ImGui::InputFloat("Initial Learn Rate", &initial_learn_rate))
regenerate_network();
}
if (ImGui::CollapsingHeader("Debug")) if (ImGui::CollapsingHeader("Debug"))
{ {
ImGui::Checkbox("Debug Visuals", &debug_mode); ImGui::Checkbox("Debug Visuals", &debug_mode);
@ -230,11 +238,25 @@ namespace assign3
{ {
case debug_type::DATA_POINT: case debug_type::DATA_POINT:
{ {
std::vector<blt::vec2> data_positions;
for (const auto& [i, v] : blt::enumerate(file.data_points))
{
auto pos = som->get_topological_position(v.bins) * 120 + blt::vec2{370, 145};
auto color = blt::make_color(1,1,1);
float z_index = 1;
if (i == static_cast<blt::size_t>(selected_data_point))
{
color = blt::make_color(1, 0, 1);
z_index = 2;
}
br2d.drawRectangleInternal(color, blt::gfx::rectangle2d_t{pos, blt::vec2{8,8}}, z_index);
}
const auto& data_point = file.data_points[selected_data_point]; const auto& data_point = file.data_points[selected_data_point];
auto closest_type = get_neuron_activations(file); auto closest_type = get_neuron_activations(file);
draw_som(neuron_render_info_t{}.set_base_pos({370, 145}).set_neuron_scale(120).set_neuron_padding({50, 50}), draw_som(neuron_render_info_t{}.set_base_pos({370, 145}).set_neuron_scale(120).set_neuron_padding({0, 0}),
[this, &closest_type](render_data_t context) { [this, &data_point, &closest_type](render_data_t context) {
auto& text = fr2d.render_text(std::to_string(closest_type[context.index]), 13); auto& text = fr2d.render_text(std::to_string(context.neuron.dist(data_point.bins)), 18).setColor(0.2, 0.2, 0.8);
auto text_width = text.getAssociatedText().getTextWidth(); auto text_width = text.getAssociatedText().getTextWidth();
auto text_height = text.getAssociatedText().getTextHeight(); auto text_height = text.getAssociatedText().getTextHeight();
text.setPosition(context.neuron_padded - blt::vec2{text_width / 2.0f, text_height / 2.0f}).setZIndex(1); text.setPosition(context.neuron_padded - blt::vec2{text_width / 2.0f, text_height / 2.0f}).setZIndex(1);

View File

@ -92,5 +92,51 @@ namespace assign3
return distance_min; return distance_min;
} }
struct distance_data_t
{
Scalar data;
blt::size_t index;
distance_data_t(Scalar data, size_t index): data(data), index(index)
{}
inline friend bool operator<(const distance_data_t& a, const distance_data_t& b)
{
return a.data < b.data;
}
inline friend bool operator==(const distance_data_t& a, const distance_data_t& b)
{
return a.data == b.data;
}
};
blt::vec2 som_t::get_topological_position(const std::vector<Scalar>& data)
{
std::vector<distance_data_t> distances;
for (auto [i, d] : blt::enumerate(get_array().get_map()))
distances.emplace_back(d.dist(data), i);
std::sort(distances.begin(), distances.end());
auto [dist_1, ni_1] = distances[0];
auto [dist_2, ni_2] = distances[1];
auto [dist_3, ni_3] = distances[2];
float dt = dist_1 + dist_2 + dist_3;
float dp1 = dist_1 / dt;
float dp2 = dist_2 / dt;
float dp3 = dist_3 / dt;
auto n_1 = array.get_map()[ni_1];
auto n_2 = array.get_map()[ni_2];
auto n_3 = array.get_map()[ni_3];
auto p_1 = blt::vec2{n_1.get_x(), n_1.get_y()};
auto p_2 = blt::vec2{n_2.get_x(), n_2.get_y()};
auto p_3 = blt::vec2{n_3.get_x(), n_3.get_y()};
return (dp1 * p_1) + (dp2 * p_2) + (dp3 * p_3);
}
} }