Skip to content

Commit

Permalink
fix some bugs, support python_out_sig
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 committed Mar 24, 2022
1 parent 733fc9c commit f864790
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,12 @@ def parse_attri_value(name, op_inputs, op_attrs):
if name in op_proto_attrs:
return op_proto_attrs[name]
elif name in op_inputs:
if len(op_inputs[name]) == 1:
if len(op_inputs[name]) == 1:
# why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op]
# why we reconstruct a tensor: because we want the tensor in cpu.
return paddle.to_tensor(op_inputs[name][0].numpy(), place='cpu')
else:
return paddle.to_tensor(
op_inputs[name][0].numpy(), place='cpu')
else:
# if this is a list (test_unsqueeze2_op): we just pass it into the python api.
return op_inputs[name]
else:
Expand Down Expand Up @@ -828,7 +829,9 @@ def _get_kernel_signature(eager_tensor_inputs, eager_tensor_outputs,
""" we think the kernel_sig is missing.
"""
kernel_sig = None
print ("[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state." % self.op_type)
print(
"[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state."
% self.op_type)
return kernel_sig

def cal_python_api(python_api, args, kernel_sig):
Expand Down Expand Up @@ -1946,16 +1949,17 @@ def _get_dygraph_grad(self,
attrs_outputs[attrs_name] = self.attrs[attrs_name]

if check_eager:
eager_outputs = self._calc_python_api_output(place, inputs, outputs)
eager_outputs = self._calc_python_api_output(place, inputs,
outputs)
# if outputs is None, kernel sig is empty or other error is happens.
if not check_eager or outputs is None:
if not check_eager or eager_outputs is None:
block.append_op(
type=self.op_type,
inputs=inputs,
outputs=outputs,
attrs=attrs_outputs if hasattr(self, "attrs") else None)
else:
output = eager_outputs
else:
outputs = eager_outputs

if self.dtype == np.uint16:
cast_inputs = self._find_var_in_dygraph(outputs,
Expand Down

0 comments on commit f864790

Please sign in to comment.