Skip to content

Commit

Permalink
CI fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jul 15, 2021
1 parent 1f4b963 commit 2104d0d
Showing 1 changed file with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,15 @@ def set_dtype_attr(self):
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())

def tmttml(self, x, transpose_x, y, transpose_y):
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])

def matmul_grad(self, x, transpose_x, y, transpose_y):
x = np.transpose(
x, self.shape_transpose_axes[x.ndim]) if transpose_x else x
y = np.transpose(
Expand Down Expand Up @@ -296,19 +304,19 @@ def calculate_grads(self):
is_broadcast = x.shape[0:-2] != y.shape[0:-2]

if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True:
self.dx = self.tmttml(self.y_fp32, True, dout, True)
self.dy = self.tmttml(dout, True, self.x_fp32, True)
self.dx = self.matmul_grad(self.y_fp32, True, dout, True)
self.dy = self.matmul_grad(dout, True, self.x_fp32, True)
elif self.attrs['trans_x'] is True and self.attrs[
'trans_y'] is False:
self.dx = self.tmttml(self.y_fp32, False, dout, True)
self.dy = self.tmttml(self.x_fp32, False, dout, False)
self.dx = self.matmul_grad(self.y_fp32, False, dout, True)
self.dy = self.matmul_grad(self.x_fp32, False, dout, False)
elif self.attrs['trans_x'] is False and self.attrs[
'trans_y'] is True:
self.dx = self.tmttml(dout, False, self.y_fp32, False)
self.dy = self.tmttml(dout, True, self.x_fp32, False)
self.dx = self.matmul_grad(dout, False, self.y_fp32, False)
self.dy = self.matmul_grad(dout, True, self.x_fp32, False)
else:
self.dx = self.tmttml(dout, False, self.y_fp32, True)
self.dy = self.tmttml(self.x_fp32, True, dout, False)
self.dx = self.matmul_grad(dout, False, self.y_fp32, True)
self.dy = self.matmul_grad(self.x_fp32, True, dout, False)

if is_broadcast:
x_reduce_axis = []
Expand Down Expand Up @@ -340,14 +348,6 @@ def calculate_grads(self):

self.dout = dout

def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])

cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestMatMulV2Bf16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestMatMulV2Bf16OneDNNOp
Expand Down

0 comments on commit 2104d0d

Please sign in to comment.