Skip to content
This repository has been archived by the owner on Aug 10, 2022. It is now read-only.

Various fixes and cleanup #3

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# spleeter-pytorch
Spleeter implementation in pytorch.

## Requirements

To install requirements, run `pip install -r requirements.txt`

## Usage

See [example](./test_estimator.py) for the usage how to use.
See [example](run_estimator.py) for the usage how to use.


## Note
Expand Down
Binary file removed output/out_0.wav
Binary file not shown.
Binary file removed output/out_1.wav
Binary file not shown.
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
numpy==1.18.5
tensorflow==2.3.1
torch==1.7.0
torchaudio==0.7.0
librosa==0.8.0
29 changes: 29 additions & 0 deletions run_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torchaudio
import soundfile as sf

from spleeter.estimator import Estimator
import os


es = Estimator(2, './checkpoints/2stems/model')


def main(original_audio='./audio_example.mp3', out_dir='./output'):
# load wav audio
wav, sr = torchaudio.load(original_audio)

# normalize audio
wav_torch = wav / (wav.max() + 1e-8)

wavs = es.separate(wav_torch)
for i in range(len(wavs)):
fname = os.path.join(out_dir, f'out_{i}.wav')
print('Writing:', fname)
new_wav = wavs[i].squeeze()
new_wav = new_wav.permute(1, 0)
new_wav = new_wav.numpy()
sf.write(fname, new_wav, sr)


if __name__ == '__main__':
main()
3 changes: 1 addition & 2 deletions spleeter/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchaudio.functional import istft

from .unet import UNet
from .util import tf2pytorch
Expand Down Expand Up @@ -88,7 +87,7 @@ def inverse_stft(self, stft):

pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
wav = istft(stft, self.win_length, hop_length=self.hop_length,
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length,
window=self.win)
return wav.detach()

Expand Down
2 changes: 0 additions & 2 deletions spleeter/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import tensorflow as tf

from .unet import UNet


def tf2pytorch(checkpoint_path, num_instrumments):
tf_vars = {}
Expand Down
21 changes: 0 additions & 21 deletions test_estimator.py

This file was deleted.