From d5be85435f58072c87623708381e6afb8420b8d4 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 25 Jul 2024 12:44:25 -0400 Subject: [PATCH] Fix End Epoch Default (#39) * fix finetune end bug and resulting issues * Update integration test --- .../modifiers/distillation/output/base.py | 12 ++++++------ .../distillation/utils/pytorch/kd_factory.py | 2 +- src/llmcompressor/modifiers/modifier.py | 4 ++-- src/llmcompressor/recipe/recipe.py | 2 +- src/llmcompressor/utils/fsdp/helpers.py | 17 ++++++++--------- tests/llmcompressor/recipe/test_recipe.py | 6 ++---- .../finetune/test_oneshot_and_finetune.py | 15 +++++++++++++++ .../transformers/finetune/test_session_mixin.py | 5 ++--- 8 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/llmcompressor/modifiers/distillation/output/base.py b/src/llmcompressor/modifiers/distillation/output/base.py index 8de0e50e5..3716db359 100644 --- a/src/llmcompressor/modifiers/distillation/output/base.py +++ b/src/llmcompressor/modifiers/distillation/output/base.py @@ -44,7 +44,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: hidden_size = ( kwargs.get("metadata").get("per_device_train_batch_size", 1), kwargs.get("metadata").get("max_seq_length", 512), - state.model.model.config.hidden_size, + state.model.config.hidden_size, ) for target in ( @@ -77,18 +77,18 @@ def on_initialize(self, state: State, **kwargs) -> bool: ) self.wrappers_[key] = (student_wrapper, teacher_wrapper) - with summon_full_params_context(state.teacher_model.model, offload_to_cpu=True): + with summon_full_params_context(state.teacher_model, offload_to_cpu=True): for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items(): set_layer(key, student_wrapper, state.model) set_layer(key, teacher_wrapper, state.teacher_model) self.wrapped_kd_model_ = self._create_model_wrapper( student_model=maybe_get_wrapped(state.model), - teacher_model=state.teacher_model.model, + teacher_model=state.teacher_model, state=state, ) - set_wrapped_model(state.model, self.wrapped_kd_model_) + set_wrapped_model(state, self.wrapped_kd_model_) # for square-head distillation we want to scale the loss by the number of # layers if the user doesn't alter the default scale. This is done so the @@ -99,9 +99,9 @@ def on_initialize(self, state: State, **kwargs) -> bool: return True def on_finalize(self, state: State, **kwargs) -> bool: - set_wrapped_model(state.model, self.wrapped_kd_model_.student_model) + set_wrapped_model(state, self.wrapped_kd_model_.student_model) - with summon_full_params_context(state.teacher_model.model, offload_to_cpu=True): + with summon_full_params_context(state.teacher_model, offload_to_cpu=True): for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items(): set_layer(key, student_wrapper.layer, state.model) set_layer(key, teacher_wrapper.layer, state.teacher_model) diff --git a/src/llmcompressor/modifiers/distillation/utils/pytorch/kd_factory.py b/src/llmcompressor/modifiers/distillation/utils/pytorch/kd_factory.py index 93b4b3dd9..150a0e222 100644 --- a/src/llmcompressor/modifiers/distillation/utils/pytorch/kd_factory.py +++ b/src/llmcompressor/modifiers/distillation/utils/pytorch/kd_factory.py @@ -184,7 +184,7 @@ def recursive_combine( val_two: TensorOrCollectionType, func: Callable[[Tensor, Tensor], Tensor], ): - if isinstance(val_one, type(val_two)): + if not isinstance(val_one, type(val_two)): raise ValueError( f"val_one type of {type(val_one)} must match " f"val_two type of {type(val_two)}" diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index f48287c23..4fa56124e 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -25,8 +25,8 @@ class Modifier(BaseModel, ModifierInterface): index: Optional[int] = None group: Optional[str] = None - start: Optional[float] = -1 - end: Optional[float] = -1 + start: Optional[float] = None + end: Optional[float] = None update: Optional[float] = None initialized_structure_: bool = False diff --git a/src/llmcompressor/recipe/recipe.py b/src/llmcompressor/recipe/recipe.py index 9317521e6..169cef170 100644 --- a/src/llmcompressor/recipe/recipe.py +++ b/src/llmcompressor/recipe/recipe.py @@ -667,7 +667,7 @@ def create_recipe_string_from_modifiers( recipe_dict = { f"{modifier_group_name}_stage": { f"{default_group_name}_modifiers": { - modifier.__class__.__name__: modifier.model_dump() + modifier.__class__.__name__: modifier.model_dump(exclude_unset=True) for modifier in modifiers } } diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 37e2713ce..8cc0f5405 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -16,6 +16,7 @@ import torch from torch.nn import Module +from llmcompressor.core.state import State from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe from llmcompressor.utils.pytorch import set_layer @@ -56,20 +57,18 @@ def maybe_get_wrapped(model: Module) -> Module: return model -def set_wrapped_model(model: Module, wrapped_model: Module): +def set_wrapped_model(state: State, wrapped_model: Module): """ - Given a model that may or may not have a distributed wrapper, set the underlying - wrapped model. - - #TODO: will probably have to fix this + Given a state with a model that may or may not have a distributed wrapper, set + the underlying wrapped model. - :param input_model: input model to be updated + :param state: state to update model of :param updated_wrapped: model to inject into input_model """ - if is_fsdp_model(model): - model._fsdp_wrapped_module = wrapped_model + if is_fsdp_model(state.model): + state.model._fsdp_wrapped_module = wrapped_model else: - model = wrapped_model + state.model = wrapped_model def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): diff --git a/tests/llmcompressor/recipe/test_recipe.py b/tests/llmcompressor/recipe/test_recipe.py index 43aa8737c..8c721487e 100644 --- a/tests/llmcompressor/recipe/test_recipe.py +++ b/tests/llmcompressor/recipe/test_recipe.py @@ -101,13 +101,11 @@ def test_recipe_can_be_created_from_modifier_instances(): class A_FirstDummyModifier(Modifier): - def model_dump(self): - return {} + pass class B_SecondDummyModifier(Modifier): - def model_dump(self): - return {} + pass def test_create_recipe_string_from_modifiers_with_default_group_name(): diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py index 6cbc58c75..98b79b3f5 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py @@ -1,8 +1,11 @@ +import os import shutil import unittest import pytest +from compressed_tensors.compressors.model_compressor import ModelCompressor from parameterized import parameterized_class +from transformers import AutoConfig from tests.testing_utils import parse_params, requires_gpu, requires_torch @@ -35,6 +38,18 @@ def _test_oneshot_and_finetune(self): dataset_config_name=self.dataset_config_name, ) + config_os = ModelCompressor.parse_sparsity_config( + AutoConfig.from_pretrained( + os.path.join(self.output, "stage_test_oneshot") + ).compression_config + ) + config_ft = ModelCompressor.parse_sparsity_config( + AutoConfig.from_pretrained( + os.path.join(self.output, "stage_test_oneshot") + ).compression_config + ) + assert config_ft["global_sparsity"] >= config_os["global_sparsity"] + def tearDown(self): # TODO: we get really nice stats from finetune that we should log # stored in results.json diff --git a/tests/llmcompressor/transformers/finetune/test_session_mixin.py b/tests/llmcompressor/transformers/finetune/test_session_mixin.py index 520b20c8d..8d409fcb8 100644 --- a/tests/llmcompressor/transformers/finetune/test_session_mixin.py +++ b/tests/llmcompressor/transformers/finetune/test_session_mixin.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Optional, Union import pytest -from datasets import load_dataset from torch.nn import Module from transformers import AutoModelForCausalLM, Trainer @@ -44,8 +43,8 @@ def mixin_trainer(): model_state_path = "Xenova/llama2.c-stories15M" model = AutoModelForCausalLM.from_pretrained(model_state_path) recipe = "tests/llmcompressor/transformers/finetune/test_quantization.yaml" - train_dataset = load_dataset("garage-bAInd/Open-Platypus", split="train[:5%]") - eval_dataset = load_dataset("garage-bAInd/Open-Platypus", split="train[5%:6%]") + train_dataset = "open-platypus" + eval_dataset = "open-platypus" return MixInTest( model=model,