diff --git a/README.md b/README.md index a8f5ce3f..938bf464 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,16 @@ The following properties are defined in the top level of the model configuration - `training` - The training configuration for the model, varies based on `model_type`. Provides parameters for training as well as demos. + +### Optimizer config +The optimizer config, inside of the training subsection of the model config, allows for use of different optimizer implementations, including those that allow for fine tuning with 24GB VRAM. + +- `backend` + - The type of optimizer library being used, currently limited to one of `"bnb", "default"`. +- `type` + - Optimizer name to use. If using bnb, enabled the use of `"AdamW8bit"` and other 8bit optimizers. + + ## Dataset config `stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md) diff --git a/setup.py b/setup.py index 7e7470d3..929bfce5 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ 'aeiou==0.0.20', 'alias-free-torch==0.0.6', 'auraloss==0.4.0', + 'bitsandbytes==0.35.0', 'descript-audio-codec==1.0.0', 'einops==0.7.0', 'einops-exts==0.0.4', diff --git a/stable_audio_tools/training/utils.py b/stable_audio_tools/training/utils.py index 38a3fccc..48f60be7 100644 --- a/stable_audio_tools/training/utils.py +++ b/stable_audio_tools/training/utils.py @@ -84,13 +84,20 @@ def create_optimizer_from_config(optimizer_config, parameters): """ optimizer_type = optimizer_config["type"] + optimizer_backend = optimizer_config.get("backend", "") - if optimizer_type == "FusedAdam": - from deepspeed.ops.adam import FusedAdam - optimizer = FusedAdam(parameters, **optimizer_config["config"]) - else: - optimizer_fn = getattr(torch.optim, optimizer_type) + if optimizer_backend == "bnb": + import bitsandbytes as bnb + optimizer_fn = getattr(bnb.optim, optimizer_type) optimizer = optimizer_fn(parameters, **optimizer_config["config"]) + else: + if optimizer_type == "FusedAdam": + from deepspeed.ops.adam import FusedAdam + optimizer = FusedAdam(parameters, **optimizer_config["config"]) + else: + optimizer_fn = getattr(torch.optim, optimizer_type) + optimizer = optimizer_fn(parameters, **optimizer_config["config"]) + return optimizer def create_scheduler_from_config(scheduler_config, optimizer):