Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python 3.13 support #9

Merged
merged 4 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/setup-uv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ runs:
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "0.5.10"
version: "0.5.14"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
python-version: ${{ matrix.python-version }}
8 changes: 8 additions & 0 deletions .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ jobs:
uses: ./.github/actions/setup-uv
- name: Lint check
run: make lint
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- name: Type check
run: make mypy
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
uv-resolution: ["lowest-direct", "highest"]
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.6
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.DEFAULT_GOAL := help
.PHONY: test dev-deps deps style lint install help
.PHONY: test dev-deps deps style lint mypy install help

help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
Expand All @@ -22,6 +22,9 @@ lint: ## run linter.
uv run --only-dev ruff check ${target_dirs}
uv run --only-dev ruff format --check ${target_dirs}

mypy: ## run type checker.
uv run --group mypy mypy trainer

install: ## install 🐸 Trainer for development.
uv sync --all-extras
uv run pre-commit install
4 changes: 2 additions & 2 deletions examples/train_simple_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def eval_step(self, batch, criterion):
valid = valid.type_as(imgs)

logits = self.discriminator(imgs_gen)
loss_gen = trainer.criterion(logits, valid)
loss_gen = criterion(logits, valid)
return {"model_outputs": logits}, {"loss_gen": loss_gen}

def get_optimizer(self):
def get_optimizer(self) -> list[torch.optim.Optimizer]:
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
return [discriminator_optimizer, generator_optimizer]
Expand Down
28 changes: 20 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ build-backend = "hatchling.build"

[project]
name = "coqui-tts-trainer"
version = "0.2.2"
version = "0.2.3"
description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui."
readme = "README.md"
requires-python = ">=3.10, <3.13"
requires-python = ">=3.10, <3.14"
license = {text = "Apache-2.0"}
authors = [
{name = "Eren Gölge", email = "egolge@coqui.ai"}
Expand All @@ -50,38 +50,41 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"coqpit-config>=0.1.1",
"coqpit-config>=0.2.0,<0.3.0",
"fsspec>=2023.6.0",
"numpy>=1.25.2; python_version < '3.12'",
"numpy>=1.26.0; python_version >= '3.12'",
"numpy>=1.26.0; python_version == '3.12'",
"numpy>=2.1.0; python_version >= '3.13'",
"packaging>=21.0",
"psutil>=5",
"soundfile>=0.12.0",
"tensorboard>=2.17.0",
"torch>=2.1; python_version < '3.12'",
"torch>=2.3; python_version >= '3.12'",
"torch>=2.3; python_version == '3.12'",
"torch>=2.6; python_version >= '3.13'",
]

[dependency-groups]
dev = [
"coverage>=7",
"pre-commit>=3",
"pytest>=8",
"ruff==0.8.3",
"ruff==0.8.6",
]
test = [
"accelerate>=0.20.0",
"torchvision>=0.15.1",
"torchvision>=0.21.0",
]
mypy = [
"matplotlib>=3.9.2",
"mlflow>=2.18.0",
"mypy>=1.13.0",
"mypy>=1.14.1",
"types-psutil>=6.1.0.20241102",
"wandb>=0.18.7",
]
Expand Down Expand Up @@ -155,6 +158,15 @@ skip_empty = true
source = ["trainer", "tests"]
command_line = "-m pytest"

[tool.mypy]
warn_unused_configs = true
disallow_subclassing_any = true
disallow_untyped_decorators = true
warn_redundant_casts = true
warn_unused_ignores = true
strict_equality = true
extra_checks = true

[[tool.mypy.overrides]]
module = [
"accelerate",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_train_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def optimize(self, batch, trainer):
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}

@torch.inference_mode()
def eval_step(self, batch, trainer):
def eval_step(self, batch, criterion):
imgs, _ = batch

# sample noise
Expand All @@ -313,7 +313,7 @@ def eval_step(self, batch, trainer):
valid = valid.type_as(imgs)

logits = self.discriminator(imgs_gen)
loss_gen = trainer.criterion(logits, valid)
loss_gen = criterion(logits, valid)
return {"model_outputs": logits}, {"loss_gen": loss_gen}

def get_optimizer(self):
Expand Down Expand Up @@ -426,7 +426,7 @@ def eval_step(self, batch, criterion):
valid = valid.type_as(imgs)

logits = self.discriminator(imgs_gen)
loss_gen = trainer.criterion(logits, valid)
loss_gen = criterion(logits, valid)
return {"model_outputs": logits}, {"loss_gen": loss_gen}

def get_optimizer(self):
Expand Down Expand Up @@ -541,7 +541,7 @@ def eval_step(self, batch, criterion):
valid = valid.type_as(imgs)

logits = self.discriminator(imgs_gen)
loss_gen = trainer.criterion(logits, valid)
loss_gen = criterion(logits, valid)
return {"model_outputs": logits}, {"loss_gen": loss_gen}

def get_optimizer(self):
Expand Down
84 changes: 20 additions & 64 deletions trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,10 @@ def parse_callbacks_dict(self, callbacks_dict: dict[str, Callback]) -> None:
raise ValueError(msg)

def on_init_start(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_start"):
trainer.model.module.on_init_start(trainer)
elif hasattr(trainer.model, "on_init_start"):
trainer.model.on_init_start(trainer)
trainer._get_model().on_init_start(trainer)

if hasattr(trainer.criterion, "on_init_start"):
trainer.criterion.on_init_start(trainer)
trainer.criterion.on_init_start(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_init_start"):
trainer.optimizer.on_init_start(trainer)
Expand All @@ -60,14 +56,10 @@ def on_init_start(self, trainer: "Trainer") -> None:
callback(trainer)

def on_init_end(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_end"):
trainer.model.module.on_init_end(trainer)
elif hasattr(trainer.model, "on_init_end"):
trainer.model.on_init_end(trainer)
trainer._get_model().on_init_end(trainer)

if hasattr(trainer.criterion, "on_init_end"):
trainer.criterion.on_init_end(trainer)
trainer.criterion.on_init_end(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_init_end"):
trainer.optimizer.on_init_end(trainer)
Expand All @@ -77,14 +69,10 @@ def on_init_end(self, trainer: "Trainer") -> None:
callback(trainer)

def on_epoch_start(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_start"):
trainer.model.module.on_epoch_start(trainer)
elif hasattr(trainer.model, "on_epoch_start"):
trainer.model.on_epoch_start(trainer)
trainer._get_model().on_epoch_start(trainer)

if hasattr(trainer.criterion, "on_epoch_start"):
trainer.criterion.on_epoch_start(trainer)
trainer.criterion.on_epoch_start(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_epoch_start"):
trainer.optimizer.on_epoch_start(trainer)
Expand All @@ -94,14 +82,10 @@ def on_epoch_start(self, trainer: "Trainer") -> None:
callback(trainer)

def on_epoch_end(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_end"):
trainer.model.module.on_epoch_end(trainer)
elif hasattr(trainer.model, "on_epoch_end"):
trainer.model.on_epoch_end(trainer)
trainer._get_model().on_epoch_end(trainer)

if hasattr(trainer.criterion, "on_epoch_end"):
trainer.criterion.on_epoch_end(trainer)
trainer.criterion.on_epoch_end(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_epoch_end"):
trainer.optimizer.on_epoch_end(trainer)
Expand All @@ -111,14 +95,10 @@ def on_epoch_end(self, trainer: "Trainer") -> None:
callback(trainer)

def on_train_epoch_start(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_epoch_start"):
trainer.model.module.on_train_epoch_start(trainer)
elif hasattr(trainer.model, "on_train_epoch_start"):
trainer.model.on_train_epoch_start(trainer)
trainer._get_model().on_train_epoch_start(trainer)

if hasattr(trainer.criterion, "on_train_epoch_start"):
trainer.criterion.on_train_epoch_start(trainer)
trainer.criterion.on_train_epoch_start(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_train_epoch_start"):
trainer.optimizer.on_train_epoch_start(trainer)
Expand All @@ -128,14 +108,10 @@ def on_train_epoch_start(self, trainer: "Trainer") -> None:
callback(trainer)

def on_train_epoch_end(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_epoch_end"):
trainer.model.module.on_train_epoch_end(trainer)
elif hasattr(trainer.model, "on_train_epoch_end"):
trainer.model.on_train_epoch_end(trainer)
trainer._get_model().on_train_epoch_end(trainer)

if hasattr(trainer.criterion, "on_train_epoch_end"):
trainer.criterion.on_train_epoch_end(trainer)
trainer.criterion.on_train_epoch_end(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_train_epoch_end"):
trainer.optimizer.on_train_epoch_end(trainer)
Expand All @@ -146,29 +122,17 @@ def on_train_epoch_end(self, trainer: "Trainer") -> None:

@staticmethod
def before_backward_pass(trainer: "Trainer", loss_dict: dict[str, Any]) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "before_backward_pass"):
trainer.model.module.before_backward_pass(loss_dict, trainer.optimizer)
elif hasattr(trainer.model, "before_backward_pass"):
trainer.model.before_backward_pass(loss_dict, trainer.optimizer)
trainer._get_model().before_backward_pass(loss_dict, trainer.optimizer)

@staticmethod
def before_gradient_clipping(trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "before_gradient_clipping"):
trainer.model.module.before_gradient_clipping()
elif hasattr(trainer.model, "before_gradient_clipping"):
trainer.model.before_gradient_clipping()
trainer._get_model().before_gradient_clipping()

def on_train_step_start(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_start"):
trainer.model.module.on_train_step_start(trainer)
elif hasattr(trainer.model, "on_train_step_start"):
trainer.model.on_train_step_start(trainer)
trainer._get_model().on_train_step_start(trainer)

if hasattr(trainer.criterion, "on_train_step_start"):
trainer.criterion.on_train_step_start(trainer)
trainer.criterion.on_train_step_start(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_train_step_start"):
trainer.optimizer.on_train_step_start(trainer)
Expand All @@ -178,14 +142,10 @@ def on_train_step_start(self, trainer: "Trainer") -> None:
callback(trainer)

def on_train_step_end(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_end"):
trainer.model.module.on_train_step_end(trainer)
elif hasattr(trainer.model, "on_train_step_end"):
trainer.model.on_train_step_end(trainer)
trainer._get_model().on_train_step_end(trainer)

if hasattr(trainer.criterion, "on_train_step_end"):
trainer.criterion.on_train_step_end(trainer)
trainer.criterion.on_train_step_end(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_train_step_end"):
trainer.optimizer.on_train_step_end(trainer)
Expand All @@ -195,14 +155,10 @@ def on_train_step_end(self, trainer: "Trainer") -> None:
callback(trainer)

def on_keyboard_interrupt(self, trainer: "Trainer") -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
trainer.model.module.on_keyboard_interrupt(trainer)
elif hasattr(trainer.model, "on_keyboard_interrupt"):
trainer.model.on_keyboard_interrupt(trainer)
trainer._get_model().on_keyboard_interrupt(trainer)

if hasattr(trainer.criterion, "on_keyboard_interrupt"):
trainer.criterion.on_keyboard_interrupt(trainer)
trainer.criterion.on_keyboard_interrupt(trainer) # type: ignore[operator]

if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
trainer.optimizer.on_keyboard_interrupt(trainer)
Expand Down
13 changes: 0 additions & 13 deletions trainer/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,6 @@ def is_pytorch_at_least_2_4() -> bool:
return Version(torch.__version__) >= Version("2.4")


def isimplemented(obj: Any, method_name: str) -> bool:
"""Check if a method is implemented in a class."""
if method_name in dir(obj) and callable(getattr(obj, method_name)):
try:
obj.__getattribute__(method_name)() # pylint: disable=bad-option-value, unnecessary-dunder-call
except NotImplementedError:
return False
except Exception:
return True
return True
return False


def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None:
return None
Expand Down
Loading