Skip to content

Commit

Permalink
[NeMo-UX] Add PEFT (#9490)
Browse files Browse the repository at this point in the history
* initial commit for PEFT in nemo2

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* address comments

Signed-off-by: Chen Cui <chcui@nvidia.com>

* make import easier

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* address comments

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Update nemo/collections/llm/peft/lora.py

Signed-off-by: Marc Romeyn <marcromeyn@gmail.com>

* Some small fixes + adding more doc-strings

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Adding ModelTransform callback

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fixing type-hint for model_transform

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* fix import

Signed-off-by: Chen Cui <chcui@nvidia.com>

* model transform for gemma llama

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* fix model transform

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* change lora target default to all linear modules

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* Small fix in mixtral

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Integrating PEFT to the public-API + some fixes

* Big refactor to allow to load adapter-states

* Some fixes to support adapter_path

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Disabling ckpt reloading when adapter_path is passed

* Fix CLI

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Remove commented-out code

* Remove commented-out code

* Remove un-used import

* Fix callback imports

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fixing llm.pretrain

* Some small fixes

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fix missing import + type-hint in finetune

* Adding PreemptionCallback + some more tests

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Clean up imports & clean up llm.api

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Trying to fix failing tests

* Remove __init__.py 2

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fix failing test

* Trying to fix last failing test

---------

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Marc Romeyn <marcromeyn@gmail.com>
Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>
Co-authored-by: cuichenx <cuichenx@users.noreply.github.com>
Co-authored-by: Marc Romeyn <mromeijn@nvidia.com>
Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 5, 2024
1 parent d862499 commit f89bca0
Show file tree
Hide file tree
Showing 33 changed files with 1,434 additions and 186 deletions.
6 changes: 4 additions & 2 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
except ImportError:
pass

from nemo.collections.llm import tokenizer
from nemo.collections.llm.api import export_ckpt, import_ckpt, pretrain, train, validate
from nemo.collections.llm import peft, tokenizer
from nemo.collections.llm.api import export_ckpt, finetune, import_ckpt, pretrain, train, validate
from nemo.collections.llm.gpt.data import (
DollyDataModule,
FineTuningDataModule,
Expand Down Expand Up @@ -98,6 +98,7 @@
"export_ckpt",
"pretrain",
"validate",
"finetune",
"tokenizer",
"mock",
"squad",
Expand All @@ -118,4 +119,5 @@
"gemma_7b",
"code_gemma_2b",
"code_gemma_7b",
"peft",
]
285 changes: 219 additions & 66 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from copy import deepcopy
from pathlib import Path
from typing import Callable, Optional
from typing import Any, Callable, Optional, Union

import pytorch_lightning as pl
from typing_extensions import Annotated

from nemo.collections.llm.utils import Config, task
from nemo.lightning import AutoResume, MegatronStrategy, NeMoLogger, OptimizerModule, Trainer, io, teardown
from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging


TokenizerType = Any


@task(namespace="llm")
Expand All @@ -16,7 +22,8 @@ def train(
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional[str] = None,
tokenizer: Optional[TokenizerType] = None,
model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None,
# TODO: Fix export export: Optional[str] = None,
) -> Path:
"""
Expand All @@ -30,42 +37,38 @@ def train(
resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[str]): Tokenizer setting to be applied. Can be 'data' or 'model'.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
export (Optional[str]): Filename to save the exported checkpoint after training.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.
Returns
-------
Path: The directory path where training artifacts are saved.
Raises
------
ValueError: If the trainer's strategy is not MegatronStrategy.
Examples
--------
>>> model = MyModel()
>>> data = MyDataModule()
>>> trainer = Trainer(strategy=MegatronStrategy())
>>> train(model, data, trainer, tokenizer='data', source='path/to/ckpt.ckpt', export='final.ckpt')
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> train(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
_log = log or NeMoLogger()
app_state = _log.setup(
trainer,
resume_if_exists=getattr(resume, "resume_if_exists", False),
task_config=getattr(train, "__io__", None),
app_state = _setup(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer=tokenizer,
model_transform=model_transform,
)
if resume is not None:
resume.setup(model, trainer)
if optim:
optim.connect(model)
if tokenizer: # TODO: Improve this
_use_tokenizer(model, data, tokenizer)

trainer.fit(model, data)

_log.teardown()

return app_state.exp_dir


Expand All @@ -74,41 +77,152 @@ def pretrain(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
source: Optional[str] = None,
# export: Optional[str] = None
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
) -> Path:
return train(model=model, data=data, trainer=trainer, tokenizer="data", source=source)
"""
Pretrains a model using the specified data and trainer, with optional logging, resuming, and optimization.
This function is a wrapper around the `train` function, specifically configured for pretraining tasks.
Note, by default it will use the tokenizer from the model.
Args:
model (pl.LightningModule): The model to be pretrained.
data (pl.LightningDataModule): The data module containing pretraining data.
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
log (NeMoLogger): A nemologger instance.
resume (Optional[AutoResume]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default
optimizer from the model will be used.
Returns:
Path: The directory path where pretraining artifacts are saved.
Examples:
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.PretrainingDataModule(paths=[...], seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> llm.pretrain(model, data, trainer)
PosixPath('/path/to/log_dir')
"""
return train(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer="data",
)


@task(namespace="llm")
def validate(
def finetune(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
tokenizer: Optional[str] = None,
source: Optional[str] = None,
export: Optional[str] = None,
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
peft: Optional[Union[PEFT, ModelTransform, Callable]] = None,
) -> Path:
if not isinstance(trainer.strategy, MegatronStrategy):
raise ValueError("Only MegatronStrategy is supported")
"""
Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT.
validate_kwargs = {}
run_dir = Path(trainer.logger.log_dir)
export_dir = run_dir / "export"
Note, by default it will use the tokenizer from the model.
if tokenizer: # TODO: Improve this
_use_tokenizer(model, data, tokenizer)
if source:
_add_ckpt_path(source, model, validate_kwargs)
Args:
model (pl.LightningModule): The model to be finetuned.
data (pl.LightningDataModule): The data module containing finetuning data.
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
log (NeMoLogger): A nemologger instance.
resume (Optional[AutoResume]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default
optimizer from the model will be used.
peft (Optional[PEFT]): A PEFT (Parameter-Efficient Fine-Tuning) configuration to be applied.
Returns:
Path: The directory path where finetuning artifacts are saved.
Examples:
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> finetune(model, data, trainer, peft=llm.peft.LoRA()])
PosixPath('/path/to/log_dir')
"""

trainer.validate(model, data, **validate_kwargs)
trainer.save_checkpoint(export_dir)
if export:
teardown(trainer)
del trainer, model, data
export_ckpt(export_dir, export)
return train(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer="model",
model_transform=peft,
)

return run_dir

@task(namespace="llm")
def validate(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional[TokenizerType] = None,
model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None,
) -> Path:
"""
Validates a model using the specified data and trainer, with optional logging, resuming, and model transformations.
Args:
model (pl.LightningModule): The model to be validated.
data (pl.LightningDataModule): The data module containing validation data.
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
log (NeMoLogger): A nemologger instance.
resume (Optional[AutoResume]): Resume from a checkpoint for validation.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.
Returns:
Path: The directory path where validation artifacts are saved.
Examples:
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> validate(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
app_state = _setup(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer=tokenizer,
model_transform=model_transform,
)

trainer.validate(model, data)

return app_state.exp_dir


@task(name="import", namespace="llm")
Expand Down Expand Up @@ -136,28 +250,67 @@ def export_ckpt(
return io.export_ckpt(path, target, output_path, overwrite, load_connector)


def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: str) -> None:
def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None:
if tokenizer == "data":
model.tokenizer = data.tokenizer
if hasattr(model, "__io__"):
model.__io__.tokenizer = data.tokenizer
_set_with_io(model, "tokenizer", data.tokenizer)
elif tokenizer == "model":
data.tokenizer = model.tokenizer
if hasattr(data, "__io__"):
data.__io__.tokenizer = model.tokenizer
_set_with_io(data, "tokenizer", model.tokenizer)
else:
try:
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

if isinstance(tokenizer, TokenizerSpec):
_set_with_io(model, "tokenizer", tokenizer)
_set_with_io(data, "tokenizer", tokenizer)
else:
raise ValueError(f"Expected TokenizerSpec or 'data' or 'model', got: {tokenizer}")
except ImportError:
raise ValueError("TokenizerSpec is not available")

def _add_ckpt_path(source, model, kwargs) -> None:
if io.is_distributed_ckpt(source):
kwargs["ckpt_path"] = source
else:
kwargs["ckpt_path"] = model.import_ckpt(source)

def _setup(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
log: Optional[NeMoLogger],
resume: Optional[AutoResume],
optim: Optional[OptimizerModule],
tokenizer: Optional[TokenizerType],
model_transform: Optional[Union[PEFT, ModelTransform, Callable]],
) -> Any: # Return type is Any because app_state's type is not specified
_log = log or NeMoLogger()
if resume and resume.adapter_path and _log.ckpt:
logging.info("Disabling try_restore_best_ckpt restoration for adapters")
_log.ckpt.try_restore_best_ckpt = False

app_state = _log.setup(
trainer,
resume_if_exists=getattr(resume, "resume_if_exists", False),
task_config=getattr(train, "__io__", None),
)
if resume is not None:
resume.setup(model, trainer)

if optim:
optim.connect(model)
if tokenizer: # TODO: Improve this
_use_tokenizer(model, data, tokenizer)

if model_transform:
_set_with_io(model, "model_transform", model_transform)

# Add ModelTransform callback to Trainer if needed
if getattr(model, "model_transform", None):
if not any(isinstance(cb, ModelTransform) for cb in trainer.callbacks):
if isinstance(model_transform, ModelTransform):
trainer.callbacks.append(model_transform)
else:
trainer.callbacks.append(ModelTransform())

return app_state

def _save_config_img(*args, **kwargs):
try:
from nemo_sdk.utils import save_config_img

save_config_img(*args, **kwargs)
except ImportError:
pass
def _set_with_io(obj, attr, value):
setattr(obj, attr, value)
if hasattr(obj, "__io__") and hasattr(value, "__io__"):
setattr(obj.__io__, attr, deepcopy(value.__io__))
Loading

0 comments on commit f89bca0

Please sign in to comment.