main
Brett 2024-12-10 14:13:55 -05:00
commit 3ac5de603a
23 changed files with 769 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
cmake-build*/
build/
out/
./cmake-build*/
./build/
./out/

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "lib/blt-with-graphics"]
path = lib/blt-with-graphics
url = https://git.tpgc.me/tri11paragon/BLT-With-Graphics-Template

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -0,0 +1,2 @@
<?xml version="1.0" encoding="UTF-8"?>
<module classpath="CMake" type="CPP_MODULE" version="4" />

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CMakePythonSetting">
<option name="pythonIntegrationState" value="YES" />
</component>
<component name="CMakeWorkspace" PROJECT_DIR="$PROJECT_DIR$" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/COSC-4P80-Final-Project.iml" filepath="$PROJECT_DIR$/.idea/COSC-4P80-Final-Project.iml" />
</modules>
</component>
</project>

29
.idea/vcs.xml Normal file
View File

@ -0,0 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/freetype-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/freetype-src/subprojects/dlg" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/glfw3-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/imgui-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-debug/_deps/opennn-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/freetype-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/freetype-src/subprojects/dlg" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/glfw3-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/imgui-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-release/_deps/opennn-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/freetype-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/freetype-src/subprojects/dlg" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/glfw3-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/imgui-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo-addrsan/_deps/opennn-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/freetype-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/freetype-src/subprojects/dlg" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/glfw3-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/imgui-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/cmake-build-relwithdebinfo/_deps/opennn-src" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt-with-graphics" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt-with-graphics/libraries/BLT" vcs="Git" />
<mapping directory="$PROJECT_DIR$/lib/blt-with-graphics/libraries/BLT/libraries/parallel-hashmap" vcs="Git" />
</component>
</project>

44
CMakeLists.txt Normal file
View File

@ -0,0 +1,44 @@
cmake_minimum_required(VERSION 3.29)
project(COSC-4P80-Final-Project VERSION 0.0.2)
option(ENABLE_ADDRSAN "Enable the address sanitizer" OFF)
option(ENABLE_UBSAN "Enable the ub sanitizer" OFF)
option(ENABLE_TSAN "Enable the thread data race sanitizer" OFF)
set(CMAKE_CXX_STANDARD 17)
add_subdirectory(lib/blt-with-graphics)
add_compile_options("-fopenmp")
add_link_options("-fopenmp")
fetchcontent_declare(dlib
URL http://dlib.net/files/dlib-19.24.tar.bz2
URL_HASH MD5=8a98957a73eebd3cd7431c2bac79665f
FIND_PACKAGE_ARGS)
fetchcontent_makeavailable(dlib)
include_directories(include/)
file(GLOB_RECURSE PROJECT_BUILD_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
add_executable(COSC-4P80-Final-Project ${PROJECT_BUILD_FILES})
target_compile_options(COSC-4P80-Final-Project PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
target_link_options(COSC-4P80-Final-Project PRIVATE -Wall -Wextra -Wpedantic -Wno-comment)
target_link_libraries(COSC-4P80-Final-Project PRIVATE BLT_WITH_GRAPHICS dlib)
if (${ENABLE_ADDRSAN} MATCHES ON)
target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=address)
target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=address)
endif ()
if (${ENABLE_UBSAN} MATCHES ON)
target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=undefined)
target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=undefined)
endif ()
if (${ENABLE_TSAN} MATCHES ON)
target_compile_options(COSC-4P80-Final-Project PRIVATE -fsanitize=thread)
target_link_options(COSC-4P80-Final-Project PRIVATE -fsanitize=thread)
endif ()

285
commit.py Executable file
View File

@ -0,0 +1,285 @@
#!python3
import subprocess
import argparse
import sys
import os
import itertools
import requests
import json
from pathlib import Path
#---------------------------------------
# CONFIG
#---------------------------------------
VERSION_BEGIN_STR = " VERSION "
VERSION_END_STR = ")"
#---------------------------------------
# DO NOT TOUCH
#---------------------------------------
USER_HOME = Path.home()
ENVIRONMENT_DATA_LOCATION = USER_HOME / ".brett_scripts.env"
if sys.platform.startswith("win"):
CONFIG_FILE_DIRECTORY = Path(os.getenv('APPDATA') + "\BLT")
CONFIG_FILE_LOCATION = Path(CONFIG_FILE_DIRECTORY + "\commit_config.env")
else:
XDG_CONFIG_HOME = os.environ.get('XDG_CONFIG_HOME')
if XDG_CONFIG_HOME is None:
XDG_CONFIG_HOME = USER_HOME / ".config"
else:
XDG_CONFIG_HOME = Path(XDG_CONFIG_HOME)
if len(str(XDG_CONFIG_HOME)) == 0:
XDG_CONFIG_HOME = USER_HOME
CONFIG_FILE_DIRECTORY = XDG_CONFIG_HOME / "blt"
CONFIG_FILE_LOCATION = CONFIG_FILE_DIRECTORY / "commit_config.env"
class Config:
def __init__(self):
# Inline with semantic versioning it doesn't make sense to branch / release on minor
self.branch_on_major = True
self.branch_on_minor = False
self.release_on_major = True
self.release_on_minor = False
self.main_branch = "main"
self.patch_limit = -1
def toJSON(self):
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
def fromJSON(file):
with open(file, "r") as f:
j = json.load(f)
obj = Config()
[setattr(obj, key, val) for key, val in j.items() if hasattr(obj, key)]
return obj
def from_file(file):
values = {}
if (not os.path.exists(file)):
return Config()
with open(file, "r") as f:
j = json.load(f)
obj = Config()
[setattr(obj, key, val) for key, val in j.items() if hasattr(obj, key)]
return obj
def save_to_file(self, file):
dir_index = str(file).rfind("/")
dir = str(file)[:dir_index]
if not os.path.exists(dir):
print(f"Creating config directory {dir}")
os.makedirs(dir)
with open(file, "w") as f:
json.dump(self, f, default=lambda o: o.__dict__, sort_keys=True, indent=4)
class EnvData:
def __init__(self, github_username = '', github_token = ''):
self.github_token = github_token
self.github_username = github_username
def get_env_from_file(file):
f = open(file, "rt")
values = {}
for line in f:
if line.startswith("export"):
content = line.split("=")
for idx, c in enumerate(content):
content[idx] = c.replace("export", "").strip()
values[content[0]] = content[1].replace("\"", "").replace("'", "")
try:
github_token = values["github_token"]
except Exception:
print("Failed to parse github token!")
try:
github_username = values["github_username"]
except:
print("Failed to parse github username! Assuming you are me!")
github_username = "Tri11Paragon"
return EnvData(github_username=github_username, github_token=github_token)
def open_process(command, print_out = True):
process = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
stdout, stderr = process.communicate()
exit_code = process.wait()
str_out = stdout.decode('utf8')
str_err = stderr.decode('utf8')
if print_out and len(str_out) > 0:
print(str_out, end='')
if print_out and len(str_err) > 0:
print(str_err, end='')
#print(stdout, stderr, exit_code)
return (stdout, stderr, exit_code)
def load_cmake():
cmake_file = open("CMakeLists.txt", 'r')
cmake_text = cmake_file.read()
cmake_file.close()
return cmake_text
def write_cmake(cmake_text):
cmake_file = open("CMakeLists.txt", 'w')
cmake_file.write(cmake_text)
cmake_file.close()
def get_version(cmake_text):
begin = cmake_text.find(VERSION_BEGIN_STR) + len(VERSION_BEGIN_STR)
end = cmake_text.find(VERSION_END_STR, begin)
return (cmake_text[begin:end], begin, end)
def split_version(cmake_text):
version, begin, end = get_version(cmake_text)
version_parts = version.split('.')
return (version_parts, begin, end)
def recombine(cmake_text, version_parts, begin, end):
constructed_version = version_parts[0] + '.' + version_parts[1] + '.' + version_parts[2]
constructed_text_begin = cmake_text[0:begin]
constructed_text_end = cmake_text[end::]
return constructed_text_begin + constructed_version + constructed_text_end
def inc_major(cmake_text):
version_parts, begin, end = split_version(cmake_text)
version_parts[0] = str(int(version_parts[0]) + 1)
version_parts[1] = '0'
version_parts[2] = '0'
return recombine(cmake_text, version_parts, begin, end)
def inc_minor(cmake_text):
version_parts, begin, end = split_version(cmake_text)
version_parts[1] = str(int(version_parts[1]) + 1)
version_parts[2] = '0'
return recombine(cmake_text, version_parts, begin, end)
def inc_patch(config: Config, cmake_text):
version_parts, begin, end = split_version(cmake_text)
if config.patch_limit > 0 and int(version_parts[2]) + 1 >= config.patch_limit:
return inc_minor(cmake_text)
version_parts[2] = str(int(version_parts[2]) + 1)
return recombine(cmake_text, version_parts, begin, end)
def make_branch(config: Config, name):
print(f"Making new branch {name}")
subprocess.call(["git", "checkout", "-b", name])
subprocess.call(["git", "merge", config.main_branch])
subprocess.call(["git", "checkout", config.main_branch])
def make_release(env: EnvData, name):
print(f"Making new release {name}")
repos_v = open_process(["git", "remote", "-v"])[0].splitlines()
urls = []
for line in repos_v:
origin = ''.join(itertools.takewhile(str.isalpha, line.decode('utf8')))
urls.append("https://api.github.com/repos/" + open_process(["git", "remote", "get-url", origin], False)[0].decode('utf8').replace("\n", "").replace("https://github.com/", "") + "/releases")
urls = set(urls)
data = {
'tag_name': name,
'name': name,
'body': "Automated Release '" + name + "'",
'draft': False,
'prerelease': False
}
headers = {
'Authorization': f'Bearer {env.github_token}',
'Accept': 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28'
}
for url in urls:
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 201:
print('Release created successfully!')
release_data = response.json()
print(f"Release URL: {release_data['html_url']}")
else:
print(f"Failed to create release: {response.status_code}")
print(response.json())
def main():
parser = argparse.ArgumentParser(
prog="Commit Helper",
description="Help you make pretty commits :3")
parser.add_argument("action", nargs='?', default=None)
parser.add_argument("-p", "--patch", action='store_true', default=False, required=False)
parser.add_argument("-m", "--minor", action='store_true', default=False, required=False)
parser.add_argument("-M", "--major", action='store_true', default=False, required=False)
parser.add_argument('-e', "--env", help="environment file", required=False, default=None)
parser.add_argument('-c', "--config", help="config file", required=False, default=None)
parser.add_argument("--create_default_config", action="store_true", default=False, required=False)
args = parser.parse_args()
if args.env is not None:
env = EnvData.get_env_from_file(args.e)
else:
env = EnvData.get_env_from_file(ENVIRONMENT_DATA_LOCATION)
if args.config is not None:
config = Config.from_file(args.config)
else:
config = Config.from_file(CONFIG_FILE_LOCATION)
if args.create_default_config:
config.save_to_file(args.config if args.config is not None else CONFIG_FILE_LOCATION)
cmake_text = load_cmake()
cmake_version = get_version(cmake_text)[0]
print(f"Current Version: {cmake_version}")
if not (args.patch or args.minor or args.major):
try:
if args.action is not None:
type = args.action
else:
type = input("What kind of commit is this ((M)ajor, (m)inor, (p)atch)? ")
if type.startswith('M'):
args.major = True
elif type.startswith('m'):
args.minor = True
elif type.startswith('p') or type.startswith('P') or len(type) == 0:
args.patch = True
except KeyboardInterrupt:
print("\nCancelling!")
return
if args.major:
print("Selected major")
write_cmake(inc_major(cmake_text))
elif args.minor:
print("Selected minor")
write_cmake(inc_minor(cmake_text))
elif args.patch:
print("Selected patch")
write_cmake(inc_patch(config, cmake_text))
subprocess.call(["git", "add", "*"])
subprocess.call(["git", "commit"])
cmake_text = load_cmake()
version_parts = split_version(cmake_text)[0]
if args.major:
if config.branch_on_major:
make_branch(config, "v" + str(version_parts[0]))
if args.minor:
if config.branch_on_minor:
make_branch(config, "v" + str(version_parts[0]) + "." + str(version_parts[1]))
subprocess.call(["sh", "-c", "git remote | xargs -L1 git push --all"])
if args.major:
if config.release_on_major:
make_release(env, "v" + str(version_parts[0]))
if args.minor:
if config.release_on_minor:
make_release(env, "v" + str(version_parts[0]) + "." + str(version_parts[1]))
if __name__ == "__main__":
main()

56
default.nix Normal file
View File

@ -0,0 +1,56 @@
{ pkgs ? (import <nixpkgs> {
config.allowUnfree = true;
config.segger-jlink.acceptLicense = true;
}), customPkgs ? (import /home/brett/my-nixpkgs {
config.allowUnfree = true;
config.segger-jlink.acceptLicense = true;
}), ... }:
pkgs.mkShell
{
buildInputs = with pkgs; [
cmake
gcc
clang
emscripten
ninja
customPkgs.jetbrains.clion
#clion = import ~/my-nixpkgs/pkgs/applications/editors/jetbrains {};
renderdoc
valgrind
];
propagatedBuildInputs = with pkgs; [
xorg.libX11
xorg.libX11.dev
xorg.libXcursor
xorg.libXcursor.dev
xorg.libXext
xorg.libXext.dev
xorg.libXinerama
xorg.libXinerama.dev
xorg.libXrandr
xorg.libXrandr.dev
xorg.libXrender
xorg.libXrender.dev
xorg.libxcb
xorg.libxcb.dev
xorg.libXi
xorg.libXi.dev
harfbuzz
harfbuzz.dev
zlib
zlib.dev
bzip2
bzip2.dev
pngpp
brotli
brotli.dev
pulseaudio.dev
git
libGL
libGL.dev
glfw
openblas
openblas.dev
];
LD_LIBRARY_PATH="/run/opengl-driver/lib:/run/opengl-driver-32/lib";
}

27
include/MNIST.h Normal file
View File

@ -0,0 +1,27 @@
#pragma once
/*
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef MNIST_H
#define MNIST_H
namespace fp
{
void run_mnist();
}
#endif //MNIST_H

1
lib/blt-with-graphics Submodule

@ -0,0 +1 @@
Subproject commit 29286e66daa724ef08692d9a65e9c88e5467d9b2

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

285
src/MNIST.cpp Normal file
View File

@ -0,0 +1,285 @@
/*
* <Short Description>
* Copyright (C) 2024 Brett Terpstra
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include <MNIST.h>
#include <blt/fs/loader.h>
#include <blt/std/memory.h>
#include <blt/std/memory_util.h>
#include <variant>
#include <blt/iterator/iterator.h>
#include <dlib/dnn.h>
#include <dlib/data_io.h>
namespace fp
{
class idx_file_t
{
template <typename T>
using mk_v = std::vector<T>;
using vec_t = std::variant<mk_v<blt::u8>, mk_v<blt::i8>, mk_v<blt::u16>, mk_v<blt::u32>, mk_v<blt::f32>, mk_v<blt::f64>>;
public:
explicit idx_file_t(const std::string& path)
{
std::ifstream file{path, std::ios::in | std::ios::binary};
using char_type = std::ifstream::char_type;
char_type magic_arr[4];
file.read(magic_arr, 4);
BLT_ASSERT(magic_arr[0] == 0 && magic_arr[1] == 0);
blt::u8 dims = magic_arr[3];
blt::size_t total_size = 1;
for (blt::i32 i = 0; i < dims; i++)
{
char_type dim_arr[4];
file.read(dim_arr, 4);
blt::u32 dim;
blt::mem::fromBytes(dim_arr, dim);
dimensions.push_back(dim);
total_size *= dim;
}
switch (magic_arr[2])
{
// unsigned char
case 0x08:
data = mk_v<blt::u8>{};
read_data<blt::u8>(file, total_size);
break;
// signed char
case 0x09:
data = mk_v<blt::i8>{};
read_data<blt::i8>(file, total_size);
break;
// short
case 0x0B:
data = mk_v<blt::u16>{};
read_data<blt::u16>(file, total_size);
reverse_data<blt::u16>();
break;
// int
case 0x0C:
data = mk_v<blt::u32>{};
read_data<blt::u32>(file, total_size);
reverse_data<blt::u32>();
break;
// float
case 0x0D:
data = mk_v<blt::f32>{};
read_data<blt::f32>(file, total_size);
reverse_data<blt::f32>();
break;
// double
case 0x0E:
data = mk_v<blt::f64>{};
read_data<blt::f64>(file, total_size);
reverse_data<blt::f64>();
break;
default:
BLT_ERROR("Unspported idx file type!");
}
if (file.eof())
{
BLT_ERROR("EOF reached. It's unlikely your file was read correctly!");
}
}
template <typename T>
[[nodiscard]] const std::vector<T>& get_data_as() const
{
return std::get<mk_v<T>>(data);
}
template <typename T>
std::vector<blt::span<T>> get_as_spans() const
{
std::vector<blt::span<T>> spans;
blt::size_t total_size = data_size(1);
for (blt::size_t i = 0; i < dimensions[0]; i++)
{
auto& array = std::get<mk_v<T>>(data);
spans.push_back({&array[i * total_size], total_size});
}
return spans;
}
[[nodiscard]] const std::vector<blt::u32>& get_dimensions() const
{
return dimensions;
}
[[nodiscard]] blt::size_t data_size(const blt::size_t starting_dimension = 0) const
{
blt::size_t total_size = 1;
for (const auto d : blt::iterate(dimensions).skip(starting_dimension))
total_size *= d;
return total_size;
}
private:
template <typename T>
void read_data(std::ifstream& file, blt::size_t total_size)
{
auto& array = std::get<mk_v<T>>(data);
array.resize(total_size);
file.read(reinterpret_cast<char*>(array.data()), static_cast<std::streamsize>(total_size) * sizeof(T));
}
template <typename T>
void reverse_data()
{
auto& array = std::get<mk_v<T>>(data);
for (auto& v : array)
blt::mem::reverse(v);
}
std::vector<blt::u32> dimensions;
vec_t data;
};
class image_t
{
public:
static constexpr blt::u32 target_size = 10;
image_t(const idx_file_t& image_data, const idx_file_t& label_data): samples(image_data.get_dimensions()[0]),
input_size(image_data.data_size(1))
{
BLT_ASSERT_MSG(samples == label_data.get_dimensions()[0],
("Mismatch in data sample sizes! " + std::to_string(samples) + " vs " + std::to_string(label_data.get_dimensions()[0])).
c_str());
auto& image_array = image_data.get_data_as<blt::u8>();
auto& label_array = label_data.get_data_as<blt::u8>();
for (const auto label : label_array)
image_labels.push_back(label);
const auto row_length = image_data.get_dimensions()[2];
const auto number_of_rows = image_data.get_dimensions()[1];
for (blt::u32 i = 0; i < samples; i++)
{
dlib::matrix<blt::u8> mat(number_of_rows, row_length);
for (blt::u32 y = 0; y < number_of_rows; y++)
{
for (blt::u32 x = 0; x < row_length; x++)
{
mat(x, y) = image_array[i * input_size + y * row_length + x];
}
}
data.push_back(mat);
}
}
[[nodiscard]] const std::vector<dlib::matrix<blt::u8>>& get_image_data() const
{
return data;
}
[[nodiscard]] const std::vector<blt::u64>& get_image_labels() const
{
return image_labels;
}
private:
blt::u32 samples;
blt::u32 input_size;
std::vector<dlib::matrix<blt::u8>> data;
std::vector<blt::u64> image_labels;
};
void run_mnist()
{
using namespace dlib;
idx_file_t test_images{"../problems/mnist/t10k-images.idx3-ubyte"};
idx_file_t test_labels{"../problems/mnist/t10k-labels.idx1-ubyte"};
idx_file_t train_images{"../problems/mnist/train-images.idx3-ubyte"};
idx_file_t train_labels{"../problems/mnist/train-labels.idx1-ubyte"};
auto train_samples = train_images.get_dimensions()[0];
auto test_samples = test_images.get_dimensions()[0];
auto columns = train_images.get_dimensions()[1];
auto rows = train_images.get_dimensions()[2];
auto input_size = static_cast<blt::u32>(train_images.data_size(1));
image_t train_image{train_images, train_labels};
image_t test_image{test_images, test_labels};
using net_type = loss_multiclass_log<
fc<10,
relu<fc<84,
relu<fc<120,
max_pool<2,2,2,2,relu<con<16,5,5,1,1,
max_pool<2,2,2,2,relu<con<6,5,5,1,1,
input<matrix<blt::u8>>>>>>>>>>>>>>;
net_type network{};
dnn_trainer trainer(network);
trainer.set_learning_rate(0.01);
trainer.set_min_learning_rate(0.00001);
trainer.set_mini_batch_size(128);
trainer.be_verbose();
trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
trainer.train(train_image.get_image_data(), train_image.get_image_labels());
trainer.train_one_step(train_image.get_image_data(), train_image.get_image_labels());
network.clean();
serialize("mnist_network.dat") << network;
std::vector<unsigned long> predicted_labels = network(train_image.get_image_data());
int num_right = 0;
int num_wrong = 0;
// And then let's see if it classified them correctly.
for (size_t i = 0; i < train_image.get_image_data().size(); ++i)
{
if (predicted_labels[i] == train_image.get_image_labels()[i])
++num_right;
else
++num_wrong;
}
std::cout << "training num_right: " << num_right << std::endl;
std::cout << "training num_wrong: " << num_wrong << std::endl;
std::cout << "training accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << std::endl;
predicted_labels = network(test_image.get_image_data());
num_right = 0;
num_wrong = 0;
for (size_t i = 0; i < test_image.get_image_data().size(); ++i)
{
if (predicted_labels[i] == test_image.get_image_labels()[i])
++num_right;
else
++num_wrong;
}
std::cout << "testing num_right: " << num_right << std::endl;
std::cout << "testing num_wrong: " << num_wrong << std::endl;
std::cout << "testing accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << std::endl;
}
}

8
src/main.cpp Normal file
View File

@ -0,0 +1,8 @@
#include <iostream>
#include <MNIST.h>
int main()
{
fp::run_mnist();
}