Skip to content

Commit 122d821

Browse files
gchananfacebook-github-bot
authored andcommitted
[RESUBMIT] Kill broadcasting from the codegen layer. (#37907)
Summary: Pull Request resolved: #37907 Test Plan: Imported from OSS Differential Revision: D21420872 Pulled By: gchanan fbshipit-source-id: c782c0c438bcb7e764a97b446f8c3cd168e188f0
1 parent 88c447b commit 122d821

File tree

1 file changed

+2
-81
lines changed

1 file changed

+2
-81
lines changed

aten/src/ATen/function_wrapper.py

Lines changed: 2 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,12 @@ def TypedDict(name, attrs, total=True): # type: ignore
2424
# where they can be called directly by a native function, they can be wrapped
2525
# by a native function that handles dispatch
2626

27-
# Handle broadcasting for TH functions that need it
28-
LEGACY_TH_DECLARATION_BROADCAST = CodeTemplate("""\
29-
${return_type} ${api_name}(${type_method_formals});
30-
""")
31-
32-
LEGACY_TH_DEFINITION_BROADCAST = CodeTemplate("""\
33-
${return_type} ${api_name}(${type_method_formals}) {
34-
${named_guard_declaration}
35-
${device_guard_declaration}
36-
Tensor ${broadcast_returns};
37-
std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
38-
return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
39-
}
40-
""")
41-
4227
LEGACY_TH_DECLARATION = CodeTemplate("""\
43-
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals});
28+
${return_type} ${api_name}(${type_method_formals});
4429
""")
4530

4631
LEGACY_TH_DEFINITION = CodeTemplate("""\
47-
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) {
32+
${return_type} ${api_name}(${type_method_formals}) {
4833
${named_guard_declaration}
4934
${device_guard_declaration}
5035
${type_definition_body}
@@ -414,8 +399,6 @@ def __getitem__(self, x):
414399
'annotation': str,
415400
'allocate': bool,
416401
'mask': bool,
417-
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
418-
'broadcast': Any,
419402
'resize': str,
420403
'zero': bool,
421404
}, total=False)
@@ -481,10 +464,6 @@ def __getitem__(self, x):
481464
'arguments': List[THFormal],
482465
'backend_types': Dict[str, List[str]],
483466
'backends': List[str],
484-
'broadcast_actuals': List[str],
485-
'broadcast_function': str,
486-
'broadcast_modified_actuals': List[str],
487-
'broadcast_returns': List[str],
488467
'buffers': List[NNBuffer],
489468
# cimpls is really a List[FunctionOption]
490469
'cimpls': List[Any],
@@ -519,7 +498,6 @@ def __getitem__(self, x):
519498
'method_actuals': List[str],
520499
'method_formals_with_defaults': List[str],
521500
'method_formals': List[str],
522-
'method_prefix_derived': str,
523501
'named_guard_declaration': str,
524502
'mode': str,
525503
'python_module': str,
@@ -554,7 +532,6 @@ def __getitem__(self, x):
554532
('category_override', str),
555533
('matches_jit_signature', bool),
556534
('schema_string', str),
557-
('method_prefix_derived', str),
558535
('arguments', List[AtFormal]),
559536
('method_of', List[str]),
560537
('mode', str),
@@ -898,33 +875,6 @@ def format_return_type(return_types):
898875
return return_types[0]['type']
899876
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
900877

901-
def get_broadcast_argument(option):
902-
# type: (FunctionOption) -> Optional[THFormal]
903-
for argument in option['arguments']:
904-
if argument.get('broadcast'):
905-
return argument
906-
return None
907-
908-
def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
909-
# type: (THFormal, bool, bool) -> List[str]
910-
# Note: broadcast_dims can change type...
911-
# return the actuals that will be passed to the broadcast function.
912-
# in the broadcast_dims case (the only currently supported case), this is
913-
# the broadcasted argument (e.g. "self") followed by the sizes it is broadcasted
914-
# to (as an initializer list), so e.g. the specification:
915-
# "mat1.dim0,mat2.dim1"
916-
# gets transformed to
917-
# "self, {mat1.size(0),mat2.size(1)}"
918-
assert broadcast_dims
919-
broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
920-
# generate size call for each dimension
921-
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore
922-
for x in broadcast_dims_spec])
923-
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore
924-
broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]
925-
926-
return broadcast_actuals
927-
928878
def process_legacy_th_option(option):
929879
# type: (FunctionOption) -> None
930880
# Mutably populate option with derived values computed from values
@@ -961,9 +911,6 @@ def process_legacy_th_option(option):
961911
dispatch_tensor = find_dispatch_tensor(formals)
962912
is_namespace_function = is_function and dispatch_tensor is not None
963913

964-
broadcast_arg = get_broadcast_argument(option)
965-
# "s_" for "same size".
966-
option['method_prefix_derived'] = '' if broadcast_arg is None else 's_'
967914
if option['mode'] == 'TH':
968915
option['device_guard'] = False
969916
option['device_guard_declaration'] = device_guard(option, False, dispatch_tensor)
@@ -973,21 +920,6 @@ def process_legacy_th_option(option):
973920

974921
assert option['extended_method'], 'Expected legacy operator to be an extended method'
975922

976-
if broadcast_arg is not None:
977-
broadcast_inplace = 'inplace' in broadcast_arg['broadcast']
978-
broadcast_dims = 'dims:' in broadcast_arg['broadcast']
979-
option['broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims)
980-
if not broadcast_dims:
981-
option['broadcast_returns'] = (["b_" + x for x in option['broadcast_actuals']
982-
if x != broadcast_arg['name'] or not broadcast_inplace])
983-
else:
984-
option['broadcast_returns'] = ["b_" + broadcast_arg['name']]
985-
986-
option['broadcast_function'] = 'expand_' + ('inplace' if broadcast_inplace
987-
else 'size' if broadcast_dims else 'outplace')
988-
option['broadcast_modified_actuals'] = ['b_' + y if 'b_' + y in option['broadcast_returns'] else y
989-
for y in option['actuals']]
990-
991923
def native_get_formals(option, include_constants=False):
992924
# type: (FunctionOption, bool) -> List[AtFormal]
993925
seen = set() # type: Set[str]
@@ -1221,7 +1153,6 @@ def gen_namespace_function(option, multidispatch_formals):
12211153

12221154
check_methods_do_not_start_with_underscore(option['name'], is_method)
12231155

1224-
option['method_prefix_derived'] = ''
12251156
# NB: Device guard and scalar type generated code is still based on the
12261157
# first argument. Scalar type test will be removed once TH is removed.
12271158
# If you need more complex device guard behavior, you should disable
@@ -1233,10 +1164,6 @@ def gen_namespace_function(option, multidispatch_formals):
12331164
find_tensorlists(formals))
12341165
option['dispatch_scalar_type_declaration'] = dispatch_scalar_type(option, dispatch_options, guard_tensor)
12351166

1236-
broadcast_arg = get_broadcast_argument(option)
1237-
if broadcast_arg is not None:
1238-
raise Exception("broadcasting is not yet supported for native functions, "
1239-
"but specified for function {}", option['name'])
12401167
top_env['aten_ops'].append(OPERATOR_NAME_FULL.substitute(option))
12411168

12421169
option['native_type_method_dispatch'] = type_method_dispatch
@@ -1311,7 +1238,6 @@ def gen_namespace_function(option, multidispatch_formals):
13111238
category_override=option['category_override'],
13121239
matches_jit_signature=option["matches_jit_signature"],
13131240
schema_string=option["schema_string"],
1314-
method_prefix_derived=option['method_prefix_derived'],
13151241
arguments=formals,
13161242
method_of=method_of,
13171243
mode=option['mode'],
@@ -1570,11 +1496,6 @@ def process_legacy_th_option(option):
15701496
env = nested_dict(option, backend_type_env)
15711497
body = emit_body(env, option, option['backend_types'][backend]) # type: ignore
15721498
option['type_definition_body'] = body
1573-
if option.get('broadcast_actuals', None):
1574-
legacy_th_declarations.append(
1575-
LEGACY_TH_DECLARATION_BROADCAST.substitute(env))
1576-
legacy_th_definitions.append(
1577-
LEGACY_TH_DEFINITION_BROADCAST.substitute(env))
15781499
legacy_th_declarations.append(
15791500
LEGACY_TH_DECLARATION.substitute(env))
15801501
legacy_th_definitions.append(

0 commit comments

Comments
 (0)