@@ -678,8 +678,6 @@ def fgraph_to_python(
678
678
* ,
679
679
type_conversion_fn : Callable = lambda x , ** kwargs : x ,
680
680
order : Optional [List [Apply ]] = None ,
681
- input_storage : Optional ["InputStorageType" ] = None ,
682
- output_storage : Optional ["OutputStorageType" ] = None ,
683
681
storage_map : Optional ["StorageMapType" ] = None ,
684
682
fgraph_name : str = "fgraph_to_python" ,
685
683
global_env : Optional [Dict [Any , Any ]] = None ,
@@ -704,10 +702,6 @@ def fgraph_to_python(
704
702
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
705
703
order
706
704
The `order` argument to `map_storage`.
707
- input_storage
708
- The `input_storage` argument to `map_storage`.
709
- output_storage
710
- The `output_storage` argument to `map_storage`.
711
705
storage_map
712
706
The `storage_map` argument to `map_storage`.
713
707
fgraph_name
@@ -730,9 +724,9 @@ def fgraph_to_python(
730
724
731
725
if order is None :
732
726
order = fgraph .toposort ()
733
- input_storage , output_storage , storage_map = map_storage (
734
- fgraph , order , input_storage , output_storage , storage_map
735
- )
727
+
728
+ if storage_map is None :
729
+ storage_map = {}
736
730
737
731
unique_name = unique_name_generator ([fgraph_name ])
738
732
@@ -752,31 +746,38 @@ def fgraph_to_python(
752
746
node_input_names = []
753
747
for i in node .inputs :
754
748
local_input_name = unique_name (i )
755
- if storage_map [i ][0 ] is not None or isinstance (i , Constant ):
749
+ input_storage = storage_map .setdefault (
750
+ i , [None if not isinstance (i , Constant ) else i .data ]
751
+ )
752
+ if input_storage [0 ] is not None or isinstance (i , Constant ):
756
753
# Constants need to be assigned locally and referenced
757
754
global_env [local_input_name ] = type_conversion_fn (
758
- storage_map [ i ][ 0 ], variable = i , storage = storage_map [ i ] , ** kwargs
755
+ input_storage [ 0 ], variable = i , storage = input_storage , ** kwargs
759
756
)
760
757
# TODO: We could attempt to use the storage arrays directly
761
758
# E.g. `local_input_name = f"{local_input_name}[0]"`
762
759
node_input_names .append (local_input_name )
763
760
764
761
node_output_names = [unique_name (v ) for v in node .outputs ]
765
762
766
- assign_comment_str = f"{ indent (str (node ), '# ' )} "
767
763
assign_str = f"{ ', ' .join (node_output_names )} = { local_compiled_func_name } ({ ', ' .join (node_input_names )} )"
768
- body_assigns .append (f"{ assign_comment_str } \n { assign_str } " )
764
+ assign_comment_str = f"{ indent (str (node ), '# ' )} "
765
+ assign_block_str = f"{ assign_comment_str } \n { assign_str } "
766
+ body_assigns .append (assign_block_str )
769
767
770
768
# Handle `Constant`-only outputs (these don't have associated `Apply`
771
769
# nodes, so the above isn't applicable)
772
770
for out in fgraph .outputs :
773
771
if isinstance (out , Constant ):
774
- local_input_name = unique_name (out )
775
- if local_input_name not in global_env :
776
- global_env [local_input_name ] = type_conversion_fn (
777
- storage_map [out ][0 ],
772
+ local_output_name = unique_name (out )
773
+ if local_output_name not in global_env :
774
+ output_storage = storage_map .setdefault (
775
+ out , [None if not isinstance (out , Constant ) else out .data ]
776
+ )
777
+ global_env [local_output_name ] = type_conversion_fn (
778
+ output_storage [0 ],
778
779
variable = out ,
779
- storage = storage_map [ out ] ,
780
+ storage = output_storage ,
780
781
** kwargs ,
781
782
)
782
783
@@ -794,7 +795,7 @@ def fgraph_to_python(
794
795
fgraph_def_src = dedent (
795
796
f"""
796
797
def { fgraph_name } ({ ", " .join (fgraph_input_names )} ):
797
- { indent (joined_body_assigns , " " * 4 )}
798
+ { indent (joined_body_assigns , " " * 4 )}
798
799
return { fgraph_return_src }
799
800
"""
800
801
).strip ()
0 commit comments