Skip to content

Commit

Permalink
add ut for use_global_beta_pow
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed May 11, 2021
1 parent f09b6cf commit d9430dd
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 19 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/optimizers/adam_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;

auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/optimizers/adam_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;

auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
Expand Down
38 changes: 20 additions & 18 deletions paddle/fluid/operators/optimizers/adam_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class AdamNPUKernel : public framework::OpKernel<T> {
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
Expand All @@ -60,28 +59,12 @@ class AdamNPUKernel : public framework::OpKernel<T> {
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");

bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;

param_out->mutable_data<T>(ctx.GetPlace());
mom1_out->mutable_data<T>(ctx.GetPlace());
mom2_out->mutable_data<T>(ctx.GetPlace());

// NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place.
if (beta1_pow->place() == platform::CPUPlace()) {
T beta1 = *beta1_pow->data<T>();
// `mutable_data` operation needs to be done after getting data
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
FillNpuTensorWithConstant<T>(beta1_pow_out, beta1);
} else {
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
}
if (beta2_pow->place() == platform::CPUPlace()) {
T beta2 = *beta2_pow->data<T>();
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
FillNpuTensorWithConstant<T>(beta2_pow_out, beta2);
} else {
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
}

const Tensor* beta1_tensor = nullptr;
const Tensor* beta2_tensor = nullptr;
const Tensor* epsilon_tensor = nullptr;
Expand Down Expand Up @@ -177,6 +160,25 @@ class AdamNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::DeviceContext>(), mom2_out);
}
if (!use_global_beta_pow) {
// NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform
// place.
Tensor beta1_pow_tmp;
Tensor beta2_pow_tmp;
if (beta1_pow->place() == platform::CPUPlace()) {
T beta1 = *beta1_pow->data<T>();
beta1_pow_tmp.mutable_data<T>(ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta1_pow_tmp, beta1);
beta1_pow = &beta1_pow_tmp;
}
if (beta2_pow->place() == platform::CPUPlace()) {
T beta2 = *beta2_pow->data<T>();
beta2_pow_tmp.mutable_data<T>(ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta2_pow_tmp, beta2);
beta2_pow = &beta2_pow_tmp;
}

beta1_pow_out->mutable_data<T>(ctx.GetPlace());
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
auto runner_m1 =
NpuOpRunner("Mul", {*beta1_pow, *beta1_tensor}, {*beta1_pow_out}, {});
runner_m1.Run(stream);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/optimizers/adam_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
beta2_pow_out->numel()));

bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;

T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
Expand Down
59 changes: 59 additions & 0 deletions python/paddle/fluid/tests/unittests/npu/test_adam_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,65 @@ def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)


@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestAdamOpWithGlobalBetaPow(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")

learning_rate = 0.004
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = beta1**10
beta2_pow = beta2**10

self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
'Beta1Tensor': np.array([beta1]).astype("float32"),
'Beta2Tensor': np.array([beta2]).astype("float32"),
'EpsilonTensor': np.array([epsilon]).astype("float32"),
}

attributes = {'epsilon': epsilon}

param_out, moment1_out, \
moment2_out = adam_step(self.inputs, attributes)

self.attrs = {'use_global_beta_pow': True}

# use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty.
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([]),
'Beta2PowOut': np.array([])
}

def set_npu(self):
self.__class__.use_npu = True

def init_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)


@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNet(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ def check_output_with_place(self,
dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)

for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
Expand Down Expand Up @@ -1177,6 +1178,11 @@ def find_actual(target_name, fetch_list):
actual_t = convert_uint16_to_float(actual_t)
atol = 0.03

# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_t.size == 0:
self.assertTrue(actual_t.size == 0)

self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
Expand Down
59 changes: 58 additions & 1 deletion python/paddle/fluid/tests/unittests/test_adam_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def test_check_output(self):

class TestAdamOpBetaEpsilonVariable(OpTest):
def setUp(self):
'''Test Adam Op with beta as Variable
'''Test Adam Op with beta/epsilon as Variable
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
Expand Down Expand Up @@ -450,6 +450,57 @@ def test_check_output(self):
self.check_output()


class TestAdamOpWithGlobalBetaPow(OpTest):
def setUp(self):
'''Test Adam Op with global_beta_pow
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
beta1 = 0.85
beta2 = 0.95

learning_rate = 0.001
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10

self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
"Beta1Tensor": np.array([beta1]).astype("float32"),
"Beta2Tensor": np.array([beta2]).astype("float32"),
"EpsilonTensor": np.array([epsilon]).astype("float32"),
}

attributes = {'epsilon': epsilon}

param_out, moment1_out, \
moment2_out = adam_step(self.inputs, attributes)

self.attrs = {'use_global_beta_pow': True}

# use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty.
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([]),
'Beta2PowOut': np.array([])
}

def test_check_output(self):
self.check_output()


class TestAdamOpV2(unittest.TestCase):
def test_adam_op(self):
place = fluid.CPUPlace()
Expand Down Expand Up @@ -493,6 +544,7 @@ def test_adam_op_dygraph(self):
out.backward()
adam.step()
adam.clear_gradients()
paddle.enable_static()

def test_adam_op_with_state_dict(self):

Expand Down Expand Up @@ -523,6 +575,7 @@ def test_adam_op_with_state_dict(self):

params = adam.get_opti_var_name_list()
assert (params is not None)
paddle.enable_static()

def test_adam_with_grad_clip(self):
paddle.disable_static()
Expand All @@ -536,6 +589,7 @@ def test_adam_with_grad_clip(self):
out.backward()
adam.step()
adam.clear_gradients()
paddle.enable_static()

def test_adam_op_with_set_lr(self):
paddle.disable_static()
Expand All @@ -550,6 +604,7 @@ def test_adam_op_with_set_lr(self):
lr_var = paddle.fluid.layers.create_global_var(
shape=[1], value=lr, dtype='float32')
adam.set_lr(lr_var)
paddle.enable_static()

def test_adam_op_invalid_input(self):
paddle.disable_static()
Expand All @@ -563,6 +618,7 @@ def test_adam_op_invalid_input(self):
with self.assertRaises(ValueError):
adam = paddle.optimizer.Adam(
0.1, epsilon=-1, parameters=linear.parameters())
paddle.enable_static()

def test_adam_op_with_sparse_input_and_weight_decay(self):

Expand All @@ -577,6 +633,7 @@ def test_adam_op_with_sparse_input_and_weight_decay(self):
out = emb(x)
out.backward()
adam.step()
paddle.enable_static()


class TestNetWithEpsilonTensor(unittest.TestCase):
Expand Down

0 comments on commit d9430dd

Please sign in to comment.