-
Notifications
You must be signed in to change notification settings - Fork 37
/
filter_audio_separation.py
95 lines (73 loc) · 4.12 KB
/
filter_audio_separation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from demucs import pretrained
from demucs.apply import apply_model
from demucs.audio import convert_audio
from datasets import load_dataset
from multiprocess import set_start_method
import torch
import argparse
from datasets import Audio
demucs = pretrained.get_model('htdemucs')
source = demucs.sources
def wrap_audio(audio, sr):
return {
"array": audio.cpu().numpy(),
"sampling_rate": sr
}
# TODO(YL): make compatible with other naming and stems
def filter_stems(batch, rank=None):
if rank is not None:
# move the model to the right GPU if not there already
device = f"cuda:{(rank or 0)% torch.cuda.device_count()}"
# move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway
demucs.to(device)
if isinstance(batch["audio"], list):
wavs = [convert_audio(
torch.tensor(audio["array"][None], device=device).to(torch.float32), audio["sampling_rate"], demucs.samplerate, demucs.audio_channels).T for audio in batch["audio"]]
wavs_length = [audio.shape[0] for audio in wavs]
wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0).transpose(1,2)
stems = apply_model(demucs, wavs)
batch["vocals"] = [wrap_audio(s[-1,:,:length].mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]
batch["others"] = [wrap_audio(s[:-1, :,:length].sum(0).mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]
else:
audio = torch.tensor(batch["audio"]["array"].squeeze(), device=device).to(torch.float32)
sample_rate = batch["audio"]["sampling_rate"]
audio = convert_audio(
audio, sample_rate, demucs.samplerate, demucs.audio_channels)
stems = apply_model(demucs, audio[None])
batch["vocals"] = wrap_audio(stems[0,-1].mean(0), demucs.samplerate)
batch["others"] = wrap_audio(stems[0, :-1].sum(0).mean(0), demucs.samplerate)
return batch
if __name__ == "__main__":
set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument("dataset_name", type=str, help="Path or name of the dataset. See: https://huggingface.co/docs/datasets/v2.17.0/en/package_reference/loading_methods#datasets.load_dataset.path")
parser.add_argument("--configuration", default=None, type=str, help="Dataset configuration to use, if necessary.")
parser.add_argument("--output_dir", default=None, type=str, help="If specified, save the dataset on disk with this path.")
parser.add_argument("--repo_id", default=None, type=str, help="If specified, push the model to the hub.")
parser.add_argument("--audio_column_name", default="audio", type=str, help="Column name of the audio column to be separated.")
parser.add_argument("--batch_size", default=8, type=int, help="Batch size. Speeds up operations on GPU.")
parser.add_argument("--num_workers_per_gpu", default=1, type=int, help="Number of workers per GPU for transformations that uses GPUs if GPUs are available. Defaults to 1 if some are avaiable. Useful if you want multiple processes per GPUs to maximise GPU usage.")
args = parser.parse_args()
if args.configuration:
dataset = load_dataset(args.dataset_name, args.configuration)
else:
dataset = load_dataset(args.dataset_name)
num_proc = torch.cuda.device_count()*args.num_workers_per_gpu if torch.cuda.device_count() >= 1 else None
updated_dataset = dataset.map(
filter_stems,
batched=True,
batch_size=args.batch_size,
with_rank=True,
num_proc=num_proc,
)
updated_dataset = updated_dataset.cast_column("vocals", Audio())
updated_dataset = updated_dataset.cast_column("others", Audio())
if args.output_dir:
print("Saving to disk...")
updated_dataset.save_to_disk(args.output_dir)
if args.repo_id:
print("Pushing to the hub...")
if args.configuration:
updated_dataset.push_to_hub(args.repo_id, args.configuration)
else:
updated_dataset.push_to_hub(args.repo_id)