getting closer!

main
Brett 2024-10-25 01:22:32 -04:00
parent 1b79238114
commit 4216b53b28
14 changed files with 229 additions and 76 deletions

3
.gitmodules vendored
View File

@ -1,3 +1,6 @@
[submodule "lib/blt"]
path = lib/blt
url = https://github.com/Tri11Paragon/BLT.git
[submodule "lib/blt-graphics"]
path = lib/blt-graphics
url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template

View File

@ -3,6 +3,7 @@
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt-graphics" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt/libraries/parallel-hashmap" vcs="Git" />
</component>
</project>

9
1.txt Normal file
View File

@ -0,0 +1,9 @@
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09

9
2.txt Normal file
View File

@ -0,0 +1,9 @@
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09

9
3.txt Normal file
View File

@ -0,0 +1,9 @@
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09

View File

@ -1,9 +1,10 @@
cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-2 VERSION 0.0.3)
project(COSC-4P80-Assignment-2 VERSION 0.0.4)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
option(ENABLE_GRAPHICS "Enable usage of graphics package" OFF)
#option(EIGEN_TEST_CXX11 "Enable testing with C++11 and C++11 features (e.g. Tensor module)." ON)
set(CMAKE_CXX_STANDARD 17)
@ -12,7 +13,12 @@ if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
if (ENABLE_GRAPHICS)
add_subdirectory(lib/blt-graphics)
add_compile_definitions(BLT_USE_GRAPHICS)
else ()
add_subdirectory(lib/blt)
endif ()
#add_subdirectory(lib/eigen-3.4.0)
@ -25,7 +31,11 @@ target_compile_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -
target_link_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
#target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT Eigen3::Eigen)
if (ENABLE_GRAPHICS)
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT_WITH_GRAPHICS)
else ()
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT)
endif ()
if (${ENABLE_ADDRSAN} MATCHES ON)
target_compile_options(COSC-4P80-Assignment-2 PRIVATE -fsanitize=address)

View File

@ -19,10 +19,25 @@
#ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H
#define COSC_4P80_ASSIGNMENT_2_COMMON_H
#include <iostream>
#include <blt/iterator/enumerate.h>
namespace assign2
{
using Scalar = float;
const inline Scalar learn_rate = 0.1;
template<typename T>
decltype(std::cout)& print_vec(const std::vector<T>& vec)
{
for (auto [i, v] : blt::enumerate(vec))
{
std::cout << v;
if (i != vec.size() - 1)
std::cout << ", ";
}
return std::cout;
}
struct data_t
{
@ -36,15 +51,23 @@ namespace assign2
};
class layer_t;
class network_t;
struct function_t
{
[[nodiscard]] virtual Scalar call(Scalar) const = 0;
[[nodiscard]] virtual Scalar derivative(Scalar) const = 0;
};
struct weight_view
{
public:
weight_view(double* data, blt::size_t size): m_data(data), m_size(size)
weight_view(Scalar* data, blt::size_t size): m_data(data), m_size(size)
{}
inline double& operator[](blt::size_t index) const
inline Scalar& operator[](blt::size_t index) const
{
#if BLT_DEBUG_LEVEL > 0
if (index >= size)
@ -69,7 +92,7 @@ namespace assign2
}
private:
double* m_data;
Scalar* m_data;
blt::size_t m_size;
};
@ -86,9 +109,16 @@ namespace assign2
return {&data[size], count};
}
void debug() const
{
std::cout << "Weights: ";
print_vec(data) << std::endl;
}
private:
std::vector<double> data;
std::vector<Scalar> data;
};
}
#endif //COSC_4P80_ASSIGNMENT_2_COMMON_H

View File

@ -24,22 +24,44 @@
namespace assign2
{
struct sigmoid_function
struct sigmoid_function : public function_t
{
[[nodiscard]] Scalar call(Scalar s) const // NOLINT
[[nodiscard]] Scalar call(const Scalar s) const final
{
return 1 / (1 + std::exp(-s));
}
[[nodiscard]] Scalar derivative(Scalar s) const
[[nodiscard]] Scalar derivative(const Scalar s) const final
{
return call(s) * (1 - call(s));
auto v = call(s);
return v * (1 - v);
}
};
struct linear_function
struct threshold_function : public function_t
{
[[nodiscard]] Scalar call(const Scalar s) const final
{
return s >= 0 ? 1 : 0;
}
[[nodiscard]] Scalar derivative(Scalar s) const final
{
return 0;
}
};
struct relu_function : public function_t
{
[[nodiscard]] Scalar call(const Scalar s) const final
{
return std::max(static_cast<Scalar>(0), s);
}
[[nodiscard]] Scalar derivative(Scalar s) const final
{
return 0;
}
};
}

View File

@ -28,6 +28,7 @@ namespace assign2
{
class neuron_t
{
friend layer_t;
public:
// empty neuron for loading from a stream
explicit neuron_t(weight_view weights): weights(weights)
@ -37,13 +38,13 @@ namespace assign2
explicit neuron_t(weight_view weights, Scalar bias): bias(bias), weights(weights)
{}
template<typename ActFunc>
Scalar activate(const Scalar* inputs, ActFunc func) const
Scalar activate(const Scalar* inputs, function_t* act_func)
{
auto sum = bias;
z = bias;
for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()}))
sum += x * w;
return func.call(sum);
z += x * w;
a = act_func->call(z);
return a;
}
template<typename OStream>
@ -62,8 +63,16 @@ namespace assign2
stream >> bias;
}
void debug() const
{
std::cout << bias << " ";
}
private:
Scalar bias = 0;
float z = 0;
float a = 0;
float bias = 0;
float error = 0;
weight_view weights;
};
@ -71,7 +80,8 @@ namespace assign2
{
public:
template<typename WeightFunc, typename BiasFunc>
layer_t(const blt::i32 in, const blt::i32 out, WeightFunc w, BiasFunc b): in_size(in), out_size(out)
layer_t(const blt::i32 in, const blt::i32 out, function_t* act_func, WeightFunc w, BiasFunc b):
in_size(in), out_size(out), act_func(act_func)
{
neurons.reserve(out_size);
for (blt::i32 i = 0; i < out_size; i++)
@ -83,8 +93,7 @@ namespace assign2
}
}
template<typename ActFunction>
std::vector<Scalar> call(const std::vector<Scalar>& in, ActFunction func = ActFunction{})
std::vector<Scalar> call(const std::vector<Scalar>& in)
{
std::vector<Scalar> out;
out.reserve(out_size);
@ -93,10 +102,47 @@ namespace assign2
throw std::runtime_exception("Input vector doesn't match expected input size!");
#endif
for (auto& n : neurons)
out.push_back(n.activate(in.data(), func));
out.push_back(n.activate(in.data(), act_func));
return out;
}
Scalar back_prop(const std::vector<Scalar>& prev_layer_output, Scalar error, const layer_t& next_layer, bool is_output)
{
std::vector<Scalar> dw;
// δ(h)
if (is_output)
{
// assign error to output layer
for (auto& n : neurons)
n.error = act_func->derivative(n.z) * error; // f'act(net(h)) * (error)
} else
{
// first calculate and assign input layer error
std::vector<Scalar> next_error;
next_error.resize(next_layer.neurons.size());
for (const auto& [i, w] : blt::enumerate(next_layer.neurons))
{
for (auto wv : w.weights)
next_error[i] += w.error * wv;
// needed?
next_error[i] /= static_cast<Scalar>(w.weights.size());
}
for (auto& n : neurons)
{
n.error = act_func->derivative(n.z);
}
}
for (const auto& v : prev_layer_output)
{
}
return error_at_current_layer;
}
template<typename OStream>
OStream& serialize(OStream& stream)
{
@ -120,9 +166,20 @@ namespace assign2
{
return out_size;
}
void debug() const
{
std::cout << "Bias: ";
for (auto& v : neurons)
v.debug();
std::cout << std::endl;
weights.debug();
}
private:
const blt::i32 in_size, out_size;
weight_t weights;
function_t* act_func;
std::vector<neuron_t> neurons;
};
}

View File

@ -21,6 +21,7 @@
#include <assign2/common.h>
#include <assign2/layer.h>
#include "blt/std/assert.h"
namespace assign2
{
@ -75,8 +76,7 @@ namespace assign2
network_t() = default;
template<typename ActFunc, typename ActFuncOut>
std::vector<Scalar> execute(const std::vector<Scalar>& input, ActFunc func, ActFuncOut outFunc)
std::vector<Scalar> execute(const std::vector<Scalar>& input)
{
std::vector<Scalar> previous_output;
std::vector<Scalar> current_output;
@ -85,39 +85,45 @@ namespace assign2
{
previous_output = current_output;
if (i == 0)
current_output = v.call(input, func);
else if (i == layers.size() - 1)
current_output = v.call(previous_output, outFunc);
current_output = v.call(input);
else
current_output = v.call(previous_output, func);
current_output = v.call(previous_output);
}
return current_output;
}
std::pair<Scalar, Scalar> error(const std::vector<Scalar>& outputs, bool is_bad)
{
BLT_ASSERT(outputs.size() == 2);
auto g = is_bad ? 0.0f : 1.0f;
auto b = is_bad ? 1.0f : 0.0f;
auto g_diff = outputs[0] - g;
auto b_diff = outputs[1] - b;
auto error = g_diff * g_diff + b_diff * b_diff;
BLT_INFO("%f %f %f", error, g_diff, b_diff);
return {0.5f * (error * error), error};
}
Scalar train(const data_file_t& example)
{
const Scalar learn_rate = 0.1;
Scalar total_error = 0;
Scalar total_d_error = 0;
for (const auto& x : example.data_points)
{
auto o = execute(x.bins, sigmoid_function{}, sigmoid_function{});
auto y = x.is_bad ? 1.0f : 0.0f;
Scalar is_bad = 0;
if (o[0] >= 1)
is_bad = 0;
else if (o[1] >= 1)
is_bad = 1;
auto error = y - is_bad;
if (o[0] >= 1 && o[1] >= 1)
error += 1;
total_error += error;
print_vec(x.bins) << std::endl;
auto o = execute(x.bins);
print_vec(o) << std::endl;
auto [e, de] = error(o, x.is_bad);
total_error += e;
total_d_error += -learn_rate * de;
BLT_TRACE("\tError %f, %f, is bad? %s", e, -learn_rate * de, x.is_bad ? "True" : "False");
}
BLT_DEBUG("Total Errors found %f, %f", total_error, total_d_error);
return total_error;
}

1
lib/blt-graphics Submodule

@ -0,0 +1 @@
Subproject commit 8103a3ad0ff5341c61ba07c4b1b9803b2cc740e9

View File

@ -219,9 +219,9 @@ public:
, output_dims(output_dims_), output_strides(output_strides_)
{}
void operator()(const Scalar* output_data) const
void operator()(const Scalar* neuron_data) const
{
check_recursive(input_data, output_data);
check_recursive(input_data, neuron_data);
}
};

View File

@ -86,7 +86,7 @@ static void test_block_io_copy_data_from_source_to_target() {
auto output_strides = internal::strides<Layout>(dims);
const T* input_data = input.data();
T* output_data = output.data();
T* neuron_data = output.data();
T* block_data = block.data();
for (int i = 0; i < block_mapper.blockCount(); ++i) {
@ -105,7 +105,7 @@ static void test_block_io_copy_data_from_source_to_target() {
{
// Write from block buffer to output.
IODst dst(blk_dims, output_strides, output_data, desc.offset());
IODst dst(blk_dims, output_strides, neuron_data, desc.offset());
IOSrc src(blk_strides, block_data, 0);
TensorBlockIO::Copy(dst, src);
@ -113,7 +113,7 @@ static void test_block_io_copy_data_from_source_to_target() {
}
for (int i = 0; i < dims.TotalSize(); ++i) {
VERIFY_IS_EQUAL(input_data[i], output_data[i]);
VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
}
}
@ -159,7 +159,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
auto output_strides = internal::strides<Layout>(output_tensor_dims);
const T* input_data = input.data();
T* output_data = output.data();
T* neuron_data = output.data();
T* block_data = block.data();
for (Index i = 0; i < block_mapper.blockCount(); ++i) {
@ -198,7 +198,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
}
// Write from block buffer to output.
IODst dst(dst_dims, input_strides, output_data, first_coeff_index);
IODst dst(dst_dims, input_strides, neuron_data, first_coeff_index);
IOSrc src(blk_strides, block_data, 0);
// TODO(ezhulenev): Remove when fully switched to TensorBlock.
@ -210,7 +210,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
}
for (Index i = 0; i < dims.TotalSize(); ++i) {
VERIFY_IS_EQUAL(input_data[i], output_data[i]);
VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
}
}

View File

@ -68,18 +68,6 @@ std::vector<data_file_t> load_data_files(const std::vector<std::string>& files)
return loaded_data;
}
template<typename T>
decltype(std::cout)& print_vec(const std::vector<T>& vec)
{
for (auto [i, v] : blt::enumerate(vec))
{
std::cout << v;
if (i != vec.size() - 1)
std::cout << ", ";
}
return std::cout;
}
int main(int argc, const char** argv)
{
blt::arg_parse parser;
@ -90,6 +78,22 @@ int main(int argc, const char** argv)
auto data_files = load_data_files(get_data_files(data_directory));
random_init randomizer{619};
sigmoid_function sig;
relu_function relu;
threshold_function thresh;
layer_t layer1{16, 8, &sig, randomizer, randomizer};
layer1.debug();
layer_t layer2{8, 8, &sig, randomizer, randomizer};
layer2.debug();
layer_t layer3{8, 8, &sig, randomizer, randomizer};
layer3.debug();
layer_t layer_output{8, 2, &relu, randomizer, randomizer};
layer_output.debug();
network_t network{{layer1, layer2, layer3, layer_output}};
std::vector<Scalar> input;
input.resize(16);
for (auto f : data_files)
@ -98,20 +102,12 @@ int main(int argc, const char** argv)
{
for (auto [i, b] : blt::enumerate(f.data_points.begin()->bins))
input[i] = b;
network.train(f);
break;
}
}
random_init randomizer{619};
sigmoid_function sig;
layer_t layer1{16, 4, randomizer, empty_init{}};
layer_t layer2{4, 4, randomizer, empty_init{}};
layer_t layer3{4, 4, randomizer, empty_init{}};
layer_t layer_output{4, 1, randomizer, empty_init{}};
network_t network{{layer1, layer2, layer3, layer_output}};
auto output = network.execute(input, sig, sig);
auto output = network.execute(input);
print_vec(output) << std::endl;
// for (auto d : data_files)