diff --git a/.gitmodules b/.gitmodules
index 484b3a7..cd52502 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,6 @@
[submodule "lib/blt"]
path = lib/blt
url = https://github.com/Tri11Paragon/BLT.git
+[submodule "lib/blt-graphics"]
+ path = lib/blt-graphics
+ url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
index 0a294a1..ebc05d3 100644
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -3,6 +3,7 @@
+
\ No newline at end of file
diff --git a/1.txt b/1.txt
new file mode 100644
index 0000000..277e5ce
--- /dev/null
+++ b/1.txt
@@ -0,0 +1,9 @@
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
+Bias: 0.0883882 0.498606
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
+389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09
diff --git a/2.txt b/2.txt
new file mode 100644
index 0000000..277e5ce
--- /dev/null
+++ b/2.txt
@@ -0,0 +1,9 @@
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
+Bias: 0.0883882 0.498606
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
+389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09
diff --git a/3.txt b/3.txt
new file mode 100644
index 0000000..277e5ce
--- /dev/null
+++ b/3.txt
@@ -0,0 +1,9 @@
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729, 0.290662, 0.180188, -0.306817, 0.00927613, 0.37434, 0.267375, -0.15453, -0.259941, -0.0329586, 0.0232179, -0.0882589, 0.211924, 0.302485, -0.465395, 0.297113, 0.0298, -0.255424, 0.0950433, -0.493201, 0.387198, -0.218439, -0.0878063, -0.27663, -0.18966, -0.348982, -0.210474, -0.232789, -0.0718565, -0.401074, -0.321769, -0.225299, -0.138247, 0.0411653, -0.20941, 0.468472, 0.0649849, -0.147813, 0.260647, -0.495487, 0.353673, -0.0956114, -0.0889727, 0.443833, 0.0468353, -0.487945, 0.200253, 0.108264, 0.0764869, -0.151083, -0.0761, 0.399845, -0.471762, -0.267245, -0.357395, -0.330219, 0.306514, -0.244428, -0.499103, -0.456246, 0.0216964, 0.490225, -0.178489, -0.118379, 0.477028
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
+Bias: 0.0883882 0.498606 -0.192237 -0.225643 -0.202986 0.337387 0.492392 -0.378505
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186, -0.387672, 0.156685, 0.396763, 0.163871, -0.347555, 0.161215, 0.40124, 0.121386, -0.0582743, -0.178808, 0.443577, 0.0854848, 0.0230515, 0.343403, 0.201804, 0.101254, 0.413756, 0.228725, -0.392687, 0.305022, -0.0608484, 0.0060341, 0.393663, -0.0173984, 0.0559042, 0.357518, 0.30257, 0.376668, 0.14234, -0.0381443, 0.0163791, -0.180164, 0.195689, 0.0580821, -0.16776, 0.0569814, -0.271095, 0.223116, -0.475364, 0.0926328, 0.0061297, -0.338429, 0.393725, 0.0131179, -0.207723, -0.296026, 0.0254685, -0.240729
+Bias: 0.0883882 0.498606
+Weights: 0.0883882, 0.498606, -0.192237, -0.225643, -0.202986, 0.337387, 0.492392, -0.378505, 0.174707, 0.289964, -0.176566, -0.0890466, 0.464011, 0.0977563, -0.0747116, -0.179186
+389.05, 133.34, 62.14, 32.6, 36.97, 29.03, 62.28, 114.07, 79.41, 212.39, 70.29, 36.73, 31.3, 76.01, 103.99, 214.09
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7823189..2143028 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,9 +1,10 @@
cmake_minimum_required(VERSION 3.25)
-project(COSC-4P80-Assignment-2 VERSION 0.0.3)
+project(COSC-4P80-Assignment-2 VERSION 0.0.4)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
+option(ENABLE_GRAPHICS "Enable usage of graphics package" OFF)
#option(EIGEN_TEST_CXX11 "Enable testing with C++11 and C++11 features (e.g. Tensor module)." ON)
set(CMAKE_CXX_STANDARD 17)
@@ -12,7 +13,12 @@ if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
-add_subdirectory(lib/blt)
+if (ENABLE_GRAPHICS)
+ add_subdirectory(lib/blt-graphics)
+ add_compile_definitions(BLT_USE_GRAPHICS)
+else ()
+ add_subdirectory(lib/blt)
+endif ()
#add_subdirectory(lib/eigen-3.4.0)
@@ -25,7 +31,11 @@ target_compile_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -
target_link_options(COSC-4P80-Assignment-2 PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
#target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT Eigen3::Eigen)
-target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT)
+if (ENABLE_GRAPHICS)
+ target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT_WITH_GRAPHICS)
+else ()
+ target_link_libraries(COSC-4P80-Assignment-2 PRIVATE BLT)
+endif ()
if (${ENABLE_ADDRSAN} MATCHES ON)
target_compile_options(COSC-4P80-Assignment-2 PRIVATE -fsanitize=address)
diff --git a/include/assign2/common.h b/include/assign2/common.h
index 705a86d..1b3af1c 100644
--- a/include/assign2/common.h
+++ b/include/assign2/common.h
@@ -19,10 +19,25 @@
#ifndef COSC_4P80_ASSIGNMENT_2_COMMON_H
#define COSC_4P80_ASSIGNMENT_2_COMMON_H
+#include
+#include
namespace assign2
{
using Scalar = float;
+ const inline Scalar learn_rate = 0.1;
+
+ template
+ decltype(std::cout)& print_vec(const std::vector& vec)
+ {
+ for (auto [i, v] : blt::enumerate(vec))
+ {
+ std::cout << v;
+ if (i != vec.size() - 1)
+ std::cout << ", ";
+ }
+ return std::cout;
+ }
struct data_t
{
@@ -36,15 +51,23 @@ namespace assign2
};
class layer_t;
+
class network_t;
+ struct function_t
+ {
+ [[nodiscard]] virtual Scalar call(Scalar) const = 0;
+
+ [[nodiscard]] virtual Scalar derivative(Scalar) const = 0;
+ };
+
struct weight_view
{
public:
- weight_view(double* data, blt::size_t size): m_data(data), m_size(size)
+ weight_view(Scalar* data, blt::size_t size): m_data(data), m_size(size)
{}
- inline double& operator[](blt::size_t index) const
+ inline Scalar& operator[](blt::size_t index) const
{
#if BLT_DEBUG_LEVEL > 0
if (index >= size)
@@ -69,7 +92,7 @@ namespace assign2
}
private:
- double* m_data;
+ Scalar* m_data;
blt::size_t m_size;
};
@@ -85,10 +108,17 @@ namespace assign2
data.resize(size + count);
return {&data[size], count};
}
+
+ void debug() const
+ {
+ std::cout << "Weights: ";
+ print_vec(data) << std::endl;
+ }
private:
- std::vector data;
+ std::vector data;
};
+
}
#endif //COSC_4P80_ASSIGNMENT_2_COMMON_H
diff --git a/include/assign2/functions.h b/include/assign2/functions.h
index 0f60d39..aa2b739 100644
--- a/include/assign2/functions.h
+++ b/include/assign2/functions.h
@@ -24,22 +24,44 @@
namespace assign2
{
- struct sigmoid_function
+ struct sigmoid_function : public function_t
{
- [[nodiscard]] Scalar call(Scalar s) const // NOLINT
+ [[nodiscard]] Scalar call(const Scalar s) const final
{
return 1 / (1 + std::exp(-s));
}
- [[nodiscard]] Scalar derivative(Scalar s) const
+ [[nodiscard]] Scalar derivative(const Scalar s) const final
{
- return call(s) * (1 - call(s));
+ auto v = call(s);
+ return v * (1 - v);
}
};
- struct linear_function
+ struct threshold_function : public function_t
{
+ [[nodiscard]] Scalar call(const Scalar s) const final
+ {
+ return s >= 0 ? 1 : 0;
+ }
+
+ [[nodiscard]] Scalar derivative(Scalar s) const final
+ {
+ return 0;
+ }
+ };
+ struct relu_function : public function_t
+ {
+ [[nodiscard]] Scalar call(const Scalar s) const final
+ {
+ return std::max(static_cast(0), s);
+ }
+
+ [[nodiscard]] Scalar derivative(Scalar s) const final
+ {
+ return 0;
+ }
};
}
diff --git a/include/assign2/layer.h b/include/assign2/layer.h
index a18f6e5..028a025 100644
--- a/include/assign2/layer.h
+++ b/include/assign2/layer.h
@@ -28,6 +28,7 @@ namespace assign2
{
class neuron_t
{
+ friend layer_t;
public:
// empty neuron for loading from a stream
explicit neuron_t(weight_view weights): weights(weights)
@@ -37,13 +38,13 @@ namespace assign2
explicit neuron_t(weight_view weights, Scalar bias): bias(bias), weights(weights)
{}
- template
- Scalar activate(const Scalar* inputs, ActFunc func) const
+ Scalar activate(const Scalar* inputs, function_t* act_func)
{
- auto sum = bias;
+ z = bias;
for (auto [x, w] : blt::zip_iterator_container({inputs, inputs + weights.size()}, {weights.begin(), weights.end()}))
- sum += x * w;
- return func.call(sum);
+ z += x * w;
+ a = act_func->call(z);
+ return a;
}
template
@@ -61,9 +62,17 @@ namespace assign2
stream >> d;
stream >> bias;
}
+
+ void debug() const
+ {
+ std::cout << bias << " ";
+ }
private:
- Scalar bias = 0;
+ float z = 0;
+ float a = 0;
+ float bias = 0;
+ float error = 0;
weight_view weights;
};
@@ -71,7 +80,8 @@ namespace assign2
{
public:
template
- layer_t(const blt::i32 in, const blt::i32 out, WeightFunc w, BiasFunc b): in_size(in), out_size(out)
+ layer_t(const blt::i32 in, const blt::i32 out, function_t* act_func, WeightFunc w, BiasFunc b):
+ in_size(in), out_size(out), act_func(act_func)
{
neurons.reserve(out_size);
for (blt::i32 i = 0; i < out_size; i++)
@@ -83,8 +93,7 @@ namespace assign2
}
}
- template
- std::vector call(const std::vector& in, ActFunction func = ActFunction{})
+ std::vector call(const std::vector& in)
{
std::vector out;
out.reserve(out_size);
@@ -93,10 +102,47 @@ namespace assign2
throw std::runtime_exception("Input vector doesn't match expected input size!");
#endif
for (auto& n : neurons)
- out.push_back(n.activate(in.data(), func));
+ out.push_back(n.activate(in.data(), act_func));
return out;
}
+ Scalar back_prop(const std::vector& prev_layer_output, Scalar error, const layer_t& next_layer, bool is_output)
+ {
+ std::vector dw;
+
+ // δ(h)
+ if (is_output)
+ {
+ // assign error to output layer
+ for (auto& n : neurons)
+ n.error = act_func->derivative(n.z) * error; // f'act(net(h)) * (error)
+ } else
+ {
+ // first calculate and assign input layer error
+ std::vector next_error;
+ next_error.resize(next_layer.neurons.size());
+ for (const auto& [i, w] : blt::enumerate(next_layer.neurons))
+ {
+ for (auto wv : w.weights)
+ next_error[i] += w.error * wv;
+ // needed?
+ next_error[i] /= static_cast(w.weights.size());
+ }
+
+ for (auto& n : neurons)
+ {
+ n.error = act_func->derivative(n.z);
+ }
+ }
+
+ for (const auto& v : prev_layer_output)
+ {
+
+ }
+
+ return error_at_current_layer;
+ }
+
template
OStream& serialize(OStream& stream)
{
@@ -120,9 +166,20 @@ namespace assign2
{
return out_size;
}
+
+ void debug() const
+ {
+ std::cout << "Bias: ";
+ for (auto& v : neurons)
+ v.debug();
+ std::cout << std::endl;
+ weights.debug();
+ }
+
private:
const blt::i32 in_size, out_size;
weight_t weights;
+ function_t* act_func;
std::vector neurons;
};
}
diff --git a/include/assign2/network.h b/include/assign2/network.h
index e56ea62..dd739d7 100644
--- a/include/assign2/network.h
+++ b/include/assign2/network.h
@@ -21,6 +21,7 @@
#include
#include
+#include "blt/std/assert.h"
namespace assign2
{
@@ -75,8 +76,7 @@ namespace assign2
network_t() = default;
- template
- std::vector execute(const std::vector& input, ActFunc func, ActFuncOut outFunc)
+ std::vector execute(const std::vector& input)
{
std::vector previous_output;
std::vector current_output;
@@ -85,39 +85,45 @@ namespace assign2
{
previous_output = current_output;
if (i == 0)
- current_output = v.call(input, func);
- else if (i == layers.size() - 1)
- current_output = v.call(previous_output, outFunc);
+ current_output = v.call(input);
else
- current_output = v.call(previous_output, func);
+ current_output = v.call(previous_output);
}
return current_output;
}
+ std::pair error(const std::vector& outputs, bool is_bad)
+ {
+ BLT_ASSERT(outputs.size() == 2);
+ auto g = is_bad ? 0.0f : 1.0f;
+ auto b = is_bad ? 1.0f : 0.0f;
+
+ auto g_diff = outputs[0] - g;
+ auto b_diff = outputs[1] - b;
+
+ auto error = g_diff * g_diff + b_diff * b_diff;
+ BLT_INFO("%f %f %f", error, g_diff, b_diff);
+
+ return {0.5f * (error * error), error};
+ }
+
Scalar train(const data_file_t& example)
{
- const Scalar learn_rate = 0.1;
-
Scalar total_error = 0;
+ Scalar total_d_error = 0;
for (const auto& x : example.data_points)
{
- auto o = execute(x.bins, sigmoid_function{}, sigmoid_function{});
- auto y = x.is_bad ? 1.0f : 0.0f;
-
- Scalar is_bad = 0;
- if (o[0] >= 1)
- is_bad = 0;
- else if (o[1] >= 1)
- is_bad = 1;
-
- auto error = y - is_bad;
- if (o[0] >= 1 && o[1] >= 1)
- error += 1;
-
- total_error += error;
-
+ print_vec(x.bins) << std::endl;
+ auto o = execute(x.bins);
+ print_vec(o) << std::endl;
+ auto [e, de] = error(o, x.is_bad);
+ total_error += e;
+ total_d_error += -learn_rate * de;
+ BLT_TRACE("\tError %f, %f, is bad? %s", e, -learn_rate * de, x.is_bad ? "True" : "False");
}
+ BLT_DEBUG("Total Errors found %f, %f", total_error, total_d_error);
+
return total_error;
}
diff --git a/lib/blt-graphics b/lib/blt-graphics
new file mode 160000
index 0000000..8103a3a
--- /dev/null
+++ b/lib/blt-graphics
@@ -0,0 +1 @@
+Subproject commit 8103a3ad0ff5341c61ba07c4b1b9803b2cc740e9
diff --git a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp
index 5fb12e0..93e61fe 100644
--- a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp
+++ b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_access.cpp
@@ -219,9 +219,9 @@ public:
, output_dims(output_dims_), output_strides(output_strides_)
{}
- void operator()(const Scalar* output_data) const
+ void operator()(const Scalar* neuron_data) const
{
- check_recursive(input_data, output_data);
+ check_recursive(input_data, neuron_data);
}
};
diff --git a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp
index 52f7dde..7dd1cb3 100644
--- a/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp
+++ b/lib/eigen-3.4.0/unsupported/test/cxx11_tensor_block_io.cpp
@@ -86,7 +86,7 @@ static void test_block_io_copy_data_from_source_to_target() {
auto output_strides = internal::strides(dims);
const T* input_data = input.data();
- T* output_data = output.data();
+ T* neuron_data = output.data();
T* block_data = block.data();
for (int i = 0; i < block_mapper.blockCount(); ++i) {
@@ -105,7 +105,7 @@ static void test_block_io_copy_data_from_source_to_target() {
{
// Write from block buffer to output.
- IODst dst(blk_dims, output_strides, output_data, desc.offset());
+ IODst dst(blk_dims, output_strides, neuron_data, desc.offset());
IOSrc src(blk_strides, block_data, 0);
TensorBlockIO::Copy(dst, src);
@@ -113,7 +113,7 @@ static void test_block_io_copy_data_from_source_to_target() {
}
for (int i = 0; i < dims.TotalSize(); ++i) {
- VERIFY_IS_EQUAL(input_data[i], output_data[i]);
+ VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
}
}
@@ -159,7 +159,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
auto output_strides = internal::strides(output_tensor_dims);
const T* input_data = input.data();
- T* output_data = output.data();
+ T* neuron_data = output.data();
T* block_data = block.data();
for (Index i = 0; i < block_mapper.blockCount(); ++i) {
@@ -198,7 +198,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
}
// Write from block buffer to output.
- IODst dst(dst_dims, input_strides, output_data, first_coeff_index);
+ IODst dst(dst_dims, input_strides, neuron_data, first_coeff_index);
IOSrc src(blk_strides, block_data, 0);
// TODO(ezhulenev): Remove when fully switched to TensorBlock.
@@ -210,7 +210,7 @@ static void test_block_io_copy_using_reordered_dimensions() {
}
for (Index i = 0; i < dims.TotalSize(); ++i) {
- VERIFY_IS_EQUAL(input_data[i], output_data[i]);
+ VERIFY_IS_EQUAL(input_data[i], neuron_data[i]);
}
}
diff --git a/src/main.cpp b/src/main.cpp
index 68bf5fe..0d5504e 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -68,18 +68,6 @@ std::vector load_data_files(const std::vector& files)
return loaded_data;
}
-template
-decltype(std::cout)& print_vec(const std::vector& vec)
-{
- for (auto [i, v] : blt::enumerate(vec))
- {
- std::cout << v;
- if (i != vec.size() - 1)
- std::cout << ", ";
- }
- return std::cout;
-}
-
int main(int argc, const char** argv)
{
blt::arg_parse parser;
@@ -90,6 +78,22 @@ int main(int argc, const char** argv)
auto data_files = load_data_files(get_data_files(data_directory));
+ random_init randomizer{619};
+ sigmoid_function sig;
+ relu_function relu;
+ threshold_function thresh;
+
+ layer_t layer1{16, 8, &sig, randomizer, randomizer};
+ layer1.debug();
+ layer_t layer2{8, 8, &sig, randomizer, randomizer};
+ layer2.debug();
+ layer_t layer3{8, 8, &sig, randomizer, randomizer};
+ layer3.debug();
+ layer_t layer_output{8, 2, &relu, randomizer, randomizer};
+ layer_output.debug();
+
+ network_t network{{layer1, layer2, layer3, layer_output}};
+
std::vector input;
input.resize(16);
for (auto f : data_files)
@@ -98,20 +102,12 @@ int main(int argc, const char** argv)
{
for (auto [i, b] : blt::enumerate(f.data_points.begin()->bins))
input[i] = b;
+ network.train(f);
+ break;
}
}
- random_init randomizer{619};
- sigmoid_function sig;
-
- layer_t layer1{16, 4, randomizer, empty_init{}};
- layer_t layer2{4, 4, randomizer, empty_init{}};
- layer_t layer3{4, 4, randomizer, empty_init{}};
- layer_t layer_output{4, 1, randomizer, empty_init{}};
-
- network_t network{{layer1, layer2, layer3, layer_output}};
-
- auto output = network.execute(input, sig, sig);
+ auto output = network.execute(input);
print_vec(output) << std::endl;
// for (auto d : data_files)