From 433f477e3f3acd671f18b2a6d77d24390e0e1ce6 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 23 Dec 2022 15:23:08 +0100 Subject: [PATCH] Allow kwargs in all generate_dummy_inputs() methods (#638) fix generate_dummy_inputs args --- optimum/exporters/onnx/model_configs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index ee99668144..d3f9cafb7c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -555,8 +555,8 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: "pooler_output": {0: "batch_size", 1: "feature_dim"}, } - def generate_dummy_inputs(self, framework: str = "pt"): - dummy_inputs = super().generate_dummy_inputs(framework=framework) + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) if framework == "pt": import torch @@ -599,8 +599,8 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: def output_names_for_validation(self, reference_output_names: List[str]) -> List[str]: return ["sample"] - def generate_dummy_inputs(self, framework: str = "pt"): - dummy_inputs = super().generate_dummy_inputs(framework=framework) + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] return dummy_inputs