Skip to content

Commit

Permalink
support some prim ops bf16 dtype (#54263)
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit authored Jun 2, 2023
1 parent 585f113 commit d1d43c2
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,15 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
set_output<T>(out_grad * out, x_grad);
if (out.dtype() == phi::DataType::FLOAT16 ||
out.dtype() == phi::DataType::BFLOAT16) {
Tensor out_promote = cast<T>(out, phi::DataType::FLOAT32);
Tensor out_grad_promote = cast<T>(out_grad, phi::DataType::FLOAT32);
set_output<T>(cast<T>(out_promote * out_grad_promote, out.dtype()),
x_grad);
} else {
set_output<T>(out_grad * out, x_grad);
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def setUp(self):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.if_enable_cinn()
self.convert_input_output()

def test_check_output(self):
self.check_output()
Expand All @@ -129,6 +130,9 @@ def init_shape(self):
def if_enable_cinn(self):
pass

def convert_input_output(self):
pass


class TestExpFp64_Prim(TestExpFp32_Prim):
def init_dtype(self):
Expand Down Expand Up @@ -4003,6 +4007,7 @@ def test_check_grad(self):


create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpFp32_Prim, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True)
Expand Down Expand Up @@ -4133,6 +4138,7 @@ def test_check_grad(self):


create_test_act_bf16_class(TestActivation)
create_test_act_bf16_class(TestExpFp32_Prim, check_prim=True)
create_test_act_bf16_class(TestExpm1)
create_test_act_bf16_class(TestSigmoid, check_prim=True)
create_test_act_bf16_class(TestSilu, check_prim=True)
Expand Down
6 changes: 6 additions & 0 deletions test/legacy_test/test_cast_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def setUp(self):
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.if_enable_cinn()

def if_enable_cinn(self):
self.enable_cinn = False

def test_check_output(self):
Expand All @@ -130,6 +133,9 @@ def setUp(self):
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.if_enable_cinn()

def if_enable_cinn(self):
self.enable_cinn = False

def test_check_output(self):
Expand Down
3 changes: 1 addition & 2 deletions test/legacy_test/test_elementwise_div_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,8 @@ def test_check_gradient(self):
check_args.insert(0, self.place)
self.check_grad_with_place(*check_args, **check_kwargs)

# elementwise_pow does't support bfloat16
def if_check_prim(self):
self.check_prim = False
self.check_prim = True

def if_enable_cinn(self):
self.enable_cinn = False
Expand Down
5 changes: 4 additions & 1 deletion test/legacy_test/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def setUp(self):
self.axis = [0]
self.keepdim = False
self.set_attrs()
self.enable_cinn = False
self.if_enable_cinn()

np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(np.float32)
Expand All @@ -227,6 +227,9 @@ def setUp(self):
'reduce_all': self.reduce_all,
}

def if_enable_cinn(self):
self.enable_cinn = False

def set_attrs(self):
pass

Expand Down
7 changes: 6 additions & 1 deletion test/legacy_test/test_softmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ def get_x_shape(self):
class TestSoftmaxBF16Op(OpTest):
def setUp(self):
self.op_type = "softmax"
self.prim_op_type = "comp"
self.python_api = F.softmax
self.public_python_api = F.softmax
self.use_cudnn = self.init_cudnn()
self.use_mkldnn = False
self.dtype = np.uint16
Expand All @@ -424,7 +426,9 @@ def init_cudnn(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_dygraph=(not self.use_mkldnn))
self.check_output_with_place(
place, check_dygraph=(not self.use_mkldnn), check_prim=True
)

def test_check_grad(self):
place = core.CUDAPlace(0)
Expand All @@ -434,6 +438,7 @@ def test_check_grad(self):
"Out",
numeric_grad_delta=0.05,
check_dygraph=(not self.use_mkldnn),
check_prim=True,
)


Expand Down

0 comments on commit d1d43c2

Please sign in to comment.