From 4216b53b28badf8b7b87e0f100b38f7de834094a Mon Sep 17 00:00:00 2001 From: Brett Date: Fri, 25 Oct 2024 01:22:32 -0400 Subject: [PATCH] getting closer! --- .gitmodules | 3 + .idea/vcs.xml | 1 + 1.txt | 9 +++ 2.txt | 9 +++ 3.txt | 9 +++ CMakeLists.txt | 16 +++- include/assign2/common.h | 38 ++++++++- include/assign2/functions.h | 32 ++++++-- include/assign2/layer.h | 77 ++++++++++++++++--- include/assign2/network.h | 52 +++++++------ lib/blt-graphics | 1 + .../test/cxx11_tensor_block_access.cpp | 4 +- .../test/cxx11_tensor_block_io.cpp | 12 +-- src/main.cpp | 42 +++++----- 14 files changed, 229 insertions(+), 76 deletions(-) create mode 100644 1.txt create mode 100644 2.txt create mode 100644 3.txt create mode 160000 lib/blt-graphics diff --git a/.gitmodules b/.gitmodules index 484b3a7..cd52502 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "lib/blt"] path = lib/blt url = https://github.com/Tri11Paragon/BLT.git +[submodule "lib/blt-graphics"] + path = lib/blt-graphics + url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 0a294a1..ebc05d3 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -3,6 +3,7 @@ + \ No newline at end of file diff --git a/1.txt b/1.txt new file mode 100644 index 0000000..277e5ce --- /dev/null +++ b/1.txt @@ -0,0 +1,9 @@ +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028 +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729 +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729 +Bias: 0.0883882 0.498606 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186 +389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09 diff --git a/2.txt b/2.txt new file mode 100644 index 0000000..277e5ce --- /dev/null +++ b/2.txt @@ -0,0 +1,9 @@ +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028 +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729 +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729 +Bias: 0.0883882 0.498606 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186 +389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09 diff --git a/3.txt b/3.txt new file mode 100644 index 0000000..277e5ce --- /dev/null +++ b/3.txt @@ -0,0 +1,9 @@ +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028 +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729 +Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729 +Bias: 0.0883882 0.498606 +Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186 +389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09 diff --git a/CMakeLists.txt b/CMakeLists.txt index 7823189..2143028 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,9 +1,10 @@ cmake_minimum_required(VERSION 3.25) -project(COSC-4P80-Assignment-2 VERSION 0.0.3) +project(COSC-4P80-Assignment-2 VERSION 0.0.4) option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF) option(ENABLE_UBSAN "Enable the ub sanitizer" OFF) option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF) +option(ENABLE_GRAPHICS "Enable usage of graphics package" OFF) #option(EIGEN_TEST_CXX11 "Enable testing with C++11 and C++11 features (e.g. Tensor module)." ON) set(CMAKE_CXX_STANDARD 17) @@ -12,7 +13,12 @@ if (NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() -add_subdirectory(lib/blt) +if (ENABLE_GRAPHICS) + add_subdirectory(lib/blt-graphics) + add_compile_definitions(BLT_USE_GRAPHICS) +else () + add_subdirectory(lib/blt) +endif () #add_subdirectory(lib/eigen-3.4.0) @@ -25,7 +31,11 @@ target_compile_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic - target_link_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment) #target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT Eigen3::Eigen) -target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT) +if (ENABLE_GRAPHICS) + target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT_WITH_GRAPHICS) +else () + target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT) +endif () if (${ENABLE_ADDRSAN} MATCHES ON) target_compile_options(COSC-4P80-Assignment-2 PRIVATE -fsanitize=address) diff --git a/include/assign2/common.h b/include/assign2/common.h index 705a86d..1b3af1c 100644 --- a/include/assign2/common.h +++ b/include/assign2/common.h @@ -19,10 +19,25 @@ #ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H #define COSC_4P80_ASSIGNMENT_2_COMMON_H +#include +#include namespace assign2 { using Scalar = float; + const inline Scalar learn_rate = 0.1; + + template + decltype(std::cout)& print_vec(const std::vector& vec) + { + for (auto [i, v] : blt::enumerate(vec)) + { + std::cout << v; + if (i != vec.size() - 1) + std::cout << ", "; + } + return std::cout; + } struct data_t { @@ -36,15 +51,23 @@ namespace assign2 }; class layer_t; + class network_t; + struct function_t + { + [[nodiscard]] virtual Scalar call(Scalar) const = 0; + + [[nodiscard]] virtual Scalar derivative(Scalar) const = 0; + }; + struct weight_view { public: - weight_view(double* data, blt::size_t size): m_data(data), m_size(size) + weight_view(Scalar* data, blt::size_t size): m_data(data), m_size(size) {} - inline double& operator[](blt::size_t index) const + inline Scalar& operator[](blt::size_t index) const { #if BLT_DEBUG_LEVEL > 0 if (index >= size) @@ -69,7 +92,7 @@ namespace assign2 } private: - double* m_data; + Scalar* m_data; blt::size_t m_size; }; @@ -85,10 +108,17 @@ namespace assign2 data.resize(size + count); return {&data[size], count}; } + + void debug() const + { + std::cout << "Weights: "; + print_vec(data) << std::endl; + } private: - std::vector data; + std::vector data; }; + } #endif //COSC_4P80_ASSIGNMENT_2_COMMON_H diff --git a/include/assign2/functions.h b/include/assign2/functions.h index 0f60d39..aa2b739 100644 --- a/include/assign2/functions.h +++ b/include/assign2/functions.h @@ -24,22 +24,44 @@ namespace assign2 { - struct sigmoid_function + struct sigmoid_function : public function_t { - [[nodiscard]] Scalar call(Scalar s) const // NOLINT + [[nodiscard]] Scalar call(const Scalar s) const final { return 1 / (1 + std::exp(-s)); } - [[nodiscard]] Scalar derivative(Scalar s) const + [[nodiscard]] Scalar derivative(const Scalar s) const final { - return call(s) * (1 - call(s)); + auto v = call(s); + return v * (1 - v); } }; - struct linear_function + struct threshold_function : public function_t { + [[nodiscard]] Scalar call(const Scalar s) const final + { + return s >= 0 ? 1 : 0; + } + + [[nodiscard]] Scalar derivative(Scalar s) const final + { + return 0; + } + }; + struct relu_function : public function_t + { + [[nodiscard]] Scalar call(const Scalar s) const final + { + return std::max(static_cast(0), s); + } + + [[nodiscard]] Scalar derivative(Scalar s) const final + { + return 0; + } }; } diff --git a/include/assign2/layer.h b/include/assign2/layer.h index a18f6e5..028a025 100644 --- a/include/assign2/layer.h +++ b/include/assign2/layer.h @@ -28,6 +28,7 @@ namespace assign2 { class neuron_t { + friend layer_t; public: // empty neuron for loading from a stream explicit neuron_t(weight_view weights): weights(weights) @@ -37,13 +38,13 @@ namespace assign2 explicit neuron_t(weight_view weights, Scalar bias): bias(bias), weights(weights) {} - template - Scalar activate(const Scalar* inputs, ActFunc func) const + Scalar activate(const Scalar* inputs, function_t* act_func) { - auto sum = bias; + z = bias; for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()})) - sum += x * w; - return func.call(sum); + z += x * w; + a = act_func->call(z); + return a; } template @@ -61,9 +62,17 @@ namespace assign2 stream >> d; stream >> bias; } + + void debug() const + { + std::cout << bias << " "; + } private: - Scalar bias = 0; + float z = 0; + float a = 0; + float bias = 0; + float error = 0; weight_view weights; }; @@ -71,7 +80,8 @@ namespace assign2 { public: template - layer_t(const blt::i32 in, const blt::i32 out, WeightFunc w, BiasFunc b): in_size(in), out_size(out) + layer_t(const blt::i32 in, const blt::i32 out, function_t* act_func, WeightFunc w, BiasFunc b): + in_size(in), out_size(out), act_func(act_func) { neurons.reserve(out_size); for (blt::i32 i = 0; i < out_size; i++) @@ -83,8 +93,7 @@ namespace assign2 } } - template - std::vector call(const std::vector& in, ActFunction func = ActFunction{}) + std::vector call(const std::vector& in) { std::vector out; out.reserve(out_size); @@ -93,10 +102,47 @@ namespace assign2 throw std::runtime_exception("Input vector doesn't match expected input size!"); #endif for (auto& n : neurons) - out.push_back(n.activate(in.data(), func)); + out.push_back(n.activate(in.data(), act_func)); return out; } + Scalar back_prop(const std::vector& prev_layer_output, Scalar error, const layer_t& next_layer, bool is_output) + { + std::vector dw; + + // δ(h) + if (is_output) + { + // assign error to output layer + for (auto& n : neurons) + n.error = act_func->derivative(n.z) * error; // f'act(net(h)) * (error) + } else + { + // first calculate and assign input layer error + std::vector next_error; + next_error.resize(next_layer.neurons.size()); + for (const auto& [i, w] : blt::enumerate(next_layer.neurons)) + { + for (auto wv : w.weights) + next_error[i] += w.error * wv; + // needed? + next_error[i] /= static_cast(w.weights.size()); + } + + for (auto& n : neurons) + { + n.error = act_func->derivative(n.z); + } + } + + for (const auto& v : prev_layer_output) + { + + } + + return error_at_current_layer; + } + template OStream& serialize(OStream& stream) { @@ -120,9 +166,20 @@ namespace assign2 { return out_size; } + + void debug() const + { + std::cout << "Bias: "; + for (auto& v : neurons) + v.debug(); + std::cout << std::endl; + weights.debug(); + } + private: const blt::i32 in_size, out_size; weight_t weights; + function_t* act_func; std::vector neurons; }; } diff --git a/include/assign2/network.h b/include/assign2/network.h index e56ea62..dd739d7 100644 --- a/include/assign2/network.h +++ b/include/assign2/network.h @@ -21,6 +21,7 @@ #include #include +#include "blt/std/assert.h" namespace assign2 { @@ -75,8 +76,7 @@ namespace assign2 network_t() = default; - template - std::vector execute(const std::vector& input, ActFunc func, ActFuncOut outFunc) + std::vector execute(const std::vector& input) { std::vector previous_output; std::vector current_output; @@ -85,39 +85,45 @@ namespace assign2 { previous_output = current_output; if (i == 0) - current_output = v.call(input, func); - else if (i == layers.size() - 1) - current_output = v.call(previous_output, outFunc); + current_output = v.call(input); else - current_output = v.call(previous_output, func); + current_output = v.call(previous_output); } return current_output; } + std::pair error(const std::vector& outputs, bool is_bad) + { + BLT_ASSERT(outputs.size() == 2); + auto g = is_bad ? 0.0f : 1.0f; + auto b = is_bad ? 1.0f : 0.0f; + + auto g_diff = outputs[0] - g; + auto b_diff = outputs[1] - b; + + auto error = g_diff * g_diff + b_diff * b_diff; + BLT_INFO("%f %f %f", error, g_diff, b_diff); + + return {0.5f * (error * error), error}; + } + Scalar train(const data_file_t& example) { - const Scalar learn_rate = 0.1; - Scalar total_error = 0; + Scalar total_d_error = 0; for (const auto& x : example.data_points) { - auto o = execute(x.bins, sigmoid_function{}, sigmoid_function{}); - auto y = x.is_bad ? 1.0f : 0.0f; - - Scalar is_bad = 0; - if (o[0] >= 1) - is_bad = 0; - else if (o[1] >= 1) - is_bad = 1; - - auto error = y - is_bad; - if (o[0] >= 1 && o[1] >= 1) - error += 1; - - total_error += error; - + print_vec(x.bins) << std::endl; + auto o = execute(x.bins); + print_vec(o) << std::endl; + auto [e, de] = error(o, x.is_bad); + total_error += e; + total_d_error += -learn_rate * de; + BLT_TRACE("\tError %f, %f, is bad? %s", e, -learn_rate * de, x.is_bad ? "True" : "False"); } + BLT_DEBUG("Total Errors found %f, %f", total_error, total_d_error); + return total_error; } diff --git a/lib/blt-graphics b/lib/blt-graphics new file mode 160000 index 0000000..8103a3a --- /dev/null +++ b/lib/blt-graphics @@ -0,0 +1 @@ +Subproject commit 8103a3ad0ff5341c61ba07c4b1b9803b2cc740e9 diff --git a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp index 5fb12e0..93e61fe 100644 --- a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp +++ b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp @@ -219,9 +219,9 @@ public: , output_dims(output_dims_), output_strides(output_strides_) {} - void operator()(const Scalar* output_data) const + void operator()(const Scalar* neuron_data) const { - check_recursive(input_data, output_data); + check_recursive(input_data, neuron_data); } }; diff --git a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp index 52f7dde..7dd1cb3 100644 --- a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp +++ b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp @@ -86,7 +86,7 @@ static void test_block_io_copy_data_from_source_to_target() { auto output_strides = internal::strides(dims); const T* input_data = input.data(); - T* output_data = output.data(); + T* neuron_data = output.data(); T* block_data = block.data(); for (int i = 0; i < block_mapper.blockCount(); ++i) { @@ -105,7 +105,7 @@ static void test_block_io_copy_data_from_source_to_target() { { // Write from block buffer to output. - IODst dst(blk_dims, output_strides, output_data, desc.offset()); + IODst dst(blk_dims, output_strides, neuron_data, desc.offset()); IOSrc src(blk_strides, block_data, 0); TensorBlockIO::Copy(dst, src); @@ -113,7 +113,7 @@ static void test_block_io_copy_data_from_source_to_target() { } for (int i = 0; i < dims.TotalSize(); ++i) { - VERIFY_IS_EQUAL(input_data[i], output_data[i]); + VERIFY_IS_EQUAL(input_data[i], neuron_data[i]); } } @@ -159,7 +159,7 @@ static void test_block_io_copy_using_reordered_dimensions() { auto output_strides = internal::strides(output_tensor_dims); const T* input_data = input.data(); - T* output_data = output.data(); + T* neuron_data = output.data(); T* block_data = block.data(); for (Index i = 0; i < block_mapper.blockCount(); ++i) { @@ -198,7 +198,7 @@ static void test_block_io_copy_using_reordered_dimensions() { } // Write from block buffer to output. - IODst dst(dst_dims, input_strides, output_data, first_coeff_index); + IODst dst(dst_dims, input_strides, neuron_data, first_coeff_index); IOSrc src(blk_strides, block_data, 0); // TODO(ezhulenev): Remove when fully switched to TensorBlock. @@ -210,7 +210,7 @@ static void test_block_io_copy_using_reordered_dimensions() { } for (Index i = 0; i < dims.TotalSize(); ++i) { - VERIFY_IS_EQUAL(input_data[i], output_data[i]); + VERIFY_IS_EQUAL(input_data[i], neuron_data[i]); } } diff --git a/src/main.cpp b/src/main.cpp index 68bf5fe..0d5504e 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -68,18 +68,6 @@ std::vector load_data_files(const std::vector& files) return loaded_data; } -template -decltype(std::cout)& print_vec(const std::vector& vec) -{ - for (auto [i, v] : blt::enumerate(vec)) - { - std::cout << v; - if (i != vec.size() - 1) - std::cout << ", "; - } - return std::cout; -} - int main(int argc, const char** argv) { blt::arg_parse parser; @@ -90,6 +78,22 @@ int main(int argc, const char** argv) auto data_files = load_data_files(get_data_files(data_directory)); + random_init randomizer{619}; + sigmoid_function sig; + relu_function relu; + threshold_function thresh; + + layer_t layer1{16, 8, &sig, randomizer, randomizer}; + layer1.debug(); + layer_t layer2{8, 8, &sig, randomizer, randomizer}; + layer2.debug(); + layer_t layer3{8, 8, &sig, randomizer, randomizer}; + layer3.debug(); + layer_t layer_output{8, 2, &relu, randomizer, randomizer}; + layer_output.debug(); + + network_t network{{layer1, layer2, layer3, layer_output}}; + std::vector input; input.resize(16); for (auto f : data_files) @@ -98,20 +102,12 @@ int main(int argc, const char** argv) { for (auto [i, b] : blt::enumerate(f.data_points.begin()->bins)) input[i] = b; + network.train(f); + break; } } - random_init randomizer{619}; - sigmoid_function sig; - - layer_t layer1{16, 4, randomizer, empty_init{}}; - layer_t layer2{4, 4, randomizer, empty_init{}}; - layer_t layer3{4, 4, randomizer, empty_init{}}; - layer_t layer_output{4, 1, randomizer, empty_init{}}; - - network_t network{{layer1, layer2, layer3, layer_output}}; - - auto output = network.execute(input, sig, sig); + auto output = network.execute(input); print_vec(output) << std::endl; // for (auto d : data_files)