Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to the Composer integration of LoRA (works with FSDP) #886

Merged
merged 75 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
b409125
switch up
dakinggg Dec 19, 2023
471328d
propagate
dakinggg Dec 19, 2023
7feca49
fix
dakinggg Dec 19, 2023
da09cfe
fix
dakinggg Dec 22, 2023
0ec48ae
fix
dakinggg Dec 22, 2023
c3f94cc
fix
dakinggg Dec 22, 2023
9bca493
Merge branch 'main' into composer_lora
dakinggg Dec 22, 2023
b743767
fix
dakinggg Dec 22, 2023
a05fcf1
fix
dakinggg Dec 22, 2023
4482e76
wip test
dakinggg Dec 24, 2023
bcd45cc
quick adapt ckpter for peft
dakinggg Jan 5, 2024
5f4ebd9
default adapter
dakinggg Jan 5, 2024
8f21297
convert
dakinggg Jan 5, 2024
935bce8
remove args overriding
dakinggg Jan 18, 2024
beaa86c
Merge branch 'main' into composer_lora
dakinggg Jan 19, 2024
03a1c57
wip
dakinggg Jan 19, 2024
7cb401b
Merge branch 'main' into composer_lora
dakinggg Jan 22, 2024
0757eab
wip
dakinggg Jan 23, 2024
b0acf3c
temp
dakinggg Jan 23, 2024
8d6366e
temp
dakinggg Jan 23, 2024
f1c7e0b
Merge branch 'main' into composer_lora
dakinggg Jan 24, 2024
5bb5ec1
temp version allowance
dakinggg Jan 24, 2024
7d65980
fix
dakinggg Jan 25, 2024
47a5d6b
fix
dakinggg Jan 25, 2024
787fb98
fix
dakinggg Jan 25, 2024
0cce9e8
fix torch pin
dakinggg Jan 25, 2024
fde45d6
clean up config parsing
dakinggg Jan 25, 2024
141ed38
fixing up
dakinggg Jan 25, 2024
300f366
add better test
dakinggg Jan 25, 2024
d665a85
fix
dakinggg Jan 25, 2024
8bef086
fix test
dakinggg Jan 25, 2024
83eecd2
precommit
dakinggg Jan 25, 2024
d4465ab
remove commented out code
dakinggg Jan 25, 2024
038d3a9
new test
dakinggg Jan 25, 2024
7ff2451
wt tests
dakinggg Jan 25, 2024
f060a20
more test
dakinggg Jan 25, 2024
2165a3d
more test
dakinggg Jan 25, 2024
cfbae1f
precommit
dakinggg Jan 25, 2024
25e2278
Merge branch 'main' into composer_lora
dakinggg Jan 27, 2024
2bac9b9
clean up peft checking
dakinggg Jan 27, 2024
9b0d5b3
precommit
dakinggg Jan 27, 2024
cbdd80f
Merge branch 'main' into composer_lora
dakinggg Jan 27, 2024
a2a26e0
fix peft deps
dakinggg Jan 27, 2024
d8fdafe
first pr comments
dakinggg Jan 29, 2024
abdc3d9
remove init return
dakinggg Jan 29, 2024
e8e60e9
explicit save peft only
dakinggg Jan 30, 2024
e3f84cd
more pr comments
dakinggg Jan 30, 2024
b04b4e1
pull out underlying model
dakinggg Jan 30, 2024
fbc600f
pull out underlying model
dakinggg Jan 30, 2024
869ebe9
precommit
dakinggg Jan 30, 2024
3c7f87a
update docstring
dakinggg Jan 30, 2024
1238971
add more docs
dakinggg Jan 30, 2024
5b39238
pull out peft config stuff
dakinggg Jan 30, 2024
164063b
refactor init
dakinggg Jan 30, 2024
8f172e4
merge
dakinggg Jan 30, 2024
7c91255
Merge branch 'main' into composer_lora
dakinggg Jan 30, 2024
35481c7
fix duplicate tuple
dakinggg Jan 30, 2024
3535e00
first try
dakinggg Jan 31, 2024
ebc9541
precommit
dakinggg Jan 31, 2024
f16bb64
fix
dakinggg Jan 31, 2024
4f8a258
fix duplicate tuple
dakinggg Jan 30, 2024
45ee0e9
change comment
dakinggg Feb 1, 2024
f9e0cff
adjust
dakinggg Feb 1, 2024
24b4ecc
fix
dakinggg Feb 1, 2024
86d76de
add trainable params
dakinggg Feb 1, 2024
f15209c
Merge pull request #1 from dakinggg/save-peft-mlflow
dakinggg Feb 1, 2024
b2a0b0e
precommit
dakinggg Feb 1, 2024
8ae469c
merge
dakinggg Feb 1, 2024
cfeeb11
bump
dakinggg Feb 2, 2024
c8d1f6b
fix tests
dakinggg Feb 2, 2024
df17f5b
precommit
dakinggg Feb 2, 2024
7f32684
Merge branch 'main' into composer_lora
dakinggg Feb 2, 2024
6cfc408
fix
dakinggg Feb 2, 2024
5fb9587
pr comments
dakinggg Feb 2, 2024
b9f1d20
Merge branch 'main' into composer_lora
dakinggg Feb 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import tempfile
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand Down Expand Up @@ -203,14 +203,17 @@ def _save_checkpoint(self, state: State, logger: Logger):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
Expand All @@ -237,10 +240,25 @@ def _save_checkpoint(self, state: State, logger: Logger):
copied_config.init_device = 'cpu'

log.debug(f'Creating new model instance')
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

dakinggg marked this conversation as resolved.
Show resolved Hide resolved
if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter])
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand Down Expand Up @@ -295,12 +313,24 @@ def _save_checkpoint(self, state: State, logger: Logger):
# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)
model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path
}
if composer_model.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs[
'save_pretrained_dir'] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config[
'metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs[
'transformers_model'] = components
model_saving_kwargs.update(
self.mlflow_logging_config)

mlflow_logger.save_model(**model_saving_kwargs)

license_filename = _maybe_get_license_filename(
local_save_path)
Expand Down
Loading
Loading