Skip to content

Commit

Permalink
saving the loaded model path to state
Browse files Browse the repository at this point in the history
  • Loading branch information
jorshi committed Mar 21, 2024
1 parent c1f084b commit 832ae76
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 15 deletions.
5 changes: 4 additions & 1 deletion source/NeuralNetwork.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "NeuralNetwork.h"

void NeuralNetwork::loadModel(const std::string& path)
bool NeuralNetwork::loadModel(const std::string& path)
{
// Obtain the write lock - this will block until the lock is acquired
const juce::ScopedWriteLock writeLock(modelLock);
Expand All @@ -12,6 +12,9 @@ void NeuralNetwork::loadModel(const std::string& path)
{
_testModel();
}

// Return the model loaded status
return modelLoaded;
}

void NeuralNetwork::process(const std::vector<double>& input, std::vector<double>& output)
Expand Down
2 changes: 1 addition & 1 deletion source/NeuralNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class NeuralNetwork
NeuralNetwork() = default;
~NeuralNetwork() = default;

void loadModel(const std::string& path);
bool loadModel(const std::string& path);
void process(const std::vector<double>& input, std::vector<double>& output);
void getCurrentPatch(std::vector<juce::RangedAudioParameter*> parameters);

Expand Down
12 changes: 11 additions & 1 deletion source/PluginProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ void TorchDrumProcessor::getStateInformation(juce::MemoryBlock& destData)

juce::ValueTree pluginPreset("TorchDrum");
pluginPreset.appendChild(params, nullptr);
// This a good place to add any non-parameters to your preset

// Save model path and feature normalizers
juce::ValueTree modelPath("ModelPath");
modelPath.setProperty("Path", juce::String(synthController.getModelPath()), nullptr);
pluginPreset.appendChild(modelPath, nullptr);

copyXmlToBinary(*pluginPreset.createXml(), destData);
}
Expand All @@ -106,6 +110,12 @@ void TorchDrumProcessor::setStateInformation(const void* data,
}

// Load your non-parameter data now
auto modelPath = preset.getChildWithName("ModelPath");
if (modelPath.isValid())
{
std::string path = modelPath["Path"].toString().toStdString();
synthController.updateModel(path);
}
}
}

Expand Down
28 changes: 16 additions & 12 deletions source/SynthController.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
#include "SynthController.h"

SynthController::SynthController(SynthBase& synth, Parameters& params)
: synth(synth), parameters(params)
: synth(synth), parameters(params), modelPath("")
{
// Prepare input and output features for NN
size_t numSynthParams = synth.getParameters().parameters.size();
neuralInput.resize(3);
neuralOutput.resize(numSynthParams);

// Update input and output features for the neural network
neuralMapper.setInOutFeatures(3, numSynthParams);
}

void SynthController::prepare(double sr, int samplesPerBlock)
Expand All @@ -24,14 +31,6 @@ void SynthController::prepare(double sr, int samplesPerBlock)
featureBuffer.clear();
featureBuffer.setSize(1, ONSET_WINDOW_SIZE);

// Prepare input and output features for NN
size_t numSynthParams = synth.getParameters().parameters.size();
neuralInput.resize(3);
neuralOutput.resize(numSynthParams);

// Load the neural network model
neuralMapper.setInOutFeatures(3, numSynthParams);

// Update synth parameters with the current patch
neuralMapper.getCurrentPatch(synth.getParameters().parameters);
synth.getParameters().updateAllParameters();
Expand Down Expand Up @@ -80,9 +79,14 @@ void SynthController::process(float x)

void SynthController::updateModel(const std::string& path)
{
neuralMapper.loadModel(path);
neuralMapper.getCurrentPatch(synth.getParameters().parameters);
synth.getParameters().updateAllParameters();
if (neuralMapper.loadModel(path))
{
modelPath = path;

// Update synth parameters with the current patch
neuralMapper.getCurrentPatch(synth.getParameters().parameters);
synth.getParameters().updateAllParameters();
}
}

void SynthController::addSampleToBuffer(float x)
Expand Down
3 changes: 3 additions & 0 deletions source/SynthController.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class SynthController
// Get the onset detection object
OnsetDetection& getOnsetDetection() { return onsetDetection; }

std::string& getModelPath() { return modelPath; }

private:
// Add a sample to the circular audio buffer
void addSampleToBuffer(float x);
Expand All @@ -80,6 +82,7 @@ class SynthController
NeuralNetwork neuralMapper;
std::vector<double> neuralInput;
std::vector<double> neuralOutput;
std::string modelPath;
juce::Random random;

// ActionBroadcaster for sending messages to the GUI
Expand Down

0 comments on commit 832ae76

Please sign in to comment.