1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Dict
15+ from typing import Dict , Optional
1616
1717import torch
1818from torch .export import ExportedProgram
@@ -43,7 +43,13 @@ def __init__(self, model):
4343 self .config = model .config
4444 self .metadata = save_config_to_constant_methods (model .config , model .generation_config )
4545
46- def export (self , input_ids = None , cache_position = None ) -> Dict [str , ExportedProgram ]:
46+ def export (
47+ self ,
48+ input_ids = None ,
49+ cache_position = None ,
50+ dynamic_shapes : Optional [dict ] = None ,
51+ strict : Optional [bool ] = None ,
52+ ) -> Dict [str , ExportedProgram ]:
4753 example_input_ids = input_ids if input_ids is not None else torch .tensor ([[1 ]], dtype = torch .long )
4854 example_cache_position = cache_position if cache_position is not None else torch .tensor ([0 ], dtype = torch .long )
4955
@@ -57,13 +63,17 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
5763 exportable_module = TorchExportableModuleForDecoderOnlyLM (self .model , max_batch_size , max_cache_len )
5864
5965 with torch .no_grad ():
60- exported_program = exportable_module .export (example_input_ids , example_cache_position )
66+ exported_program = exportable_module .export (
67+ example_input_ids , example_cache_position , dynamic_shapes , strict
68+ )
6169 else :
6270 from transformers .integrations .executorch import (
6371 convert_and_export_with_cache ,
6472 )
6573
66- exported_program = convert_and_export_with_cache (self .model , example_input_ids , example_cache_position )
74+ exported_program = convert_and_export_with_cache (
75+ self .model , example_input_ids , example_cache_position , dynamic_shapes , strict
76+ )
6777
6878 return {"model" : exported_program }
6979
0 commit comments