From 8a1710bdabee193eb86ee913bfac9bc81727796c Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 18 Oct 2024 14:14:40 +0200 Subject: [PATCH 1/4] build: enable python 3.13 support --- .github/workflows/tests.yml | 2 +- pyproject.toml | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 86750d1..72f763c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] uv-resolution: ["lowest-direct", "highest"] steps: - uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index 9b88fe0..d19a481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ name = "coqui-tts-trainer" version = "0.2.2" description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui." readme = "README.md" -requires-python = ">=3.10, <3.13" +requires-python = ">=3.10, <3.14" license = {text = "Apache-2.0"} authors = [ {name = "Eren Gölge", email = "egolge@coqui.ai"} @@ -50,21 +50,24 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Software Development", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "coqpit-config>=0.1.1", + "coqpit-config>=0.2.0,<0.3.0", "fsspec>=2023.6.0", "numpy>=1.25.2; python_version < '3.12'", - "numpy>=1.26.0; python_version >= '3.12'", + "numpy>=1.26.0; python_version == '3.12'", + "numpy>=2.1.0; python_version >= '3.13'", "packaging>=21.0", "psutil>=5", "soundfile>=0.12.0", "tensorboard>=2.17.0", "torch>=2.1; python_version < '3.12'", - "torch>=2.3; python_version >= '3.12'", + "torch>=2.3; python_version == '3.12'", + "torch>=2.6; python_version >= '3.13'", ] [dependency-groups] @@ -76,7 +79,7 @@ dev = [ ] test = [ "accelerate>=0.20.0", - "torchvision>=0.15.1", + "torchvision>=0.21.0", ] mypy = [ "matplotlib>=3.9.2", From c4cf9d32fc1675be37b157172e0f5257e34c6010 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 6 Jan 2025 16:53:42 +0100 Subject: [PATCH 2/4] ci: add mypy check --- .github/actions/setup-uv/action.yml | 2 +- .github/workflows/style_check.yml | 8 ++++++++ .pre-commit-config.yaml | 2 +- Makefile | 5 ++++- pyproject.toml | 13 +++++++++++-- 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/.github/actions/setup-uv/action.yml b/.github/actions/setup-uv/action.yml index 5b2734e..ba431bf 100644 --- a/.github/actions/setup-uv/action.yml +++ b/.github/actions/setup-uv/action.yml @@ -5,7 +5,7 @@ runs: - name: Install uv uses: astral-sh/setup-uv@v4 with: - version: "0.5.10" + version: "0.5.14" enable-cache: true cache-dependency-glob: "**/pyproject.toml" python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 0342680..8a0a789 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -15,3 +15,11 @@ jobs: uses: ./.github/actions/setup-uv - name: Lint check run: make lint + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup uv + uses: ./.github/actions/setup-uv + - name: Type check + run: make mypy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f7d6c1d..73e0131 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.3 + rev: v0.8.6 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/Makefile b/Makefile index 1869cad..e2b0b8f 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .DEFAULT_GOAL := help -.PHONY: test dev-deps deps style lint install help +.PHONY: test dev-deps deps style lint mypy install help help: @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @@ -22,6 +22,9 @@ lint: ## run linter. uv run --only-dev ruff check ${target_dirs} uv run --only-dev ruff format --check ${target_dirs} +mypy: ## run type checker. + uv run --group mypy mypy trainer + install: ## install 🐸 Trainer for development. uv sync --all-extras uv run pre-commit install diff --git a/pyproject.toml b/pyproject.toml index d19a481..7b2cab8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ dev = [ "coverage>=7", "pre-commit>=3", "pytest>=8", - "ruff==0.8.3", + "ruff==0.8.6", ] test = [ "accelerate>=0.20.0", @@ -84,7 +84,7 @@ test = [ mypy = [ "matplotlib>=3.9.2", "mlflow>=2.18.0", - "mypy>=1.13.0", + "mypy>=1.14.1", "types-psutil>=6.1.0.20241102", "wandb>=0.18.7", ] @@ -158,6 +158,15 @@ skip_empty = true source = ["trainer", "tests"] command_line = "-m pytest" +[tool.mypy] +warn_unused_configs = true +disallow_subclassing_any = true +disallow_untyped_decorators = true +warn_redundant_casts = true +warn_unused_ignores = true +strict_equality = true +extra_checks = true + [[tool.mypy.overrides]] module = [ "accelerate", From 1a42af07508e9cef56e876a57b95f8026b2ec9bb Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 30 Jan 2025 13:04:49 +0100 Subject: [PATCH 3/4] refactor: simplify ddp wrapping and implementation checks --- examples/train_simple_gan.py | 4 +- tests/test_train_gan.py | 8 +- trainer/callbacks.py | 84 ++------ trainer/generic_utils.py | 13 -- trainer/io.py | 9 +- trainer/model.py | 72 +++++-- trainer/trainer.py | 384 ++++++++++++++--------------------- 7 files changed, 239 insertions(+), 335 deletions(-) diff --git a/examples/train_simple_gan.py b/examples/train_simple_gan.py index be93983..1176666 100644 --- a/examples/train_simple_gan.py +++ b/examples/train_simple_gan.py @@ -137,10 +137,10 @@ def eval_step(self, batch, criterion): valid = valid.type_as(imgs) logits = self.discriminator(imgs_gen) - loss_gen = trainer.criterion(logits, valid) + loss_gen = criterion(logits, valid) return {"model_outputs": logits}, {"loss_gen": loss_gen} - def get_optimizer(self): + def get_optimizer(self) -> list[torch.optim.Optimizer]: discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) return [discriminator_optimizer, generator_optimizer] diff --git a/tests/test_train_gan.py b/tests/test_train_gan.py index 555241c..25955fc 100644 --- a/tests/test_train_gan.py +++ b/tests/test_train_gan.py @@ -301,7 +301,7 @@ def optimize(self, batch, trainer): return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} @torch.inference_mode() - def eval_step(self, batch, trainer): + def eval_step(self, batch, criterion): imgs, _ = batch # sample noise @@ -313,7 +313,7 @@ def eval_step(self, batch, trainer): valid = valid.type_as(imgs) logits = self.discriminator(imgs_gen) - loss_gen = trainer.criterion(logits, valid) + loss_gen = criterion(logits, valid) return {"model_outputs": logits}, {"loss_gen": loss_gen} def get_optimizer(self): @@ -426,7 +426,7 @@ def eval_step(self, batch, criterion): valid = valid.type_as(imgs) logits = self.discriminator(imgs_gen) - loss_gen = trainer.criterion(logits, valid) + loss_gen = criterion(logits, valid) return {"model_outputs": logits}, {"loss_gen": loss_gen} def get_optimizer(self): @@ -541,7 +541,7 @@ def eval_step(self, batch, criterion): valid = valid.type_as(imgs) logits = self.discriminator(imgs_gen) - loss_gen = trainer.criterion(logits, valid) + loss_gen = criterion(logits, valid) return {"model_outputs": logits}, {"loss_gen": loss_gen} def get_optimizer(self): diff --git a/trainer/callbacks.py b/trainer/callbacks.py index 77f7b6e..cbefe58 100644 --- a/trainer/callbacks.py +++ b/trainer/callbacks.py @@ -43,14 +43,10 @@ def parse_callbacks_dict(self, callbacks_dict: dict[str, Callback]) -> None: raise ValueError(msg) def on_init_start(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_init_start"): - trainer.model.module.on_init_start(trainer) - elif hasattr(trainer.model, "on_init_start"): - trainer.model.on_init_start(trainer) + trainer._get_model().on_init_start(trainer) if hasattr(trainer.criterion, "on_init_start"): - trainer.criterion.on_init_start(trainer) + trainer.criterion.on_init_start(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_init_start"): trainer.optimizer.on_init_start(trainer) @@ -60,14 +56,10 @@ def on_init_start(self, trainer: "Trainer") -> None: callback(trainer) def on_init_end(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_init_end"): - trainer.model.module.on_init_end(trainer) - elif hasattr(trainer.model, "on_init_end"): - trainer.model.on_init_end(trainer) + trainer._get_model().on_init_end(trainer) if hasattr(trainer.criterion, "on_init_end"): - trainer.criterion.on_init_end(trainer) + trainer.criterion.on_init_end(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_init_end"): trainer.optimizer.on_init_end(trainer) @@ -77,14 +69,10 @@ def on_init_end(self, trainer: "Trainer") -> None: callback(trainer) def on_epoch_start(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_epoch_start"): - trainer.model.module.on_epoch_start(trainer) - elif hasattr(trainer.model, "on_epoch_start"): - trainer.model.on_epoch_start(trainer) + trainer._get_model().on_epoch_start(trainer) if hasattr(trainer.criterion, "on_epoch_start"): - trainer.criterion.on_epoch_start(trainer) + trainer.criterion.on_epoch_start(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_epoch_start"): trainer.optimizer.on_epoch_start(trainer) @@ -94,14 +82,10 @@ def on_epoch_start(self, trainer: "Trainer") -> None: callback(trainer) def on_epoch_end(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_epoch_end"): - trainer.model.module.on_epoch_end(trainer) - elif hasattr(trainer.model, "on_epoch_end"): - trainer.model.on_epoch_end(trainer) + trainer._get_model().on_epoch_end(trainer) if hasattr(trainer.criterion, "on_epoch_end"): - trainer.criterion.on_epoch_end(trainer) + trainer.criterion.on_epoch_end(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_epoch_end"): trainer.optimizer.on_epoch_end(trainer) @@ -111,14 +95,10 @@ def on_epoch_end(self, trainer: "Trainer") -> None: callback(trainer) def on_train_epoch_start(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_train_epoch_start"): - trainer.model.module.on_train_epoch_start(trainer) - elif hasattr(trainer.model, "on_train_epoch_start"): - trainer.model.on_train_epoch_start(trainer) + trainer._get_model().on_train_epoch_start(trainer) if hasattr(trainer.criterion, "on_train_epoch_start"): - trainer.criterion.on_train_epoch_start(trainer) + trainer.criterion.on_train_epoch_start(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_train_epoch_start"): trainer.optimizer.on_train_epoch_start(trainer) @@ -128,14 +108,10 @@ def on_train_epoch_start(self, trainer: "Trainer") -> None: callback(trainer) def on_train_epoch_end(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_train_epoch_end"): - trainer.model.module.on_train_epoch_end(trainer) - elif hasattr(trainer.model, "on_train_epoch_end"): - trainer.model.on_train_epoch_end(trainer) + trainer._get_model().on_train_epoch_end(trainer) if hasattr(trainer.criterion, "on_train_epoch_end"): - trainer.criterion.on_train_epoch_end(trainer) + trainer.criterion.on_train_epoch_end(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_train_epoch_end"): trainer.optimizer.on_train_epoch_end(trainer) @@ -146,29 +122,17 @@ def on_train_epoch_end(self, trainer: "Trainer") -> None: @staticmethod def before_backward_pass(trainer: "Trainer", loss_dict: dict[str, Any]) -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "before_backward_pass"): - trainer.model.module.before_backward_pass(loss_dict, trainer.optimizer) - elif hasattr(trainer.model, "before_backward_pass"): - trainer.model.before_backward_pass(loss_dict, trainer.optimizer) + trainer._get_model().before_backward_pass(loss_dict, trainer.optimizer) @staticmethod def before_gradient_clipping(trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "before_gradient_clipping"): - trainer.model.module.before_gradient_clipping() - elif hasattr(trainer.model, "before_gradient_clipping"): - trainer.model.before_gradient_clipping() + trainer._get_model().before_gradient_clipping() def on_train_step_start(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_train_step_start"): - trainer.model.module.on_train_step_start(trainer) - elif hasattr(trainer.model, "on_train_step_start"): - trainer.model.on_train_step_start(trainer) + trainer._get_model().on_train_step_start(trainer) if hasattr(trainer.criterion, "on_train_step_start"): - trainer.criterion.on_train_step_start(trainer) + trainer.criterion.on_train_step_start(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_train_step_start"): trainer.optimizer.on_train_step_start(trainer) @@ -178,14 +142,10 @@ def on_train_step_start(self, trainer: "Trainer") -> None: callback(trainer) def on_train_step_end(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_train_step_end"): - trainer.model.module.on_train_step_end(trainer) - elif hasattr(trainer.model, "on_train_step_end"): - trainer.model.on_train_step_end(trainer) + trainer._get_model().on_train_step_end(trainer) if hasattr(trainer.criterion, "on_train_step_end"): - trainer.criterion.on_train_step_end(trainer) + trainer.criterion.on_train_step_end(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_train_step_end"): trainer.optimizer.on_train_step_end(trainer) @@ -195,14 +155,10 @@ def on_train_step_end(self, trainer: "Trainer") -> None: callback(trainer) def on_keyboard_interrupt(self, trainer: "Trainer") -> None: - if hasattr(trainer.model, "module"): - if hasattr(trainer.model.module, "on_keyboard_interrupt"): - trainer.model.module.on_keyboard_interrupt(trainer) - elif hasattr(trainer.model, "on_keyboard_interrupt"): - trainer.model.on_keyboard_interrupt(trainer) + trainer._get_model().on_keyboard_interrupt(trainer) if hasattr(trainer.criterion, "on_keyboard_interrupt"): - trainer.criterion.on_keyboard_interrupt(trainer) + trainer.criterion.on_keyboard_interrupt(trainer) # type: ignore[operator] if hasattr(trainer.optimizer, "on_keyboard_interrupt"): trainer.optimizer.on_keyboard_interrupt(trainer) diff --git a/trainer/generic_utils.py b/trainer/generic_utils.py index 46a407e..4d40170 100644 --- a/trainer/generic_utils.py +++ b/trainer/generic_utils.py @@ -23,19 +23,6 @@ def is_pytorch_at_least_2_4() -> bool: return Version(torch.__version__) >= Version("2.4") -def isimplemented(obj: Any, method_name: str) -> bool: - """Check if a method is implemented in a class.""" - if method_name in dir(obj) and callable(getattr(obj, method_name)): - try: - obj.__getattribute__(method_name)() # pylint: disable=bad-option-value, unnecessary-dunder-call - except NotImplementedError: - return False - except Exception: - return True - return True - return False - - def to_cuda(x: torch.Tensor) -> torch.Tensor: if x is None: return None diff --git a/trainer/io.py b/trainer/io.py index d436d22..d456a1d 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -17,6 +17,7 @@ from trainer._types import LossDict from trainer.generic_utils import is_pytorch_at_least_2_4 from trainer.logger import logger +from trainer.model import TrainerModel # `torch.serialization.add_safe_globals` is needed for weights_only=True to work # with all Coqui models and only available from Pytorch >=2.4 @@ -124,7 +125,7 @@ def save_fsspec(state: Any, path: str | os.PathLike[Any], **kwargs: Any) -> None def save_model( config: dict[str, Any] | Coqpit, - model: torch.nn.Module, + model: TrainerModel, optimizer: torch.optim.Optimizer | list[torch.optim.Optimizer], scaler: "torch.GradScaler | None", current_step: int, @@ -133,7 +134,7 @@ def save_model( save_func: Callable[[Any, str | os.PathLike[Any]], None] | None = None, **kwargs: Any, ) -> None: - model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict() + model_state = model.state_dict() optimizer_state: StateDict | list[StateDict] | None if isinstance(optimizer, list): optimizer_state = [optim.state_dict() for optim in optimizer] @@ -169,7 +170,7 @@ def save_model( def save_checkpoint( config: dict[str, Any] | Coqpit, - model: torch.nn.Module, + model: TrainerModel, optimizer: torch.optim.Optimizer | list[torch.optim.Optimizer], scaler: "torch.GradScaler | None", current_step: int, @@ -202,7 +203,7 @@ def save_best_model( current_loss: LossDict | float, best_loss: LossDict | float, config: dict[str, Any] | Coqpit, - model: torch.nn.Module, + model: TrainerModel, optimizer: torch.optim.Optimizer | list[torch.optim.Optimizer], scaler: "torch.GradScaler | None", current_step: int, diff --git a/trainer/model.py b/trainer/model.py index 097e5ae..58a4f44 100644 --- a/trainer/model.py +++ b/trainer/model.py @@ -8,9 +8,6 @@ from trainer.trainer import Trainer -# pylint: skip-file - - class TrainerModel(ABC, nn.Module): """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" @@ -128,6 +125,24 @@ def get_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader[An msg = " [!] `get_data_loader()` is not implemented." raise NotImplementedError(msg) + def get_train_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader[Any]: + raise NotImplementedError + + def get_eval_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader[Any]: + raise NotImplementedError + + def get_test_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader[Any]: + raise NotImplementedError + + def test_run(self, *args: Any, **kwargs: Any): + raise NotImplementedError + + def test(self, assets: dict[str, Any], data_loader: torch.utils.data.DataLoader[Any], outputs: Any | None = None): + raise NotImplementedError + + def test_log(self, *args: Any, **kwargs: Any): + raise NotImplementedError + def init_for_training(self) -> None: """Initialize model for training.""" @@ -172,22 +187,51 @@ def scaled_backward( # main model optimizer step loss.backward() - # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: - # """Setup an return optimizer or optimizers.""" - # ... + def get_optimizer(self) -> torch.optim.Optimizer | list[torch.optim.Optimizer]: + """Setup an return optimizer or optimizers.""" + raise NotImplementedError - # def get_lr(self) -> Union[float, List[float]]: - # """Return learning rate(s). + def get_lr(self) -> float | list[float]: + """Return learning rate(s). - # Returns: - # Union[float, List[float]]: Model's initial learning rates. - # """ - # ... + Returns: + Union[float, List[float]]: Model's initial learning rates. + """ + raise NotImplementedError - # def get_scheduler(self, optimizer: torch.optim.Optimizer): - # ... + def get_scheduler( + self, optimizer: torch.optim.Optimizer | list[torch.optim.Optimizer] | dict[str, torch.optim.Optimizer] + ): + raise NotImplementedError def get_criterion(self) -> nn.Module | list[nn.Module]: """Return model criterion.""" msg = "`get_criterion` is not implemented." raise NotImplementedError(msg) + + ## Callbacks + def on_init_start(self, trainer: "Trainer") -> None: ... + + def on_init_end(self, trainer: "Trainer") -> None: ... + + def on_epoch_start(self, trainer: "Trainer") -> None: ... + + def on_epoch_end(self, trainer: "Trainer") -> None: ... + + def on_train_epoch_start(self, trainer: "Trainer") -> None: ... + + def on_train_epoch_end(self, trainer: "Trainer") -> None: ... + + @staticmethod + def before_backward_pass( + loss_dict: dict[str, Any], optimizer: torch.optim.Optimizer | list[torch.optim.Optimizer] + ) -> None: ... + + @staticmethod + def before_gradient_clipping() -> None: ... + + def on_train_step_start(self, trainer: "Trainer") -> None: ... + + def on_train_step_end(self, trainer: "Trainer") -> None: ... + + def on_keyboard_interrupt(self, trainer: "Trainer") -> None: ... diff --git a/trainer/trainer.py b/trainer/trainer.py index eef0769..4453a66 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -29,7 +29,6 @@ get_git_branch, is_pytorch_at_least_2_3, is_pytorch_at_least_2_4, - isimplemented, remove_experiment_folder, set_partial_state_dict, to_cuda, @@ -292,8 +291,7 @@ def __init__( # pylint: disable=dangerous-default-value raise ValueError(msg) # init model's training assets - if isimplemented(self.model, "init_for_training"): - self.model.init_for_training() + self.model.init_for_training() # setup criterion self.criterion = self.get_criterion(self.model) @@ -321,25 +319,13 @@ def __init__( # pylint: disable=dangerous-default-value # setup optimizer self.optimizer = self.get_optimizer(self.model, self.config) - # If multiple-optimizer setup with grad accumulation and without custom optimize method raise an error - if ( - self.grad_accum_steps != 1 - and isinstance(self.optimizer, list) - and not isimplemented(self.model, "optimize") - ): - msg = " [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom method called `optimize` that need to deal with dangling gradients in multiple-optimizer setup!" - raise ValueError(msg) - # CALLBACK self.callbacks = TrainerCallback() self.callbacks.parse_callbacks_dict(callbacks) self.callbacks.on_init_start(self) # init AMP - if self.use_amp_scaler: - self.scaler = GradScaler() - else: - self.scaler = None + self.scaler = GradScaler() if self.use_amp_scaler else None # restore model if self.args.restore_path: @@ -354,8 +340,10 @@ def __init__( # pylint: disable=dangerous-default-value ) # DISTRIBUTED + self.wrapped_model: TrainerModel | None = None if self.use_pt_ddp: - self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) # type: ignore[assignment] + ddp_model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) + self.wrapped_model = ddp_model.module # cast(TrainerModel, ddp_model.module) # setup accelerator self.setup_accelerate() @@ -696,27 +684,26 @@ def _get_loader( verbose: bool, num_gpus: int, ) -> DataLoader[Any]: - if num_gpus > 1: - if isimplemented(model.module, "get_data_loader"): - loader = model.module.get_data_loader( - config, - assets, - is_eval, - samples, - verbose, - num_gpus, - self.args.rank, - ) - elif isimplemented(model, "get_data_loader"): - loader = model.get_data_loader( - config=config, assets=assets, is_eval=is_eval, samples=samples, verbose=verbose, num_gpus=num_gpus - ) + loader = model.get_data_loader( + config=config, + assets=assets, + is_eval=is_eval, + samples=samples, + verbose=verbose, + num_gpus=num_gpus, + rank=self.args.rank, + ) assert ( len(loader) > 0 ), " ❗ len(DataLoader) returns 0. Make sure your dataset is not empty or len(dataset) > 0. " return loader + def _get_model(self) -> TrainerModel: + if not hasattr(self, "wrapped_model") or self.wrapped_model is None: + return self.model + return self.wrapped_model + def get_train_dataloader( self, training_assets: dict[str, Any], samples: list[Any] | None, *, verbose: bool ) -> DataLoader[Any]: @@ -733,28 +720,26 @@ def get_train_dataloader( Returns: DataLoader: Initialized training data loader. """ - if self.num_gpus > 1: - if isimplemented(self.model.module, "get_train_data_loader"): - return self.model.module.get_train_data_loader( - self.config, - self.training_assets, - samples, - verbose, - self.num_gpus, - self.args.rank, - ) - elif isimplemented(self.model, "get_train_data_loader"): - return self.model.get_train_data_loader(self.config, self.training_assets, samples, verbose, self.num_gpus) - - return self._get_loader( - self.model, - self.config, - training_assets, - samples, - is_eval=False, - verbose=verbose, - num_gpus=self.num_gpus, - ) + model = self._get_model() + try: + return model.get_train_data_loader( + self.config, + self.training_assets, + samples, + verbose, + self.num_gpus, + self.args.rank, + ) + except NotImplementedError: + return self._get_loader( + model, + self.config, + training_assets, + samples, + is_eval=False, + verbose=verbose, + num_gpus=self.num_gpus, + ) def get_eval_dataloader( self, training_assets: dict[str, Any], samples: list[Any] | None, *, verbose: bool @@ -772,28 +757,26 @@ def get_eval_dataloader( Returns: DataLoader: Initialized training data loader. """ - if self.num_gpus > 1: - if isimplemented(self.model.module, "get_eval_data_loader"): - return self.model.module.get_eval_data_loader( - self.config, - self.training_assets, - samples, - verbose, - self.num_gpus, - self.args.rank, - ) - elif isimplemented(self.model, "get_eval_data_loader"): - return self.model.get_eval_data_loader(self.config, self.training_assets, samples, verbose, self.num_gpus) - - return self._get_loader( - self.model, - self.config, - training_assets, - samples, - is_eval=True, - verbose=verbose, - num_gpus=self.num_gpus, - ) + model = self._get_model() + try: + return model.get_eval_data_loader( + self.config, + self.training_assets, + samples, + verbose, + self.num_gpus, + self.args.rank, + ) + except NotImplementedError: + return self._get_loader( + model, + self.config, + training_assets, + samples, + is_eval=True, + verbose=verbose, + num_gpus=self.num_gpus, + ) def get_test_dataloader( self, training_assets: dict[str, Any], samples: list[Any] | None, *, verbose: bool @@ -811,28 +794,26 @@ def get_test_dataloader( Returns: DataLoader: Initialized training data loader. """ - if self.num_gpus > 1: - if isimplemented(self.model.module, "get_test_data_loader"): - return self.model.module.get_test_data_loader( - self.config, - self.training_assets, - samples, - verbose, - self.num_gpus, - self.args.rank, - ) - elif isimplemented(self.model, "get_test_data_loader"): - return self.model.get_test_data_loader(self.config, self.training_assets, samples, verbose, self.num_gpus) - - return self._get_loader( - self.model, - self.config, - training_assets, - samples, - is_eval=True, - verbose=verbose, - num_gpus=self.num_gpus, - ) + model = self._get_model() + try: + return model.get_test_data_loader( + self.config, + self.training_assets, + samples, + verbose, + self.num_gpus, + self.args.rank, + ) + except NotImplementedError: + return self._get_loader( + model, + self.config, + training_assets, + samples, + is_eval=True, + verbose=verbose, + num_gpus=self.num_gpus, + ) def format_batch(self, batch: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]: """Format the dataloader output and return a batch. @@ -848,7 +829,11 @@ def format_batch(self, batch: dict[str, Any] | list[Any]) -> dict[str, Any] | li Dict: Formatted batch. """ with suppress(NotImplementedError): - batch = self.model.module.format_batch(batch) if self.num_gpus > 1 else self.model.format_batch(batch) + batch = ( + self.wrapped_model.format_batch(batch) + if self.wrapped_model is not None + else self.model.format_batch(batch) + ) if isinstance(batch, dict): for k, v in batch.items(): @@ -856,13 +841,12 @@ def format_batch(self, batch: dict[str, Any] | list[Any]) -> dict[str, Any] | li elif isinstance(batch, list): batch = [to_cuda(v) for v in batch] - try: - if self.num_gpus > 1: - batch = self.model.module.format_batch_on_device(batch) - else: - batch = self.model.format_batch_on_device(batch) - except NotImplementedError: - pass + with suppress(NotImplementedError): + batch = ( + self.wrapped_model.format_batch_on_device(batch) + if self.wrapped_model is not None + else self.model.format_batch_on_device(batch) + ) return batch ###################### @@ -881,10 +865,9 @@ def master_params(optimizer: torch.optim.Optimizer) -> Generator[Any]: for group in optimizer.param_groups: yield from group["params"] - @staticmethod def _model_train_step( + self, batch: dict[str, Any] | list[Any], - model: TrainerModel, criterion: nn.Module | list[nn.Module], optimizer_idx: int | None = None, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -892,7 +875,6 @@ def _model_train_step( Args: batch (Dict): [description] - model (TrainerModel): [description] criterion (nn.Module): [description] optimizer_idx (int, optional): [description]. Defaults to None. @@ -903,9 +885,9 @@ def _model_train_step( if optimizer_idx is not None: input_args.append(optimizer_idx) # unwrap model in DDP training - if hasattr(model, "module"): - return model.module.train_step(*input_args) - return model.train_step(*input_args) + if self.wrapped_model is not None: + return self.wrapped_model.train_step(*input_args) + return self.model.train_step(*input_args) def _get_autocast_args(self, *, mixed_precision: bool, precision: str) -> tuple[str, torch.dtype]: device = "cpu" @@ -948,17 +930,17 @@ def detach_loss_dict( def _compute_loss( self, batch: dict[str, Any] | list[Any], - model: TrainerModel, criterion: nn.Module | list[nn.Module], - config: TrainerConfig, optimizer_idx: int | None, ) -> tuple[dict[str, Any], dict[str, Any]]: - device, dtype = self._get_autocast_args(mixed_precision=config.mixed_precision, precision=config.precision) - with torch.autocast(device_type=device, dtype=dtype, enabled=config.mixed_precision): + device, dtype = self._get_autocast_args( + mixed_precision=self.config.mixed_precision, precision=self.config.precision + ) + with torch.autocast(device_type=device, dtype=dtype, enabled=self.config.mixed_precision): if optimizer_idx is not None: - outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx) + outputs, loss_dict = self._model_train_step(batch, criterion, optimizer_idx=optimizer_idx) else: - outputs, loss_dict = self._model_train_step(batch, model, criterion) + outputs, loss_dict = self._model_train_step(batch, criterion) return outputs, loss_dict @staticmethod @@ -997,12 +979,10 @@ def _grad_clipping( def optimize( self, batch: dict[str, Any] | list[Any], - model: TrainerModel, optimizer: torch.optim.Optimizer, scaler: "torch.GradScaler | None", criterion: nn.Module | list[nn.Module], scheduler: LRScheduler | None, - config: TrainerConfig, *, optimizer_idx: int | None = None, step_optimizer: bool = True, @@ -1012,12 +992,10 @@ def optimize( Args: batch (Dict): Input batch. If - model (TrainerModel): Model for training. Defaults to None. optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. scaler (AMPScaler): AMP scaler. criterion (nn.Module): Model's criterion. scheduler (LRScheduler): LR scheduler used by the optimizer. - config (TrainerConfig): Model config. optimizer_idx (int, optional): Target optimizer being used. Defaults to None. step_optimizer (bool, optional): Whether step the optimizer. If False, gradients are accumulated and model parameters are not updated. Defaults to True. @@ -1032,16 +1010,14 @@ def optimize( step_start_time = time.time() # forward pass and loss computation - outputs, loss_dict = self._compute_loss( - batch=batch, model=model, criterion=criterion, config=config, optimizer_idx=optimizer_idx - ) + outputs, loss_dict = self._compute_loss(batch=batch, criterion=criterion, optimizer_idx=optimizer_idx) # skip the rest if not outputs from the model if not loss_dict: step_time = time.time() - step_start_time return outputs, {}, step_time - grad_clip = self._set_grad_clip_per_optimizer(config=config, optimizer_idx=optimizer_idx) + grad_clip = self._set_grad_clip_per_optimizer(config=self.config, optimizer_idx=optimizer_idx) # optimizer step grad_norm: float | torch.Tensor = 0.0 update_lr_scheduler = True @@ -1053,13 +1029,13 @@ def optimize( loss_dict["loss"] = loss_dict["loss"] / float(self.grad_accum_steps) if self.use_accelerate: - with self.accelerator.accumulate(model): - ctx_mgr = self.accelerator.autocast if config.mixed_precision else nullcontext + with self.accelerator.accumulate(self.model): + ctx_mgr = self.accelerator.autocast if self.config.mixed_precision else nullcontext with ctx_mgr(): self.accelerator.backward(loss_dict["loss"]) grad_norm = self._compute_grad_norm(optimizer) if self.accelerator.sync_gradients and grad_clip is not None and grad_clip > 0: - self.accelerator.clip_grad_norm_(model.parameters(), grad_clip) + self.accelerator.clip_grad_norm_(self.model.parameters(), grad_clip) optimizer.step() if ( scheduler is not None @@ -1141,7 +1117,7 @@ def train_step( loss_dict = {} # OPTIMIZATION - if isimplemented(self.model, "optimize"): # pylint: disable=too-many-nested-blocks + try: # custom optimize for the model step_time = time.time() device, dtype = self._get_autocast_args( @@ -1156,7 +1132,7 @@ def train_step( # TODO: find a way to log grad_norm for custom optimize loss_dict_new = self.detach_loss_dict(loss_dict_new, step_optimizer=True) loss_dict.update(loss_dict_new) - else: + except NotImplementedError as e: # gradient accumulation # TODO: grad accumulation for each optimizer step_optimizer = True @@ -1164,23 +1140,27 @@ def train_step( step_optimizer = False if not isinstance(self.optimizer, list): - if isinstance(self.scheduler, dict | list): - msg = "Can't use dict or list of schedulers with a single optimizer." - raise TypeError(msg) + if isinstance(self.scheduler, list): + msg = "Can't use list of schedulers with a single optimizer." + raise TypeError(msg) from e + if isinstance(self.scheduler, dict): + msg = "Can only use dict of schedulers with custom `optimize()`" + raise TypeError(msg) from e # auto training with a single optimizer outputs, loss_dict_new, step_time = self.optimize( batch, - self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, - self.config, step_optimizer=step_optimizer, num_optimizers=1, ) loss_dict.update(loss_dict_new) else: + if self.grad_accum_steps != 1: + msg = " [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom `optimize` method to deal with dangling gradients in multiple-optimizer setup!" + raise ValueError(msg) from e # auto training with multiple optimizers (e.g. GAN) outputs_per_optimizer = [] total_step_time = 0.0 @@ -1193,12 +1173,10 @@ def train_step( scheduler = self.scheduler[idx] optimizer_outputs, loss_dict_new, step_time = self.optimize( batch, - self.model, optimizer, scaler, criterion, scheduler, - self.config, optimizer_idx=idx, step_optimizer=step_optimizer, num_optimizers=len(self.optimizer), @@ -1301,10 +1279,7 @@ def train_epoch(self) -> None: ) self.train_loader = self.prepare_accelerate_loader(self.train_loader) # set model to training mode - if self.num_gpus > 1: - self.model.module.train() - else: - self.model.train() + self.model.train() epoch_start_time = time.time() self.callbacks.on_train_epoch_start(self) @@ -1324,10 +1299,7 @@ def train_epoch(self) -> None: # RUN EVAL -> run evaluation epoch in the middle of training. Useful for big datasets. if self.config.run_eval_steps is not None and (self.total_steps_done % self.config.run_eval_steps == 0): self.eval_epoch() - if self.num_gpus > 1: - self.model.module.train() - else: - self.model.train() + self.model.train() epoch_time = time.time() - epoch_start_time self.callbacks.on_train_epoch_end(self) @@ -1377,18 +1349,10 @@ def _model_eval_step( Tuple[Dict, Dict]: model outputs and losses. """ input_args: list[Any] = [batch, criterion] - - if isimplemented(model, "optimize"): - if hasattr(model, "module"): - return model.module.eval_step(batch, self) - return model.eval_step(batch, self) - if optimizer_idx is not None: input_args.append(optimizer_idx) - if hasattr(model, "module"): - return model.module.eval_step(*input_args) - return model.eval_step(*input_args) + return self._get_model().eval_step(*input_args) def eval_step( self, batch: dict[str, Any], step: int @@ -1405,15 +1369,15 @@ def eval_step( outputs: dict[str, Any] | list[dict[str, Any]] with torch.inference_mode(): loss_dict: dict[str, Any] = {} - if not isinstance(self.optimizer, list) or isimplemented(self.model, "optimize"): - outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) + model = self._get_model() + if not isinstance(self.optimizer, list) or len(signature(model.eval_step).parameters) == 2: # noqa: PLR2004 + outputs, loss_dict = model.eval_step(batch, self.criterion) if outputs is None: return None, None else: optimizer_outputs = [] for idx, _ in enumerate(self.optimizer): - criterion = self.criterion - outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) + outputs_, loss_dict_new = model.eval_step(batch, self.criterion, idx) if outputs_ is None: return None, None optimizer_outputs.append(outputs_) @@ -1474,16 +1438,9 @@ def eval_epoch(self) -> None: loader_start_time = time.time() # plot epoch stats, artifacts and figures if self.args.rank == 0 and outputs is not None: - if hasattr(self.model, "module") and isimplemented(self.model.module, "eval_log"): - self.model.module.eval_log( - batch, - outputs, - self.dashboard_logger, - self.training_assets, - self.total_steps_done, - ) - elif isimplemented(self.model, "eval_log"): - self.model.eval_log( + model = self._get_model() + with suppress(NotImplementedError): + model.eval_log( batch, outputs, self.dashboard_logger, @@ -1508,35 +1465,21 @@ def test_run(self) -> None: and iterate over it. """ self.model.eval() + model = self._get_model() test_outputs = None - if isimplemented(self.model, "test_run") or ( - self.num_gpus > 1 and isimplemented(self.model.module, "test_run") - ): - # handle everything in ```model.test_run()` - if self.num_gpus > 1: - test_outputs = self.model.module.test_run(self.training_assets) - else: - test_outputs = self.model.test_run(self.training_assets) - elif isimplemented(self.model, "test") or (self.num_gpus > 1 and isimplemented(self.model.module, "test")): + try: + test_outputs = model.test_run(self.training_assets) + except NotImplementedError: self.test_loader = self.get_test_dataloader( self.training_assets, self.test_samples if self.test_samples else self.eval_samples, verbose=True, ) # use test_loader to load test samples - if self.num_gpus > 1: - test_outputs = self.model.module.test(self.training_assets, self.test_loader, None) - else: - test_outputs = self.model.test(self.training_assets, self.test_loader, None) - if isimplemented(self.model, "test_log") or ( - self.num_gpus > 1 and isimplemented(self.model.module, "test_log") - ): - if self.num_gpus > 1: - self.model.module.test_log( - test_outputs, self.dashboard_logger, self.training_assets, self.total_steps_done - ) - else: - self.model.test_log(test_outputs, self.dashboard_logger, self.training_assets, self.total_steps_done) + with suppress(NotImplementedError): + test_outputs = model.test(self.training_assets, self.test_loader, None) + with suppress(NotImplementedError): + model.test_log(test_outputs, self.dashboard_logger, self.training_assets, self.total_steps_done) def _restore_best_loss(self) -> None: """Restore the best loss. @@ -1742,7 +1685,7 @@ def save_best_model(self) -> None: {"train_loss": train_loss, "eval_loss": eval_loss}, self.best_loss, self.config, - self.model, + self._get_model(), self.optimizer, self.scaler if self.use_amp_scaler else None, self.total_steps_done, @@ -1761,7 +1704,7 @@ def save_checkpoint(self) -> None: save_checkpoint( self.config, - self.model, + self._get_model(), self.optimizer, self.scaler if self.use_amp_scaler else None, self.total_steps_done, @@ -1786,16 +1729,9 @@ def update_training_dashboard_logger( # training visualizations if batch is not None and outputs is not None: - if hasattr(self.model, "module") and isimplemented(self.model.module, "train_log"): - self.model.module.train_log( - batch, - outputs, - self.dashboard_logger, - self.training_assets, - self.total_steps_done, - ) - elif isimplemented(self.model, "train_log"): - self.model.train_log( + model = self._get_model() + with suppress(NotImplementedError): + model.train_log( batch, outputs, self.dashboard_logger, @@ -1823,13 +1759,9 @@ def get_optimizer( Returns: Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list. """ - optimizer = None - if isimplemented(model, "get_optimizer"): - try: - optimizer = model.get_optimizer() - except NotImplementedError: - optimizer = None - if optimizer is None: + try: + return model.get_optimizer() + except NotImplementedError as e: if isinstance(config.optimizer, list): optimizers = [] for i, optimizer_name in enumerate(config.optimizer): @@ -1838,11 +1770,10 @@ def get_optimizer( return optimizers if config.optimizer is None: msg = "No name specified in `optimizer`" - raise ValueError(msg) + raise ValueError(msg) from e optimizer_name = config.optimizer optimizer_params = {} if config.optimizer_params is None else config.optimizer_params return get_optimizer(optimizer_name, optimizer_params, config.lr, model) # type: ignore[arg-type] - return optimizer @staticmethod def get_lr(model: TrainerModel, config: TrainerConfig) -> float | list[float] | dict[str, float]: @@ -1858,15 +1789,10 @@ def get_lr(model: TrainerModel, config: TrainerConfig) -> float | list[float] | Returns: Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier. """ - lr = None - if isimplemented(model, "get_lr"): - try: - lr = model.get_lr() - except NotImplementedError: - lr = None - if lr is None: - lr = config.lr - return lr + try: + return model.get_lr() + except NotImplementedError: + return config.lr @staticmethod def get_scheduler( @@ -1886,22 +1812,12 @@ def get_scheduler( Returns: Union[torch.optim.Optimizer, List, Dict]: A scheduler or a list of schedulers, one for each optimizer. """ - scheduler = None - if isimplemented(model, "get_scheduler"): - try: - scheduler = model.get_scheduler(optimizer) - except NotImplementedError: - scheduler = None - if isinstance(scheduler, dict) and not isimplemented(model, "optimize"): - msg = ( - " [!] Dictionary of schedulers are only supported with the manual optimization `model.optimize()`." - ) - raise ValueError(msg) - if scheduler is None: + try: + return model.get_scheduler(optimizer) + except NotImplementedError: lr_scheduler = config.lr_scheduler lr_scheduler_params = config.lr_scheduler_params return get_scheduler(lr_scheduler, lr_scheduler_params, optimizer) # type: ignore[arg-type] - return scheduler @staticmethod def restore_scheduler( From 77f64ba7a34b2b73b3cee676047a32fe7eba278f Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 29 Jan 2025 22:18:22 +0100 Subject: [PATCH 4/4] chore: bump version to 0.2.3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7b2cab8..7669e41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ build-backend = "hatchling.build" [project] name = "coqui-tts-trainer" -version = "0.2.2" +version = "0.2.3" description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui." readme = "README.md" requires-python = ">=3.10, <3.14"