From 75e37e758d0b813f8b9dd0d044c97740be94d738 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 6 Feb 2024 22:36:25 -0800 Subject: [PATCH] Get rid of --- nam/data.py | 52 ++++++++++---------------- nam/models/_base.py | 8 +++- nam/models/conv_net.py | 8 +++- nam/train/core.py | 13 ++++--- tests/test_bin/test_train/test_main.py | 2 +- tests/test_nam/test_data.py | 25 +++++++------ tests/test_nam/test_train/test_core.py | 3 +- 7 files changed, 56 insertions(+), 55 deletions(-) diff --git a/nam/data.py b/nam/data.py index a56ab4da..3f93099a 100644 --- a/nam/data.py +++ b/nam/data.py @@ -22,8 +22,6 @@ logger = logging.getLogger(__name__) -REQUIRED_RATE = 48_000 # FIXME not "required" anymore! -_DEFAULT_RATE = REQUIRED_RATE # There we go :) _REQUIRED_CHANNELS = 1 # Mono @@ -242,8 +240,7 @@ def __init__( x_path: Optional[Union[str, Path]] = None, y_path: Optional[Union[str, Path]] = None, input_gain: float = 0.0, - sample_rate: Optional[int] = None, - rate: Optional[int] = None, + sample_rate: Optional[float] = None, require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, ): """ @@ -283,16 +280,13 @@ def __init__( completely dry signal (i.e. connecting the interface output directly back into the input with which the guitar was originally recorded.) :param sample_rate: Sample rate for the data - :param rate: Sample rate for the data (deprecated) :param require_input_pre_silence: If provided, require that this much time (in seconds) preceding the start of the data set (`start`) have a silent input. If it's not, then raise an exception because the output due to it will leak into the data set that we're trying to use. If `None`, don't assert. """ self._validate_x_y(x, y) - self._sample_rate = self._validate_sample_rate( - sample_rate, rate, default=_DEFAULT_RATE - ) + self._sample_rate = sample_rate start, stop = self._validate_start_stop( x, y, @@ -302,7 +296,7 @@ def __init__( stop_samples, start_seconds, stop_seconds, - self._sample_rate, + self.sample_rate, ) if not isinstance(delay_interpolation_method, _DelayInterpolationMethod): delay_interpolation_method = _DelayInterpolationMethod( @@ -310,7 +304,7 @@ def __init__( ) if require_input_pre_silence is not None: self._validate_preceding_silence( - x, start, int(require_input_pre_silence * self._sample_rate) + x, start, require_input_pre_silence, self.sample_rate ) x, y = [z[start:stop] for z in (x, y)] if delay is not None and delay != 0: @@ -377,9 +371,12 @@ def y_offset(self) -> int: @classmethod def parse_config(cls, config): config = deepcopy(config) - sample_rate = cls._validate_sample_rate( - config.pop("sample_rate", None), config.pop("rate", None) - ) + if "rate" in config: + raise ValueError( + "use of `rate` was deprecated in version 0.8. Use `sample_rate` " + "instead." + ) + sample_rate = config.pop("sample_rate", None) x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate) sample_rate = x_wavinfo.rate try: @@ -469,25 +466,6 @@ def _apply_delay_float( y = _interpolate_delay(y, delay, method) return x, y - @classmethod - def _validate_sample_rate( - cls, sample_rate: Optional[float], rate: Optional[int], default=None - ) -> float: - if sample_rate is None and rate is None: # Default value - return default - if rate is not None: - if sample_rate is not None: - raise ValueError( - "Provided both sample_rate and rate. Provide only sample_rate!" - ) - else: - logger.warning( - "Use of 'rate' is deprecated and will be removed. Use sample_rate instead" - ) - return float(rate) - else: - return sample_rate - @classmethod def _validate_start_stop( cls, @@ -632,19 +610,27 @@ def _validate_inputs_after_processing(self, x, y, nx, ny): @classmethod def _validate_preceding_silence( - cls, x: torch.Tensor, start: Optional[int], silent_samples: int + cls, x: torch.Tensor, start: Optional[int], silent_seconds: float, sample_rate: Optional[float] ): """ Make sure that the input is silent before the starting index. If it's not, then the output from that non-silent input will leak into the data set and couldn't be predicted! + This assumes that silence is indeed required. If it's not, then don't call this! + See: Issue #252 :param x: Input :param start: Where the data starts :param silent_samples: How many are expected to be silent """ + if sample_rate is None: + raise ValueError( + f"Pre-silence was required for {silent_seconds} seconds, but no sample " + "rate was provided!" + ) + silent_samples = int(silent_seconds * sample_rate) if start is None: return raw_check_start = start - silent_samples diff --git a/nam/models/_base.py b/nam/models/_base.py index 6b30fa77..5eb6e174 100644 --- a/nam/models/_base.py +++ b/nam/models/_base.py @@ -17,7 +17,7 @@ import torch.nn as nn from .._core import InitializableFromConfig -from ..data import REQUIRED_RATE, wav_to_tensor +from ..data import wav_to_tensor from ._exportable import Exportable @@ -133,7 +133,11 @@ def _export_input_output_args(self) -> Tuple[Any]: def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: args = self._export_input_output_args() - rate = REQUIRED_RATE + rate = self.sample_rate + if rate is None: + raise RuntimeError( + "Cannot export model's input and output without a sample rate." + ) x = torch.cat( [ torch.zeros((rate,)), diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py index 746c016e..aef18e4d 100644 --- a/nam/models/conv_net.py +++ b/nam/models/conv_net.py @@ -17,7 +17,7 @@ from .. import __version__ -from ..data import REQUIRED_RATE, wav_to_tensor +from ..data import wav_to_tensor from ._activations import get_activation from ._base import BaseNet from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME @@ -217,7 +217,11 @@ def _export_input_signal(self): """ :return: (L,) """ - rate = REQUIRED_RATE + rate = self.sample_rate + if rate is None: + raise RuntimeError( + "Cannot export model's input and output without a sample rate." + ) return torch.cat( [ torch.zeros((rate,)), diff --git a/nam/train/core.py b/nam/train/core.py index b4118233..41f6692e 100644 --- a/nam/train/core.py +++ b/nam/train/core.py @@ -22,7 +22,7 @@ from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader -from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor +from ..data import Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model from ..models.losses import esr from ..util import filter_warnings @@ -30,6 +30,9 @@ __all__ = ["train"] +# Training using the simplified trainers in NAM is done at 48k. +STANDARD_SAMPLE_RATE = 48_000.0 + class Architecture(Enum): STANDARD = "standard" @@ -222,7 +225,7 @@ class _DataInfo(BaseModel): """ major_version: int - rate: Optional[int] + rate: Optional[float] t_blips: int first_blips_start: int t_validate: int @@ -234,7 +237,7 @@ class _DataInfo(BaseModel): _V1_DATA_INFO = _DataInfo( major_version=1, - rate=REQUIRED_RATE, + rate=STANDARD_SAMPLE_RATE, t_blips=48_000, first_blips_start=0, t_validate=432_000, @@ -254,7 +257,7 @@ class _DataInfo(BaseModel): # (3:09-3:11) Blips at 3:09.5 and 3:10.5 _V2_DATA_INFO = _DataInfo( major_version=2, - rate=REQUIRED_RATE, + rate=STANDARD_SAMPLE_RATE, t_blips=96_000, first_blips_start=0, t_validate=432_000, @@ -274,7 +277,7 @@ class _DataInfo(BaseModel): # (3:01-3:10) Validation 2 _V3_DATA_INFO = _DataInfo( major_version=3, - rate=REQUIRED_RATE, + rate=STANDARD_SAMPLE_RATE, t_blips=96_000, first_blips_start=480_000, t_validate=432_000, diff --git a/tests/test_bin/test_train/test_main.py b/tests/test_bin/test_train/test_main.py index 20b60883..e1f7613f 100644 --- a/tests/test_bin/test_train/test_main.py +++ b/tests/test_bin/test_train/test_main.py @@ -13,7 +13,7 @@ import pytest import torch -from nam.data import REQUIRED_RATE, np_to_wav +from nam.data import np_to_wav _BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path( "bin", "train", "main.py" diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py index 2a91b8e9..d91e9a2c 100644 --- a/tests/test_nam/test_data.py +++ b/tests/test_nam/test_data.py @@ -15,7 +15,8 @@ from nam import data -_sample_rates = (44_100, 48_000, 88_200, 96_000) +_SAMPLE_RATES = (44_100.0, 48_000.0, 88_200.0, 96_000.0) +_DEFAULT_SAMPLE_RATE = 48_000.0 class _XYMethod(Enum): @@ -85,11 +86,11 @@ def test_apply_delay_int_positive(self): def test_init(self): x, y = self._create_xy() - data.Dataset(x, y, 3, None) + data.Dataset(x, y, 3, None, sample_rate=_DEFAULT_SAMPLE_RATE) def test_init_sample_rate(self): x, y = self._create_xy() - sample_rate = 48_000.0 + sample_rate = _DEFAULT_SAMPLE_RATE d = data.Dataset(x, y, 3, None, sample_rate=sample_rate) assert hasattr(d, "sample_rate") assert isinstance(d.sample_rate, float) @@ -100,7 +101,7 @@ def test_init_zero_delay(self): Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed """ x, y = self._create_xy() - data.Dataset(x, y, 3, None, delay=0) + data.Dataset(x, y, 3, None, delay=0, sample_rate=_DEFAULT_SAMPLE_RATE) def test_input_gain(self): """ @@ -112,14 +113,16 @@ def test_input_gain(self): nx = 3 ny = None args = (x, y, nx, ny) - d1 = data.Dataset(*args) - d2 = data.Dataset(*args, input_gain=input_gain) + d1 = data.Dataset(*args, sample_rate=_DEFAULT_SAMPLE_RATE) + d2 = data.Dataset( + *args, sample_rate=_DEFAULT_SAMPLE_RATE, input_gain=input_gain + ) sample_x1 = d1[0][0] sample_x2 = d2[0][0] assert torch.allclose(sample_x1 * x_scale, sample_x2) - @pytest.mark.parametrize("sample_rate", _sample_rates) + @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES) def test_sample_rates(self, sample_rate: int): """ Test that datasets with various sample rates can be made @@ -155,7 +158,7 @@ def test_validate_start(self, n: int, start: int, valid: bool): """ def init(): - data.Dataset(x, y, nx, ny, start=start) + data.Dataset(x, y, nx, ny, start=start, sample_rate=_DEFAULT_SAMPLE_RATE) nx = 1 ny = None @@ -239,7 +242,7 @@ def f(): ) def test_validate_stop(self, n: int, stop: int, valid: bool): def init(): - data.Dataset(x, y, nx, ny, stop=stop) + data.Dataset(x, y, nx, ny, stop=stop, sample_rate=_DEFAULT_SAMPLE_RATE) nx = 1 ny = None @@ -257,7 +260,7 @@ def init(): ) def test_validate_x_y(self, lenx: int, leny: int, valid: bool): def init(): - data.Dataset(x, y, nx, ny) + data.Dataset(x, y, nx, ny, sample_rate=_DEFAULT_SAMPLE_RATE) x, y = self._create_xy() assert len(x) >= lenx, "Invalid test!" @@ -345,7 +348,7 @@ def test_np_to_wav_to_np(self, tmpdir): # Check if the two arrays are equal assert y == pytest.approx(x, abs=self.tolerance) - @pytest.mark.parametrize("sample_rate", _sample_rates) + @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES) def test_np_to_wav_to_np_sample_rates(self, sample_rate: int): with TemporaryDirectory() as tmpdir: # Create random numpy array diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py index 07b1bae0..bbcca592 100644 --- a/tests/test_nam/test_train/test_core.py +++ b/tests/test_nam/test_train/test_core.py @@ -180,7 +180,8 @@ def test_validation_preceded_by_silence(self): Dataset._validate_preceding_silence( x, data_info.validation_start, - int(_DEFAULT_REQUIRE_INPUT_PRE_SILENCE * data_info.rate), + _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, + data_info.rate, ) return C