Skip to content

Commit

Permalink
Load model from state, don't overwrite params
Browse files Browse the repository at this point in the history
  • Loading branch information
jorshi committed Mar 21, 2024
1 parent 832ae76 commit 689bb81
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 13 deletions.
6 changes: 5 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
{
"cmake.configureOnOpen": true,
"cmake.cmakePath": "/opt/homebrew/bin/cmake"
"cmake.cmakePath": "/opt/homebrew/bin/cmake",
"files.associations": {
"*.m": "matlab",
"iosfwd": "cpp"
}
}
3 changes: 2 additions & 1 deletion source/PluginProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ void TorchDrumProcessor::setStateInformation(const void* data,
auto modelPath = preset.getChildWithName("ModelPath");
if (modelPath.isValid())
{
// Load the model, but don't overwrite the parameters
std::string path = modelPath["Path"].toString().toStdString();
synthController.updateModel(path);
synthController.updateModel(path, false);
}
}
}
Expand Down
15 changes: 7 additions & 8 deletions source/SynthController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ void SynthController::prepare(double sr, int samplesPerBlock)
featureExtraction.prepare(sampleRate, ONSET_WINDOW_SIZE, ONSET_WINDOW_SIZE / 4);
featureBuffer.clear();
featureBuffer.setSize(1, ONSET_WINDOW_SIZE);

// Update synth parameters with the current patch
neuralMapper.getCurrentPatch(synth.getParameters().parameters);
synth.getParameters().updateAllParameters();
}

void SynthController::process(float x)
Expand Down Expand Up @@ -77,15 +73,18 @@ void SynthController::process(float x)
}
}

void SynthController::updateModel(const std::string& path)
void SynthController::updateModel(const std::string& path, bool updateParameters)
{
if (neuralMapper.loadModel(path))
{
modelPath = path;

// Update synth parameters with the current patch
neuralMapper.getCurrentPatch(synth.getParameters().parameters);
synth.getParameters().updateAllParameters();
// Update synth parameters with the preset stored in the model file
if (updateParameters)
{
neuralMapper.getCurrentPatch(synth.getParameters().parameters);
synth.getParameters().updateAllParameters();
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion source/SynthController.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SynthController
void process(float x);

// Update the neural network model
void updateModel(const std::string& path);
void updateModel(const std::string& path, bool updateParameters = true);

// Getters for audio buffers
const juce::AudioBuffer<float>& getBuffer() const { return buffer; }
Expand Down
3 changes: 2 additions & 1 deletion source/Utils/NeuralNetworkMock.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "NeuralNetworkMock.h"

void NeuralNetwork::loadModel(const std::string& path)
bool NeuralNetwork::loadModel(const std::string& path)
{
modelLoaded = true;
return modelLoaded;
}

void NeuralNetwork::process(const std::vector<double>& input, std::vector<double>& output)
Expand Down
2 changes: 1 addition & 1 deletion source/Utils/NeuralNetworkMock.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class NeuralNetwork
NeuralNetwork() = default;
~NeuralNetwork() = default;

void loadModel(const std::string& path);
bool loadModel(const std::string& path);
void process(const std::vector<double>& input, std::vector<double>& output);
void getCurrentPatch(std::vector<juce::RangedAudioParameter*> parameters) {}

Expand Down

0 comments on commit 689bb81

Please sign in to comment.