getting closer!
parent
1b79238114
commit
4216b53b28
|
@ -1,3 +1,6 @@
|
||||||
[submodule "lib/blt"]
|
[submodule "lib/blt"]
|
||||||
path = lib/blt
|
path = lib/blt
|
||||||
url = https://github.com/Tri11Paragon/BLT.git
|
url = https://github.com/Tri11Paragon/BLT.git
|
||||||
|
[submodule "lib/blt-graphics"]
|
||||||
|
path = lib/blt-graphics
|
||||||
|
url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
<component name="VcsDirectoryMappings">
|
<component name="VcsDirectoryMappings">
|
||||||
<mapping directory="" vcs="Git" />
|
<mapping directory="" vcs="Git" />
|
||||||
<mapping directory="$PROJECT_DIR$/lib/blt" vcs="Git" />
|
<mapping directory="$PROJECT_DIR$/lib/blt" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/lib/blt-graphics" vcs="Git" />
|
||||||
<mapping directory="$PROJECT_DIR$/lib/blt/libraries/parallel-hashmap" vcs="Git" />
|
<mapping directory="$PROJECT_DIR$/lib/blt/libraries/parallel-hashmap" vcs="Git" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
|
@ -0,0 +1,9 @@
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
|
||||||
|
Bias: 0.0883882 0.498606
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
|
||||||
|
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09
|
|
@ -0,0 +1,9 @@
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
|
||||||
|
Bias: 0.0883882 0.498606
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
|
||||||
|
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09
|
|
@ -0,0 +1,9 @@
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
|
||||||
|
Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
|
||||||
|
Bias: 0.0883882 0.498606
|
||||||
|
Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
|
||||||
|
389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09
|
|
@ -1,9 +1,10 @@
|
||||||
cmake_minimum_required(VERSION 3.25)
|
cmake_minimum_required(VERSION 3.25)
|
||||||
project(COSC-4P80-Assignment-2 VERSION 0.0.3)
|
project(COSC-4P80-Assignment-2 VERSION 0.0.4)
|
||||||
|
|
||||||
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
|
||||||
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
|
||||||
option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
|
option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
|
||||||
|
option(ENABLE_GRAPHICS "Enable usage of graphics package" OFF)
|
||||||
#option(EIGEN_TEST_CXX11 "Enable testing with C++11 and C++11 features (e.g. Tensor module)." ON)
|
#option(EIGEN_TEST_CXX11 "Enable testing with C++11 and C++11 features (e.g. Tensor module)." ON)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
@ -12,7 +13,12 @@ if (NOT CMAKE_BUILD_TYPE)
|
||||||
set(CMAKE_BUILD_TYPE "Release")
|
set(CMAKE_BUILD_TYPE "Release")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(lib/blt)
|
if (ENABLE_GRAPHICS)
|
||||||
|
add_subdirectory(lib/blt-graphics)
|
||||||
|
add_compile_definitions(BLT_USE_GRAPHICS)
|
||||||
|
else ()
|
||||||
|
add_subdirectory(lib/blt)
|
||||||
|
endif ()
|
||||||
|
|
||||||
#add_subdirectory(lib/eigen-3.4.0)
|
#add_subdirectory(lib/eigen-3.4.0)
|
||||||
|
|
||||||
|
@ -25,7 +31,11 @@ target_compile_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -
|
||||||
target_link_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
|
target_link_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
|
||||||
|
|
||||||
#target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT Eigen3::Eigen)
|
#target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT Eigen3::Eigen)
|
||||||
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT)
|
if (ENABLE_GRAPHICS)
|
||||||
|
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT_WITH_GRAPHICS)
|
||||||
|
else ()
|
||||||
|
target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT)
|
||||||
|
endif ()
|
||||||
|
|
||||||
if (${ENABLE_ADDRSAN} MATCHES ON)
|
if (${ENABLE_ADDRSAN} MATCHES ON)
|
||||||
target_compile_options(COSC-4P80-Assignment-2 PRIVATE -fsanitize=address)
|
target_compile_options(COSC-4P80-Assignment-2 PRIVATE -fsanitize=address)
|
||||||
|
|
|
@ -19,10 +19,25 @@
|
||||||
#ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H
|
#ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H
|
||||||
#define COSC_4P80_ASSIGNMENT_2_COMMON_H
|
#define COSC_4P80_ASSIGNMENT_2_COMMON_H
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <blt/iterator/enumerate.h>
|
||||||
|
|
||||||
namespace assign2
|
namespace assign2
|
||||||
{
|
{
|
||||||
using Scalar = float;
|
using Scalar = float;
|
||||||
|
const inline Scalar learn_rate = 0.1;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
decltype(std::cout)& print_vec(const std::vector<T>& vec)
|
||||||
|
{
|
||||||
|
for (auto [i, v] : blt::enumerate(vec))
|
||||||
|
{
|
||||||
|
std::cout << v;
|
||||||
|
if (i != vec.size() - 1)
|
||||||
|
std::cout << ", ";
|
||||||
|
}
|
||||||
|
return std::cout;
|
||||||
|
}
|
||||||
|
|
||||||
struct data_t
|
struct data_t
|
||||||
{
|
{
|
||||||
|
@ -36,15 +51,23 @@ namespace assign2
|
||||||
};
|
};
|
||||||
|
|
||||||
class layer_t;
|
class layer_t;
|
||||||
|
|
||||||
class network_t;
|
class network_t;
|
||||||
|
|
||||||
|
struct function_t
|
||||||
|
{
|
||||||
|
[[nodiscard]] virtual Scalar call(Scalar) const = 0;
|
||||||
|
|
||||||
|
[[nodiscard]] virtual Scalar derivative(Scalar) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
struct weight_view
|
struct weight_view
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
weight_view(double* data, blt::size_t size): m_data(data), m_size(size)
|
weight_view(Scalar* data, blt::size_t size): m_data(data), m_size(size)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
inline double& operator[](blt::size_t index) const
|
inline Scalar& operator[](blt::size_t index) const
|
||||||
{
|
{
|
||||||
#if BLT_DEBUG_LEVEL > 0
|
#if BLT_DEBUG_LEVEL > 0
|
||||||
if (index >= size)
|
if (index >= size)
|
||||||
|
@ -69,7 +92,7 @@ namespace assign2
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double* m_data;
|
Scalar* m_data;
|
||||||
blt::size_t m_size;
|
blt::size_t m_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -85,10 +108,17 @@ namespace assign2
|
||||||
data.resize(size + count);
|
data.resize(size + count);
|
||||||
return {&data[size], count};
|
return {&data[size], count};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void debug() const
|
||||||
|
{
|
||||||
|
std::cout << "Weights: ";
|
||||||
|
print_vec(data) << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<double> data;
|
std::vector<Scalar> data;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif //COSC_4P80_ASSIGNMENT_2_COMMON_H
|
#endif //COSC_4P80_ASSIGNMENT_2_COMMON_H
|
||||||
|
|
|
@ -24,22 +24,44 @@
|
||||||
|
|
||||||
namespace assign2
|
namespace assign2
|
||||||
{
|
{
|
||||||
struct sigmoid_function
|
struct sigmoid_function : public function_t
|
||||||
{
|
{
|
||||||
[[nodiscard]] Scalar call(Scalar s) const // NOLINT
|
[[nodiscard]] Scalar call(const Scalar s) const final
|
||||||
{
|
{
|
||||||
return 1 / (1 + std::exp(-s));
|
return 1 / (1 + std::exp(-s));
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Scalar derivative(Scalar s) const
|
[[nodiscard]] Scalar derivative(const Scalar s) const final
|
||||||
{
|
{
|
||||||
return call(s) * (1 - call(s));
|
auto v = call(s);
|
||||||
|
return v * (1 - v);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct linear_function
|
struct threshold_function : public function_t
|
||||||
{
|
{
|
||||||
|
[[nodiscard]] Scalar call(const Scalar s) const final
|
||||||
|
{
|
||||||
|
return s >= 0 ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Scalar derivative(Scalar s) const final
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct relu_function : public function_t
|
||||||
|
{
|
||||||
|
[[nodiscard]] Scalar call(const Scalar s) const final
|
||||||
|
{
|
||||||
|
return std::max(static_cast<Scalar>(0), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Scalar derivative(Scalar s) const final
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ namespace assign2
|
||||||
{
|
{
|
||||||
class neuron_t
|
class neuron_t
|
||||||
{
|
{
|
||||||
|
friend layer_t;
|
||||||
public:
|
public:
|
||||||
// empty neuron for loading from a stream
|
// empty neuron for loading from a stream
|
||||||
explicit neuron_t(weight_view weights): weights(weights)
|
explicit neuron_t(weight_view weights): weights(weights)
|
||||||
|
@ -37,13 +38,13 @@ namespace assign2
|
||||||
explicit neuron_t(weight_view weights, Scalar bias): bias(bias), weights(weights)
|
explicit neuron_t(weight_view weights, Scalar bias): bias(bias), weights(weights)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
template<typename ActFunc>
|
Scalar activate(const Scalar* inputs, function_t* act_func)
|
||||||
Scalar activate(const Scalar* inputs, ActFunc func) const
|
|
||||||
{
|
{
|
||||||
auto sum = bias;
|
z = bias;
|
||||||
for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()}))
|
for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()}))
|
||||||
sum += x * w;
|
z += x * w;
|
||||||
return func.call(sum);
|
a = act_func->call(z);
|
||||||
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename OStream>
|
template<typename OStream>
|
||||||
|
@ -61,9 +62,17 @@ namespace assign2
|
||||||
stream >> d;
|
stream >> d;
|
||||||
stream >> bias;
|
stream >> bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void debug() const
|
||||||
|
{
|
||||||
|
std::cout << bias << " ";
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Scalar bias = 0;
|
float z = 0;
|
||||||
|
float a = 0;
|
||||||
|
float bias = 0;
|
||||||
|
float error = 0;
|
||||||
weight_view weights;
|
weight_view weights;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,7 +80,8 @@ namespace assign2
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
template<typename WeightFunc, typename BiasFunc>
|
template<typename WeightFunc, typename BiasFunc>
|
||||||
layer_t(const blt::i32 in, const blt::i32 out, WeightFunc w, BiasFunc b): in_size(in), out_size(out)
|
layer_t(const blt::i32 in, const blt::i32 out, function_t* act_func, WeightFunc w, BiasFunc b):
|
||||||
|
in_size(in), out_size(out), act_func(act_func)
|
||||||
{
|
{
|
||||||
neurons.reserve(out_size);
|
neurons.reserve(out_size);
|
||||||
for (blt::i32 i = 0; i < out_size; i++)
|
for (blt::i32 i = 0; i < out_size; i++)
|
||||||
|
@ -83,8 +93,7 @@ namespace assign2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename ActFunction>
|
std::vector<Scalar> call(const std::vector<Scalar>& in)
|
||||||
std::vector<Scalar> call(const std::vector<Scalar>& in, ActFunction func = ActFunction{})
|
|
||||||
{
|
{
|
||||||
std::vector<Scalar> out;
|
std::vector<Scalar> out;
|
||||||
out.reserve(out_size);
|
out.reserve(out_size);
|
||||||
|
@ -93,10 +102,47 @@ namespace assign2
|
||||||
throw std::runtime_exception("Input vector doesn't match expected input size!");
|
throw std::runtime_exception("Input vector doesn't match expected input size!");
|
||||||
#endif
|
#endif
|
||||||
for (auto& n : neurons)
|
for (auto& n : neurons)
|
||||||
out.push_back(n.activate(in.data(), func));
|
out.push_back(n.activate(in.data(), act_func));
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Scalar back_prop(const std::vector<Scalar>& prev_layer_output, Scalar error, const layer_t& next_layer, bool is_output)
|
||||||
|
{
|
||||||
|
std::vector<Scalar> dw;
|
||||||
|
|
||||||
|
// δ(h)
|
||||||
|
if (is_output)
|
||||||
|
{
|
||||||
|
// assign error to output layer
|
||||||
|
for (auto& n : neurons)
|
||||||
|
n.error = act_func->derivative(n.z) * error; // f'act(net(h)) * (error)
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
// first calculate and assign input layer error
|
||||||
|
std::vector<Scalar> next_error;
|
||||||
|
next_error.resize(next_layer.neurons.size());
|
||||||
|
for (const auto& [i, w] : blt::enumerate(next_layer.neurons))
|
||||||
|
{
|
||||||
|
for (auto wv : w.weights)
|
||||||
|
next_error[i] += w.error * wv;
|
||||||
|
// needed?
|
||||||
|
next_error[i] /= static_cast<Scalar>(w.weights.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& n : neurons)
|
||||||
|
{
|
||||||
|
n.error = act_func->derivative(n.z);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& v : prev_layer_output)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return error_at_current_layer;
|
||||||
|
}
|
||||||
|
|
||||||
template<typename OStream>
|
template<typename OStream>
|
||||||
OStream& serialize(OStream& stream)
|
OStream& serialize(OStream& stream)
|
||||||
{
|
{
|
||||||
|
@ -120,9 +166,20 @@ namespace assign2
|
||||||
{
|
{
|
||||||
return out_size;
|
return out_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void debug() const
|
||||||
|
{
|
||||||
|
std::cout << "Bias: ";
|
||||||
|
for (auto& v : neurons)
|
||||||
|
v.debug();
|
||||||
|
std::cout << std::endl;
|
||||||
|
weights.debug();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const blt::i32 in_size, out_size;
|
const blt::i32 in_size, out_size;
|
||||||
weight_t weights;
|
weight_t weights;
|
||||||
|
function_t* act_func;
|
||||||
std::vector<neuron_t> neurons;
|
std::vector<neuron_t> neurons;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
#include <assign2/common.h>
|
#include <assign2/common.h>
|
||||||
#include <assign2/layer.h>
|
#include <assign2/layer.h>
|
||||||
|
#include "blt/std/assert.h"
|
||||||
|
|
||||||
namespace assign2
|
namespace assign2
|
||||||
{
|
{
|
||||||
|
@ -75,8 +76,7 @@ namespace assign2
|
||||||
|
|
||||||
network_t() = default;
|
network_t() = default;
|
||||||
|
|
||||||
template<typename ActFunc, typename ActFuncOut>
|
std::vector<Scalar> execute(const std::vector<Scalar>& input)
|
||||||
std::vector<Scalar> execute(const std::vector<Scalar>& input, ActFunc func, ActFuncOut outFunc)
|
|
||||||
{
|
{
|
||||||
std::vector<Scalar> previous_output;
|
std::vector<Scalar> previous_output;
|
||||||
std::vector<Scalar> current_output;
|
std::vector<Scalar> current_output;
|
||||||
|
@ -85,39 +85,45 @@ namespace assign2
|
||||||
{
|
{
|
||||||
previous_output = current_output;
|
previous_output = current_output;
|
||||||
if (i == 0)
|
if (i == 0)
|
||||||
current_output = v.call(input, func);
|
current_output = v.call(input);
|
||||||
else if (i == layers.size() - 1)
|
|
||||||
current_output = v.call(previous_output, outFunc);
|
|
||||||
else
|
else
|
||||||
current_output = v.call(previous_output, func);
|
current_output = v.call(previous_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
return current_output;
|
return current_output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Scalar, Scalar> error(const std::vector<Scalar>& outputs, bool is_bad)
|
||||||
|
{
|
||||||
|
BLT_ASSERT(outputs.size() == 2);
|
||||||
|
auto g = is_bad ? 0.0f : 1.0f;
|
||||||
|
auto b = is_bad ? 1.0f : 0.0f;
|
||||||
|
|
||||||
|
auto g_diff = outputs[0] - g;
|
||||||
|
auto b_diff = outputs[1] - b;
|
||||||
|
|
||||||
|
auto error = g_diff * g_diff + b_diff * b_diff;
|
||||||
|
BLT_INFO("%f %f %f", error, g_diff, b_diff);
|
||||||
|
|
||||||
|
return {0.5f * (error * error), error};
|
||||||
|
}
|
||||||
|
|
||||||
Scalar train(const data_file_t& example)
|
Scalar train(const data_file_t& example)
|
||||||
{
|
{
|
||||||
const Scalar learn_rate = 0.1;
|
|
||||||
|
|
||||||
Scalar total_error = 0;
|
Scalar total_error = 0;
|
||||||
|
Scalar total_d_error = 0;
|
||||||
for (const auto& x : example.data_points)
|
for (const auto& x : example.data_points)
|
||||||
{
|
{
|
||||||
auto o = execute(x.bins, sigmoid_function{}, sigmoid_function{});
|
print_vec(x.bins) << std::endl;
|
||||||
auto y = x.is_bad ? 1.0f : 0.0f;
|
auto o = execute(x.bins);
|
||||||
|
print_vec(o) << std::endl;
|
||||||
Scalar is_bad = 0;
|
auto [e, de] = error(o, x.is_bad);
|
||||||
if (o[0] >= 1)
|
total_error += e;
|
||||||
is_bad = 0;
|
total_d_error += -learn_rate * de;
|
||||||
else if (o[1] >= 1)
|
BLT_TRACE("\tError %f, %f, is bad? %s", e, -learn_rate * de, x.is_bad ? "True" : "False");
|
||||||
is_bad = 1;
|
|
||||||
|
|
||||||
auto error = y - is_bad;
|
|
||||||
if (o[0] >= 1 && o[1] >= 1)
|
|
||||||
error += 1;
|
|
||||||
|
|
||||||
total_error += error;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
BLT_DEBUG("Total Errors found %f, %f", total_error, total_d_error);
|
||||||
|
|
||||||
return total_error;
|
return total_error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 8103a3ad0ff5341c61ba07c4b1b9803b2cc740e9
|
|
@ -219,9 +219,9 @@ public:
|
||||||
, output_dims(output_dims_), output_strides(output_strides_)
|
, output_dims(output_dims_), output_strides(output_strides_)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
void operator()(const Scalar* output_data) const
|
void operator()(const Scalar* neuron_data) const
|
||||||
{
|
{
|
||||||
check_recursive(input_data, output_data);
|
check_recursive(input_data, neuron_data);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ static void test_block_io_copy_data_from_source_to_target() {
|
||||||
auto output_strides = internal::strides<Layout>(dims);
|
auto output_strides = internal::strides<Layout>(dims);
|
||||||
|
|
||||||
const T* input_data = input.data();
|
const T* input_data = input.data();
|
||||||
T* output_data = output.data();
|
T* neuron_data = output.data();
|
||||||
T* block_data = block.data();
|
T* block_data = block.data();
|
||||||
|
|
||||||
for (int i = 0; i < block_mapper.blockCount(); ++i) {
|
for (int i = 0; i < block_mapper.blockCount(); ++i) {
|
||||||
|
@ -105,7 +105,7 @@ static void test_block_io_copy_data_from_source_to_target() {
|
||||||
|
|
||||||
{
|
{
|
||||||
// Write from block buffer to output.
|
// Write from block buffer to output.
|
||||||
IODst dst(blk_dims, output_strides, output_data, desc.offset());
|
IODst dst(blk_dims, output_strides, neuron_data, desc.offset());
|
||||||
IOSrc src(blk_strides, block_data, 0);
|
IOSrc src(blk_strides, block_data, 0);
|
||||||
|
|
||||||
TensorBlockIO::Copy(dst, src);
|
TensorBlockIO::Copy(dst, src);
|
||||||
|
@ -113,7 +113,7 @@ static void test_block_io_copy_data_from_source_to_target() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < dims.TotalSize(); ++i) {
|
for (int i = 0; i < dims.TotalSize(); ++i) {
|
||||||
VERIFY_IS_EQUAL(input_data[i], output_data[i]);
|
VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
|
||||||
auto output_strides = internal::strides<Layout>(output_tensor_dims);
|
auto output_strides = internal::strides<Layout>(output_tensor_dims);
|
||||||
|
|
||||||
const T* input_data = input.data();
|
const T* input_data = input.data();
|
||||||
T* output_data = output.data();
|
T* neuron_data = output.data();
|
||||||
T* block_data = block.data();
|
T* block_data = block.data();
|
||||||
|
|
||||||
for (Index i = 0; i < block_mapper.blockCount(); ++i) {
|
for (Index i = 0; i < block_mapper.blockCount(); ++i) {
|
||||||
|
@ -198,7 +198,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write from block buffer to output.
|
// Write from block buffer to output.
|
||||||
IODst dst(dst_dims, input_strides, output_data, first_coeff_index);
|
IODst dst(dst_dims, input_strides, neuron_data, first_coeff_index);
|
||||||
IOSrc src(blk_strides, block_data, 0);
|
IOSrc src(blk_strides, block_data, 0);
|
||||||
|
|
||||||
// TODO(ezhulenev): Remove when fully switched to TensorBlock.
|
// TODO(ezhulenev): Remove when fully switched to TensorBlock.
|
||||||
|
@ -210,7 +210,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Index i = 0; i < dims.TotalSize(); ++i) {
|
for (Index i = 0; i < dims.TotalSize(); ++i) {
|
||||||
VERIFY_IS_EQUAL(input_data[i], output_data[i]);
|
VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
42
src/main.cpp
42
src/main.cpp
|
@ -68,18 +68,6 @@ std::vector<data_file_t> load_data_files(const std::vector<std::string>& files)
|
||||||
return loaded_data;
|
return loaded_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
decltype(std::cout)& print_vec(const std::vector<T>& vec)
|
|
||||||
{
|
|
||||||
for (auto [i, v] : blt::enumerate(vec))
|
|
||||||
{
|
|
||||||
std::cout << v;
|
|
||||||
if (i != vec.size() - 1)
|
|
||||||
std::cout << ", ";
|
|
||||||
}
|
|
||||||
return std::cout;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, const char** argv)
|
int main(int argc, const char** argv)
|
||||||
{
|
{
|
||||||
blt::arg_parse parser;
|
blt::arg_parse parser;
|
||||||
|
@ -90,6 +78,22 @@ int main(int argc, const char** argv)
|
||||||
|
|
||||||
auto data_files = load_data_files(get_data_files(data_directory));
|
auto data_files = load_data_files(get_data_files(data_directory));
|
||||||
|
|
||||||
|
random_init randomizer{619};
|
||||||
|
sigmoid_function sig;
|
||||||
|
relu_function relu;
|
||||||
|
threshold_function thresh;
|
||||||
|
|
||||||
|
layer_t layer1{16, 8, &sig, randomizer, randomizer};
|
||||||
|
layer1.debug();
|
||||||
|
layer_t layer2{8, 8, &sig, randomizer, randomizer};
|
||||||
|
layer2.debug();
|
||||||
|
layer_t layer3{8, 8, &sig, randomizer, randomizer};
|
||||||
|
layer3.debug();
|
||||||
|
layer_t layer_output{8, 2, &relu, randomizer, randomizer};
|
||||||
|
layer_output.debug();
|
||||||
|
|
||||||
|
network_t network{{layer1, layer2, layer3, layer_output}};
|
||||||
|
|
||||||
std::vector<Scalar> input;
|
std::vector<Scalar> input;
|
||||||
input.resize(16);
|
input.resize(16);
|
||||||
for (auto f : data_files)
|
for (auto f : data_files)
|
||||||
|
@ -98,20 +102,12 @@ int main(int argc, const char** argv)
|
||||||
{
|
{
|
||||||
for (auto [i, b] : blt::enumerate(f.data_points.begin()->bins))
|
for (auto [i, b] : blt::enumerate(f.data_points.begin()->bins))
|
||||||
input[i] = b;
|
input[i] = b;
|
||||||
|
network.train(f);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
random_init randomizer{619};
|
auto output = network.execute(input);
|
||||||
sigmoid_function sig;
|
|
||||||
|
|
||||||
layer_t layer1{16, 4, randomizer, empty_init{}};
|
|
||||||
layer_t layer2{4, 4, randomizer, empty_init{}};
|
|
||||||
layer_t layer3{4, 4, randomizer, empty_init{}};
|
|
||||||
layer_t layer_output{4, 1, randomizer, empty_init{}};
|
|
||||||
|
|
||||||
network_t network{{layer1, layer2, layer3, layer_output}};
|
|
||||||
|
|
||||||
auto output = network.execute(input, sig, sig);
|
|
||||||
print_vec(output) << std::endl;
|
print_vec(output) << std::endl;
|
||||||
|
|
||||||
// for (auto d : data_files)
|
// for (auto d : data_files)
|
||||||
|
|
Loading…
Reference in New Issue