Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,6 @@ def __getitem__(self, x):
'annotation': str,
'allocate': bool,
'mask': bool,
'resize': str,
'zero': bool,
}, total=False)

# Generic ATen formal or native_functions.yaml formal argument.
Expand Down Expand Up @@ -1339,15 +1337,6 @@ def allocate_arg(arg, output_count, backend, scalar_name):
'auto {} = Tensor({}::reclaim({}));'.format(name, intrusive_ptr_type, tensor_arg),
]

def resize_arg(arg):
# type: (THFormal) -> str
resize = arg['resize']
if isinstance(resize, str):
return "{}.resize_({}.sizes());".format(arg['name'], resize)
else:
dims = ['{}.size({})'.format(name, dim) for name, dim in resize]
return "{}.resize_({{ {} }});".format(arg['name'], ','.join(dims))

def handle_call(env, option, cimpl):
# type: (Environment, FunctionOption, FunctionOption) -> str
is_nn = option['mode'] == 'NN'
Expand Down Expand Up @@ -1433,14 +1422,6 @@ def emit_body(env, option, scalar_type_cases):

initializers = []

# resize tensors for special ops that require it
if 'resize' in arg:
initializers.append(resize_arg(arg))

# also special handling where we zero some outputs.
if arg.get('zero', False):
initializers.append("{}.zero_();".format(arg['name']))

# only initialize non-null arguments
if nullable_argument(arg) and len(initializers) > 0:
case_body.append(CONDITIONAL_INITIALIZER.substitute({
Expand Down
9 changes: 0 additions & 9 deletions aten/src/ATen/nn_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,6 @@ def initialize_output_arg(arg):
arg['mask'] = True
arg['is_nullable'] = True

# grad_weight and grad_bias need to be resized and zeroed
if arg['name'] == 'grad_weight' and base['name'] != '_thnn_conv2d' and base['name'] != '_thnn_conv_depthwise2d':
arg['resize'] = 'weight'
arg['zero'] = True
if arg['name'] == 'grad_bias' and base['name'] != '_thnn_conv2d' and base['name'] != '_thnn_conv_depthwise2d':
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hahahahah NEVER MIND EARLIER COMMENTS

dim = 1 if 'transpose' in name else 0
arg['resize'] = [('weight', dim)]
arg['zero'] = True

is_batch_norm_backward = '_backward' in thnn_functions[0].name
grad_params = []
if len(thnn_functions) > 1 or is_batch_norm_backward:
Expand Down