Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Proteus in Colab #378

Merged
merged 2 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nam/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.4"
__version__ = "0.8.0"
3 changes: 2 additions & 1 deletion nam/train/_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import NamedTuple

from ._version import Version
from ._version import PROTEUS_VERSION, Version

__all__ = ["INPUT_BASENAMES", "LATEST_VERSION", "VersionAndName"]

Expand All @@ -20,6 +20,7 @@ class VersionAndName(NamedTuple):
VersionAndName(Version(2, 0, 0), "v2_0_0.wav"),
VersionAndName(Version(1, 1, 1), "v1_1_1.wav"),
VersionAndName(Version(1, 0, 0), "v1.wav"),
VersionAndName(PROTEUS_VERSION, "Proteus_Capture.wav"),
)

LATEST_VERSION = INPUT_BASENAMES[0]
5 changes: 5 additions & 0 deletions nam/train/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Version utility
"""

__all__ = ["PROTEUS_VERSION", "Version"]


class Version:
def __init__(self, major: int, minor: int, patch: int):
Expand All @@ -30,3 +32,6 @@ def __lt__(self, other) -> bool:

def __str__(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"


PROTEUS_VERSION = Version(4, 0, 0)
11 changes: 8 additions & 3 deletions nam/train/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..models.metadata import UserMetadata
from ._names import INPUT_BASENAMES, LATEST_VERSION, Version
from ._version import Version
from ._version import PROTEUS_VERSION, Version
from .core import train


Expand All @@ -34,7 +34,9 @@ def _check_for_files() -> Tuple[Version, str]:
)
for input_version, input_basename in INPUT_BASENAMES:
if Path(input_basename).exists():
if input_version != LATEST_VERSION.version:
if input_version == PROTEUS_VERSION:
print(f"Using Proteus input file...")
elif input_version != LATEST_VERSION.version:
print(
f"WARNING: Using out-of-date input file {input_basename}. "
"Recommend downloading and using the latest version, "
Expand All @@ -49,7 +51,10 @@ def _check_for_files() -> Tuple[Version, str]:
raise FileNotFoundError(
f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}."
)
print(f"Found {input_basename}, version {input_version}")
if input_version != PROTEUS_VERSION:
print(f"Found {input_basename}, version {input_version}")
else:
print(f"Found Proteus input {input_basename}.")
return input_version, input_basename


Expand Down
6 changes: 3 additions & 3 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..models import Model
from ..models.losses import esr
from ..util import filter_warnings
from ._version import Version
from ._version import PROTEUS_VERSION, Version

__all__ = ["train"]

Expand Down Expand Up @@ -70,7 +70,7 @@ def assign_hash(path):
"7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
"ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0),
"36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0),
"80e224bd5622fd6153ff1fd9f34cb3bd": Version(4, 0, 0),
"80e224bd5622fd6153ff1fd9f34cb3bd": PROTEUS_VERSION,
}.get(file_hash)
if version is None:
print(
Expand Down Expand Up @@ -211,7 +211,7 @@ def assign_hash_v4(path) -> Hash:
}.get((start_hash_v1, end_hash_v1))
if version is not None:
return version
version = {"46151c8030798081acc00a725325a07d": Version(4, 0, 0)}.get(hash_v4)
version = {"46151c8030798081acc00a725325a07d": PROTEUS_VERSION}.get(hash_v4)
return version

version = detect_strong(input_path)
Expand Down
Loading