Skip to content

Commit

Permalink
[RESUBMIT] Kill broadcasting from the codegen layer. (#37907)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #37907

Test Plan: Imported from OSS

Differential Revision: D21420872

Pulled By: gchanan

fbshipit-source-id: c782c0c438bcb7e764a97b446f8c3cd168e188f0
  • Loading branch information
gchanan authored and facebook-github-bot committed May 6, 2020
1 parent 88c447b commit 122d821
Showing 1 changed file with 2 additions and 81 deletions.
83 changes: 2 additions & 81 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,12 @@ def TypedDict(name, attrs, total=True): # type: ignore
# where they can be called directly by a native function, they can be wrapped
# by a native function that handles dispatch

# Handle broadcasting for TH functions that need it
LEGACY_TH_DECLARATION_BROADCAST = CodeTemplate("""\
${return_type} ${api_name}(${type_method_formals});
""")

LEGACY_TH_DEFINITION_BROADCAST = CodeTemplate("""\
${return_type} ${api_name}(${type_method_formals}) {
${named_guard_declaration}
${device_guard_declaration}
Tensor ${broadcast_returns};
std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
}
""")

LEGACY_TH_DECLARATION = CodeTemplate("""\
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals});
${return_type} ${api_name}(${type_method_formals});
""")

LEGACY_TH_DEFINITION = CodeTemplate("""\
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) {
${return_type} ${api_name}(${type_method_formals}) {
${named_guard_declaration}
${device_guard_declaration}
${type_definition_body}
Expand Down Expand Up @@ -414,8 +399,6 @@ def __getitem__(self, x):
'annotation': str,
'allocate': bool,
'mask': bool,
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
'broadcast': Any,
'resize': str,
'zero': bool,
}, total=False)
Expand Down Expand Up @@ -481,10 +464,6 @@ def __getitem__(self, x):
'arguments': List[THFormal],
'backend_types': Dict[str, List[str]],
'backends': List[str],
'broadcast_actuals': List[str],
'broadcast_function': str,
'broadcast_modified_actuals': List[str],
'broadcast_returns': List[str],
'buffers': List[NNBuffer],
# cimpls is really a List[FunctionOption]
'cimpls': List[Any],
Expand Down Expand Up @@ -519,7 +498,6 @@ def __getitem__(self, x):
'method_actuals': List[str],
'method_formals_with_defaults': List[str],
'method_formals': List[str],
'method_prefix_derived': str,
'named_guard_declaration': str,
'mode': str,
'python_module': str,
Expand Down Expand Up @@ -554,7 +532,6 @@ def __getitem__(self, x):
('category_override', str),
('matches_jit_signature', bool),
('schema_string', str),
('method_prefix_derived', str),
('arguments', List[AtFormal]),
('method_of', List[str]),
('mode', str),
Expand Down Expand Up @@ -898,33 +875,6 @@ def format_return_type(return_types):
return return_types[0]['type']
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))

def get_broadcast_argument(option):
# type: (FunctionOption) -> Optional[THFormal]
for argument in option['arguments']:
if argument.get('broadcast'):
return argument
return None

def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims):
# type: (THFormal, bool, bool) -> List[str]
# Note: broadcast_dims can change type...
# return the actuals that will be passed to the broadcast function.
# in the broadcast_dims case (the only currently supported case), this is
# the broadcasted argument (e.g. "self") followed by the sizes it is broadcasted
# to (as an initializer list), so e.g. the specification:
# "mat1.dim0,mat2.dim1"
# gets transformed to
# "self, {mat1.size(0),mat2.size(1)}"
assert broadcast_dims
broadcast_dims_spec = broadcast_arg['broadcast'].split()[1].split(':')[1].split(',')
# generate size call for each dimension
broadcast_dims = ([x.split('.')[0] + '.size(' + x.split('.')[1].replace('dim', '') + ')' # type: ignore
for x in broadcast_dims_spec])
broadcast_dims_init_list = '{' + ','.join(broadcast_dims) + '}' # type: ignore
broadcast_actuals = [broadcast_arg['name'], broadcast_dims_init_list]

return broadcast_actuals

def process_legacy_th_option(option):
# type: (FunctionOption) -> None
# Mutably populate option with derived values computed from values
Expand Down Expand Up @@ -961,9 +911,6 @@ def process_legacy_th_option(option):
dispatch_tensor = find_dispatch_tensor(formals)
is_namespace_function = is_function and dispatch_tensor is not None

broadcast_arg = get_broadcast_argument(option)
# "s_" for "same size".
option['method_prefix_derived'] = '' if broadcast_arg is None else 's_'
if option['mode'] == 'TH':
option['device_guard'] = False
option['device_guard_declaration'] = device_guard(option, False, dispatch_tensor)
Expand All @@ -973,21 +920,6 @@ def process_legacy_th_option(option):

assert option['extended_method'], 'Expected legacy operator to be an extended method'

if broadcast_arg is not None:
broadcast_inplace = 'inplace' in broadcast_arg['broadcast']
broadcast_dims = 'dims:' in broadcast_arg['broadcast']
option['broadcast_actuals'] = get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims)
if not broadcast_dims:
option['broadcast_returns'] = (["b_" + x for x in option['broadcast_actuals']
if x != broadcast_arg['name'] or not broadcast_inplace])
else:
option['broadcast_returns'] = ["b_" + broadcast_arg['name']]

option['broadcast_function'] = 'expand_' + ('inplace' if broadcast_inplace
else 'size' if broadcast_dims else 'outplace')
option['broadcast_modified_actuals'] = ['b_' + y if 'b_' + y in option['broadcast_returns'] else y
for y in option['actuals']]

def native_get_formals(option, include_constants=False):
# type: (FunctionOption, bool) -> List[AtFormal]
seen = set() # type: Set[str]
Expand Down Expand Up @@ -1221,7 +1153,6 @@ def gen_namespace_function(option, multidispatch_formals):

check_methods_do_not_start_with_underscore(option['name'], is_method)

option['method_prefix_derived'] = ''
# NB: Device guard and scalar type generated code is still based on the
# first argument. Scalar type test will be removed once TH is removed.
# If you need more complex device guard behavior, you should disable
Expand All @@ -1233,10 +1164,6 @@ def gen_namespace_function(option, multidispatch_formals):
find_tensorlists(formals))
option['dispatch_scalar_type_declaration'] = dispatch_scalar_type(option, dispatch_options, guard_tensor)

broadcast_arg = get_broadcast_argument(option)
if broadcast_arg is not None:
raise Exception("broadcasting is not yet supported for native functions, "
"but specified for function {}", option['name'])
top_env['aten_ops'].append(OPERATOR_NAME_FULL.substitute(option))

option['native_type_method_dispatch'] = type_method_dispatch
Expand Down Expand Up @@ -1311,7 +1238,6 @@ def gen_namespace_function(option, multidispatch_formals):
category_override=option['category_override'],
matches_jit_signature=option["matches_jit_signature"],
schema_string=option["schema_string"],
method_prefix_derived=option['method_prefix_derived'],
arguments=formals,
method_of=method_of,
mode=option['mode'],
Expand Down Expand Up @@ -1570,11 +1496,6 @@ def process_legacy_th_option(option):
env = nested_dict(option, backend_type_env)
body = emit_body(env, option, option['backend_types'][backend]) # type: ignore
option['type_definition_body'] = body
if option.get('broadcast_actuals', None):
legacy_th_declarations.append(
LEGACY_TH_DECLARATION_BROADCAST.substitute(env))
legacy_th_definitions.append(
LEGACY_TH_DEFINITION_BROADCAST.substitute(env))
legacy_th_declarations.append(
LEGACY_TH_DECLARATION.substitute(env))
legacy_th_definitions.append(
Expand Down

0 comments on commit 122d821

Please sign in to comment.