-
Notifications
You must be signed in to change notification settings - Fork 2.8k
3. Models
Spleeter is provided with 3 pretrained models:
Model | Description |
---|---|
2 stems model | Performs vocals + accompaniment separation |
4 stems model1 | Performs vocals + drums + bass + other separation |
5 stems model | Performs vocals + piano + drums + bass + other |
1 Pretrained 4 stems model performs very well on the standard musDB benchmark, as shown in the associated paper.
We use GitHub Releases to distribute the pretrained models. Every change in the model training will result in a new release of Spleeter. By default, the model provider will download models for the latest available version.
If you intend to use Spleeter for your research and want to enable full reproducibility, you should specify which version you used1.
1 specific model version could be setup using
GITHUB_RELEASE
environment variable.
For setting a model for training, you must create a JSON configuration file that contains the parameters that describe all that is need for training. The same file is then used at prediction time. Parameters that can be set in this file are the following:
Following parameters should be settled for model description :
Key | Description |
---|---|
instrument_list |
a list of instrument name (example: ["vocals", "accompaniment"] ). |
mix_name |
the prefix used for the mixture filename (default is "mix" ). |
model_dir |
path where to store the trained model. |
model |
a dictionary that describe the model. |
Last parameter model
should contains following keys :
Model key | Description |
---|---|
type |
The type of the model. The model is dynamically loaded from the model subfolder of spleeter . You can add your own model builder in this submodule. Module for building a U-net and a Bi-LSTM are provided. Default "unet.unet" . |
params |
a dictionary of parameters of the models. It depends on the model: this dictionary is passed to the model at model building time and can be used to store parameters specific to the model: number of layers, number of unit per layer... Check the code of the provided models for more details. |
Key | Description |
---|---|
sample_rate |
sample rate used for loading the audio files (files are resampled to this sample rate). |
frame_length |
frame length of the short time Fourier transform. |
frame_step |
frame step of the short time Fourier transform. |
T |
time length of the input spectrogram segment (expressed in short time Fourier transform frames). Must be a power of 2. |
F |
Number of frequency bins to be processed (frequency above F are note processed). Must be a power of 2. |
n_channels |
number of channels of input audio (audio with different number of channels are discarder at training/validation time and remixed to the provided number of channels at prediction time). |
n_chunks_per_song |
number of audio chunks to be used per songs. |
Key | Description |
---|---|
separation_exponent |
exponent applied to magnitude spectrogram to compute ratio masks (default: 2, correspond to basic Wiener filtering). |
mask_extension |
as the model estimates instrument spectrograms up to frequency bin F (see above), ratio masks have to be extended for separation. mask_extension can be "zeros" (Masks are set to zero above F , thus discarding these frequencies) or "average" (For each time frame, masks above F are set to the average of the value of the estimated mask under F ). |
MWF |
whether to use Multi-channel Wiener Filtering with Norbert. Note that Multi-channel Wiener Filtering is not implemented in tensorflow and thus might be slow. |
Key | Description |
---|---|
train_csv |
path to a CSV file that describe the training dataset. The CSV file has columns <mix_name>_path (indicates the path to the mixture audio file), one column per instrument named <instrument_name>_path (indicates the path to the audio file of the isolated corresponding instrument), duration (indicate the duration in seconds of the audio files, that should be the same for the mixture and the separated instrument audio files). |
validation_csv |
path to a CSV file that describe the validation dataset. The CSV file has the same format has the one provided in train_csv . |
learning_rate |
Learning rate of the optimizer. |
random_seed |
seed used for random generators (randomness is used in model initialization, data providing, and possibly dropout). |
batch_size |
Size of the mini-batch used in optimization. |
training_cache |
a path where to cache features for training. Features (spectrograms of mix and isolated instruments) are cached to avoid recurrent computation. |
validation_cache |
a path where to cache features for validation. |
train_max_steps |
number of training steps (see doc of tensorflow estimators). |
throttle_secs |
see doc of tensorflow estimators. |
save_checkpoints_steps |
see doc of tensorflow estimators. |
save_summary_steps |
see doc of tensorflow estimators. |
The pre-processing pipeline is coded using the tensorflow data API. It makes it possible features/labels (spectrograms) caching and asynchronous data providing. Training is done using the tensorflow estimator API. A custom estimator is designed. For efficiency reasons, the behavior of the estimator is not the same in train/evaluate mode and in predict mode: in train/evaluate mode, the estimator takes as input spectrograms (mix spectrogram as features, and target instrument spectrograms as labels) and outputs estimated spectrograms. This makes it possible to leverage spectrograms caching, avoiding features/labels recomputation during the whole training. In predict mode, the estimator takes the mixture waveform as input and outputs estimated isolated instrument waveforms. This makes it possible to run the whole separation graph on GPU, thus making the separation process very fast (more than 100x realtime).