-
Notifications
You must be signed in to change notification settings - Fork 11
/
utils.py
100 lines (84 loc) · 3.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Various utilities."""
from hashlib import sha256
from pathlib import Path
import typing as tp
import torch
import torchaudio
def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int):
# Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
# e.g., more than 2 frames per position.
# The core idea is to use a weight function that is a triangle,
# with a maximum value at the middle of the segment.
# We use this weighting when summing the frames, and divide by the sum of weights
# for each positions at the end. Thus:
# - if a frame is the only one to cover a position, the weighting is a no-op.
# - if 2 frames cover a position:
# ... ...
# / \/ \
# / /\ \
# S T , i.e. S offset of second frame starts, T end of first frame.
# Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
# After the final normalization, the weight of the second frame at position `t` is
# (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
#
# - if more than 2 frames overlap at a given point, we hope that by induction
# something sensible happens.
assert len(frames)
device = frames[0].device
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
frame_length = frames[0].shape[-1]
t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1]
weight = 0.5 - (t - 0.5).abs()
sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
out[..., offset:offset + frame_length] += weight[:frame_length] * frame
sum_weight[offset:offset + frame_length] += weight[:frame_length]
offset += stride
assert sum_weight.min() > 0
return out / sum_weight
def _get_checkpoint_url(root_url: str, checkpoint: str):
if not root_url.endswith('/'):
root_url += '/'
return root_url + checkpoint
def _check_checksum(path: Path, checksum: str):
sha = sha256()
with open(path, 'rb') as file:
while True:
buf = file.read(2**20)
if not buf:
break
sha.update(buf)
actual_checksum = sha.hexdigest()[:len(checksum)]
if actual_checksum != checksum:
raise RuntimeError(f'Invalid checksum for file {path}, '
f'expected {checksum} but got {actual_checksum}')
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
if target_channels == 1:
wav = wav.mean(0, keepdim=True)
elif target_channels == 2:
*shape, _, length = wav.shape
wav = wav.expand(*shape, target_channels, length)
elif wav.shape[0] == 1:
wav = wav.expand(target_channels, -1)
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
return wav
def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
sample_rate: int, rescale: bool = False):
limit = 0.99
mx = wav.abs().max()
if rescale:
wav = wav * min(limit / mx, 1)
else:
wav = wav.clamp(-limit, limit)
torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)