@@ -24,27 +24,12 @@ def TypedDict(name, attrs, total=True): # type: ignore
2424# where they can be called directly by a native function, they can be wrapped
2525# by a native function that handles dispatch
2626
27- # Handle broadcasting for TH functions that need it
28- LEGACY_TH_DECLARATION_BROADCAST = CodeTemplate ("""\
29- ${return_type} ${api_name}(${type_method_formals});
30- """ )
31-
32- LEGACY_TH_DEFINITION_BROADCAST = CodeTemplate ("""\
33- ${return_type} ${api_name}(${type_method_formals}) {
34- ${named_guard_declaration}
35- ${device_guard_declaration}
36- Tensor ${broadcast_returns};
37- std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
38- return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
39- }
40- """ )
41-
4227LEGACY_TH_DECLARATION = CodeTemplate ("""\
43- ${return_type} ${method_prefix_derived}${ api_name}(${type_method_formals});
28+ ${return_type} ${api_name}(${type_method_formals});
4429""" )
4530
4631LEGACY_TH_DEFINITION = CodeTemplate ("""\
47- ${return_type} ${method_prefix_derived}${ api_name}(${type_method_formals}) {
32+ ${return_type} ${api_name}(${type_method_formals}) {
4833 ${named_guard_declaration}
4934 ${device_guard_declaration}
5035 ${type_definition_body}
@@ -414,8 +399,6 @@ def __getitem__(self, x):
414399 'annotation' : str ,
415400 'allocate' : bool ,
416401 'mask' : bool ,
417- # Broadcast is originally a str but gets unwrapped to a List or Dict in-place
418- 'broadcast' : Any ,
419402 'resize' : str ,
420403 'zero' : bool ,
421404}, total = False )
@@ -481,10 +464,6 @@ def __getitem__(self, x):
481464 'arguments' : List [THFormal ],
482465 'backend_types' : Dict [str , List [str ]],
483466 'backends' : List [str ],
484- 'broadcast_actuals' : List [str ],
485- 'broadcast_function' : str ,
486- 'broadcast_modified_actuals' : List [str ],
487- 'broadcast_returns' : List [str ],
488467 'buffers' : List [NNBuffer ],
489468 # cimpls is really a List[FunctionOption]
490469 'cimpls' : List [Any ],
@@ -519,7 +498,6 @@ def __getitem__(self, x):
519498 'method_actuals' : List [str ],
520499 'method_formals_with_defaults' : List [str ],
521500 'method_formals' : List [str ],
522- 'method_prefix_derived' : str ,
523501 'named_guard_declaration' : str ,
524502 'mode' : str ,
525503 'python_module' : str ,
@@ -554,7 +532,6 @@ def __getitem__(self, x):
554532 ('category_override' , str ),
555533 ('matches_jit_signature' , bool ),
556534 ('schema_string' , str ),
557- ('method_prefix_derived' , str ),
558535 ('arguments' , List [AtFormal ]),
559536 ('method_of' , List [str ]),
560537 ('mode' , str ),
@@ -898,33 +875,6 @@ def format_return_type(return_types):
898875 return return_types [0 ]['type' ]
899876 return "std::tuple<{}>" .format (',' .join (r ['type' ] for r in return_types ))
900877
901- def get_broadcast_argument (option ):
902- # type: (FunctionOption) -> Optional[THFormal]
903- for argument in option ['arguments' ]:
904- if argument .get ('broadcast' ):
905- return argument
906- return None
907-
908- def get_broadcast_actuals (broadcast_arg , broadcast_inplace , broadcast_dims ):
909- # type: (THFormal, bool, bool) -> List[str]
910- # Note: broadcast_dims can change type...
911- # return the actuals that will be passed to the broadcast function.
912- # in the broadcast_dims case (the only currently supported case), this is
913- # the broadcasted argument (e.g. "self") followed by the sizes it is broadcasted
914- # to (as an initializer list), so e.g. the specification:
915- # "mat1.dim0,mat2.dim1"
916- # gets transformed to
917- # "self, {mat1.size(0),mat2.size(1)}"
918- assert broadcast_dims
919- broadcast_dims_spec = broadcast_arg ['broadcast' ].split ()[1 ].split (':' )[1 ].split (',' )
920- # generate size call for each dimension
921- broadcast_dims = ([x .split ('.' )[0 ] + '.size(' + x .split ('.' )[1 ].replace ('dim' , '' ) + ')' # type: ignore
922- for x in broadcast_dims_spec ])
923- broadcast_dims_init_list = '{' + ',' .join (broadcast_dims ) + '}' # type: ignore
924- broadcast_actuals = [broadcast_arg ['name' ], broadcast_dims_init_list ]
925-
926- return broadcast_actuals
927-
928878 def process_legacy_th_option (option ):
929879 # type: (FunctionOption) -> None
930880 # Mutably populate option with derived values computed from values
@@ -961,9 +911,6 @@ def process_legacy_th_option(option):
961911 dispatch_tensor = find_dispatch_tensor (formals )
962912 is_namespace_function = is_function and dispatch_tensor is not None
963913
964- broadcast_arg = get_broadcast_argument (option )
965- # "s_" for "same size".
966- option ['method_prefix_derived' ] = '' if broadcast_arg is None else 's_'
967914 if option ['mode' ] == 'TH' :
968915 option ['device_guard' ] = False
969916 option ['device_guard_declaration' ] = device_guard (option , False , dispatch_tensor )
@@ -973,21 +920,6 @@ def process_legacy_th_option(option):
973920
974921 assert option ['extended_method' ], 'Expected legacy operator to be an extended method'
975922
976- if broadcast_arg is not None :
977- broadcast_inplace = 'inplace' in broadcast_arg ['broadcast' ]
978- broadcast_dims = 'dims:' in broadcast_arg ['broadcast' ]
979- option ['broadcast_actuals' ] = get_broadcast_actuals (broadcast_arg , broadcast_inplace , broadcast_dims )
980- if not broadcast_dims :
981- option ['broadcast_returns' ] = (["b_" + x for x in option ['broadcast_actuals' ]
982- if x != broadcast_arg ['name' ] or not broadcast_inplace ])
983- else :
984- option ['broadcast_returns' ] = ["b_" + broadcast_arg ['name' ]]
985-
986- option ['broadcast_function' ] = 'expand_' + ('inplace' if broadcast_inplace
987- else 'size' if broadcast_dims else 'outplace' )
988- option ['broadcast_modified_actuals' ] = ['b_' + y if 'b_' + y in option ['broadcast_returns' ] else y
989- for y in option ['actuals' ]]
990-
991923 def native_get_formals (option , include_constants = False ):
992924 # type: (FunctionOption, bool) -> List[AtFormal]
993925 seen = set () # type: Set[str]
@@ -1221,7 +1153,6 @@ def gen_namespace_function(option, multidispatch_formals):
12211153
12221154 check_methods_do_not_start_with_underscore (option ['name' ], is_method )
12231155
1224- option ['method_prefix_derived' ] = ''
12251156 # NB: Device guard and scalar type generated code is still based on the
12261157 # first argument. Scalar type test will be removed once TH is removed.
12271158 # If you need more complex device guard behavior, you should disable
@@ -1233,10 +1164,6 @@ def gen_namespace_function(option, multidispatch_formals):
12331164 find_tensorlists (formals ))
12341165 option ['dispatch_scalar_type_declaration' ] = dispatch_scalar_type (option , dispatch_options , guard_tensor )
12351166
1236- broadcast_arg = get_broadcast_argument (option )
1237- if broadcast_arg is not None :
1238- raise Exception ("broadcasting is not yet supported for native functions, "
1239- "but specified for function {}" , option ['name' ])
12401167 top_env ['aten_ops' ].append (OPERATOR_NAME_FULL .substitute (option ))
12411168
12421169 option ['native_type_method_dispatch' ] = type_method_dispatch
@@ -1311,7 +1238,6 @@ def gen_namespace_function(option, multidispatch_formals):
13111238 category_override = option ['category_override' ],
13121239 matches_jit_signature = option ["matches_jit_signature" ],
13131240 schema_string = option ["schema_string" ],
1314- method_prefix_derived = option ['method_prefix_derived' ],
13151241 arguments = formals ,
13161242 method_of = method_of ,
13171243 mode = option ['mode' ],
@@ -1570,11 +1496,6 @@ def process_legacy_th_option(option):
15701496 env = nested_dict (option , backend_type_env )
15711497 body = emit_body (env , option , option ['backend_types' ][backend ]) # type: ignore
15721498 option ['type_definition_body' ] = body
1573- if option .get ('broadcast_actuals' , None ):
1574- legacy_th_declarations .append (
1575- LEGACY_TH_DECLARATION_BROADCAST .substitute (env ))
1576- legacy_th_definitions .append (
1577- LEGACY_TH_DEFINITION_BROADCAST .substitute (env ))
15781499 legacy_th_declarations .append (
15791500 LEGACY_TH_DECLARATION .substitute (env ))
15801501 legacy_th_definitions .append (
0 commit comments