Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 committed Jun 15, 2023
1 parent 22e19e1 commit 68ebf55
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions test/dygraph_to_static/test_basic_api_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,17 @@ def dyfunc_CosineDecay():

def dyfunc_ExponentialDecay():
base_lr = 0.1
exponential_decay = fluid.dygraph.ExponentialDecay(
learning_rate=base_lr, decay_steps=10000, decay_rate=0.5, staircase=True
exponential_decay = paddle.optimizer.lr.ExponentialDecay(
learning_rate=base_lr, gamma=0.5
)
lr = exponential_decay()
return lr


def dyfunc_InverseTimeDecay():
base_lr = 0.1
inverse_time_decay = fluid.dygraph.InverseTimeDecay(
learning_rate=base_lr, decay_steps=10000, decay_rate=0.5, staircase=True
inverse_time_decay = paddle.optimizer.lr.InverseTimeDecay(
learning_rate=base_lr, gamma=0.5
)
lr = inverse_time_decay()
return lr
Expand Down Expand Up @@ -424,11 +424,53 @@ class TestDygraphBasicApi_ExponentialDecay(TestDygraphBasicApi_CosineDecay):
def setUp(self):
self.dygraph_func = dyfunc_ExponentialDecay

def get_dygraph_output(self):
with fluid.dygraph.guard():
fluid.default_startup_program.random_seed = SEED
fluid.default_main_program.random_seed = SEED
res = self.dygraph_func()
return res

def get_static_output(self):
startup_program = fluid.Program()
startup_program.random_seed = SEED
main_program = fluid.Program()
main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program):
static_out = dygraph_to_static_func(self.dygraph_func)()
static_out = paddle.to_tensor(static_out)

exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]


class TestDygraphBasicApi_InverseTimeDecay(TestDygraphBasicApi_CosineDecay):
def setUp(self):
self.dygraph_func = dyfunc_InverseTimeDecay

def get_dygraph_output(self):
with fluid.dygraph.guard():
fluid.default_startup_program.random_seed = SEED
fluid.default_main_program.random_seed = SEED
res = self.dygraph_func()
return res

def get_static_output(self):
startup_program = fluid.Program()
startup_program.random_seed = SEED
main_program = fluid.Program()
main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program):
static_out = dygraph_to_static_func(self.dygraph_func)()
static_out = paddle.to_tensor(static_out)

exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]


class TestDygraphBasicApi_NaturalExpDecay(TestDygraphBasicApi_CosineDecay):
def setUp(self):
Expand Down

0 comments on commit 68ebf55

Please sign in to comment.