diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 5ac607018856f..260455ad61251 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -360,7 +360,7 @@ def try_call_once(self, data_type): def is_bfloat16_op(self): return self.dtype == np.uint16 or ( hasattr(self, 'mkldnn_data_type') and - getattr(self, 'mkldnn_data_type') is "bfloat16") + getattr(self, 'mkldnn_data_type') is "bfloat16") or ('mkldnn_data_type' in self.attrs and self.attrs['mkldnn_data_type'] == 'bfloat16') def infer_dtype_from_inputs_outputs(self, inputs, outputs): def is_np_data(input):