Skip to content

3. Models

Félix Voituret edited this page Oct 28, 2019 · 15 revisions

Spleeter is provided with 3 pretrained models:

Model Description
2 stems model Performs vocals + accompaniment separation
4 stems model1 Performs vocals + drums + bass + other separation
5 stems model Performs vocals + piano + drums + bass + other

1 Pretrained 4 stems model performs very well on the standard musDB benchmark, as shown in the associated paper.

Model Versions

We use GitHub Releases to distribute the pretrained models. Every change in the model training will result in a new release of Spleeter. By default, the model provider will download models for the latest available version.

If you intend to use Spleeter for your research and want to enable full reproducibility, you should specify which version you used1.

1 specific model version could be setup using GITHUB_RELEASE environment variable.

Configuration file

For setting a model for training, you must create a JSON configuration file that contains the parameters that describe all that is need for training. The same file is then used at prediction time. Parameters that can be set in this file are the following:

Model description

Following parameters should be settled for model description :

Key Description
instrument_list a list of instrument name (example: ["vocals", "accompaniment"]).
mix_name the prefix used for the mixture filename (default is "mix").
model_dir path where to store the trained model.
model a dictionary that describe the model.

Last parameter model should contains following keys :

Model key Description
type The type of the model. The model is dynamically loaded from the model subfolder of spleeter. You can add your own model builder in this submodule. Module for building a U-net and a Bi-LSTM are provided. Default "unet.unet".
params a dictionary of parameters of the models. It depends on the model: this dictionary is passed to the model at model building time and can be used to store parameters specific to the model: number of layers, number of unit per layer... Check the code of the provided models for more details.

Audio parameters

Key Description
sample_rate sample rate used for loading the audio files (files are resampled to this sample rate).
frame_length frame length of the short time Fourier transform.
frame_step frame step of the short time Fourier transform.
T time length of the input spectrogram segment (expressed in short time Fourier transform frames). Must be a power of 2.
F Number of frequency bins to be processed (frequency above F are note processed). Must be a power of 2.
n_channels number of channels of input audio (audio with different number of channels are discarder at training/validation time and remixed to the provided number of channels at prediction time).
n_chunks_per_song number of audio chunks to be used per songs.

Separation parameters

Key Description
separation_exponent exponent applied to magnitude spectrogram to compute ratio masks (default: 2, correspond to basic Wiener filtering).
mask_extension as the model estimates instrument spectrograms up to frequency bin F (see above), ratio masks have to be extended for separation. mask_extension can be "zeros" (Masks are set to zero above F, thus discarding these frequencies) or "average" (For each time frame, masks above F are set to the average of the value of the estimated mask under F).
MWF whether to use Multi-channel Wiener Filtering with Norbert. Note that Multi-channel Wiener Filtering is not implemented in tensorflow and thus might be slow.

Training parameters

Key Description
train_csv path to a CSV file that describe the training dataset. The CSV file has columns <mix_name>_path (indicates the path to the mixture audio file), one column per instrument named <instrument_name>_path (indicates the path to the audio file of the isolated corresponding instrument), duration (indicate the duration in seconds of the audio files, that should be the same for the mixture and the separated instrument audio files).
validation_csv path to a CSV file that describe the validation dataset. The CSV file has the same format has the one provided in train_csv.
learning_rate Learning rate of the optimizer.
random_seed seed used for random generators (randomness is used in model initialization, data providing, and possibly dropout).
batch_size Size of the mini-batch used in optimization.
training_cache a path where to cache features for training. Features (spectrograms of mix and isolated instruments) are cached to avoid recurrent computation.
validation_cache a path where to cache features for validation.
train_max_steps number of training steps (see doc of tensorflow estimators).
throttle_secs see doc of tensorflow estimators.
save_checkpoints_steps see doc of tensorflow estimators.
save_summary_steps see doc of tensorflow estimators.

Implementation details

The pre-processing pipeline is coded using the tensorflow data API. It makes it possible features/labels (spectrograms) caching and asynchronous data providing. Training is done using the tensorflow estimator API. A custom estimator is designed. For efficiency reasons, the behavior of the estimator is not the same in train/evaluate mode and in predict mode: in train/evaluate mode, the estimator takes as input spectrograms (mix spectrogram as features, and target instrument spectrograms as labels) and outputs estimated spectrograms. This makes it possible to leverage spectrograms caching, avoiding features/labels recomputation during the whole training. In predict mode, the estimator takes the mixture waveform as input and outputs estimated isolated instrument waveforms. This makes it possible to run the whole separation graph on GPU, thus making the separation process very fast (more than 100x realtime).

Clone this wiki locally