getting closer!

main
Brett 2024-10-25 01:22:32 -04:00
parent 1b79238114
commit 4216b53b28
14 changed files with 229 additions and 76 deletions

3
.gitmodules vendored
View File

@ -1,3 +1,6 @@
[submodule "lib/blt"] [submodule "lib/blt"]
path = lib/blt path = lib/blt
url = https://github.com/Tri11Paragon/BLT.git url = https://github.com/Tri11Paragon/BLT.git
[submodule "lib/blt-graphics"]
path = lib/blt-graphics
url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template

View File

@ -3,6 +3,7 @@
<component name="VcsDirectoryMappings"> <component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" /> <mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt" vcs="Git" /> <mapping directory="$PROJECT_DIR$/lib/blt" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt-graphics" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt/libraries/parallel-hashmap" vcs="Git" /> <mapping directory="$PROJECT_DIR$/lib/blt/libraries/parallel-hashmap" vcs="Git" />
</component> </component>
</project> </project>

9
1.txt Normal file
View File

@ -0,0 +1,9 @@
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09

9
2.txt Normal file
View File

@ -0,0 +1,9 @@
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09

9
3.txt Normal file
View File

@ -0,0 +1,9 @@
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
Bias: 0.0883882 0.498606
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09

View File

@ -1,9 +1,10 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.25)
project(COSC-4P80-Assignment-2 VERSION 0.0.3) project(COSC-4P80-Assignment-2 VERSION 0.0.4)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF) option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
option(ENABLE_GRAPHICS "Enable usage of graphics package" OFF)
#option(EIGEN_TEST_CXX11 "Enable testing with C++11 and C++11 features (e.g. Tensor module)." ON) #option(EIGEN_TEST_CXX11 "Enable testing with C++11 and C++11 features (e.g. Tensor module)." ON)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
@ -12,7 +13,12 @@ if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release") set(CMAKE_BUILD_TYPE "Release")
endif() endif()
add_subdirectory(lib/blt) if (ENABLE_GRAPHICS)
add_subdirectory(lib/blt-graphics)
add_compile_definitions(BLT_USE_GRAPHICS)
else ()
add_subdirectory(lib/blt)
endif ()
#add_subdirectory(lib/eigen-3.4.0) #add_subdirectory(lib/eigen-3.4.0)
@ -25,7 +31,11 @@ target_compile_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -
target_link_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment) target_link_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
#target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT Eigen3::Eigen) #target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT Eigen3::Eigen)
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT) if (ENABLE_GRAPHICS)
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT_WITH_GRAPHICS)
else ()
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT)
endif ()
if (${ENABLE_ADDRSAN} MATCHES ON) if (${ENABLE_ADDRSAN} MATCHES ON)
target_compile_options(COSC-4P80-Assignment-2 PRIVATE -fsanitize=address) target_compile_options(COSC-4P80-Assignment-2 PRIVATE -fsanitize=address)

View File

@ -19,10 +19,25 @@
#ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H #ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H
#define COSC_4P80_ASSIGNMENT_2_COMMON_H #define COSC_4P80_ASSIGNMENT_2_COMMON_H
#include <iostream>
#include <blt/iterator/enumerate.h>
namespace assign2 namespace assign2
{ {
using Scalar = float; using Scalar = float;
const inline Scalar learn_rate = 0.1;
template<typename T>
decltype(std::cout)& print_vec(const std::vector<T>& vec)
{
for (auto [i, v] : blt::enumerate(vec))
{
std::cout << v;
if (i != vec.size() - 1)
std::cout << ", ";
}
return std::cout;
}
struct data_t struct data_t
{ {
@ -36,15 +51,23 @@ namespace assign2
}; };
class layer_t; class layer_t;
class network_t; class network_t;
struct function_t
{
[[nodiscard]] virtual Scalar call(Scalar) const = 0;
[[nodiscard]] virtual Scalar derivative(Scalar) const = 0;
};
struct weight_view struct weight_view
{ {
public: public:
weight_view(double* data, blt::size_t size): m_data(data), m_size(size) weight_view(Scalar* data, blt::size_t size): m_data(data), m_size(size)
{} {}
inline double& operator[](blt::size_t index) const inline Scalar& operator[](blt::size_t index) const
{ {
#if BLT_DEBUG_LEVEL > 0 #if BLT_DEBUG_LEVEL > 0
if (index >= size) if (index >= size)
@ -69,7 +92,7 @@ namespace assign2
} }
private: private:
double* m_data; Scalar* m_data;
blt::size_t m_size; blt::size_t m_size;
}; };
@ -85,10 +108,17 @@ namespace assign2
data.resize(size + count); data.resize(size + count);
return {&data[size], count}; return {&data[size], count};
} }
void debug() const
{
std::cout << "Weights: ";
print_vec(data) << std::endl;
}
private: private:
std::vector<double> data; std::vector<Scalar> data;
}; };
} }
#endif //COSC_4P80_ASSIGNMENT_2_COMMON_H #endif //COSC_4P80_ASSIGNMENT_2_COMMON_H

View File

@ -24,22 +24,44 @@
namespace assign2 namespace assign2
{ {
struct sigmoid_function struct sigmoid_function : public function_t
{ {
[[nodiscard]] Scalar call(Scalar s) const // NOLINT [[nodiscard]] Scalar call(const Scalar s) const final
{ {
return 1 / (1 + std::exp(-s)); return 1 / (1 + std::exp(-s));
} }
[[nodiscard]] Scalar derivative(Scalar s) const [[nodiscard]] Scalar derivative(const Scalar s) const final
{ {
return call(s) * (1 - call(s)); auto v = call(s);
return v * (1 - v);
} }
}; };
struct linear_function struct threshold_function : public function_t
{ {
[[nodiscard]] Scalar call(const Scalar s) const final
{
return s >= 0 ? 1 : 0;
}
[[nodiscard]] Scalar derivative(Scalar s) const final
{
return 0;
}
};
struct relu_function : public function_t
{
[[nodiscard]] Scalar call(const Scalar s) const final
{
return std::max(static_cast<Scalar>(0), s);
}
[[nodiscard]] Scalar derivative(Scalar s) const final
{
return 0;
}
}; };
} }

View File

@ -28,6 +28,7 @@ namespace assign2
{ {
class neuron_t class neuron_t
{ {
friend layer_t;
public: public:
// empty neuron for loading from a stream // empty neuron for loading from a stream
explicit neuron_t(weight_view weights): weights(weights) explicit neuron_t(weight_view weights): weights(weights)
@ -37,13 +38,13 @@ namespace assign2
explicit neuron_t(weight_view weights, Scalar bias): bias(bias), weights(weights) explicit neuron_t(weight_view weights, Scalar bias): bias(bias), weights(weights)
{} {}
template<typename ActFunc> Scalar activate(const Scalar* inputs, function_t* act_func)
Scalar activate(const Scalar* inputs, ActFunc func) const
{ {
auto sum = bias; z = bias;
for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()})) for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()}))
sum += x * w; z += x * w;
return func.call(sum); a = act_func->call(z);
return a;
} }
template<typename OStream> template<typename OStream>
@ -61,9 +62,17 @@ namespace assign2
stream >> d; stream >> d;
stream >> bias; stream >> bias;
} }
void debug() const
{
std::cout << bias << " ";
}
private: private:
Scalar bias = 0; float z = 0;
float a = 0;
float bias = 0;
float error = 0;
weight_view weights; weight_view weights;
}; };
@ -71,7 +80,8 @@ namespace assign2
{ {
public: public:
template<typename WeightFunc, typename BiasFunc> template<typename WeightFunc, typename BiasFunc>
layer_t(const blt::i32 in, const blt::i32 out, WeightFunc w, BiasFunc b): in_size(in), out_size(out) layer_t(const blt::i32 in, const blt::i32 out, function_t* act_func, WeightFunc w, BiasFunc b):
in_size(in), out_size(out), act_func(act_func)
{ {
neurons.reserve(out_size); neurons.reserve(out_size);
for (blt::i32 i = 0; i < out_size; i++) for (blt::i32 i = 0; i < out_size; i++)
@ -83,8 +93,7 @@ namespace assign2
} }
} }
template<typename ActFunction> std::vector<Scalar> call(const std::vector<Scalar>& in)
std::vector<Scalar> call(const std::vector<Scalar>& in, ActFunction func = ActFunction{})
{ {
std::vector<Scalar> out; std::vector<Scalar> out;
out.reserve(out_size); out.reserve(out_size);
@ -93,10 +102,47 @@ namespace assign2
throw std::runtime_exception("Input vector doesn't match expected input size!"); throw std::runtime_exception("Input vector doesn't match expected input size!");
#endif #endif
for (auto& n : neurons) for (auto& n : neurons)
out.push_back(n.activate(in.data(), func)); out.push_back(n.activate(in.data(), act_func));
return out; return out;
} }
Scalar back_prop(const std::vector<Scalar>& prev_layer_output, Scalar error, const layer_t& next_layer, bool is_output)
{
std::vector<Scalar> dw;
// δ(h)
if (is_output)
{
// assign error to output layer
for (auto& n : neurons)
n.error = act_func->derivative(n.z) * error; // f'act(net(h)) * (error)
} else
{
// first calculate and assign input layer error
std::vector<Scalar> next_error;
next_error.resize(next_layer.neurons.size());
for (const auto& [i, w] : blt::enumerate(next_layer.neurons))
{
for (auto wv : w.weights)
next_error[i] += w.error * wv;
// needed?
next_error[i] /= static_cast<Scalar>(w.weights.size());
}
for (auto& n : neurons)
{
n.error = act_func->derivative(n.z);
}
}
for (const auto& v : prev_layer_output)
{
}
return error_at_current_layer;
}
template<typename OStream> template<typename OStream>
OStream& serialize(OStream& stream) OStream& serialize(OStream& stream)
{ {
@ -120,9 +166,20 @@ namespace assign2
{ {
return out_size; return out_size;
} }
void debug() const
{
std::cout << "Bias: ";
for (auto& v : neurons)
v.debug();
std::cout << std::endl;
weights.debug();
}
private: private:
const blt::i32 in_size, out_size; const blt::i32 in_size, out_size;
weight_t weights; weight_t weights;
function_t* act_func;
std::vector<neuron_t> neurons; std::vector<neuron_t> neurons;
}; };
} }

View File

@ -21,6 +21,7 @@
#include <assign2/common.h> #include <assign2/common.h>
#include <assign2/layer.h> #include <assign2/layer.h>
#include "blt/std/assert.h"
namespace assign2 namespace assign2
{ {
@ -75,8 +76,7 @@ namespace assign2
network_t() = default; network_t() = default;
template<typename ActFunc, typename ActFuncOut> std::vector<Scalar> execute(const std::vector<Scalar>& input)
std::vector<Scalar> execute(const std::vector<Scalar>& input, ActFunc func, ActFuncOut outFunc)
{ {
std::vector<Scalar> previous_output; std::vector<Scalar> previous_output;
std::vector<Scalar> current_output; std::vector<Scalar> current_output;
@ -85,39 +85,45 @@ namespace assign2
{ {
previous_output = current_output; previous_output = current_output;
if (i == 0) if (i == 0)
current_output = v.call(input, func); current_output = v.call(input);
else if (i == layers.size() - 1)
current_output = v.call(previous_output, outFunc);
else else
current_output = v.call(previous_output, func); current_output = v.call(previous_output);
} }
return current_output; return current_output;
} }
std::pair<Scalar, Scalar> error(const std::vector<Scalar>& outputs, bool is_bad)
{
BLT_ASSERT(outputs.size() == 2);
auto g = is_bad ? 0.0f : 1.0f;
auto b = is_bad ? 1.0f : 0.0f;
auto g_diff = outputs[0] - g;
auto b_diff = outputs[1] - b;
auto error = g_diff * g_diff + b_diff * b_diff;
BLT_INFO("%f %f %f", error, g_diff, b_diff);
return {0.5f * (error * error), error};
}
Scalar train(const data_file_t& example) Scalar train(const data_file_t& example)
{ {
const Scalar learn_rate = 0.1;
Scalar total_error = 0; Scalar total_error = 0;
Scalar total_d_error = 0;
for (const auto& x : example.data_points) for (const auto& x : example.data_points)
{ {
auto o = execute(x.bins, sigmoid_function{}, sigmoid_function{}); print_vec(x.bins) << std::endl;
auto y = x.is_bad ? 1.0f : 0.0f; auto o = execute(x.bins);
print_vec(o) << std::endl;
Scalar is_bad = 0; auto [e, de] = error(o, x.is_bad);
if (o[0] >= 1) total_error += e;
is_bad = 0; total_d_error += -learn_rate * de;
else if (o[1] >= 1) BLT_TRACE("\tError %f, %f, is bad? %s", e, -learn_rate * de, x.is_bad ? "True" : "False");
is_bad = 1;
auto error = y - is_bad;
if (o[0] >= 1 && o[1] >= 1)
error += 1;
total_error += error;
} }
BLT_DEBUG("Total Errors found %f, %f", total_error, total_d_error);
return total_error; return total_error;
} }

1
lib/blt-graphics Submodule

@ -0,0 +1 @@
Subproject commit 8103a3ad0ff5341c61ba07c4b1b9803b2cc740e9

View File

@ -219,9 +219,9 @@ public:
, output_dims(output_dims_), output_strides(output_strides_) , output_dims(output_dims_), output_strides(output_strides_)
{} {}
void operator()(const Scalar* output_data) const void operator()(const Scalar* neuron_data) const
{ {
check_recursive(input_data, output_data); check_recursive(input_data, neuron_data);
} }
}; };

View File

@ -86,7 +86,7 @@ static void test_block_io_copy_data_from_source_to_target() {
auto output_strides = internal::strides<Layout>(dims); auto output_strides = internal::strides<Layout>(dims);
const T* input_data = input.data(); const T* input_data = input.data();
T* output_data = output.data(); T* neuron_data = output.data();
T* block_data = block.data(); T* block_data = block.data();
for (int i = 0; i < block_mapper.blockCount(); ++i) { for (int i = 0; i < block_mapper.blockCount(); ++i) {
@ -105,7 +105,7 @@ static void test_block_io_copy_data_from_source_to_target() {
{ {
// Write from block buffer to output. // Write from block buffer to output.
IODst dst(blk_dims, output_strides, output_data, desc.offset()); IODst dst(blk_dims, output_strides, neuron_data, desc.offset());
IOSrc src(blk_strides, block_data, 0); IOSrc src(blk_strides, block_data, 0);
TensorBlockIO::Copy(dst, src); TensorBlockIO::Copy(dst, src);
@ -113,7 +113,7 @@ static void test_block_io_copy_data_from_source_to_target() {
} }
for (int i = 0; i < dims.TotalSize(); ++i) { for (int i = 0; i < dims.TotalSize(); ++i) {
VERIFY_IS_EQUAL(input_data[i], output_data[i]); VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
} }
} }
@ -159,7 +159,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
auto output_strides = internal::strides<Layout>(output_tensor_dims); auto output_strides = internal::strides<Layout>(output_tensor_dims);
const T* input_data = input.data(); const T* input_data = input.data();
T* output_data = output.data(); T* neuron_data = output.data();
T* block_data = block.data(); T* block_data = block.data();
for (Index i = 0; i < block_mapper.blockCount(); ++i) { for (Index i = 0; i < block_mapper.blockCount(); ++i) {
@ -198,7 +198,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
} }
// Write from block buffer to output. // Write from block buffer to output.
IODst dst(dst_dims, input_strides, output_data, first_coeff_index); IODst dst(dst_dims, input_strides, neuron_data, first_coeff_index);
IOSrc src(blk_strides, block_data, 0); IOSrc src(blk_strides, block_data, 0);
// TODO(ezhulenev): Remove when fully switched to TensorBlock. // TODO(ezhulenev): Remove when fully switched to TensorBlock.
@ -210,7 +210,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
} }
for (Index i = 0; i < dims.TotalSize(); ++i) { for (Index i = 0; i < dims.TotalSize(); ++i) {
VERIFY_IS_EQUAL(input_data[i], output_data[i]); VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
} }
} }

View File

@ -68,18 +68,6 @@ std::vector<data_file_t> load_data_files(const std::vector<std::string>& files)
return loaded_data; return loaded_data;
} }
template<typename T>
decltype(std::cout)& print_vec(const std::vector<T>& vec)
{
for (auto [i, v] : blt::enumerate(vec))
{
std::cout << v;
if (i != vec.size() - 1)
std::cout << ", ";
}
return std::cout;
}
int main(int argc, const char** argv) int main(int argc, const char** argv)
{ {
blt::arg_parse parser; blt::arg_parse parser;
@ -90,6 +78,22 @@ int main(int argc, const char** argv)
auto data_files = load_data_files(get_data_files(data_directory)); auto data_files = load_data_files(get_data_files(data_directory));
random_init randomizer{619};
sigmoid_function sig;
relu_function relu;
threshold_function thresh;
layer_t layer1{16, 8, &sig, randomizer, randomizer};
layer1.debug();
layer_t layer2{8, 8, &sig, randomizer, randomizer};
layer2.debug();
layer_t layer3{8, 8, &sig, randomizer, randomizer};
layer3.debug();
layer_t layer_output{8, 2, &relu, randomizer, randomizer};
layer_output.debug();
network_t network{{layer1, layer2, layer3, layer_output}};
std::vector<Scalar> input; std::vector<Scalar> input;
input.resize(16); input.resize(16);
for (auto f : data_files) for (auto f : data_files)
@ -98,20 +102,12 @@ int main(int argc, const char** argv)
{ {
for (auto [i, b] : blt::enumerate(f.data_points.begin()->bins)) for (auto [i, b] : blt::enumerate(f.data_points.begin()->bins))
input[i] = b; input[i] = b;
network.train(f);
break;
} }
} }
random_init randomizer{619}; auto output = network.execute(input);
sigmoid_function sig;
layer_t layer1{16, 4, randomizer, empty_init{}};
layer_t layer2{4, 4, randomizer, empty_init{}};
layer_t layer3{4, 4, randomizer, empty_init{}};
layer_t layer_output{4, 1, randomizer, empty_init{}};
network_t network{{layer1, layer2, layer3, layer_output}};
auto output = network.execute(input, sig, sig);
print_vec(output) << std::endl; print_vec(output) << std::endl;
// for (auto d : data_files) // for (auto d : data_files)