linear regressioin

main
Brett 2024-07-19 03:16:51 -04:00
parent 088d879dc8
commit c59298d4b5
5 changed files with 251 additions and 47 deletions

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(image-gp-6 VERSION 0.0.11) project(image-gp-6 VERSION 0.0.12)
include(FetchContent) include(FetchContent)

BIN
GSab4SWWcAA1TNR.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

97
include/slr.h Normal file
View File

@ -0,0 +1,97 @@
#pragma once
/*
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef IMAGE_GP_6_SLR_H
#define IMAGE_GP_6_SLR_H
template<typename T, blt::size_t sample_size>
float mean(const std::array<T, sample_size>& data)
{
T x = 0;
for (blt::size_t n = 0; n < sample_size; n++)
{
x = x + data[n];
}
x = x / sample_size;
return x;
}
// https://github.com/georgemaier/simple-linear-regression/blob/master/slr.cpp
template<typename T, blt::size_t sample_size>
class slr
{
private:
T WN1 = 0, WN2 = 0, WN3 = 0, WN4 = 0, Sy = 0, Sx = 0;
public:
T r = 0, rsquared = 0, alpha = 0, beta = 0, x = 0, y = 0;
T yhat = 0, ybar = 0, xbar = 0;
T SSR = 0, SSE = 0, SST = 0;
T residualSE = 0, residualmax = 0, residualmin = 0, residualmean = 0, t = 0;
T SEBeta = 0, sample = 0, residuals[sample_size]{};
slr(const std::array<T, sample_size>& datax, const std::array<T, sample_size>& datay)
{
//This is the main regression function that is called when a new SLR object is created.
//Calculate means
sample = sample_size;
xbar = mean(datax);
ybar = mean(datay);
//Calculate r correlation
for (blt::size_t n = 0; n < sample_size; ++n)
{
WN1 += (datax[n] - xbar) * (datay[n] - ybar);
WN2 += pow((datax[n] - xbar), 2);
WN3 += pow((datay[n] - ybar), 2);
}
WN4 = WN2 * WN3;
r = WN1 / (std::sqrt(WN4));
//Calculate alpha and beta
Sy = std::sqrt(WN3 / (sample_size - 1));
Sx = std::sqrt(WN2 / (sample_size - 1));
beta = r * (Sy / Sx);
alpha = ybar - beta * xbar;
//Calculate SSR, SSE, R-Squared, residuals
for (blt::size_t n = 0; n < sample_size; n++)
{
yhat = alpha + (beta * datax[n]);
SSE += std::pow((yhat - ybar), 2);
SSR += std::pow((datay[n] - yhat), 2);
residuals[n] = (datay[n] - yhat);
if (residuals[n] > residualmax)
residualmax = residuals[n];
if (residuals[n] < residualmin)
residualmin = residuals[n];
residualmean += std::fabs(residuals[n]);
}
residualmean = (residualmean / sample_size);
SST = SSR + SSE;
rsquared = SSE / SST; //Can also be obtained by r ^ 2 for simple regression (i.e. 1 independent variable)
//Calculate T-test for Beta
residualSE = std::sqrt(SSR / (sample_size - 2));
SEBeta = (residualSE / (Sx * std::sqrt(sample_size - 1)));
t = beta / SEBeta;
}
};
#endif //IMAGE_GP_6_SLR_H

@ -1 +1 @@
Subproject commit c7bb4a434b25d3c918cc7908e0b527aa0101b73d Subproject commit 8e5a3f3b7c52a361a47199108be082730b1aeddd

View File

@ -35,14 +35,23 @@
#include "opencv2/imgcodecs.hpp" #include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp" #include "opencv2/imgproc.hpp"
#include <random> #include <random>
#include "slr.h"
constexpr size_t log2(size_t n) // NOLINT
{
return ((n < 2) ? 1 : 1 + log2(n / 2));
}
static const blt::u64 SEED = std::random_device()(); static const blt::u64 SEED = std::random_device()();
static constexpr long IMAGE_SIZE = 128; static constexpr blt::size_t IMAGE_SIZE = 128;
static constexpr long IMAGE_PADDING = 16; static constexpr blt::size_t IMAGE_PADDING = 16;
static constexpr long POP_SIZE = 64; static constexpr blt::size_t POP_SIZE = 64;
static constexpr blt::size_t CHANNELS = 3; static constexpr blt::size_t CHANNELS = 3;
static constexpr blt::size_t DATA_SIZE = IMAGE_SIZE * IMAGE_SIZE; static constexpr blt::size_t DATA_SIZE = IMAGE_SIZE * IMAGE_SIZE;
static constexpr blt::size_t DATA_CHANNELS_SIZE = DATA_SIZE * CHANNELS; static constexpr blt::size_t DATA_CHANNELS_SIZE = DATA_SIZE * CHANNELS;
static constexpr blt::size_t BOX_COUNT = static_cast<blt::size_t>(log2(IMAGE_SIZE / 2));
static constexpr float THRESHOLD = 0.3;
static constexpr auto load_image = "../silly.png";
blt::gfx::matrix_state_manager global_matrices; blt::gfx::matrix_state_manager global_matrices;
blt::gfx::resource_manager resources; blt::gfx::resource_manager resources;
@ -72,11 +81,17 @@ inline context get_pop_ctx(blt::size_t i)
return ctx; return ctx;
} }
inline blt::size_t get_index(blt::size_t x, blt::size_t y)
{
return y * IMAGE_SIZE + x;
}
struct full_image_t struct full_image_t
{ {
float rgb_data[DATA_SIZE * CHANNELS]{}; float rgb_data[DATA_SIZE * CHANNELS]{};
full_image_t() { full_image_t()
{
for (auto& v : rgb_data) for (auto& v : rgb_data)
v = 0; v = 0;
} }
@ -110,19 +125,30 @@ blt::i32 time_between_runs = 100;
bool is_running = false; bool is_running = false;
blt::gp::prog_config_t config = blt::gp::prog_config_t() blt::gp::prog_config_t config = blt::gp::prog_config_t()
.set_initial_min_tree_size(2) .set_initial_min_tree_size(4)
.set_initial_max_tree_size(6) .set_initial_max_tree_size(8)
.set_elite_count(1) .set_elite_count(2)
.set_max_generations(50) .set_max_generations(50)
.set_mutation_chance(0.8) .set_mutation_chance(1.0)
.set_crossover_chance(1.0) .set_crossover_chance(1.0)
.set_reproduction_chance(0) .set_reproduction_chance(0.5)
.set_pop_size(POP_SIZE) .set_pop_size(POP_SIZE)
.set_thread_count(0); .set_thread_count(16);
blt::gp::type_provider type_system; blt::gp::type_provider type_system;
blt::gp::gp_program program{type_system, SEED, config}; blt::gp::gp_program program{type_system, SEED, config};
template<typename SINGLE_FUNC>
constexpr static auto make_single(SINGLE_FUNC&& func)
{
return [func](const full_image_t& a) {
full_image_t img{};
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++)
img.rgb_data[i] = func(a.rgb_data[i]);
return img;
};
}
blt::gp::operation_t add([](const full_image_t& a, const full_image_t& b) { blt::gp::operation_t add([](const full_image_t& a, const full_image_t& b) {
full_image_t img{}; full_image_t img{};
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++) for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++)
@ -147,30 +173,12 @@ blt::gp::operation_t pro_div([](const full_image_t& a, const full_image_t& b) {
img.rgb_data[i] = b.rgb_data[i] == 0 ? 0 : (a.rgb_data[i] / b.rgb_data[i]); img.rgb_data[i] = b.rgb_data[i] == 0 ? 0 : (a.rgb_data[i] / b.rgb_data[i]);
return img; return img;
}, "div"); }, "div");
blt::gp::operation_t op_sin([](const full_image_t& a) { blt::gp::operation_t op_sin(make_single((float (*)(float)) &std::sin), "sin");
full_image_t img{}; blt::gp::operation_t op_cos(make_single((float (*)(float)) &std::cos), "cos");
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++) blt::gp::operation_t op_atan(make_single((float (*)(float)) &std::atan), "atan");
img.rgb_data[i] = std::sin(a.rgb_data[i]); blt::gp::operation_t op_exp(make_single((float (*)(float)) &std::exp), "exp");
return img; blt::gp::operation_t op_abs(make_single((float (*)(float)) &std::abs), "abs");
}, "sin"); blt::gp::operation_t op_log(make_single((float (*)(float)) &std::log), "log");
blt::gp::operation_t op_cos([](const full_image_t& a) {
full_image_t img{};
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++)
img.rgb_data[i] = std::cos(a.rgb_data[i]);
return img;
}, "cos");
blt::gp::operation_t op_exp([](const full_image_t& a) {
full_image_t img{};
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++)
img.rgb_data[i] = std::exp(a.rgb_data[i]);
return img;
}, "exp");
blt::gp::operation_t op_log([](const full_image_t& a) {
full_image_t img{};
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++)
img.rgb_data[i] = a.rgb_data[i] == 0 ? 0 : std::log(a.rgb_data[i]);
return img;
}, "log");
blt::gp::operation_t op_v_mod([](const full_image_t& a, const full_image_t& b) { blt::gp::operation_t op_v_mod([](const full_image_t& a, const full_image_t& b) {
full_image_t img{}; full_image_t img{};
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++) for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++)
@ -209,8 +217,15 @@ blt::gp::operation_t bitwise_xor([](const full_image_t& a, const full_image_t& b
blt::gp::operation_t lit([]() { blt::gp::operation_t lit([]() {
full_image_t img{}; full_image_t img{};
for (auto& i : img.rgb_data) auto r = program.get_random().get_float(0.0f, 1.0f);
i = program.get_random().get_float(0.0f, 1.0f); auto g = program.get_random().get_float(0.0f, 1.0f);
auto b = program.get_random().get_float(0.0f, 1.0f);
for (blt::size_t i = 0; i < DATA_SIZE; i++)
{
img.rgb_data[i * CHANNELS] = r;
img.rgb_data[i * CHANNELS + 1] = g;
img.rgb_data[i * CHANNELS + 2] = b;
}
return img; return img;
}, "lit"); }, "lit");
blt::gp::operation_t random_val([]() { blt::gp::operation_t random_val([]() {
@ -309,6 +324,80 @@ static blt::gp::operation_t op_y_b([]() {
return img; return img;
}, "y_b"); }, "y_b");
constexpr float compare_values(float a, float b)
{
if (std::isnan(a) || std::isnan(b) || std::isinf(a) || std::isinf(b))
return IMAGE_SIZE;
auto dist = a - b;
//BLT_TRACE(std::sqrt(dist * dist));
return std::sqrt(dist * dist);
}
struct fractal_stats
{
blt::f64 box_size, num_boxes, xy, x2, y2;
};
bool in_box(full_image_t& image, blt::size_t channel, blt::size_t box_size, blt::size_t i, blt::size_t j)
{
// TODO: this could be made better by starting from the smallest boxes, moving upwards and using the last set of boxes
// instead of pixels, since they contain already computed information about if a box is in foam
for (blt::size_t x = i; x < i + box_size; x++)
{
for (blt::size_t y = j; y < j + box_size; y++)
{
if (image.rgb_data[get_index(x, y) * CHANNELS + channel] > THRESHOLD)
return true;
}
}
return false;
}
blt::f64 get_fractal_value(full_image_t& image, blt::size_t channel)
{
std::array<fractal_stats, BOX_COUNT> box_data{};
std::array<double, BOX_COUNT> x_data{};
std::array<double, BOX_COUNT> y_data{};
for (blt::size_t box_size = IMAGE_SIZE / 2; box_size > 1; box_size /= 2)
{
blt::ptrdiff_t num_boxes = 0;
for (blt::size_t i = 0; i < IMAGE_SIZE; i += box_size)
{
for (blt::size_t j = 0; j < IMAGE_SIZE; j += box_size)
{
if (in_box(image, channel, box_size, i, j))
num_boxes++;
}
}
auto x = static_cast<blt::f64>(std::log2(box_size));
auto y = static_cast<blt::f64>(num_boxes == 0 ? 0 : std::log2(num_boxes));
//auto y = static_cast<blt::f64>(num_boxes);
box_data[static_cast<blt::size_t>(std::log2(box_size)) - 1] = {x, y, x * y, x * x, y * y};
x_data[static_cast<blt::size_t>(std::log2(box_size)) - 1] = x;
y_data[static_cast<blt::size_t>(std::log2(box_size)) - 1] = y;
//BLT_DEBUG("%lf vs %lf", x, y);
}
// fractal_stats total{};
// for (const auto& b : box_data)
// {
// total.box_size += b.box_size;
// total.num_boxes += b.num_boxes;
// total.xy += b.xy;
// total.x2 += b.x2;
// total.y2 += b.y2;
// }
//
// auto n = static_cast<blt::f64>(BOX_COUNT);
// auto b0 = ((total.num_boxes * total.x2) - (total.box_size * total.xy)) / ((n * total.x2) - (total.box_size * total.box_size));
// auto b1 = ((n * total.xy) - (total.box_size * total.num_boxes)) / ((n * total.x2) - (total.box_size * total.box_size));
//
// return b1;
slr count{x_data, y_data};
return count.beta;
}
constexpr auto create_fitness_function() constexpr auto create_fitness_function()
{ {
return [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t index) { return [](blt::gp::tree_t& current_tree, blt::gp::fitness_t& fitness, blt::size_t index) {
@ -317,17 +406,23 @@ constexpr auto create_fitness_function()
fitness.raw_fitness = 0; fitness.raw_fitness = 0;
for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++) for (blt::size_t i = 0; i < DATA_CHANNELS_SIZE; i++)
fitness.raw_fitness += compare_values(v.rgb_data[i], base_image.rgb_data[i]);
fitness.raw_fitness /= (IMAGE_SIZE * IMAGE_SIZE);
for (blt::size_t channel = 0; channel < CHANNELS; channel++)
{ {
auto base = base_image.rgb_data[i]; auto raw = -get_fractal_value(v, channel);
auto set = v.rgb_data[i]; auto fit = 1.0 - std::max(0.0, 1.0 - std::abs(1.35 - raw));
if (std::isnan(set)) BLT_DEBUG("Fitness %lf (raw: %lf) for channel %lu", fit, raw, channel);
set = 1 - base; if (std::isnan(raw))
auto dist = set - base; fitness.raw_fitness += 400;
fitness.raw_fitness += std::sqrt(dist * dist); else
fitness.raw_fitness += raw;
} }
//BLT_TRACE("Raw fitness: %lf for %ld", fitness.raw_fitness, index); //BLT_TRACE("Raw fitness: %lf for %ld", fitness.raw_fitness, index);
fitness.standardized_fitness = fitness.raw_fitness / IMAGE_SIZE; fitness.standardized_fitness = fitness.raw_fitness;
fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness)); fitness.adjusted_fitness = (1.0 / (1.0 + fitness.standardized_fitness));
}; };
} }
@ -370,13 +465,13 @@ void init(const blt::gfx::window_data&)
BLT_INFO("Using Seed: %ld", SEED); BLT_INFO("Using Seed: %ld", SEED);
BLT_START_INTERVAL("Image Test", "Main"); BLT_START_INTERVAL("Image Test", "Main");
BLT_DEBUG("Setup Base Image"); BLT_DEBUG("Setup Base Image");
base_image.load("../my_pride_flag.png"); base_image.load(load_image);
BLT_DEBUG("Setup Types and Operators"); BLT_DEBUG("Setup Types and Operators");
type_system.register_type<full_image_t>(); type_system.register_type<full_image_t>();
blt::gp::operator_builder<context> builder{type_system}; blt::gp::operator_builder<context> builder{type_system};
builder.add_operator(perlin); //builder.add_operator(perlin);
builder.add_operator(perlin_terminal); builder.add_operator(perlin_terminal);
builder.add_operator(add); builder.add_operator(add);
@ -385,8 +480,10 @@ void init(const blt::gfx::window_data&)
builder.add_operator(pro_div); builder.add_operator(pro_div);
builder.add_operator(op_sin); builder.add_operator(op_sin);
builder.add_operator(op_cos); builder.add_operator(op_cos);
builder.add_operator(op_atan);
builder.add_operator(op_exp); builder.add_operator(op_exp);
builder.add_operator(op_log); builder.add_operator(op_log);
builder.add_operator(op_abs);
builder.add_operator(op_v_mod); builder.add_operator(op_v_mod);
builder.add_operator(bitwise_and); builder.add_operator(bitwise_and);
builder.add_operator(bitwise_or); builder.add_operator(bitwise_or);
@ -472,6 +569,11 @@ void update(const blt::gfx::window_data& data)
if (io.WantCaptureMouse) if (io.WantCaptureMouse)
continue; continue;
if (blt::gfx::mousePressedLastFrame())
{
program.get_current_pop().get_individuals()[i].tree.print(program, std::cout, false);
}
// if (blt::gfx::mousePressedLastFrame()) // if (blt::gfx::mousePressedLastFrame())
// { // {
// if (blt::gfx::isKeyPressed(GLFW_KEY_LEFT_SHIFT)) // if (blt::gfx::isKeyPressed(GLFW_KEY_LEFT_SHIFT))
@ -515,6 +617,11 @@ int main()
BLT_END_INTERVAL("Image Test", "Main"); BLT_END_INTERVAL("Image Test", "Main");
base_image.save("input.png"); base_image.save("input.png");
for (blt::size_t i = 0; i < CHANNELS; i++)
{
auto v = -get_fractal_value(base_image, i);
BLT_INFO("Base image values per channel: %lf", v);
}
BLT_PRINT_PROFILE("Image Test", blt::PRINT_CYCLES | blt::PRINT_THREAD | blt::PRINT_WALL); BLT_PRINT_PROFILE("Image Test", blt::PRINT_CYCLES | blt::PRINT_THREAD | blt::PRINT_WALL);