From 34ac7b74c216bd02d44d9bc57b1537343adc0934 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 22 Apr 2022 19:44:09 +0800 Subject: [PATCH] Support triple grad check of op in Eager mode (#42131) * support 3-rd order gradient * change code format --- .../fluid/tests/unittests/gradient_checker.py | 222 +++++++++++++++--- .../unittests/test_elementwise_nn_grad.py | 21 ++ 2 files changed, 204 insertions(+), 39 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/gradient_checker.py b/python/paddle/fluid/tests/unittests/gradient_checker.py index 562d52668ce5b..569d994b831b6 100644 --- a/python/paddle/fluid/tests/unittests/gradient_checker.py +++ b/python/paddle/fluid/tests/unittests/gradient_checker.py @@ -60,19 +60,6 @@ def _get_item(t, i, np_dtype): raise ValueError("Not supported data type " + str(np_dtype)) -def _get_item_for_dygraph(t, i, np_dtype): - if np_dtype == np.float16: - np_t = t.numpy().astype(np.float16) - elif np_dtype == np.float32: - np_t = t.numpy().astype(np.float32) - elif np_dtype == np.float64: - np_t = t.numpy().astype(np.float64) - else: - raise ValueError("Not supported data type " + str(np_dtype)) - np_t = np_t.flatten() - return np_t[i] - - def _set_item(t, i, e, np_dtype): if np_dtype == np.float16: np_t = np.array(t).astype(np.float16) @@ -89,22 +76,6 @@ def _set_item(t, i, e, np_dtype): raise ValueError("Not supported data type " + str(np_dtype)) -def _set_item_for_dygraph(t, i, e, np_dtype): - if np_dtype == np.float16: - np_t = t.numpy().astype(np.float16) - elif np_dtype == np.float32: - np_t = t.numpy().astype(np.float32) - elif np_dtype == np.float64: - np_t = t.numpy().astype(np.float64) - else: - raise ValueError("Not supported data type " + str(np_dtype)) - shape = np_t.shape - np_t = np_t.flatten() - np_t[i] = e - np_t = np_t.reshape(shape) - paddle.assign(np_t, t) - - def set_var_in_scope(scope, place, name, value, recursive_seq_len=None): t = scope.var(name).get_tensor() t.set(value, place) @@ -169,8 +140,6 @@ def run(): np_type = dtype_to_np_dtype(x.dtype) jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y] - if np_type == np.float64: - delta = 1e-5 for i in six.moves.xrange(x_size): orig = _get_item(x_t, i, np_type) x_pos = orig + delta @@ -545,7 +514,12 @@ def triple_grad_check(x, rtol=rtol) -def get_static_double_grad(x, y, x_init=None, dy_init=None, place=None): +def get_static_double_grad(x, + y, + x_init=None, + dy_init=None, + place=None, + program=None): """ Get Double Grad result of static graph. @@ -555,11 +529,14 @@ def get_static_double_grad(x, y, x_init=None, dy_init=None, place=None): x_init (numpy.array|list[numpy.array]|None): the init value for input x. dy_init (numpy.array|list[numpy.array]|None): the init value for output y. place (fluid.CPUPlace or fluid.CUDAPlace): the device. + program (Program|None): a Program with forward pass. + If None, use fluid.default_main_program(). Returns: A list of numpy array that stores second derivative result calulated by static graph. """ - program = fluid.default_main_program() + if program is None: + program = fluid.default_main_program() scope = fluid.executor.global_scope() y_grads = [] for i in six.moves.xrange(len(y)): @@ -635,7 +612,10 @@ def get_static_double_grad(x, y, x_init=None, dy_init=None, place=None): return ddx_res -def get_eager_double_grad(func, x_init=None, dy_init=None): +def get_eager_double_grad(func, + x_init=None, + dy_init=None, + return_mid_result=False): """ Get Double Grad result of dygraph. @@ -643,8 +623,13 @@ def get_eager_double_grad(func, x_init=None, dy_init=None): func: A wrapped dygraph function that its logic is equal to static program x_init (numpy.array|list[numpy.array]|None): the init value for input x. dy_init (numpy.array|list[numpy.array]|None): the init value for gradient of output. + return_mid_result (bool): A flag that controls the return content. Returns: - A list of numpy array that stores second derivative result calulated by dygraph + If 'return_mid_result' set True. + the second order derivative and the inputs of second order derivative's calculation + will be returned for higher order derivative's calculation. + If 'return_mid_result' set False. + A list of numpy array that stores second derivative result calulated by dygraph. """ inputs = [] dys = [] @@ -664,13 +649,25 @@ def get_eager_double_grad(func, x_init=None, dy_init=None): # calcluate second derivative inputs = inputs + dys ddys = [] + if return_mid_result: + create_graph = True + else: + create_graph = False + for d_input in d_inputs: d_input.stop_gradient = False ddy = paddle.ones(shape=d_input.shape, dtype=d_input.dtype) ddy.stop_gradient = False ddys.append(ddy) - dd_inputs = paddle.grad(outputs=d_inputs, inputs=inputs, grad_outputs=ddys) - return [dd_input.numpy() for dd_input in dd_inputs] + dd_inputs = paddle.grad( + outputs=d_inputs, + inputs=inputs, + grad_outputs=ddys, + create_graph=create_graph) + if return_mid_result: + return dd_inputs, inputs + ddys + else: + return [dd_input.numpy() for dd_input in dd_inputs] def double_grad_check_for_dygraph(func, @@ -682,8 +679,9 @@ def double_grad_check_for_dygraph(func, rtol=1e-3, raise_exception=True): """ - Check gradients of gradients. This function will append backward to the - program before second order gradient check. + Check second order gradients of dygraph. This function will compare the + second order gradients of dygraph and second order gradients of static graph + to validate dygraph's correctness Args: func: A wrapped dygraph function that its logic is equal to static program @@ -734,3 +732,149 @@ def fail_test(msg): 'static:%s\n eager:%s\n' \ % (static_double_grad[i].name, eager_double_grad[i].name, str(place), static_double_grad[i], eager_double_grad[i]) return fail_test(msg) + + +def get_static_triple_grad(x, + y, + x_init=None, + dy_init=None, + place=None, + program=None): + """ + Get Triple Grad result of static graph. + + Args: + x (Variable|list[Variable]): input variables to the program. + y (Variable|list[Variable]): output variables to the program. + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + dy_init (numpy.array|list[numpy.array]|None): the init value for output y. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + program (Program|None): a Program with forward pass. + If None, use fluid.default_main_program(). + Returns: + A list of numpy array that stores third derivative result calulated by static graph. + """ + if program is None: + program = fluid.default_main_program() + scope = fluid.executor.global_scope() + y_grads = [] + for i in six.moves.xrange(len(y)): + yi = y[i] + dyi_name = _append_grad_suffix_(yi.name) + np_type = dtype_to_np_dtype(yi.dtype) + dy = program.global_block().create_var( + name=dyi_name, shape=yi.shape, dtype=np_type, persistable=True) + dy.stop_gradient = False + set_var_in_scope(scope, place, dyi_name, dy_init[i]) + y_grads.append(dy) + + # append first order grads + dx = fluid.gradients(y, x, y_grads) + + # y_grads are the input of first-order backward, + # so, they are also the input of second-order backward. + x += y_grads + x_init += dy_init + y = dx + + x_grads_grads_init = [] + for dxi in dx: + np_type = dtype_to_np_dtype(dxi.dtype) + value = np.ones(dxi.shape, dtype=np_type) + x_grads_grads_init.append(value) + + return get_static_double_grad( + x, y, x_init, dy_init=x_grads_grads_init, place=place, program=program) + + +def get_eager_triple_grad(func, + x_init=None, + dy_init=None, + return_mid_result=False): + """ + Get triple Grad result of dygraph. + + Args: + func: A wrapped dygraph function that its logic is equal to static program + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + dy_init (numpy.array|list[numpy.array]|None): the init value for gradient of output. + return_mid_result (list[Tensor], list[Tensor]): If set True, the + Returns: + A list of numpy array that stores second derivative result calulated by dygraph + """ + dd_y, dd_x = get_eager_double_grad( + func, x_init, dy_init, return_mid_result=True) + + # calcluate third derivative + dddys = [] + for dd_yi in dd_y: + dd_yi.stop_gradient = False + dddy = paddle.ones(shape=dd_yi.shape, dtype=dd_yi.dtype) + dddy.stop_gradient = False + dddys.append(dddy) + ddd_inputs = paddle.grad(outputs=dd_y, inputs=dd_x, grad_outputs=dddys) + return [ddd_input.numpy() for ddd_input in ddd_inputs] + + +def triple_grad_check_for_dygraph(func, + x, + y, + x_init=None, + place=None, + atol=1e-5, + rtol=1e-3, + raise_exception=True): + """ + Check third order gradients of dygraph. This function will compare the + third order gradients of dygraph and third order gradients of static graph + to validate dygraph's correctness + + Args: + func: A wrapped dygraph function that its logic is equal to static program + x (Variable|list[Variable]): input variables to the program. + y (Variable|list[Variable]): output variables to the program. + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + eps (float): perturbation for finite differences. + atol (float): absolute tolerance. + rtol (float): relative tolerance. + raise_exception (bool): whether to raise an exception if + the check fails. Default is True. + """ + + def fail_test(msg): + if raise_exception: + raise RuntimeError(msg) + return False + + # check input arguments + x = _as_list(x) + for v in x: + v.stop_gradient = False + v.persistable = True + y = _as_list(y) + + y_grads_init = [] + for yi in y: + np_type = dtype_to_np_dtype(yi.dtype) + v = np.random.random(size=yi.shape).astype(np_type) + y_grads_init.append(v) + + x_init = _as_list(x_init) + + paddle.disable_static() + with _test_eager_guard(): + eager_triple_grad = get_eager_triple_grad(func, x_init, y_grads_init) + paddle.enable_static() + + static_triple_grad = get_static_triple_grad(x, y, x_init, y_grads_init, + place) + + for i in six.moves.xrange(len(static_triple_grad)): + if not np.allclose(static_triple_grad[i], eager_triple_grad[i], rtol, + atol): + msg = 'Check eager double result fail. Mismatch between static_graph double grad %s ' \ + 'and eager double grad %s on %s,\n' \ + 'static:%s\n eager:%s\n' \ + % (static_triple_grad[i].name, eager_triple_grad[i].name, str(place), static_triple_grad[i], eager_triple_grad[i]) + return fail_test(msg) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py b/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py index c51c8098706a6..8f6f9851c7006 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core @@ -45,6 +46,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -72,6 +74,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -99,6 +102,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -126,6 +130,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -153,6 +158,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -180,6 +186,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -208,6 +215,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps, atol=1e-3) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -236,6 +244,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps, atol=1e-3) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -263,6 +272,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -290,6 +300,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -298,6 +309,9 @@ def test_grad(self): class TestElementwiseMulTripleGradCheck(unittest.TestCase): + def multiply_wrapper(self, x): + return paddle.multiply(x[0], x[1]) + @prog_scope() def func(self, place): # the shape of input variable should be clearly specified, not inlcude -1. @@ -315,8 +329,14 @@ def func(self, place): gradient_checker.triple_grad_check( [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + gradient_checker.triple_grad_check_for_dygraph( + self.multiply_wrapper, [x, y], + out, + x_init=[x_arr, y_arr], + place=place) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -344,6 +364,7 @@ def func(self, place): [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) def test_grad(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0))