Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Get rid of REQUIRED_RATE #375

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 19 additions & 33 deletions nam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

logger = logging.getLogger(__name__)

REQUIRED_RATE = 48_000 # FIXME not "required" anymore!
_DEFAULT_RATE = REQUIRED_RATE # There we go :)
_REQUIRED_CHANNELS = 1 # Mono


Expand Down Expand Up @@ -242,8 +240,7 @@ def __init__(
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
input_gain: float = 0.0,
sample_rate: Optional[int] = None,
rate: Optional[int] = None,
sample_rate: Optional[float] = None,
require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
Expand Down Expand Up @@ -283,16 +280,13 @@ def __init__(
completely dry signal (i.e. connecting the interface output directly back
into the input with which the guitar was originally recorded.)
:param sample_rate: Sample rate for the data
:param rate: Sample rate for the data (deprecated)
:param require_input_pre_silence: If provided, require that this much time (in
seconds) preceding the start of the data set (`start`) have a silent input.
If it's not, then raise an exception because the output due to it will leak
into the data set that we're trying to use. If `None`, don't assert.
"""
self._validate_x_y(x, y)
self._sample_rate = self._validate_sample_rate(
sample_rate, rate, default=_DEFAULT_RATE
)
self._sample_rate = sample_rate
start, stop = self._validate_start_stop(
x,
y,
Expand All @@ -302,15 +296,15 @@ def __init__(
stop_samples,
start_seconds,
stop_seconds,
self._sample_rate,
self.sample_rate,
)
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
)
if require_input_pre_silence is not None:
self._validate_preceding_silence(
x, start, int(require_input_pre_silence * self._sample_rate)
x, start, require_input_pre_silence, self.sample_rate
)
x, y = [z[start:stop] for z in (x, y)]
if delay is not None and delay != 0:
Expand Down Expand Up @@ -377,9 +371,12 @@ def y_offset(self) -> int:
@classmethod
def parse_config(cls, config):
config = deepcopy(config)
sample_rate = cls._validate_sample_rate(
config.pop("sample_rate", None), config.pop("rate", None)
)
if "rate" in config:
raise ValueError(
"use of `rate` was deprecated in version 0.8. Use `sample_rate` "
"instead."
)
sample_rate = config.pop("sample_rate", None)
x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate)
sample_rate = x_wavinfo.rate
try:
Expand Down Expand Up @@ -469,25 +466,6 @@ def _apply_delay_float(
y = _interpolate_delay(y, delay, method)
return x, y

@classmethod
def _validate_sample_rate(
cls, sample_rate: Optional[float], rate: Optional[int], default=None
) -> float:
if sample_rate is None and rate is None: # Default value
return default
if rate is not None:
if sample_rate is not None:
raise ValueError(
"Provided both sample_rate and rate. Provide only sample_rate!"
)
else:
logger.warning(
"Use of 'rate' is deprecated and will be removed. Use sample_rate instead"
)
return float(rate)
else:
return sample_rate

@classmethod
def _validate_start_stop(
cls,
Expand Down Expand Up @@ -632,19 +610,27 @@ def _validate_inputs_after_processing(self, x, y, nx, ny):

@classmethod
def _validate_preceding_silence(
cls, x: torch.Tensor, start: Optional[int], silent_samples: int
cls, x: torch.Tensor, start: Optional[int], silent_seconds: float, sample_rate: Optional[float]
):
"""
Make sure that the input is silent before the starting index.
If it's not, then the output from that non-silent input will leak into the data
set and couldn't be predicted!

This assumes that silence is indeed required. If it's not, then don't call this!

See: Issue #252

:param x: Input
:param start: Where the data starts
:param silent_samples: How many are expected to be silent
"""
if sample_rate is None:
raise ValueError(
f"Pre-silence was required for {silent_seconds} seconds, but no sample "
"rate was provided!"
)
silent_samples = int(silent_seconds * sample_rate)
if start is None:
return
raw_check_start = start - silent_samples
Expand Down
8 changes: 6 additions & 2 deletions nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn

from .._core import InitializableFromConfig
from ..data import REQUIRED_RATE, wav_to_tensor
from ..data import wav_to_tensor
from ._exportable import Exportable


Expand Down Expand Up @@ -133,7 +133,11 @@ def _export_input_output_args(self) -> Tuple[Any]:

def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
args = self._export_input_output_args()
rate = REQUIRED_RATE
rate = self.sample_rate
if rate is None:
raise RuntimeError(
"Cannot export model's input and output without a sample rate."
)
x = torch.cat(
[
torch.zeros((rate,)),
Expand Down
8 changes: 6 additions & 2 deletions nam/models/conv_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


from .. import __version__
from ..data import REQUIRED_RATE, wav_to_tensor
from ..data import wav_to_tensor
from ._activations import get_activation
from ._base import BaseNet
from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME
Expand Down Expand Up @@ -217,7 +217,11 @@ def _export_input_signal(self):
"""
:return: (L,)
"""
rate = REQUIRED_RATE
rate = self.sample_rate
if rate is None:
raise RuntimeError(
"Cannot export model's input and output without a sample rate."
)
return torch.cat(
[
torch.zeros((rate,)),
Expand Down
13 changes: 8 additions & 5 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader

from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor
from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.losses import esr
from ..util import filter_warnings
from ._version import Version

__all__ = ["train"]

# Training using the simplified trainers in NAM is done at 48k.
STANDARD_SAMPLE_RATE = 48_000.0


class Architecture(Enum):
STANDARD = "standard"
Expand Down Expand Up @@ -222,7 +225,7 @@ class _DataInfo(BaseModel):
"""

major_version: int
rate: Optional[int]
rate: Optional[float]
t_blips: int
first_blips_start: int
t_validate: int
Expand All @@ -234,7 +237,7 @@ class _DataInfo(BaseModel):

_V1_DATA_INFO = _DataInfo(
major_version=1,
rate=REQUIRED_RATE,
rate=STANDARD_SAMPLE_RATE,
t_blips=48_000,
first_blips_start=0,
t_validate=432_000,
Expand All @@ -254,7 +257,7 @@ class _DataInfo(BaseModel):
# (3:09-3:11) Blips at 3:09.5 and 3:10.5
_V2_DATA_INFO = _DataInfo(
major_version=2,
rate=REQUIRED_RATE,
rate=STANDARD_SAMPLE_RATE,
t_blips=96_000,
first_blips_start=0,
t_validate=432_000,
Expand All @@ -274,7 +277,7 @@ class _DataInfo(BaseModel):
# (3:01-3:10) Validation 2
_V3_DATA_INFO = _DataInfo(
major_version=3,
rate=REQUIRED_RATE,
rate=STANDARD_SAMPLE_RATE,
t_blips=96_000,
first_blips_start=480_000,
t_validate=432_000,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bin/test_train/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest
import torch

from nam.data import REQUIRED_RATE, np_to_wav
from nam.data import np_to_wav

_BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path(
"bin", "train", "main.py"
Expand Down
25 changes: 14 additions & 11 deletions tests/test_nam/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from nam import data

_sample_rates = (44_100, 48_000, 88_200, 96_000)
_SAMPLE_RATES = (44_100.0, 48_000.0, 88_200.0, 96_000.0)
_DEFAULT_SAMPLE_RATE = 48_000.0


class _XYMethod(Enum):
Expand Down Expand Up @@ -85,11 +86,11 @@ def test_apply_delay_int_positive(self):

def test_init(self):
x, y = self._create_xy()
data.Dataset(x, y, 3, None)
data.Dataset(x, y, 3, None, sample_rate=_DEFAULT_SAMPLE_RATE)

def test_init_sample_rate(self):
x, y = self._create_xy()
sample_rate = 48_000.0
sample_rate = _DEFAULT_SAMPLE_RATE
d = data.Dataset(x, y, 3, None, sample_rate=sample_rate)
assert hasattr(d, "sample_rate")
assert isinstance(d.sample_rate, float)
Expand All @@ -100,7 +101,7 @@ def test_init_zero_delay(self):
Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed
"""
x, y = self._create_xy()
data.Dataset(x, y, 3, None, delay=0)
data.Dataset(x, y, 3, None, delay=0, sample_rate=_DEFAULT_SAMPLE_RATE)

def test_input_gain(self):
"""
Expand All @@ -112,14 +113,16 @@ def test_input_gain(self):
nx = 3
ny = None
args = (x, y, nx, ny)
d1 = data.Dataset(*args)
d2 = data.Dataset(*args, input_gain=input_gain)
d1 = data.Dataset(*args, sample_rate=_DEFAULT_SAMPLE_RATE)
d2 = data.Dataset(
*args, sample_rate=_DEFAULT_SAMPLE_RATE, input_gain=input_gain
)

sample_x1 = d1[0][0]
sample_x2 = d2[0][0]
assert torch.allclose(sample_x1 * x_scale, sample_x2)

@pytest.mark.parametrize("sample_rate", _sample_rates)
@pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
def test_sample_rates(self, sample_rate: int):
"""
Test that datasets with various sample rates can be made
Expand Down Expand Up @@ -155,7 +158,7 @@ def test_validate_start(self, n: int, start: int, valid: bool):
"""

def init():
data.Dataset(x, y, nx, ny, start=start)
data.Dataset(x, y, nx, ny, start=start, sample_rate=_DEFAULT_SAMPLE_RATE)

nx = 1
ny = None
Expand Down Expand Up @@ -239,7 +242,7 @@ def f():
)
def test_validate_stop(self, n: int, stop: int, valid: bool):
def init():
data.Dataset(x, y, nx, ny, stop=stop)
data.Dataset(x, y, nx, ny, stop=stop, sample_rate=_DEFAULT_SAMPLE_RATE)

nx = 1
ny = None
Expand All @@ -257,7 +260,7 @@ def init():
)
def test_validate_x_y(self, lenx: int, leny: int, valid: bool):
def init():
data.Dataset(x, y, nx, ny)
data.Dataset(x, y, nx, ny, sample_rate=_DEFAULT_SAMPLE_RATE)

x, y = self._create_xy()
assert len(x) >= lenx, "Invalid test!"
Expand Down Expand Up @@ -345,7 +348,7 @@ def test_np_to_wav_to_np(self, tmpdir):
# Check if the two arrays are equal
assert y == pytest.approx(x, abs=self.tolerance)

@pytest.mark.parametrize("sample_rate", _sample_rates)
@pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
def test_np_to_wav_to_np_sample_rates(self, sample_rate: int):
with TemporaryDirectory() as tmpdir:
# Create random numpy array
Expand Down
3 changes: 2 additions & 1 deletion tests/test_nam/test_train/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def test_validation_preceded_by_silence(self):
Dataset._validate_preceding_silence(
x,
data_info.validation_start,
int(_DEFAULT_REQUIRE_INPUT_PRE_SILENCE * data_info.rate),
_DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
data_info.rate,
)

return C
Expand Down
Loading