Skip to content

Commit

Permalink
Add SNAFX test code
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Dec 1, 2023
1 parent 5ca5629 commit 49bece3
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ if(APPLE)
add_subdirectory(accelerate)
endif()
add_subdirectory(residual_connection)
add_subdirectory(snafx)
21 changes: 21 additions & 0 deletions examples/snafx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
CPMAddPackage(
NAME libfmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 9.1.0
)
CPMAddPackage(
NAME libsndfile
GIT_REPOSITORY https://github.com/libsndfile/libsndfile
GIT_TAG v1.0.30
OPTIONS
"BUILD_PROGRAMS OFF"
"BUILD_EXAMPLES OFF"
"BUILD_TESTING OFF"
)

create_example(snafx)
target_link_libraries(snafx
PRIVATE
fmt::fmt
sndfile
)
89 changes: 89 additions & 0 deletions examples/snafx/FiLM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#pragma once

#include <RTNeural/RTNeural.h>

namespace snafx
{
template <typename T, int num_features, int cond_dim, bool use_batch_norm = false>
struct FiLM
{
#if RTNEURAL_USE_XSIMD
using v_type = xsimd::batch<T>;
static constexpr auto v_size = (int) v_type::size;
static_assert (num_features == 1 || num_features % v_size == 0, "This implementation relies on num_features being an even multiple of batch_size!");
static constexpr auto num_features_vec = RTNeural::ceil_div (num_features, v_size);
#endif

void load_weights (const nlohmann::json& model_json, int block_index);

void condition (const T (&cond_ins)[cond_dim]) noexcept
{
#if RTNEURAL_USE_XSIMD
std::copy (std::begin (cond_ins), std::end (cond_ins), reinterpret_cast<T*> (cond_in));
adaptor.forward (cond_in);
std::copy (adaptor.outs, adaptor.outs + num_features_vec, g_vals);
std::copy (adaptor.outs + num_features_vec, adaptor.outs + 2 * num_features_vec, b_vals);
#elif RTNEURAL_USE_EIGEN
std::copy (std::begin (cond_ins), std::end (cond_ins), cond_in.data());
adaptor.forward (cond_in);
std::copy (adaptor.outs.data(), adaptor.outs.data() + num_features, g_vals.data());
std::copy (adaptor.outs.data() + num_features, adaptor.outs.data() + 2 * num_features, b_vals.data());
#endif
}

#if RTNEURAL_USE_XSIMD
template <bool B = use_batch_norm>
inline typename std::enable_if<B, void>::type
forward (const v_type (&ins)[num_features_vec]) noexcept
{
bn.forward (ins);
forward<false> (bn.outs);
}

template <bool B = use_batch_norm>
inline typename std::enable_if<! B, void>::type
forward (const v_type (&ins)[num_features_vec]) noexcept
{
for (int i = 0; i < num_features_vec; ++i)
outs[i] = g_vals[i] * ins[i] + b_vals[i];
}
#elif RTNEURAL_USE_EIGEN
template <bool B = use_batch_norm>
inline typename std::enable_if<B, void>::type
forward (const Eigen::Vector<T, num_features>& ins) noexcept
{
bn.forward (ins);
forward<false> (bn.outs);
}

template <bool B = use_batch_norm>
inline typename std::enable_if<! B, void>::type
forward (const Eigen::Vector<T, num_features>& ins) noexcept
{
outs = g_vals.cwiseProduct (ins) + b_vals;
}
#endif

RTNeural::DenseT<T, cond_dim, 2 * num_features> adaptor;
RTNeural::BatchNorm1DT<T, num_features, false> bn;

#if RTNEURAL_USE_XSIMD
v_type outs[num_features_vec] {};
#elif RTNEURAL_USE_EIGEN
Eigen::Vector<T, num_features> outs {};
#endif

private:
#if RTNEURAL_USE_XSIMD
v_type cond_in[RTNeural::ceil_div (cond_dim, v_size)] {};
v_type g_vals[num_features_vec] {};
v_type b_vals[num_features_vec] {};
#elif RTNEURAL_USE_EIGEN
Eigen::Vector<T, cond_dim> cond_in {};
Eigen::Vector<T, num_features> g_vals {};
Eigen::Vector<T, num_features> b_vals {};
#endif
};
} // namespace snafx

#include "FiLM.tpp"
16 changes: 16 additions & 0 deletions examples/snafx/FiLM.tpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace snafx
{
template <typename T, int num_features, int cond_dim, bool use_batch_norm>
void FiLM<T, num_features, cond_dim, use_batch_norm>::load_weights (const nlohmann::json& model_json, int block_index)
{
const auto adaptor_weights_key = std::string { "blocks." } + std::to_string (block_index) + ".film.adaptor.weight";
std::vector<std::vector<T>> adaptor_weights = model_json.at (adaptor_weights_key);
adaptor.setWeights (adaptor_weights);

const auto adaptor_bias_key = std::string { "blocks." } + std::to_string (block_index) + ".film.adaptor.bias";
std::vector<T> adaptor_bias = model_json.at (adaptor_bias_key).get<std::vector<T>>();
adaptor.setBias(adaptor_bias.data());

static_assert (! use_batch_norm, "TODO: figure out how to load the batch norm weights once we have a test model!");
}
} // namespace snafx
68 changes: 68 additions & 0 deletions examples/snafx/SNAFxModel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include "TCNBlock.h"

namespace snafx
{
template <typename T, int n_channels = 32, int kernel_size = 9, int cond_dim = 2>
struct Model
{
void load_model (const nlohmann::json& model_json)
{
block0.load_weights (model_json, 0);
block1.load_weights (model_json, 1);
block2.load_weights (model_json, 2);
block3.load_weights (model_json, 3);
}

void reset()
{
block0.reset();
block1.reset();
block2.reset();
block3.reset();
}

void condition (const T (&cond_ins)[cond_dim]) noexcept
{
block0.film.condition (cond_ins);
block1.film.condition (cond_ins);
block2.film.condition (cond_ins);
block3.film.condition (cond_ins);
}

T forward (T input) noexcept
{
#if RTNEURAL_USE_XSIMD
std::fill (std::begin (arr), std::end (arr), T{});
arr[0] = input;
block0.forward ({ xsimd::load_aligned (arr) });
#elif RTNEURAL_USE_EIGEN
ins (0) = input;
block0.forward (ins);
#endif
block1.forward (block0.outs);
block2.forward (block1.outs);
block3.forward (block2.outs);

#if RTNEURAL_USE_XSIMD
block3.outs[0].store_aligned (arr);
return arr[0];
#elif RTNEURAL_USE_EIGEN
return block3.outs (0);
#endif
}

TCNBlock<T, 1, n_channels, cond_dim, kernel_size, 1> block0;
TCNBlock<T, n_channels, n_channels, cond_dim, kernel_size, 10> block1;
TCNBlock<T, n_channels, n_channels, cond_dim, kernel_size, 100> block2;
TCNBlock<T, n_channels, 1, cond_dim, kernel_size, 1000> block3;

private:
#if RTNEURAL_USE_XSIMD
alignas (RTNEURAL_DEFAULT_ALIGNMENT) T arr[xsimd::batch<T>::size] {};
#elif RTNEURAL_USE_EIGEN
Eigen::Vector<T, 1> ins {};
#endif
};
}
54 changes: 54 additions & 0 deletions examples/snafx/TCNBlock.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#pragma once

#include "FiLM.h"

namespace snafx
{
template <typename T, int in_size, int out_size, int cond_dim, int kernel_size, int dilation_rate>
struct TCNBlock
{
#if RTNEURAL_USE_XSIMD
using v_type = xsimd::batch<T>;
static constexpr auto v_size = (int) v_type::size;
static constexpr auto in_size_v = RTNeural::ceil_div (in_size, v_size);
static constexpr auto out_size_v = RTNeural::ceil_div (out_size, v_size);
#endif

void load_weights (const nlohmann::json& model_json, int block_index);

void reset();

#if RTNEURAL_USE_XSIMD
inline void forward (const v_type (&ins)[in_size_v]) noexcept
#elif RTNEURAL_USE_EIGEN
inline void forward (const Eigen::Vector<T, in_size>& ins) noexcept
#endif
{
conv.forward (ins);
film.forward (conv.outs);
act.forward (film.outs);

res.forward (ins);

#if RTNEURAL_USE_XSIMD
for (int i = 0; i < out_size_v; ++i)
outs[i] = act.outs[i] + res.outs[i];
#elif RTNEURAL_USE_EIGEN
outs = act.outs + res.outs;
#endif
}

RTNeural::Conv1DT<T, in_size, out_size, kernel_size, dilation_rate, true> conv;
FiLM<T, out_size, cond_dim> film;
RTNeural::PReLUActivationT<T, out_size> act;
RTNeural::Conv1DT<T, in_size, out_size, 1, 1> res;

#if RTNEURAL_USE_XSIMD
v_type outs[out_size_v]{};
#elif RTNEURAL_USE_EIGEN
Eigen::Vector<T, out_size> outs {};
#endif
};
}

#include "TCNBlock.tpp"
47 changes: 47 additions & 0 deletions examples/snafx/TCNBlock.tpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
namespace snafx
{
void reverse_kernels(std::vector<std::vector<std::vector<float>>>& conv_weights)
{
for (auto& channel_weights : conv_weights)
{
for (auto& kernel : channel_weights)
{
std::reverse(kernel.begin(), kernel.end());
}
}
}

template <typename T, int in_size, int out_size, int cond_dim, int kernel_size, int dilation_rate>
void TCNBlock<T, in_size, out_size, cond_dim, kernel_size, dilation_rate>::load_weights (const nlohmann::json& model_json, int block_index)
{
const auto conv_weights_key = std::string { "blocks." } + std::to_string (block_index) + ".conv.weight";
std::vector<std::vector<std::vector<T>>> conv_weights = model_json.at (conv_weights_key);
reverse_kernels (conv_weights);
conv.setWeights (conv_weights);

const auto conv_bias_key = std::string { "blocks." } + std::to_string (block_index) + ".conv.bias";
std::vector<T> conv_bias = model_json.at (conv_bias_key);
conv.setBias (conv_bias);

film.load_weights (model_json, block_index);

const auto act_weights_key = std::string { "blocks." } + std::to_string (block_index) + ".act.weight";
const auto act_weight = model_json.at (act_weights_key).at (0).get<T>();
act.setAlphaVals ({ act_weight });

const auto res_weights_key = std::string { "blocks." } + std::to_string (block_index) + ".res.weight";
std::vector<std::vector<std::vector<T>>> res_weights = model_json.at (res_weights_key);
reverse_kernels (res_weights);
res.setWeights (res_weights);

std::vector<T> res_bias (res.out_size, (T) 0);
res.setBias (res_bias);
}

template <typename T, int in_size, int out_size, int cond_dim, int kernel_size, int dilation_rate>
void TCNBlock<T, in_size, out_size, cond_dim, kernel_size, dilation_rate>::reset()
{
conv.reset();
res.reset();
}
} // namespace snafx
1 change: 1 addition & 0 deletions examples/snafx/model_comp.json

Large diffs are not rendered by default.

70 changes: 70 additions & 0 deletions examples/snafx/snafx.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include <chrono>
#include <filesystem>
#include <fmt/format.h>
#include <sndfile.hh>

namespace chrono = std::chrono;
namespace fs = std::filesystem;

#include "SNAFxModel.h"

void write_file(const fs::path& file_path, const std::vector<float>& data, int sample_rate)
{
fmt::print("Writing output to file: {}\n", file_path.string());
SndfileHandle file { file_path.c_str(), SFM_WRITE, SF_FORMAT_WAV | SF_FORMAT_PCM_16, 1, sample_rate };
file.write(data.data(), (sf_count_t)data.size());
}

std::pair<std::vector<float>, int> read_file(const fs::path& file_path)
{
fmt::print("Reading from input: {}\n", file_path.string());
SndfileHandle file { file_path.c_str() };

fmt::print(" File sample rate: {}\n", file.samplerate());
fmt::print(" File # channels: {}\n", file.channels());
fmt::print(" File # samples: {}\n", file.frames());

if(file.channels() != 1)
return {};

std::vector<float> data;
data.resize(std::min((size_t)file.frames(), size_t(10.0 * file.samplerate())), 0.0);
file.read(data.data(), (sf_count_t)data.size());
return { data, (int)file.samplerate() };
}

int main()
{
const auto model_file_path = fs::path { std::string { RTNEURAL_EXPERIMENTS_SOURCE_DIR } + std::string { "/examples/snafx/model_comp.json" } };

std::ifstream json_stream(model_file_path, std::ifstream::binary);
const auto model_json = nlohmann::json::parse(json_stream);
// fmt::print ("{}\n", model_json.dump());

snafx::Model<float, 32, 9, 2> model;
model.load_model(model_json);
model.reset();
model.condition({ 0.0f, 0.0f });

const auto [test_ins, sample_rate] = read_file("test.wav");
const auto num_samples = test_ins.size();
std::vector<float> test_outs(num_samples, 0.0f);

const auto start_time = chrono::high_resolution_clock::now();

for(size_t i = 0; i < num_samples; ++i)
test_outs[i] = 0.25f * model.forward(test_ins[i]);

const auto stop_time = chrono::high_resolution_clock::now();
const auto duration = chrono::duration_cast<chrono::milliseconds>(stop_time - start_time);
const auto seconds_processed = (float)num_samples / (float)sample_rate;
fmt::print("Processed {} seconds of audio at {} sample rate in {} milliseconds\n",
seconds_processed,
sample_rate,
duration.count());
fmt::print("Speed: {:.4f}x real-time\n", seconds_processed / (0.001 * (double)duration.count()));

write_file("test_out.wav", test_outs, sample_rate);

return 0;
}

0 comments on commit 49bece3

Please sign in to comment.