diff --git a/tests/e2e/vLLM/__init__.py b/tests/e2e/vLLM/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml b/tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml new file mode 100644 index 000000000..b37bbde09 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml @@ -0,0 +1,4 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: FP8_DYNAMIC \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml b/tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml new file mode 100644 index 000000000..9d0e3c1a1 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: FP8 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml new file mode 100644 index 000000000..89f845279 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml +scheme: FP8A16_channel \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml new file mode 100644 index 000000000..1239287f2 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml +scheme: FP8A16_tensor \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..ecdd84938 --- /dev/null +++ b/tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +scheme: W8A8_channel_weight_static_per_tensor \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml b/tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml new file mode 100644 index 000000000..befa14beb --- /dev/null +++ b/tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W8A8 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..4af8e65ad --- /dev/null +++ b/tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +scheme: W8A8_tensor_weight_static_per_tensor_act diff --git a/tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml new file mode 100644 index 000000000..f08a64159 --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W4A16_channel +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +recipe: tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml new file mode 100644 index 000000000..bbd1406ce --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W4A16 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml new file mode 100644 index 000000000..f9adbc506 --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W8A16_channel +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +recipe: tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml new file mode 100644 index 000000000..4e9a278a5 --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W8A16 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml new file mode 100644 index 000000000..84d6505cb --- /dev/null +++ b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: float, symmetric: true, strategy: channel, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml new file mode 100644 index 000000000..8a6dfbde6 --- /dev/null +++ b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: float, symmetric: true, strategy: tensor, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..6cfa275af --- /dev/null +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml @@ -0,0 +1,10 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: int, symmetric: true, strategy: channel} + input_activations: {num_bits: 8, type: int, symmetric: true, strategy: tensor} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..6ddcc63b4 --- /dev/null +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml @@ -0,0 +1,10 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: int, symmetric: true, strategy: tensor} + input_activations: {num_bits: 8, type: int, symmetric: true, strategy: tensor} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml b/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml new file mode 100644 index 000000000..b667b2d10 --- /dev/null +++ b/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 4, type: int, symmetric: true, strategy: channel, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml b/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml new file mode 100644 index 000000000..bafd7928d --- /dev/null +++ b/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: int, symmetric: true, strategy: channel, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py new file mode 100644 index 000000000..f348d2965 --- /dev/null +++ b/tests/e2e/vLLM/test_vllm.py @@ -0,0 +1,144 @@ +import shutil +import unittest + +import pytest +from datasets import load_dataset +from parameterized import parameterized_class +from transformers import AutoTokenizer + +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot +from tests.testing_utils import parse_params, requires_gpu, requires_torch + +try: + from vllm import LLM, SamplingParams + + vllm_installed = True +except ImportError: + vllm_installed = False + +# Defines the file paths to the directories containing the test configs +# for each of the quantization schemes +WNA16 = "tests/e2e/vLLM/configs/WNA16" +FP8 = "tests/e2e/vLLM/configs/FP8" +INT8 = "tests/e2e/vLLM/configs/INT8" +CONFIGS = [WNA16, FP8, INT8] + + +@requires_gpu +@requires_torch +@pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") +@parameterized_class(parse_params(CONFIGS)) +class TestvLLM(unittest.TestCase): + """ + The following test quantizes a model using a preset scheme or recipe, + runs the model using vLLM, and then pushes the model to the hub for + future use. Each test case is focused on a specific quantization type + (e.g W4A16 with grouped quantization, W4N16 with channel quantization). + To add a new test case, a new config has to be added to one of the folders + listed in the `CONFIGS` folder. If the test case is for a data type not listed + in `CONFIGS`, a new folder can be created and added to the list. The tests + run on a cadence defined by the `cadence` field. Each config defines the model + to quantize. Optionally, a dataset id and split can be provided for calibration. + Finally, all config files must list a scheme. The scheme can be a preset scheme + from https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py # noqa: E501 + or another identifier which can be used for the particular test case. If a recipe + is not provided, it is assumed that the scheme provided is a preset scheme and will + be used for quantization. Otherwise, the recipe will always be used if given. + """ + + model = None + scheme = None + dataset_id = None + dataset_split = None + recipe = None + + def setUp(self): + print("========== RUNNING ==============") + print(self.scheme) + + self.save_dir = None + self.device = "cuda:0" + self.oneshot_kwargs = {} + self.num_calibration_samples = 256 + self.max_seq_length = 1048 + self.prompts = [ + "The capital of France is", + "The president of the US is", + "My name is", + ] + + def test_vllm(self): + # Load model. + loaded_model = SparseAutoModelForCausalLM.from_pretrained( + self.model, device_map=self.device, torch_dtype="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(self.model) + + def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=self.max_seq_length, + truncation=True, + add_special_tokens=False, + ) + + if self.dataset_id: + ds = load_dataset(self.dataset_id, split=self.dataset_split) + ds = ds.shuffle(seed=42).select(range(self.num_calibration_samples)) + ds = ds.map(preprocess) + ds = ds.map(tokenize, remove_columns=ds.column_names) + self.oneshot_kwargs["dataset"] = ds + self.oneshot_kwargs["max_seq_length"] = self.max_seq_length + self.oneshot_kwargs["num_calibration_samples"] = ( + self.num_calibration_samples + ) + + self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" + self.oneshot_kwargs["model"] = loaded_model + if self.recipe: + self.oneshot_kwargs["recipe"] = self.recipe + else: + # Test assumes that if a recipe was not provided, using + # a compatible preset sceme + self.oneshot_kwargs["recipe"] = QuantizationModifier( + targets="Linear", scheme=self.scheme, ignore=["lm_head"] + ) + + # Apply quantization. + print("ONESHOT KWARGS", self.oneshot_kwargs) + oneshot( + **self.oneshot_kwargs, + clear_sparse_session=True, + oneshot_device=self.device, + ) + self.oneshot_kwargs["model"].save_pretrained(self.save_dir) + tokenizer.save_pretrained(self.save_dir) + # Run vLLM with saved model + print("================= RUNNING vLLM =========================") + sampling_params = SamplingParams(temperature=0.80, top_p=0.95) + llm = LLM(model=self.save_dir) + outputs = llm.generate(self.prompts, sampling_params) + print("================= vLLM GENERATION ======================") + for output in outputs: + assert output + prompt = output.prompt + generated_text = output.outputs[0].text + print("PROMPT", prompt) + print("GENERATED TEXT", generated_text) + + print("================= UPLOADING TO HUB ======================") + self.oneshot_kwargs["model"].push_to_hub(f"nm-testing/{self.save_dir}-e2e") + tokenizer.push_to_hub(f"nm-testing/{self.save_dir}-e2e") + + def tearDown(self): + shutil.rmtree(self.save_dir) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index d1d9494df..ca1f05d74 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -68,39 +68,49 @@ def _validate_test_config(config: dict): # Set cadence in the config. The environment must set if nightly, weekly or commit # tests are running def parse_params( - configs_directory: str, type: Optional[str] = None + configs_directory: Union[list, str], type: Optional[str] = None ) -> List[Union[dict, CustomTestConfig]]: - # parses the config file provided - assert os.path.isdir( - configs_directory - ), f"Config_directory {configs_directory} is not a directory" + # parses the config files provided config_dicts = [] - for file in os.listdir(configs_directory): - config = _load_yaml(configs_directory, file) - if not config: - continue - - cadence = os.environ.get("CADENCE", "commit") - expected_cadence = config.get("cadence") - - if not isinstance(expected_cadence, list): - expected_cadence = [expected_cadence] - if cadence in expected_cadence: - if type == "custom": - config = CustomTestConfig(**config) + + def _parse_configs_dir(current_config_dir): + assert os.path.isdir( + current_config_dir + ), f"Config_directory {current_config_dir} is not a directory" + + for file in os.listdir(current_config_dir): + config = _load_yaml(current_config_dir, file) + if not config: + continue + + cadence = os.environ.get("CADENCE", "commit") + expected_cadence = config.get("cadence") + + if not isinstance(expected_cadence, list): + expected_cadence = [expected_cadence] + if cadence in expected_cadence: + if type == "custom": + config = CustomTestConfig(**config) + else: + if not _validate_test_config(config): + raise ValueError( + "The config provided does not comply with the expected " + "structure. See tests.data.TestConfig for the expected " + "fields." + ) + config_dicts.append(config) else: - if not _validate_test_config(config): - raise ValueError( - "The config provided does not comply with the expected " - "structure. See tests.data.TestConfig for the expected " - "fields." - ) - config_dicts.append(config) - else: - logging.info( - f"Skipping testing model: {file} for cadence: {config['cadence']}" - ) + logging.info( + f"Skipping testing model: {file} for cadence: {config['cadence']}" + ) + + if isinstance(configs_directory, list): + for config in configs_directory: + _parse_configs_dir(config) + else: + _parse_configs_dir(configs_directory) + return config_dicts