Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Nov 14, 2024
1 parent 5dbb911 commit 9418de1
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 18 deletions.
5 changes: 3 additions & 2 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ def main(
)

# wrap model.save_pretrained
model = trainer.model
if is_fsdp_model(model):
modify_fsdp_model_save_pretrained(trainer, tokenizer)
else:
Expand Down Expand Up @@ -394,7 +393,9 @@ def main(
training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
):
model.save_pretrained(training_args.output_dir)
model.save_pretrained(
training_args.output_dir, save_compressed=training_args.save_compressed
)
if tokenizer is not None:
tokenizer.save_pretrained(training_args.output_dir)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ def save_pretrained_wrapper(
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

def skip(*args, **kwargs):
pass

# Skip the initializer step. This accelerates the loading
# of the models, especially for the quantized models
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
Expand Down
16 changes: 0 additions & 16 deletions tests/llmcompressor/transformers/compression/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,6 @@ def _run_oneshot(model, recipe, dataset, output_dir):
num_calibration_samples = 512
max_seq_length = 512
pad_to_max_length = False

def skip(*args, **kwargs):
pass

# Skip the initializer step. This accelerates the loading
# of the models, especially for the quantized models
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

oneshot(
model=model,
Expand Down Expand Up @@ -110,17 +101,11 @@ def _get_quant_info(self, model):
return quant_info_weights, quant_info_inputs

def test_quantization_reload(self):
breakpoint()
model_reloaded = AutoModelForCausalLM.from_pretrained(
os.path.join(self.test_dir, self.output),
torch_dtype="auto",
device_map="cuda:0",
)
# model_reloaded = self.session_model
"""
model_reloaded = AutoModelForCausalLM.from_pretrained(os.path.join(self.test_dir, self.output),torch_dtype="auto", device_map="cuda:0",)
model_reloaded = AutoModelForCausalLM.from_pretrained(os.path.join(self.test_dir, self.output),torch_dtype="auto",)
"""

og_weights, og_inputs = self._get_quant_info(self.model)
reloaded_weights, reloaded_inputs = self._get_quant_info(model_reloaded)
Expand Down Expand Up @@ -184,4 +169,3 @@ def test_perplexity(self):

avg_ppl = total_ppl / total_non_nan
assert avg_ppl <= self.ppl_threshold

1 change: 1 addition & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _parse_configs_dir(current_config_dir):
_parse_configs_dir(config)
else:
_parse_configs_dir(configs_directory)

return config_dicts


Expand Down

0 comments on commit 9418de1

Please sign in to comment.