Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SparseAutoModelForCausalLM Deprecation] Feature change #881

Merged
merged 45 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5c8ff83
src and tests updates
horheynm Oct 31, 2024
4d9f4df
save model if output_dir is provided
horheynm Nov 1, 2024
588ee7e
save model if provided as a string
horheynm Nov 1, 2024
17d4a9c
typo
horheynm Nov 1, 2024
bd98f6d
save if model was provided as a string or custom output_dir was set
horheynm Nov 1, 2024
51e1ada
comments
horheynm Nov 4, 2024
7b8247b
save tokenizer also if model passed as a string or custom outputdir p…
horheynm Nov 4, 2024
2f6d9ef
revert to True
horheynm Nov 4, 2024
0425124
revert model to string
horheynm Nov 4, 2024
ce77540
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
horheynm Nov 4, 2024
1ec6974
merge main
horheynm Nov 4, 2024
ff55775
merge main
horheynm Nov 4, 2024
aae73e2
Merge branch 'main' of github.com:vllm-project/llm-compressor into de…
horheynm Nov 4, 2024
a66147e
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
horheynm Nov 4, 2024
8d146ad
fix transformers tests
horheynm Nov 4, 2024
d3073fe
Merge branch 'deprecate-SparseAutoModelForCausalLM/deprecation' of gi…
horheynm Nov 4, 2024
6d02fd5
Update tests/llmcompressor/transformers/obcq/test_consecutive_runs.py
horheynm Nov 5, 2024
30123c3
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
horheynm Nov 5, 2024
9b869f8
lint:
horheynm Nov 5, 2024
1d29417
fix bug
horheynm Nov 5, 2024
9024d90
fix bug
horheynm Nov 5, 2024
e0ef750
comments
horheynm Nov 6, 2024
a221ca0
comments
horheynm Nov 6, 2024
57af085
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
kylesayrs Nov 6, 2024
f2ed4e0
fix saving bug on example script and comments
horheynm Nov 6, 2024
71d8683
Merge branch 'deprecate-SparseAutoModelForCausalLM/deprecation' of gi…
horheynm Nov 6, 2024
45994c2
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
horheynm Nov 7, 2024
d0ac63d
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
dsikka Nov 7, 2024
8786407
fix test failure
horheynm Nov 7, 2024
8b9baab
Merge branch 'deprecate-SparseAutoModelForCausalLM/deprecation' of gi…
horheynm Nov 7, 2024
10f3883
comments
horheynm Nov 8, 2024
a7d0e3e
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
horheynm Nov 8, 2024
f33793a
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
dsikka Nov 11, 2024
0ce72a5
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
dsikka Nov 12, 2024
2602648
comments
horheynm Nov 12, 2024
0284f0c
comments
horheynm Nov 13, 2024
c7c951e
lint
horheynm Nov 13, 2024
5a1cc95
fix test_quantization.py
horheynm Nov 14, 2024
acc2776
fix bugs
horheynm Nov 14, 2024
a2992ab
revert to default
horheynm Nov 14, 2024
4bcbe03
revert to default
horheynm Nov 14, 2024
5dbb911
draft
horheynm Nov 14, 2024
9418de1
fix test
horheynm Nov 14, 2024
5bc9a25
logging output fix
horheynm Nov 15, 2024
a516950
Merge branch 'main' into deprecate-SparseAutoModelForCausalLM/depreca…
dsikka Nov 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"log_model_load",
"initialize_recipe",
"save_model_and_recipe",
"copy_python_files_from_model_cache",
"fallback_to_cpu",
"parse_dtype",
"get_session_model",
Expand Down Expand Up @@ -99,7 +100,6 @@ def save_model_and_recipe(
):
"""
Save a model, tokenizer and the currently loaded recipe to file

:param model: pytorch model to save
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
Expand All @@ -123,7 +123,7 @@ def save_model_and_recipe(
fp.write(recipe_yaml_str)

# copy python files from cache dir to save_path if any
_copy_python_files_from_model_cache(model, save_path)
copy_python_files_from_model_cache(model, save_path)


def fallback_to_cpu(device: str) -> str:
Expand Down Expand Up @@ -212,16 +212,31 @@ def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]:
return {key: f.get_tensor(key) for key in f.keys()}


def _copy_python_files_from_model_cache(model: Module, save_path: str):
def copy_python_files_from_model_cache(model, save_path: str):
config = model.config
cache_dir = None
cache_path = None
if hasattr(config, "_name_or_path"):
import os
import shutil

cache_dir = config._name_or_path
for file in os.listdir(cache_dir):
full_file_name = os.path.join(cache_dir, file)
from huggingface_hub import hf_hub_download
from transformers import TRANSFORMERS_CACHE
from transformers.utils import http_user_agent

cache_path = config._name_or_path
if not os.path.exists(cache_path):
user_agent = http_user_agent()
config_file_path = hf_hub_download(
repo_id=cache_path,
filename="config.json",
cache_dir=TRANSFORMERS_CACHE,
force_download=False,
user_agent=user_agent,
)
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])

for file in os.listdir(cache_path):
full_file_name = os.path.join(cache_path, file)
if file.endswith(".py") and os.path.isfile(full_file_name):
logger.debug(f"Transferring {full_file_name} to {save_path}")
shutil.copy(full_file_name, save_path)
3 changes: 1 addition & 2 deletions src/llmcompressor/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
# isort: skip_file
# (import order matters for circular import avoidance)
from .utils import *

from .sparsification import (
SparseAutoModel,
SparseAutoModelForCausalLM,
wrap_hf_model_class,
)
from .finetune import *
20 changes: 11 additions & 9 deletions src/llmcompressor/transformers/finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ llmcompressor.transformers.text_generation.train
--distill_teacher PATH_TO_TEACHER
--dataset DATASET_NAME
--recipe PATH_TO_RECIPE
--output_dir PATH_TO_OUTPUT
--num_train_epochs 1
--splits "train"
```
Expand All @@ -33,7 +32,6 @@ accelerate launch
--distill_teacher PATH_TO_TEACHER
--dataset DATASET_NAME
--recipe PATH_TO_RECIPE
--output_dir PATH_TO_OUTPUT
--num_train_epochs 1
--splits "train"
```
Expand All @@ -44,8 +42,9 @@ See [configure_fsdp.md](../../../../examples/finetuning/configure_fsdp.md) for a

```python
from llmcompressor.transformers import train
from transformers import AutoModelForCausalLM

model = "./obcq_deployment"
model = AutoModelForCausalLM.from_pretrained("./obcq_deployment")
teacher_model = "Xenova/llama2.c-stories15M"
dataset_name = "open_platypus"
concatenate_data = False
Expand All @@ -61,13 +60,13 @@ train(
model=model,
distill_teacher=teacher_model,
dataset=dataset_name,
output_dir=output_dir,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
recipe=recipe,
num_train_epochs=num_train_epochs,
overwrite_output_dir=overwrite_output_dir,
concatenate_data = concatenate_data,
splits = splits
)
model.save_pretrained(output_dir)
```

## Additional Configuration
Expand All @@ -91,7 +90,6 @@ accelerate launch
--max_seq_len OPTIONAL
--concatenate_data OPTIONAL
--recipe PATH_TO_RECIPE
--output_dir PATH_TO_OUTPUT
--splits "train"
--pad_to_max_length False
```
Expand All @@ -100,8 +98,9 @@ accelerate launch
## Running One-shot from Python (without FSDP)
```python
from llmcompressor.transformers import oneshot
from transformers import AutoModelForCausalLM

model = "Xenova/llama2.c-stories15M"
model = AutoModelForCausalLM.from_pretrained("Xenova/llama2.c-stories15M")
dataset_name = "open_platypus"
concatenate_data = False
pad_to_max_length = False
Expand All @@ -116,13 +115,13 @@ oneshot(
model=model,
dataset=dataset_name,
concatenate_data=concatenate_data,
output_dir=output_dir,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
recipe=recipe,
overwrite_output_dir=overwrite_output_dir,
concatenate_data = concatenate_data,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
pad_to_max_length = pad_to_max_length,
splits = splits
)
model.save_pretrained(output_dir)
```

## Running Multi-Stage Recipes
Expand All @@ -141,8 +140,11 @@ of a staged recipe for Llama.
test_multi.py
```python
from llmcompressor.transformers import apply
from transformers import AutoModelForCausalLM

model = "../ml-experiments/nlg-text_generation/llama_pretrain-llama_7b-base/dense/training"
model = AutoModelForCausalLM.from_pretrained(
"../ml-experiments/nlg-text_generation/llama_pretrain-llama_7b-base/dense/training"
)
dataset_name = "open_platypus"
concatenate_data = False
run_stages=True
Expand All @@ -159,12 +161,12 @@ apply(
model_name_or_path=model,
dataset_name=dataset_name,
run_stages=run_stages,
output_dir=output_dir,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
recipe=recipe,
num_train_epochs=num_train_epochs,
overwrite_output_dir=overwrite_output_dir,
concatenate_data = concatenate_data,
remove_unused_columns = False,
splits = splits
)
model.save_pretrained(output_dir)
```
36 changes: 1 addition & 35 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
get_completed_stages,
get_session_model,
save_completed_stages,
save_model_and_recipe,
)
from llmcompressor.pytorch.utils import tensors_to_device
from llmcompressor.recipe import Recipe, StageRunType
Expand All @@ -25,11 +24,7 @@
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.training_args import TrainingArguments
from llmcompressor.utils.fsdp.helpers import (
find_and_move_state_dicts_to_cpu,
is_fsdp_model,
unwrap_and_export_model,
)
from llmcompressor.utils.fsdp.helpers import is_fsdp_model


class StageRunner:
Expand Down Expand Up @@ -170,35 +165,6 @@ def one_shot(self, stage: Optional[str] = None):

self.trainer.one_shot(calibration_data=calib_data, stage=stage)

if is_fsdp_model(self.trainer.model):
try:
self.trainer.save_model(output_dir=self._output_dir, _is_oneshot=True)
except AssertionError:
# fallback to this in the case of quantization
unwrap_and_export_model(
model=self.trainer.model,
accelerator=self.trainer.accelerator,
output_dir=self._output_dir,
tokenizer=self.tokenizer,
)
# only allow the main process move the state
# dicts to cpu
if self.trainer.accelerator.is_main_process:
# assuming quantization is the last step
# we no longer need the original model
# and can safely delete it to save memory
del self.trainer.model
find_and_move_state_dicts_to_cpu(self._output_dir)

else:
save_model_and_recipe(
model=self.trainer.model,
save_path=self._output_dir,
tokenizer=self.tokenizer,
save_safetensors=self._training_args.save_safetensors,
save_compressed=self._training_args.save_compressed,
)

def train(self, checkpoint: str, stage: Optional[str] = None):
"""
Run trainer's training loop on train_dataset, saving the resulting model to
Expand Down
7 changes: 1 addition & 6 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,7 @@ def one_shot(
# self.maybe_log_model_sparsification()
self.accelerator.wait_for_everyone()

def save_model(
self, output_dir: Optional[str] = None, _internal_call=False, _is_oneshot=False
):
def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
"""
Override of the save_model function and expects it to exist in the parent.
Calls into super() to save the model and additionally saves any recipes
Expand All @@ -465,9 +463,6 @@ def save_model(
if active_session() is None:
return # nothing to save

if output_dir is None:
output_dir = self.args.output_dir

# knowledge distillation requires making wrappers transparent during
if isinstance(self.model, KDModelWrapper):
self.model.prepare_for_save()
Expand Down
38 changes: 30 additions & 8 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from loguru import logger
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DefaultDataCollator,
HfArgumentParser,
Expand All @@ -42,11 +43,15 @@
from llmcompressor.transformers.finetune.runner import StageRunner
from llmcompressor.transformers.finetune.trainer import Trainer
from llmcompressor.transformers.finetune.training_args import TrainingArguments
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_fsdp_model_save_pretrained,
modify_save_pretrained,
)
from llmcompressor.transformers.sparsification.sparse_model import (
SparseAutoModel,
get_shared_tokenizer_src,
)
from llmcompressor.transformers.utils.helpers import detect_last_checkpoint
from llmcompressor.utils.fsdp.helpers import is_fsdp_model


def train(**kwargs):
Expand Down Expand Up @@ -199,21 +204,23 @@ def initialize_model_from_path(
"trust_remote_code": model_args.trust_remote_code_model,
}
# this calls from_pretrained under the hood so should be FSDP safe
model = SparseAutoModel.text_generation_from_pretrained(
model_name_or_path=model_path,
sequence_length=None, # use model default
model = AutoModelForCausalLM.from_pretrained(
model_path,
**model_kwargs,
)
if "sequence_length" in model_kwargs:
model.seqlen = model_kwargs["sequence_length"]

teacher = (
SparseAutoModel.text_generation_from_pretrained(
model_name_or_path=model_args.distill_teacher,
sequence_length=None, # use model default
AutoModelForCausalLM.from_pretrained(
horheynm marked this conversation as resolved.
Show resolved Hide resolved
model_args.distill_teacher,
**teacher_kwargs,
)
if model_args.distill_teacher is not None
else None
)
if teacher is not None and "sequence_length" in teacher_kwargs:
model.seqlen = teacher_kwargs["sequence_length"]

return teacher, model_path, model

Expand Down Expand Up @@ -348,7 +355,6 @@ def main(

# exit immediately
return

# Training
if training_args.do_train:
checkpoint = None
Expand All @@ -374,6 +380,22 @@ def main(
if training_args.clear_sparse_session:
reset_session()

# wrap model.save_pretrained
model = trainer.model
if is_fsdp_model(model):
modify_fsdp_model_save_pretrained(trainer, tokenizer)
else:
modify_save_pretrained(model)

# save if model was provided as a string or custom output_dir was set
if isinstance(model_args.model, str) or (
training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
):
model.save_pretrained(training_args.output_dir)
horheynm marked this conversation as resolved.
Show resolved Hide resolved
if tokenizer is not None:
tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
apply()
Loading