Skip to content

Commit

Permalink
support model_class option
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs committed Dec 20, 2024
1 parent a0a6cf7 commit 46af518
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tests/e2e/e2e_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import transformers
from datasets import load_dataset
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers import oneshot, tracing
from tests.testing_utils import preprocess_tokenize_dataset


Expand All @@ -18,13 +19,14 @@ def run_oneshot_for_e2e_testing(
dataset_config: str,
scheme: str,
quant_type: str,
model_class: str = "AutoModelForCausalLM",
):
# Load model.
oneshot_kwargs = {}
loaded_model = AutoModelForCausalLM.from_pretrained(
loaded_model = get_model_class(model_class).from_pretrained(
model, device_map=device, torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer = AutoProcessor.from_pretrained(model)

if dataset_id:
ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split)
Expand Down Expand Up @@ -56,3 +58,13 @@ def run_oneshot_for_e2e_testing(
oneshot_device=device,
)
return oneshot_kwargs["model"], tokenizer


def get_model_class(model_class: str):
model_class = getattr(
tracing, model_class, getattr(transformers, model_class, None)
)
if model_class is None:
raise ValueError(f"Could not import model class {model_class}")

return model_class
9 changes: 9 additions & 0 deletions tests/e2e/vLLM/llava_configs/fp8_dynamic_per_token.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cadence: "weekly"
model: llava-hf/llava-1.5-7b-hf
scheme: FP8_DYNAMIC
num_fewshot: 5
limit: 1000
task: "gsm8k"
exact_match,flexible-extract: 0.75
exact_match,strict-match: 0.75
model_class: "TracableLlavaForConditionalGeneration"
11 changes: 11 additions & 0 deletions tests/e2e/vLLM/llava_configs/fp8_static_per_tensor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cadence: "weekly"
model: llava-hf/llava-1.5-7b-hf
scheme: FP8
num_fewshot: 5
limit: 1000
task: "gsm8k"
dataset_id: HuggingFaceH4/ultrachat_200k
dataset_split: train_sft
exact_match,flexible-extract: 0.75
exact_match,strict-match: 0.75
model_class: "TracableLlavaForConditionalGeneration"
12 changes: 12 additions & 0 deletions tests/e2e/vLLM/llava_configs/int8_w8a8_dynamic_per_token.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
cadence: "weekly"
model: llava-hf/llava-1.5-7b-hf
scheme: INT8_dyn_per_token
recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml
num_fewshot: 5
limit: 1000
task: "gsm8k"
dataset_id: HuggingFaceH4/ultrachat_200k
dataset_split: train_sft
exact_match,flexible-extract: 0.77
exact_match,strict-match: 0.76
model_class: "TracableLlavaForConditionalGeneration"
11 changes: 11 additions & 0 deletions tests/e2e/vLLM/llava_configs/w4a16_actorder_weight.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cadence: "weekly"
model: llava-hf/llava-1.5-7b-hf
recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml
num_fewshot: 5
limit: 1000
task: "gsm8k"
dataset_id: HuggingFaceH4/ultrachat_200k
dataset_split: train_sft
exact_match,flexible-extract: 0.72
exact_match,strict-match: 0.72
model_class: "TracableLlavaForConditionalGeneration"
12 changes: 12 additions & 0 deletions tests/e2e/vLLM/llava_configs/w4a16_grouped_quant.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
cadence: "weekly"
model: llava-hf/llava-1.5-7b-hf
num_fewshot: 5
limit: 1000
task: "gsm8k"
exact_match,flexible-extract: 0.72
exact_match,strict-match: 0.72
scheme: W4A16
dataset_id: HuggingFaceH4/ultrachat_200k
dataset_split: train_sft
quant_type: "GPTQ"
model_class: "TracableLlavaForConditionalGeneration"
Empty file modified tests/e2e/vLLM/run_tests.sh
100644 → 100755
Empty file.
2 changes: 2 additions & 0 deletions tests/e2e/vLLM/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def set_up(self):
pytest.skip("Skipping test; cadence mismatch")

self.model = eval_config["model"]
self.model_class = eval_config.get("model_class", "AutoModelForCausalLM")
self.scheme = eval_config.get("scheme")
self.dataset_id = eval_config.get("dataset_id")
self.dataset_config = eval_config.get("dataset_config")
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_lm_eval(self):
dataset_split=self.dataset_split,
recipe=self.recipe,
quant_type=self.quant_type,
model_class=self.model_class,
)

logger.info("================= SAVING TO DISK ======================")
Expand Down

0 comments on commit 46af518

Please sign in to comment.