diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index adb3b4a2..1c3af49d 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -92,8 +92,12 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon f"This attribute is used to identify the corresponding AutoModel class." ) - for key, value in models.items(): - setattr(self, key, value) + if len(models) == 1: + # For single PTE, always set the attr to "model" + setattr(self, "model", next(iter(models.values()))) + else: + for key, value in models.items(): + setattr(self, key, value) self.stats = Stats() @@ -570,8 +574,8 @@ class ExecuTorchModelForCausalLM(ExecuTorchModelBase): Data type of the model parameters. bos_token_id (`int`): Beginning-of-sequence token ID. - eos_token_id (`int`): - End-of-sequence token ID. + eos_token_ids (`List[int]`): + End-of-sequence token IDs. vocab_size (`int`): Size of the model vocabulary. """ @@ -594,8 +598,10 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon self.dtype = self.model.run_method("get_dtype")[0] if "get_bos_id" in metadata: self.bos_token_id = self.model.run_method("get_bos_id")[0] - if "get_eos_id" in metadata: - self.eos_token_id = self.model.run_method("get_eos_id")[0] + for key in ("get_eos_id", "get_eos_ids"): + if key in metadata: + self.eos_token_ids = self.model.run_method(key) + break if "get_vocab_size" in metadata: self.vocab_size = self.model.run_method("get_vocab_size")[0] if "use_sdpa_with_kv_cache" in metadata: @@ -694,7 +700,7 @@ def generate( next_token = torch.argmax(logits, dim=-1).item() generated_tokens.append(next_token) - if next_token == self.eos_token_id: + if next_token in self.eos_token_ids: break self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens)) @@ -730,9 +736,9 @@ def text_generation( raise ValueError( f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." ) - if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id: + if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id not in self.eos_token_ids: raise ValueError( - f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}." + f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must match with the model's eos_token_ids={self.eos_token_ids}." ) # Reset stats for a new generation diff --git a/optimum/executorch/passes/remove_padding_idx_embedding_pass.py b/optimum/executorch/passes/remove_padding_idx_embedding_pass.py new file mode 100644 index 00000000..b3578175 --- /dev/null +++ b/optimum/executorch/passes/remove_padding_idx_embedding_pass.py @@ -0,0 +1,22 @@ +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemovePaddingIdxEmbeddingPass(ExportPass): + """ + An ExportPass that removes the `padding_idx` keyword argument + from all aten.embedding.default operator calls. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default: + # node.args[2] is the padding_idx + if len(node.args) == 3: + node.args = (node.args[0], node.args[1]) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index ce0220dd..14970e62 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -28,6 +28,7 @@ ExecutorchProgram, to_edge_transform_and_lower, ) +from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass from ..integrations import ( CausalLMExportableModule, @@ -76,9 +77,11 @@ def _lower_to_executorch( exported_program, partitioner=[XnnpackPartitioner()], compile_config=EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=True, ), constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], ).to_executorch( config=ExecutorchBackendConfig(**backend_config_dict), ) diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py index 2dca1221..b6df9048 100644 --- a/optimum/exporters/executorch/tasks/causal_lm.py +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -77,6 +77,11 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl ), ) + for param in eager_model.parameters(): + # Must disable gradient for quantized checkpoint + if isinstance(param, torchao.utils.TorchAOBaseTensor): + param.requires_grad = False + # TODO: Move quantization recipe out for better composability. # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed. qlinear_config = kwargs.get("qlinear", None) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index d3ff1b8b..4f30dc03 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -153,3 +153,33 @@ def test_eager_text_generation_with_custom_sdpa(self): eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)[0] logging.info(f"\nEager generated text:\n\t{eager_generated_text}") self.assertTrue(check_causal_lm_output_quality(model_id, eager_generated_ids)) + + def test_removing_padding_idx_embedding_pass(self): + class ModuleWithEmbedding(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(10, 3, padding_idx=0) + + def forward(self, x): + return self.emb(x) + torch.ops.aten.embedding.default(self.emb.weight, x, padding_idx=1) + + test_model = ModuleWithEmbedding() + example_inputs = (torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]),) + exported_model = torch.export.export(test_model, example_inputs) + + from executorch.exir import to_edge_transform_and_lower + from executorch.exir.dialects._ops import ops as exir_ops + + from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass + + et_model = to_edge_transform_and_lower( + exported_model, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ) + self.assertTrue( + all( + len(node.args) < 3 + for node in et_model.exported_program().graph_module.graph.nodes + if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default + ) + ) diff --git a/tests/models/test_modeling_phi4.py b/tests/models/test_modeling_phi4.py index 3baa511b..7c989e10 100644 --- a/tests/models/test_modeling_phi4.py +++ b/tests/models/test_modeling_phi4.py @@ -15,10 +15,13 @@ import gc import logging +import os import unittest import pytest +import torchao from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from packaging.version import parse from transformers import AutoConfig, AutoTokenizer from transformers.testing_utils import slow @@ -27,13 +30,21 @@ from ..utils import check_causal_lm_output_quality -@pytest.mark.skip(reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM") +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +is_ci = os.environ.get("GITHUB_ACTIONS") == "true" + + class ExecuTorchModelIntegrationTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @slow @pytest.mark.run_slow + @pytest.mark.skipif( + is_ci, + reason="Test Phi-4-mini (3.8B) will require runner to be configured with larger RAM", + ) def test_phi4_text_generation(self): model_id = "microsoft/Phi-4-mini-instruct" config = AutoConfig.from_pretrained(model_id) @@ -61,3 +72,92 @@ def test_phi4_text_generation(self): gc.collect() self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif( + parse(torchao.__version__) < parse("0.11.0.dev0"), + reason="Only available on torchao >= 0.11.0.dev0", + ) + def test_phi4_text_generation_with_quantized_pte_from_hub(self): + model_id = "pytorch/Phi-4-mini-instruct-8da4w" + config = AutoConfig.from_pretrained(model_id) + # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting + # the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite + # that function to avoid the data-dependent control flow. + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + config.rope_scaling["type"] = "default" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_id, recipe="xnnpack", config=config, file_name="phi4-mini-8da4w.pte" + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="My favourite condiment is ", + max_seq_len=64, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + + if not is_ci: + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + + # Free memory before loading eager for quality check + del model + del tokenizer + gc.collect() + + self.assertTrue( + check_causal_lm_output_quality( + "microsoft/Phi-4-mini-instruct", + generated_tokens, + ) + ) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif( + parse(torchao.__version__) < parse("0.11.0.dev0"), + reason="Only available on torchao >= 0.11.0.dev0", + ) + def test_phi4_text_generation_with_quantized_ckp(self): + model_id = "pytorch/Phi-4-mini-instruct-8da4w" + config = AutoConfig.from_pretrained(model_id) + # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting + # the data-dependent control flow in _longrope_frequency_update. Alternatively, we can rewrite + # that function to avoid the data-dependent control flow. + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + config.rope_scaling["type"] = "default" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_id, + recipe="xnnpack", + config=config, + export=True, + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="My favourite condiment is ", + max_seq_len=64, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + + if not is_ci: + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + + # Free memory before loading eager for quality check + del model + del tokenizer + gc.collect() + + self.assertTrue( + check_causal_lm_output_quality( + "microsoft/Phi-4-mini-instruct", + generated_tokens, + ) + )