Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Add PEFT #9490

Merged
merged 48 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
a140dcf
initial commit for PEFT in nemo2
cuichenx Jun 18, 2024
a62fd06
Apply isort and black reformatting
cuichenx Jun 18, 2024
76b23be
address comments
cuichenx Jun 25, 2024
c87896b
make import easier
cuichenx Jun 25, 2024
e356a1f
Apply isort and black reformatting
cuichenx Jun 25, 2024
0d649de
address comments
cuichenx Jun 25, 2024
6b98061
Update nemo/collections/llm/peft/lora.py
marcromeyn Jun 25, 2024
76f3a48
Some small fixes + adding more doc-strings
marcromeyn Jun 25, 2024
1b92d15
Apply isort and black reformatting
marcromeyn Jun 25, 2024
21dc1c5
Adding ModelTransform callback
marcromeyn Jun 25, 2024
6666535
Apply isort and black reformatting
marcromeyn Jun 25, 2024
38cb190
Fixing type-hint for model_transform
marcromeyn Jun 25, 2024
3d8351f
Apply isort and black reformatting
marcromeyn Jun 25, 2024
741c998
fix import
cuichenx Jun 25, 2024
ac4eec1
model transform for gemma llama
cuichenx Jun 25, 2024
eadf9fe
Apply isort and black reformatting
cuichenx Jun 25, 2024
3228035
fix model transform
cuichenx Jun 27, 2024
21a2b10
Apply isort and black reformatting
cuichenx Jun 27, 2024
85bfad6
change lora target default to all linear modules
cuichenx Jun 27, 2024
b64789b
Apply isort and black reformatting
cuichenx Jun 27, 2024
077469e
Small fix in mixtral
marcromeyn Jul 3, 2024
7cd790a
Apply isort and black reformatting
marcromeyn Jul 3, 2024
e083e3c
Integrating PEFT to the public-API + some fixes
marcromeyn Jul 3, 2024
3df37df
Big refactor to allow to load adapter-states
marcromeyn Jul 3, 2024
ef76686
Some fixes to support adapter_path
marcromeyn Jul 3, 2024
908feed
Apply isort and black reformatting
marcromeyn Jul 3, 2024
6c7be3d
Disabling ckpt reloading when adapter_path is passed
marcromeyn Jul 3, 2024
f14a3c5
Fix CLI
marcromeyn Jul 3, 2024
cdabd75
Apply isort and black reformatting
marcromeyn Jul 3, 2024
08fb761
Remove commented-out code
marcromeyn Jul 3, 2024
7d96473
Remove commented-out code
marcromeyn Jul 3, 2024
454c23f
Remove un-used import
marcromeyn Jul 3, 2024
0e43371
Fix callback imports
marcromeyn Jul 4, 2024
aa711bd
Apply isort and black reformatting
marcromeyn Jul 4, 2024
baca872
Fixing llm.pretrain
marcromeyn Jul 4, 2024
2d12a84
Some small fixes
marcromeyn Jul 4, 2024
553b3f3
Apply isort and black reformatting
marcromeyn Jul 4, 2024
45690f3
Fix missing import + type-hint in finetune
marcromeyn Jul 4, 2024
835cdd3
Adding PreemptionCallback + some more tests
marcromeyn Jul 4, 2024
e4d809e
Apply isort and black reformatting
marcromeyn Jul 4, 2024
30b967b
Clean up imports & clean up llm.api
marcromeyn Jul 4, 2024
2617a15
Apply isort and black reformatting
marcromeyn Jul 4, 2024
2e9857e
Trying to fix failing tests
marcromeyn Jul 4, 2024
c80e7b7
Merge branch 'main' into nemo-ux/peft
marcromeyn Jul 4, 2024
d541a47
Remove __init__.py 2
marcromeyn Jul 4, 2024
7a153d3
Apply isort and black reformatting
marcromeyn Jul 4, 2024
37b5b73
Fix failing test
marcromeyn Jul 4, 2024
eb72dad
Trying to fix last failing test
marcromeyn Jul 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
except ImportError:
pass

from nemo.collections.llm import tokenizer
from nemo.collections.llm.api import export_ckpt, import_ckpt, pretrain, train, validate
from nemo.collections.llm import peft, tokenizer
from nemo.collections.llm.api import export_ckpt, finetune, import_ckpt, pretrain, train, validate
from nemo.collections.llm.gpt.data import (
DollyDataModule,
FineTuningDataModule,
Expand Down Expand Up @@ -98,6 +98,7 @@
"export_ckpt",
"pretrain",
"validate",
"finetune",
"tokenizer",
"mock",
"squad",
Expand All @@ -118,4 +119,5 @@
"gemma_7b",
"code_gemma_2b",
"code_gemma_7b",
"peft",
]
285 changes: 219 additions & 66 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from copy import deepcopy
from pathlib import Path
from typing import Callable, Optional
from typing import Any, Callable, Optional, Union

import pytorch_lightning as pl
from typing_extensions import Annotated

from nemo.collections.llm.utils import Config, task
from nemo.lightning import AutoResume, MegatronStrategy, NeMoLogger, OptimizerModule, Trainer, io, teardown
from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging


TokenizerType = Any


@task(namespace="llm")
Expand All @@ -16,7 +22,8 @@ def train(
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional[str] = None,
tokenizer: Optional[TokenizerType] = None,
model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None,
# TODO: Fix export export: Optional[str] = None,
) -> Path:
"""
Expand All @@ -30,42 +37,38 @@ def train(
resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[str]): Tokenizer setting to be applied. Can be 'data' or 'model'.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
export (Optional[str]): Filename to save the exported checkpoint after training.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.

Returns
-------
Path: The directory path where training artifacts are saved.

Raises
------
ValueError: If the trainer's strategy is not MegatronStrategy.

Examples
--------
>>> model = MyModel()
>>> data = MyDataModule()
>>> trainer = Trainer(strategy=MegatronStrategy())
>>> train(model, data, trainer, tokenizer='data', source='path/to/ckpt.ckpt', export='final.ckpt')
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> train(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
_log = log or NeMoLogger()
app_state = _log.setup(
trainer,
resume_if_exists=getattr(resume, "resume_if_exists", False),
task_config=getattr(train, "__io__", None),
app_state = _setup(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer=tokenizer,
model_transform=model_transform,
)
if resume is not None:
resume.setup(model, trainer)
if optim:
optim.connect(model)
if tokenizer: # TODO: Improve this
_use_tokenizer(model, data, tokenizer)

trainer.fit(model, data)

_log.teardown()

return app_state.exp_dir


Expand All @@ -74,41 +77,152 @@ def pretrain(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
source: Optional[str] = None,
# export: Optional[str] = None
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
) -> Path:
return train(model=model, data=data, trainer=trainer, tokenizer="data", source=source)
"""
Pretrains a model using the specified data and trainer, with optional logging, resuming, and optimization.

This function is a wrapper around the `train` function, specifically configured for pretraining tasks.
Note, by default it will use the tokenizer from the model.

Args:
model (pl.LightningModule): The model to be pretrained.
data (pl.LightningDataModule): The data module containing pretraining data.
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
log (NeMoLogger): A nemologger instance.
resume (Optional[AutoResume]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default
optimizer from the model will be used.

Returns:
Path: The directory path where pretraining artifacts are saved.

Examples:
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.PretrainingDataModule(paths=[...], seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> llm.pretrain(model, data, trainer)
PosixPath('/path/to/log_dir')
"""
return train(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer="data",
)


@task(namespace="llm")
def validate(
def finetune(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
tokenizer: Optional[str] = None,
source: Optional[str] = None,
export: Optional[str] = None,
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
peft: Optional[Union[PEFT, ModelTransform, Callable]] = None,
) -> Path:
if not isinstance(trainer.strategy, MegatronStrategy):
raise ValueError("Only MegatronStrategy is supported")
"""
Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT.

validate_kwargs = {}
run_dir = Path(trainer.logger.log_dir)
export_dir = run_dir / "export"
Note, by default it will use the tokenizer from the model.

if tokenizer: # TODO: Improve this
_use_tokenizer(model, data, tokenizer)
if source:
_add_ckpt_path(source, model, validate_kwargs)
Args:
model (pl.LightningModule): The model to be finetuned.
data (pl.LightningDataModule): The data module containing finetuning data.
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
log (NeMoLogger): A nemologger instance.
resume (Optional[AutoResume]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default
optimizer from the model will be used.
peft (Optional[PEFT]): A PEFT (Parameter-Efficient Fine-Tuning) configuration to be applied.

Returns:
Path: The directory path where finetuning artifacts are saved.

Examples:
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> finetune(model, data, trainer, peft=llm.peft.LoRA()])
PosixPath('/path/to/log_dir')
"""

trainer.validate(model, data, **validate_kwargs)
trainer.save_checkpoint(export_dir)
if export:
teardown(trainer)
del trainer, model, data
export_ckpt(export_dir, export)
return train(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer="model",
model_transform=peft,
)

return run_dir

@task(namespace="llm")
def validate(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None,
resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional[TokenizerType] = None,
model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None,
) -> Path:
"""
Validates a model using the specified data and trainer, with optional logging, resuming, and model transformations.

Args:
model (pl.LightningModule): The model to be validated.
data (pl.LightningDataModule): The data module containing validation data.
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
log (NeMoLogger): A nemologger instance.
resume (Optional[AutoResume]): Resume from a checkpoint for validation.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.

Returns:
Path: The directory path where validation artifacts are saved.

Examples:
>>> from nemo.collections import llm
>>> from nemo import lightning as nl
>>> model = llm.MistralModel()
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> validate(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
app_state = _setup(
model=model,
data=data,
trainer=trainer,
log=log,
resume=resume,
optim=optim,
tokenizer=tokenizer,
model_transform=model_transform,
)

trainer.validate(model, data)

return app_state.exp_dir


@task(name="import", namespace="llm")
Expand Down Expand Up @@ -136,28 +250,67 @@ def export_ckpt(
return io.export_ckpt(path, target, output_path, overwrite, load_connector)


def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: str) -> None:
def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None:
if tokenizer == "data":
model.tokenizer = data.tokenizer
if hasattr(model, "__io__"):
model.__io__.tokenizer = data.tokenizer
_set_with_io(model, "tokenizer", data.tokenizer)
elif tokenizer == "model":
data.tokenizer = model.tokenizer
if hasattr(data, "__io__"):
data.__io__.tokenizer = model.tokenizer
_set_with_io(data, "tokenizer", model.tokenizer)
else:
try:
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

if isinstance(tokenizer, TokenizerSpec):
_set_with_io(model, "tokenizer", tokenizer)
_set_with_io(data, "tokenizer", tokenizer)
else:
raise ValueError(f"Expected TokenizerSpec or 'data' or 'model', got: {tokenizer}")
except ImportError:
raise ValueError("TokenizerSpec is not available")

def _add_ckpt_path(source, model, kwargs) -> None:
if io.is_distributed_ckpt(source):
kwargs["ckpt_path"] = source
else:
kwargs["ckpt_path"] = model.import_ckpt(source)

def _setup(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
log: Optional[NeMoLogger],
resume: Optional[AutoResume],
optim: Optional[OptimizerModule],
tokenizer: Optional[TokenizerType],
model_transform: Optional[Union[PEFT, ModelTransform, Callable]],
) -> Any: # Return type is Any because app_state's type is not specified
_log = log or NeMoLogger()
if resume and resume.adapter_path and _log.ckpt:
logging.info("Disabling try_restore_best_ckpt restoration for adapters")
_log.ckpt.try_restore_best_ckpt = False

app_state = _log.setup(
trainer,
resume_if_exists=getattr(resume, "resume_if_exists", False),
task_config=getattr(train, "__io__", None),
)
if resume is not None:
resume.setup(model, trainer)

if optim:
optim.connect(model)
if tokenizer: # TODO: Improve this
_use_tokenizer(model, data, tokenizer)

if model_transform:
_set_with_io(model, "model_transform", model_transform)

# Add ModelTransform callback to Trainer if needed
if getattr(model, "model_transform", None):
if not any(isinstance(cb, ModelTransform) for cb in trainer.callbacks):
if isinstance(model_transform, ModelTransform):
trainer.callbacks.append(model_transform)
else:
trainer.callbacks.append(ModelTransform())

return app_state

def _save_config_img(*args, **kwargs):
try:
from nemo_sdk.utils import save_config_img

save_config_img(*args, **kwargs)
except ImportError:
pass
def _set_with_io(obj, attr, value):
setattr(obj, attr, value)
if hasattr(obj, "__io__") and hasattr(value, "__io__"):
setattr(obj.__io__, attr, deepcopy(value.__io__))
Loading
Loading