-
Notifications
You must be signed in to change notification settings - Fork 89
/
time_inversion.py
75 lines (64 loc) · 2.23 KB
/
time_inversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from torch import Tensor
from typing import Optional
from ..core.transforms_interface import BaseWaveformTransform
from ..utils.object_dict import ObjectDict
class TimeInversion(BaseWaveformTransform):
"""
Reverse (invert) the audio along the time axis similar to random flip of
an image in the visual domain. This can be relevant in the context of audio
classification. It was successfully applied in the paper
AudioCLIP: Extending CLIP to Image, Text and Audio
https://arxiv.org/pdf/2106.13043.pdf
"""
supported_modes = {"per_batch", "per_example", "per_channel"}
supports_multichannel = True
requires_sample_rate = False
supports_target = True
requires_target = False
def __init__(
self,
mode: str = "per_example",
p: float = 0.5,
p_mode: str = None,
sample_rate: int = None,
target_rate: int = None,
output_type: Optional[str] = None,
):
"""
:param mode:
:param p:
:param p_mode:
:param sample_rate:
"""
super().__init__(
mode=mode,
p=p,
p_mode=p_mode,
sample_rate=sample_rate,
target_rate=target_rate,
output_type=output_type,
)
def apply_transform(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
) -> ObjectDict:
# torch.flip() is supposed to be slower than np.flip()
# An alternative is to use advanced indexing: https://github.com/pytorch/pytorch/issues/16424
# reverse_index = torch.arange(selected_samples.size(-1) - 1, -1, -1).to(selected_samples.device)
# transformed_samples = selected_samples[..., reverse_index]
# return transformed_samples
flipped_samples = torch.flip(samples, dims=(-1,))
if targets is None:
flipped_targets = targets
else:
flipped_targets = torch.flip(targets, dims=(-2,))
return ObjectDict(
samples=flipped_samples,
sample_rate=sample_rate,
targets=flipped_targets,
target_rate=target_rate,
)