1313# limitations under the License.
1414
1515import logging
16- from typing import Dict , Union
16+ from itertools import product
17+ from typing import Any , Dict , Union
1718
1819from tabulate import tabulate
1920from torch .export import ExportedProgram
3435from ..recipe_registry import register_recipe
3536
3637
37- @register_recipe ("coreml" )
38- def export_to_executorch_with_coreml (
38+ def _export_to_executorch (
3939 model : Union [CausalLMExportableModule , MaskedLMExportableModule , Seq2SeqLMExportableModule ],
4040 ** kwargs ,
4141):
@@ -63,23 +63,14 @@ def export_to_executorch_with_coreml(
6363
6464 def _lower_to_executorch (
6565 exported_programs : Dict [str , ExportedProgram ],
66- metadata = None ,
67- ** kwargs ,
66+ metadata ,
67+ compute_unit ,
68+ minimum_deployment_target ,
69+ compute_precision ,
6870 ) -> Dict [str , ExecutorchProgram ]:
69- compute_unit = kwargs .get ("compute_unit" , ct .ComputeUnit .ALL )
70- minimum_deployment_target = kwargs .get ("minimum_deployment_target" , ct .target .iOS15 )
71- compute_precision = kwargs .get ("compute_precision" , ct .precision .FLOAT16 )
72- model_type = kwargs .get ("model_type" , "model" )
73- model_type = {
74- "model" : CoreMLBackend .MODEL_TYPE .MODEL ,
75- "modelc" : CoreMLBackend .MODEL_TYPE .COMPILED_MODEL ,
76- }[model_type ]
77- take_over_mutable_buffer = kwargs .get ("take_over_mutable_buffer" , True )
78-
7971 et_progs = {}
8072 backend_config_dict = {}
8173 for pte_name , exported_program in exported_programs .items ():
82- exported_program = exported_program .run_decompositions ({})
8374 logging .debug (f"\n Exported program for { pte_name } .pte: { exported_program } " )
8475 et_progs [pte_name ] = to_edge_transform_and_lower (
8576 exported_program ,
@@ -89,14 +80,15 @@ def _lower_to_executorch(
8980 compute_unit = compute_unit ,
9081 minimum_deployment_target = minimum_deployment_target ,
9182 compute_precision = compute_precision ,
92- model_type = model_type ,
83+ model_type = CoreMLBackend . MODEL_TYPE . MODEL ,
9384 ),
94- take_over_mutable_buffer = take_over_mutable_buffer ,
85+ take_over_mutable_buffer = ( minimum_deployment_target >= ct . target . iOS18 ) ,
9586 )
9687 ],
9788 compile_config = EdgeCompileConfig (
9889 _check_ir_validity = False ,
99- _skip_dim_order = False ,
90+ # In ET 0.7, we can set _skip_dim_order=False
91+ _skip_dim_order = True ,
10092 ),
10193 constant_methods = metadata ,
10294 ).to_executorch (
@@ -114,3 +106,46 @@ def _lower_to_executorch(
114106
115107 exported_progs = model .export ()
116108 return _lower_to_executorch (exported_progs , model .metadata , ** kwargs )
109+
110+
111+ def _get_recipe_kwargs (dtype : str , compute_unit : str ) -> Dict [str , Any ]:
112+ import coremltools as ct
113+
114+ compute_precision = {
115+ "fp16" : ct .precision .FLOAT16 ,
116+ "fp32" : ct .precision .FLOAT32 ,
117+ }[dtype ]
118+
119+ compute_unit = {
120+ "cpu" : ct .ComputeUnit .CPU_ONLY ,
121+ "gpu" : ct .ComputeUnit .CPU_AND_GPU ,
122+ "ne" : ct .ComputeUnit .CPU_AND_NE ,
123+ "all" : ct .ComputeUnit .ALL ,
124+ }[compute_unit ]
125+
126+ recipe_kwargs = {
127+ "compute_precision" : compute_precision ,
128+ "compute_unit" : compute_unit ,
129+ "minimum_deployment_target" : ct .target .iOS18 ,
130+ }
131+ return recipe_kwargs
132+
133+
134+ def _make_recipe (recipe_name , recipe_kwargs ):
135+ @register_recipe (recipe_name )
136+ def recipe_fn (exported_programs : Dict [str , ExportedProgram ], ** kwargs ):
137+ return _export_to_executorch (
138+ exported_programs ,
139+ ** recipe_kwargs ,
140+ )
141+
142+ return recipe_fn
143+
144+
145+ # Register recipes for CoreML backend
146+ for dtype , compute_unit in product (["fp32" , "fp16" ], ["cpu" , "gpu" , "ne" , "all" ]):
147+ recipe_name = f"coreml_{ dtype } "
148+ if compute_unit != "all" :
149+ recipe_name += f"_{ compute_unit } "
150+ recipe_kwargs = _get_recipe_kwargs (dtype = dtype , compute_unit = compute_unit )
151+ _make_recipe (recipe_name , recipe_kwargs )
0 commit comments