@@ -24,27 +24,12 @@ def TypedDict(name, attrs, total=True): # type: ignore
24
24
# where they can be called directly by a native function, they can be wrapped
25
25
# by a native function that handles dispatch
26
26
27
- # Handle broadcasting for TH functions that need it
28
- LEGACY_TH_DECLARATION_BROADCAST = CodeTemplate ("""\
29
- ${return_type} ${api_name}(${type_method_formals});
30
- """ )
31
-
32
- LEGACY_TH_DEFINITION_BROADCAST = CodeTemplate ("""\
33
- ${return_type} ${api_name}(${type_method_formals}) {
34
- ${named_guard_declaration}
35
- ${device_guard_declaration}
36
- Tensor ${broadcast_returns};
37
- std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
38
- return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
39
- }
40
- """ )
41
-
42
27
LEGACY_TH_DECLARATION = CodeTemplate ("""\
43
- ${return_type} ${method_prefix_derived}${ api_name}(${type_method_formals});
28
+ ${return_type} ${api_name}(${type_method_formals});
44
29
""" )
45
30
46
31
LEGACY_TH_DEFINITION = CodeTemplate ("""\
47
- ${return_type} ${method_prefix_derived}${ api_name}(${type_method_formals}) {
32
+ ${return_type} ${api_name}(${type_method_formals}) {
48
33
${named_guard_declaration}
49
34
${device_guard_declaration}
50
35
${type_definition_body}
@@ -414,8 +399,6 @@ def __getitem__(self, x):
414
399
'annotation' : str ,
415
400
'allocate' : bool ,
416
401
'mask' : bool ,
417
- # Broadcast is originally a str but gets unwrapped to a List or Dict in-place
418
- 'broadcast' : Any ,
419
402
'resize' : str ,
420
403
'zero' : bool ,
421
404
}, total = False )
@@ -481,10 +464,6 @@ def __getitem__(self, x):
481
464
'arguments' : List [THFormal ],
482
465
'backend_types' : Dict [str , List [str ]],
483
466
'backends' : List [str ],
484
- 'broadcast_actuals' : List [str ],
485
- 'broadcast_function' : str ,
486
- 'broadcast_modified_actuals' : List [str ],
487
- 'broadcast_returns' : List [str ],
488
467
'buffers' : List [NNBuffer ],
489
468
# cimpls is really a List[FunctionOption]
490
469
'cimpls' : List [Any ],
@@ -519,7 +498,6 @@ def __getitem__(self, x):
519
498
'method_actuals' : List [str ],
520
499
'method_formals_with_defaults' : List [str ],
521
500
'method_formals' : List [str ],
522
- 'method_prefix_derived' : str ,
523
501
'named_guard_declaration' : str ,
524
502
'mode' : str ,
525
503
'python_module' : str ,
@@ -554,7 +532,6 @@ def __getitem__(self, x):
554
532
('category_override' , str ),
555
533
('matches_jit_signature' , bool ),
556
534
('schema_string' , str ),
557
- ('method_prefix_derived' , str ),
558
535
('arguments' , List [AtFormal ]),
559
536
('method_of' , List [str ]),
560
537
('mode' , str ),
@@ -898,33 +875,6 @@ def format_return_type(return_types):
898
875
return return_types [0 ]['type' ]
899
876
return "std::tuple<{}>" .format (',' .join (r ['type' ] for r in return_types ))
900
877
901
- def get_broadcast_argument (option ):
902
- # type: (FunctionOption) -> Optional[THFormal]
903
- for argument in option ['arguments' ]:
904
- if argument .get ('broadcast' ):
905
- return argument
906
- return None
907
-
908
- def get_broadcast_actuals (broadcast_arg , broadcast_inplace , broadcast_dims ):
909
- # type: (THFormal, bool, bool) -> List[str]
910
- # Note: broadcast_dims can change type...
911
- # return the actuals that will be passed to the broadcast function.
912
- # in the broadcast_dims case (the only currently supported case), this is
913
- # the broadcasted argument (e.g. "self") followed by the sizes it is broadcasted
914
- # to (as an initializer list), so e.g. the specification:
915
- # "mat1.dim0,mat2.dim1"
916
- # gets transformed to
917
- # "self, {mat1.size(0),mat2.size(1)}"
918
- assert broadcast_dims
919
- broadcast_dims_spec = broadcast_arg ['broadcast' ].split ()[1 ].split (':' )[1 ].split (',' )
920
- # generate size call for each dimension
921
- broadcast_dims = ([x .split ('.' )[0 ] + '.size(' + x .split ('.' )[1 ].replace ('dim' , '' ) + ')' # type: ignore
922
- for x in broadcast_dims_spec ])
923
- broadcast_dims_init_list = '{' + ',' .join (broadcast_dims ) + '}' # type: ignore
924
- broadcast_actuals = [broadcast_arg ['name' ], broadcast_dims_init_list ]
925
-
926
- return broadcast_actuals
927
-
928
878
def process_legacy_th_option (option ):
929
879
# type: (FunctionOption) -> None
930
880
# Mutably populate option with derived values computed from values
@@ -961,9 +911,6 @@ def process_legacy_th_option(option):
961
911
dispatch_tensor = find_dispatch_tensor (formals )
962
912
is_namespace_function = is_function and dispatch_tensor is not None
963
913
964
- broadcast_arg = get_broadcast_argument (option )
965
- # "s_" for "same size".
966
- option ['method_prefix_derived' ] = '' if broadcast_arg is None else 's_'
967
914
if option ['mode' ] == 'TH' :
968
915
option ['device_guard' ] = False
969
916
option ['device_guard_declaration' ] = device_guard (option , False , dispatch_tensor )
@@ -973,21 +920,6 @@ def process_legacy_th_option(option):
973
920
974
921
assert option ['extended_method' ], 'Expected legacy operator to be an extended method'
975
922
976
- if broadcast_arg is not None :
977
- broadcast_inplace = 'inplace' in broadcast_arg ['broadcast' ]
978
- broadcast_dims = 'dims:' in broadcast_arg ['broadcast' ]
979
- option ['broadcast_actuals' ] = get_broadcast_actuals (broadcast_arg , broadcast_inplace , broadcast_dims )
980
- if not broadcast_dims :
981
- option ['broadcast_returns' ] = (["b_" + x for x in option ['broadcast_actuals' ]
982
- if x != broadcast_arg ['name' ] or not broadcast_inplace ])
983
- else :
984
- option ['broadcast_returns' ] = ["b_" + broadcast_arg ['name' ]]
985
-
986
- option ['broadcast_function' ] = 'expand_' + ('inplace' if broadcast_inplace
987
- else 'size' if broadcast_dims else 'outplace' )
988
- option ['broadcast_modified_actuals' ] = ['b_' + y if 'b_' + y in option ['broadcast_returns' ] else y
989
- for y in option ['actuals' ]]
990
-
991
923
def native_get_formals (option , include_constants = False ):
992
924
# type: (FunctionOption, bool) -> List[AtFormal]
993
925
seen = set () # type: Set[str]
@@ -1221,7 +1153,6 @@ def gen_namespace_function(option, multidispatch_formals):
1221
1153
1222
1154
check_methods_do_not_start_with_underscore (option ['name' ], is_method )
1223
1155
1224
- option ['method_prefix_derived' ] = ''
1225
1156
# NB: Device guard and scalar type generated code is still based on the
1226
1157
# first argument. Scalar type test will be removed once TH is removed.
1227
1158
# If you need more complex device guard behavior, you should disable
@@ -1233,10 +1164,6 @@ def gen_namespace_function(option, multidispatch_formals):
1233
1164
find_tensorlists (formals ))
1234
1165
option ['dispatch_scalar_type_declaration' ] = dispatch_scalar_type (option , dispatch_options , guard_tensor )
1235
1166
1236
- broadcast_arg = get_broadcast_argument (option )
1237
- if broadcast_arg is not None :
1238
- raise Exception ("broadcasting is not yet supported for native functions, "
1239
- "but specified for function {}" , option ['name' ])
1240
1167
top_env ['aten_ops' ].append (OPERATOR_NAME_FULL .substitute (option ))
1241
1168
1242
1169
option ['native_type_method_dispatch' ] = type_method_dispatch
@@ -1311,7 +1238,6 @@ def gen_namespace_function(option, multidispatch_formals):
1311
1238
category_override = option ['category_override' ],
1312
1239
matches_jit_signature = option ["matches_jit_signature" ],
1313
1240
schema_string = option ["schema_string" ],
1314
- method_prefix_derived = option ['method_prefix_derived' ],
1315
1241
arguments = formals ,
1316
1242
method_of = method_of ,
1317
1243
mode = option ['mode' ],
@@ -1570,11 +1496,6 @@ def process_legacy_th_option(option):
1570
1496
env = nested_dict (option , backend_type_env )
1571
1497
body = emit_body (env , option , option ['backend_types' ][backend ]) # type: ignore
1572
1498
option ['type_definition_body' ] = body
1573
- if option .get ('broadcast_actuals' , None ):
1574
- legacy_th_declarations .append (
1575
- LEGACY_TH_DECLARATION_BROADCAST .substitute (env ))
1576
- legacy_th_definitions .append (
1577
- LEGACY_TH_DEFINITION_BROADCAST .substitute (env ))
1578
1499
legacy_th_declarations .append (
1579
1500
LEGACY_TH_DECLARATION .substitute (env ))
1580
1501
legacy_th_definitions .append (
0 commit comments