From 9e764d82036d91333e95a75348ba7c3b8f583005 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Sat, 2 Apr 2022 06:51:55 +0800 Subject: [PATCH 01/93] Enhance vjp/jvp/Jacobian/Hessian API for supporting dynamic, static graph and batched, unbatched mode (#40692) * modify vjp/jvp for both dynamic and static graph * enforce jacobian class for supporting first/last batch * add unittest for jvp, jacobian withlast batch, jacobian with first batch * fix the incorrect shape when multi-index Jacobian * enforce Hessian class for supporting dynamic graph * add Hessian class unittest * bugfix, jvp double_backward_trick zeros_like return stop_gradient=True in static graph * add API beta warnnings * add white_list for cuda11.x ci windows. * optimize some code snippets and documments * set unittest timeout to 100 seconds * move vjp,jvp,Jacobian,Hessian to incubate * fix vjp,vjp import path of sample code * fix code style error of augtograd/__init__ file --- python/paddle/autograd/__init__.py | 18 +- python/paddle/autograd/functional.py | 1081 +++++++++------ python/paddle/autograd/utils.py | 45 - .../tests/unittests/autograd/CMakeLists.txt | 5 +- .../fluid/tests/unittests/autograd/config.py | 49 + .../test_autograd_functional_dynamic.py | 1233 +++++++++++++++++ .../test_autograd_functional_static.py | 455 ++++++ .../autograd/test_autograd_static.py | 308 ---- .../tests/unittests/autograd/test_hessian.py | 263 ---- .../tests/unittests/autograd/test_jacobian.py | 319 ----- .../tests/unittests/autograd/test_vhp.py | 182 --- .../tests/unittests/autograd/test_vjp_jvp.py | 315 ----- .../fluid/tests/unittests/autograd/utils.py | 231 ++- python/paddle/incubate/__init__.py | 1 + python/paddle/incubate/autograd/__init__.py | 18 + python/setup.py.in | 1 + tools/windows/run_unittests.sh | 114 +- 17 files changed, 2731 insertions(+), 1907 deletions(-) delete mode 100644 python/paddle/autograd/utils.py create mode 100644 python/paddle/fluid/tests/unittests/autograd/config.py create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py create mode 100644 python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py delete mode 100644 python/paddle/fluid/tests/unittests/autograd/test_autograd_static.py delete mode 100644 python/paddle/fluid/tests/unittests/autograd/test_hessian.py delete mode 100644 python/paddle/fluid/tests/unittests/autograd/test_jacobian.py delete mode 100644 python/paddle/fluid/tests/unittests/autograd/test_vhp.py delete mode 100644 python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py create mode 100644 python/paddle/incubate/autograd/__init__.py diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index 7aab7117de905..b13a4591b4ef2 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,12 +13,18 @@ # limitations under the License. from ..fluid.dygraph.base import grad # noqa: F401 +from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 +from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401 from . import backward_mode # noqa: F401 from .backward_mode import backward # noqa: F401 from .py_layer import PyLayer, PyLayerContext, EagerPyLayer, EagerPyLayerContext # noqa: F401 from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 -from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401 -from .functional import vjp, jvp, vhp # noqa: F401 +from .functional import vjp, jvp, Jacobian, Hessian # noqa: F401 +from .functional import jacobian, hessian, batch_jacobian, batch_hessian, vhp # noqa: F401 -__all__ = ['backward', 'PyLayer', 'PyLayerContext'] +__all__ = [ # noqa + 'backward', + 'PyLayer', + 'PyLayerContext', +] diff --git a/python/paddle/autograd/functional.py b/python/paddle/autograd/functional.py index c663d37e7f2ab..8e027c270b700 100644 --- a/python/paddle/autograd/functional.py +++ b/python/paddle/autograd/functional.py @@ -12,236 +12,686 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib +import functools +import typing + import paddle -from paddle.static import gradients -from ..fluid import framework -from ..fluid.dygraph import grad -from ..tensor.creation import assign -from ..tensor import reshape, zeros_like, to_tensor -from .utils import _tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor - - -@contextlib.contextmanager -def gradient_scope(*var_lists, create_graph=False, allow_unused=False): - def grad_fn(ys, xs, v=None, create_graph=create_graph): - if v is not None: - assert len(ys) == len(v), ( - f'The argument {v} is expected to be of the same size as the output. ' - f'Here the output is {ys}, and `v` is {v}.') - if allow_unused: - ys = [ - to_tensor( - [0.0], stop_gradient=False) if y is None else y for y in ys - ] - return grad( - ys, xs, v, create_graph=create_graph, allow_unused=allow_unused) - - def return_fn(out): - if isinstance(out, paddle.Tensor): - if not create_graph: - out = out.detach() - return out - if isinstance(out, list): - return list(return_fn(x) for x in out) - elif isinstance(out, tuple): - return tuple(return_fn(x) for x in out) - else: - assert out is None - return out - - def process(vl): - if vl is None: - return None - out = [] - # If v is treated as constant in the outer scope, its gradient is guaranteed - # not to be taken beyond this scope. Within this scope, however, v's gradient - # may be computed. We only need to detach v in this case. - # Otherwise, v's gradient is valid, and is subject to update beyond this scope. - # In this case we must not confuse the gradient in the outer scope with the - # inner one's. Moreover, we need to make sure that the result from the inner - # scope can flow back to the outer scope. This can be satisfied by extending - # the original variable with a duplication operation v1 = v so that v still - # maintains the complete lineage. - for v in vl: - if v is None: - out.append(v) - continue - if create_graph and not v.stop_gradient: - v = assign(v) - else: - v = v.detach() - v.stop_gradient = False - out.append(v) - return out - - try: - var_lists = [process(vl) for vl in var_lists] - bundle = var_lists + [grad_fn, return_fn] - yield bundle - finally: - pass +from paddle.fluid import framework -@framework.dygraph_only -def vjp(func, inputs, v=None, create_graph=False, allow_unused=False): +def vjp(func, xs, v=None): r"""Computes the Vector-Jacobian product, a functional form of reverse mode automatic differentiation. + Warning: + This API is in beta, the signatures could be changed in future version. + Args: - func(Callable): `func` takes as input a tensor or a list/tuple - of tensors and returns a tensor or a list/tuple of tensors. - inputs(list[Tensor]|tuple[Tensor]|Tensor): used as positional - arguments to evaluate `func`. `inputs` is accepted as one - tensor or a list of tensors. - v(list[Tensor]|tuple[Tensor]|Tensor|None, optional): the - cotangent vector invovled in the VJP computation. `v` matches - the size and shape of `func`'s output. Default value is None - and in this case is equivalent to all ones the same size - of `func`'s output. - create_graph(bool, optional): if `True`, gradients can be - evaluated on the results. If `False`, taking gradients on - the results is invalid. Default value is False. - allow_unused(bool, optional): In case that some Tensors of - `inputs` do not contribute to the computation of the output. - If `allow_unused` is False, an error will be raised, - Otherwise, the gradients of the said inputs are returned - None. Default value is False. + func(Callable): A function that takes ``xs`` as inputs parameter and + returns a sequence of Tensors or a Tensor. + xs(Tensor|Sequence[Tensor]): Used as positional arguments to evaluate + ``func``. ``xs`` is accepted as one Tensor or a sequence of Tensors. + v(Tensor|Sequence[Tensor]|None, optional): The cotangent vector invovled + in the VJP computation. ``v`` matches the size and shape of + ``func`` 's output. Defaults to None, which is equivalent to all + ones the same size of ``func`` 's output. Returns: output(tuple): - func_out(list[Tensor]|tuple[Tensor]|Tensor): the output of - `func(inputs)` - vjp(list[Tensor]): the pullback results of `v` on `func` + + - func_out(Tensor|tuple[Tensor]): The output of ``func(xs)`` . + - vjp(Tensor|tuple[Tensor]): The vjp result. Examples: - .. code-block:: python - - def func(x): - return paddle.matmul(x, x) - - x = paddle.ones(shape=[2, 2], dtype='float32') - output, inputs_grad = vjp(func, x) - print(inputs_grad) - # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[4., 4.], - # [4., 4.]])] - - v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]]) - output, inputs_grad = vjp(func, x, v) - print(inputs_grad) - # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[2., 1.], - # [1., 0.]])] - - output, inputs_grad = vjp(func, x, v, create_graph=True) - print(inputs_grad) - # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False, - # [[2., 1.], - # [1., 0.]])] - - y = paddle.ones(shape=[2, 2], dtype='float32') - def func_unused(x, y): - return paddle.matmul(x, x) - - output, inputs_grad = vjp(func, [x, y], v) - # ValueError: (InvalidArgument) The 1-th input does not appear in the backward graph. - # Please check the input variable or set allow_unused=True to get None result. - # [Hint: Expected allow_unused_ == true, but received allow_unused_:0 != true:1.] - - output, inputs_grad = vjp(func, [x, y], v, allow_unused=True) - print(inputs_grad) - # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[2., 1.], - # [1., 0.]]), None] + + .. code-block:: python + + import paddle + + def func(x): + return paddle.matmul(x, x) + + x = paddle.ones(shape=[2, 2], dtype='float32') + _, vjp_result = paddle.incubate.autograd.vjp(func, x) + print(vjp_result) + # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [[4., 4.], + # [4., 4.]]) + + v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]]) + _, vjp_result = paddle.incubate.autograd.vjp(func, x, v) + print(vjp_result) + # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [[2., 1.], + # [1., 0.]]) """ - xs = _tensors(inputs, "inputs") - if v is not None: - v = _tensors(v, "v") + _check_inputs(func, xs, v) - with gradient_scope( - xs, v, create_graph=create_graph, - allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: - outputs = func(*xs) - ys = _tensors(outputs, "outputs") - grads = grad_fn(ys, xs, v) - outputs, grads = return_fn(outputs), return_fn(grads) + # ``_seprate`` breaks the dependencies between ``xs`` and other + # variables. See more ``_seprate`` . + xs, v = _separate(xs), _separate(v) + ys = func(*xs) if isinstance(xs, typing.Sequence) else func(xs) + _check_v_shape(v, ys) - return outputs, grads + return ys, _grad(ys, xs, v) -@framework.dygraph_only -def jvp(func, inputs, v=None, create_graph=False, allow_unused=False): +def jvp(func, xs, v=None): r""" Computes the Jacobian-Vector product for a function at the given inputs and a vector in the tangent space induced by the inputs. - .. note:: - **This API is ONLY available in imperative mode.** + Warning: + This API is in beta, the signatures could be changed in future version. Args: - func(Callable): `func` takes as input a tensor or a list/tuple - of tensors and returns a tensor or a list/tuple of tensors. - inputs(list[Tensor]|tuple[Tensor]|Tensor): used as positional - arguments to evaluate `func`. `inputs` is accepted as one - tensor or a list/tuple of tensors. - v(list[Tensor]|tuple[Tensor]|Tensor|None, optional): the - tangent vector invovled in the JVP computation. `v` matches - the size and shape of `inputs`. `v` is Optional if `func` - returns a single tensor. Default value is None and in this - case is equivalent to all ones the same size of `inputs`. - create_graph(bool, optional): if `True`, gradients can - be evaluated on the results. If `False`, taking gradients - on the results is invalid. Default value is False. - allow_unused(bool, optional): In case that some Tensors of - `inputs` do not contribute to the computation of the output. - If `allow_unused` is False, an error will be raised, - Otherwise, the gradients of the said inputs are returned - None. Default value is False. + func(Callable): The ``func`` takes as input a Tensor or a Sequence + of Tensors and returns a Tensor or a Sequence of Tensors. + xs(Tensor|Sequence[Tensor]): Used as positional arguments to + evaluate ``func``. The ``xs`` is accepted as one Tensor or a + Sequence of Tensors. + v(Tensor|Sequence[Tensor]|None, Optional): The tangent vector invovled + in the JVP computation. The ``v`` matches the size and shape of + ``xs`` . Default value is None and in this case is equivalent to + all ones the same size of ``xs`` . Returns: output(tuple): - func_out(list[Tensor]|tuple[Tensor]|Tensor): the output of - `func(inputs)` - jvp(list[Tensor]): the pullback results of `v` on `func` + + - func_out(Tensor|tuple[Tensor]): The output of ``func(xs)`` . + - jvp(Tensor|tuple[Tensor]): The jvp result. + + Examples: + + .. code-block:: python + + import paddle + + + def func(x): + return paddle.matmul(x, x) + + + x = paddle.ones(shape=[2, 2], dtype='float32') + _, jvp_result = paddle.incubate.autograd.jvp(func, x) + print(jvp_result) + # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [[4., 4.], + # [4., 4.]]) + v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]]) + _, jvp_result = paddle.incubate.autograd.jvp(func, x, v) + print(jvp_result) + # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [[2., 1.], + # [1., 0.]]) + + """ + _check_inputs(func, xs, v) + # ``_seprate`` breaks the dependencies between ``xs`` and other + # variables. See more ``_seprate`` . + xs, v = _separate(xs), _separate(v) + ys = func(*xs) if isinstance(xs, typing.Sequence) else func(xs) + _check_v_shape(v, xs) + return ys, _double_backward_trick(ys, xs, v) + + +def _double_backward_trick(ys, xs, v): + """Double backward trick for computing ``jvp`` by ``vjp`` + see details: https://j-towns.github.io/2017/06/12/A-new-trick.html + """ + # The value of ys_grad is not important, it can be any random value in + # theory, but it's required to set stop_gradient=False. + ys_grad = _zeros_like_with_grad(ys) + xs_grad = _grad(ys, xs, ys_grad) + return _grad(xs_grad, ys_grad, v) + + +def _zeros_like_with_grad(xs): + """Create a zero or zeros sequence Tensor like ``xs`` with a flag + ``stop_graident=False`` . + """ + if not isinstance(xs, typing.Sequence): + ys = paddle.zeros_like(xs) + ys.stop_gradient = False + else: + ys = [] + for x in xs: + y = paddle.zeros_like(x) + y.stop_gradient = False + ys.append(y) + return ys + + +class Jacobian(object): + r""" + Computes the Jacobian matrix of a given function. + + If the function has multiple inputs and multiple outputs, during internal + implementation, all input tensors are concatenated after being flatten, + the batch dimension is retained, and the output is subject to the same + processing rules. + + Once the Jacobian ``J`` is constructed, you can use a multidimensional index + to retrieve the submatrix of ``J``, as same as slicing a Tensor. The + submatrix is lazily evaluated along row axis, and will be cached once + evaluated. + + For examples, supposing ``is_batched=True``, you can retrieve the submatrix + by following methods: + + * J[:], retrieving the full matrix. + * J[:, :, j], retrieving the partial derivatives w.r.t. the j'th input + variable. + * J[:, i, :], retrieving the partial derivatives w.r.t. the i'th output + variable. + * J[:, i, j], retrieving the partial derivatives w.r.t. the i'th output + variable and the j'th input variable. + + Notes: + + Eclipsis index is not supported currently. + + Warning: + This API is in beta, the signatures could be changed in future version. + + Args: + + func (Callable): A python function that takes a Tensor or a sequence of + Tensors as inputs(the first dimension is batch size) and + returns a Tensor a sequence of Tensors. + xs (Tensor|Sequence[Tensor]): The input to the function ``func`` . + is_batched (bool): If true, the first axis is batch axis. Defaults to + False. + + Returns: + + Jacobian (Object): A python object retains the Jacobian matrix. + + Examples: + + .. code-block:: python + + import paddle + + + def func(x, y): + return paddle.matmul(x, y) + + + x = paddle.to_tensor([[1., 2.], [3., 4.]]) + J = paddle.incubate.autograd.Jacobian(func, [x, x]) + print(J[:, :]) + # Tensor(shape=[4, 8], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [[1., 3., 0., 0., 1., 0., 2., 0.], + # [2., 4., 0., 0., 0., 1., 0., 2.], + # [0., 0., 1., 3., 3., 0., 4., 0.], + # [0., 0., 2., 4., 0., 3., 0., 4.]]) + + print(J[0, :]) + # Tensor(shape=[8], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [1., 3., 0., 0., 1., 0., 2., 0.]) + print(J[:, 0]) + # Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [1., 2., 0., 0.]) + + """ + + def __init__(self, func, xs, is_batched=False): + if not is_batched: + self._jacobian = _JacobianNoBatch(func, xs) + else: + self._jacobian = _JacobianBatchFirst(func, xs) + + def __getitem__(self, indexes): + return self._jacobian[indexes] + + @property + def shape(self): + """The shape of flattened Jacobian matrix. + """ + return self._jacobian.shape + + +class Hessian(object): + """ + Computes the Hessian matrix with a given ``func`` with respect to ``xs`` . + + If the function has multiple inputs, during internal implementation, + all input tensors are concatenated after being flatten, the batch dimension + is retained. + + The Hessian submatrix is lazily evaluated, and can be retrieved with a + multidimensional indexes. See details ``Jacobian`` . + + Warning: + This API is in beta, the signatures could be changed in future version. + + Args: + func (Callable): A python function that takes a Tensor or a Tensor + sequence as inputs and returns a Tensor with shape + ``[batch_size, 1]`` with batch or ``[1]`` without batch. + xs (Tensor|Sequence(Tensor)): The input Tensor or Tensor sequence of + the function ``func``. + is_batched (bool): If true, the first axis is batch axis. Defaults to + False. + + Returns: + + Hessian (Object): A python object retains the Hessian matrix. + Examples: + .. code-block:: python - def func(x): - return paddle.matmul(x, x) + import paddle - x = paddle.ones(shape=[2, 2], dtype='float32') - output, inputs_grad = jvp(func, x) - print(inputs_grad) - # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[2., 2.], - # [2., 2.]])] + def reducer(x): + return paddle.sum(x * x) - v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]]) - output, inputs_grad = vjp(func, x, v) - print(inputs_grad) - # [Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[1., 1.], - # [0., 0.]])] + x = paddle.rand([2, 2]) + h = paddle.incubate.autograd.Hessian(reducer, x) + print(h[:]) + # Tensor(shape=[4, 4], dtype=float32, place=Place(gpu:0), stop_gradient=False, + # [[2., 0., 0., 0.], + # [0., 2., 0., 0.], + # [0., 0., 2., 0.], + # [0., 0., 0., 2.]]) """ - xs = _tensors(inputs, "inputs") - if v is not None: - v = _tensors(v, "v") - with gradient_scope( - xs, v, create_graph=create_graph, - allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: - outputs = func(*xs) - ys = _tensors(outputs, "outputs") - ys_grad = [zeros_like(y) for y in ys] - xs_grad = grad_fn(ys, xs, ys_grad, create_graph=True) - ys_grad = grad_fn(xs_grad, ys_grad, v) - outputs, ys_grad = return_fn(outputs), return_fn(ys_grad) + def __init__(self, func, xs, is_batched=False): + def _jac_func(*xs): + jac = Jacobian(func, xs, is_batched=is_batched) + if (is_batched and jac.shape[1] != 1) or (not is_batched and + jac.shape[0] != 1): + raise RuntimeError( + "The function given to Hessian shoud return as single element Tensor or batched single element Tensor." + ) + return jac[:, 0, :] if is_batched else jac[0, :] + + self.symbolic = Jacobian(_jac_func, xs, is_batched=is_batched) + + def __getitem__(self, indexes): + return self.symbolic[indexes] - return outputs, ys_grad + @property + def shape(self): + """The shape of flattened Hessian matrix. + """ + return self.symbolic.shape + + +class _Jacobian(object): + """The base class for computing Jacobian matrix. + + ``_Jacobian`` implementes the core logic of multidimensional index and lazy + evaluation for Jacobian matrix, subclass only need to overwrite following + methods: + + * ``_lazy_axis()``, return the axis along which will be lazy + evaluating. + * ``_flatten(xs)``, flattens the inputs ``xs``. + * ``_evaluate(index)``, evaluates one slice along ``_lazy_axis`` . + + Notes: + + Because currently PaddlePaddle only support reverse differentiation by + ``paddle.grad``, so lazy evaluation is only supported along the row of + Jacobian matrix, which means that slicing along row will get better + performance. + + """ + + def __init__(self, func, xs): + self._xs = _separate(xs) + self._ys = func(*_as_tensors(self._xs)) + self._flatten_xs = self._flatten(_as_tensors(self._xs)) + self._flatten_ys = self._flatten(_as_tensors(self._ys)) + self._cache = {} + + @property + def shape(self): + raise NotImplementedError + + @property + def _lazy_axis(self): + """"The axis of lazily evaluated.""" + raise NotImplementedError + + def _lazy_indexes(self, indexes): + idx = indexes[self._lazy_axis] + return (idx, ) if isinstance( + idx, int) else tuple(range(idx.start, idx.stop, idx.step)) + + def _flatten(self, xs): + raise NotImplementedError + + def _shifted_indexes(self, indexes, lazy_axis_size=0): + idx = indexes[self._lazy_axis] + shifted_lazy_axis_idx = 0 if isinstance( + idx, int) else slice(0, lazy_axis_size, 1) + return indexes[:self._lazy_axis] + (shifted_lazy_axis_idx, + ) + indexes[self._lazy_axis + 1:] + + def __getitem__(self, indexes): + indexes = _multi_index(indexes, self.shape) + + if isinstance(indexes[self._lazy_axis], int): + other_indexes = indexes[:self._lazy_axis] + \ + indexes[self._lazy_axis+1:] + return self._cached_evaluate(indexes[self._lazy_axis])[ + other_indexes] + lazy_indexes = self._lazy_indexes(indexes) + part_jac = paddle.stack( + [self._cached_evaluate(i) for i in lazy_indexes], + axis=self._lazy_axis) + return part_jac[self._shifted_indexes(indexes, len(lazy_indexes))] + + def _cached_evaluate(self, k): + v = self._cache.get(k) + if v is None: + v = self._evaluate(k) + self._cache[k] = v + return v + + def _evaluate(self, index): + """Evaluate one slice at along lazy axis.""" + raise NotImplementedError + + +class _JacobianNoBatch(_Jacobian): + """Compute Jacobian matrix without batch dimension. + Suppose the mapping is :math:`f: R^M \to R^N`, the output shape is + ``(N, M)`` . + """ + + def __init__(self, func, xs): + super(_JacobianNoBatch, self).__init__(func, xs) + + @property + def shape(self): + return (self._flatten_ys.shape[0], self._flatten_xs.shape[0]) + + @property + def _lazy_axis(self): + return 0 + + def _flatten(self, xs): + return paddle.concat(tuple(x.reshape((-1, )) for x in xs)) + + def _evaluate(self, row_index): + return self._flatten(_grad( + self._flatten_ys[row_index], + self._xs, )) + + +class _JacobianBatchLast(_Jacobian): + """Compute Jacobian matrix with batch at last axis. + Suppose the mapping is :math:`f: R^{M,B} \to R^{N,B}`, the output shape is + ``(N, M, B)`` . + """ + + def __init__(self, func, xs): + super(_JacobianBatchLast, self).__init__(func, xs) + + @property + def shape(self): + return (self._flatten_ys.shape[0], self._flatten_xs.shape[0], + self._flatten_xs.shape[1]) + + @property + def _lazy_axis(self): + return 0 + + def _flatten(self, xs): + return paddle.concat( + tuple(x.reshape((-1, x.shape[-1])) for x in _as_tensors(xs)), 0) + + def _evaluate(self, row): + return self._flatten(_grad(self._flatten_ys[row, :], self._xs)) + + +class _JacobianBatchFirst(_Jacobian): + """Compute Jacobian matrix with batch at first axis. + Suppose the mapping is :math:`f: R^{B,M} \to R^{B,N}`, the output shape is + ``(B, N, M)`` . + """ + + def __init__(self, func, xs): + super(_JacobianBatchFirst, self).__init__(func, xs) + + @property + def shape(self): + return (self._flatten_xs.shape[0], self._flatten_ys.shape[1], + self._flatten_xs.shape[1]) + + @property + def _lazy_axis(self): + return 1 + + def _flatten(self, xs): + return paddle.concat( + tuple(x.reshape((x.shape[0], -1)) for x in _as_tensors(xs)), 1) + + def _evaluate(self, row_index): + return self._flatten(_grad(self._flatten_ys[:, row_index], self._xs)) + + +def _multi_index(indexes, shape): + """A tool for parsing N-dimensional index into a standard format. + + Currently supporting following input format: + * ([positive|negative|slice], ...), the right-most elements can be + omited. + + The standard format after converted is slice tuple which contains N elements: + * ([positive|slice], ..., [positive|slice]) + + Notes: + Ellipsis indexes such as ``(..., i), (i, ...)`` is not supported. + + Args: + indexes (tuple): The input indexes. + shape (tuple): The input shape. + + Returns: + tuple: The standard format index as the above description. + """ + indexes = indexes if isinstance(indexes, typing.Sequence) else (indexes, ) + if any(isinstance(i, type(Ellipsis)) for i in indexes): + raise IndexError('Ellipsis index currently is not supported.') + # Fill the right-most elements. + indexes = indexes + (slice(0, None, None), ) * (len(shape) - len(indexes)) + # Convert to positive index. + positive_indexes = [] + for i, index in enumerate(indexes): + if isinstance(index, slice): + index = slice(index.start or 0, index.stop or shape[i], + index.step or 1) + positive_indexes.append( + slice( + index.start + shape[i] if index.start < 0 else index.start, + index.stop + shape[i] if index.stop < 0 else index.stop, + # Negative step means index backward, no need to convert to + # positive interger. + index.step)) + elif isinstance(index, int): + positive_indexes.append(index + shape[i] if index < 0 else index) + else: + raise TypeError(f'Not supported index type {index}.') + return tuple(positive_indexes) + + +def _as_tensors(xs): + return (xs, ) if isinstance(xs, framework.Variable) else xs + + +def _stack_tensor_or_return_none(origin_list): + assert len(origin_list) > 0, "Can't not stack an empty list" + return paddle.stack( + origin_list, axis=0) if isinstance( + origin_list[0], paddle.fluid.framework.Variable) else None + + +def _replace_none_with_zero_tensor(xs, refs): + if xs is None: + xs = paddle.zeros_like(refs) + xs.stop_gradient = refs.stop_gradient + return xs + elif isinstance(xs, typing.Sequence): + return tuple( + _replace_none_with_zero_tensor(x, refs[i]) + for i, x in enumerate(xs)) + else: + return xs + + +def _grad(ys, xs, v=None): + """A gradient function that can be used in dynamic graph and static graph. + + The ``grad`` combines ``paddle.grad`` used in dynamic graph and + ``paddle.static.gradients`` used in static graph, and do following changes: + + * The ``allow_unused`` flag is removed and set defaults to true internally, + none in outputs will be replaced by zero tensor. + * The ``create_graph`` flag is removed and set defaults to true internally, + only makes sense in dynamic graph. + * When xs is a single Tensor, ``paddle.grad`` returns a list which only + contains one Tensor. It may confuse users, thus in this case we improve + to return a single Tensor in _grad interface. + + Args: + ys (Tensor|Sequence[Tensor]): The output tensor or tensor sequence of + the graph to compute gradients. + xs (Tensor|Sequence[Tensor]): The input tensor or tensor sequence of the graph to + compute gradients. The returned values of this API are the + gradients of inputs . + v (Tensor|Sequence[Tensor]|None,optional): The initial gradient values + of outputs . If grad_outputs is None, the initial gradient values of + outputs would be Tensors filled with 1; if grad_outputs is not None, + it must have the same length as outputs , and in this case, the + initial gradient value of the i-th outputs would be: (1) a Tensor + filled with 1 when the i-th element of grad_outputs is None; + (2) the i-th element of grad_outputs when the i-th element of + grad_outputs is a Tensor. Default None. + + Returns: + Tensor|tuple[Tensor]: Tensor or a tuple of Tensors, whose length is the + same as the Tensor number inside inputs, and the i-th returned + Tensor is the sum of gradients of outputs with respect to the i-th + inputs. + """ + if paddle.fluid._non_static_mode(): + xs_grad = paddle.grad(ys, xs, v, create_graph=True, allow_unused=True) + else: + xs_grad = paddle.static.gradients(ys, xs, v) + + if isinstance(xs, paddle.fluid.framework.Variable): + xs_grad = xs_grad[0] + + return _replace_none_with_zero_tensor(xs_grad, xs) + + +def _separate(xs): + """ + ``_separate`` separates ``xs`` from the computation graph through ``clone`` + or ``deteach`` . + + Interally, ``paddle.grad(xs, ys)`` is stateful API implemented based on + computional graph, which will reduce gradients along all path from ys to xs. + + However, funcional autograd API such as ``vjp``, ``jvp`` is stateless, and + only compute gradients with a given ``func`` . + + For example, given a ``func`` :math:`y0=f(x0)`, supposing forward path is: + ``x0 -> y0``, ``x0 -> x1 -> y0`` . + ``paddle.grad(y0, x0)`` will reduce gradients along ``y0->x0`` and + ``y0->x1->x0``, and ``vjp`` only need reduce along ``y0->x0``. + + So, it's needed to clone or detach xs for breaking the dependencies with + other variables. + + Examples: + + .. code-block:: python + + import paddle + from paddle.autograd.functional import _separate + + + def func(x, y): + return x * y + + + x = paddle.ones((1,)) + x.stop_gradient = False + + y = func(x, x) + print(paddle.grad(y, x)) + # [Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [2.])] + + x1, x2 = _separate((x, x)) + y = func(x1, x2) + print(paddle.grad(y, x1)) + # [Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1.])] + + """ + if isinstance(xs, typing.Sequence): + return tuple(_single_separate(x) for x in xs) + else: + return _single_separate(xs) + + +def _single_separate(x): + if x is None: # x maybe none because grad input's v defaults to none. + return x + if not x.stop_gradient: + return paddle.clone(x) + else: # use detach to share memory when no need gradients. + x = x.detach() + x.stop_gradient = False + return x + return x + + +def _check_inputs(func, xs, v=None): + if not callable(func): + raise TypeError(f"Expected 'fun' is Callable, but got {type(func)}.") + + if not isinstance(xs, (framework.Variable, typing.Sequence)): + raise TypeError(f"Expected 'xs' is a Tensor|Sequence[Tensor]," + f"but got {type(xs)}.") + if isinstance(xs, typing.Sequence) and not all( + isinstance(x, framework.Variable) for x in xs): + raise TypeError("All elements of 'xs' shoule be Tensor.") + + if not isinstance(v, (framework.Variable, typing.Sequence, type(None))): + raise TypeError( + f"Expected 'v' is Tensor|Sequence[Tensor]|None, but got {type(v)}.") + + if isinstance(v, typing.Sequence) and not all( + isinstance(e, framework.Variable) for e in v): + raise TypeError("All elements of 'xs' shoule be Tensor.") + + +def _check_v_shape(v, refs): + if v is None: + return + + v, refs = _as_tensors(v), _as_tensors(refs) + if len(refs) != len(v): + raise RuntimeError(f"The argument v is a tuple of invalid length:" + f"should be {len(refs)} but got {len(v)}.") + + for index, (element_v, element_ref) in enumerate(zip(v, refs)): + if element_v.shape != element_ref.shape: + raise RuntimeError( + f"The v[{index}] has invalid shape: should " + f"be {element_ref.shape} but got {element_v.shape}.") @framework.dygraph_only @@ -354,16 +804,18 @@ def func(x, y): # [0., 0., 0., 2.]]), None)) ''' - inputs = _tensors(inputs, "inputs") - outputs = _tensors(func(*inputs), "outputs") + inputs = _as_tensors(inputs) + outputs = _as_tensors(func(*inputs)) fin_size = len(inputs) fout_size = len(outputs) - flat_outputs = tuple(reshape(output, shape=[-1]) for output in outputs) + flat_outputs = tuple( + paddle.reshape( + output, shape=[-1]) for output in outputs) jacobian = tuple() for i, flat_output in enumerate(flat_outputs): jac_i = list([] for _ in range(fin_size)) for k in range(len(flat_output)): - row_k = grad( + row_k = paddle.grad( flat_output[k], inputs, create_graph=create_graph, @@ -371,7 +823,7 @@ def func(x, y): allow_unused=allow_unused) for j in range(fin_size): jac_i[j].append( - reshape( + paddle.reshape( row_k[j], shape=[-1]) if isinstance(row_k[j], paddle.Tensor) else None) jacobian += (tuple( @@ -419,7 +871,7 @@ def batch_jacobian(func, inputs, create_graph=False, allow_unused=False): be a tuple of Tensors. If both of inputs and outputs are Tensor list/tuple, then the Jacobian will be a tuple of tuple of Tensors. Noted that the first dimension of inputs is batch size. - + For example, the inputs shape and outputs shape of function ``func` is [batch_size, num] and [batch_size, num] respectively, then the Jacobian will be a Tensor with @@ -489,10 +941,10 @@ def func(x, y): # [0., 1., 0., 1., 0., 1., 0., 1.]]), Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, # [[1., 0., 1., 0., 1., 0., 1., 0.], # [0., 1., 0., 1., 0., 1., 0., 1.]])) - + ''' - inputs = _tensors(inputs, "inputs") - outputs = _tensors(func(*inputs), "outputs") + inputs = _as_tensors(inputs) + outputs = _as_tensors(func(*inputs)) batch_size = inputs[0].shape[0] for input in inputs: assert input.shape[ @@ -503,13 +955,13 @@ def func(x, y): fin_size = len(inputs) fout_size = len(outputs) flat_outputs = tuple( - reshape( + paddle.reshape( output, shape=[batch_size, -1]) for output in outputs) jacobian = tuple() for i, flat_output in enumerate(flat_outputs): jac_i = list([] for _ in range(fin_size)) for k in range(flat_output.shape[1]): - row_k = grad( + row_k = paddle.grad( flat_output[:, k], inputs, create_graph=create_graph, @@ -517,7 +969,7 @@ def func(x, y): allow_unused=allow_unused) for j in range(fin_size): jac_i[j].append( - reshape( + paddle.reshape( row_k[j], shape=[-1]) if isinstance(row_k[j], paddle.Tensor) else None) jacobian += (tuple( @@ -569,7 +1021,7 @@ def batch_hessian(func, inputs, create_graph=False, allow_unused=False): the inputs shape and outputs shape of function ``func` is [batch_size, num] and [batch_size, 1] respectively, then the batched Hessian will be a Tensor with a shape of [num, batch_size * num]. - + Why the final shape in this case is that? because batch_hessian will create a inner func(the wrapper of paddle.grad() func) to computes the sum of gradients of `outputs` with respect to each `inputs`, @@ -579,7 +1031,7 @@ def batch_hessian(func, inputs, create_graph=False, allow_unused=False): matrix of the ``i``th column output(Noted that this output means the first order differentiation) and the ``j``th input and will have same dtype and device as the corresponding input. Other situations can be deduced by analogy. - + Examples 1: .. code-block:: python @@ -592,8 +1044,8 @@ def batch_hessian(func, inputs, create_graph=False, allow_unused=False): def func(x): return paddle.matmul(x * x, weight)[:, 0:1] - - + + x.stop_gradient = False batch_hessian = paddle.autograd.batch_hessian(func, x) print(batch_hessian) @@ -612,7 +1064,7 @@ def func(x): def func(x, y): return paddle.matmul(x * x * y * y, weight)[:, 0:1] - + x.stop_gradient = False y.stop_gradient = False batch_hessian = paddle.autograd.batch_hessian(func, [x, y]) @@ -629,7 +1081,7 @@ def func(x, y): # Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True, # [[2., 0., 2., 0., 2., 0., 2., 0.], # [0., 2., 0., 2., 0., 2., 0., 2.]]))) - + Examples 3: .. code-block:: python @@ -639,7 +1091,7 @@ def func(x, y): x = paddle.ones(shape=(4, 2), dtype='float64') weight = paddle.ones(shape=(2, 4), dtype='float64') y = paddle.ones(shape=(4, 2), dtype='float64') - + def func(x, y): return paddle.matmul(x * x, weight)[:, 0:1] @@ -652,7 +1104,7 @@ def func(x, y): # [0., 2., 0., 2., 0., 2., 0., 2.]]), None), (None, None)) ''' - inputs = _tensors(inputs, "inputs") + inputs = _as_tensors(inputs) outputs = func(*inputs) batch_size = inputs[0].shape[0] for input in inputs: @@ -663,7 +1115,7 @@ def func(x, y): ], "The function to compute batched Hessian matrix should return a Tensor of shape [batch_size, 1]" def jac_func(*ins): - grad_inputs = grad( + grad_inputs = paddle.grad( outputs, ins, create_graph=True, @@ -715,7 +1167,7 @@ def hessian(func, inputs, create_graph=False, allow_unused=False): def func(x): return paddle.sum(paddle.matmul(x, x)) - + x = paddle.ones(shape=[2, 2], dtype='float32') x.stop_gradient = False hessian = paddle.autograd.hessian(func, x) @@ -733,7 +1185,7 @@ def func(x): def func(x, y): return paddle.sum(paddle.matmul(x, y)) - + x = paddle.ones(shape=[2, 2], dtype='float32') y = paddle.ones(shape=[2, 2], dtype='float32') x.stop_gradient = False @@ -768,7 +1220,7 @@ def func(x, y): def func(x, y): return paddle.sum(paddle.matmul(x, x)) - + x = paddle.ones(shape=[2, 2], dtype='float32') y = paddle.ones(shape=[2, 2], dtype='float32') x.stop_gradient = False @@ -782,14 +1234,14 @@ def func(x, y): # [0., 1., 1., 2.]]), None), (None, None)) ''' - inputs = _tensors(inputs, "inputs") + inputs = _as_tensors(inputs) outputs = func(*inputs) assert isinstance(outputs, paddle.Tensor) and outputs.shape == [ 1 ], "The function to compute Hessian matrix should return a Tensor with a single element" def jac_func(*ins): - grad_inputs = grad( + grad_inputs = paddle.grad( outputs, ins, create_graph=True, @@ -803,7 +1255,6 @@ def jac_func(*ins): jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused) -@framework.dygraph_only def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): ''' .. note:: @@ -839,7 +1290,7 @@ def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): import paddle def func(x): return paddle.sum(paddle.matmul(x, x)) - + x = paddle.ones(shape=[2, 2], dtype='float32') x.stop_gradient = False vx = paddle.ones(shape=[2, 2], dtype='float32') * 2 @@ -856,7 +1307,7 @@ def func(x): import paddle def func(x): return paddle.sum(paddle.matmul(x, x)) - + x = paddle.ones(shape=[2, 2], dtype='float32') x.stop_gradient = False vhp_rslt = paddle.autograd.vhp(func, x) @@ -872,7 +1323,7 @@ def func(x): import paddle def func(x, y): return paddle.sum(paddle.matmul(x, x)) - + x = paddle.ones(shape=[2, 2], dtype='float32') x.stop_gradient = False y = paddle.ones(shape=[2, 2], dtype='float32') @@ -887,177 +1338,17 @@ def func(x, y): # [[8., 8.], # [8., 8.]]), None]) ''' - xs = _tensors(inputs, "inputs") + xs = _as_tensors(inputs) if v is not None: - v = _tensors(v, "v") - - with gradient_scope( - xs, v, create_graph=create_graph, - allow_unused=allow_unused) as [xs, v, grad_fn, return_fn]: - outputs = func(*xs) - ys = _tensors(outputs, "outputs") - assert len(ys) == 1 and isinstance( - ys[0], paddle.Tensor - ) and ys[0].shape == [ - 1 - ], "The function to compute vhp should return a Tensor with a single element" - jac = grad_fn(ys, xs, create_graph=True) - vhp = grad_fn(jac, xs, v) - outputs, vhp = return_fn(outputs), return_fn(vhp) + v = _as_tensors(v) + xs, v = _separate(xs), _separate(v) + outputs = func(*xs) + ys = _as_tensors(outputs) + assert len(ys) == 1 and isinstance( + ys[0], framework.Variable + ) and ys[0].shape == [ + 1 + ], "The function to compute vhp should return a Tensor with a single element" + jac = _grad(ys, xs) + vhp = _grad(jac, xs, v) return outputs, vhp - - -class Jacobian(object): - r""" - Computes the Jacobian matrix of function `func`, which may take as input - single or multiple tensor typed arguments and output a single tensor or - multiple tensors. - - In case `func` is multi-input and multi-output, i.e., - - func: Callable[[Tensor, ...], [Tensor, ...]] - - `func` is treated as a vector valued function with all its inputs flattened - into a single one dimensional tensor, or a two dimensional tensor with the - first dimension retained as the batching dimension. The same rule applies to - the function outputs. - - Once the Jacobian J is constructed, there are four ways to retrieve the - partial derivatives. - - - J[:], retrieving the full matrix. - - - J[:, j], retrieving the partial derivatives w.r.t. the j'th input - variable. - - - J[i, :], retrieving the partial derivatives w.r.t. the i'th output - variable. - - - J[i, j], retrieving the partial derivatives w.r.t. the i'th output - variable and the j'th input variable. - - Examples: - .. code-block:: python - import paddle - import numpy as np - - def func(xs): - x, y = xs - return paddle.matmul(x, y) - - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - x = paddle.static.data(name='x', shape=[2, 2], dtype='float32') - JJ = paddle.autograd.functional.Jacobian(func, [x, x]) - nrow, ncol = JJ.shape() - full_jacobian = JJ[:] - place = fluid.CUDAPlace(0) - exe = fluid.Executor(place) - exe.run(startup) - - feeds = {'x': np.array([[2., 2.], [2., 1.]]).astype('float32')} - jacobian = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] - print(jacobian) - # [[4. 2. 2. 0. 4. 2. 2. 0.] - # [2. 3. 0. 2. 2. 3. 0. 2.] - # [2. 0. 3. 2. 2. 0. 3. 2.] - # [0. 2. 2. 2. 0. 2. 2. 2.]] - """ - - def __init__(self, func, inputs, batch=False): - r"""Constructing a Jacobian matrix. - - Parameters: - func (Callable): a Python function that takes as input a Tensor - or a Tensor list and outputs a Tensor or a Tensor list. - inputs (Tensor|list[Tensor]): a Tensor or a list of Tensors as - `func`'s input. - batch (bool): if True the 0'th axis is considered the batch - dimension, both on input and output. - """ - - def enable_grads(inputs): - if isinstance(inputs, (list, tuple)): - for x in inputs: - x.stop_gradient = False - else: - assert isinstance(inputs, paddle.fluid.framework.Variable), ( - f"Expecting {inputs} to be paddle.fluid.framework.Variable," - f" however it's found to be a(n) {type(inputs)}.") - inputs.stop_gradient = False - return inputs - - self.batch = batch - self.xs = enable_grads(inputs) - ys = func(inputs) - if not isinstance(ys, list): - ys = [ys] - self.y = self.flatten_all(ys) - self.ydim = self.y.shape[-1] - self.xdim = self.flatten_all(inputs).shape[-1] - self.bdim = self.y.shape[0] - self.jacobian = {} - - def flatten(self, x): - to = [x.shape[0], -1] if self.batch else [-1] - return x.reshape(to) - - def flatten_all(self, xs): - if isinstance(xs, (list, tuple)): - return paddle.concat([self.flatten(x) for x in xs], axis=-1) - else: - return self.flatten(xs) - - def shape(self): - return (self.ydim, self.xdim) - - def __getitem__(self, tup): - if hasattr(tup, '__iter__'): - i, j = tup - else: - i, j = tup, None - - full = isinstance(i, slice) - - if full: - if 'full' not in self.jacobian: - rows = [ - self.flatten_all(gradients(self.y[..., i], self.xs)) - for i in range(self.ydim) - ] - self.jacobian['full'] = full_jacobian = paddle.stack(rows) - else: - full_jacobian = self.jacobian['full'] - - return full_jacobian[i] if j is None else full_jacobian[i][..., j] - - assert 0 <= i < self.ydim, f"Jacobian index i={i} is not valid." - assert j is None or isinstance(j, slice) or (0 <= j < self.xdim), ( - f"Jacobian index j={j} is not valid.") - if 'full' in self.jacobian: - JJ = self.jacobian['full'] - else: - JJ = self.jacobian - if i not in self.jacobian: - self.jacobian[i] = self.flatten_all( - gradients(self.y[..., i], self.xs)) - - if j is None: - return JJ[i] - else: - return JJ[i][..., j] - - -class Hessian(object): - def __init__(self, func, inputs, batch=False): - f_x = lambda xs: Jacobian(func, xs, batch=batch)[0] - self.symbolic = Jacobian(f_x, inputs, batch=batch) - self.xs = inputs - self.batch = batch - - def __getitem__(self, tup): - return self.symbolic[tup] - - def shape(self): - return self.symbolic.shape() diff --git a/python/paddle/autograd/utils.py b/python/paddle/autograd/utils.py deleted file mode 100644 index 710c9ee18dfbf..0000000000000 --- a/python/paddle/autograd/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle - - -def _tensors(ts, name): - if isinstance(ts, (list, tuple)): - assert len(ts) > 0, "{} connot be empty".format(name) - for each_t in ts: - assert isinstance( - each_t, paddle.Tensor - ) or each_t is None, "Elements of {} must be paddle.Tensor or None".format( - name) - return list(ts) - else: - assert isinstance(ts, paddle.Tensor), "{} must be Tensor".format(name) - return [ts] - - -def _stack_tensor_or_return_none(origin_list): - assert len(origin_list) > 0, "Can't not stack an empty list" - return paddle.stack( - origin_list, axis=0) if isinstance(origin_list[0], - paddle.Tensor) else None - - -def _replace_none_with_zero_tensor(t, spec_t): - if t is None: - zero_t = paddle.zeros(shape=spec_t.shape, dtype=spec_t.dtype) - zero_t.stop_gradient = spec_t.stop_gradient - return zero_t - else: - return t diff --git a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt index 6d9625483ea82..1f69abac01ac6 100644 --- a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt @@ -6,6 +6,5 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach(TEST_OP) -set_tests_properties(test_jacobian PROPERTIES TIMEOUT 50) -set_tests_properties(test_hessian PROPERTIES TIMEOUT 50) -set_tests_properties(test_vhp PROPERTIES TIMEOUT 50) +set_tests_properties(test_autograd_functional_dynamic PROPERTIES TIMEOUT 100) +set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 100) diff --git a/python/paddle/fluid/tests/unittests/autograd/config.py b/python/paddle/fluid/tests/unittests/autograd/config.py new file mode 100644 index 0000000000000..311ca49d39555 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/config.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + +DEVICES = [paddle.CPUPlace()] +if paddle.is_compiled_with_cuda(): + DEVICES.append(paddle.CUDAPlace(0)) + +DEFAULT_DTYPE = 'float64' + +# The numerical tolerance of different dtype of different order different +# derivative. It's a empirical value provided by Paddle Science team. +TOLERANCE = { + "float32": { + "first_order_grad": { + "rtol": 1e-3, + "atol": 1e-3, + "eps": 1e-4 + }, + "second_order_grad": { + "rtol": 1e-2, + "atol": 1e-2, + "eps": 1e-2 + } + }, + "float64": { + "first_order_grad": { + "rtol": 1e-7, + "atol": 1e-7, + "eps": 1e-7 + }, + "second_order_grad": { + "rtol": 1e-5, + "atol": 1e-5, + "eps": 1e-5 + } + } +} diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py new file mode 100644 index 0000000000000..e46c532eb05db --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py @@ -0,0 +1,1233 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import typing +import unittest + +import numpy as np +import paddle +import paddle.compat as cpt +import paddle.nn.functional as F +from paddle.autograd.functional import _as_tensors + +import config +import utils +from utils import (_compute_numerical_batch_hessian, _compute_numerical_hessian, + _compute_numerical_vhp, _compute_numerical_jacobian, + _compute_numerical_batch_jacobian) +from utils import matmul, mul, nested, o2, pow, reduce, reduce_dim, unuse + + +def make_v(f, inputs): + outputs = _as_tensors(f(*inputs)) + return [paddle.ones_like(x) for x in outputs] + + +class TestAutogradFunctional(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.RAW_INPUTS = { + 'a': [1.0], + 'b': [1.0, 2.0], + 'c': [3.0, 4.0], + 'd': [[2.0], [3.0]], + 'A': [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], + 'B': [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], + } + + def setUp(self): + pass + + def gen_input(self, inp, stop_gradient=False): + if isinstance(inp, paddle.Tensor): + return inp + return paddle.to_tensor( + self.RAW_INPUTS[inp], stop_gradient=stop_gradient) + + def gen_inputs(self, inputs): + if isinstance(inputs, list): + inputs = [self.gen_input(x) for x in inputs] + else: + inputs = [self.gen_input(inputs)] + return inputs + + def gen_test_pairs(self, + func, + inputs, + v=None, + create_graph=False, + allow_unused=False): + def vjp_test(): + nonlocal v + xs = self.gen_inputs(inputs) + if v is not None: + v = self.gen_inputs(v) + outputs, inputs_grad = paddle.autograd.vjp(func, xs, v) + else: + outputs, inputs_grad = paddle.autograd.vjp(func, xs) + return outputs, inputs_grad + + def grad_test(): + nonlocal v + xs = self.gen_inputs(inputs) + if v is not None: + v = self.gen_inputs(v) + outputs = func(*xs) + if v is not None: + inputs_grad = paddle.grad( + outputs, + xs, + v, + create_graph=create_graph, + allow_unused=allow_unused) + else: + inputs_grad = paddle.grad( + outputs, + xs, + create_graph=create_graph, + allow_unused=allow_unused) + return outputs, inputs_grad + + return vjp_test, grad_test + + def gen_jvp_tests(self, + func, + inputs, + v=None, + create_graph=False, + allow_unused=False): + def jvp_test(): + nonlocal v + xs = self.gen_inputs(inputs) + if v is not None: + v = self.gen_inputs(v) + outputs, outputs_grad = paddle.autograd.jvp( + func, + xs, + v, + create_graph=create_graph, + allow_unused=allow_unused) + else: + outputs, outputs_grad = paddle.autograd.jvp( + func, + xs, + create_graph=create_graph, + allow_unused=allow_unused) + return outputs, outputs_grad + + return jvp_test + + def check_results(self, ref, res): + type_error = 'Result is different than expected in shape or type' + value_error = 'Result is different than expected values' + if ref is None: + self.assertTrue(res is None, type_error) + elif isinstance(ref, paddle.Tensor): + self.assertTrue(isinstance(res, paddle.Tensor), type_error) + np.testing.assert_allclose(res, ref) + else: + self.assertTrue(len(res) == len(ref), type_error) + for i in range(len(ref)): + self.check_results(ref[i], res[i]) + return True + + +class TestVJP(TestAutogradFunctional): + def test_vjp_i1o1(self): + test_cases = [ + [reduce, 'A'], # noqa + [reduce_dim, 'A'], # noqa + ] # noqa + for f, inputs in test_cases: + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_i2o1(self): + test_cases = [ + [matmul, ['A', 'B']], # noqa + [mul, ['b', 'c']], # noqa + ] # noqa + for f, inputs in test_cases: + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_i2o2(self): + test_cases = [ + [o2, ['A', 'A']], # noqa + ] # noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + v = make_v(f, inputs) + vjp, grad = self.gen_test_pairs(f, inputs, v=v) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_i2o2_omitting_v(self): + test_cases = [ + [o2, ['A', 'A']], # noqa + ] # noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_nested(self): + x = self.gen_input('a') + test_cases = [ + [nested(x), 'a'], # noqa + ] + for f, inputs in test_cases: + vjp, grad = self.gen_test_pairs(f, inputs) + vjp_result, grad_result = vjp(), grad() + self.check_results(grad_result, vjp_result) + + def test_vjp_aliased_input(self): + x = self.gen_input('a') + ref = self.gen_test_pairs(nested(x), 'a')[0] + aliased = self.gen_test_pairs(nested(x), x)[0] + ref_result, aliased_result = ref(), aliased() + self.check_results(ref_result, aliased_result) + + +@utils.place(config.DEVICES) +@utils.parameterize( + (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'expected_exception'), ( + ('v_shape_not_equal_ys', utils.square, np.random.rand(3), + np.random.rand(1), RuntimeError), )) +class TestVJPException(unittest.TestCase): + def test_vjp(self): + with self.assertRaises(self.expected_exception): + paddle.autograd.vjp(self.fun, + paddle.to_tensor(self.xs), + paddle.to_tensor(self.v)) + + +def jac(grad_fn, f, inputs): + assert grad_fn in [paddle.autograd.vjp, paddle.autograd.jvp] + if grad_fn is paddle.autograd.jvp: + vs = [paddle.zeros_like(x) for x in inputs] + else: + outputs = f(*inputs) + if isinstance(outputs, paddle.Tensor): + outputs = [outputs] + vs = [paddle.zeros_like(y) for y in outputs] + JJ_cols = [] + for i, v in enumerate(vs): + v = v.flatten() + for j in range(len(v)): + _v = paddle.zeros_like(v).detach() + _v[j] = 1.0 + _v = _v.reshape(vs[i].shape) + _vs = vs.copy() + _vs[i] = _v + _, grads = grad_fn(f, inputs, _vs) + d_outs = paddle.concat([d_out.flatten() for d_out in grads]) + JJ_cols.append(d_outs) + # JJ is the fully unrolled jacobian + JJ = paddle.stack(JJ_cols) + if grad_fn is paddle.autograd.vjp: + JJ = JJ.t() + return JJ + + +class TestJVP(TestAutogradFunctional): + def test_jvp_i1o1(self): + test_cases = [ + [reduce, 'A'], # noqa + [reduce_dim, 'A'], # noqa + ] # noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + forward_jac = jac(paddle.autograd.jvp, f, inputs) + reverse_jac = jac(paddle.autograd.vjp, f, inputs) + self.check_results(forward_jac, reverse_jac) + + def test_jvp_i2o1(self): + test_cases = [ # noqa + [matmul, ['A', 'B']], # noqa + ] # noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + forward_jac = jac(paddle.autograd.jvp, f, inputs) + reverse_jac = jac(paddle.autograd.vjp, f, inputs) + self.check_results(forward_jac, reverse_jac) + + def test_jvp_i2o2(self): + test_cases = [ # noqa + [o2, ['A', 'A']], # noqa + ] # noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + forward_jac = jac(paddle.autograd.jvp, f, inputs) + reverse_jac = jac(paddle.autograd.vjp, f, inputs) + self.check_results(forward_jac, reverse_jac) + + def test_jvp_i2o2_omitting_v(self): + test_cases = [ # noqa + [o2, ['A', 'A']], # noqa + ] # noqa + for f, inputs in test_cases: + inputs = self.gen_inputs(inputs) + results_omitting_v = paddle.autograd.jvp(f, inputs) + v = [paddle.ones_like(x) for x in inputs] + results_with_v = paddle.autograd.jvp(f, inputs, v) + self.check_results(results_omitting_v, results_with_v) + + +@utils.place(config.DEVICES) +@utils.parameterize((utils.TEST_CASE_NAME, 'func', 'xs'), ( + ('1d_in_1d_out', utils.square, np.array([2., 3.])), + ('3d_in_3d_out', utils.square, np.random.rand(2, 3, 4)), + ('single_in_single_out', utils.square, np.random.rand(2, 3)), + ('multi_in_single_out', paddle.matmul, + (np.random.rand(2, 2), np.random.rand(2, 2))), )) +class TestJacobianClassNoBatch(unittest.TestCase): + def setUp(self): + self._dtype = self.xs[0].dtype if isinstance( + self.xs, typing.Sequence) else self.xs.dtype + self._eps = config.TOLERANCE.get(str(self._dtype)).get( + "first_order_grad").get("eps") + self._rtol = config.TOLERANCE.get(str(self._dtype)).get( + "first_order_grad").get("rtol") + self._atol = config.TOLERANCE.get(str(self._dtype)).get( + "first_order_grad").get("atol") + + self.xs = [paddle.to_tensor(x) for x in self.xs] if isinstance( + self.xs, typing.Sequence) else paddle.to_tensor(self.xs) + self._actual = paddle.autograd.Jacobian(self.func, self.xs, False) + self._expected = self._expected() + + def test_jacobian(self): + Index = collections.namedtuple('Index', ('type', 'value')) + indexes = (Index('all', (slice(0, None, None), slice(0, None, None))), + Index('row', (0, slice(0, None, None))), + Index('col', (slice(0, None, None), 0)), + Index('multi-row', (slice(0, 2, 1), slice(0, None, None)))) + self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype) + for index in indexes: + np.testing.assert_allclose( + self._actual.__getitem__(index.value), + self._expected.__getitem__(index.value), + rtol=self._rtol, + atol=self._atol, + err_msg=f'Testcase {index.type} index not passed, value is {index.value}' + ) + + def _expected(self): + jac = utils._compute_numerical_jacobian(self.func, self.xs, self._eps, + self._dtype) + return utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NM) + + +@utils.place(config.DEVICES) +@utils.parameterize((utils.TEST_CASE_NAME, 'func', 'xs'), ( + ('1d_in_1d_out', utils.square, np.array([[1., 2., 3.], [3., 4., 3.]])), + ('3d_in_3d_out', utils.square, np.random.rand(2, 3, 4)), + ('multi_in_single_out', utils.square, np.random.rand(2, 3)), )) +class TestJacobianClassBatchFirst(unittest.TestCase): + def setUp(self): + self._dtype = self.xs[0].dtype if isinstance( + self.xs, typing.Sequence) else self.xs.dtype + self._eps = config.TOLERANCE.get(str(self._dtype)).get( + "first_order_grad").get("eps") + self._rtol = config.TOLERANCE.get(str(self._dtype)).get( + "first_order_grad").get("rtol") + self._atol = config.TOLERANCE.get(str(self._dtype)).get( + "first_order_grad").get("atol") + + self.xs = [paddle.to_tensor(x) for x in self.xs] if isinstance( + self.xs, typing.Sequence) else paddle.to_tensor(self.xs) + self._actual = paddle.autograd.Jacobian(self.func, self.xs, True) + self._expected = self._expected() + + def test_jacobian(self): + Index = collections.namedtuple('Index', ('type', 'value')) + indexes = ( + Index('all', (slice(0, None, None), slice(0, None, None), + slice(0, None, None))), + Index('row', (slice(0, None, None), 0, slice(0, None, None))), + Index('col', + (slice(0, None, None), slice(0, None, None), 0)), Index( + 'batch', (slice(0, 2, None), slice(0, None, None), + slice(0, None, None))), + Index('multi_row', + (slice(0, 1, None), slice(0, 2, 1), slice(0, None, None)))) + self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype) + for index in indexes: + np.testing.assert_allclose( + self._actual.__getitem__(index.value), + self._expected.__getitem__(index.value), + rtol=self._rtol, + atol=self._atol, + err_msg=f'Testcase {index.type} index not passed, value is {index.value}' + ) + + def _expected(self): + jac = utils._compute_numerical_batch_jacobian( + self.func, self.xs, self._eps, self._dtype, False) + jac = utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NBM) + return utils._np_transpose_matrix_format(jac, utils.MatrixFormat.NBM, + utils.MatrixFormat.BNM) + + +class TestHessianClassNoBatch(unittest.TestCase): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("eps") + self.rtol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("rtol") + self.atol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("atol") + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = utils._compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) + + self.x.stop_gradient = False + hessian = paddle.autograd.Hessian(func, self.x) + np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian, + self.rtol, self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + numerical_hessian = utils._compute_numerical_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.Hessian(func, [self.x, self.y]) + np.testing.assert_allclose( + hessian[:].numpy(), + numerical_hessian, + rtol=self.rtol, + atol=self.atol) + + def test_allow_unused_true(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = utils._compute_numerical_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.Hessian(func, [self.x, self.y]) + np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian, + self.rtol, self.atol) + + def test_create_graph_true(self): + def func(x): + return paddle.sum(F.sigmoid(x)) + + numerical_hessian = utils._compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + numerical_hessian = utils._np_concat_matrix_sequence(numerical_hessian) + self.x.stop_gradient = False + hessian = paddle.autograd.Hessian(func, self.x) + assert hessian[:].stop_gradient == False + np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian, + self.rtol, self.atol) + + def test_out_not_single(self): + def func(x): + return x * x + + with self.assertRaises(RuntimeError): + paddle.autograd.Hessian(func, paddle.ones([3])) + + +class TestHessianClassBatchFirst(unittest.TestCase): + @classmethod + def setUpClass(self): + self.x_shape = (5, 2) + self.weight_shape = (2, 4) + self.y_shape = (5, 2) + self.nbatch, self.nrow = 5, 2 + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('eps') + self.rtol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('rtol') + self.atol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('atol') + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + expected = utils._compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + + H = paddle.autograd.Hessian(func, self.x, is_batched=True) + actual = utils._np_transpose_matrix_format( + H[:].numpy(), utils.MatrixFormat.BNM, utils.MatrixFormat.NBM) + actual = actual.reshape((H.shape[1], -1)) + + np.testing.assert_allclose(actual, expected, self.rtol, self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.matmul(x * x * y * y, self.weight)[:, 0:1] + + xs_len = 2 + expected = utils._compute_numerical_batch_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + expected = np.reshape( + np.array(expected), + (xs_len, xs_len, self.nrow, self.nbatch, self.nrow)) + expected = [[n for n in row] for row in expected] + expected = utils._np_concat_matrix_sequence(expected) + + self.x.stop_gradient = False + self.y.stop_gradient = False + H = paddle.autograd.Hessian(func, [self.x, self.y], is_batched=True) + actual = utils._np_transpose_matrix_format( + H[:].numpy(), utils.MatrixFormat.BNM, utils.MatrixFormat.NBM) + + np.testing.assert_allclose(actual, expected, self.rtol, self.atol) + + def test_allow_unused(self): + def func(x, y): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + xs_len = 2 + expected = utils._compute_numerical_batch_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + expected = np.reshape( + np.array(expected), + (xs_len, xs_len, self.nrow, self.nbatch, self.nrow)) + expected = [[n for n in row] for row in expected] + expected = utils._np_concat_matrix_sequence(expected) + expected = utils._np_transpose_matrix_format( + expected, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM) + + actual = paddle.autograd.Hessian( + func, [self.x, self.y], is_batched=True)[:] + + np.testing.assert_allclose( + actual, expected, rtol=self.rtol, atol=self.atol) + + def test_stop_gradient(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + expected = utils._compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + + x = self.x.clone() + x.stop_gradient = True + H = paddle.autograd.Hessian(func, self.x, is_batched=True)[:] + actual = utils._np_transpose_matrix_format( + H[:].numpy(), utils.MatrixFormat.BNM, utils.MatrixFormat.NBM) + actual = actual.reshape((H.shape[1], -1)) + + np.testing.assert_allclose(actual, expected, self.rtol, self.atol) + + def test_out_not_single(self): + def func(x): + return (x * x) + + with self.assertRaises(RuntimeError): + paddle.autograd.Hessian(func, paddle.ones((3, 3)), is_batched=True) + + +class TestHessian(unittest.TestCase): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("eps") + self.rtol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("rtol") + self.atol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("atol") + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = _compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + hessian = paddle.autograd.hessian(func, self.x) + np.testing.assert_allclose(hessian.numpy(), numerical_hessian[0][0], + self.rtol, self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + numerical_hessian = _compute_numerical_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.hessian(func, [self.x, self.y]) + for i in range(len(hessian)): + for j in range(len(hessian[0])): + np.testing.assert_allclose(hessian[i][j].numpy(), + numerical_hessian[i][j], self.rtol, + self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.hessian(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = _compute_numerical_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.hessian( + func, [self.x, self.y], allow_unused=True) + for i in range(len(hessian)): + for j in range(len(hessian[0])): + if i == j == 0: + np.testing.assert_allclose(hessian[i][j].numpy(), + numerical_hessian[i][j], + self.rtol, self.atol) + else: + assert hessian[i][j] is None + + def test_create_graph_false(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_hessian = _compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.hessian(func, self.x) + assert hessian.stop_gradient == True + np.testing.assert_allclose(hessian.numpy(), numerical_hessian[0][0], + self.rtol, self.atol) + try: + paddle.grad(hessian, self.x) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + def test_create_graph_true(self): + def func(x): + return paddle.sum(F.sigmoid(x)) + + numerical_hessian = _compute_numerical_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.hessian(func, self.x, create_graph=True) + assert hessian.stop_gradient == False + np.testing.assert_allclose(hessian.numpy(), numerical_hessian[0][0], + self.rtol, self.atol) + triple_grad = paddle.grad(hessian, self.x) + assert triple_grad is not None + + +class TestHessianFloat64(TestHessian): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("eps") + self.rtol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("rtol") + self.atol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("atol") + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + + +class TestBatchHessian(unittest.TestCase): + @classmethod + def setUpClass(self): + self.x_shape = (5, 2) + self.weight_shape = (2, 4) + self.y_shape = (5, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("eps") + self.rtol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("rtol") + self.atol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("atol") + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True) + np.testing.assert_allclose(hessian, numerical_hessian, self.rtol, + self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.matmul(x * x * y * y, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, [self.x, self.y]) + + shape_tensor = paddle.to_tensor(numerical_hessian).astype("float64") + hessian_reshape = np.reshape(hessian, (shape_tensor.shape)) + np.testing.assert_allclose(hessian_reshape, numerical_hessian, + self.rtol, self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + hessian = paddle.autograd.batch_hessian( + func, [self.x, self.y], allow_unused=True) + + for i in range(len(hessian)): + for j in range(len(hessian[0])): + if i == j == 0: + numerical_hessian = np.stack( + (numerical_hessian[i][j], numerical_hessian[i][j + 1]), + axis=0) + np.testing.assert_allclose(hessian[i][j], numerical_hessian, + self.rtol, self.atol) + else: + assert hessian[i][j] is None + + def test_create_graph_false(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, self.x) + assert hessian.stop_gradient == True + np.testing.assert_allclose(hessian.numpy(), numerical_hessian, + self.rtol, self.atol) + try: + paddle.grad(hessian, self.x) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + def test_create_graph_true(self): + def func(x): + return paddle.matmul(x * x, self.weight)[:, 0:1] + + numerical_hessian = _compute_numerical_batch_hessian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True) + assert hessian.stop_gradient == False + np.testing.assert_allclose(hessian.numpy(), numerical_hessian, + self.rtol, self.atol) + triple_grad = paddle.grad(hessian, self.x) + assert triple_grad is not None + + +class TestBatchHessianFloat64(TestBatchHessian): + @classmethod + def setUpClass(self): + self.x_shape = (5, 2) + self.weight_shape = (2, 4) + self.y_shape = (5, 2) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("eps") + self.rtol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("rtol") + self.atol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("atol") + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + +class TestVHP(unittest.TestCase): + @classmethod + def setUpClass(self): + self.shape = (2, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("eps") + self.rtol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("rtol") + self.atol = config.TOLERANCE.get(self.dtype).get( + "second_order_grad").get("atol") + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + self.vx = paddle.rand(shape=self.shape, dtype=self.dtype) + self.vy = paddle.rand(shape=self.shape, dtype=self.dtype) + + def test_single_input(self): + def func(x): + return paddle.sum(paddle.matmul(x, x)) + + numerical_func_output = func(self.x).numpy() + numerical_vhp = _compute_numerical_vhp( + func, self.x, self.vx, self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx) + np.testing.assert_allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + np.testing.assert_allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, + self.atol) + + def test_multi_input(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + numerical_func_output = func(self.x, self.y).numpy() + numerical_vhp = _compute_numerical_vhp( + func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta, + self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y], + [self.vx, self.vy]) + np.testing.assert_allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + for i in range(len(vhp)): + np.testing.assert_allclose(vhp[i].numpy(), numerical_vhp[i], + self.rtol, self.atol) + + def test_v_default(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, y)) + + numerical_func_output = func(self.x, self.y).numpy() + vx = paddle.ones(self.vx.shape, dtype=self.vx.dtype) + vy = paddle.ones(self.vy.shape, dtype=self.vy.dtype) + numerical_vhp = _compute_numerical_vhp(func, [self.x, self.y], + [vx, vy], self.numerical_delta, + self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y]) + np.testing.assert_allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + for i in range(len(vhp)): + np.testing.assert_allclose(vhp[i].numpy(), numerical_vhp[i], + self.rtol, self.atol) + + def test_allow_unused_true(self): + def func(x, y): + return paddle.sum(paddle.matmul(x, x)) + + numerical_func_output = func(self.x, self.y).numpy() + numerical_vhp = _compute_numerical_vhp( + func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta, + self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y], + [self.vx, self.vy]) + np.testing.assert_allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + np.testing.assert_allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, + self.atol) + + def test_create_graph_true(self): + def func(x): + return paddle.sum(F.sigmoid(x)) + + numerical_func_output = func(self.x).numpy() + numerical_vhp = _compute_numerical_vhp( + func, self.x, self.vx, self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx) + np.testing.assert_allclose(func_output.numpy(), numerical_func_output, + self.rtol, self.atol) + assert vhp[0].stop_gradient == False + np.testing.assert_allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, + self.atol) + triple_grad = paddle.grad(vhp, self.x) + assert triple_grad is not None + + +class TestJacobian(unittest.TestCase): + @classmethod + def setUpClass(self): + self.shape = (4, 4) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = 1e-4 + self.rtol = 1e-3 + self.atol = 1e-3 + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + + def test_single_input_and_single_output(self): + def func(x): + return paddle.matmul(x, x) + + numerical_jacobian = _compute_numerical_jacobian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + jacobian = paddle.autograd.jacobian(func, self.x) + np.testing.assert_allclose(jacobian.numpy(), numerical_jacobian[0][0], + self.rtol, self.atol) + + def test_single_input_and_multi_output(self): + def func(x): + return paddle.matmul(x, x), x * x + + numerical_jacobian = _compute_numerical_jacobian( + func, self.x, self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + jacobian = paddle.autograd.jacobian(func, self.x) + for i in range(len(jacobian)): + np.testing.assert_allclose(jacobian[i].numpy(), + numerical_jacobian[i][0], self.rtol, + self.atol) + + def test_multi_input_and_single_output(self): + def func(x, y): + return paddle.matmul(x, y) + + numerical_jacobian = _compute_numerical_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) + for j in range(len(jacobian)): + np.testing.assert_allclose(jacobian[j].numpy(), + numerical_jacobian[0][j], self.rtol, + self.atol) + + def test_multi_input_and_multi_output(self): + def func(x, y): + return paddle.matmul(x, y), x * y + + numerical_jacobian = _compute_numerical_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) + for i in range(len(jacobian)): + for j in range(len(jacobian[0])): + np.testing.assert_allclose(jacobian[i][j].numpy(), + numerical_jacobian[i][j], self.rtol, + self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return paddle.matmul(x, x) + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return paddle.matmul(x, x) + + numerical_jacobian = _compute_numerical_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.jacobian( + func, [self.x, self.y], allow_unused=True) + np.testing.assert_allclose( + jacobian[0].numpy(), numerical_jacobian[0][0], self.rtol, self.atol) + assert jacobian[1] is None + + def test_create_graph_false(self): + def func(x, y): + return paddle.matmul(x, y) + + numerical_jacobian = _compute_numerical_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) + for j in range(len(jacobian)): + assert jacobian[j].stop_gradient == True + np.testing.assert_allclose(jacobian[j].numpy(), + numerical_jacobian[0][j], self.rtol, + self.atol) + try: + paddle.grad(jacobian[0], [self.x, self.y]) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + def test_create_graph_true(self): + def func(x, y): + return paddle.matmul(x, y) + + numerical_jacobian = _compute_numerical_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.jacobian( + func, [self.x, self.y], create_graph=True) + for j in range(len(jacobian)): + assert jacobian[j].stop_gradient == False + np.testing.assert_allclose(jacobian[j].numpy(), + numerical_jacobian[0][j], self.rtol, + self.atol) + double_grad = paddle.grad(jacobian[0], [self.x, self.y]) + assert double_grad is not None + + +class TestJacobianFloat64(TestJacobian): + @classmethod + def setUpClass(self): + self.shape = (4, 4) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = 1e-7 + self.rtol = 1e-7 + self.atol = 1e-7 + self.x = paddle.rand(shape=self.shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + + +class TestJacobianBatch(unittest.TestCase): + @classmethod + def setUpClass(self): + self.x_shape = (4, 2) + self.weight_shape = (2, 4) + self.y_shape = (4, 2) + self.dtype = 'float32' + self.np_dtype = np.float32 + self.numerical_delta = 1e-4 + self.rtol = 1e-3 + self.atol = 1e-3 + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + def test_batch_single_input_and_batch_single_output(self): + def func(x): + return paddle.matmul(paddle.matmul(x, self.weight), self.y) + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian( + func, + self.x, ) + + self.assertTrue( + np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0] + .all())) + + def test_batch_single_input_and_batch_multi_output(self): + def func(x): + return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian( + func, + self.x, ) + + for i in range(len(batch_jacobian)): + np.testing.assert_allclose(batch_jacobian[i].numpy(), + numerical_jacobian[i][0], self.rtol, + self.atol) + + def test_batch_multi_input_and_batch_single_output(self): + def func(x, y): + return x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + + for j in range(len(batch_jacobian)): + np.testing.assert_allclose(batch_jacobian[j].numpy(), + numerical_jacobian[0][j], self.rtol, + self.atol) + + def test_batch_multi_input_and_batch_multi_output(self): + def func(x, y): + return x * y, x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + + self.x.stop_gradient = False + self.y.stop_gradient = False + batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + + for i in range(len(batch_jacobian)): + np.testing.assert_allclose(batch_jacobian[i], numerical_jacobian[i], + self.rtol, self.atol) + + def test_allow_unused_false(self): + def func(x, y): + return x * x + + try: + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_allow_unused_true(self): + def func(x, y): + return x * x + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian( + func, [self.x, self.y], allow_unused=True) + + np.testing.assert_allclose( + jacobian[0].numpy(), numerical_jacobian[0][0], self.rtol, self.atol) + assert jacobian[1] is None + + def test_create_graph_false(self): + def func(x, y): + return x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) + for j in range(len(jacobian)): + assert jacobian[j].stop_gradient == True + np.testing.assert_allclose(jacobian[j].numpy(), + numerical_jacobian[0][j], self.rtol, + self.atol) + try: + paddle.grad(jacobian[0], [self.x, self.y]) + except RuntimeError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("has no gradient") > 0 + + def test_create_graph_true(self): + def func(x, y): + return x * y + + numerical_jacobian = _compute_numerical_batch_jacobian( + func, [self.x, self.y], self.numerical_delta, self.np_dtype) + self.x.stop_gradient = False + self.y.stop_gradient = False + jacobian = paddle.autograd.batch_jacobian( + func, [self.x, self.y], create_graph=True) + for j in range(len(jacobian)): + assert jacobian[j].stop_gradient == False + np.testing.assert_allclose(jacobian[j].numpy(), + numerical_jacobian[0][j], self.rtol, + self.atol) + double_grad = paddle.grad(jacobian[0], [self.x, self.y]) + assert double_grad is not None + + +class TestJacobianBatchFloat64(TestJacobianBatch): + @classmethod + def setUpClass(self): + self.x_shape = (12, 2) + self.weight_shape = (2, 12) + self.y_shape = (12, 2) + self.dtype = 'float64' + self.np_dtype = np.float64 + self.numerical_delta = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('eps') + self.rtol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('rtol') + self.atol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('atol') + self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) + self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) + self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py new file mode 100644 index 0000000000000..8801664fdca9a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_static.py @@ -0,0 +1,455 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +import unittest + +import numpy as np +import paddle +import paddle.fluid as fluid + +import config +import utils +from utils import (_compute_numerical_batch_jacobian, + _compute_numerical_jacobian) +from paddle.autograd.functional import _as_tensors + +paddle.enable_static() + + +@utils.place(config.DEVICES) +@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'stop_gradient'), ( + ('tensor_input', utils.reduce, np.random.rand(2, 3), None, False), + ('tensor_sequence_input', utils.reduce, np.random.rand(2, 3), None, False), + ('v_not_none', utils.reduce, np.random.rand(2, 3), np.random.rand(1), + False), + ('xs_stop_gradient', utils.reduce, np.random.rand(2, 3), np.random.rand(1), + True), + ('func_mutmul', utils.matmul, (np.random.rand(3, 2), np.random.rand(2, 3)), + None, False), + ('func_mul', utils.mul, (np.random.rand(3, 3), np.random.rand(3, 3)), None, + False), + ('func_out_two', utils.o2, (np.random.rand(10), np.random.rand(10)), None, + False), )) +class TestVJP(unittest.TestCase): + def setUp(self): + self.dtype = str(self.xs[0].dtype) if isinstance( + self.xs, typing.Sequence) else str(self.xs.dtype) + self._rtol = config.TOLERANCE.get(str(self.dtype)).get( + "first_order_grad").get("rtol") + self._atol = config.TOLERANCE.get(str(self.dtype)).get( + "first_order_grad").get("atol") + + def _vjp(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = gen_static_data_and_feed( + self.xs, self.v, stop_gradient=self.stop_gradient) + ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v) + exe.run(sp) + return exe.run(mp, feed=feed, fetch_list=[ys, xs_grads]) + + def _expected_vjp(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = gen_static_data_and_feed(self.xs, + self.v, False) + ys = self.fun(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun(static_xs) + xs_grads = paddle.static.gradients(ys, static_xs, static_v) + exe.run(sp) + return exe.run(mp, feed=feed, fetch_list=[ys, xs_grads]) + + def test_vjp(self): + actual = self._vjp() + expected = self._expected_vjp() + self.assertEqual(len(actual), len(expected)) + for i in range(len(actual)): + np.testing.assert_allclose( + actual[i], expected[i], rtol=self._rtol, atol=self._atol) + + +@utils.place(config.DEVICES) +@utils.parameterize( + (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'expected_exception'), ( + ('v_shape_not_equal_ys', utils.square, np.random.rand(3), + np.random.rand(1), RuntimeError), )) +class TestVJPException(unittest.TestCase): + def setUp(self): + self.exe = paddle.static.Executor() + + def _vjp(self): + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = gen_static_data_and_feed(self.xs, + self.v) + ys, xs_grads = paddle.autograd.vjp(self.fun, static_xs, static_v) + self.exe.run(sp) + return self.exe.run(mp, feed, fetch_list=[ys, xs_grads]) + + def test_vjp(self): + with self.assertRaises(self.expected_exception): + self._vjp() + + +def gen_static_data_and_feed(xs, v, stop_gradient=True): + feed = {} + if isinstance(xs, typing.Sequence): + static_xs = [] + for i, x in enumerate(xs): + x = paddle.static.data(f"x{i}", x.shape, x.dtype) + x.stop_gradient = stop_gradient + static_xs.append(x) + feed.update({f'x{idx}': value for idx, value in enumerate(xs)}) + else: + static_xs = paddle.static.data('x', xs.shape, xs.dtype) + static_xs.stop_gradient = stop_gradient + feed.update({'x': xs}) + + if isinstance(v, typing.Sequence): + static_v = [] + for i, e in enumerate(v): + e = paddle.static.data(f'v{idx}', v.shape, v.dtype) + e.stop_gradient = stop_gradient + static_v.append(e) + feed.update({f'v{idx}': value for idx, value in v}) + elif v is not None: + static_v = paddle.static.data('v', v.shape, v.dtype) + static_v.stop_gradient = stop_gradient + feed.update({'v': v}) + else: + static_v = v + + return feed, static_xs, static_v + + +def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False): + r"""Computes an approximate Jacobian matrix of a multi-valued function + using finite differences. + + The function input is required to be an np array or a list of list of np + arrays. + """ + + def flatten(x): + if len(x.shape) > 0: + to = [x.shape[0], -1] if batch else [-1] + return x.reshape(to) + else: + return x + + def flatten_all(xs): + if isinstance(xs, list): + flattened = np.concatenate([flatten(x) for x in xs], axis=-1) + else: + flattened = flatten(xs) + return flattened + + def x_like(x, orig_x): + return x.reshape(orig_x.shape) + + def _f(x): + if multi_inps: + _xs = np.split(x, splits, axis=-1) + _xs = [x_like(_x, _o) for _x, _o in zip(_xs, xs)] + outs = f(_xs) + else: + outs = f(x) + return flatten_all(outs) + + multi_inps = False if isinstance(xs, np.ndarray) else True + x = flatten_all(xs) + xdim = x.shape[-1] + splits = [] + + if multi_inps: + split = 0 + for inp in xs: + split += flatten(inp).shape[-1] + splits.append(split) + + ds = eps * np.eye(xdim, dtype=dtype) + + fprimes_by_x = [(0.5 * (_f(x + d) - _f(x - d)) / eps) for d in ds] + fprimes_by_y = np.stack(fprimes_by_x, axis=-1) + return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y + + +def make_tensors(inps): + if isinstance(inps, list): + xs = [ + paddle.static.data( + f'x{i}', inp.shape, dtype=inp.dtype) + for i, inp in enumerate(inps) + ] + else: + xs = paddle.static.data(name='x', shape=inps.shape, dtype=inps.dtype) + return xs + + +all_data_shapes = { + 'A': [[1., 2.]], + 'B': [[1., 2.], [2., 1.]], + 'C': [[2., 2.], [2., 1.]], + 'D': [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]], + 'E': [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]], +} + + +def prepare_data(test, input_shapes, dtype): + for name, shape in input_shapes.items(): + setattr(test, name, np.array(shape, dtype=dtype)) + + +class TestJacobianFloat32(unittest.TestCase): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.dtype = 'float32' + self.np_dtype = np.float32 + prepare_data(self, all_data_shapes, self.dtype) + self.eps = config.TOLERANCE.get(self.dtype).get('first_order_grad').get( + 'eps') + # self.rtol = config.TOLERANCE.get(self.dtype).get('first_order_grad').get('rtol') + # self.atol = config.TOLERANCE.get(self.dtype).get('first_order_grad').get('atol') + # Do't use tolerance in config, which will cause this test case failed. + self.rtol = 1e-2 + self.atol = 1e-2 + + def run_test_by_fullmatrix(self, pd_f, np_f, inps, batch=False): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch) + if batch: + _, nrow, ncol = JJ.shape + else: + nrow, ncol = JJ.shape + full_jacobian = JJ[:] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] + np_jacobians = approx_jacobian( + np_f, inps, self.dtype, self.eps, batch=batch) + if batch: + np_jacobians = utils._np_transpose_matrix_format( + np_jacobians, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM) + + np.testing.assert_allclose(pd_jacobians, np_jacobians, self.rtol, + self.atol) + + def run_test_by_rows(self, pd_f, np_f, inps, batch=False): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch) + if batch: + nbatch, nrow, ncol = JJ.shape + rows = [JJ[:, i, :] for i in range(nrow)] + else: + nrow, ncol = JJ.shape + rows = [JJ[i, :] for i in range(nrow)] + + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_jac = exe.run(main, feed=feeds, fetch_list=[rows]) + np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch) + for i in range(nrow): + np.testing.assert_allclose(pd_jac[i], np_jac[i], self.rtol, + self.atol) + + def run_test_by_entries(self, pd_f, np_f, inps, batch=False): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + JJ = paddle.autograd.functional.Jacobian(pd_f, xs, is_batched=batch) + if batch: + nbatch, nrow, ncol = JJ.shape + entries = [ + JJ[:, i, j] for i in range(nrow) for j in range(ncol) + ] + else: + nrow, ncol = JJ.shape + entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_entries = exe.run(main, feed=feeds, fetch_list=[entries]) + np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch) + np_entries = [ + np_jac[i, ..., j] for i in range(nrow) for j in range(ncol) + ] + for pd_entry, np_entry in zip(pd_entries, np_entries): + np.testing.assert_allclose(pd_entry, np_entry, self.rtol, self.atol) + + def test_square(self): + def pd_f(x): + return paddle.multiply(x, x) + + def np_f(x): + return np.multiply(x, x) + + self.run_test_by_fullmatrix(pd_f, np_f, self.A) + self.run_test_by_rows(pd_f, np_f, self.A) + self.run_test_by_entries(pd_f, np_f, self.A) + + def test_mul(self): + def pd_f(x, y): + return paddle.multiply(x, y) + + def np_f(xs): + x, y = xs + return np.multiply(x, y) + + self.run_test_by_fullmatrix( + pd_f, + np_f, + [self.B, self.C], ) + self.run_test_by_rows(pd_f, np_f, [self.B, self.C]) + self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) + + def test_matmul(self): + def pd_f(x, y): + return paddle.matmul(x, y) + + def np_f(xs): + x, y = xs + return np.matmul(x, y) + + self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C]) + self.run_test_by_rows(pd_f, np_f, [self.B, self.C]) + self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) + + def test_batch_matmul(self): + def pd_f(x, y): + return paddle.matmul(x, y) + + def np_f(xs): + x, y = xs + return np.matmul(x, y) + + self.run_test_by_fullmatrix(pd_f, np_f, [self.D, self.E], batch=True) + self.run_test_by_rows(pd_f, np_f, [self.D, self.E], batch=True) + self.run_test_by_entries(pd_f, np_f, [self.D, self.E], batch=True) + + +class TestJacobianFloat64(TestJacobianFloat32): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.dtype = 'float64' + prepare_data(self, all_data_shapes, self.dtype) + self.eps = config.TOLERANCE.get(self.dtype).get('first_order_grad').get( + 'eps') + self.rtol = config.TOLERANCE.get(self.dtype).get( + 'first_order_grad').get('rtol') + self.atol = config.TOLERANCE.get(self.dtype).get( + 'first_order_grad').get('atol') + + +class TestHessianFloat32(unittest.TestCase): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.dtype = 'float32' + prepare_data(self, all_data_shapes, self.dtype) + self.eps = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('eps') + self.rtol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('rtol') + self.atol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('atol') + + def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + xs = make_tensors(inps) + HH = paddle.autograd.functional.Hessian(pd_f, xs, is_batched=batch) + nrow, ncol = HH.shape + full_hessian = HH[:] + exe = fluid.Executor(self.place) + exe.run(startup) + if isinstance(inps, list): + feeds = {f'x{i}': x for i, x in enumerate(inps)} + else: + feeds = {'x': inps} + pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0] + np.testing.assert_allclose(pd_hess, np_hess, self.rtol, self.atol) + + def test_square(self): + def pd_f(x): + """Input is a square matrix.""" + return paddle.matmul(x, x.T).flatten().sum() + + def np_hess(x): + dim = x.shape[0] + upperleft = 2 * np.eye(dim, dtype=self.dtype) + upper = np.concatenate((upperleft, upperleft)) + return np.concatenate((upper, upper), axis=1) + + self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B)) + + +class TestHessianFloat64(TestHessianFloat32): + @classmethod + def setUpClass(self): + paddle.enable_static() + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + self.dtype = 'float64' + prepare_data(self, all_data_shapes, self.dtype) + self.eps = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('eps') + self.rtol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('rtol') + self.atol = config.TOLERANCE.get(self.dtype).get( + 'second_order_grad').get('atol') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_static.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_static.py deleted file mode 100644 index 60dc9d06b8a7f..0000000000000 --- a/python/paddle/fluid/tests/unittests/autograd/test_autograd_static.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -import paddle -import paddle.fluid as fluid -from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian - - -def approx_jacobian(f, xs, dtype, eps=1e-5, batch=False): - r"""Computes an approximate Jacobian matrix of a multi-valued function - using finite differences. - - The function input is required to be an np array or a list of list of np - arrays. - """ - - def flatten(x): - if len(x.shape) > 0: - to = [x.shape[0], -1] if batch else [-1] - return x.reshape(to) - else: - return x - - def flatten_all(xs): - if isinstance(xs, list): - flattened = np.concatenate([flatten(x) for x in xs], axis=-1) - else: - flattened = flatten(xs) - return flattened - - def x_like(x, orig_x): - return x.reshape(orig_x.shape) - - def _f(x): - if multi_inps: - _xs = np.split(x, splits, axis=-1) - _xs = [x_like(_x, _o) for _x, _o in zip(_xs, xs)] - outs = f(_xs) - else: - outs = f(x) - return flatten_all(outs) - - multi_inps = False if isinstance(xs, np.ndarray) else True - x = flatten_all(xs) - xdim = x.shape[-1] - splits = [] - - if multi_inps: - split = 0 - for inp in xs: - split += flatten(inp).shape[-1] - splits.append(split) - - ds = eps * np.eye(xdim, dtype=dtype) - - fprimes_by_x = [(0.5 * (_f(x + d) - _f(x - d)) / eps) for d in ds] - fprimes_by_y = np.stack(fprimes_by_x, axis=-1) - return np.transpose(fprimes_by_y, [1, 0, 2]) if batch else fprimes_by_y - - -def make_tensors(inps): - if isinstance(inps, list): - xs = [ - paddle.static.data( - f'x{i}', inp.shape, dtype=inp.dtype) - for i, inp in enumerate(inps) - ] - else: - xs = paddle.static.data(name='x', shape=inps.shape, dtype=inps.dtype) - return xs - - -all_data_shapes = { - 'A': [[1., 2.]], - 'B': [[1., 2.], [2., 1.]], - 'C': [[2., 2.], [2., 1.]], - 'D': [[[2., 2.], [2., 1.]], [[1., 2.], [2., 1.]]], - 'E': [[[3., 4.], [2., 3.]], [[2., 1.], [1., 3.]]], -} - - -def prepare_data(test, input_shapes, dtype): - for name, shape in input_shapes.items(): - setattr(test, name, np.array(shape, dtype=dtype)) - - -class TestJacobianFloat32(unittest.TestCase): - @classmethod - def setUpClass(self): - paddle.enable_static() - if fluid.core.is_compiled_with_cuda(): - self.place = fluid.CUDAPlace(0) - else: - self.place = fluid.CPUPlace() - self.dtype = 'float32' - prepare_data(self, all_data_shapes, self.dtype) - self.eps = 1e-4 - self.rtol = 1e-2 - self.atol = 1e-2 - - def run_test_by_fullmatrix(self, pd_f, np_f, inps, batch=False): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - xs = make_tensors(inps) - JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) - nrow, ncol = JJ.shape() - full_jacobian = JJ[:] - exe = fluid.Executor(self.place) - exe.run(startup) - if isinstance(inps, list): - feeds = {f'x{i}': x for i, x in enumerate(inps)} - else: - feeds = {'x': inps} - pd_jacobians = exe.run(main, feed=feeds, fetch_list=[full_jacobian])[0] - np_jacobians = approx_jacobian( - np_f, inps, self.dtype, self.eps, batch=batch) - self.assertTrue( - np.allclose(pd_jacobians, np_jacobians, self.rtol, self.atol)) - - def run_test_by_rows(self, pd_f, np_f, inps, batch=False): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - xs = make_tensors(inps) - JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) - nrow, ncol = JJ.shape() - rows = [JJ[i] for i in range(nrow)] - exe = fluid.Executor(self.place) - exe.run(startup) - if isinstance(inps, list): - feeds = {f'x{i}': x for i, x in enumerate(inps)} - else: - feeds = {'x': inps} - pd_jac = exe.run(main, feed=feeds, fetch_list=[rows]) - np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch) - for i in range(nrow): - self.assertTrue( - np.allclose(pd_jac[i], np_jac[i], self.rtol, self.atol)) - - def run_test_by_entries(self, pd_f, np_f, inps, batch=False): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - xs = make_tensors(inps) - JJ = paddle.autograd.functional.Jacobian(pd_f, xs, batch=batch) - nrow, ncol = JJ.shape() - entries = [JJ[i, j] for i in range(nrow) for j in range(ncol)] - exe = fluid.Executor(self.place) - exe.run(startup) - if isinstance(inps, list): - feeds = {f'x{i}': x for i, x in enumerate(inps)} - else: - feeds = {'x': inps} - pd_entries = exe.run(main, feed=feeds, fetch_list=[entries]) - np_jac = approx_jacobian(np_f, inps, self.dtype, self.eps, batch=batch) - np_entries = [ - np_jac[i, ..., j] for i in range(nrow) for j in range(ncol) - ] - for pd_entry, np_entry in zip(pd_entries, np_entries): - self.assertTrue( - np.allclose(pd_entry, np_entry, self.rtol, self.atol)) - - def test_square(self): - def pd_f(x): - return paddle.multiply(x, x) - - def np_f(x): - return np.multiply(x, x) - - self.run_test_by_fullmatrix(pd_f, np_f, self.A) - self.run_test_by_rows(pd_f, np_f, self.A) - self.run_test_by_entries(pd_f, np_f, self.A) - - def test_mul(self): - def pd_f(xs): - x, y = xs - return paddle.multiply(x, y) - - def np_f(xs): - x, y = xs - return np.multiply(x, y) - - self.run_test_by_fullmatrix( - pd_f, - np_f, - [self.B, self.C], ) - self.run_test_by_rows(pd_f, np_f, [self.B, self.C]) - self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) - - def test_matmul(self): - def pd_f(xs): - x, y = xs - return paddle.matmul(x, y) - - def np_f(xs): - x, y = xs - return np.matmul(x, y) - - self.run_test_by_fullmatrix(pd_f, np_f, [self.B, self.C]) - self.run_test_by_rows(pd_f, np_f, [self.B, self.C]) - self.run_test_by_entries(pd_f, np_f, [self.B, self.C]) - - def test_batch_matmul(self): - def pd_f(xs): - x, y = xs - return paddle.matmul(x, y) - - def np_f(xs): - x, y = xs - return np.matmul(x, y) - - self.run_test_by_fullmatrix(pd_f, np_f, [self.D, self.E], batch=True) - self.run_test_by_rows(pd_f, np_f, [self.D, self.E], batch=True) - self.run_test_by_entries(pd_f, np_f, [self.D, self.E], batch=True) - - -class TestJacobianFloat64(TestJacobianFloat32): - @classmethod - def setUpClass(self): - paddle.enable_static() - if fluid.core.is_compiled_with_cuda(): - self.place = fluid.CUDAPlace(0) - else: - self.place = fluid.CPUPlace() - self.dtype = 'float64' - prepare_data(self, all_data_shapes, self.dtype) - self.eps = 1e-7 - self.rtol = 1e-6 - self.atol = 1e-6 - - -class TestHessianFloat64(unittest.TestCase): - @classmethod - def setUpClass(self): - paddle.enable_static() - if fluid.core.is_compiled_with_cuda(): - self.place = fluid.CUDAPlace(0) - else: - self.place = fluid.CPUPlace() - self.dtype = 'float64' - prepare_data(self, all_data_shapes, self.dtype) - self.eps = 1e-7 - self.rtol = 1e-6 - self.atol = 1e-6 - - def run_test_by_fullmatrix(self, pd_f, inps, np_hess, batch=False): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - xs = make_tensors(inps) - HH = paddle.autograd.functional.Hessian(pd_f, xs, batch=batch) - nrow, ncol = HH.shape() - full_hessian = HH[:] - exe = fluid.Executor(self.place) - exe.run(startup) - if isinstance(inps, list): - feeds = {f'x{i}': x for i, x in enumerate(inps)} - else: - feeds = {'x': inps} - pd_hess = exe.run(main, feed=feeds, fetch_list=[full_hessian])[0] - self.assertTrue(np.allclose(pd_hess, np_hess, self.rtol, self.atol)) - - def test_square(self): - def pd_f(x): - """Input is a square matrix.""" - return paddle.matmul(x, x.T) - - def np_hess(x): - dim = x.shape[0] - f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype) - f_xx = np.zeros([dim * dim, dim * dim], dtype=self.dtype) - f_xx[:dim, :dim] = f_xx_upperleft - return f_xx - - self.run_test_by_fullmatrix(pd_f, self.B, np_hess(self.B)) - - def test_batch_square(self): - def pd_f(x): - """Input is a square matrix.""" - return paddle.matmul(x, paddle.transpose(x, [0, 2, 1])) - - def np_hess(x): - bat, dim, _ = x.shape - f_xx_upperleft = 2 * np.eye(dim, dtype=self.dtype) - f_xx = np.zeros([bat, dim * dim, dim * dim], dtype=self.dtype) - f_xx[..., :dim, :dim] = f_xx_upperleft - return f_xx - - self.run_test_by_fullmatrix( - pd_f, self.E, np_hess(self.E), batch=True) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_hessian.py b/python/paddle/fluid/tests/unittests/autograd/test_hessian.py deleted file mode 100644 index 7b3bd9fd55932..0000000000000 --- a/python/paddle/fluid/tests/unittests/autograd/test_hessian.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -import paddle -import paddle.compat as cpt -import paddle.nn.functional as F -from utils import _compute_numerical_hessian, _compute_numerical_batch_hessian - - -class TestHessian(unittest.TestCase): - @classmethod - def setUpClass(self): - self.shape = (2, 2) - self.dtype = 'float32' - self.np_dtype = np.float32 - self.numerical_delta = 1e-2 - self.rtol = 1e-2 - self.atol = 1e-2 - self.x = paddle.rand(shape=self.shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - - def test_single_input(self): - def func(x): - return paddle.sum(paddle.matmul(x, x)) - - numerical_hessian = _compute_numerical_hessian( - func, self.x, self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - hessian = paddle.autograd.hessian(func, self.x) - assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol, - self.atol) - - def test_multi_input(self): - def func(x, y): - return paddle.sum(paddle.matmul(x, y)) - - numerical_hessian = _compute_numerical_hessian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - self.y.stop_gradient = False - hessian = paddle.autograd.hessian(func, [self.x, self.y]) - for i in range(len(hessian)): - for j in range(len(hessian[0])): - assert np.allclose(hessian[i][j].numpy(), - numerical_hessian[i][j], self.rtol, - self.atol) - - def test_allow_unused_false(self): - def func(x, y): - return paddle.sum(paddle.matmul(x, x)) - - try: - self.x.stop_gradient = False - self.y.stop_gradient = False - hessian = paddle.autograd.hessian(func, [self.x, self.y]) - except ValueError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("allow_unused") > 0 - - def test_allow_unused_true(self): - def func(x, y): - return paddle.sum(paddle.matmul(x, x)) - - numerical_hessian = _compute_numerical_hessian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - hessian = paddle.autograd.hessian( - func, [self.x, self.y], allow_unused=True) - for i in range(len(hessian)): - for j in range(len(hessian[0])): - if i == j == 0: - assert np.allclose(hessian[i][j].numpy(), - numerical_hessian[i][j], self.rtol, - self.atol) - else: - assert hessian[i][j] is None - - def test_create_graph_false(self): - def func(x): - return paddle.sum(paddle.matmul(x, x)) - - numerical_hessian = _compute_numerical_hessian( - func, self.x, self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - hessian = paddle.autograd.hessian(func, self.x) - assert hessian.stop_gradient == True - assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol, - self.atol) - try: - paddle.grad(hessian, self.x) - except RuntimeError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("has no gradient") > 0 - - def test_create_graph_true(self): - def func(x): - return paddle.sum(F.sigmoid(x)) - - numerical_hessian = _compute_numerical_hessian( - func, self.x, self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - hessian = paddle.autograd.hessian(func, self.x, create_graph=True) - assert hessian.stop_gradient == False - assert np.allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol, - self.atol) - triple_grad = paddle.grad(hessian, self.x) - assert triple_grad is not None - - -class TestHessianFloat64(TestHessian): - @classmethod - def setUpClass(self): - self.shape = (2, 2) - self.dtype = 'float64' - self.np_dtype = np.float64 - self.numerical_delta = 1e-5 - self.rtol = 1e-5 - self.atol = 1e-5 - self.x = paddle.rand(shape=self.shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - - -class TestBatchHessian(unittest.TestCase): - @classmethod - def setUpClass(self): - self.x_shape = (5, 2) - self.weight_shape = (2, 4) - self.y_shape = (5, 2) - self.dtype = 'float32' - self.np_dtype = np.float32 - self.numerical_delta = 1e-2 - self.rtol = 1e-3 - self.atol = 1e-3 - self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) - self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) - - def test_single_input(self): - def func(x): - return paddle.matmul(x * x, self.weight)[:, 0:1] - - numerical_hessian = _compute_numerical_batch_hessian( - func, self.x, self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True) - assert np.allclose(hessian, numerical_hessian, self.rtol, self.atol) - - def test_multi_input(self): - def func(x, y): - return paddle.matmul(x * x * y * y, self.weight)[:, 0:1] - - numerical_hessian = _compute_numerical_batch_hessian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - self.y.stop_gradient = False - hessian = paddle.autograd.batch_hessian(func, [self.x, self.y]) - - shape_tensor = paddle.to_tensor(numerical_hessian).astype("float64") - hessian_reshape = np.reshape(hessian, (shape_tensor.shape)) - assert np.allclose(hessian_reshape, numerical_hessian, self.rtol, - self.atol) - - def test_allow_unused_false(self): - def func(x, y): - return paddle.matmul(x * x, self.weight)[:, 0:1] - - try: - self.x.stop_gradient = False - self.y.stop_gradient = False - hessian = paddle.autograd.batch_hessian(func, [self.x, self.y]) - except ValueError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("allow_unused") > 0 - - def test_allow_unused_true(self): - def func(x, y): - return paddle.matmul(x * x, self.weight)[:, 0:1] - - numerical_hessian = _compute_numerical_batch_hessian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - hessian = paddle.autograd.batch_hessian( - func, [self.x, self.y], allow_unused=True) - - for i in range(len(hessian)): - for j in range(len(hessian[0])): - if i == j == 0: - numerical_hessian = np.stack( - (numerical_hessian[i][j], numerical_hessian[i][j + 1]), - axis=0) - assert np.allclose(hessian[i][j], numerical_hessian, - self.rtol, self.atol) - else: - assert hessian[i][j] is None - - def test_create_graph_false(self): - def func(x): - return paddle.matmul(x * x, self.weight)[:, 0:1] - - numerical_hessian = _compute_numerical_batch_hessian( - func, self.x, self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - hessian = paddle.autograd.batch_hessian(func, self.x) - assert hessian.stop_gradient == True - assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol, - self.atol) - try: - paddle.grad(hessian, self.x) - except RuntimeError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("has no gradient") > 0 - - def test_create_graph_true(self): - def func(x): - return paddle.matmul(x * x, self.weight)[:, 0:1] - - numerical_hessian = _compute_numerical_batch_hessian( - func, self.x, self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - hessian = paddle.autograd.batch_hessian(func, self.x, create_graph=True) - assert hessian.stop_gradient == False - assert np.allclose(hessian.numpy(), numerical_hessian, self.rtol, - self.atol) - triple_grad = paddle.grad(hessian, self.x) - assert triple_grad is not None - - -class TestBatchHessianFloat64(TestBatchHessian): - @classmethod - def setUpClass(self): - self.x_shape = (5, 2) - self.weight_shape = (2, 4) - self.y_shape = (5, 2) - self.dtype = 'float64' - self.np_dtype = np.float64 - self.numerical_delta = 1e-4 - self.rtol = 1e-5 - self.atol = 1e-5 - self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) - self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py b/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py deleted file mode 100644 index 335ea4e519bef..0000000000000 --- a/python/paddle/fluid/tests/unittests/autograd/test_jacobian.py +++ /dev/null @@ -1,319 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -import paddle -import paddle.compat as cpt -from utils import _compute_numerical_jacobian, _compute_numerical_batch_jacobian - - -class TestJacobian(unittest.TestCase): - @classmethod - def setUpClass(self): - self.shape = (4, 4) - self.dtype = 'float32' - self.np_dtype = np.float32 - self.numerical_delta = 1e-4 - self.rtol = 1e-3 - self.atol = 1e-3 - self.x = paddle.rand(shape=self.shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - - def test_single_input_and_single_output(self): - def func(x): - return paddle.matmul(x, x) - - numerical_jacobian = _compute_numerical_jacobian( - func, self.x, self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - jacobian = paddle.autograd.jacobian(func, self.x) - assert np.allclose(jacobian.numpy(), numerical_jacobian[0][0], - self.rtol, self.atol) - - def test_single_input_and_multi_output(self): - def func(x): - return paddle.matmul(x, x), x * x - - numerical_jacobian = _compute_numerical_jacobian( - func, self.x, self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - jacobian = paddle.autograd.jacobian(func, self.x) - for i in range(len(jacobian)): - assert np.allclose(jacobian[i].numpy(), numerical_jacobian[i][0], - self.rtol, self.atol) - - def test_multi_input_and_single_output(self): - def func(x, y): - return paddle.matmul(x, y) - - numerical_jacobian = _compute_numerical_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) - for j in range(len(jacobian)): - assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j], - self.rtol, self.atol) - - def test_multi_input_and_multi_output(self): - def func(x, y): - return paddle.matmul(x, y), x * y - - numerical_jacobian = _compute_numerical_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) - for i in range(len(jacobian)): - for j in range(len(jacobian[0])): - assert np.allclose(jacobian[i][j].numpy(), - numerical_jacobian[i][j], self.rtol, - self.atol) - - def test_allow_unused_false(self): - def func(x, y): - return paddle.matmul(x, x) - - try: - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) - except ValueError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("allow_unused") > 0 - - def test_allow_unused_true(self): - def func(x, y): - return paddle.matmul(x, x) - - numerical_jacobian = _compute_numerical_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.jacobian( - func, [self.x, self.y], allow_unused=True) - assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0], - self.rtol, self.atol) - assert jacobian[1] is None - - def test_create_graph_false(self): - def func(x, y): - return paddle.matmul(x, y) - - numerical_jacobian = _compute_numerical_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.jacobian(func, [self.x, self.y]) - for j in range(len(jacobian)): - assert jacobian[j].stop_gradient == True - assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j], - self.rtol, self.atol) - try: - paddle.grad(jacobian[0], [self.x, self.y]) - except RuntimeError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("has no gradient") > 0 - - def test_create_graph_true(self): - def func(x, y): - return paddle.matmul(x, y) - - numerical_jacobian = _compute_numerical_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.jacobian( - func, [self.x, self.y], create_graph=True) - for j in range(len(jacobian)): - assert jacobian[j].stop_gradient == False - assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j], - self.rtol, self.atol) - double_grad = paddle.grad(jacobian[0], [self.x, self.y]) - assert double_grad is not None - - -class TestJacobianFloat64(TestJacobian): - @classmethod - def setUpClass(self): - self.shape = (4, 4) - self.dtype = 'float64' - self.np_dtype = np.float64 - self.numerical_delta = 1e-7 - self.rtol = 1e-7 - self.atol = 1e-7 - self.x = paddle.rand(shape=self.shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - - -class TestJacobianBatch(unittest.TestCase): - @classmethod - def setUpClass(self): - self.x_shape = (4, 2) - self.weight_shape = (2, 4) - self.y_shape = (4, 2) - self.dtype = 'float32' - self.np_dtype = np.float32 - self.numerical_delta = 1e-4 - self.rtol = 1e-3 - self.atol = 1e-3 - self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) - self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) - - def test_batch_single_input_and_batch_single_output(self): - def func(x): - return paddle.matmul(paddle.matmul(x, self.weight), self.y) - - numerical_jacobian = _compute_numerical_batch_jacobian( - func, [self.x], self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - batch_jacobian = paddle.autograd.batch_jacobian( - func, - self.x, ) - - self.assertTrue( - np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0] - .all())) - - def test_batch_single_input_and_batch_multi_output(self): - def func(x): - return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x - - numerical_jacobian = _compute_numerical_batch_jacobian( - func, [self.x], self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - batch_jacobian = paddle.autograd.batch_jacobian( - func, - self.x, ) - - for i in range(len(batch_jacobian)): - assert np.allclose(batch_jacobian[i].numpy(), - numerical_jacobian[i][0], self.rtol, self.atol) - - def test_batch_multi_input_and_batch_single_output(self): - def func(x, y): - return x * y - - numerical_jacobian = _compute_numerical_batch_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - self.y.stop_gradient = False - batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) - - for j in range(len(batch_jacobian)): - assert np.allclose(batch_jacobian[j].numpy(), - numerical_jacobian[0][j], self.rtol, self.atol) - - def test_batch_multi_input_and_batch_multi_output(self): - def func(x, y): - return x * y, x * y - - numerical_jacobian = _compute_numerical_batch_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - self.y.stop_gradient = False - batch_jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) - - for i in range(len(batch_jacobian)): - assert np.allclose(batch_jacobian[i], numerical_jacobian[i], - self.rtol, self.atol) - - def test_allow_unused_false(self): - def func(x, y): - return x * x - - try: - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) - except ValueError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("allow_unused") > 0 - - def test_allow_unused_true(self): - def func(x, y): - return x * x - - numerical_jacobian = _compute_numerical_batch_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.batch_jacobian( - func, [self.x, self.y], allow_unused=True) - - assert np.allclose(jacobian[0].numpy(), numerical_jacobian[0][0], - self.rtol, self.atol) - assert jacobian[1] is None - - def test_create_graph_false(self): - def func(x, y): - return x * y - - numerical_jacobian = _compute_numerical_batch_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.batch_jacobian(func, [self.x, self.y]) - for j in range(len(jacobian)): - assert jacobian[j].stop_gradient == True - assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j], - self.rtol, self.atol) - try: - paddle.grad(jacobian[0], [self.x, self.y]) - except RuntimeError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("has no gradient") > 0 - - def test_create_graph_true(self): - def func(x, y): - return x * y - - numerical_jacobian = _compute_numerical_batch_jacobian( - func, [self.x, self.y], self.numerical_delta, self.np_dtype) - self.x.stop_gradient = False - self.y.stop_gradient = False - jacobian = paddle.autograd.batch_jacobian( - func, [self.x, self.y], create_graph=True) - for j in range(len(jacobian)): - assert jacobian[j].stop_gradient == False - assert np.allclose(jacobian[j].numpy(), numerical_jacobian[0][j], - self.rtol, self.atol) - double_grad = paddle.grad(jacobian[0], [self.x, self.y]) - assert double_grad is not None - - -class TestJacobianBatchFloat64(TestJacobianBatch): - @classmethod - def setUpClass(self): - self.x_shape = (12, 2) - self.weight_shape = (2, 12) - self.y_shape = (12, 2) - self.dtype = 'float64' - self.np_dtype = np.float64 - self.numerical_delta = 1e-7 - self.rtol = 1e-7 - self.atol = 1e-7 - self.x = paddle.rand(shape=self.x_shape, dtype=self.dtype) - self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_vhp.py b/python/paddle/fluid/tests/unittests/autograd/test_vhp.py deleted file mode 100644 index 09b25203e04a4..0000000000000 --- a/python/paddle/fluid/tests/unittests/autograd/test_vhp.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -import paddle -import paddle.compat as cpt -import paddle.nn.functional as F -from utils import _compute_numerical_vhp - - -class TestVHP(unittest.TestCase): - @classmethod - def setUpClass(self): - self.shape = (2, 2) - self.dtype = 'float32' - self.np_dtype = np.float32 - self.numerical_delta = 1e-2 - self.rtol = 1e-2 - self.atol = 1e-2 - self.x = paddle.rand(shape=self.shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - self.vx = paddle.rand(shape=self.shape, dtype=self.dtype) - self.vy = paddle.rand(shape=self.shape, dtype=self.dtype) - - def test_single_input(self): - def func(x): - return paddle.sum(paddle.matmul(x, x)) - - numerical_func_output = func(self.x).numpy() - numerical_vhp = _compute_numerical_vhp( - func, self.x, self.vx, self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx) - assert np.allclose(func_output.numpy(), numerical_func_output, - self.rtol, self.atol) - assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, - self.atol) - - def test_multi_input(self): - def func(x, y): - return paddle.sum(paddle.matmul(x, y)) - - numerical_func_output = func(self.x, self.y).numpy() - numerical_vhp = _compute_numerical_vhp( - func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta, - self.np_dtype) - - self.x.stop_gradient = False - self.y.stop_gradient = False - func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y], - [self.vx, self.vy]) - assert np.allclose(func_output.numpy(), numerical_func_output, - self.rtol, self.atol) - for i in range(len(vhp)): - assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol, - self.atol) - - def test_v_default(self): - def func(x, y): - return paddle.sum(paddle.matmul(x, y)) - - numerical_func_output = func(self.x, self.y).numpy() - vx = paddle.ones(self.vx.shape, dtype=self.vx.dtype) - vy = paddle.ones(self.vy.shape, dtype=self.vy.dtype) - numerical_vhp = _compute_numerical_vhp(func, [self.x, self.y], - [vx, vy], self.numerical_delta, - self.np_dtype) - - self.x.stop_gradient = False - self.y.stop_gradient = False - func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y]) - assert np.allclose(func_output.numpy(), numerical_func_output, - self.rtol, self.atol) - for i in range(len(vhp)): - assert np.allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol, - self.atol) - - def test_allow_unused_false(self): - def func(x, y): - return paddle.sum(paddle.matmul(x, x)) - - try: - self.x.stop_gradient = False - self.y.stop_gradient = False - _ = paddle.autograd.vhp(func, [self.x, self.y]) - except ValueError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("allow_unused") > 0 - - def test_allow_unused_true(self): - def func(x, y): - return paddle.sum(paddle.matmul(x, x)) - - numerical_func_output = func(self.x, self.y).numpy() - numerical_vhp = _compute_numerical_vhp( - func, [self.x, self.y], [self.vx, self.vy], self.numerical_delta, - self.np_dtype) - - self.x.stop_gradient = False - self.y.stop_gradient = False - func_output, vhp = paddle.autograd.vhp(func, [self.x, self.y], - [self.vx, self.vy], - allow_unused=True) - assert np.allclose(func_output.numpy(), numerical_func_output, - self.rtol, self.atol) - assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, - self.atol) - assert vhp[1] is None - - def test_create_graph_false(self): - def func(x): - return paddle.sum(F.sigmoid(x)) - - numerical_func_output = func(self.x).numpy() - numerical_vhp = _compute_numerical_vhp( - func, self.x, self.vx, self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - func_output, vhp = paddle.autograd.vhp(func, self.x, self.vx) - assert np.allclose(func_output.numpy(), numerical_func_output, - self.rtol, self.atol) - assert vhp[0].stop_gradient == True - assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, - self.atol) - try: - paddle.grad(vhp, self.x) - except RuntimeError as e: - error_msg = cpt.get_exception_message(e) - assert error_msg.find("has no gradient") > 0 - - def test_create_graph_true(self): - def func(x): - return paddle.sum(F.sigmoid(x)) - - numerical_func_output = func(self.x).numpy() - numerical_vhp = _compute_numerical_vhp( - func, self.x, self.vx, self.numerical_delta, self.np_dtype) - - self.x.stop_gradient = False - func_output, vhp = paddle.autograd.vhp(func, - self.x, - self.vx, - create_graph=True) - assert np.allclose(func_output.numpy(), numerical_func_output, - self.rtol, self.atol) - assert vhp[0].stop_gradient == False - assert np.allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, - self.atol) - triple_grad = paddle.grad(vhp, self.x) - assert triple_grad is not None - - -class TestVHPFloat64(TestVHP): - @classmethod - def setUpClass(self): - self.shape = (2, 2) - self.dtype = 'float64' - self.np_dtype = np.float64 - self.numerical_delta = 1e-5 - self.rtol = 1e-5 - self.atol = 1e-5 - self.x = paddle.rand(shape=self.shape, dtype=self.dtype) - self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - self.vx = paddle.rand(shape=self.shape, dtype=self.dtype) - self.vy = paddle.rand(shape=self.shape, dtype=self.dtype) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py b/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py deleted file mode 100644 index c228ad79321d4..0000000000000 --- a/python/paddle/fluid/tests/unittests/autograd/test_vjp_jvp.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import paddle - -from paddle.autograd.functional import vjp, jvp, _tensors -from paddle import grad, ones_like, zeros_like - - -def reduce(x): - return paddle.sum(x) - - -def reduce_dim(x): - return paddle.sum(x, axis=0) - - -def matmul(x, y): - return paddle.matmul(x, y) - - -def mul(x, y): - return x * y - - -def pow(x, y): - return paddle.pow(x, y) - - -def o2(x, y): - return paddle.multiply(x, y), paddle.matmul(x, y.t()) - - -def unuse(x, y): - return paddle.sum(x) - - -def nested(x): - def inner(y): - return x * y - - return inner - - -def make_v(f, inputs): - outputs = _tensors(f(*inputs), "outputs") - return [ones_like(x) for x in outputs] - - -class TestAutogradFunctional(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.RAW_INPUTS = { - 'a': [1.0], - 'b': [1.0, 2.0], - 'c': [3.0, 4.0], - 'd': [[2.0], [3.0]], - 'A': [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], - 'B': [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], - } - - def setUp(self): - pass - - def gen_input(self, inp, stop_gradient=False): - if isinstance(inp, paddle.Tensor): - return inp - return paddle.to_tensor( - self.RAW_INPUTS[inp], stop_gradient=stop_gradient) - - def gen_inputs(self, inputs): - if isinstance(inputs, list): - inputs = [self.gen_input(x) for x in inputs] - else: - inputs = [self.gen_input(inputs)] - return inputs - - def gen_test_pairs(self, - func, - inputs, - v=None, - create_graph=False, - allow_unused=False): - def vjp_test(): - nonlocal v - xs = self.gen_inputs(inputs) - if v is not None: - v = self.gen_inputs(v) - outputs, inputs_grad = vjp(func, - xs, - v, - create_graph=create_graph, - allow_unused=allow_unused) - else: - outputs, inputs_grad = vjp(func, - xs, - create_graph=create_graph, - allow_unused=allow_unused) - return outputs, inputs_grad - - def grad_test(): - nonlocal v - xs = self.gen_inputs(inputs) - if v is not None: - v = self.gen_inputs(v) - outputs = func(*xs) - if v is not None: - inputs_grad = grad( - outputs, - xs, - v, - create_graph=create_graph, - allow_unused=allow_unused) - else: - inputs_grad = grad( - outputs, - xs, - create_graph=create_graph, - allow_unused=allow_unused) - return outputs, inputs_grad - - return vjp_test, grad_test - - def gen_jvp_tests(self, - func, - inputs, - v=None, - create_graph=False, - allow_unused=False): - def jvp_test(): - nonlocal v - xs = self.gen_inputs(inputs) - if v is not None: - v = self.gen_inputs(v) - outputs, outputs_grad = jvp(func, - xs, - v, - create_graph=create_graph, - allow_unused=allow_unused) - else: - outputs, outputs_grad = jvp(func, - xs, - create_graph=create_graph, - allow_unused=allow_unused) - return outputs, outputs_grad - - return jvp_test - - def check_results(self, ref, res): - type_error = 'Result is different than expected in shape or type' - value_error = 'Result is different than expected values' - if ref is None: - self.assertTrue(res is None, type_error) - elif isinstance(ref, paddle.Tensor): - self.assertTrue(isinstance(res, paddle.Tensor), type_error) - self.assertTrue(paddle.allclose(res, ref), value_error) - else: - self.assertTrue(len(res) == len(ref), type_error) - for i in range(len(ref)): - self.check_results(ref[i], res[i]) - return True - - -class TestVJP(TestAutogradFunctional): - def test_vjp_i1o1_no_create_graph(self): - test_cases = [ - [reduce, 'A'], #noqa - [reduce_dim, 'A'], #noqa - ] #noqa - for f, inputs in test_cases: - vjp, grad = self.gen_test_pairs(f, inputs) - vjp_result, grad_result = vjp(), grad() - self.check_results(grad_result, vjp_result) - - def test_vjp_i2o1_no_create_graph(self): - test_cases = [ - [matmul, ['A', 'B']], #noqa - [mul, ['b', 'c']], #noqa - ] #noqa - for f, inputs in test_cases: - vjp, grad = self.gen_test_pairs(f, inputs) - vjp_result, grad_result = vjp(), grad() - self.check_results(grad_result, vjp_result) - - def test_vjp_i2o2_no_create_graph(self): - test_cases = [ - [o2, ['A', 'A']], #noqa - ] #noqa - for f, inputs in test_cases: - inputs = self.gen_inputs(inputs) - v = make_v(f, inputs) - vjp, grad = self.gen_test_pairs(f, inputs, v=v) - vjp_result, grad_result = vjp(), grad() - self.check_results(grad_result, vjp_result) - - def test_vjp_i2o2_omitting_v_no_create_graph(self): - test_cases = [ - [o2, ['A', 'A']], #noqa - ] #noqa - for f, inputs in test_cases: - inputs = self.gen_inputs(inputs) - vjp, grad = self.gen_test_pairs(f, inputs) - vjp_result, grad_result = vjp(), grad() - self.check_results(grad_result, vjp_result) - - def test_vjp_nested_no_create_graph(self): - x = self.gen_input('a') - test_cases = [ - [nested(x), 'a'], #noqa - ] - for f, inputs in test_cases: - vjp, grad = self.gen_test_pairs(f, inputs) - vjp_result, grad_result = vjp(), grad() - self.check_results(grad_result, vjp_result) - - def test_vjp_aliased_input_no_create_graph(self): - x = self.gen_input('a') - ref = self.gen_test_pairs(nested(x), 'a')[0] - aliased = self.gen_test_pairs(nested(x), x)[0] - ref_result, aliased_result = ref(), aliased() - self.check_results(ref_result, aliased_result) - - def test_vjp_allowunused_no_create_graph(self): - x, y = self.gen_input('A'), self.gen_input('a') - vjp, grad = self.gen_test_pairs(unuse, [x, y], allow_unused=True) - vjp_result, grad_result = vjp(), grad() - self.check_results(grad_result, vjp_result) - - -def jac(grad_fn, f, inputs): - assert grad_fn in [vjp, jvp] - if grad_fn is jvp: - vs = [zeros_like(x) for x in inputs] - else: - outputs = f(*inputs) - if isinstance(outputs, paddle.Tensor): - outputs = [outputs] - vs = [zeros_like(y) for y in outputs] - JJ_cols = [] - for i, v in enumerate(vs): - v = v.flatten() - for j in range(len(v)): - _v = zeros_like(v).detach() - _v[j] = 1.0 - _v = _v.reshape(vs[i].shape) - _vs = vs.copy() - _vs[i] = _v - _, grads = grad_fn(f, inputs, vs) - d_outs = paddle.concat([d_out.flatten() for d_out in grads]) - JJ_cols.append(d_outs) - # JJ is the fully unrolled jacobian - JJ = paddle.stack(JJ_cols) - if grad_fn is vjp: - JJ = JJ.t() - return JJ - - -class TestJVP(TestAutogradFunctional): - def test_jvp_i1o1_no_create_graph(self): - test_cases = [ - [reduce, 'A'], #noqa - [reduce_dim, 'A'], #noqa - ] #noqa - for f, inputs in test_cases: - inputs = self.gen_inputs(inputs) - forward_jac = jac(jvp, f, inputs) - reverse_jac = jac(vjp, f, inputs) - self.check_results(forward_jac, reverse_jac) - - def test_jvp_i2o1_no_create_graph(self): - test_cases = [ #noqa - [matmul, ['A', 'B']], #noqa - ] #noqa - for f, inputs in test_cases: - inputs = self.gen_inputs(inputs) - forward_jac = jac(jvp, f, inputs) - reverse_jac = jac(vjp, f, inputs) - self.check_results(forward_jac, reverse_jac) - - def test_jvp_i2o2_no_create_graph(self): - test_cases = [ #noqa - [o2, ['A', 'A']], #noqa - ] #noqa - for f, inputs in test_cases: - inputs = self.gen_inputs(inputs) - forward_jac = jac(jvp, f, inputs) - reverse_jac = jac(vjp, f, inputs) - self.check_results(forward_jac, reverse_jac) - - def test_jvp_i2o2_omitting_v_no_create_graph(self): - test_cases = [ #noqa - [o2, ['A', 'A']], #noqa - ] #noqa - for f, inputs in test_cases: - inputs = self.gen_inputs(inputs) - results_omitting_v = jvp(f, inputs) - v = [ones_like(x) for x in inputs] - results_with_v = jvp(f, inputs, v) - self.check_results(results_omitting_v, results_with_v) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/utils.py b/python/paddle/fluid/tests/unittests/autograd/utils.py index b06ce6ed7cca3..0816b57fbf70b 100644 --- a/python/paddle/fluid/tests/unittests/autograd/utils.py +++ b/python/paddle/fluid/tests/unittests/autograd/utils.py @@ -1,22 +1,33 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing +import enum +import sys +import re +import inspect +import functools +import contextlib +import collections import numpy as np import paddle -from paddle.autograd.functional import _tensors +from paddle.autograd.functional import _as_tensors +########################################################## +# Finite Difference Utils +########################################################## def _product(t): if isinstance(t, int): return t @@ -25,7 +36,9 @@ def _product(t): def _get_item(t, idx): - assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." + assert isinstance( + t, + paddle.fluid.framework.Variable), "The first argument t must be Tensor." assert isinstance(idx, int), "The second argument idx must be an int number." flat_t = paddle.reshape(t, [-1]) @@ -33,7 +46,9 @@ def _get_item(t, idx): def _set_item(t, idx, value): - assert isinstance(t, paddle.Tensor), "The first argument t must be Tensor." + assert isinstance( + t, + paddle.fluid.framework.Variable), "The first argument t must be Tensor." assert isinstance(idx, int), "The second argument idx must be an int number." flat_t = paddle.reshape(t, [-1]) @@ -42,8 +57,8 @@ def _set_item(t, idx, value): def _compute_numerical_jacobian(func, xs, delta, np_dtype): - xs = _tensors(xs, "xs") - ys = _tensors(func(*xs), "ys") + xs = list(_as_tensors(xs)) + ys = list(_as_tensors(func(*xs))) fin_size = len(xs) fout_size = len(ys) jacobian = list([] for _ in range(fout_size)) @@ -59,11 +74,11 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): orig = _get_item(xs[j], q) x_pos = orig + delta xs[j] = _set_item(xs[j], q, x_pos) - ys_pos = _tensors(func(*xs), "ys_pos") + ys_pos = _as_tensors(func(*xs)) x_neg = orig - delta xs[j] = _set_item(xs[j], q, x_neg) - ys_neg = _tensors(func(*xs), "ys_neg") + ys_neg = _as_tensors(func(*xs)) xs[j] = _set_item(xs[j], q, orig) @@ -76,8 +91,8 @@ def _compute_numerical_jacobian(func, xs, delta, np_dtype): def _compute_numerical_hessian(func, xs, delta, np_dtype): - xs = _tensors(xs, "xs") - ys = _tensors(func(*xs), "ys") + xs = list(_as_tensors(xs)) + ys = list(_as_tensors(func(*xs))) fin_size = len(xs) hessian = list([] for _ in range(fin_size)) for i in range(fin_size): @@ -107,10 +122,22 @@ def _compute_numerical_hessian(func, xs, delta, np_dtype): return hessian -def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype): +def concat_to_matrix(xs, is_batched=False): + """Concats a tuple of tuple of Jacobian/Hessian matrix into one matrix""" + rows = [] + for i in range(len(xs)): + rows.append(np.concatenate([x for x in xs[i]], -1)) + return np.concatenate(rows, 1) if is_batched else np.concatenate(rows, 0) + + +def _compute_numerical_batch_jacobian(func, + xs, + delta, + np_dtype, + merge_batch=True): no_batch_jacobian = _compute_numerical_jacobian(func, xs, delta, np_dtype) - xs = _tensors(xs, "xs") - ys = _tensors(func(*xs), "ys") + xs = list(_as_tensors(xs)) + ys = list(_as_tensors(func(*xs))) fin_size = len(xs) fout_size = len(ys) bs = xs[0].shape[0] @@ -128,7 +155,8 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype): for b in range(bs): for q in range(in_size): batch_jac_i_j[p][b][q] = jac[b][p][b][q] - batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1)) + if merge_batch: + batch_jac_i_j = np.reshape(batch_jac_i_j, (out_size, -1)) batch_jac_i.append(batch_jac_i_j) bat_jac.append(batch_jac_i) @@ -136,7 +164,7 @@ def _compute_numerical_batch_jacobian(func, xs, delta, np_dtype): def _compute_numerical_batch_hessian(func, xs, delta, np_dtype): - xs = _tensors(xs, "xs") + xs = list(_as_tensors(xs)) batch_size = xs[0].shape[0] fin_size = len(xs) hessian = [] @@ -175,8 +203,10 @@ def _compute_numerical_batch_hessian(func, xs, delta, np_dtype): def _compute_numerical_vjp(func, xs, v, delta, np_dtype): - xs = _tensors(xs, "xs") + xs = _as_tensors(xs) jacobian = np.array(_compute_numerical_jacobian(func, xs, delta, np_dtype)) + if v is None: + v = [paddle.ones_like(x) for x in xs] flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v]) vjp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs] for j in range(len(xs)): @@ -188,7 +218,7 @@ def _compute_numerical_vjp(func, xs, v, delta, np_dtype): def _compute_numerical_vhp(func, xs, v, delta, np_dtype): - xs = _tensors(xs, "xs") + xs = list(_as_tensors(xs)) hessian = np.array(_compute_numerical_hessian(func, xs, delta, np_dtype)) flat_v = np.array([v_el.numpy().reshape(-1) for v_el in v]) vhp = [np.zeros((_product(x.shape)), dtype=np_dtype) for x in xs] @@ -198,3 +228,166 @@ def _compute_numerical_vhp(func, xs, v, delta, np_dtype): flat_v) vhp = [vhp[j].reshape(xs[j].shape) for j in range(len(xs))] return vhp + + +########################################################## +# TestCases of different function. +########################################################## +def reduce(x): + return paddle.sum(x) + + +def reduce_dim(x): + return paddle.sum(x, axis=0) + + +def matmul(x, y): + return paddle.matmul(x, y) + + +def mul(x, y): + return x * y + + +def pow(x, y): + return paddle.pow(x, y) + + +def o2(x, y): + return paddle.multiply(x, y), paddle.matmul(x, y.t()) + + +def unuse(x, y): + return paddle.sum(x) + + +def nested(x): + def inner(y): + return x * y + + return inner + + +def square(x): + return x * x + + +########################################################## +# Parameterized Test Utils. +########################################################## + +TEST_CASE_NAME = 'suffix' + + +def place(devices, key='place'): + """A Decorator for a class which will make the class running on different + devices . + + Args: + devices (Sequence[Paddle.CUDAPlace|Paddle.CPUPlace]): Device list. + key (str, optional): Defaults to 'place'. + """ + + def decorate(cls): + module = sys.modules[cls.__module__].__dict__ + raw_classes = { + k: v + for k, v in module.items() if k.startswith(cls.__name__) + } + + for raw_name, raw_cls in raw_classes.items(): + for d in devices: + test_cls = dict(raw_cls.__dict__) + test_cls.update({key: d}) + new_name = raw_name + '.' + d.__class__.__name__ + module[new_name] = type(new_name, (raw_cls, ), test_cls) + del module[raw_name] + return cls + + return decorate + + +def parameterize(fields, values=None): + """Decorator for a unittest class which make the class running on different + test cases. + + Args: + fields (Sequence): The feild name sequence of test cases. + values (Sequence, optional): The test cases sequence. Defaults to None. + + """ + fields = [fields] if isinstance(fields, str) else fields + params = [dict(zip(fields, vals)) for vals in values] + + def decorate(cls): + test_cls_module = sys.modules[cls.__module__].__dict__ + for i, values in enumerate(params): + test_cls = dict(cls.__dict__) + values = { + k: staticmethod(v) if callable(v) else v + for k, v in values.items() + } + test_cls.update(values) + name = cls.__name__ + str(i) + name = name + '.' + \ + values.get('suffix') if values.get('suffix') else name + + test_cls_module[name] = type(name, (cls, ), test_cls) + + for m in list(cls.__dict__): + if m.startswith("test"): + delattr(cls, m) + return cls + + return decorate + + +########################################################## +# Utils for transpose different Jacobian/Hessian matrix format. +########################################################## + +# B is batch size, N is row size, M is column size. +MatrixFormat = enum.Enum('MatrixFormat', ('NBM', 'BNM', 'NMB', 'NM')) + + +def _np_transpose_matrix_format(src, src_format, des_format): + """Transpose Jacobian/Hessian matrix format.""" + supported_format = (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NMB) + if src_format not in supported_format or des_format not in supported_format: + raise ValueError( + f"Supported Jacobian format is {supported_format}, but got src: {src_format}, des: {des_format}" + ) + + src_axis = {c: i for i, c in enumerate(src_format.name)} + dst_axis = tuple(src_axis[c] for c in des_format.name) + + return np.transpose(src, dst_axis) + + +def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM): + """Convert a sequence of sequence of Jacobian/Hessian matrix into one huge + matrix.""" + + def concat_col(xs): + if src_format in (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NM): + return np.concatenate(xs, axis=-1) + else: + return np.concatenate(xs, axis=1) + + def concat_row(xs): + if src_format in (MatrixFormat.NBM, MatrixFormat.NM, MatrixFormat.NMB): + return np.concatenate(xs, axis=0) + else: + return np.concatenate(xs, axis=1) + + supported_format = (MatrixFormat.NBM, MatrixFormat.BNM, MatrixFormat.NMB, + MatrixFormat.NM) + if src_format not in supported_format: + raise ValueError( + f"Supported Jacobian format is {supported_format}, but got {src_format}" + ) + if not isinstance(src, typing.Sequence): + return src + if not isinstance(src[0], typing.Sequence): + src = [src] + return concat_row(tuple(concat_col(xs) for xs in src)) diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 83dad710bad7d..182aae40f2982 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -26,6 +26,7 @@ from .tensor import segment_max from .tensor import segment_min from .passes import fuse_resnet_unit_pass +import paddle.incubate.autograd from . import nn #noqa: F401 diff --git a/python/paddle/incubate/autograd/__init__.py b/python/paddle/incubate/autograd/__init__.py new file mode 100644 index 0000000000000..5528bb4d06c6f --- /dev/null +++ b/python/paddle/incubate/autograd/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.autograd.functional import Hessian, Jacobian, jvp, vjp + +__all__ = [ # noqa + 'vjp', 'jvp', 'Jacobian', 'Hessian' +] diff --git a/python/setup.py.in b/python/setup.py.in index 3e59e22fcbc63..7f311feb4ee34 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -273,6 +273,7 @@ packages=['paddle', 'paddle.distributed.ps', 'paddle.distributed.ps.utils', 'paddle.incubate', + 'paddle.incubate.autograd', 'paddle.incubate.optimizer', 'paddle.incubate.checkpoint', 'paddle.incubate.operators', diff --git a/tools/windows/run_unittests.sh b/tools/windows/run_unittests.sh index dd6a4ad288140..44dc4eac26118 100644 --- a/tools/windows/run_unittests.sh +++ b/tools/windows/run_unittests.sh @@ -12,55 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -e -set +x -NIGHTLY_MODE=$1 -PRECISION_TEST=$2 -WITH_GPU=$3 - -export PADDLE_ROOT="$(cd "$PWD/../" && pwd )" -if [ ${NIGHTLY_MODE:-OFF} == "ON" ]; then - nightly_label="" -else - nightly_label="(RUN_TYPE=NIGHTLY|RUN_TYPE=DIST:NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY)" - echo "=========================================" - echo "Unittests with nightly labels are only run at night" - echo "=========================================" -fi - -if disable_ut_quickly=$(python ${PADDLE_ROOT}/tools/get_quick_disable_lt.py); then - echo "=========================================" - echo "The following unittests have been disabled:" - echo ${disable_ut_quickly} - echo "=========================================" -else - disable_ut_quickly='' -fi - -# check added ut -set +e -cp $PADDLE_ROOT/tools/check_added_ut.sh $PADDLE_ROOT/tools/check_added_ut_win.sh -bash $PADDLE_ROOT/tools/check_added_ut_win.sh -rm -rf $PADDLE_ROOT/tools/check_added_ut_win.sh -if [ -f "$PADDLE_ROOT/added_ut" ];then - added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$ - ctest -R "(${added_uts})" --output-on-failure -C Release --repeat-until-fail 3;added_ut_error=$? - rm -f $PADDLE_ROOT/added_ut - if [ "$added_ut_error" != 0 ];then - echo "========================================" - echo "Added UT should pass three additional executions" - echo "========================================" - exit 8; - fi - if nvcc --version | grep 11.2; then - echo "Only test added_ut temporarily when running in CI-Windows-inference of CUDA 11.2." - exit 0; - fi -fi -set -e - -# /*==================Fixed Disabled Windows GPU MKL unittests==============================*/ +# /*================Fixed Disabled Windows CUDA10.x MKL(PR-CI-Windows) unittests===========================*/ # TODO: fix these unittest that is bound to fail disable_wingpu_test="^test_model$|\ ^test_dataloader_early_reset$|\ @@ -97,7 +50,7 @@ disable_wingpu_test="^test_model$|\ ^test_bilinear_interp_op$|\ ^disable_wingpu_test$" -# /*==================Fixed Disabled Windows GPU MKL unittests==============================*/ +# /*=================Fixed Disabled Windows TRT MKL unittests=======================*/ # TODO: fix these unittest that is bound to fail disable_win_trt_test="^test_trt_convert_conv2d$|\ ^test_trt_convert_conv2d_fusion$|\ @@ -119,7 +72,13 @@ disable_win_trt_test="^test_trt_convert_conv2d$|\ ^test_trt_convert_matmul$|\ ^test_trt_convert_scale$" -# /*==================Fixed Disabled Windows GPU inference_api_test unittests==============================*/ +# /*=============Fixed Disabled Windows CUDA11.x MKL(PR-CI-Windows-Inference) unittests=================*/ +# TODO: fix these unittest that is bound to fail +disable_wingpu11_test="^test_autograd_functional_dynamic$|\ +^disable_wingpu_test$" + + +# /*==========Fixed Disabled Windows CUDA11.x inference_api_test(PR-CI-Windows-Inference) unittests=============*/ disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\ ^test_trt_dynamic_shape_ernie$|\ ^test_trt_dynamic_shape_ernie_fp16_ser_deser$|\ @@ -128,9 +87,8 @@ disable_win_inference_api_test="^trt_quant_int8_yolov3_r50_test$|\ ^lite_mul_model_test$|\ ^paddle_infer_api_copy_tensor_tester$" -# /*============================================================================*/ -# /*==================Fixed Disabled Windows CPU OPENBLAS unittests==============================*/ +# /*==========Fixed Disabled Windows CPU OPENBLAS((PR-CI-Windows-OPENBLAS)) unittests==============================*/ # TODO: fix these unittest that is bound to fail disable_wincpu_test="^jit_kernel_test$|\ ^test_analyzer_transformer$|\ @@ -189,6 +147,58 @@ long_time_test="^test_gru_op$|\ ^test_trt_matmul_quant_dequant$|\ ^test_strided_slice_op$" + +# /*============================================================================*/ + +set -e +set +x +NIGHTLY_MODE=$1 +PRECISION_TEST=$2 +WITH_GPU=$3 + +export PADDLE_ROOT="$(cd "$PWD/../" && pwd )" +if [ ${NIGHTLY_MODE:-OFF} == "ON" ]; then + nightly_label="" +else + nightly_label="(RUN_TYPE=NIGHTLY|RUN_TYPE=DIST:NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY)" + echo "=========================================" + echo "Unittests with nightly labels are only run at night" + echo "=========================================" +fi + +if disable_ut_quickly=$(python ${PADDLE_ROOT}/tools/get_quick_disable_lt.py); then + echo "=========================================" + echo "The following unittests have been disabled:" + echo ${disable_ut_quickly} + echo "=========================================" +else + disable_ut_quickly='' +fi + +# check added ut + +set +e +cp $PADDLE_ROOT/tools/check_added_ut.sh $PADDLE_ROOT/tools/check_added_ut_win.sh +bash $PADDLE_ROOT/tools/check_added_ut_win.sh +rm -rf $PADDLE_ROOT/tools/check_added_ut_win.sh +if [ -f "$PADDLE_ROOT/added_ut" ];then + added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$ + ctest -R "(${added_uts})" -E "$disable_wingpu11_test" --output-on-failure -C Release --repeat-until-fail 3;added_ut_error=$? + rm -f $PADDLE_ROOT/added_ut + if [ "$added_ut_error" != 0 ];then + echo "========================================" + echo "Added UT should pass three additional executions" + echo "========================================" + exit 8; + fi + if nvcc --version | grep 11.2; then + echo "Only test added_ut temporarily when running in CI-Windows-inference of CUDA 11.2." + exit 0; + fi +fi +set -e + + if [ ${WITH_GPU:-OFF} == "ON" ];then export CUDA_VISIBLE_DEVICES=0 From ad0c106cc840342a2e4e6368476b46120377262a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 2 Apr 2022 09:39:51 +0800 Subject: [PATCH 02/93] Fix sparse conv and verify sparse conv backward (#40961) --- .../kernels/sparse/convolution_grad_kernel.h | 37 ++++++------- .../sparse/cpu/convolution_grad_kernel.cc | 28 ++++++---- .../sparse/gpu/convolution_grad_kernel.cu | 41 ++++++++------ .../kernels/test_sparse_conv3d_dev_api.cc | 33 +++++++----- .../tests/unittests/test_sparse_conv_op.py | 54 +++++++++++++++++++ .../paddle/utils/code_gen/sparse_bw_api.yaml | 2 +- 6 files changed, 137 insertions(+), 58 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_conv_op.py diff --git a/paddle/phi/kernels/sparse/convolution_grad_kernel.h b/paddle/phi/kernels/sparse/convolution_grad_kernel.h index 23e059c72e776..5a47575141a2d 100644 --- a/paddle/phi/kernels/sparse/convolution_grad_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_grad_kernel.h @@ -25,37 +25,37 @@ namespace sparse { template void Conv3dGradKernel(const Context& dev_ctx, const SparseCooTensor& x, - const DenseTensor& rulebook, const DenseTensor& kernel, - const DenseTensor& out_grad, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, - DenseTensor* x_grad, + SparseCooTensor* x_grad, DenseTensor* kernel_grad); template -std::vector Conv3dGrad(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const DenseTensor& kernel, - const DenseTensor& out_grad, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm) { - DenseTensor x_grad = - phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); +std::tuple Conv3dGrad( + const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm) { + SparseCooTensor x_grad; DenseTensor kernel_grad = phi::Empty( dev_ctx, DenseTensorMeta(kernel.dtype(), {1}, kernel.layout())); // TODO(zhangkaihuo): call InferMeta func here Conv3dGradKernel(dev_ctx, x, - rulebook, kernel, + rulebook, out_grad, paddings, dilations, @@ -64,10 +64,7 @@ std::vector Conv3dGrad(const Context& dev_ctx, subm, &x_grad, &kernel_grad); - std::vector out(2); - out[0] = x_grad; - out[1] = kernel_grad; - return out; + return std::make_tuple(x_grad, kernel_grad); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index 3348d81cf6b4b..29079918cbf86 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" +#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" @@ -31,15 +32,15 @@ namespace sparse { template void Conv3dGradKernel(const Context& dev_ctx, const SparseCooTensor& x, - const DenseTensor& rulebook, const DenseTensor& kernel, - const DenseTensor& out_grad, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, - DenseTensor* x_grad, + SparseCooTensor* x_grad, DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; @@ -73,11 +74,18 @@ void Conv3dGradKernel(const Context& dev_ctx, int half_kernel_size = kernel_size / 2; auto blas = phi::funcs::GetBlas(dev_ctx); - x_grad->Resize(x.non_zero_elements().dims()); - dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel()); - T* x_grad_values_ptr = x_grad->data(); - memset(x_grad_values_ptr, 0, sizeof(T) * x_grad->numel()); + DenseTensor x_grad_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + T* x_grad_values_ptr = x_grad_values.data(); + memset(x_grad_values_ptr, 0, sizeof(T) * x_grad_values.numel()); memset(d_x_features_ptr, 0, sizeof(T) * d_x_features.numel()); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); + x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); std::vector offsets(kernel_size + 1), counter(kernel_size, 0); for (int i = 0; i < rulebook_len; i++) { @@ -97,12 +105,12 @@ void Conv3dGradKernel(const Context& dev_ctx, phi::funcs::sparse::SubmPreProcess(dev_ctx, x, kernel, - out_grad, + out_grad.non_zero_elements(), in_channels, out_channels, half_kernel_size, kernel_grad, - x_grad); + &x_grad_values); if (max_count == 0) { return; } @@ -113,7 +121,7 @@ void Conv3dGradKernel(const Context& dev_ctx, rulebook_len, in_channels, in_features_ptr); - Gather(out_grad.data(), + Gather(out_grad.non_zero_elements().data(), rulebook_ptr + rulebook_len * 2, rulebook_len, out_channels, diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index 4db0a0b0011b5..4a6094c23bc79 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" @@ -36,15 +38,15 @@ namespace sparse { template void Conv3dGradKernel(const Context& dev_ctx, const SparseCooTensor& x, - const DenseTensor& rulebook, const DenseTensor& kernel, - const DenseTensor& out_grad, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, - DenseTensor* x_grad, + SparseCooTensor* x_grad, DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; @@ -70,17 +72,25 @@ void Conv3dGradKernel(const Context& dev_ctx, T* in_features_ptr = in_features.data(); T* d_x_features_ptr = d_x_features.data(); T* out_grad_features_ptr = out_grad_features.data(); - kernel_grad->ResizeAndAllocate(kernel_dims); + *kernel_grad = phi::EmptyLike(dev_ctx, kernel); T* d_kernel_ptr = kernel_grad->data(); phi::funcs::SetConstant set_zero; set_zero(dev_ctx, kernel_grad, static_cast(0.0f)); int half_kernel_size = kernel_size / 2; auto blas = phi::funcs::GetBlas(dev_ctx); - x_grad->ResizeAndAllocate(x.non_zero_elements().dims()); - T* x_grad_values_ptr = x_grad->data(); - set_zero(dev_ctx, x_grad, static_cast(0.0f)); + DenseTensor x_grad_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + T* x_grad_values_ptr = x_grad_values.data(); + set_zero(dev_ctx, &x_grad_values, static_cast(0.0f)); set_zero(dev_ctx, &d_x_features, static_cast(0.0f)); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); + x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); std::vector offsets(kernel_size + 1), counter(kernel_size, 0), h_counter(rulebook_len, 0); @@ -113,12 +123,12 @@ void Conv3dGradKernel(const Context& dev_ctx, phi::funcs::sparse::SubmPreProcess(dev_ctx, x, kernel, - out_grad, + out_grad.non_zero_elements(), in_channels, out_channels, half_kernel_size, kernel_grad, - x_grad); + &x_grad_values); if (max_count == 0) { return; } @@ -140,11 +150,12 @@ void Conv3dGradKernel(const Context& dev_ctx, GatherKernel<<>>(out_grad.data(), - rulebook_ptr + rulebook_len * 2, - out_grad_features_ptr, - rulebook_len, - out_channels); + dev_ctx.stream()>>>( + out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + out_grad_features_ptr, + rulebook_len, + out_channels); const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { @@ -189,7 +200,7 @@ void Conv3dGradKernel(const Context& dev_ctx, } // 4. scatter - x_grad->ResizeAndAllocate(x.non_zero_elements().dims()); + // x_grad->ResizeAndAllocate(x.non_zero_elements().dims()); DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW); DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index 33f84db76e78e..c22464e538c21 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -71,6 +71,10 @@ void TestConv3dBase(const std::vector& indices, paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(paddle::platform::CPUPlace()) .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); dev_ctx_cpu.Init(); const int in_channels = kernel_dims[3]; @@ -132,19 +136,19 @@ void TestConv3dBase(const std::vector& indices, f_verify(out.non_zero_elements().data(), correct_out_features); if (backward) { - std::vector grads = + std::tuple grads = sparse::Conv3dGrad(dev_ctx_cpu, x_tensor, - rulebook, kernel_tensor, - out.non_zero_elements(), + rulebook, + out, paddings, dilations, strides, 1, subm); - f_verify(grads[0].data(), features_grad); - f_verify(grads[1].data(), kernel_grad); + f_verify(std::get<0>(grads).non_zero_elements().data(), features_grad); + f_verify(std::get<1>(grads).data(), kernel_grad); } } @@ -233,23 +237,28 @@ void TestConv3dBase(const std::vector& indices, f_verify(h_features_tensor.data(), correct_out_features); if (backward) { - std::vector grads = + std::tuple grads = sparse::Conv3dGrad(dev_ctx_gpu, d_x_tensor, - d_rulebook, d_kernel_tensor, - d_out.non_zero_elements(), + d_rulebook, + d_out, paddings, dilations, strides, 1, subm); - DenseTensor h_features_grad = phi::EmptyLike(dev_ctx_cpu, grads[0]); - phi::Copy(dev_ctx_gpu, grads[0], phi::CPUPlace(), true, &h_features_grad); + DenseTensor d_features_grad = std::get<0>(grads).non_zero_elements(); + DenseTensor d_kernel_grad = std::get<1>(grads); + DenseTensor h_features_grad = + phi::EmptyLike(dev_ctx_cpu, d_features_grad); + phi::Copy( + dev_ctx_gpu, d_features_grad, phi::CPUPlace(), true, &h_features_grad); f_verify(h_features_grad.data(), features_grad); - DenseTensor h_kernel_grad = phi::EmptyLike(dev_ctx_cpu, grads[1]); - phi::Copy(dev_ctx_gpu, grads[1], phi::CPUPlace(), true, &h_kernel_grad); + DenseTensor h_kernel_grad = phi::EmptyLike(dev_ctx_cpu, d_kernel_grad); + phi::Copy( + dev_ctx_gpu, std::get<1>(grads), phi::CPUPlace(), true, &h_kernel_grad); f_verify(h_kernel_grad.data(), kernel_grad); } #endif diff --git a/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py new file mode 100644 index 0000000000000..075806a93b07d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +from paddle import _C_ops +from paddle.fluid import core +from paddle.fluid.framework import _test_eager_guard + + +class TestSparseConv(unittest.TestCase): + def test_conv3d(self): + with _test_eager_guard(): + kernel = [[[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]] + dense_kernel = paddle.to_tensor( + kernel, dtype='float32', stop_gradient=False) + dense_kernel = paddle.reshape(dense_kernel, [1, 3, 3, 1, 1]) + paddings = [0, 0, 0] + strides = [1, 1, 1] + dilations = [1, 1, 1] + + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [1, 2, 3, 4] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + correct_out_values = [[4], [10]] + sparse_input = core.eager.sparse_coo_tensor(indices, values, + dense_shape, False) + out = _C_ops.final_state_sparse_conv3d(sparse_input, dense_kernel, + paddings, dilations, strides, + 1, False) + out.backward(out) + #At present, only backward can be verified to work normally + #TODO(zhangkaihuo): compare the result with dense conv + print(sparse_input.grad.non_zero_elements()) + assert np.array_equal(correct_out_values, + out.non_zero_elements().numpy()) + + +#TODO: Add more test case diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index 1f474d56a9022..7ffc906b22084 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -1,7 +1,7 @@ - backward_api : conv3d_grad forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) - output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor) + output : Tensor(x_grad@SparseCooTensor), Tensor(kernel_grad@DenseTensor) kernel : func : sparse_conv3d_grad From 2012aeb6a03ca7bc313e0aeb0d56229d44ec18c7 Mon Sep 17 00:00:00 2001 From: Wilber Date: Sat, 2 Apr 2022 10:05:05 +0800 Subject: [PATCH 03/93] add trt pool and ut (#41258) --- paddle/infrt/dialect/tensorrt/convert.h | 156 +++++++++++++++++- .../infrt/dialect/tensorrt/pd_lower_to_trt.td | 3 +- .../dialect/tensorrt/trt_op_converter_pass.cc | 27 --- paddle/infrt/dialect/tensorrt/trt_ops.td | 5 +- paddle/infrt/kernel/tensorrt/trt_helper.h | 10 +- paddle/infrt/kernel/tensorrt/trt_kernels.cc | 3 + paddle/infrt/kernel/tensorrt/trt_layers.h | 56 ++++++- .../tests/dialect/tensorrt/disabled_trt.mlir | 37 ----- .../tensorrt/disabled_trt_activation.mlir | 21 +++ .../dialect/tensorrt/disabled_trt_fc.mlir | 69 +++----- .../dialect/tensorrt/disabled_trt_pool.mlir | 21 +++ 11 files changed, 289 insertions(+), 119 deletions(-) delete mode 100644 paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir create mode 100644 paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir create mode 100644 paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir diff --git a/paddle/infrt/dialect/tensorrt/convert.h b/paddle/infrt/dialect/tensorrt/convert.h index fc607aa112714..c1f87ecde7872 100644 --- a/paddle/infrt/dialect/tensorrt/convert.h +++ b/paddle/infrt/dialect/tensorrt/convert.h @@ -14,17 +14,49 @@ #pragma once #include +#include +#include #include +#include +#include #include - #include "paddle/infrt/dialect/infrt/common/types.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" +#include "paddle/infrt/kernel/tensorrt/trt_helper.h" namespace infrt { namespace trt { + +#ifdef INFRT_WITH_TRT + +#define STRING_TO_ENUM_TYPE(enum_type) enum_type +#define STRING_TO_ENUM_VALUE(enum_value) enum_value +#include + +#else // INFRT_WITH_TRT + +#define STRING_TO_ENUM_TYPE(enum_type) std::string +#define STRING_TO_ENUM_VALUE(enum_value) #enum_value + +#endif // INFRT_WITH_TRT + +template +::mlir::IntegerAttr createNvinferEnumAttr( + ::mlir::PatternRewriter &rewriter, // NOLINT + T enum_value) { + return rewriter.getSI32IntegerAttr((int32_t)enum_value); +} + +template <> +::mlir::IntegerAttr createNvinferEnumAttr( + ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT + (void)enum_value; + return rewriter.getSI32IntegerAttr(-1); +} + static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, // NOLINT mlir::Operation *op) { auto conv_op = ::llvm::dyn_cast(op); @@ -205,5 +237,127 @@ static mlir::Value createTRTShuffledOp( return rewriter.create( op->getLoc(), resultTypes, operands, attributes); } + +inline mlir::IntegerAttr CreatePoolingType( + mlir::PatternRewriter &builder, // NOLINT + mlir::StringAttr pool_type) { + // pool_type. + auto ptype = pool_type.str(); + if (ptype == "max") { + return createNvinferEnumAttr(builder, nvinfer1::PoolingType::kMAX); + } else if (ptype == "avg") { + return createNvinferEnumAttr(builder, nvinfer1::PoolingType::kAVERAGE); + } else { + llvm_unreachable("unknown pool_type."); + return {}; + } +} + +inline mlir::IntegerAttr CreatePaddingMode( + mlir::PatternRewriter &builder, // NOLINT + mlir::StringAttr padding_algorithm, + mlir::BoolAttr ceil_mode) { + // TODO(Inference): Phi pool kernel seems not process ceil_mode. + auto padding_algo = padding_algorithm.str(); + if (padding_algo == "SAME") { + return createNvinferEnumAttr(builder, nvinfer1::PaddingMode::kSAME_UPPER); + } + if (ceil_mode.getValue() && padding_algo != "SAME") { + return createNvinferEnumAttr(builder, + nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); + } else { + return createNvinferEnumAttr(builder, + nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN); + } +} + +inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp( + mlir::PatternRewriter &builder, // NOLINT + mlir::Value input, + mlir::StringAttr pool_type, + mlir::ArrayAttr ksize, + mlir::BoolAttr global_pooling, + mlir::ArrayAttr strides, + mlir::ArrayAttr paddings, + mlir::BoolAttr exclusive, + mlir::BoolAttr adaptive, + mlir::BoolAttr ceil_mode, + mlir::StringAttr data_format, + mlir::StringAttr padding_algorithm) { + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + + // TODO(inference): Support NHWC. + if (data_format.str() != "NCHW") { + CHECK(false) << "The pool2d converter now only support NCHW."; + } + + // TODO(Wilber): How to support dynamic shape? + + auto *input_producer = input.getDefiningOp(); + + // Process pool_type. + auto pool_type_attr = CreatePoolingType(builder, pool_type); + + // Update padding. + auto padding_algorithm_str = padding_algorithm.str(); + auto paddings_attr = paddings; + if (padding_algorithm_str == "EXPLICIT") { + // Do nothing on paddings. + } else if (padding_algorithm_str == "SAME") { + // We should process this case in trt network build phase. + } else if (padding_algorithm_str == "VALID") { + // Set padding to zero. + paddings_attr = builder.getI32ArrayAttr({0, 0}); + } else { + CHECK(false) << "Unknown padding_algotithm."; + } + + // if global_pooling == true or adaptive == true, padding will be ignored + if (global_pooling.getValue() || adaptive.getValue()) { + paddings_attr = builder.getI32ArrayAttr({0, 0}); + } + + // if global_pooling == true, then we should update kernel size to input dims. + if (global_pooling.getValue() == true) { + // Update ksize to input dims. + } + + // The adaptive logic should be processed when we get the context of + // INetworkDefinition, so we place the logic in infrt runtime(trt compile + // time). + + // The `exclusive` may be a naive attr, which can be forward to trt. + + auto padding_mode_attr = + CreatePaddingMode(builder, padding_algorithm, ceil_mode); + + if (global_pooling.getValue() == true) { + CHECK(false) << "Temporarily not support global_pool"; + return tblgen_repl_values; + } + + PoolingOp pool_op; + { + auto ods_loc = builder.getFusedLoc({input_producer->getLoc()}); + builder.create(ods_loc, + input.getType(), + input, + pool_type_attr, + ksize, + strides, + paddings_attr, + padding_mode_attr, + exclusive, + adaptive, + padding_algorithm); + } + + for (auto v : + ::llvm::SmallVector<::mlir::Value, 4>{pool_op.getODSResults(0)}) { + tblgen_repl_values.push_back(v); + } + return tblgen_repl_values; +} + } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td index ad60906ececbf..227b473c3fc19 100644 --- a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -31,9 +31,10 @@ def PD2TRT_Conv2d_Lower : Pat< (PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format), (createTRTConv2dOp $old_value)>; +def createTrtPoolingOp : NativeCodeCall<"::infrt::trt::CreatePaddleTrtPoolingOp($_builder, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10)">; def PD2TRT_Pooling_Lower : Pat< (PD_Pool2dOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm), - (TRT_PoolingOp $Input, (INFRT_createI32Attr<"0">)/*kmax*/, $ksize, $strides, $paddings, $padding_algorithm)>; + (createTrtPoolingOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm)>; def PD2TRT_MatrixMultipl_Lower : Pat< (PD_MulOp $Input1, $Input2, $x_num_col_dims, $y_num_col_dims), diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 6bcef3d913d79..95dd31fcd5838 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -28,33 +28,6 @@ namespace infrt { namespace trt { -#ifdef INFRT_WITH_TRT - -#define STRING_TO_ENUM_TYPE(enum_type) enum_type -#define STRING_TO_ENUM_VALUE(enum_value) enum_value -#include - -#else // INFRT_WITH_TRT - -#define STRING_TO_ENUM_TYPE(enum_type) std::string -#define STRING_TO_ENUM_VALUE(enum_value) #enum_value - -#endif // INFRT_WITH_TRT - -template -::mlir::IntegerAttr createNvinferEnumAttr( - ::mlir::PatternRewriter &rewriter, // NOLINT - T enum_value) { - return rewriter.getSI32IntegerAttr((int32_t)enum_value); -} - -template <> -::mlir::IntegerAttr createNvinferEnumAttr( - ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT - (void)enum_value; - return rewriter.getSI32IntegerAttr(-1); -} - #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT struct PD2TRT_GraphLower : public ::mlir::RewritePattern { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 3fd3f377f4ec7..68a593e440b50 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -101,7 +101,10 @@ def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> { I32ArrayAttr:$window_size, I32ArrayAttr:$strides, I32ArrayAttr:$paddings, - StrAttr:$padding_mode + I32Attr:$padding_mode, + BoolAttr:$exclusive, + BoolAttr:$adaptive, + StrAttr:$padding_algorithm ); let results = (outs DenseTensor:$output_tensor diff --git a/paddle/infrt/kernel/tensorrt/trt_helper.h b/paddle/infrt/kernel/tensorrt/trt_helper.h index 96122bffacdb2..13529430d683d 100644 --- a/paddle/infrt/kernel/tensorrt/trt_helper.h +++ b/paddle/infrt/kernel/tensorrt/trt_helper.h @@ -28,13 +28,13 @@ namespace infrt { namespace kernel { namespace tensorrt { -static nvinfer1::DataType TensorTypeToWeightType(phi::DataType tensor_type) { +static nvinfer1::DataType TensorTypeToWeightType(::phi::DataType tensor_type) { switch (tensor_type) { - case phi::DataType::FLOAT32: + case ::phi::DataType::FLOAT32: return nvinfer1::DataType::kFLOAT; - case phi::DataType::INT32: + case ::phi::DataType::INT32: return nvinfer1::DataType::kINT32; - case phi::DataType::FLOAT16: + case ::phi::DataType::FLOAT16: return nvinfer1::DataType::kHALF; default: llvm_unreachable("should not reach here"); @@ -52,7 +52,7 @@ static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) { return dims; } -static nvinfer1::Weights TensorToWeights(phi::DenseTensor* tensor) { +static nvinfer1::Weights TensorToWeights(::phi::DenseTensor* tensor) { CHECK_NOTNULL(tensor); nvinfer1::Weights ret; ret.type = TensorTypeToWeightType(tensor->dtype()); diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc index a6d740f01846d..92e3a624bb021 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.cc +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -129,6 +129,7 @@ ::infrt::backends::tensorrt::TrtEngine CreateTrtEngine( // TODO(wilber): Find a way to add layer. for (auto& operation : block.without_terminator()) { + VLOG(1) << "process " << operation.getName().getStringRef().str() << " ..."; if (trt::ActivationOp op = llvm::dyn_cast(operation)) { ActivationFunc( op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); @@ -138,6 +139,8 @@ ::infrt::backends::tensorrt::TrtEngine CreateTrtEngine( } else if (trt::ConvolutionOp op = llvm::dyn_cast(operation)) { ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else if (trt::PoolingOp op = llvm::dyn_cast(operation)) { + PoolFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); } else { CHECK(false) << "not supported operation."; } diff --git a/paddle/infrt/kernel/tensorrt/trt_layers.h b/paddle/infrt/kernel/tensorrt/trt_layers.h index 19e20c170ec83..3a300ad0c10af 100644 --- a/paddle/infrt/kernel/tensorrt/trt_layers.h +++ b/paddle/infrt/kernel/tensorrt/trt_layers.h @@ -15,13 +15,15 @@ #pragma once #include +#include +#include #include +#include #include #include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/kernel/tensorrt/trt_helper.h" - #include "paddle/phi/core/dense_tensor.h" namespace infrt { @@ -63,7 +65,12 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT nvinfer1::Dims dims = ArrayAttrToNvDims(size_attrs); auto kernel_weights = TensorToWeights(value_to_tensor_map[op.kernel_weights()]); - auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); + nvinfer1::Weights bias_weights; + if (op.bias_weights() == mlir::Value()) { + bias_weights = nvinfer1::Weights{}; + } else { + bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); + } auto* layer = network->addConvolutionNd(*value_to_trt_tensor_map[input_tensor_repr], @@ -77,6 +84,51 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT value_to_trt_tensor_map[out_repr] = out_tensor; } +inline void PoolFunc(trt::PoolingOp& op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + mlir::Value input_tensor_repr = op.input_tensor(); + nvinfer1::ITensor* input_itensor = value_to_trt_tensor_map[input_tensor_repr]; + // nvinfer1::Dims input_shape = input_itensor->getDimensions(); + // int input_dims = input_shape.nbDims; + + auto padding_mode = op.padding_mode(); + auto pool_type = op.pool_type(); + mlir::ArrayAttr paddings = op.paddings(); + mlir::ArrayAttr strides = op.strides(); + mlir::ArrayAttr ksize = op.window_size(); + bool exclusive = op.exclusive(); + bool adaptive = op.adaptive(); + auto padding_algorithm = op.padding_algorithm().str(); + + if (padding_algorithm == "SAME") { + // TODO(wilber) + CHECK(false) << "Not supported `same` padding algorithm"; + } + + if (adaptive) { + // TODO(Inference) + CHECK(false) << "Not supported adaptive pool"; + } + + nvinfer1::Dims window_size = ArrayAttrToNvDims(ksize); + + auto* layer = + network->addPoolingNd(*input_itensor, + static_cast(pool_type), + window_size); + CHECK_NOTNULL(layer); + layer->setPaddingMode(static_cast(padding_mode)); + layer->setPaddingNd(ArrayAttrToNvDims(paddings)); + layer->setStrideNd(ArrayAttrToNvDims(strides)); + layer->setAverageCountExcludesPadding(exclusive); + + mlir::Value out_repr = op.output_tensor(); + nvinfer1::ITensor* out_tensor = layer->getOutput(0); + value_to_trt_tensor_map[out_repr] = out_tensor; +} + inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT nvinfer1::INetworkDefinition* network, ValueToITensorMap& value_to_trt_tensor_map, // NOLINT diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir deleted file mode 100644 index ef86dcf1e72a0..0000000000000 --- a/paddle/infrt/tests/dialect/tensorrt/disabled_trt.mlir +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: infrtexec -i %s | FileCheck %s - -// CHECK-LABEL: @run_trt -func @run_trt(%0 : !infrt.dense_tensor, %ctx : !phi.context) { - %a = "trt.create_engine"(%0) ({ - %1 = "trt.Activation"(%0) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor - "infrt.return"(%1) : (!infrt.dense_tensor) -> () - }) : (!infrt.dense_tensor) -> !trt.engine - "trt.inspect_engine"(%a) {} : (!trt.engine) -> () - - %res = "trt.compute"(%a, %ctx) {} : (!trt.engine, !phi.context) -> (!infrt.tensor_list) - %size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32) - "infrt.print.i32"(%size) {} : (i32) -> () - - %ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) - "phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor) -> () - - infrt.return -} - -// CHECK-LABEL: @main -func @main() { - %ctx = "phi_dt.create_context.gpu" (): () -> !phi.context - %t = "phi_dt.create_dense_tensor.gpu" (%ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) - - "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () - "phi_dt.print_tensor" (%t) : (!infrt.dense_tensor) -> () - - //%res = - infrt.call @run_trt(%t, %ctx) : (!infrt.dense_tensor, !phi.context) -> () - //-> (!infrt.dense_tensor) - - infrt.return -} diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir new file mode 100644 index 0000000000000..557990677696e --- /dev/null +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_activation.mlir @@ -0,0 +1,21 @@ +module { + func @main_graph(%arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + %0 = "phi_dt.create_context.gpu"() : () -> !phi.context + %1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + %2 = "trt.create_engine"(%1) ( { + %6 = "trt.Activation"(%1) {activation_type = 1 : si32, alpha = 0.000000e+00 : f32, beta = 0.000000e+00 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %6 : !infrt.dense_tensor + }) {run_once = true} : (!infrt.dense_tensor) -> !trt.engine + %3 = "trt.compute"(%2, %0) : (!trt.engine, !phi.context) -> !infrt.tensor_list + %4 = "dt.tensor_list_get_tensor"(%3) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor + %5 = "phi_dt.memcpy.gpu"(%4, %0) {d2h = true} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + infrt.return %5 : !infrt.dense_tensor + } + func @main() { + %0 = "phi_dt.create_context.cpu"() : () -> !phi.context + %1 = "phi_dt.create_inited_dense_tensor.cpu.f32"(%0) {dims = [3, 6, 1, 1], layout = #infrt.layout, lod = [0], value = 1.500000e+00 : f32} : (!phi.context) -> !infrt.dense_tensor + %2 = infrt.call @main_graph(%1) : (!infrt.dense_tensor) -> !infrt.dense_tensor + phi_dt.print_tensor(%2 : !infrt.dense_tensor) + infrt.return + } +} diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir index 78dc4ac1c1093..aba706df71843 100644 --- a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_fc.mlir @@ -1,46 +1,25 @@ -// RUN: infrtexec -i %s | FileCheck %s - -// CHECK-LABEL: @main -func @main() { - %ctx = "phi_dt.create_context.gpu" (): () -> !phi.context - %cpu_ctx = "phi_dt.create_context.cpu" (): () -> !phi.context - - %input_tensor = "phi_dt.create_dense_tensor.gpu" (%ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) - "phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () - //"phi_dt.print_tensor" (%input_tensor) : (!infrt.dense_tensor) -> () - - %kernel_weight = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[2:i64, 3:i64], lod=[1:i64]} : (!phi.context) -> (!infrt.dense_tensor) - "phi_dt.fill_dense_tensor.f32"(%kernel_weight) {value=[1.:f32, 2.:f32, 3.:f32, 4.:f32, 5.:f32, 6.:f32]} : (!infrt.dense_tensor) -> () - //"phi_dt.print_tensor" (%kernel_weight) : (!infrt.dense_tensor) -> () - - %kernel_bias = "phi_dt.create_dense_tensor.cpu"(%cpu_ctx) { - precision=#infrt.precision, - layout=#infrt.layout, - dims=[2:i64], lod=[1:i64]} : (!phi.context) -> (!infrt.dense_tensor) - "phi_dt.fill_dense_tensor.f32"(%kernel_bias) {value=[1.:f32, 2.:f32]} : (!infrt.dense_tensor) -> () - //"phi_dt.print_tensor" (%kernel_bias) : (!infrt.dense_tensor) -> () - - %engine = "trt.create_engine"(%input_tensor, %kernel_weight, %kernel_bias) ({ - %1 = "trt.Activation"(%input_tensor) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor - %2 = "trt.FullyConnected"(%input_tensor, %kernel_weight, %kernel_bias) {out_channel_num = 2 : si32} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor - "infrt.return"(%1, %2) : (!infrt.dense_tensor, !infrt.dense_tensor) -> () - }) : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !trt.engine - - %res = "trt.compute"(%engine, %ctx) {} : (!trt.engine, !phi.context) -> (!infrt.tensor_list) - %size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32) - "infrt.print.i32"(%size) {} : (i32) -> () - - %ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) - "phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor) -> () - - %ts1 = "dt.tensor_list_get_tensor"(%res) {id = 1 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) - "phi_dt.print_tensor" (%ts1) : (!infrt.dense_tensor) -> () - - infrt.return +module { + func @main_graph(%arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + %0 = "phi_dt.create_context.gpu"() : () -> !phi.context + %1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + %4 = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout, lod=[0], dims=[2, 6]}: (!phi.context) -> (!infrt.dense_tensor) + %3 = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout, lod=[0], dims=[2]}: (!phi.context) -> (!infrt.dense_tensor) + %5 = "trt.create_engine"(%1, %4, %3) ( { + %10 = "trt.FullyConnected"(%1, %4, %3) {out_channel_num = 2 : si32} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %10 : !infrt.dense_tensor + }) {run_once = true} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !trt.engine + %6 = "trt.compute"(%5, %0) : (!trt.engine, !phi.context) -> !infrt.tensor_list + %7 = "dt.tensor_list_get_tensor"(%6) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor + %8 = "phi_dt.memcpy.gpu"(%7, %0) {d2h = true} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + infrt.return %8 : !infrt.dense_tensor + } + + func @main() { + %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + %input_tensor = "phi_dt.create_inited_dense_tensor.cpu.f32" (%ctx) {value=1.5:f32, layout=#infrt.layout, lod=[0], dims=[3, 6, 1, 1]}: (!phi.context) -> (!infrt.dense_tensor) + %res = infrt.call @main_graph(%input_tensor) {} : (!infrt.dense_tensor) -> !infrt.dense_tensor + "phi_dt.print_tensor" (%res) : (!infrt.dense_tensor) -> () + infrt.return + } } diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir new file mode 100644 index 0000000000000..af24ac63d23fe --- /dev/null +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_trt_pool.mlir @@ -0,0 +1,21 @@ +module { + func @main_graph(%arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + %0 = "phi_dt.create_context.gpu"() : () -> !phi.context + %1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + %2 = "trt.create_engine"(%1) ( { + %6 = "trt.Pooling"(%1) {padding_mode = 0 : i32, paddings = [1 : i32, 1 : i32], pool_type = 0 : i32, strides = [2 : i32, 2 : i32], window_size = [3 : i32, 3 : i32], exclusive = false, adaptive = false, padding_algorithm = "EXPLICIT"} : (!infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %6 : !infrt.dense_tensor + }) {run_once = true} : (!infrt.dense_tensor) -> !trt.engine + %3 = "trt.compute"(%2, %0) : (!trt.engine, !phi.context) -> !infrt.tensor_list + %4 = "dt.tensor_list_get_tensor"(%3) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor + %5 = "phi_dt.memcpy.gpu"(%4, %0) {d2h = true} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + infrt.return %5 : !infrt.dense_tensor + } + func @main() { + %0 = "phi_dt.create_context.cpu"() : () -> !phi.context + %1 = "phi_dt.create_inited_dense_tensor.cpu.f32"(%0) {dims = [1, 3, 10, 10], layout = #infrt.layout, lod = [0], value = 1.500000e+00 : f32} : (!phi.context) -> !infrt.dense_tensor + %2 = infrt.call @main_graph(%1) : (!infrt.dense_tensor) -> !infrt.dense_tensor + phi_dt.print_tensor(%2 : !infrt.dense_tensor) + infrt.return + } +} From 16bfcd18ada44866104b265f9970aeaaed389b34 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Sat, 2 Apr 2022 10:15:35 +0800 Subject: [PATCH 04/93] [Yaml] transfer around 22 ops yaml file and pass the final state OpTest. (#41024) * 1. add the python api grad 2. add final and intermediate state vlog 3. change the python_api error logic * add python api or close the check_eager=True * fix the compatibility --- paddle/fluid/pybind/eager_utils.cc | 2 +- paddle/phi/infermeta/binary.cc | 2 +- paddle/phi/infermeta/binary.h | 4 +- paddle/phi/kernels/cpu/allclose_kernel.cc | 33 ++- .../phi/kernels/cpu/kthvalue_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/prelu_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/prelu_kernel.cc | 2 +- paddle/phi/kernels/gpu/allclose_kernel.cu | 33 ++- .../phi/kernels/gpu/kthvalue_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/prelu_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/prelu_kernel.cu | 2 +- .../kernels/impl/lgamma_grad_kernel_impl.h | 2 +- paddle/phi/kernels/kldiv_loss_grad_kernel.h | 1 - paddle/phi/kernels/kthvalue_grad_kernel.h | 2 +- paddle/phi/kernels/lgamma_grad_kernel.h | 2 +- paddle/phi/kernels/prelu_grad_kernel.h | 2 +- paddle/phi/kernels/prelu_kernel.h | 2 +- paddle/phi/ops/compat/kthvalue_sig.cc | 2 +- paddle/phi/ops/compat/lgamma_sig.cc | 2 +- paddle/phi/ops/compat/prelu_sig.cc | 8 +- .../fluid/layers/layer_function_generator.py | 1 + python/paddle/fluid/layers/nn.py | 15 +- .../tests/unittests/test_activation_op.py | 47 +++- .../fluid/tests/unittests/test_allclose_op.py | 11 +- .../fluid/tests/unittests/test_complex_abs.py | 12 +- .../fluid/tests/unittests/test_cumprod_op.py | 8 +- .../fluid/tests/unittests/test_fmax_op.py | 34 ++- .../fluid/tests/unittests/test_fmin_op.py | 34 ++- .../fluid/tests/unittests/test_gather_op.py | 15 +- .../fluid/tests/unittests/test_isclose_op.py | 13 +- .../tests/unittests/test_kldiv_loss_op.py | 7 +- .../fluid/tests/unittests/test_kthvalue_op.py | 10 +- .../fluid/tests/unittests/test_lgamma_op.py | 8 +- .../fluid/tests/unittests/test_log_softmax.py | 13 +- .../fluid/tests/unittests/test_max_op.py | 5 + .../fluid/tests/unittests/test_mean_op.py | 23 +- .../fluid/tests/unittests/test_min_op.py | 5 + .../fluid/tests/unittests/test_mode_op.py | 10 +- .../fluid/tests/unittests/test_norm_all.py | 31 ++- .../fluid/tests/unittests/test_normalize.py | 16 ++ .../fluid/tests/unittests/test_pad3d_op.py | 5 +- .../fluid/tests/unittests/test_prelu_op.py | 11 +- .../fluid/tests/unittests/test_reduce_op.py | 54 ++-- .../fluid/tests/unittests/test_squeeze2_op.py | 8 +- .../tests/unittests/test_unsqueeze2_op.py | 18 +- python/paddle/nn/functional/activation.py | 21 +- python/paddle/nn/functional/common.py | 8 +- python/paddle/nn/functional/loss.py | 7 +- python/paddle/nn/functional/norm.py | 8 +- python/paddle/nn/layer/distance.py | 8 +- python/paddle/tensor/linalg.py | 8 +- python/paddle/tensor/logic.py | 9 +- python/paddle/tensor/manipulation.py | 4 +- python/paddle/tensor/math.py | 65 ++++- python/paddle/tensor/search.py | 15 +- python/paddle/tensor/stat.py | 7 +- python/paddle/utils/code_gen/api.yaml | 253 +++++++++++++++++- python/paddle/utils/code_gen/backward.yaml | 249 +++++++++++++++++ 58 files changed, 990 insertions(+), 195 deletions(-) diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index bee3e27a55167..e245362c50be5 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -933,7 +933,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, bool value = CastPyArg2Boolean(obj, op_type, arg_pos); return paddle::experimental::Scalar(value); - } else if (type_name == "paddle.Tensor") { + } else if (type_name == "Tensor") { paddle::experimental::Tensor& value = GetTensorFromPyObject( op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/); return paddle::experimental::Scalar(value); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 8e285aba55145..44ae53a00d18e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1374,8 +1374,8 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, - const std::string& mode, const std::string& data_format, + const std::string& mode, MetaTensor* out, MetaConfig config) { auto x_dim = x.dims(); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index fc9d2642d9cc4..751422a4def48 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -196,10 +196,10 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, - const std::string& mode, const std::string& data_format, + const std::string& mode, MetaTensor* out, - MetaConfig config); + MetaConfig config = MetaConfig()); void SearchsortedInferMeta(const MetaTensor& sorted_sequence, const MetaTensor& value, diff --git a/paddle/phi/kernels/cpu/allclose_kernel.cc b/paddle/phi/kernels/cpu/allclose_kernel.cc index 7ffeadfeed8aa..80dea561956cf 100644 --- a/paddle/phi/kernels/cpu/allclose_kernel.cc +++ b/paddle/phi/kernels/cpu/allclose_kernel.cc @@ -29,21 +29,28 @@ void AllCloseKernel(const Context& dev_ctx, const Scalar& atol, bool equal_nan, DenseTensor* out) { - PADDLE_ENFORCE_EQ( - rtol.dtype(), - DataType::FLOAT64, - phi::errors::InvalidArgument( - "Input (Rtol) type must be double, but get %s.", rtol.dtype())); - PADDLE_ENFORCE_EQ( - atol.dtype(), - DataType::FLOAT64, - phi::errors::InvalidArgument( - "Input (Atol) type must be double, but get %s.", atol.dtype())); - + double rtol_v, atol_v; + if (rtol.dtype() == DataType::FLOAT64) { + rtol_v = rtol.to(); + } else if (rtol.dtype() == DataType::FLOAT32) { + rtol_v = rtol.to(); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Input (Rtol) type must be double or float, but get %s.", + rtol.dtype())); + } + if (atol.dtype() == DataType::FLOAT64) { + atol_v = atol.to(); + } else if (atol.dtype() == DataType::FLOAT32) { + atol_v = atol.to(); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Input (Atol) type must be double or float, but get %s.", + atol.dtype())); + } + VLOG(3) << "rtol and atol is : " << rtol_v << " " << atol_v; auto* in_a = x.data(); auto* in_b = y.data(); - auto rtol_v = rtol.to(); - auto atol_v = atol.to(); auto* out_data = dev_ctx.template Alloc(out); *out_data = true; diff --git a/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc b/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc index 185d6cbedc85d..de7dfd167b76d 100644 --- a/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc @@ -46,9 +46,9 @@ static void kthvalueAssign(const Type& input_height, template void KthvalueGradKernel(const Context& dev_ctx, - const DenseTensor& d_out, const DenseTensor& x, const DenseTensor& indices, + const DenseTensor& d_out, int k, int axis, bool keepdim, diff --git a/paddle/phi/kernels/cpu/prelu_grad_kernel.cc b/paddle/phi/kernels/cpu/prelu_grad_kernel.cc index 97558cdb31f66..17be3fc897917 100644 --- a/paddle/phi/kernels/cpu/prelu_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/prelu_grad_kernel.cc @@ -24,8 +24,8 @@ void PReluGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& alpha, const DenseTensor& out_grad, - const std::string& mode, const std::string& data_format, + const std::string& mode, DenseTensor* x_grad, DenseTensor* alpha_grad) { const T* alpha_ptr = alpha.data(); diff --git a/paddle/phi/kernels/cpu/prelu_kernel.cc b/paddle/phi/kernels/cpu/prelu_kernel.cc index 8f389ab9ff459..636a3a4d750d1 100644 --- a/paddle/phi/kernels/cpu/prelu_kernel.cc +++ b/paddle/phi/kernels/cpu/prelu_kernel.cc @@ -23,8 +23,8 @@ template void PReluKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& alpha, - const std::string& mode, const std::string& data_format, + const std::string& mode, DenseTensor* out) { const T* x_ptr = x.data(); const T* alpha_ptr = alpha.data(); diff --git a/paddle/phi/kernels/gpu/allclose_kernel.cu b/paddle/phi/kernels/gpu/allclose_kernel.cu index af2612bb10c9f..8abc6b272c511 100644 --- a/paddle/phi/kernels/gpu/allclose_kernel.cu +++ b/paddle/phi/kernels/gpu/allclose_kernel.cu @@ -51,21 +51,28 @@ void AllCloseKernel(const Context& dev_ctx, const Scalar& atol, bool equal_nan, DenseTensor* out) { - PADDLE_ENFORCE_EQ( - rtol.dtype(), - DataType::FLOAT64, - phi::errors::InvalidArgument( - "Input (Rtol) type must be double, but get %s.", rtol.dtype())); - PADDLE_ENFORCE_EQ( - atol.dtype(), - DataType::FLOAT64, - phi::errors::InvalidArgument( - "Input (Atol) type must be double, but get %s.", atol.dtype())); - + double rtol_v, atol_v; + if (rtol.dtype() == DataType::FLOAT64) { + rtol_v = rtol.to(); + } else if (rtol.dtype() == DataType::FLOAT32) { + rtol_v = rtol.to(); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Input (Rtol) type must be double or float, but get %s.", + rtol.dtype())); + } + if (atol.dtype() == DataType::FLOAT64) { + atol_v = atol.to(); + } else if (atol.dtype() == DataType::FLOAT32) { + atol_v = atol.to(); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Input (Atol) type must be double or float, but get %s.", + atol.dtype())); + } + VLOG(3) << "rtol and atol is : " << rtol_v << " " << atol_v; const T* in_data = x.data(); const T* other_data = y.data(); - auto rtol_v = rtol.to(); - auto atol_v = atol.to(); bool* out_data = dev_ctx.template Alloc(out); int num = x.numel(); diff --git a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu index f6e96046a2bd7..bcd370a72d91d 100644 --- a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu @@ -34,9 +34,9 @@ static int getBlockSize(int col) { template void KthvalueGradKernel(const Context& dev_ctx, - const DenseTensor& d_out, const DenseTensor& x, const DenseTensor& indices, + const DenseTensor& d_out, int k, int axis, bool keepdim, diff --git a/paddle/phi/kernels/gpu/prelu_grad_kernel.cu b/paddle/phi/kernels/gpu/prelu_grad_kernel.cu index d8661268e82c3..013ad1974a8fb 100644 --- a/paddle/phi/kernels/gpu/prelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/prelu_grad_kernel.cu @@ -102,8 +102,8 @@ void PReluGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& alpha, const DenseTensor& out_grad, - const std::string& mode, const std::string& data_format, + const std::string& mode, DenseTensor* x_grad, DenseTensor* alpha_grad) { dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/gpu/prelu_kernel.cu b/paddle/phi/kernels/gpu/prelu_kernel.cu index 8255a7ba2ed96..c4730768982bb 100644 --- a/paddle/phi/kernels/gpu/prelu_kernel.cu +++ b/paddle/phi/kernels/gpu/prelu_kernel.cu @@ -24,8 +24,8 @@ template void PReluKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& alpha, - const std::string& mode, const std::string& data_format, + const std::string& mode, DenseTensor* out) { const T* x_ptr = x.data(); T* o_ptr = dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h b/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h index 8fb1f1c4fa361..9ef6c61fd60fb 100644 --- a/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h @@ -33,8 +33,8 @@ struct LgammaGradFunctor { }; template void LgammaGradKernel(const Context& dev_ctx, - const DenseTensor& d_out, const DenseTensor& x, + const DenseTensor& d_out, DenseTensor* d_x) { auto numel = d_out.numel(); auto* dout_data = d_out.data(); diff --git a/paddle/phi/kernels/kldiv_loss_grad_kernel.h b/paddle/phi/kernels/kldiv_loss_grad_kernel.h index 8f53898fa6816..6e05c7992eb61 100644 --- a/paddle/phi/kernels/kldiv_loss_grad_kernel.h +++ b/paddle/phi/kernels/kldiv_loss_grad_kernel.h @@ -19,7 +19,6 @@ namespace phi { template -// XKTODO (change name) void KLDivLossGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& label, diff --git a/paddle/phi/kernels/kthvalue_grad_kernel.h b/paddle/phi/kernels/kthvalue_grad_kernel.h index 488dde8237b08..c2eac0a3e3de9 100644 --- a/paddle/phi/kernels/kthvalue_grad_kernel.h +++ b/paddle/phi/kernels/kthvalue_grad_kernel.h @@ -20,9 +20,9 @@ namespace phi { template void KthvalueGradKernel(const Context& dev_ctx, - const DenseTensor& d_out, const DenseTensor& x, const DenseTensor& indices, + const DenseTensor& d_out, int k, int axis, bool keepdim, diff --git a/paddle/phi/kernels/lgamma_grad_kernel.h b/paddle/phi/kernels/lgamma_grad_kernel.h index 94173cc29c7a7..d7f0ef399eaa0 100644 --- a/paddle/phi/kernels/lgamma_grad_kernel.h +++ b/paddle/phi/kernels/lgamma_grad_kernel.h @@ -21,7 +21,7 @@ namespace phi { template void LgammaGradKernel(const Context& dev_ctx, - const DenseTensor& d_out, const DenseTensor& x, + const DenseTensor& d_out, DenseTensor* d_x); } // namespace phi diff --git a/paddle/phi/kernels/prelu_grad_kernel.h b/paddle/phi/kernels/prelu_grad_kernel.h index 15917e2e1f02e..d36f529640d7d 100644 --- a/paddle/phi/kernels/prelu_grad_kernel.h +++ b/paddle/phi/kernels/prelu_grad_kernel.h @@ -24,8 +24,8 @@ void PReluGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& alpha, const DenseTensor& out_grad, - const std::string& mode, const std::string& data_format, + const std::string& mode, DenseTensor* x_grad, DenseTensor* alpha_grad); } // namespace phi diff --git a/paddle/phi/kernels/prelu_kernel.h b/paddle/phi/kernels/prelu_kernel.h index 251332a8158dc..7e273ecfd2fa1 100644 --- a/paddle/phi/kernels/prelu_kernel.h +++ b/paddle/phi/kernels/prelu_kernel.h @@ -22,7 +22,7 @@ template void PReluKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& alpha, - const std::string& mode, const std::string& data_format, + const std::string& mode, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/ops/compat/kthvalue_sig.cc b/paddle/phi/ops/compat/kthvalue_sig.cc index e59e9de1e4382..3b1a6a45f9a0a 100644 --- a/paddle/phi/ops/compat/kthvalue_sig.cc +++ b/paddle/phi/ops/compat/kthvalue_sig.cc @@ -20,7 +20,7 @@ namespace phi { KernelSignature KthvalueGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("kthvalue_grad", - {GradVarName("Out"), "X", "Indices"}, + {"X", "Indices", GradVarName("Out")}, {"k", "axis", "keepdim"}, {GradVarName("X")}); } diff --git a/paddle/phi/ops/compat/lgamma_sig.cc b/paddle/phi/ops/compat/lgamma_sig.cc index 968ad4923ba7b..452ba5e2b45a1 100644 --- a/paddle/phi/ops/compat/lgamma_sig.cc +++ b/paddle/phi/ops/compat/lgamma_sig.cc @@ -18,7 +18,7 @@ namespace phi { KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( - "lgamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")}); + "lgamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); } } // namespace phi diff --git a/paddle/phi/ops/compat/prelu_sig.cc b/paddle/phi/ops/compat/prelu_sig.cc index bd296c5e95318..43e5f20a92676 100644 --- a/paddle/phi/ops/compat/prelu_sig.cc +++ b/paddle/phi/ops/compat/prelu_sig.cc @@ -16,13 +16,19 @@ namespace phi { +KernelSignature PReluOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "prelu", {"X", "Alpha"}, {"data_format", "mode"}, {"Out"}); +} + KernelSignature PReluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("prelu_grad", {"X", "Alpha", GradVarName("Out")}, - {"mode", "data_format"}, + {"data_format", "mode"}, {GradVarName("X"), GradVarName("Alpha")}); } } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(prelu, phi::PReluOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(prelu_grad, phi::PReluGradOpArgumentMapping); diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index a99838cb27d4c..ec99f7c64f36f 100755 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -23,6 +23,7 @@ from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph from ..layer_helper import LayerHelper from ..data_feeder import check_variable_and_dtype +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph from paddle import _C_ops __all__ = [ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6260213face05..0d2c1f14f2ddd 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -25,6 +25,7 @@ import paddle from ..layer_helper import LayerHelper +from paddle.fluid.framework import _in_legacy_dygraph from ..initializer import Normal, Constant, NumpyArrayInitializer from ..framework import Variable, OpProtoHolder, _non_static_mode, dygraph_only, _dygraph_tracer, default_main_program, _varbase_creator, static_only, _global_flags, _in_legacy_dygraph, in_dygraph_mode from .. import dygraph_utils @@ -6427,7 +6428,9 @@ def squeeze(input, axes, name=None): y = layers.squeeze(input=x, axes=[2]) # y.shape=[None, 5, 10] """ - if _non_static_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_squeeze(input, axes)[1] + if _in_legacy_dygraph(): out, _ = _C_ops.squeeze2(input, 'axes', axes) return out @@ -6488,8 +6491,10 @@ def unsqueeze(input, axes, name=None): item.numpy().item(0) if isinstance(item, Variable) else item for item in axes ] - out, _ = _C_ops.unsqueeze2(input, 'axes', axes) - return out + if _in_legacy_dygraph(): + out, _ = _C_ops.unsqueeze2(input, 'axes', axes) + return out + return _C_ops.final_state_unsqueeze(input, axes)[1] check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze') check_variable_and_dtype(input, 'input', [ @@ -8910,7 +8915,9 @@ def log(x, name=None): res = paddle.log(x) # [[0.693147, 1.09861, 1.38629], [1.94591, 2.07944, 2.19722]] """ - if _non_static_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_log(x) + if _in_legacy_dygraph(): return _C_ops.log(x) check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log") diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 471d0245aa83c..ef47b841cf819 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -50,6 +50,7 @@ def setUp(self): self.op_type = "exp" self.init_dtype() self.init_kernel_type() + self.check_eager = False np.random.seed(2049) x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) @@ -59,12 +60,18 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + check_eager = False + if hasattr(self, 'check_eager'): + check_eager = self.check_eager + self.check_output(check_eager=check_eager) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + check_eager = False + if hasattr(self, 'check_eager'): + check_eager = self.check_eager + self.check_grad(['X'], 'Out', check_eager=check_eager) def init_dtype(self): self.dtype = np.float64 @@ -876,6 +883,8 @@ def ref_softshrink(x, threshold=0.5): class TestSoftshrink(TestActivation): def setUp(self): self.op_type = "softshrink" + self.check_eager = True + self.python_api = paddle.nn.functional.softshrink self.init_dtype() threshold = 0.8 @@ -890,7 +899,7 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSoftshrinkAPI(unittest.TestCase): @@ -1050,6 +1059,8 @@ def test_check_grad(self): class TestCeil(TestActivation): def setUp(self): self.op_type = "ceil" + self.check_eager = True + self.python_api = paddle.ceil self.init_dtype() np.random.seed(1024) @@ -1067,6 +1078,8 @@ def test_check_grad(self): class TestFloor(TestActivation): def setUp(self): self.op_type = "floor" + self.check_eager = True + self.python_api = paddle.floor self.init_dtype() np.random.seed(1024) @@ -1263,6 +1276,8 @@ def test_check_grad(self): class TestRound(TestActivation): def setUp(self): self.op_type = "round" + self.check_eager = True + self.python_api = paddle.round self.init_dtype() np.random.seed(1024) @@ -2075,6 +2090,8 @@ def test_check_output(self): class TestLog(TestActivation): def setUp(self): self.op_type = "log" + self.check_eager = True + self.python_api = paddle.log self.init_dtype() np.random.seed(1024) @@ -2087,7 +2104,7 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def test_error(self): in1 = fluid.layers.data( @@ -2102,6 +2119,8 @@ def test_error(self): class TestLog2(TestActivation): def setUp(self): self.op_type = "log2" + self.check_eager = True + self.python_api = paddle.log2 self.init_dtype() x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) @@ -2113,7 +2132,7 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def test_error(self): in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32") @@ -2151,6 +2170,8 @@ def test_api(self): class TestLog10(TestActivation): def setUp(self): self.op_type = "log10" + self.check_eager = True + self.python_api = paddle.log10 self.init_dtype() x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) @@ -2162,7 +2183,7 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def test_error(self): in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32") @@ -2200,6 +2221,8 @@ def test_api(self): class TestLog1p(TestActivation): def setUp(self): self.op_type = "log1p" + self.check_eager = True + self.python_api = paddle.log1p self.init_dtype() np.random.seed(1024) @@ -2212,7 +2235,7 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def test_api(self): with fluid.program_guard(fluid.Program(), fluid.Program()): @@ -2298,6 +2321,8 @@ def test_check_grad(self): class TestPow(TestActivation): def setUp(self): self.op_type = "pow" + self.python_api = paddle.pow + self.check_eager = False self.init_dtype() np.random.seed(1024) @@ -2311,12 +2336,14 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=self.check_eager) class TestPow_factor_tensor(TestActivation): def setUp(self): self.op_type = "pow" + self.check_eager = False + self.python_api = paddle.pow self.init_dtype() np.random.seed(1024) @@ -2332,12 +2359,12 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=self.check_eager) def test_api(self): input = np.random.uniform(1, 2, [11, 17]).astype("float32") diff --git a/python/paddle/fluid/tests/unittests/test_allclose_op.py b/python/paddle/fluid/tests/unittests/test_allclose_op.py index e96bf951240e7..ec1c5363fcde1 100644 --- a/python/paddle/fluid/tests/unittests/test_allclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_allclose_op.py @@ -29,6 +29,7 @@ def set_args(self): def setUp(self): self.set_args() self.op_type = "allclose" + self.python_api = paddle.allclose self.inputs = { 'Input': self.input, 'Other': self.other, @@ -48,7 +49,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAllcloseOpException(TestAllcloseOp): @@ -56,28 +57,28 @@ def test_check_output(self): def test_rtol_num(): self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64") self.inputs['Atol'] = np.array([1e-08]).astype("float64") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_rtol_num) def test_rtol_type(): self.inputs['Rtol'] = np.array([5]).astype("int32") self.inputs['Atol'] = np.array([1e-08]).astype("float64") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_rtol_type) def test_atol_num(): self.inputs['Rtol'] = np.array([1e-05]).astype("float64") self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_atol_num) def test_atol_type(): self.inputs['Rtol'] = np.array([1e-05]).astype("float64") self.inputs['Atol'] = np.array([8]).astype("int32") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_atol_type) diff --git a/python/paddle/fluid/tests/unittests/test_complex_abs.py b/python/paddle/fluid/tests/unittests/test_complex_abs.py index 4bc6beacb689f..a29d9baadead0 100644 --- a/python/paddle/fluid/tests/unittests/test_complex_abs.py +++ b/python/paddle/fluid/tests/unittests/test_complex_abs.py @@ -46,7 +46,7 @@ def init_grad_input_output(self): self.grad_x = self.grad_out * (self.x / np.abs(self.x)) def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): self.check_grad( @@ -54,7 +54,7 @@ def test_check_grad(self): 'Out', user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) class TestComplexAbsOpZeroValues(OpTest): @@ -80,7 +80,7 @@ def init_grad_input_output(self): self.grad_x = np.zeros(self.shape, self.dtype) def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): self.check_grad( @@ -88,7 +88,7 @@ def test_check_grad(self): 'Out', user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) class TestAbs(unittest.TestCase): @@ -133,7 +133,7 @@ def init_grad_input_output(self): self.grad_x = self.grad_out * (self.x / np.abs(self.x)) def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): self.check_grad( @@ -141,7 +141,7 @@ def test_check_grad(self): 'Out', user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_cumprod_op.py b/python/paddle/fluid/tests/unittests/test_cumprod_op.py index 31e7ee287f0ea..681b8d6cc0bdf 100644 --- a/python/paddle/fluid/tests/unittests/test_cumprod_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumprod_op.py @@ -73,6 +73,7 @@ def setUp(self): self.init_params() self.init_dtype() self.op_type = "cumprod" + self.python_api = paddle.cumprod self.inputs = {'X': None} self.outputs = {'Out': None} self.attrs = {'dim': None} @@ -110,7 +111,7 @@ def test_check_output(self): for dim in range(-len(self.shape), len(self.shape)): for zero_num in self.zero_nums: self.prepare_inputs_outputs_attrs(dim, zero_num) - self.check_output() + self.check_output(check_eager=True) # test backward. def test_check_grad(self): @@ -119,13 +120,14 @@ def test_check_grad(self): self.prepare_inputs_outputs_attrs(dim, zero_num) self.init_grad_input_output(dim) if self.dtype == np.float64: - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) else: self.check_grad( ['X'], 'Out', user_defined_grads=[self.grad_x], - user_defined_grad_outputs=[self.grad_out]) + user_defined_grad_outputs=[self.grad_out], + check_eager=True) # test float32 case. diff --git a/python/paddle/fluid/tests/unittests/test_fmax_op.py b/python/paddle/fluid/tests/unittests/test_fmax_op.py index 3981d63c00582..608d97b68ac22 100644 --- a/python/paddle/fluid/tests/unittests/test_fmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmax_op.py @@ -125,6 +125,7 @@ class TestElementwiseFmaxOp(OpTest): def setUp(self): """setUp""" self.op_type = "elementwise_fmax" + self.python_api = paddle.fmax # If x and y have the same value, the max() is not differentiable. # So we generate test data by the following method # to avoid them being too close to each other. @@ -136,21 +137,29 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" self.check_grad( - ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + ['Y'], + 'Out', + max_relative_error=0.005, + no_grad_set=set("X"), + check_eager=True) def test_check_grad_ingore_y(self): """test_check_grad_ingore_y""" self.check_grad( - ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + ['X'], + 'Out', + max_relative_error=0.005, + no_grad_set=set('Y'), + check_eager=True) class TestElementwiseFmax2Op(OpTest): @@ -159,6 +168,7 @@ class TestElementwiseFmax2Op(OpTest): def setUp(self): """setUp""" self.op_type = "elementwise_fmax" + self.python_api = paddle.fmax # If x and y have the same value, the max() is not differentiable. # So we generate test data by the following method # to avoid them being too close to each other. @@ -172,18 +182,26 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" self.check_grad( - ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + ['Y'], + 'Out', + max_relative_error=0.005, + no_grad_set=set("X"), + check_eager=True) def test_check_grad_ingore_y(self): """test_check_grad_ingore_y""" self.check_grad( - ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + ['X'], + 'Out', + max_relative_error=0.005, + no_grad_set=set('Y'), + check_eager=True) diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py index 7231823c37532..b9d26827988cd 100644 --- a/python/paddle/fluid/tests/unittests/test_fmin_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -127,6 +127,7 @@ class TestElementwiseFminOp(OpTest): def setUp(self): """setUp""" self.op_type = "elementwise_fmin" + self.python_api = paddle.fmin # If x and y have the same value, the min() is not differentiable. # So we generate test data by the following method # to avoid them being too close to each other. @@ -138,21 +139,29 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" self.check_grad( - ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + ['Y'], + 'Out', + max_relative_error=0.005, + no_grad_set=set("X"), + check_eager=True) def test_check_grad_ingore_y(self): """test_check_grad_ingore_y""" self.check_grad( - ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + ['X'], + 'Out', + max_relative_error=0.005, + no_grad_set=set('Y'), + check_eager=True) class TestElementwiseFmin2Op(OpTest): @@ -161,6 +170,7 @@ class TestElementwiseFmin2Op(OpTest): def setUp(self): """setUp""" self.op_type = "elementwise_fmin" + self.python_api = paddle.fmin # If x and y have the same value, the min() is not differentiable. # So we generate test data by the following method # to avoid them being too close to each other. @@ -174,21 +184,29 @@ def setUp(self): def test_check_output(self): """test_check_output""" - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): """test_check_grad_normal""" - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) def test_check_grad_ingore_x(self): """test_check_grad_ingore_x""" self.check_grad( - ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + ['Y'], + 'Out', + max_relative_error=0.005, + no_grad_set=set("X"), + check_eager=True) def test_check_grad_ingore_y(self): """test_check_grad_ingore_y""" self.check_grad( - ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + ['X'], + 'Out', + max_relative_error=0.005, + no_grad_set=set('Y'), + check_eager=True) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 978a3d86d882a..9ec2d1acdb5f3 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -33,6 +33,7 @@ def gather_numpy(x, index, axis): class TestGatherOp(OpTest): def setUp(self): self.op_type = "gather" + self.python_api = paddle.gather self.config() xnp = np.random.random(self.x_shape).astype(self.x_type) self.inputs = { @@ -42,10 +43,10 @@ def setUp(self): self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} def test_check_output(self): - self.check_output() + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=False) def config(self): """ @@ -120,6 +121,7 @@ def config(self): class TestGatherBF16Op(OpTest): def setUp(self): self.op_type = "gather" + self.python_api = paddle.gather self.dtype = np.uint16 self.config() xnp = np.random.random(self.x_shape).astype(np.float32) @@ -134,10 +136,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=0.5) + self.check_grad(['X'], 'Out', numeric_grad_delta=0.5, check_eager=False) def config(self): """ @@ -153,6 +155,7 @@ def config(self): class TestGatherOp1(OpTest): def setUp(self): self.op_type = "gather" + self.python_api = paddle.gather self.config() xnp = np.random.random(self.x_shape).astype(self.x_type) axis_np = np.array(self.axis).astype(self.index_type) @@ -162,10 +165,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=False) def config(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_isclose_op.py b/python/paddle/fluid/tests/unittests/test_isclose_op.py index 2bb58d7c5741f..245520e5ab666 100644 --- a/python/paddle/fluid/tests/unittests/test_isclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_isclose_op.py @@ -30,6 +30,7 @@ def setUp(self): paddle.enable_static() self.set_args() self.op_type = "isclose" + self.python_api = paddle.isclose self.inputs = { 'Input': self.input, 'Other': self.other, @@ -49,7 +50,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestIscloseOpException(TestIscloseOp): @@ -57,28 +58,28 @@ def test_check_output(self): def test_rtol_num(): self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64") self.inputs['Atol'] = np.array([1e-08]).astype("float64") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_rtol_num) def test_rtol_type(): self.inputs['Rtol'] = np.array([5]).astype("int32") self.inputs['Atol'] = np.array([1e-08]).astype("float64") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_rtol_type) def test_atol_num(): self.inputs['Rtol'] = np.array([1e-05]).astype("float64") self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_atol_num) def test_atol_type(): self.inputs['Rtol'] = np.array([1e-05]).astype("float64") self.inputs['Atol'] = np.array([8]).astype("int32") - self.check_output() + self.check_output(check_eager=True) self.assertRaises(ValueError, test_atol_type) @@ -211,7 +212,7 @@ def set_args(self): self.equal_nan = False def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestIscloseOpLargeDimInput(TestIscloseOp): diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py index a301748ed7bbb..aa94cf2d35cc7 100644 --- a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -17,6 +17,7 @@ import unittest import numpy as np from op_test import OpTest +from paddle.nn.functional import kl_div def kldiv_loss(x, target, reduction): @@ -40,6 +41,7 @@ class TestKLDivLossOp(OpTest): def setUp(self): self.initTestCase() self.op_type = 'kldiv_loss' + self.python_api = kl_div x = np.random.uniform(-10, 10, self.x_shape).astype('float64') target = np.random.uniform(-10, 10, self.x_shape).astype('float64') @@ -53,10 +55,11 @@ def setUp(self): self.outputs = {'Loss': loss.astype('float64')} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Loss', no_grad_set=set(["Target"])) + self.check_grad( + ['X'], 'Loss', no_grad_set=set(["Target"]), check_eager=True) def initTestCase(self): self.x_shape = (4, 5, 5) diff --git a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py index 68dd58835c56c..e1b1422580983 100644 --- a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py +++ b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py @@ -41,6 +41,7 @@ def init_args(self): def setUp(self): self.op_type = "kthvalue" + self.python_api = paddle.kthvalue self.dtype = np.float64 self.input_data = np.random.random((2, 1, 2, 4, 10)) self.init_args() @@ -52,11 +53,11 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(set(['X']), 'Out') + self.check_grad(set(['X']), 'Out', check_eager=True) class TestKthvalueOpWithKeepdim(OpTest): @@ -67,6 +68,7 @@ def init_args(self): def setUp(self): self.init_args() self.op_type = "kthvalue" + self.python_api = paddle.kthvalue self.dtype = np.float64 self.input_data = np.random.random((1, 3, 2, 4, 10)) self.inputs = {'X': self.input_data} @@ -77,11 +79,11 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(set(['X']), 'Out') + self.check_grad(set(['X']), 'Out', check_eager=True) class TestKthvalueOpKernels(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_lgamma_op.py b/python/paddle/fluid/tests/unittests/test_lgamma_op.py index 686d5b1eb6dfe..8e9edab55baf8 100644 --- a/python/paddle/fluid/tests/unittests/test_lgamma_op.py +++ b/python/paddle/fluid/tests/unittests/test_lgamma_op.py @@ -24,6 +24,7 @@ class TestLgammaOp(OpTest): def setUp(self): self.op_type = 'lgamma' + self.python_api = paddle.lgamma self.init_dtype_type() shape = (5, 20) data = np.random.random(shape).astype(self.dtype) + 1 @@ -38,10 +39,10 @@ def init_dtype_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=1e-7) + self.check_grad(['X'], 'Out', numeric_grad_delta=1e-7, check_eager=True) class TestLgammaOpFp32(TestLgammaOp): @@ -49,7 +50,8 @@ def init_dtype_type(self): self.dtype = np.float32 def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=0.005) + self.check_grad( + ['X'], 'Out', numeric_grad_delta=0.005, check_eager=True) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py index 423eeaf3ada45..b3b164725fc34 100644 --- a/python/paddle/fluid/tests/unittests/test_log_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -42,6 +42,7 @@ def ref_log_softmax_grad(x, axis): class TestLogSoftmaxOp(OpTest): def setUp(self): self.op_type = 'log_softmax' + self.python_api = F.log_softmax self.dtype = 'float64' self.shape = [2, 3, 4, 5] self.axis = -1 @@ -59,10 +60,11 @@ def set_attrs(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_eager=True) class TestLogSoftmaxShape(TestLogSoftmaxOp): @@ -80,6 +82,7 @@ def set_attrs(self): class TestLogSoftmaxBF16Op(OpTest): def setUp(self): self.op_type = 'log_softmax' + self.python_api = F.log_softmax self.dtype = np.uint16 self.shape = [2, 3, 4, 5] self.axis = -1 @@ -94,12 +97,14 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_eager=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], ['Out'], user_defined_grads=[self.x_grad]) + place, ['X'], ['Out'], + user_defined_grads=[self.x_grad], + check_eager=True) class TestNNLogSoftmaxAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_max_op.py b/python/paddle/fluid/tests/unittests/test_max_op.py index 5e413e80d7143..d5b884dfcc93b 100644 --- a/python/paddle/fluid/tests/unittests/test_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_max_op.py @@ -18,6 +18,7 @@ import numpy as np from op_test import OpTest, skip_check_grad_ci, check_out_dtype import paddle +from paddle.fluid.framework import _test_eager_guard import paddle.fluid.core as core @@ -86,6 +87,10 @@ def test_imperative_api(self): z_expected = np.array(np.max(np_x, axis=0)) self.assertEqual((np_z == z_expected).all(), True) + def test_eager_api(self): + with _test_eager_guard(): + self.test_imperative_api() + def test_big_dimension(self): paddle.disable_static() x = paddle.rand(shape=[2, 2, 2, 2, 2, 2, 2]) diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 7a49770e57985..b20c2932f09dd 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -25,9 +25,22 @@ np.random.seed(10) +def mean_wrapper(x, axis=None, keepdim=False, reduce_all=False): + if reduce_all == True: + return paddle.mean(x, range(len(x.shape)), keepdim) + return paddle.mean(x, axis, keepdim) + + +def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False): + if reduce_all == True: + return paddle.mean(x, range(len(x.shape)), keepdim) + return paddle.mean(x, axis, keepdim) + + class TestMeanOp(OpTest): def setUp(self): self.op_type = "mean" + self.python_api = mean_wrapper self.dtype = np.float64 self.init_dtype_type() self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} @@ -37,10 +50,10 @@ def init_dtype_type(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_checkout_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestMeanOpError(unittest.TestCase): @@ -117,6 +130,7 @@ def ref_reduce_mean_grad(x, axis, dtype): class TestReduceMeanOp(OpTest): def setUp(self): self.op_type = 'reduce_mean' + self.python_api = reduce_mean_wrapper self.dtype = 'float64' self.shape = [2, 3, 4, 5] self.axis = [0] @@ -145,7 +159,7 @@ def set_attrs(self): def test_check_output(self): if self.dtype != 'float16': - self.check_output() + self.check_output(check_eager=True) else: if not core.is_compiled_with_cuda(): return @@ -154,7 +168,7 @@ def test_check_output(self): def test_check_grad(self): if self.dtype != 'float16': - self.check_grad(['X'], ['Out']) + self.check_grad(['X'], ['Out'], check_eager=True) else: return if not core.is_compiled_with_cuda(): @@ -175,6 +189,7 @@ def test_check_grad(self): class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp): def setUp(self): self.op_type = 'reduce_mean' + self.python_api = reduce_mean_wrapper self.dtype = 'float64' self.shape = [2, 3, 4, 5] diff --git a/python/paddle/fluid/tests/unittests/test_min_op.py b/python/paddle/fluid/tests/unittests/test_min_op.py index f865c234a747c..13f82fb9bd7cb 100644 --- a/python/paddle/fluid/tests/unittests/test_min_op.py +++ b/python/paddle/fluid/tests/unittests/test_min_op.py @@ -19,6 +19,7 @@ from op_test import OpTest, skip_check_grad_ci, check_out_dtype import paddle import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard class ApiMinTest(unittest.TestCase): @@ -86,6 +87,10 @@ def test_imperative_api(self): z_expected = np.array(np.min(np_x, axis=0)) self.assertEqual((np_z == z_expected).all(), True) + def test_eager_api(self): + with _test_eager_guard(): + self.test_imperative_api() + class TestOutDtype(unittest.TestCase): def test_min(self): diff --git a/python/paddle/fluid/tests/unittests/test_mode_op.py b/python/paddle/fluid/tests/unittests/test_mode_op.py index 1b0458f2e255f..471904b0c9426 100644 --- a/python/paddle/fluid/tests/unittests/test_mode_op.py +++ b/python/paddle/fluid/tests/unittests/test_mode_op.py @@ -62,6 +62,7 @@ def init_args(self): def setUp(self): self.op_type = "mode" + self.python_api = paddle.mode self.dtype = np.float64 np.random.seed(666) self.input_data = np.random.rand(2, 64, 1) @@ -73,11 +74,11 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(set(['X']), 'Out') + self.check_grad(set(['X']), 'Out', check_eager=True) class TestModeOpLastdim(OpTest): @@ -86,6 +87,7 @@ def init_args(self): def setUp(self): self.op_type = "mode" + self.python_api = paddle.mode self.dtype = np.float64 np.random.seed(666) self.input_data = np.random.rand(2, 1, 1, 2, 30) @@ -97,11 +99,11 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): paddle.enable_static() - self.check_grad(set(['X']), 'Out') + self.check_grad(set(['X']), 'Out', check_eager=True) class TestModeOpKernels(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py index ef912699455d1..17c45299d0fc5 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_all.py +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -20,6 +20,24 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core +from paddle import _C_ops +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph + + +# hack method for test p_norm final state +def p_norm_python_api(x, + p=2.0, + axis=-1, + epsilon=1e-12, + keepdim=False, + as_vector=False): + if in_dygraph_mode(): + return _C_ops.final_state_p_norm(x, p, axis, epsilon, keepdim, + as_vector) + if _in_legacy_dygraph(): + return _C_ops.p_norm(x, 'axis', axis, 'porder', + float(p), 'keepdim', keepdim, 'epsilon', epsilon, + 'as_vector', as_vector) def p_norm(x, axis, porder, keepdims=False, reduce_all=False): @@ -110,6 +128,7 @@ def test_check_grad(self): class TestPnormOp(OpTest): def setUp(self): self.op_type = "p_norm" + self.python_api = p_norm_python_api self.init_test_case() x = (np.random.random(self.shape) + 0.5).astype(self.dtype) norm = p_norm(x, self.axis, self.porder, self.keepdim, self.asvector) @@ -125,10 +144,10 @@ def setUp(self): self.gradient = self.calc_gradient() def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def init_test_case(self): self.shape = [2, 3, 4, 5] @@ -287,6 +306,7 @@ def init_test_case(self): class TestPnormBF16Op(OpTest): def setUp(self): self.op_type = "p_norm" + self.python_api = p_norm_python_api self.init_test_case() self.x = (np.random.random(self.shape) + 0.5).astype(np.float32) self.norm = p_norm(self.x, self.axis, self.porder, self.keepdim, @@ -304,12 +324,15 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=1e-3, check_eager=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', user_defined_grads=self.gradient) + place, ['X'], + 'Out', + user_defined_grads=self.gradient, + check_eager=True) def init_test_case(self): self.shape = [2, 3, 4, 5] diff --git a/python/paddle/fluid/tests/unittests/test_normalize.py b/python/paddle/fluid/tests/unittests/test_normalize.py index 274a4ebee7c3c..2f52ae391c7de 100644 --- a/python/paddle/fluid/tests/unittests/test_normalize.py +++ b/python/paddle/fluid/tests/unittests/test_normalize.py @@ -20,6 +20,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core import numpy as np +from paddle.fluid.framework import _test_eager_guard def p_normalize(x, axis=1, p=2, epsilon=1e-12, keepdims=True): @@ -87,6 +88,12 @@ def test_cpu(self): with fluid.program_guard(fluid.Program()): self.run_static() + def test_cpu_eager(self): + with _test_eager_guard(): + paddle.disable_static(place=paddle.fluid.CPUPlace()) + self.run_imperative() + paddle.enable_static() + def test_gpu(self): if not fluid.core.is_compiled_with_cuda(): return @@ -98,6 +105,15 @@ def test_gpu(self): with fluid.program_guard(fluid.Program()): self.run_static(use_gpu=True) + def test_gpu_eager(self): + with _test_eager_guard(): + if not fluid.core.is_compiled_with_cuda(): + return + + paddle.disable_static(place=paddle.fluid.CUDAPlace(0)) + self.run_imperative() + paddle.enable_static() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 7abc314bc1ba0..12f6f7b572108 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -30,6 +30,7 @@ def setUp(self): self.variable_paddings = False self.initTestCase() self.op_type = "pad3d" + self.python_api = paddle.nn.functional.pad self.inputs = {'X': np.random.random(self.shape).astype("float64")} self.attrs = {} if self.variable_paddings: @@ -72,10 +73,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def initTestCase(self): self.shape = (2, 3, 4, 5, 6) diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 6afc462322fba..56b32d41a9bd1 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -157,6 +157,7 @@ def setUp(self): self.init_input_shape() self.init_attr() self.op_type = "prelu" + self.python_api = paddle.nn.functional.prelu x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) # Since zero point in prelu is not differentiable, avoid randomize @@ -207,10 +208,10 @@ def init_attr(self): self.attrs = {'mode': "channel", "data_format": "NCHW"} def test_check_output(self): - self.check_output() + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X', 'Alpha'], 'Out') + self.check_grad(['X', 'Alpha'], 'Out', check_eager=False) @skip_check_grad_ci( @@ -373,7 +374,8 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=atol) + self.check_output_with_place( + place, atol=atol, check_eager=False) def test_check_grad(self): place = core.CUDAPlace(0) @@ -381,7 +383,8 @@ def test_check_grad(self): self.check_grad_with_place( place, ['X', 'Alpha'], 'Out', - max_relative_error=max_relative_error) + max_relative_error=max_relative_error, + check_eager=False) cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op") TestPReluFp16Case.__name__ = cls_name diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index d246356b4ec75..737e1af851fa7 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -172,6 +172,7 @@ class TestMaxOp(OpTest): def setUp(self): self.op_type = "reduce_max" + self.python_api = paddle.max self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.attrs = {'dim': [-1]} self.outputs = { @@ -179,7 +180,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) @skip_check_grad_ci( @@ -190,6 +191,7 @@ class TestMinOp(OpTest): def setUp(self): self.op_type = "reduce_min" + self.python_api = paddle.min self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.attrs = {'dim': [2]} self.outputs = { @@ -197,7 +199,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestMin6DOp(OpTest): @@ -205,6 +207,7 @@ class TestMin6DOp(OpTest): def setUp(self): self.op_type = "reduce_min" + self.python_api = paddle.min self.inputs = { 'X': np.random.random((2, 4, 3, 5, 6, 10)).astype("float64") } @@ -214,7 +217,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestMin8DOp(OpTest): @@ -222,6 +225,7 @@ class TestMin8DOp(OpTest): def setUp(self): self.op_type = "reduce_min" + self.python_api = paddle.min self.inputs = { 'X': np.random.random((2, 4, 3, 5, 6, 3, 2, 4)).astype("float64") } @@ -231,7 +235,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestProdOp(OpTest): @@ -302,17 +306,19 @@ def test_check_grad(self): class TestAllOp(OpTest): def setUp(self): self.op_type = "reduce_all" + self.python_api = paddle.all self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.outputs = {'Out': self.inputs['X'].all()} self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAll8DOp(OpTest): def setUp(self): self.op_type = "reduce_all" + self.python_api = paddle.all self.inputs = { 'X': np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") @@ -321,23 +327,25 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAllOpWithDim(OpTest): def setUp(self): self.op_type = "reduce_all" + self.python_api = paddle.all self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.attrs = {'dim': (1, )} self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAll8DOpWithDim(OpTest): def setUp(self): self.op_type = "reduce_all" + self.python_api = paddle.all self.inputs = { 'X': np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") @@ -346,12 +354,13 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAllOpWithKeepDim(OpTest): def setUp(self): self.op_type = "reduce_all" + self.python_api = paddle.all self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.attrs = {'dim': [1], 'keep_dim': True} self.outputs = { @@ -360,12 +369,13 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAll8DOpWithKeepDim(OpTest): def setUp(self): self.op_type = "reduce_all" + self.python_api = paddle.all self.inputs = { 'X': np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") @@ -377,7 +387,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAllOpError(unittest.TestCase): @@ -395,17 +405,19 @@ def test_errors(self): class TestAnyOp(OpTest): def setUp(self): self.op_type = "reduce_any" + self.python_api = paddle.any self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.outputs = {'Out': self.inputs['X'].any()} self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAny8DOp(OpTest): def setUp(self): self.op_type = "reduce_any" + self.python_api = paddle.any self.inputs = { 'X': np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") @@ -414,23 +426,25 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAnyOpWithDim(OpTest): def setUp(self): self.op_type = "reduce_any" + self.python_api = paddle.any self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.attrs = {'dim': [1]} self.outputs = {'Out': self.inputs['X'].any(axis=1)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAny8DOpWithDim(OpTest): def setUp(self): self.op_type = "reduce_any" + self.python_api = paddle.any self.inputs = { 'X': np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") @@ -439,12 +453,13 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAnyOpWithKeepDim(OpTest): def setUp(self): self.op_type = "reduce_any" + self.python_api = paddle.any self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.attrs = {'dim': (1, ), 'keep_dim': True} self.outputs = { @@ -453,12 +468,13 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAny8DOpWithKeepDim(OpTest): def setUp(self): self.op_type = "reduce_any" + self.python_api = paddle.any self.inputs = { 'X': np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") @@ -470,7 +486,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestAnyOpError(unittest.TestCase): @@ -600,6 +616,7 @@ class TestReduceMaxOpMultiAxises(OpTest): def setUp(self): self.op_type = "reduce_max" + self.python_api = paddle.max self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.attrs = {'dim': [-2, -1]} self.outputs = { @@ -607,7 +624,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) @skip_check_grad_ci( @@ -618,6 +635,7 @@ class TestReduceMinOpMultiAxises(OpTest): def setUp(self): self.op_type = "reduce_min" + self.python_api = paddle.min self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.attrs = {'dim': [1, 2]} self.outputs = { @@ -625,7 +643,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestKeepDimReduceSumMultiAxises(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py index fc43a8e782382..7d7893cfda0b1 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py @@ -27,6 +27,10 @@ class TestSqueezeOp(OpTest): def setUp(self): self.op_type = "squeeze2" + self.python_api = paddle.squeeze + self.python_out_sig = [ + "Out" + ] # python out sig is customized output signature. self.init_test_case() self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")} self.init_attrs() @@ -36,10 +40,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=['XShape']) + self.check_output(no_check_set=['XShape'], check_eager=True) def test_check_grad(self): - self.check_grad(["X"], "Out") + self.check_grad(["X"], "Out", check_eager=True) def init_test_case(self): self.ori_shape = (1, 3, 1, 40) diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py index b75e32f2bad14..af9d3db629581 100755 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py @@ -29,6 +29,8 @@ class TestUnsqueezeOp(OpTest): def setUp(self): self.init_test_case() self.op_type = "unsqueeze2" + self.python_api = paddle.unsqueeze + self.python_out_sig = ["Out"] self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")} self.init_attrs() self.outputs = { @@ -37,10 +39,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=["XShape"]) + self.check_output(no_check_set=["XShape"], check_eager=True) def test_check_grad(self): - self.check_grad(["X"], "Out") + self.check_grad(["X"], "Out", check_eager=True) def init_test_case(self): self.ori_shape = (3, 40) @@ -88,6 +90,8 @@ class TestUnsqueezeOp_AxesTensorList(OpTest): def setUp(self): self.init_test_case() self.op_type = "unsqueeze2" + self.python_out_sig = ["Out"] + self.python_api = paddle.unsqueeze axes_tensor_list = [] for index, ele in enumerate(self.axes): @@ -105,10 +109,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=["XShape"]) + self.check_output(no_check_set=["XShape"], check_eager=True) def test_check_grad(self): - self.check_grad(["X"], "Out") + self.check_grad(["X"], "Out", check_eager=True) def init_test_case(self): self.ori_shape = (20, 5) @@ -152,6 +156,8 @@ class TestUnsqueezeOp_AxesTensor(OpTest): def setUp(self): self.init_test_case() self.op_type = "unsqueeze2" + self.python_out_sig = ["Out"] + self.python_api = paddle.unsqueeze self.inputs = { "X": np.random.random(self.ori_shape).astype("float64"), @@ -164,10 +170,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=["XShape"]) + self.check_output(no_check_set=["XShape"], check_eager=True) def test_check_grad(self): - self.check_grad(["X"], "Out") + self.check_grad(["X"], "Out", check_eager=True) def init_test_case(self): self.ori_shape = (20, 5) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 6134badd79232..66c50d16e7201 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -23,7 +23,7 @@ import warnings from ...fluid.layer_helper import LayerHelper from ...fluid.framework import convert_np_dtype_to_dtype_ -from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode +from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle from paddle import _C_ops, in_dynamic_mode @@ -519,7 +519,9 @@ def prelu(x, weight, data_format="NCHW", name=None): 1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]." mode = 'channel' - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_prelu(x, weight, data_format, mode) + if _in_legacy_dygraph(): return _C_ops.prelu(x, weight, 'mode', mode, 'data_format', data_format) helper = LayerHelper('prelu', **locals()) @@ -578,9 +580,10 @@ def relu_(x, name=None): Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_nn_cn_relu`. """ - if paddle.fluid.framework._in_eager_mode_: + if in_dygraph_mode(): return _C_ops.final_state_relu_(x) - return _C_ops.relu_(x) + if _in_legacy_dygraph(): + return _C_ops.relu_(x) def log_sigmoid(x, name=None): @@ -1092,7 +1095,9 @@ def softshrink(x, threshold=0.5, name=None): "The threshold must be no less than zero. Received: {}.".format( threshold)) - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_soft_shrink(x, threshold) + if _in_legacy_dygraph(): return _C_ops.softshrink(x, 'lambda', threshold) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], @@ -1371,10 +1376,12 @@ def log_softmax(x, axis=-1, dtype=None, name=None): if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if _non_static_mode(): if dtype is not None: x = _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) - return _C_ops.log_softmax(x, 'axis', axis) + if _in_legacy_dygraph(): + return _C_ops.log_softmax(x, 'axis', axis) + return _C_ops.final_state_log_softmax(x, axis) if dtype is None: check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index d988d1653ca69..131d31aa02405 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -38,6 +38,7 @@ from paddle.framework import in_dynamic_mode from paddle.tensor.creation import full from paddle.framework import core +from paddle.fluid.framework import _in_legacy_dygraph from paddle.static import default_main_program __all__ = [] @@ -1352,8 +1353,11 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): if in_dynamic_mode(): if isinstance(pad, Variable): pad = pad.numpy() - out = _C_ops.pad3d(x, "paddings", pad, "mode", mode, "value", value, - "data_format", data_format, "name", name) + if _in_legacy_dygraph(): + out = _C_ops.pad3d(x, "paddings", pad, "mode", mode, "value", value, + "data_format", data_format, "name", name) + else: + out = _C_ops.final_state_pad3d(x, pad, mode, value, data_format) else: attrs = {'mode': mode, 'value': value, 'data_format': data_format} inputs = {'X': [x]} diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index ca5629aab6790..3748a5904ba96 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -921,8 +921,11 @@ def kl_div(input, label, reduction='mean', name=None): label.dtype) == 'float32': label = paddle.cast(label, 'float64') - if paddle.in_dynamic_mode(): - out = _C_ops.kldiv_loss(input, label, 'reduction', 'none') + if _non_static_mode(): + if _in_legacy_dygraph(): + out = _C_ops.kldiv_loss(input, label, 'reduction', 'none') + else: + out = _C_ops.final_state_kldiv_loss(input, label, 'none') if reduction == 'mean': out = paddle.mean(out) elif reduction == 'sum': diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 536c611d85f28..3f7e819f442c1 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -24,6 +24,7 @@ import numbers from paddle import _C_ops from paddle import in_dynamic_mode +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph __all__ = [] @@ -78,7 +79,12 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): # [[0. 0.24253564 0.37139067] # [1. 0.97014254 0.9284767 ]] """ - if in_dynamic_mode(): + if in_dygraph_mode(): + eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype) + out = _C_ops.final_state_p_norm(x, float(p), axis, epsilon, True, False) + return x / _C_ops.elementwise_max(out, eps) + + if _in_legacy_dygraph(): eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype) out = _C_ops.p_norm(x, 'axis', axis, 'porder', float(p), 'keepdim', True, 'epsilon', epsilon) diff --git a/python/paddle/nn/layer/distance.py b/python/paddle/nn/layer/distance.py index 1fb7e8c4f2148..eb85de5711078 100644 --- a/python/paddle/nn/layer/distance.py +++ b/python/paddle/nn/layer/distance.py @@ -20,6 +20,7 @@ from ...fluid.layer_helper import LayerHelper from paddle import _C_ops from paddle import in_dynamic_mode +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph __all__ = [] @@ -78,7 +79,12 @@ def __init__(self, p=2., epsilon=1e-6, keepdim=False, name=None): check_type(self.keepdim, 'keepdim', (bool), 'PairwiseDistance') def forward(self, x, y): - if in_dynamic_mode(): + if in_dygraph_mode(): + sub = _C_ops.elementwise_sub(x, y) + return _C_ops.final_state_p_norm(sub, self.p, 1, self.epsilon, + self.keepdim, False) + + if _in_legacy_dygraph(): sub = _C_ops.elementwise_sub(x, y) return _C_ops.p_norm(sub, 'axis', 1, 'porder', self.p, 'keepdim', self.keepdim, 'epsilon', self.epsilon) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 7c4c8a9b793c9..818ce2f5c6757 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -288,10 +288,16 @@ def vector_norm(input, axis (int, optional): None for last dimension. keepdim (bool, optional): Whether keep the dimensions as the `input`, Default False. """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if axis is None: axis = -1 + return _C_ops.final_state_p_norm(input, porder, axis, 1e-12, + keepdim, asvector) + + if _in_legacy_dygraph(): if axis is None: axis = -1 return _C_ops.p_norm(input, 'porder', porder, 'axis', axis, 'keepdim', keepdim, 'asvector', asvector) + if porder is not None: check_type(porder, 'porder', (float, int), 'p_norm') if axis is not None: diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 3c02c11b933c1..e3ffd36d77972 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -122,11 +122,12 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): # [True] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_allclose(x, y, rtol, atol, equal_nan) + if _in_legacy_dygraph(): return _C_ops.allclose(x, y, 'rtol', str(rtol), 'atol', str(atol), 'equal_nan', equal_nan) - check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose') check_variable_and_dtype(y, "input", ['float32', 'float64'], 'allclose') check_type(rtol, 'rtol', float, 'allclose') @@ -678,7 +679,9 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): # [True, True] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_isclose(x, y, rtol, atol, equal_nan) + if _in_legacy_dygraph(): return _C_ops.isclose(x, y, 'rtol', str(rtol), 'atol', str(atol), 'equal_nan', equal_nan) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 01836eaed09c9..9fe3304bf2471 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1409,7 +1409,9 @@ def gather(x, index, axis=None, name=None): if axis is None: axis = 0 - if paddle.in_dynamic_mode(): + #if in_dygraph_mode(): + #return _C_ops.final_state_gather(x, index, axis) + if _non_static_mode(): axis = axis.item() if isinstance(axis, paddle.Tensor) else axis return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 10de77a44a910..e932595fc378e 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -28,7 +28,7 @@ import paddle from paddle.static import Variable from ..framework import core -from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode +from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode from ..framework import _varbase_creator, convert_np_dtype_to_dtype_ from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype @@ -150,7 +150,17 @@ def pow(x, y, name=None): """ # in dynamic graph mode - if paddle.in_dynamic_mode(): + #if in_dygraph_mode(): + #if isinstance(y, (int, float)): + #return _C_ops.final_state_pow(x, y) + #elif isinstance(y, (paddle.Tensor, Variable)): + #return _elementwise_op_in_dygraph( + #x, y, axis=-1, act=None, op_name='elementwise_pow') + #else: + #raise TypeError('y must be scalar or tensor type, but received: %s '% (y.dtype)) + + #if _in_legacy_dygraph(): + if _non_static_mode(): if isinstance(y, (int, float)): return _C_ops.pow(x, 'factor', y) elif isinstance(y, (paddle.Tensor, Variable)): @@ -719,7 +729,9 @@ def fmax(x, y, name=None): op_type = 'elementwise_fmax' axis = -1 act = None - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_fmax(x, y, axis) + if _in_legacy_dygraph(): return _elementwise_op_in_dygraph( x, y, axis=axis, act=act, op_name=op_type) return _elementwise_op(LayerHelper(op_type, **locals())) @@ -780,7 +792,9 @@ def fmin(x, y, name=None): op_type = 'elementwise_fmin' axis = -1 act = None - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_fmin(x, y, axis) + if _in_legacy_dygraph(): return _elementwise_op_in_dygraph( x, y, axis=axis, act=act, op_name=op_type) return _elementwise_op(LayerHelper(op_type, **locals())) @@ -1711,7 +1725,11 @@ def max(x, axis=None, keepdim=False, name=None): """ reduce_all, axis = _get_reduce_all_value(axis) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all: + axis = range(len(x.shape)) + return _C_ops.final_state_max(x, axis, keepdim) + if _in_legacy_dygraph(): return _C_ops.reduce_max(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) @@ -1811,7 +1829,12 @@ def min(x, axis=None, keepdim=False, name=None): """ reduce_all, axis = _get_reduce_all_value(axis) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all: + axis = range(len(x.shape)) + return _C_ops.final_state_min(x, axis, keepdim) + + if _in_legacy_dygraph(): return _C_ops.reduce_min(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) @@ -2081,7 +2104,9 @@ def log1p(x, name=None): # [[0.], [0.6931472]] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_log1p(x) + if _in_legacy_dygraph(): return _C_ops.log1p(x) check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log1p") @@ -2130,7 +2155,9 @@ def log2(x, name=None): res = paddle.log2(x_i) print(res) # [1.0] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_log2(x) + if _in_legacy_dygraph(): return _C_ops.log2(x) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], "log2") @@ -2180,7 +2207,9 @@ def log10(x, name=None): res = paddle.log10(x_i) print(res) # [1.0] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_log10(x) + if _in_legacy_dygraph(): return _C_ops.log10(x) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], "log10") @@ -2667,7 +2696,9 @@ def cumprod(x, dim=None, dtype=None, name=None): if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): x = cast(x, dtype) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_cumprod(x, dim) + if _in_legacy_dygraph(): return _C_ops.cumprod(x, 'dim', dim) check_variable_and_dtype(x, "x", ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'], 'cumprod') @@ -3028,7 +3059,12 @@ def all(x, axis=None, keepdim=False, name=None): else: reduce_all_flag = False - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all_flag: + axis = range(len(x.shape)) + return _C_ops.final_state_all(x, axis, keepdim) + + if _in_legacy_dygraph(): axis = axis if axis != None and axis != [] else [0] return _C_ops.reduce_all(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all_flag) @@ -3120,7 +3156,12 @@ def any(x, axis=None, keepdim=False, name=None): else: reduce_all_flag = False - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all_flag: + axis = range(len(x.shape)) + return _C_ops.final_state_any(x, axis, keepdim) + + if _in_legacy_dygraph(): axis = axis if axis != None and axis != [] else [0] return _C_ops.reduce_any(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all_flag) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index e295431df3389..7a2dd22cff294 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -518,7 +518,9 @@ def mode(x, axis=-1, keepdim=False, name=None): # [1, 0]])) """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_mode(x, axis, keepdim) + if _in_legacy_dygraph(): return _C_ops.mode(x, "axis", axis, "keepdim", keepdim) helper = LayerHelper("mode", **locals()) @@ -1002,11 +1004,16 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): # [[0, 2], # [1, 2]])) """ - if paddle.in_dynamic_mode(): + if _non_static_mode(): if axis is not None: - return _C_ops.kthvalue(x, 'k', k, "axis", axis, "keepdim", keepdim) + if _in_legacy_dygraph(): + return _C_ops.kthvalue(x, 'k', k, "axis", axis, "keepdim", + keepdim) + return _C_ops.final_state_kthvalue(x, k, axis, keepdim) else: - return _C_ops.kthvalue(x, 'k', k, "keepdim", keepdim) + if _in_legacy_dygraph(): + return _C_ops.kthvalue(x, 'k', k, "keepdim", keepdim) + return _C_ops.final_state_kthvalue(x, k, -1, keepdim) helper = LayerHelper("kthvalue", **locals()) inputs = {"X": [x]} diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 5876b9180823e..89462e2a8721f 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -18,6 +18,7 @@ from ..static import Variable from ..fluid.layer_helper import LayerHelper from ..framework import core +from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode from .search import where from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype import paddle @@ -87,7 +88,11 @@ def mean(x, axis=None, keepdim=False, name=None): if axis is None or len(axis) == 0: axis = [0] - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all: + axis = range(len(x.shape)) + return _C_ops.final_state_mean(x, axis, keepdim) + if _in_legacy_dygraph(): return _C_ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index da79a928dba7a..ef1e4797874a8 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -72,6 +72,31 @@ func : addmm backward : addmm_grad +- api : all + args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : all + +- api : allclose + args : (Tensor x, Tensor y, Scalar rtol, Scalar atol, bool equal_nan) + output : Tensor(out) + infer_meta : + func : AllValueCompareInferMeta + param: [x, y] + kernel : + func : allclose + +- api : any + args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : any + # arg_max - api : argmax args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) @@ -235,6 +260,15 @@ data_type : x backward : cast_grad +- api : ceil + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : ceil + backward : ceil_grad + # cholesky - api : cholesky args : (Tensor x, bool upper) @@ -306,6 +340,16 @@ func : cross backward : cross_grad +- api : cumprod + args : (Tensor x, int dim) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : cumprod + backward : cumprod_grad + # cumsum - api : cumsum args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) @@ -458,6 +502,35 @@ kernel : func : flip +- api : floor + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : floor + backward : floor_grad + +- api : fmax + args : (Tensor x, Tensor y, int axis) + output : Tensor(out) + infer_meta : + param: [x, y] + func : ElementwiseInferMeta + kernel : + func : fmax + backward : fmax_grad + +- api : fmin + args : (Tensor x, Tensor y, int axis) + output : Tensor(out) + infer_meta : + param: [x, y] + func : ElementwiseInferMeta + kernel : + func : fmin + backward : fmin_grad + - api : full args : (IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) output: Tensor @@ -500,6 +573,16 @@ kernel : func : gather_tree +- api : gelu + args : (Tensor x, bool approximate) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : gelu + backward : gelu_grad + - api : greater args : (Tensor x, Tensor y, int axis = -1) output : Tensor @@ -594,6 +677,15 @@ kernel : func : is_empty +- api : isclose + args : (Tensor x, Tensor y, Scalar rtol, Scalar atol, bool equal_nan) + output : Tensor(out) + infer_meta : + func : ValueCompareInferMeta + param: [x, y] + kernel : + func : isclose + # isfinite - api : isfinite args : (Tensor x) @@ -621,6 +713,25 @@ kernel : func : isnan, isnan_sr +- api : kldiv_loss + args : (Tensor x, Tensor label, str reduction) + output : Tensor(out) + infer_meta : + func : KLDivInferMeta + kernel : + func : kldiv_loss + data_type : x + backward : kldiv_loss_grad + +- api : kthvalue + args : (Tensor x, int k, int axis, bool keepdim) + output : Tensor(out), Tensor(indices) + infer_meta : + func : KthvalueInferMeta + kernel : + func : kthvalue + backward : kthvalue_grad + # leaky_relu - api : leaky_relu args : (Tensor x, float alpha) @@ -657,6 +768,51 @@ kernel : func : less_than +- api : lgamma + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : lgamma + backward : lgamma_grad + +- api : log + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : log + backward: log_grad + +- api : log10 + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : log10 + backward: log10_grad + +- api : log1p + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : log1p + backward: log1p_grad + +- api : log2 + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : log2 + backward: log2_grad + # log_loss - api : log_loss args : (Tensor input, Tensor label, float epsilon) @@ -667,6 +823,15 @@ func : log_loss backward : log_loss_grad +- api : log_softmax + args : (Tensor x, int axis) + output : Tensor(out) + infer_meta : + func : UnchangedInferMetaCheckAxis + kernel : + func : log_softmax + backward : log_softmax_grad + # logical_and - api : logical_and args : (Tensor x, Tensor y) @@ -744,6 +909,15 @@ func : matrix_power backward : matrix_power_grad +- api : max + args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : max + backward : max_grad + - api : maximum args : (Tensor x, Tensor y) output : Tensor(out) @@ -754,12 +928,22 @@ backward : maximum_grad - api : mean - args : (Tensor x, int64_t[] axis={}, bool keep_dim=false) - output : Tensor + args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + output : Tensor(out) infer_meta : func : ReduceInferMeta kernel : func : mean + backward : mean_grad + +- api : min + args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : min + backward : min_grad - api : minimum args : (Tensor x, Tensor y) @@ -770,6 +954,15 @@ func : minimum backward : minimum_grad +- api : mode + args : (Tensor x, int axis, bool keepdim) + output : Tensor(out), Tensor(indices) + infer_meta : + func : ModeInferMeta + kernel : + func : mode + backward : mode_grad + - api : modulo args : (Tensor x, Tensor y) output : Tensor @@ -838,6 +1031,15 @@ output : Tensor invoke : full_like(x, 1, dtype, place) +- api : p_norm + args : (Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false) + output : Tensor(out) + infer_meta : + func : PNormInferMeta + kernel : + func : p_norm + backward : p_norm_grad + # pad - api : pad args : (Tensor x, int[] paddings, float pad_value) @@ -848,6 +1050,15 @@ func : pad # backward : pad_grad +- api : pad3d + args : (Tensor x, IntArray paddings, str mode, float pad_value, str data_format) + output : Tensor(out) + infer_meta : + func : Pad3dInferMeta + kernel : + func : pad3d + backward : pad3d_grad + # pixel_shuffle - api : pixel_shuffle args : (Tensor x, int upscale_factor, str data_format) @@ -875,6 +1086,15 @@ kernel: func : pool2d +- api : prelu + args : (Tensor x, Tensor alpha, str data_format, str mode) + output : Tensor(out) + infer_meta : + func : PReluInferMeta + kernel : + func : prelu + backward : prelu_grad + # put_along_axis - api : put_along_axis args : (Tensor x, Tensor index, Tensor value, int axis, str reduce) @@ -927,6 +1147,15 @@ intermediate : xshape backward: reshape_grad +- api : round + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : round + backward : round_grad + - api : scale args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) output : Tensor @@ -1107,6 +1336,16 @@ func : square backward : square_grad +- api : squeeze + args : (Tensor x, int[] axes) + output : Tensor(xshape), Tensor(out) + infer_meta : + func : SqueezeInferMeta + kernel : + func : squeeze + view: (x -> out) + backward : squeeze_grad + - api : strided_slice args : (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) output : Tensor @@ -1256,6 +1495,16 @@ backward : unfold_grad # no_need_buffer : x +- api : unsqueeze + args : (Tensor x, IntArray axes) + output : Tensor(xshape), Tensor(out) + infer_meta : + func : UnsqueezeInferMeta + kernel : + func : unsqueeze + view: (x -> out) + backward : unsqueeze_grad + # viterbi_decode - api : viterbi_decode args : (Tensor input, Tensor transition, Tensor length, bool include_bos_eos_tag) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index dc7261eef1650..a59b02c34cf76 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -142,6 +142,16 @@ func : cast_grad data_type : out_grad +- backward_api : ceil_grad + forward : ceil(Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [out_grad] + kernel : + func : ceil_grad + - backward_api : cholesky_grad forward : cholesky (Tensor x, bool upper) -> Tensor(out) args : (Tensor out, Tensor out_grad, bool upper) @@ -192,6 +202,25 @@ kernel : func : cross_grad +- backward_api : cumprod_grad + forward : cumprod (Tensor x, int dim) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int dim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : cumprod_grad +# - backward_api : gumbel_softmax_grad +# forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out) +# args : (Tensor out, Tensor out_grad, int axis) +# output : Tensor(x_grad) +# infer_meta : +# func : GumbelSoftmaxGradInferMeta +# param : [out, out_grad, axis] +# kernel : +# func : gumbel_softmax_grad + - backward_api : diagonal_grad forward : diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) args : (Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1) @@ -273,6 +302,36 @@ kernel : func : erfinv_grad +- backward_api : floor_grad + forward : floor(Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [out_grad] + kernel : + func : floor_grad + +- backward_api : fmax_grad + forward : fmax(Tensor x, Tensor y, int axis) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad, int axis) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param: [x, y] + kernel : + func : fmax_grad + +- backward_api : fmin_grad + forward : fmin(Tensor x, Tensor y, int axis) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad, int axis) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param: [x, y] + kernel : + func : fmin_grad + - backward_api : gather_nd_grad forward : gather_nd (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) @@ -283,6 +342,16 @@ kernel : func : gather_nd_grad +- backward_api : gelu_grad + forward : gelu(Tensor x, bool approximate) -> Tensor(out) + args : (Tensor x, Tensor out_grad, bool approximate) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : gelu_grad + - backward_api : hard_shrink_grad forward : hard_shrink (Tensor x, float threshold) -> Tensor(out) args : (Tensor x, Tensor out_grad, float threshold) @@ -314,6 +383,26 @@ func : index_sample_grad data_type : out_grad +- backward_api : kldiv_loss_grad + forward : kldiv_loss(Tensor x, Tensor label, str reduction) -> Tensor(out) + args : (Tensor x, Tensor label, Tensor out_grad, str reduction) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : kldiv_loss_grad + +- backward_api : kthvalue_grad + forward : kthvalue(Tensor x, int k, int axis, bool keepdim) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int k, int axis, bool keepdim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : kthvalue_grad + - backward_api : label_smooth_grad forward : label_smooth (Tensor label, Tensor prior_dist, float epsilon) -> Tensor(out) args : (Tensor out_grad, float epsilon) @@ -345,6 +434,56 @@ kernel : func : lerp_grad +- backward_api : lgamma_grad + forward : lgamma(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : lgamma_grad + +- backward_api : log10_grad + forward : log10 (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : log10_grad + +- backward_api : log1p_grad + forward : log1p (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : log1p_grad + +- backward_api : log2_grad + forward : log2 (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : log2_grad + +- backward_api : log_grad + forward : log (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : log_grad + - backward_api : log_loss_grad forward : log_loss (Tensor input, Tensor label, float epsilon) -> Tensor(out) args : (Tensor input, Tensor label, Tensor out_grad, float epsilon) @@ -355,6 +494,16 @@ kernel : func : log_loss_grad +- backward_api : log_softmax_grad + forward : log_softmax(Tensor x, int axis) -> Tensor(out) + args : (Tensor out, Tensor out_grad, int axis) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [out] + kernel : + func : log_softmax_grad + - backward_api : logsigmoid_grad forward : logsigmoid (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -408,6 +557,16 @@ kernel : func : matrix_power_grad +- backward_api : max_grad + forward: max (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : max_grad + - backward_api : maximum_grad forward : maximum(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) @@ -418,6 +577,26 @@ kernel : func : maximum_grad +- backward_api : mean_grad + forward: mean (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : mean_grad + +- backward_api : min_grad + forward: min (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : min_grad + - backward_api : minimum_grad forward : minimum(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) @@ -428,6 +607,16 @@ kernel : func : minimum_grad +- backward_api : mode_grad + forward : mode(Tensor x, int axis, bool keepdim) -> Tensor(out), Tensor(indices) + args : (Tensor x, Tensor indices, Tensor out_grad, int axis, bool keepdim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : mode_grad + - backward_api : modulo_grad forward : add (Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1) @@ -470,6 +659,36 @@ data_type : input optional : weight +- backward_api : p_norm_grad + forward : p_norm(Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, float porder, int axis, float epsilon, bool keepdim, bool asvector) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : p_norm_grad + +- backward_api : pad3d_grad + forward : pad3d(Tensor x, IntArray paddings, str mode, float pad_value, str data_format) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray paddings, str mode, float pad_value, str data_format) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : pad3d_grad + +- backward_api : prelu_grad + forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out) + args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode) + output : Tensor(x_grad), Tensor(alpha_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param: [x, alpha] + kernel : + func : prelu_grad + - backward_api : psroi_pool_grad forward : psroi_pool (Tensor x, Tensor rois, Tensor rois_num, int pooled_weight, int pooled_width, int output_channels, float spatial_scale ) -> Tensor(out) args : (Tensor x, Tensor rois, Tensor rois_num, Tensor out_grad, int pooled_weight, int pooled_width, int output_channels, float spatial_scale) @@ -537,6 +756,16 @@ backend: out_grad layout: out_grad +- backward_api : round_grad + forward : round(Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [out_grad] + kernel : + func : round_grad + - backward_api : scale_grad forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) args : (Tensor out_grad, Scalar scale, float bias=0.0, bool bias_after_scale=true) @@ -680,6 +909,16 @@ kernel : func : square_grad +- backward_api : squeeze_grad + forward : squeeze(Tensor x, int[] axes) -> Tensor(xshape), Tensor(out) + args : (Tensor xshape, Tensor out_grad, int[] axes) + output : Tensor(x_grad) + infer_meta : + func : KernelWithXShapeInferMeta + param: [xshape] + kernel : + func : squeeze_grad + - backward_api : strided_slice_grad forward : strided_slice (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) -> Tensor(out) args : (Tensor x, Tensor out_grad, int[] axes, IntArray starts, IntArray ends, IntArray strides) @@ -810,6 +1049,16 @@ kernel : func : unfold_grad +- backward_api : unsqueeze_grad + forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(xshape), Tensor(out) + args : (Tensor xshape, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : KernelWithXShapeInferMeta + param: [xshape] + kernel : + func : unsqueeze_grad + - backward_api : where_grad forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out) args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad) From 3b686b189e81f57455abb6737b581d306987bbae Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Sat, 2 Apr 2022 10:30:54 +0800 Subject: [PATCH 05/93] Limit the condition of entering optimized kernel (#41296) Co-authored-by: root --- paddle/phi/kernels/gpu/top_k_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index adaf5cc092b4e..8262023826b32 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -98,7 +98,7 @@ void TopkKernel(const Context& dev_ctx, } #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000 - if (input_width >= 1024 && input_height == 1) { + if (input_width >= 1024 && in_dims.size() == 1) { // 1. Gather TopK, but without sorting constexpr int max_num_threads = 1024; if (largest) { From acec26a1f3e6b85c78f293e0418857ddd34df0c8 Mon Sep 17 00:00:00 2001 From: taixiurong Date: Sat, 2 Apr 2022 10:52:36 +0800 Subject: [PATCH 06/93] xpu add dropout&cast unitest (#41120) --- paddle/fluid/operators/dropout_op_xpu.cc | 8 +- .../fluid/tests/unittests/op_test_xpu.py | 49 +++- .../tests/unittests/xpu/test_cast_op_xpu.py | 38 ++- .../unittests/xpu/test_dropout_op_xpu.py | 274 ++++++++++++------ 4 files changed, 259 insertions(+), 110 deletions(-) diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc index 7d8660f238abc..851f26ee0e717 100644 --- a/paddle/fluid/operators/dropout_op_xpu.cc +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -42,7 +42,13 @@ class DropoutXPUKernel : public framework::OpKernel { if (!context.Attr("is_test")) { int seed_data = 0; if (seed) { - seed_data = *(seed->data()); + if (platform::is_xpu_place(seed->place())) { + memory::Copy(platform::CPUPlace(), &seed_data, seed->place(), + seed->data(), sizeof(int)); + } else { + seed_data = *(seed->data()); + } + } else { seed_data = context.Attr("fix_seed") ? context.Attr("seed") : 0; diff --git a/python/paddle/fluid/tests/unittests/op_test_xpu.py b/python/paddle/fluid/tests/unittests/op_test_xpu.py index 107f340d3a847..4a67af02bcff3 100644 --- a/python/paddle/fluid/tests/unittests/op_test_xpu.py +++ b/python/paddle/fluid/tests/unittests/op_test_xpu.py @@ -54,13 +54,11 @@ def tearDownClass(cls): """Restore random seeds""" def is_empty_grad_op(op_type): - all_op_kernels = core._get_all_register_op_kernels() grad_op = op_type + '_grad' - if grad_op in all_op_kernels.keys(): - grad_op_kernels = all_op_kernels[grad_op] - for grad_op_kernel in grad_op_kernels: - if 'XPU' in grad_op_kernel: - return False + xpu_version = core.get_xpu_device_version(0) + xpu_op_list = core.get_xpu_device_op_list(xpu_version) + if grad_op in xpu_op_list.keys(): + return False return True if cls.dtype == np.float16: @@ -70,9 +68,20 @@ def is_empty_grad_op(op_type): super().tearDownClass() def _get_places(self): - places = [fluid.XPUPlace(0)] + places = [paddle.XPUPlace(0)] return places + def check_output(self, + atol=0.001, + no_check_set=None, + equal_nan=False, + check_dygraph=True, + inplace_atol=None, + check_eager=False): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol, no_check_set, equal_nan, + check_dygraph, inplace_atol, check_eager) + def check_output_with_place(self, place, atol=0.001, @@ -82,20 +91,37 @@ def check_output_with_place(self, inplace_atol=None, check_eager=False): self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) - #xpu not support float64 if self.dtype == np.float64: return - if place == None: - place = paddle.XPUPlace(0) if self.dtype == np.float16: if core.is_float16_supported(place) == False: return + if self.dtype == np.float16: atol = 0.1 return super().check_output_with_place( place, atol, no_check_set, equal_nan, check_dygraph, inplace_atol) + def check_grad(self, + inputs_to_check, + output_names, + no_grad_set=None, + numeric_grad_delta=0.005, + in_place=False, + max_relative_error=0.005, + user_defined_grads=None, + user_defined_grad_outputs=None, + check_dygraph=True, + numeric_place=None, + check_eager=False): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, inputs_to_check, output_names, + no_grad_set, numeric_grad_delta, in_place, + max_relative_error, user_defined_grads, + user_defined_grad_outputs, check_dygraph, + numeric_place, check_eager) + def check_grad_with_place(self, place, inputs_to_check, @@ -116,9 +142,6 @@ def check_grad_with_place(self, self._check_grad_helper() return - if place == None: - place = paddle.XPUPlace(0) - if self.dtype == np.float64: return diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py index 08d4810a6530b..201e758c0acea 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -23,6 +23,9 @@ import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +from op_test_xpu import XPUOpTest + +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper typeid_dict = { 'int32': int(core.VarDesc.VarType.INT32), @@ -33,10 +36,27 @@ } -def create_test_class(in_typename, out_typename): - class Cls(op_test.OpTest): +class XPUTestCastOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'cast' + self.use_dynamic_create_class = True + + def dynamic_create_class(self): + base_class = self.TestCastOp + classes = [] + for out_type in {'float16', 'float32', 'int32', 'int64'}: + class_name = 'XPUTestCastOp_outtype_' + out_type + attr_dict = {'out_typename': out_type} + classes.append([class_name, attr_dict]) + return base_class, classes + + class TestCastOp(XPUOpTest): def setUp(self): ipt = np.random.random(size=[10, 10]) + in_typename = self.in_type_str + out_typename = 'float32' if not hasattr( + self, 'out_typename') else self.out_typename + self.inputs = {'X': ipt.astype(in_typename)} self.outputs = {'Out': ipt.astype(in_typename).astype(out_typename)} self.attrs = { @@ -47,18 +67,12 @@ def setUp(self): self.__class__.no_need_check_grad = True def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - cls_name = "cast_{0}_{1}".format(in_typename, out_typename) - Cls.__name__ = cls_name - globals()[cls_name] = Cls + self.check_output() -for in_type in {'float16', 'float32', 'int32', 'int64', 'bool'}: - for out_type in {'float16', 'float32', 'int32', 'int64'}: - create_test_class(in_type, out_type) +support_types = get_xpu_op_support_types('cast') +for stype in support_types: + create_test_class(globals(), XPUTestCastOp, stype) class TestCastOpError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py index ca3b3a418abf6..2baa837b23a07 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py @@ -25,90 +25,196 @@ from op_test_xpu import XPUOpTest paddle.enable_static() - -class TestDropoutOp(XPUOpTest): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64)).astype('uint8') - } - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad_normal(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') - - -class TestDropoutOpInput1d(XPUOpTest): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((2000, )).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((2000)).astype('uint8') - } - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad_normal(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') - - -class TestDropoutOp2(TestDropoutOp): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': np.zeros((32, 64)).astype('float32'), - 'Mask': np.zeros((32, 64)).astype('uint8') - } - - -class TestDropoutOp3(TestDropoutOp): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('uint8') - } - - -class TestDropoutOp6(TestDropoutOp): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} - self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('uint8') - } - +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + + +class XPUTestDropoutOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'dropout' + self.use_dynamic_create_class = False + + class TestDropoutOp(XPUOpTest): + def setUp(self): + self.init_inputs_shape() + self.init_attrs() + self.dtype = self.in_type + self.op_type = 'dropout' + self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)} + self.attrs = { + 'dropout_prob': self.dropout_prob, + 'fix_seed': self.fix_seed, + 'is_test': self.is_test, + 'dropout_implementation': self.dropout_implementation + } + + out = self.inputs['X'] * (1.0 - self.dropout_prob) + if self.is_test == False: + mask = None + if self.dropout_prob == 0.0: + mask = np.ones(self.shape).astype(self.dtype) + elif self.dropout_prob == 1.0: + mask = np.zeros(self.shape).astype(self.dtype) + self.outputs = {'Out': out, 'Mask': mask} + else: + self.outputs = {'Out': out} + + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.__class__.no_need_check_grad = False + self.dropout_prob = 0.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + if hasattr(self.__class__, "no_need_check_grad" + ) and self.__class__.no_need_check_grad == True: + return + + self.check_grad(['X'], 'Out') + + class TestDropoutOpInput1d(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [2000] + + class TestDropoutOp2(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.dropout_prob = 1.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" + + class TestDropoutOp3(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [32, 64, 2] + + class TestDropoutOp4(TestDropoutOp): + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.35 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + class TestDropoutOp5(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [32, 64, 3] + + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.75 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + class TestDropoutOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of dropout must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], + fluid.CPUPlace()) + fluid.layers.dropout(x1, dropout_prob=0.5) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of dropout must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data( + name='x2', shape=[3, 4, 5, 6], dtype="int32") + fluid.layers.dropout(x2, dropout_prob=0.5) + + self.assertRaises(TypeError, test_dtype) + + class TestDropoutCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + self.places.append(fluid.XPUPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([40, 40]).astype(self.in_type) + result_np = input_np + input = fluid.dygraph.to_variable(input_np) + m = paddle.nn.Dropout(p=0.) + m.eval() + result = m(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + class TestDropoutBackward(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + self.places.append(fluid.XPUPlace(0)) + + def cal_grad_upscale_train(self, mask, prob): + return mask.astype(self.in_type) / (1 - prob) + + def cal_grad_downscale_in_infer(self, mask): + return mask.astype(self.in_type) + + def test_backward_downscale_in_infer(self): + for place in self.places: + with fluid.dygraph.guard(place): + + input = paddle.uniform([40, 40], dtype=self.in_type) + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', 0.5) + out.backward() + + self.assertTrue( + np.array_equal(input.gradient( + ), self.cal_grad_downscale_in_infer(mask.numpy()))) + + def test_backward_upscale_train(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.5 + input = paddle.uniform([40, 40], dtype=self.in_type) + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + def test_backward_upscale_train_2(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.3 + input = paddle.uniform([40, 40], dtype=self.in_type) + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + +support_types = get_xpu_op_support_types('dropout') +for stype in support_types: + create_test_class(globals(), XPUTestDropoutOp, stype) if __name__ == '__main__': unittest.main() From 0fe2001a883f8307441a1bed8d2ab34f459b15d3 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Apr 2022 10:53:19 +0800 Subject: [PATCH 07/93] make variable 'gradient_merge_cond' local (#41262) --- .../fleet/meta_optimizers/sharding_optimizer.py | 9 ++------- .../distributed/passes/auto_parallel_gradient_merge.py | 9 ++------- python/paddle/fluid/optimizer.py | 9 ++------- 3 files changed, 6 insertions(+), 21 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 52468ab533496..c4d42f90615fc 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -1621,13 +1621,8 @@ def _create_gm_cond(self, main_block): persistable=True, force_cpu=True) - cond_var = layers.create_global_var( - name="gradient_merge_cond", - shape=[1], - value=bool(0), - dtype='bool', - persistable=False, - force_cpu=True) + cond_var = main_block.create_var( + name="gradient_merge_cond", shape=[1], dtype='bool') with device_guard("cpu"): # step_var = (step_var + 1) % k_step diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index 7668dff36207e..accac81133825 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -107,13 +107,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): force_cpu=True) set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) - cond_var = layers.create_global_var( - name="gradient_merge_cond", - shape=[1], - value=bool(0), - dtype='bool', - persistable=False, - force_cpu=True) + cond_var = main_block.create_var( + name="gradient_merge_cond", shape=[1], dtype='bool') set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) with device_guard("cpu"): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 7bf4608de89c9..8242d8e3392ec 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -7098,13 +7098,8 @@ def _get_gm_cond_var(self, main_block): persistable=True, force_cpu=True) - cond_var = layers.create_global_var( - name="gradient_merge_cond", - shape=[1], - value=bool(0), - dtype='bool', - persistable=False, - force_cpu=True) + cond_var = main_block.create_var( + name="gradient_merge_cond", shape=[1], dtype='bool') with device_guard("cpu"): # step_var = (step_var + 1) % k_step From cb12415622351b82bbfab8df67985e844019281b Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Apr 2022 11:06:36 +0800 Subject: [PATCH 08/93] [new-exec] support to enable mkldnn by flags (#41274) --- .../fluid/framework/new_executor/interpretercore.cc | 11 ++++++++--- .../framework/new_executor/interpretercore_util.cc | 13 +++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index a2f9d90406736..1b15ca6746257 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -425,13 +425,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { : global_scope_->GetMutableScope(); auto op_with_kernel = dynamic_cast(op); { + // If it is OperatorBase, InferShape do nothing. if (op_with_kernel != nullptr) { platform::RecordEvent infershape_event( "infer_shape", platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); - // If it is OperatorBase, InferShape do nothing. - op_with_kernel->Info().infer_shape_( - instr_node.InnerInferShapeContext().get()); + + // see OperatorWithKernel::RunImpl in operator.cc for why + if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) && + op_with_kernel->Attr(kAllKernelsMustComputeRuntimeShape))) { + op_with_kernel->Info().infer_shape_( + instr_node.InnerInferShapeContext().get()); + } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index d56082a91a61f..360e0222a516c 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -29,6 +29,8 @@ PADDLE_DEFINE_EXPORTED_bool( new_executor_sequential_run, false, "Enable sequential execution for standalone executor, used for debug"); +DECLARE_bool(use_mkldnn); + namespace paddle { namespace framework { namespace interpreter { @@ -192,6 +194,7 @@ void create_all_ops(const framework::BlockDesc& block, const VariableNameMap& inputs_names = op->Inputs(); const VariableNameMap& outputs_names = op->Outputs(); + AttributeMap op_attr_map = op->GetAttrMap(); if (info.Checker() != nullptr) { @@ -199,6 +202,16 @@ void create_all_ops(const framework::BlockDesc& block, } auto op_base = info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); + +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) { + if (op->HasAttr("use_mkldnn")) { + VLOG(4) << "Set use_mkldnn=True for " << op_base->Type(); + op_base->SetAttr("use_mkldnn", true); + } + } +#endif + ops->emplace_back(std::unique_ptr(op_base)); } } From b3270adfe0c638ac582ef96565493c18e1b57989 Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Sat, 2 Apr 2022 11:07:57 +0800 Subject: [PATCH 09/93] =?UTF-8?q?=E7=BB=9F=E4=B8=80ps=20refine=20(#41234)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update name * update name * fix test * fix fleet bind * update name * update name * fix test * fix gpups wrapper * remove Push/Pull/Load/Save with context in client and wrapper base class * fix * fix Co-authored-by: esythan --- .../distributed/ps/service/brpc_ps_client.cc | 353 ++++++++---------- .../distributed/ps/service/brpc_ps_client.h | 185 +++++---- .../distributed/ps/service/brpc_ps_server.cc | 290 +++++++------- .../distributed/ps/service/brpc_ps_server.h | 83 ++-- .../ps/service/communicator/communicator.cc | 56 +-- .../ps/service/communicator/communicator.h | 10 +- paddle/fluid/distributed/ps/service/env.h | 89 +++-- .../ps/service/graph_brpc_client.cc | 49 +-- .../ps/service/graph_brpc_client.h | 4 +- .../ps/service/graph_brpc_server.cc | 115 +++--- .../ps/service/graph_brpc_server.h | 38 +- .../fluid/distributed/ps/service/ps_client.cc | 8 +- .../fluid/distributed/ps/service/ps_client.h | 187 ++++------ .../distributed/ps/service/ps_local_client.cc | 210 ++++------- .../distributed/ps/service/ps_local_client.h | 133 ++++--- .../distributed/ps/service/ps_local_server.h | 10 +- .../ps/service/ps_service/graph_py_service.cc | 28 +- .../ps/service/ps_service/graph_py_service.h | 6 +- .../ps/service/ps_service/service.cc | 50 +-- .../ps/service/ps_service/service.h | 22 +- paddle/fluid/distributed/ps/service/server.cc | 20 +- paddle/fluid/distributed/ps/service/server.h | 32 +- .../distributed/ps/table/barrier_table.cc | 8 +- .../ps/table/common_dense_table.cc | 58 +-- .../distributed/ps/table/common_dense_table.h | 34 +- .../ps/table/common_graph_table.cc | 10 +- .../distributed/ps/table/common_graph_table.h | 33 +- .../ps/table/common_sparse_table.cc | 88 ++--- .../ps/table/common_sparse_table.h | 60 ++- .../fluid/distributed/ps/table/common_table.h | 58 ++- .../distributed/ps/table/depends/dense.h | 14 +- .../distributed/ps/table/depends/sparse.h | 10 +- .../ps/table/memory_sparse_geo_table.cc | 47 ++- .../ps/table/memory_sparse_geo_table.h | 30 +- .../ps/table/memory_sparse_table.cc | 78 ++-- .../ps/table/memory_sparse_table.h | 54 ++- .../distributed/ps/table/sparse_geo_table.cc | 18 +- .../distributed/ps/table/sparse_geo_table.h | 12 +- .../distributed/ps/table/ssd_sparse_table.cc | 26 +- .../distributed/ps/table/ssd_sparse_table.h | 18 +- paddle/fluid/distributed/ps/table/table.cc | 10 +- paddle/fluid/distributed/ps/table/table.h | 80 ++-- .../fluid/distributed/ps/table/tensor_table.h | 108 +++--- paddle/fluid/distributed/ps/wrapper/fleet.cc | 152 +++----- paddle/fluid/distributed/ps/wrapper/fleet.h | 12 +- .../distributed/test/barrier_table_test.cc | 6 +- .../test/brpc_service_dense_sgd_test.cc | 30 +- .../test/brpc_service_sparse_sgd_test.cc | 30 +- .../distributed/test/dense_table_test.cc | 17 +- .../distributed/test/graph_node_split_test.cc | 32 +- .../fluid/distributed/test/graph_node_test.cc | 38 +- .../distributed/test/memory_geo_table_test.cc | 13 +- .../test/memory_sparse_table_test.cc | 13 +- paddle/fluid/distributed/test/table_test.cc | 2 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 2 +- paddle/fluid/framework/multi_trainer.cc | 2 +- paddle/fluid/pybind/fleet_py.cc | 12 +- 57 files changed, 1449 insertions(+), 1744 deletions(-) diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index 5a92afb297c7e..893e0f9a97596 100755 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -80,7 +80,7 @@ void DownpourPsClientService::service( const PsRequestMessage *request, PsResponseMessage *response, ::google::protobuf::Closure *done) { brpc::ClosureGuard done_guard(done); - int ret = _client->handle_client2client_msg( + int ret = _client->HandleClient2ClientMsg( request->cmd_id(), request->client_id(), request->data()); response->set_err_code(0); response->set_err_msg(""); @@ -91,8 +91,8 @@ void DownpourPsClientService::service( } // 启动client端RpcService 用于数据互发等操作 -int32_t BrpcPsClient::start_client_service() { - if (_service.configure(this, _client_id) != 0) { +int32_t BrpcPsClient::StartClientService() { + if (_service.Configure(this, _client_id) != 0) { LOG(ERROR) << "service initialize failed, service_name:DownpourPsClientService"; return -1; @@ -108,12 +108,12 @@ int32_t BrpcPsClient::start_client_service() { return -1; } _server_started = true; - _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, - _client_id); + _env->RegistePsClient(butil::my_ip_cstr(), _server.listen_address().port, + _client_id); return 0; } -int32_t BrpcPsClient::create_client2client_connection( +int32_t BrpcPsClient::CreateClient2ClientConnection( int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { brpc::ChannelOptions options; options.protocol = "baidu_std"; @@ -122,12 +122,12 @@ int32_t BrpcPsClient::create_client2client_connection( options.connect_timeout_ms = pserver_connect_timeout_ms; options.max_retry = max_retry; - std::vector client_list = _env->get_ps_clients(); + std::vector client_list = _env->GetPsClients(); VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: " << client_list.size(); for (auto cc : client_list) { VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: " - << cc.to_string(); + << cc.ToString(); } _client_channels.resize(client_list.size()); std::ostringstream os; @@ -154,7 +154,7 @@ int32_t BrpcPsClient::create_client2client_connection( return 0; } -int32_t BrpcPsClient::initialize() { +int32_t BrpcPsClient::Initialize() { _async_call_num = 0; brpc::ChannelOptions options; @@ -169,7 +169,7 @@ int32_t BrpcPsClient::initialize() { std::string client_ip(butil::my_ip_cstr()); // 获取server列表,并连接 - std::vector server_list = _env->get_ps_servers(); + std::vector server_list = _env->GetPsServers(); _server_channels.resize(server_list.size()); for (size_t i = 0; i < server_list.size(); ++i) { server_ip_port.assign(server_list[i].ip.c_str()); @@ -194,7 +194,7 @@ int32_t BrpcPsClient::initialize() { os << server_ip_port << ","; } // 启动client探听接口, 并相互建立连接 - start_client_service(); + StartClientService(); // 异步push 请求队列初始化 const auto &worker_param = _config.worker_param().downpour_worker_param(); @@ -234,13 +234,13 @@ int32_t BrpcPsClient::initialize() { _flushing = false; // 启动异步push线程 _async_push_sparse_thread = - std::thread(std::bind(&BrpcPsClient::push_sparse_task_consume, this)); + std::thread(std::bind(&BrpcPsClient::PushSparseTaskConsume, this)); // _async_push_sparse_thread.detach(); _async_push_dense_thread = - std::thread(std::bind(&BrpcPsClient::push_dense_task_consume, this)); + std::thread(std::bind(&BrpcPsClient::PushDenseTaskConsume, this)); // for debug // _print_thread = - // std::thread(std::bind(&BrpcPsClient::print_queue_size_thread, this)); + // std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this)); return 0; } @@ -286,7 +286,7 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { return data; } -std::future BrpcPsClient::print_table_stat(uint32_t table_id) { +std::future BrpcPsClient::PrintTableStat(uint32_t table_id) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, table_id](void *done) { @@ -319,7 +319,7 @@ std::future BrpcPsClient::print_table_stat(uint32_t table_id) { closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_timeout_ms( 10800000); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), @@ -327,7 +327,7 @@ std::future BrpcPsClient::print_table_stat(uint32_t table_id) { } return fut; } -std::future BrpcPsClient::send_cmd( +std::future BrpcPsClient::SendCmd( uint32_t table_id, int cmd_id, const std::vector ¶ms) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( @@ -352,7 +352,7 @@ std::future BrpcPsClient::send_cmd( for (const auto ¶m : params) { closure->request(i)->add_params(param); } - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_timeout_ms( 10800000 * 2); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), @@ -361,7 +361,7 @@ std::future BrpcPsClient::send_cmd( return fut; } -std::future BrpcPsClient::send_save_cmd( +std::future BrpcPsClient::SendSaveCmd( uint32_t table_id, int cmd_id, const std::vector ¶ms) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( @@ -392,7 +392,7 @@ std::future BrpcPsClient::send_save_cmd( for (const auto ¶m : params) { closure->request(i)->add_params(param); } - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_timeout_ms( 10800000); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), @@ -401,65 +401,42 @@ std::future BrpcPsClient::send_save_cmd( return fut; } -std::future BrpcPsClient::shrink(uint32_t table_id, +std::future BrpcPsClient::Shrink(uint32_t table_id, const std::string threshold) { - return send_cmd(table_id, PS_SHRINK_TABLE, {threshold}); + return SendCmd(table_id, PS_SHRINK_TABLE, {threshold}); } -std::future BrpcPsClient::load(const std::string &epoch, +std::future BrpcPsClient::Load(const std::string &epoch, const std::string &mode) { - return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); + return SendCmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); } -std::future BrpcPsClient::load(uint32_t table_id, +std::future BrpcPsClient::Load(uint32_t table_id, const std::string &epoch, const std::string &mode) { - return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); + return SendCmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); } -std::future BrpcPsClient::Load(const LoadSaveContext &load_context) { - if (load_context.table_id < 0) { - return send_cmd(-1, PS_LOAD_ALL_TABLE, - {load_context.epoch, load_context.mode}); - } else { - return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE, - {load_context.epoch, load_context.mode}); - } -} - -std::future BrpcPsClient::save(const std::string &epoch, +std::future BrpcPsClient::Save(const std::string &epoch, const std::string &mode) { VLOG(1) << "BrpcPsClient::save path " << epoch; - return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); + return SendSaveCmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); } -std::future BrpcPsClient::save(uint32_t table_id, +std::future BrpcPsClient::Save(uint32_t table_id, const std::string &epoch, const std::string &mode) { VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id " << table_id; - return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); + return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); } -std::future BrpcPsClient::Save(const LoadSaveContext &save_context) { - if (save_context.table_id < 0) { - VLOG(1) << "BrpcPsClient::save path " << save_context.epoch; - return send_save_cmd(-1, PS_SAVE_ALL_TABLE, - {save_context.epoch, save_context.mode}); - } else { - VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch - << " table_id " << save_context.table_id; - return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE, - {save_context.epoch, save_context.mode}); - } -} - -std::future BrpcPsClient::clear() { - return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); +std::future BrpcPsClient::Clear() { + return SendCmd(-1, PS_CLEAR_ALL_TABLE, {}); } -std::future BrpcPsClient::clear(uint32_t table_id) { - return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {}); +std::future BrpcPsClient::Clear(uint32_t table_id) { + return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {}); } -std::future BrpcPsClient::flush() { +std::future BrpcPsClient::Flush() { VLOG(0) << "BrpcPsClient::flush begin"; _flushing = true; std::promise promise; @@ -472,106 +449,69 @@ std::future BrpcPsClient::flush() { promise.set_value(0); _flushing = false; VLOG(0) << "BrpcPsClient::flush done"; - print_queue_size(); + PrintQueueSize(); return fut; } -void BrpcPsClient::print_queue_size() { +void BrpcPsClient::PrintQueueSize() { for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { auto table_id = push_sparse_task_itr.first; auto queue_size = push_sparse_task_itr.second->Size(); - VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id << " size: " << queue_size; } for (auto &task_queue_itr : _push_dense_task_queue_map) { auto table_id = task_queue_itr.first; auto queue_size = task_queue_itr.second->Size(); - VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id << " size: " << queue_size; } } -void BrpcPsClient::print_queue_size_thread() { +void BrpcPsClient::PrintQueueSizeThread() { while (_running) { usleep(1000000 * 60 * 2); - print_queue_size(); + PrintQueueSize(); } } -void BrpcPsClient::finalize_worker() { - flush(); - VLOG(0) << "BrpcPsClient::finalize_worker begin join thread"; +void BrpcPsClient::FinalizeWorker() { + Flush(); + VLOG(0) << "BrpcPsClient::FinalizeWorker begin join thread"; _running = false; _async_push_dense_thread.join(); _async_push_sparse_thread.join(); // _print_thread.join(); - VLOG(0) << "BrpcPsClient::finalize_worker begin join server"; + VLOG(0) << "BrpcPsClient::FinalizeWorker begin join server"; _server.Stop(1000); _server.Join(); _server_started = false; - VLOG(0) << "BrpcPsClient::finalize_worker done"; + VLOG(0) << "BrpcPsClient::FinalizeWorker done"; } -std::future BrpcPsClient::stop_server() { - return send_cmd(-1, PS_STOP_SERVER, {}); +std::future BrpcPsClient::StopServer() { + return SendCmd(-1, PS_STOP_SERVER, {}); } -std::future BrpcPsClient::start_profiler() { - return send_cmd(-1, PS_START_PROFILER, {}); +std::future BrpcPsClient::StartProfiler() { + return SendCmd(-1, PS_START_PROFILER, {}); } -std::future BrpcPsClient::stop_profiler() { - return send_cmd(-1, PS_STOP_PROFILER, {}); +std::future BrpcPsClient::StopProfiler() { + return SendCmd(-1, PS_STOP_PROFILER, {}); } -std::future BrpcPsClient::barrier(size_t table_id, +std::future BrpcPsClient::Barrier(size_t table_id, uint32_t barrier_type) { - return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); -} - -std::future BrpcPsClient::Pull(RequestContext &pull_context) { - if (pull_context.value_type == Dense) { // pull dense - Region *dense_region = - reinterpret_cast(pull_context.dense_values); - return pull_dense(dense_region, pull_context.num, pull_context.table); - } else { // pull sparse - size_t table_id = pull_context.table; - size_t num = pull_context.num; - bool is_training = pull_context.is_training; - if (pull_context.training_mode == Geo) { // for geo - return pull_sparse_param(pull_context.sparse_values, table_id, - pull_context.keys, num, is_training); - } else if (pull_context.training_mode == Async) { // for async - return pull_sparse(pull_context.sparse_values, table_id, - pull_context.keys, num, is_training); - } - } + return SendCmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); } -std::future BrpcPsClient::Push(RequestContext &push_context) { - if (push_context.value_type == Dense) { // push dense - const Region *dense_region = push_context.push_context.push_dense_values; - return push_dense(dense_region, push_context.num, push_context.table); - } else { // push sparse - size_t table_id = push_context.table; - size_t num = push_context.num; - bool is_training = push_context.is_training; - if (push_context.training_mode == Geo) { // for geo - // TODO(zhaocaibei) - } else if (push_context.training_mode == Async) { // for async - const uint64_t *keys = push_context.push_context.keys; - const float **update_values = push_context.push_context.push_values; - return push_sparse(table_id, keys, update_values, num); - } - } -} - -std::future BrpcPsClient::pull_geo_param(size_t table_id, - std::vector *values, - std::vector *keys, - int pserver_idx) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PullGeoParam(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) { + auto *accessor = GetTableAccessor(table_id); DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { int ret = 0; @@ -600,7 +540,7 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM); closure->request(0)->set_table_id(table_id); closure->request(0)->set_client_id(_client_id); - PsService_Stub rpc_stub(get_cmd_channel(pserver_idx)); + PsService_Stub rpc_stub(GetCmdChannel(pserver_idx)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -608,10 +548,11 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, } // for GEO -std::future BrpcPsClient::push_sparse_param( - size_t table_id, const uint64_t *keys, const float **update_values, - size_t num, void *done) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PushSparseParam(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) { + auto *accessor = GetTableAccessor(table_id); // 发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -649,7 +590,7 @@ std::future BrpcPsClient::push_sparse_param( memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } - PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), @@ -658,16 +599,15 @@ std::future BrpcPsClient::push_sparse_param( return fut; } -std::future BrpcPsClient::pull_dense(Region *regions, - size_t region_num, - size_t table_id) { +std::future BrpcPsClient::PullDense(Region *regions, size_t region_num, + size_t table_id) { auto timer = std::make_shared("pserver_client_pull_dense"); - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); auto fea_dim = accessor->GetTableInfo(FEA_DIM); auto select_size = accessor->GetTableInfo(SELECT_SIZE); size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); // callback 将各shard结果,顺序填入region DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, num_per_shard, regions, region_num, @@ -730,22 +670,22 @@ std::future BrpcPsClient::pull_dense(Region *regions, closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&num_per_shard, // NOLINT sizeof(num_per_shard)); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::push_dense_param(const Region *regions, - size_t region_num, - size_t table_id) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PushDenseParam(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = GetTableAccessor(table_id); size_t request_call_num = _server_channels.size(); // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 std::vector> regions_partition(request_call_num); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE); size_t current_region_idx = 0; size_t current_region_data_idx = 0; @@ -809,17 +749,17 @@ std::future BrpcPsClient::push_dense_param(const Region *regions, fill_num); fill_remain_size -= fill_num; } - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::push_sparse_raw_gradient( +std::future BrpcPsClient::PushSparseRawGradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); // 发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -872,7 +812,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } - PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), @@ -881,7 +821,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( return fut; } -std::future BrpcPsClient::push_dense_raw_gradient( +std::future BrpcPsClient::PushDenseRawGradient( int table_id, float *total_send_data, size_t total_send_data_size, void *done) { size_t request_call_num = _server_channels.size(); @@ -889,9 +829,9 @@ std::future BrpcPsClient::push_dense_raw_gradient( auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_table_id(table_id); @@ -905,16 +845,16 @@ std::future BrpcPsClient::push_dense_raw_gradient( total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); // closure->cntl(i)->set_request_compress_type( // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::push_global_step(int table_id, - int64_t *total_send_data, - void *done) { +std::future BrpcPsClient::PushGlobalStep(int table_id, + int64_t *total_send_data, + void *done) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -933,17 +873,17 @@ std::future BrpcPsClient::push_global_step(int table_id, memcpy(push_data_ptr + sizeof(uint32_t), total_send_data, num_per_shard * sizeof(int64_t)); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } return fut; } -std::future BrpcPsClient::pull_sparse(float **select_values, - size_t table_id, - const uint64_t *keys, size_t num, - bool is_training) { +std::future BrpcPsClient::PullSparse(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num, + bool is_training) { auto timer = std::make_shared("pserver_client_pull_sparse"); auto local_timer = std::make_shared("pserver_client_pull_sparse_local"); @@ -968,7 +908,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); } - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t value_size = accessor->GetTableInfo(SELECT_SIZE); @@ -1055,7 +995,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&kv_request_count, // NOLINT sizeof(uint32_t)); - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); @@ -1065,11 +1005,11 @@ std::future BrpcPsClient::pull_sparse(float **select_values, } // for GEO -std::future BrpcPsClient::pull_sparse_param(float **select_values, - size_t table_id, - const uint64_t *keys, - size_t num, - bool is_training) { +std::future BrpcPsClient::PullSparseParam(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num, + bool is_training) { auto timer = std::make_shared("pserver_client_pull_sparse_param"); size_t request_call_num = _server_channels.size(); @@ -1082,7 +1022,7 @@ std::future BrpcPsClient::pull_sparse_param(float **select_values, shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); } - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t value_size = accessor->GetTableInfo(SELECT_SIZE); DownpourBrpcClosure *closure = new DownpourBrpcClosure( @@ -1169,7 +1109,7 @@ std::future BrpcPsClient::pull_sparse_param(float **select_values, closure->request(i)->set_client_id(_client_id); closure->request(i)->add_params((char *)&kv_request_count, // NOLINT sizeof(uint32_t)); - PsService_Stub rpc_stub(get_cmd_channel(i)); + PsService_Stub rpc_stub(GetCmdChannel(i)); closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); @@ -1178,7 +1118,7 @@ std::future BrpcPsClient::pull_sparse_param(float **select_values, return fut; } -std::future BrpcPsClient::send_client2client_msg( +std::future BrpcPsClient::SendClient2ClientMsg( int msg_type, int to_client_id, const std::string &msg) { auto promise = std::make_shared>(); std::future fut = promise->get_future(); @@ -1203,10 +1143,10 @@ std::future BrpcPsClient::send_client2client_msg( return fut; } -std::future BrpcPsClient::push_sparse_raw_gradient_partial( +std::future BrpcPsClient::PushSparseRawGradientPartial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) { - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t value_size = accessor->GetTableInfo(UPDATE_SIZE); DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); @@ -1228,7 +1168,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( memcpy(push_data_ptr, update_values[i], value_size); push_data_ptr += value_size; } - PsService_Stub rpc_stub(get_sparse_channel(pserver_idx)); + PsService_Stub rpc_stub(GetSparseChannel(pserver_idx)); closure->cntl(0)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), @@ -1236,8 +1176,8 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( return fut; } -int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, - const std::string &path) { +int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id, + const std::string &path) { // get var information std::string var_name = ""; int64_t var_num = 0; @@ -1271,17 +1211,17 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, save_vec.push_back(save_huge_vec.data() + i * var_shape); } - VLOG(2) << "recv_and_save_table: table_class: " << table_class; + VLOG(2) << "RecvAndSaveTable: table_class: " << table_class; // TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its - // recv_and_save_table + // RecvAndSaveTable if (table_class == "MemorySparseGeoTable") { auto status = - pull_sparse_param(reinterpret_cast(save_vec.data()), table_id, - save_key.data(), save_key.size(), true); + PullSparseParam(reinterpret_cast(save_vec.data()), table_id, + save_key.data(), save_key.size(), true); status.wait(); } else { - auto status = pull_sparse(reinterpret_cast(save_vec.data()), - table_id, save_key.data(), save_key.size(), true); + auto status = PullSparse(reinterpret_cast(save_vec.data()), + table_id, save_key.data(), save_key.size(), true); status.wait(); } @@ -1315,15 +1255,15 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, return 0; } -std::future BrpcPsClient::push_sparse(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num) { +std::future BrpcPsClient::PushSparse(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num) { auto push_timer = std::make_shared("pserver_client_push_sparse"); CostTimer parse_timer("pserver_client_push_sparse_parse"); int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) { - // LOG(INFO) << "push_sparse Waiting for async_call_num comsume, + // LOG(INFO) << "PushSparse Waiting for async_call_num comsume, // task_num:" // << push_sparse_async_num // << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; @@ -1333,7 +1273,7 @@ std::future BrpcPsClient::push_sparse(size_t table_id, auto put_timer = std::make_shared("client_push_sparse_put"); thread_local std::vector>> shard_sorted_kv_list; - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); size_t request_call_num = _server_channels.size(); shard_sorted_kv_list.resize(request_call_num); for (auto &x : shard_sorted_kv_list) { @@ -1381,7 +1321,7 @@ std::future BrpcPsClient::push_sparse(size_t table_id, return fut; } -void BrpcPsClient::push_sparse_task_consume() { +void BrpcPsClient::PushSparseTaskConsume() { uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit; std::vector> task_list; size_t request_call_num = _server_channels.size(); @@ -1392,7 +1332,7 @@ void BrpcPsClient::push_sparse_task_consume() { // 所有sparseTable的pushTask 进行处理 for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { auto table_id = push_sparse_task_itr.first; - auto *accessor = table_accessor(table_id); + auto *accessor = GetTableAccessor(table_id); auto &task_queue = push_sparse_task_itr.second; auto queue_size = task_queue->Size(); if (queue_size == 0) { @@ -1471,7 +1411,7 @@ void BrpcPsClient::push_sparse_task_consume() { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(std::bind( - &BrpcPsClient::push_sparse_async_shard_push, this, task_list, + &BrpcPsClient::PushSparseAsyncShardPush, this, task_list, request_kv_num, table_id, shard_idx, closure, accessor)); } for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { @@ -1487,7 +1427,7 @@ void BrpcPsClient::push_sparse_task_consume() { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(std::bind( - &BrpcPsClient::push_sparse_async_shard_merge, this, task_list, + &BrpcPsClient::PushSparseAsyncShardMerge, this, task_list, request_kv_num, table_id, shard_idx, accessor)); } for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { @@ -1523,7 +1463,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data, accessor->Merge(merge_data_shell, another_data_shell, 1); } -int BrpcPsClient::push_sparse_async_shard_merge( +int BrpcPsClient::PushSparseAsyncShardMerge( std::vector> &task_list, std::vector &request_kv_num, int table_id, int shard_idx, ValueAccessor *accessor) { @@ -1615,12 +1555,12 @@ int BrpcPsClient::push_sparse_async_shard_merge( return 0; } -int BrpcPsClient::push_sparse_async_shard_push( +int BrpcPsClient::PushSparseAsyncShardPush( std::vector> &task_list, std::vector &request_kv_num, int table_id, int shard_idx, DownpourBrpcClosure *closure, ValueAccessor *accessor) { - push_sparse_async_shard_merge(task_list, request_kv_num, table_id, shard_idx, - accessor); + PushSparseAsyncShardMerge(task_list, request_kv_num, table_id, shard_idx, + accessor); size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num; auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list; @@ -1649,7 +1589,7 @@ int BrpcPsClient::push_sparse_async_shard_push( accessor->GetTableInfo(UPDATE_SIZE)); push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); } - PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), @@ -1658,10 +1598,10 @@ int BrpcPsClient::push_sparse_async_shard_push( return 0; } -std::future BrpcPsClient::push_dense(const Region *regions, - size_t region_num, - size_t table_id) { - auto *accessor = table_accessor(table_id); +std::future BrpcPsClient::PushDense(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = GetTableAccessor(table_id); int fea_dim = accessor->GetTableInfo(FEA_DIM); int update_dim = accessor->GetTableInfo(UPDATE_DIM); auto push_timer = std::make_shared("pserver_client_push_dense"); @@ -1669,7 +1609,7 @@ std::future BrpcPsClient::push_dense(const Region *regions, std::make_shared("pserver_client_push_dense_parse"); int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); while (push_dense_async_num > FLAGS_pserver_max_async_call_num) { - // LOG(INFO) << "push_dense Waiting for async_call_num comsume, + // LOG(INFO) << "PushDense Waiting for async_call_num comsume, // task_num:" // << push_dense_async_num // << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; @@ -1683,7 +1623,7 @@ std::future BrpcPsClient::push_dense(const Region *regions, size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); // 将region数据拷贝到转置矩阵中 async_task->data()->resize(num_per_shard * request_call_num * @@ -1705,7 +1645,7 @@ std::future BrpcPsClient::push_dense(const Region *regions, return fut; } -void BrpcPsClient::push_dense_task_consume() { +void BrpcPsClient::PushDenseTaskConsume() { uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit; static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge; ::ThreadPool async_merge_dense_threads(10); @@ -1723,7 +1663,7 @@ void BrpcPsClient::push_dense_task_consume() { ++_async_call_num; DenseAsyncTask *task; task_queue->Get(task); - auto *accessor = table_accessor(task->table_id()); + auto *accessor = GetTableAccessor(task->table_id()); // 设置请求回调 size_t request_call_num = _server_channels.size(); @@ -1774,7 +1714,7 @@ void BrpcPsClient::push_dense_task_consume() { merge_status[i].wait(); } - VLOG(3) << "BrpcPsClient::push_dense_task_consume before merge " + VLOG(3) << "BrpcPsClient::PushDenseTaskConsume before merge " "total_send_data[0]" << total_send_data[0] << " total_send_data[-2]" << total_send_data[total_send_data_size - 2] @@ -1787,7 +1727,7 @@ void BrpcPsClient::push_dense_task_consume() { mat *= (1.0 / (merge_count + 1)); } - VLOG(3) << "BrpcPsClient::push_dense_task_consume after merge " + VLOG(3) << "BrpcPsClient::PushDenseTaskConsume after merge " "total_send_data[0]" << total_send_data[0] << " total_send_data[-2]" << total_send_data[total_send_data_size - 2] @@ -1796,8 +1736,8 @@ void BrpcPsClient::push_dense_task_consume() { << merge_count; } std::shared_ptr task_ptr(task); - push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size, - closure); + PushDenseRawGradient(task_ptr, total_send_data, total_send_data_size, + closure); } auto wait_ms = FLAGS_pserver_async_push_dense_interval_ms - (butil::gettimeofday_ms() - async_start_time_ms); @@ -1807,16 +1747,17 @@ void BrpcPsClient::push_dense_task_consume() { } } -void BrpcPsClient::push_dense_raw_gradient( - std::shared_ptr &task, float *total_send_data, - size_t total_send_data_size, DownpourBrpcClosure *closure) { - auto *accessor = table_accessor(task->table_id()); +void BrpcPsClient::PushDenseRawGradient(std::shared_ptr &task, + float *total_send_data, + size_t total_send_data_size, + DownpourBrpcClosure *closure) { + auto *accessor = GetTableAccessor(task->table_id()); size_t request_call_num = _server_channels.size(); // 将数据拷贝到请求buffer区 auto timer = std::make_shared("pserver_client_push_dense_rpc"); closure->add_timer(timer); uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); auto send_timer = std::make_shared("pserver_client_push_dense_send"); for (size_t i = 0; i < request_call_num; ++i) { @@ -1832,7 +1773,7 @@ void BrpcPsClient::push_dense_raw_gradient( total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); closure->cntl(i)->set_request_compress_type( (brpc::CompressType)FLAGS_pserver_communicate_compress_type); - PsService_Stub rpc_stub(get_dense_channel(i)); + PsService_Stub rpc_stub(GetDenseChannel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h index 8b0cb0741b400..f109b473ca1f4 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -50,7 +50,7 @@ class DownpourPsClientService : public PsService { DownpourPsClientService() {} virtual ~DownpourPsClientService() {} - virtual int32_t configure(PSClient *client, size_t rank_id) { + virtual int32_t Configure(PSClient *client, size_t rank_id) { _client = client; _rank = rank_id; return 0; @@ -139,7 +139,7 @@ class BrpcPsClient : public PSClient { BrpcPsClient() {} virtual ~BrpcPsClient() { if (_running) { - flush(); + Flush(); _running = false; } if (_async_push_dense_thread.joinable()) { @@ -154,109 +154,98 @@ class BrpcPsClient : public PSClient { _server_started = false; } } - virtual int32_t create_client2client_connection( - int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); - std::future shrink(uint32_t table_id, + virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry); + std::future Shrink(uint32_t table_id, const std::string threshold) override; - std::future load(const std::string &epoch, + std::future Load(const std::string &epoch, const std::string &mode) override; - std::future load(uint32_t table_id, const std::string &epoch, + std::future Load(uint32_t table_id, const std::string &epoch, const std::string &mode) override; - std::future Load(const LoadSaveContext &load_context) override; - - std::future save(const std::string &epoch, + std::future Save(const std::string &epoch, const std::string &mode) override; - std::future save(uint32_t table_id, const std::string &epoch, + std::future Save(uint32_t table_id, const std::string &epoch, const std::string &mode) override; - virtual std::future Save( - const LoadSaveContext &save_context) override; - - std::future clear() override; - - std::future clear(uint32_t table_id) override; + std::future Clear() override; - std::future stop_server() override; + std::future Clear(uint32_t table_id) override; - std::future start_profiler() override; - std::future stop_profiler() override; + std::future StopServer() override; - void finalize_worker() override; + std::future StartProfiler() override; + std::future StopProfiler() override; - virtual std::future pull_dense(Region *regions, size_t region_num, - size_t table_id); + void FinalizeWorker() override; - virtual std::future push_dense_param(const Region *regions, - size_t region_num, - size_t table_id); + virtual std::future PullDense(Region *regions, size_t region_num, + size_t table_id); - virtual std::future push_dense(const Region *regions, - size_t region_num, size_t table_id); - void push_dense_task_consume(); - virtual std::future pull_sparse(float **select_values, - size_t table_id, - const uint64_t *keys, size_t num, - bool is_training); - virtual std::future pull_sparse_param(float **select_values, - size_t table_id, - const uint64_t *keys, - size_t num, bool is_training); + virtual std::future PushDenseParam(const Region *regions, + size_t region_num, + size_t table_id); - virtual std::future Pull(RequestContext &pull_context) override; + virtual std::future PushDense(const Region *regions, + size_t region_num, size_t table_id); + void PushDenseTaskConsume(); + virtual std::future PullSparse(float **select_values, + size_t table_id, const uint64_t *keys, + size_t num, bool is_training); + virtual std::future PullSparseParam(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num, + bool is_training); - virtual std::future Push(RequestContext &push_context) override; + virtual std::future PrintTableStat(uint32_t table_id); - virtual std::future print_table_stat(uint32_t table_id); + virtual std::future Barrier(size_t table_id, uint32_t barrier_type); - virtual std::future barrier(size_t table_id, uint32_t barrier_type); + virtual std::future PullGeoParam(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx); + virtual std::future PushGlobalStep(int table_id, + int64_t *total_send_data, + void *done); + virtual std::future Flush(); - virtual std::future pull_geo_param(size_t table_id, - std::vector *values, - std::vector *keys, - int pserver_idx); - virtual std::future push_global_step(int table_id, - int64_t *total_send_data, - void *done); - virtual std::future flush(); - - std::future send_client2client_msg(int msg_type, int to_client_id, - const std::string &msg) override; + std::future SendClient2ClientMsg(int msg_type, int to_client_id, + const std::string &msg) override; // for local save sparse - virtual int32_t recv_and_save_table(const uint64_t table_id, - const std::string &path); + virtual int32_t RecvAndSaveTable(const uint64_t table_id, + const std::string &path); - void print_queue_size(); - void print_queue_size_thread(); + void PrintQueueSize(); + void PrintQueueSizeThread(); protected: - virtual size_t get_server_nums() { return _server_channels.size(); } - inline brpc::Channel *get_sparse_channel(size_t server_id) { + virtual size_t GetServerNums() { return _server_channels.size(); } + inline brpc::Channel *GetSparseChannel(size_t server_id) { return _server_channels[server_id][0].get(); } - inline brpc::Channel *get_dense_channel(size_t server_id) { + inline brpc::Channel *GetDenseChannel(size_t server_id) { return _server_channels[server_id][1].get(); } - inline brpc::Channel *get_cmd_channel(size_t server_id) { + inline brpc::Channel *GetCmdChannel(size_t server_id) { return _server_channels[server_id][2].get(); } - int32_t initialize() override; + int32_t Initialize() override; private: - // virtual int32_t initialize() override; - - inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, - uint32_t shard_num) { + inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, + uint32_t shard_num) { return dense_dim_total / shard_num + 1; } - std::future send_cmd(uint32_t table_id, int cmd_id, - const std::vector ¶m); + std::future SendCmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); - std::future send_save_cmd(uint32_t table_id, int cmd_id, - const std::vector ¶m); + std::future SendSaveCmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); bool _running = false; bool _flushing = false; @@ -276,12 +265,12 @@ class BrpcPsClient : public PSClient { std::thread _print_thread; - int push_sparse_async_shard_merge( + int PushSparseAsyncShardMerge( std::vector> &task_list, // NOLINT std::vector &request_kv_num, int table_id, int shard_idx, // NOLINT ValueAccessor *accessor); - int push_sparse_async_shard_push( + int PushSparseAsyncShardPush( std::vector> &task_list, // NOLINT std::vector &request_kv_num, int table_id, int shard_idx, // NOLINT DownpourBrpcClosure *closure, ValueAccessor *accessor); @@ -292,36 +281,36 @@ class BrpcPsClient : public PSClient { _client_channels; // client2client std::vector, 3>> _server_channels; // client2server - std::future push_dense_raw_gradient(int table_id, - float *total_send_data, - size_t total_send_data_size, - void *done) override; - - std::future push_sparse_raw_gradient(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num, - void *done) override; - - std::future push_sparse_raw_gradient_partial( - size_t table_id, const uint64_t *keys, const float **update_values, - uint32_t num, void *done, int pserver_idx) override; - - std::future push_sparse_param(size_t table_id, const uint64_t *keys, - const float **update_values, - size_t num, void *done) override; - std::future push_sparse(size_t table_id, const uint64_t *keys, - const float **update_values, - size_t num) override; - void push_sparse_task_consume(); + std::future PushDenseRawGradient(int table_id, + float *total_send_data, + size_t total_send_data_size, + void *done) override; + + std::future PushSparseRawGradient(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) override; + + std::future PushSparseRawGradientPartial(size_t table_id, + const uint64_t *keys, + const float **update_values, + uint32_t num, void *done, + int pserver_idx) override; + + std::future PushSparseParam(size_t table_id, const uint64_t *keys, + const float **update_values, size_t num, + void *done) override; + std::future PushSparse(size_t table_id, const uint64_t *keys, + const float **update_values, + size_t num) override; + void PushSparseTaskConsume(); private: - int32_t start_client_service(); + int32_t StartClientService(); - void push_dense_raw_gradient(std::shared_ptr &task, // NOLINT - float *total_send_data, - size_t total_send_data_size, - DownpourBrpcClosure *closure); + void PushDenseRawGradient(std::shared_ptr &task, // NOLINT + float *total_send_data, size_t total_send_data_size, + DownpourBrpcClosure *closure); float _mae = 0; float _mse = 0; uint16_t _push_times = 0; diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index 2e77020c30751..1d88d88ebcf14 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -31,7 +31,7 @@ class RpcController; namespace paddle { namespace distributed { -int32_t BrpcPsServer::initialize() { +int32_t BrpcPsServer::Initialize() { auto &service_config = _config.downpour_server_param().service_param(); if (!service_config.has_service_class()) { LOG(ERROR) << "miss service_class in ServerServiceParameter"; @@ -46,7 +46,7 @@ int32_t BrpcPsServer::initialize() { } _service.reset(service); - if (service->configure(this) != 0 || service->initialize() != 0) { + if (service->Configure(this) != 0 || service->Initialize() != 0) { LOG(ERROR) << "service initialize failed, service_name:" << service_config.service_class(); return -1; @@ -59,7 +59,7 @@ int32_t BrpcPsServer::initialize() { return 0; } -uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { +uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) { std::unique_lock lock(mutex_); std::string ip_port = ip + ":" + std::to_string(port); @@ -68,7 +68,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { brpc::ServerOptions options; int num_threads = std::thread::hardware_concurrency(); - auto trainers = _environment->get_trainers(); + auto trainers = _environment->GetTrainers(); options.num_threads = trainers > num_threads ? trainers : num_threads; if (_server.Start(ip_port.c_str(), &options) != 0) { @@ -83,7 +83,7 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { } } - _environment->registe_ps_server(ip, port, _rank); + _environment->RegistePsServer(ip, port, _rank); cv_.wait(lock, [&] { return stoped_; }); PSHost host; @@ -93,31 +93,30 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { return host.rank; } -int32_t BrpcPsServer::port() { return _server.listen_address().port; } +int32_t BrpcPsServer::Port() { return _server.listen_address().port; } -int32_t BrpcPsService::initialize() { +int32_t BrpcPsService::Initialize() { _is_initialize_shard_info = false; - _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::stop_server; - _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::pull_dense; - _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::push_dense; - _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::pull_sparse; - _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::push_sparse; - _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::save_one_table; - _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::save_all_table; - _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::shrink_table; - _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::load_one_table; - _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::load_all_table; - _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::clear_one_table; - _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::clear_all_table; - _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::push_dense_param; - _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::print_table_stat; - _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::pull_geo_param; - _service_handler_map[PS_PUSH_SPARSE_PARAM] = - &BrpcPsService::push_sparse_param; - _service_handler_map[PS_BARRIER] = &BrpcPsService::barrier; - _service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler; - _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler; - _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::push_global_step; + _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer; + _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense; + _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense; + _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse; + _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse; + _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable; + _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable; + _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable; + _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable; + _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable; + _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable; + _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable; + _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam; + _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat; + _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam; + _service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam; + _service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier; + _service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler; + _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler; + _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep; auto &profiler = CostProfiler::instance(); profiler.register_profiler("pserver_server_pull_dense"); profiler.register_profiler("pserver_server_push_dense"); @@ -125,7 +124,7 @@ int32_t BrpcPsService::initialize() { profiler.register_profiler("pserver_server_push_sparse"); // shard初始化,server启动后才可从env获取到server_list的shard信息 - initialize_shard_info(); + InitializeShardInfo(); return 0; } @@ -138,16 +137,16 @@ int32_t BrpcPsService::initialize() { return -1; \ } -int32_t BrpcPsService::initialize_shard_info() { +int32_t BrpcPsService::InitializeShardInfo() { if (!_is_initialize_shard_info) { std::lock_guard guard(_initialize_shard_mutex); if (_is_initialize_shard_info) { return 0; } - size_t shard_num = _server->environment()->get_ps_servers().size(); - auto &table_map = *(_server->table()); + size_t shard_num = _server->Environment()->GetPsServers().size(); + auto &table_map = *(_server->GetTable()); for (auto itr : table_map) { - itr.second->set_shard(_rank, shard_num); + itr.second->SetShard(_rank, shard_num); } _is_initialize_shard_info = true; } @@ -167,7 +166,7 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, response->set_err_code(0); response->set_err_msg(""); - auto *table = _server->table(request->table_id()); + auto *table = _server->GetTable(request->table_id()); brpc::Controller *cntl = static_cast(cntl_base); auto itr = _service_handler_map.find(request->cmd_id()); if (itr == _service_handler_map.end()) { @@ -185,11 +184,11 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, } } -int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->pull_dense", platform::TracerEventType::Communication, 1); + "PsService->PullDense", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 1) { set_response_code( @@ -206,14 +205,15 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, } auto res_data = butil::get_object>(); - res_data->resize(num * table->value_accesor()->GetTableInfo(SELECT_SIZE) / + res_data->resize(num * table->ValueAccesor()->GetTableInfo(SELECT_SIZE) / sizeof(float)); + TableContext table_context; table_context.value_type = Dense; table_context.pull_context.values = res_data->data(); table_context.num = num; table->Pull(table_context); - // table->pull_dense(res_data->data(), num); + // table->PullDense(res_data->data(), num); cntl->response_attachment().append((char *)(res_data->data()), res_data->size() * sizeof(float)); @@ -222,13 +222,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, return 0; } -int32_t BrpcPsService::push_dense_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - platform::RecordEvent record_event("PsService->push_dense_param", - platform::TracerEventType::Communication, - 1); +int32_t BrpcPsService::PushDenseParam(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event( + "PsService->PushDenseParam", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) thread_local std::string push_buffer; auto &req_io_buffer = cntl->request_attachment(); @@ -245,17 +244,17 @@ int32_t BrpcPsService::push_dense_param(Table *table, uint32_t num = *(const uint32_t *)data; const float *values = (const float *)(data + sizeof(uint32_t)); - if (table->push_dense_param(values, num) != 0) { - set_response_code(response, -1, "push_dense_param failed"); + if (table->PushDenseParam(values, num) != 0) { + set_response_code(response, -1, "PushDenseParam failed"); } return 0; } -int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PushDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->push_dense", platform::TracerEventType::Communication, 1); + "PsService->PushDense", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) auto req_buffer_size = request.data().size(); if (req_buffer_size < 1) { @@ -278,14 +277,14 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, // const float *values = (const float *)(request.data().data() + // sizeof(uint32_t)); if (table->Push(table_context) != 0) { - // if (table->push_dense(values, num) != 0) { - set_response_code(response, -1, "push_dense failed"); + // if (table->PushDense(values, num) != 0) { + set_response_code(response, -1, "PushDense failed"); } return 0; } -int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request, +int32_t BrpcPsService::Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) @@ -299,15 +298,15 @@ int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request, auto trainer_id = request.client_id(); auto barrier_type = request.params(0); - table->barrier(trainer_id, barrier_type); + table->Barrier(trainer_id, barrier_type); return 0; } -int32_t BrpcPsService::push_sparse_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - platform::RecordEvent record_event("PsService->push_sparse_param", +int32_t BrpcPsService::PushSparseParam(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->PushSparseParam", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) @@ -331,16 +330,16 @@ int32_t BrpcPsService::push_sparse_param(Table *table, const uint64_t *keys = (const uint64_t *)push_data.data(); const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * num); - if (table->push_sparse_param(keys, values, num) != 0) { - set_response_code(response, -1, "push_sparse_param error"); + if (table->PushSparseParam(keys, values, num) != 0) { + set_response_code(response, -1, "PushSparseParam error"); } return 0; } -int32_t BrpcPsService::pull_geo_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PullGeoParam(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( "PsService->pull_geo_param", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) @@ -350,7 +349,7 @@ int32_t BrpcPsService::pull_geo_param(Table *table, std::vector values; std::vector ids; - table->pull_geo_param(trainer_id, &values, &ids); + table->PullGeoParam(trainer_id, &values, &ids); uint32_t num = ids.size(); cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); @@ -361,12 +360,11 @@ int32_t BrpcPsService::pull_geo_param(Table *table, return 0; } -int32_t BrpcPsService::pull_sparse(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->pull_sparse", platform::TracerEventType::Communication, 1); + "PsService->PullSparse", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) auto &req_io_buffer = cntl->request_attachment(); @@ -386,7 +384,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, CostTimer timer("pserver_server_pull_sparse"); uint32_t num = *(uint32_t *)(request.params(0).c_str()); - auto dim = table->value_accesor()->GetTableInfo(SELECT_DIM); + auto dim = table->ValueAccesor()->GetTableInfo(SELECT_DIM); thread_local std::string req_buffer; req_buffer.reserve(req_buffer_size); @@ -405,7 +403,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, table_context.pull_context.pull_value = value; table_context.pull_context.values = res_data->data(); table->Pull(table_context); - // table->pull_sparse(res_data->data(), value); + // table->PullSparse(res_data->data(), value); cntl->response_attachment().append((char *)(res_data->data()), res_data->size() * sizeof(float)); @@ -413,12 +411,11 @@ int32_t BrpcPsService::pull_sparse(Table *table, return 0; } -int32_t BrpcPsService::push_sparse(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PushSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event( - "PsService->push_sparse", platform::TracerEventType::Communication, 1); + "PsService->PushSparse", platform::TracerEventType::Communication, 1); CHECK_TABLE_EXIST(table, request, response) auto &push_data = request.data(); if (push_data.size() < 1) { @@ -448,18 +445,18 @@ int32_t BrpcPsService::push_sparse(Table *table, // const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * // num); if (table->Push(table_context) != 0) { - // if (table->push_sparse(keys, values, num) != 0) { - set_response_code(response, -1, "push_sparse error"); + // if (table->PushSparse(keys, values, num) != 0) { + set_response_code(response, -1, "PushSparse error"); } return 0; } -int32_t BrpcPsService::print_table_stat(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PrintTableStat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) - std::pair ret = table->print_table_stat(); + std::pair ret = table->PrintTableStat(); paddle::framework::BinaryArchive ar; ar << ret.first << ret.second; std::string table_info(ar.Buffer(), ar.Length()); @@ -468,10 +465,10 @@ int32_t BrpcPsService::print_table_stat(Table *table, return 0; } -int32_t BrpcPsService::load_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::LoadOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -479,20 +476,20 @@ int32_t BrpcPsService::load_one_table(Table *table, "PsRequestMessage.datas is requeired at least 2 for path & load_param"); return -1; } - if (table->load(request.params(0), request.params(1)) != 0) { + if (table->Load(request.params(0), request.params(1)) != 0) { set_response_code(response, -1, "table load failed"); return -1; } return 0; } -int32_t BrpcPsService::load_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t BrpcPsService::LoadAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); for (auto &itr : table_map) { - if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) { LOG(ERROR) << "load table[" << itr.first << "] failed"; return -1; } @@ -500,10 +497,10 @@ int32_t BrpcPsService::load_all_table(Table *table, return 0; } -int32_t BrpcPsService::save_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::SaveOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -511,12 +508,12 @@ int32_t BrpcPsService::save_one_table(Table *table, "PsRequestMessage.datas is requeired at least 2, path&mode"); return -1; } - table->flush(); + table->Flush(); int32_t feasign_size = 0; VLOG(3) << "save table " << request.params(0) << " " << request.params(1); - feasign_size = table->save(request.params(0), request.params(1)); + feasign_size = table->Save(request.params(0), request.params(1)); if (feasign_size < 0) { set_response_code(response, -1, "table save failed"); return -1; @@ -524,16 +521,16 @@ int32_t BrpcPsService::save_one_table(Table *table, return feasign_size; } -int32_t BrpcPsService::save_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t BrpcPsService::SaveAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); int32_t all_feasign_size = 0; int32_t feasign_size = 0; for (auto &itr : table_map) { - feasign_size = save_one_table(itr.second.get(), request, response, cntl); + feasign_size = SaveOneTable(itr.second.get(), request, response, cntl); if (feasign_size < 0) { LOG(ERROR) << "save table[" << itr.first << "] failed"; return -1; @@ -542,10 +539,10 @@ int32_t BrpcPsService::save_all_table(Table *table, return 0; } -int32_t BrpcPsService::shrink_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::ShrinkTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 1) { set_response_code( @@ -553,8 +550,8 @@ int32_t BrpcPsService::shrink_table(Table *table, "PsRequestMessage.datas is requeired at least 1, threshold"); return -1; } - table->flush(); - if (table->shrink(request.params(0)) != 0) { + table->Flush(); + if (table->Shrink(request.params(0)) != 0) { set_response_code(response, -1, "table shrink failed"); return -1; } @@ -562,63 +559,62 @@ int32_t BrpcPsService::shrink_table(Table *table, return 0; } -int32_t BrpcPsService::clear_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::ClearOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) - table->flush(); - table->clear(); + table->Flush(); + table->Clear(); return 0; } -int32_t BrpcPsService::clear_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t BrpcPsService::ClearAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); for (auto &itr : table_map) { - if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { + if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) { return -1; } } return 0; } -int32_t BrpcPsService::stop_server(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::StopServer(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { auto *p_server = _server; std::thread t_stop([p_server]() { - p_server->stop(); + p_server->Stop(); VLOG(3) << "Server Stoped"; }); t_stop.detach(); return 0; } -int32_t BrpcPsService::stop_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::StopProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::DisableProfiler(platform::EventSortingKey::kDefault, string::Sprintf("server_%s_profile", _rank)); return 0; } -int32_t BrpcPsService::start_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::StartProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::EnableProfiler(platform::ProfilerState::kCPU); return 0; } -int32_t BrpcPsService::push_global_step(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::PushGlobalStep(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response); auto req_buffer_size = request.data().size(); if (req_buffer_size < 1) { @@ -629,7 +625,7 @@ int32_t BrpcPsService::push_global_step(Table *table, const int64_t *values = (const int64_t *)(request.data().data() + sizeof(uint32_t)); auto trainer_id = request.client_id(); - if (table->push_dense(values, trainer_id) != 0) { + if (table->PushDense(values, trainer_id) != 0) { set_response_code(response, -1, "run_program failed"); } diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.h b/paddle/fluid/distributed/ps/service/brpc_ps_server.h index d81a3a5df07f1..250f465d84253 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.h @@ -41,8 +41,8 @@ class BrpcPsServer : public PSServer { public: BrpcPsServer() {} virtual ~BrpcPsServer() {} - virtual uint64_t start(const std::string &ip, uint32_t port); - virtual int32_t stop() { + virtual uint64_t Start(const std::string &ip, uint32_t port); + virtual int32_t Stop() { std::unique_lock lock(mutex_); stoped_ = true; cv_.notify_all(); @@ -51,10 +51,10 @@ class BrpcPsServer : public PSServer { _server.Join(); return 0; } - int32_t port(); + int32_t Port(); private: - virtual int32_t initialize(); + virtual int32_t Initialize(); mutable std::mutex mutex_; std::condition_variable cv_; bool stoped_ = false; @@ -71,7 +71,7 @@ typedef int32_t (BrpcPsService::*serviceHandlerFunc)( class BrpcPsService : public PsBaseService { public: - virtual int32_t initialize() override; + virtual int32_t Initialize() override; virtual void service(::google::protobuf::RpcController *controller, const PsRequestMessage *request, @@ -79,50 +79,49 @@ class BrpcPsService : public PsBaseService { ::google::protobuf::Closure *done) override; private: - int32_t initialize_shard_info(); - int32_t pull_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_dense_param(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_sparse_param(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl); - int32_t pull_sparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t pull_geo_param(Table *table, const PsRequestMessage &request, + int32_t InitializeShardInfo(); + int32_t PullDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PushDense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PushDenseParam(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t barrier(Table *table, const PsRequestMessage &request, + int32_t PushSparseParam(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PullSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t PullGeoParam(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_sparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t save_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t save_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t shrink_table(Table *table, const PsRequestMessage &request, + int32_t PushSparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t LoadOneTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t clear_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t clear_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_server(Table *table, const PsRequestMessage &request, + int32_t LoadAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t SaveOneTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t SaveAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t ShrinkTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t start_profiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_profiler(Table *table, const PsRequestMessage &request, + int32_t ClearOneTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t ClearAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopServer(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StartProfiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopProfiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); - int32_t print_table_stat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t PrintTableStat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); - int32_t push_global_step(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t PushGlobalStep(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); bool _is_initialize_shard_info; std::mutex _initialize_shard_mutex; diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index 50c34bd319253..c4b833f294e17 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -39,7 +39,7 @@ inline double GetCurrentUS() { Communicator::Communicator() {} -void Communicator::init_gflag(const std::string &gflags) { +void Communicator::InitGFlag(const std::string &gflags) { VLOG(3) << "Init With Gflags:" << gflags; std::vector flags = paddle::string::split_string(gflags); if (flags.size() < 1) { @@ -73,7 +73,7 @@ void Communicator::InitBrpcClient( } std::vector Communicator::GetClientInfo() { - std::vector res = _ps_env.get_client_info(); + std::vector res = _ps_env.GetClientInfo(); for (auto rr : res) { VLOG(2) << "Communicator::GetClientInfo " << rr; } @@ -82,7 +82,7 @@ std::vector Communicator::GetClientInfo() { int Communicator::SetClients(std::vector &host_sign_list) { int node = host_sign_list.size(); - return _ps_env.set_ps_clients(host_sign_list.data(), node); + return _ps_env.SetPsClients(host_sign_list.data(), node); } void Communicator::RpcRecvDense(const std::vector &varnames, @@ -114,7 +114,7 @@ void Communicator::RpcRecvDense(const std::vector &varnames, } } auto status = - _worker_ptr->pull_dense(regions.data(), regions.size(), table_id); + _worker_ptr->PullDense(regions.data(), regions.size(), table_id); status.wait(); for (auto &t : varnames) { @@ -177,7 +177,7 @@ void Communicator::RpcSendDenseParam(const std::vector &varnames, } } auto status = - _worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); + _worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id); status.wait(); VLOG(4) << "RPC Send Dense Param " << table_id << " done!"; return; @@ -190,9 +190,9 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { auto &var_names = ctx.origin_varnames; auto &table_id = ctx.table_id; auto dense_data = std::make_shared>(); - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); uint32_t num_per_shard = - dense_dim_per_shard(ctx.height_sections[0], request_call_num); + DenseDimPerShard(ctx.height_sections[0], request_call_num); dense_data->resize(num_per_shard * request_call_num); // accessor->update_dim() = 1 float *data = dense_data->data(); @@ -222,8 +222,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { closure->set_promise_value(ret); --_async_call_num; }); - auto status = _worker_ptr->push_dense_raw_gradient( - table_id, data, dense_data->size(), closure); + auto status = _worker_ptr->PushDenseRawGradient(table_id, data, + dense_data->size(), closure); status.wait(); return; } @@ -233,7 +233,7 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, platform::RecordEvent record_event("Communicator->RpcSendSparseParam", platform::TracerEventType::Communication, 1); - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); std::vector push_g_vec; auto *send_var = scope.FindVar(varname); @@ -260,9 +260,9 @@ void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, } closure->set_promise_value(ret); }); - auto status = _worker_ptr->push_sparse_param( - table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), - sparse_push_keys.size(), closure); + auto status = _worker_ptr->PushSparseParam(table_id, sparse_push_keys.data(), + (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); status.wait(); return; } @@ -272,7 +272,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, platform::RecordEvent record_event("Communicator->RpcSendSparse", platform::TracerEventType::Communication, 1); - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); std::vector sparse_push_keys; std::vector push_g_vec; @@ -313,7 +313,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, closure->set_promise_value(ret); --_async_call_num; }); - auto status = _worker_ptr->push_sparse_raw_gradient( + auto status = _worker_ptr->PushSparseRawGradient( table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), sparse_push_keys.size(), closure); status.wait(); @@ -340,7 +340,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id, bool training = true; - auto status = _worker_ptr->pull_sparse_param( + auto status = _worker_ptr->PullSparseParam( (float **)push_g_vec.data(), table_id, // NOLINT sparse_push_keys.data(), sparse_push_keys.size(), training); status.wait(); @@ -376,11 +376,11 @@ void Communicator::RpcProfilerControl() { if (!do_server_profiler_ && platform::IsProfileEnabled()) { // send profiler start flag do_server_profiler_ = true; - auto start_status = _worker_ptr->start_profiler(); + auto start_status = _worker_ptr->StartProfiler(); start_status.wait(); } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { // send profiler end flag - auto stop_status = _worker_ptr->stop_profiler(); + auto stop_status = _worker_ptr->StopProfiler(); stop_status.wait(); do_server_profiler_ = false; } @@ -396,7 +396,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, platform::TracerEventType::Communication, 1); auto &table_id = ctx.table_id; - size_t request_call_num = _worker_ptr->get_server_nums(); + size_t request_call_num = _worker_ptr->GetServerNums(); auto &var_name = STEP_COUNTER; auto *out_var = send_scope->Var(var_name); @@ -416,7 +416,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, } closure->set_promise_value(ret); }); - auto status = _worker_ptr->push_global_step(table_id, data, closure); + auto status = _worker_ptr->PushGlobalStep(table_id, data, closure); status.wait(); return; } @@ -605,8 +605,8 @@ void AsyncCommunicator::PullSparseToTensorSync( } } auto status = - _worker_ptr->pull_sparse(pull_result_ptr.data(), table_id, - fea_keys.data(), fea_keys.size(), is_training); + _worker_ptr->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(), + fea_keys.size(), is_training); status.wait(); auto ret = status.get(); if (ret != 0) { @@ -738,9 +738,9 @@ void AsyncCommunicator::PushSparseFromTensorAsync( this->Check(table_id), true, platform::errors::InvalidArgument( "can not find table: %s, please check your config", table_id)); - auto status = _worker_ptr->push_sparse(table_id, push_keys.data(), - (const float **)push_g_vec.data(), - push_keys.size()); + auto status = _worker_ptr->PushSparse(table_id, push_keys.data(), + (const float **)push_g_vec.data(), + push_keys.size()); } void HalfAsyncCommunicator::MainThread() { @@ -813,7 +813,7 @@ void AsyncCommunicator::Stop() { if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; } else { - // _worker_ptr->finalize_worker(); + // _worker_ptr->FinalizeWorker(); VLOG(1) << "client finalize_worker done"; if (recv_thread_) { VLOG(1) << "stop recv thread"; @@ -1327,7 +1327,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, closure->set_promise_value(ret); --_async_call_num; }); - auto status = _worker_ptr->push_sparse_raw_gradient_partial( + auto status = _worker_ptr->PushSparseRawGradientPartial( table_id, (const uint64_t *)sparse_ids.data(), (const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); status.wait(); @@ -1345,7 +1345,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, // 1. recv from pserver std::vector keys; std::vector values; - auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx); + auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx); status.wait(); std::string param = SplitedGradToParam(varname); diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.h b/paddle/fluid/distributed/ps/service/communicator/communicator.h index da4b46928d55c..75676c392435c 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.h @@ -299,7 +299,7 @@ class Communicator { virtual void Barrier() {} virtual void BarrierWithTable(uint32_t barrier_type) { - auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); + auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type); rets.wait(); int status = rets.get(); PADDLE_ENFORCE_EQ(status, 0, @@ -310,7 +310,7 @@ class Communicator { virtual void CreateC2CConnection(int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { - _worker_ptr->create_client2client_connection( + _worker_ptr->CreateClient2ClientConnection( pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); } @@ -379,12 +379,12 @@ class Communicator { std::unordered_map envs; // 计算每个shard 对 dense的存储量 - inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, - uint32_t shard_num) { + inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, + uint32_t shard_num) { return dense_dim_total / shard_num + 1; } - void init_gflag(const std::string &gflags); + void InitGFlag(const std::string &gflags); paddle::distributed::PSParameter _ps_param; paddle::distributed::PaddlePSEnvironment _ps_env; int servers_ = 0; diff --git a/paddle/fluid/distributed/ps/service/env.h b/paddle/fluid/distributed/ps/service/env.h index 0cc57229b7a82..162ee6f098422 100644 --- a/paddle/fluid/distributed/ps/service/env.h +++ b/paddle/fluid/distributed/ps/service/env.h @@ -40,7 +40,7 @@ struct PSHost { // |---ip---|---port---|--rank--| // |-32bit--|--20bit---|--12bit-| - uint64_t serialize_to_uint64() { + uint64_t SerializeToUint64() { uint64_t host_label = 0; host_label = inet_addr(ip.c_str()); host_label = host_label << 32; @@ -49,7 +49,7 @@ struct PSHost { return host_label; } - void parse_from_uint64(uint64_t host_label) { + void ParseFromUint64(uint64_t host_label) { static uint64_t rank_label_mask = (1L << 12) - 1; static uint64_t port_label_mask = (1L << 20) - 1; rank = host_label & rank_label_mask; @@ -58,17 +58,17 @@ struct PSHost { ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT } - std::string to_string() { + std::string ToString() { std::stringstream s; s << "host: " << ip; s << " port: " << port; s << " rank: " << rank; - s << " uint: " << serialize_to_uint64(); + s << " uint: " << SerializeToUint64(); return s.str(); } // for open source parameter server - std::string serialize_to_string() { + std::string SerializeToString() { std::stringstream s; s << ip << ":"; s << port << ":"; @@ -76,16 +76,16 @@ struct PSHost { return s.str(); } - void parse_from_string(std::string endpoint) { + void ParseFromString(std::string endpoint) { std::vector endpoint_info; - string_split(endpoint, ':', &endpoint_info); + StringSplit(endpoint, ':', &endpoint_info); ip = endpoint_info[0]; port = std::stoi(endpoint_info[1]); rank = std::stoi(endpoint_info[2]); } - void string_split(const std::string &str, char sep, - std::vector *pieces, bool ignore_null = true) { + void StringSplit(const std::string &str, char sep, + std::vector *pieces, bool ignore_null = true) { pieces->clear(); if (str.empty()) { if (!ignore_null) { @@ -111,63 +111,60 @@ class PSEnvironment { explicit PSEnvironment() {} // NOLINT virtual ~PSEnvironment() {} - virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) { return 0; } - virtual int32_t set_ps_servers( + virtual int32_t SetPsServers( const std::vector *host_endpoint_list, int node_num) { return 0; } - virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) { return 0; } - virtual int32_t set_ps_clients(std::string *host_endpoint_list, - int node_num) { + virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) { return 0; } - virtual uint64_t get_local_host_sign() { return 0; } - virtual std::vector get_ps_servers() const { return _ps_server_list; } - virtual int32_t registe_ps_server(const std::string &ip, uint32_t port, - int32_t rank) { - return registe_ps_host(ip, port, rank, _ps_server_list, - _ps_server_sign_set); + virtual uint64_t GetLocalHostSign() { return 0; } + virtual std::vector GetPsServers() const { return _ps_server_list; } + virtual int32_t RegistePsServer(const std::string &ip, uint32_t port, + int32_t rank) { + return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set); } - virtual std::vector get_ps_clients() const { return _ps_client_list; } - virtual int32_t registe_ps_client(const std::string &ip, uint32_t port, - int32_t rank) { - return registe_ps_host(ip, port, rank, _ps_client_list, - _ps_client_sign_set); + virtual std::vector GetPsClients() const { return _ps_client_list; } + virtual int32_t RegistePsClient(const std::string &ip, uint32_t port, + int32_t rank) { + return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set); } - virtual std::vector get_client_info() { + virtual std::vector GetClientInfo() { std::vector client_info; for (auto &i : _ps_client_list) { - client_info.push_back(i.serialize_to_uint64()); + client_info.push_back(i.SerializeToUint64()); } return client_info; } - virtual std::vector get_client_info(bool use_string_endpoint) { + virtual std::vector GetClientInfo(bool use_string_endpoint) { if (use_string_endpoint) { std::vector client_info; for (auto &i : _ps_client_list) { - client_info.push_back(i.serialize_to_string()); + client_info.push_back(i.SerializeToString()); } return client_info; } return {}; } - virtual void set_trainers(int trainers) { trainers_ = trainers; } + virtual void SetTrainers(int trainers) { trainers_ = trainers; } - virtual int get_trainers() { return trainers_; } + virtual int GetTrainers() { return trainers_; } protected: //注册一个host // NOLINT - virtual int32_t registe_ps_host( + virtual int32_t RegistePsHost( const std::string &ip, uint32_t port, int32_t rank, std::vector &host_list, // NOLINT std::unordered_set &sign_set) { // NOLINT @@ -198,15 +195,15 @@ class PaddlePSEnvironment : public PSEnvironment { explicit PaddlePSEnvironment() {} // NOLINT virtual ~PaddlePSEnvironment() {} - virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) { _ps_server_list.clear(); _ps_server_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list[i] > 0) { PSHost host; - host.parse_from_uint64(host_sign_list[i]); + host.ParseFromUint64(host_sign_list[i]); _ps_server_list.push_back(host); - _ps_server_sign_set.insert(host.serialize_to_uint64()); + _ps_server_sign_set.insert(host.SerializeToUint64()); } } std::sort( @@ -215,14 +212,14 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual int32_t set_ps_servers(const std::vector *host_sign_list, - int node_num) { + virtual int32_t SetPsServers(const std::vector *host_sign_list, + int node_num) { _ps_server_list.clear(); _ps_server_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list->at(i) != "") { PSHost host; - host.parse_from_string(host_sign_list->at(i)); + host.ParseFromString(host_sign_list->at(i)); _ps_server_list.push_back(host); _ps_server_sign_set.insert(host.rank); } @@ -233,15 +230,15 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) { _ps_client_list.clear(); _ps_client_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list[i] > 0) { PSHost host; - host.parse_from_uint64(host_sign_list[i]); + host.ParseFromUint64(host_sign_list[i]); _ps_client_list.push_back(host); - _ps_client_sign_set.insert(host.serialize_to_uint64()); + _ps_client_sign_set.insert(host.SerializeToUint64()); } } std::sort( @@ -250,14 +247,14 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual int32_t set_ps_clients(const std::vector *host_sign_list, - int node_num) { + virtual int32_t SetPsClients(const std::vector *host_sign_list, + int node_num) { _ps_client_list.clear(); _ps_client_sign_set.clear(); for (int i = 0; i < node_num; ++i) { if (host_sign_list->at(i) != "") { PSHost host; - host.parse_from_string(host_sign_list->at(i)); + host.ParseFromString(host_sign_list->at(i)); _ps_client_list.push_back(host); _ps_client_sign_set.insert(host.rank); } @@ -269,9 +266,9 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual uint64_t get_local_host_sign() { + virtual uint64_t GetLocalHostSign() { if (_ps_client_list.size() > 0) { - return _ps_client_list[0].serialize_to_uint64(); + return _ps_client_list[0].SerializeToUint64(); } else { return 0; } diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_client.cc b/paddle/fluid/distributed/ps/service/graph_brpc_client.cc index a3db88e3b679d..827a643ee50d6 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/ps/service/graph_brpc_client.cc @@ -135,8 +135,7 @@ std::future GraphBrpcClient::get_node_feat( closure->request(request_idx) ->add_params(joint_feature_name.c_str(), joint_feature_name.size()); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -169,8 +168,7 @@ std::future GraphBrpcClient::clear_nodes(uint32_t table_id) { closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_client_id(_client_id); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(server_index), closure->request(server_index), @@ -238,9 +236,8 @@ std::future GraphBrpcClient::add_graph_node( ->add_params((char *)weighted, sizeof(bool) * is_weighted_bucket[request_idx].size()); } - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -292,9 +289,8 @@ std::future GraphBrpcClient::remove_graph_node( closure->request(request_idx) ->add_params((char *)request_bucket[request_idx].data(), sizeof(int64_t) * node_num); - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -362,9 +358,8 @@ std::future GraphBrpcClient::batch_sample_neighbors( closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&need_weight, sizeof(bool)); ; - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -464,9 +459,8 @@ std::future GraphBrpcClient::batch_sample_neighbors( ->add_params((char *)&sample_size, sizeof(int)); closure->request(request_idx) ->add_params((char *)&need_weight, sizeof(bool)); - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -506,8 +500,8 @@ std::future GraphBrpcClient::random_sample_nodes( closure->request(0)->set_client_id(_client_id); closure->request(0)->add_params((char *)&sample_size, sizeof(int)); ; - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -541,8 +535,7 @@ std::future GraphBrpcClient::load_graph_split_config( closure->request(server_index)->set_table_id(table_id); closure->request(server_index)->set_client_id(_client_id); closure->request(server_index)->add_params(path); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(server_index), closure->request(server_index), @@ -581,8 +574,7 @@ std::future GraphBrpcClient::use_neighbors_sample_cache( closure->request(server_index) ->add_params((char *)&size_limit, sizeof(size_t)); closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t)); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(server_index), closure->request(server_index), @@ -624,8 +616,8 @@ std::future GraphBrpcClient::pull_graph_list( closure->request(0)->add_params((char *)&start, sizeof(int)); closure->request(0)->add_params((char *)&size, sizeof(int)); closure->request(0)->add_params((char *)&step, sizeof(int)); - // PsService_Stub rpc_stub(get_cmd_channel(server_index)); - GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index)); + // PsService_Stub rpc_stub(GetCmdChannel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), closure); @@ -717,8 +709,7 @@ std::future GraphBrpcClient::set_node_feat( closure->request(request_idx) ->add_params(set_feature.c_str(), set_feature.size()); - GraphPsService_Stub rpc_stub = - getServiceStub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); @@ -727,10 +718,10 @@ std::future GraphBrpcClient::set_node_feat( return fut; } -int32_t GraphBrpcClient::initialize() { +int32_t GraphBrpcClient::Initialize() { // set_shard_num(_config.shard_num()); - BrpcPsClient::initialize(); - server_size = get_server_nums(); + BrpcPsClient::Initialize(); + server_size = GetServerNums(); graph_service = NULL; local_channel = NULL; return 0; diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_client.h b/paddle/fluid/distributed/ps/service/graph_brpc_client.h index e2b8a518615dc..d1d3c95260df4 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/ps/service/graph_brpc_client.h @@ -97,12 +97,12 @@ class GraphBrpcClient : public BrpcPsClient { std::string path); virtual std::future remove_graph_node( uint32_t table_id, std::vector& node_id_list); - virtual int32_t initialize(); + virtual int32_t Initialize(); int get_shard_num() { return shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; } int get_server_index_by_id(int64_t id); void set_local_channel(int index) { - this->local_channel = get_cmd_channel(index); + this->local_channel = GetCmdChannel(index); } void set_local_graph_service(GraphBrpcService* graph_service) { this->graph_service = graph_service; diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_server.cc b/paddle/fluid/distributed/ps/service/graph_brpc_server.cc index 20a55e4d11983..21e590997b178 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/ps/service/graph_brpc_server.cc @@ -33,7 +33,7 @@ namespace distributed { return -1; \ } -int32_t GraphBrpcServer::initialize() { +int32_t GraphBrpcServer::Initialize() { auto &service_config = _config.downpour_server_param().service_param(); if (!service_config.has_service_class()) { LOG(ERROR) << "miss service_class in ServerServiceParameter"; @@ -48,7 +48,7 @@ int32_t GraphBrpcServer::initialize() { } _service.reset(service); - if (service->configure(this) != 0 || service->initialize() != 0) { + if (service->Configure(this) != 0 || service->Initialize() != 0) { LOG(ERROR) << "service initialize failed, service_name:" << service_config.service_class(); return -1; @@ -61,11 +61,11 @@ int32_t GraphBrpcServer::initialize() { return 0; } -brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) { +brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) { return _pserver_channels[server_index].get(); } -uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { +uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) { std::unique_lock lock(mutex_); std::string ip_port = ip + ":" + std::to_string(port); @@ -73,20 +73,20 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { brpc::ServerOptions options; int num_threads = std::thread::hardware_concurrency(); - auto trainers = _environment->get_trainers(); + auto trainers = _environment->GetTrainers(); options.num_threads = trainers > num_threads ? trainers : num_threads; if (_server.Start(ip_port.c_str(), &options) != 0) { LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port; return 0; } - _environment->registe_ps_server(ip, port, _rank); + _environment->RegistePsServer(ip, port, _rank); return 0; } int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { this->rank = rank; - auto _env = environment(); + auto _env = Environment(); brpc::ChannelOptions options; options.protocol = "baidu_std"; options.timeout_ms = 500000; @@ -94,7 +94,7 @@ int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { options.connect_timeout_ms = 10000; options.max_retry = 3; - std::vector server_list = _env->get_ps_servers(); + std::vector server_list = _env->GetPsServers(); _pserver_channels.resize(server_list.size()); std::ostringstream os; std::string server_ip_port; @@ -172,19 +172,18 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ((GraphTable *)table)->remove_graph_node(node_ids); return 0; } -int32_t GraphBrpcServer::port() { return _server.listen_address().port; } +int32_t GraphBrpcServer::Port() { return _server.listen_address().port; } -int32_t GraphBrpcService::initialize() { +int32_t GraphBrpcService::Initialize() { _is_initialize_shard_info = false; - _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server; - _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table; - _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table; + _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer; + _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable; + _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable; - _service_handler_map[PS_PRINT_TABLE_STAT] = - &GraphBrpcService::print_table_stat; - _service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier; - _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler; - _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler; + _service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat; + _service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier; + _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler; + _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler; _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] = @@ -207,21 +206,21 @@ int32_t GraphBrpcService::initialize() { _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] = &GraphBrpcService::load_graph_split_config; // shard初始化,server启动后才可从env获取到server_list的shard信息 - initialize_shard_info(); + InitializeShardInfo(); return 0; } -int32_t GraphBrpcService::initialize_shard_info() { +int32_t GraphBrpcService::InitializeShardInfo() { if (!_is_initialize_shard_info) { std::lock_guard guard(_initialize_shard_mutex); if (_is_initialize_shard_info) { return 0; } - server_size = _server->environment()->get_ps_servers().size(); - auto &table_map = *(_server->table()); + server_size = _server->Environment()->GetPsServers().size(); + auto &table_map = *(_server->GetTable()); for (auto itr : table_map) { - itr.second->set_shard(_rank, server_size); + itr.second->SetShard(_rank, server_size); } _is_initialize_shard_info = true; } @@ -241,7 +240,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, response->set_err_code(0); response->set_err_msg(""); - auto *table = _server->table(request->table_id()); + auto *table = _server->GetTable(request->table_id()); brpc::Controller *cntl = static_cast(cntl_base); auto itr = _service_handler_map.find(request->cmd_id()); if (itr == _service_handler_map.end()) { @@ -261,7 +260,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, } } -int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, +int32_t GraphBrpcService::Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) @@ -275,16 +274,16 @@ int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request, auto trainer_id = request.client_id(); auto barrier_type = request.params(0); - table->barrier(trainer_id, barrier_type); + table->Barrier(trainer_id, barrier_type); return 0; } -int32_t GraphBrpcService::print_table_stat(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::PrintTableStat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) - std::pair ret = table->print_table_stat(); + std::pair ret = table->PrintTableStat(); paddle::framework::BinaryArchive ar; ar << ret.first << ret.second; std::string table_info(ar.Buffer(), ar.Length()); @@ -293,10 +292,10 @@ int32_t GraphBrpcService::print_table_stat(Table *table, return 0; } -int32_t GraphBrpcService::load_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::LoadOneTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -304,20 +303,20 @@ int32_t GraphBrpcService::load_one_table(Table *table, "PsRequestMessage.datas is requeired at least 2 for path & load_param"); return -1; } - if (table->load(request.params(0), request.params(1)) != 0) { + if (table->Load(request.params(0), request.params(1)) != 0) { set_response_code(response, -1, "table load failed"); return -1; } return 0; } -int32_t GraphBrpcService::load_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { - auto &table_map = *(_server->table()); +int32_t GraphBrpcService::LoadAllTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); for (auto &itr : table_map) { - if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) { LOG(ERROR) << "load table[" << itr.first << "] failed"; return -1; } @@ -325,13 +324,13 @@ int32_t GraphBrpcService::load_all_table(Table *table, return 0; } -int32_t GraphBrpcService::stop_server(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::StopServer(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { GraphBrpcServer *p_server = (GraphBrpcServer *)_server; std::thread t_stop([p_server]() { - p_server->stop(); + p_server->Stop(); LOG(INFO) << "Server Stoped"; }); p_server->export_cv()->notify_all(); @@ -339,19 +338,19 @@ int32_t GraphBrpcService::stop_server(Table *table, return 0; } -int32_t GraphBrpcService::stop_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::StopProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::DisableProfiler(platform::EventSortingKey::kDefault, string::Sprintf("server_%s_profile", _rank)); return 0; } -int32_t GraphBrpcService::start_profiler(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t GraphBrpcService::StartProfiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::EnableProfiler(platform::ProfilerState::kCPU); return 0; } @@ -475,7 +474,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( std::vector server2request(server_size, -1); std::vector local_id; std::vector local_query_idx; - size_t rank = get_rank(); + size_t rank = GetRank(); for (int query_idx = 0; query_idx < node_num; ++query_idx) { int server_index = ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); @@ -589,9 +588,9 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( closure->request(request_idx) ->add_params((char *)&need_weight, sizeof(bool)); PsService_Stub rpc_stub( - ((GraphBrpcServer *)get_server())->get_cmd_channel(server_index)); + ((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index)); // GraphPsService_Stub rpc_stub = - // getServiceStub(get_cmd_channel(server_index)); + // getServiceStub(GetCmdChannel(server_index)); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx), closure); diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_server.h b/paddle/fluid/distributed/ps/service/graph_brpc_server.h index a978d97b296b0..caf728701b289 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/ps/service/graph_brpc_server.h @@ -31,10 +31,10 @@ class GraphBrpcServer : public PSServer { GraphBrpcServer() {} virtual ~GraphBrpcServer() {} PsBaseService *get_service() { return _service.get(); } - virtual uint64_t start(const std::string &ip, uint32_t port); + virtual uint64_t Start(const std::string &ip, uint32_t port); virtual int32_t build_peer2peer_connection(int rank); - virtual brpc::Channel *get_cmd_channel(size_t server_index); - virtual int32_t stop() { + virtual brpc::Channel *GetCmdChannel(size_t server_index); + virtual int32_t Stop() { std::unique_lock lock(mutex_); if (stoped_) return 0; stoped_ = true; @@ -43,12 +43,12 @@ class GraphBrpcServer : public PSServer { _server.Join(); return 0; } - int32_t port(); + int32_t Port(); std::condition_variable *export_cv() { return &cv_; } private: - virtual int32_t initialize(); + virtual int32_t Initialize(); mutable std::mutex mutex_; std::condition_variable cv_; bool stoped_ = false; @@ -66,7 +66,7 @@ typedef int32_t (GraphBrpcService::*serviceFunc)( class GraphBrpcService : public PsBaseService { public: - virtual int32_t initialize() override; + virtual int32_t Initialize() override; virtual void service(::google::protobuf::RpcController *controller, const PsRequestMessage *request, @@ -75,7 +75,7 @@ class GraphBrpcService : public PsBaseService { protected: std::unordered_map _service_handler_map; - int32_t initialize_shard_info(); + int32_t InitializeShardInfo(); int32_t pull_graph_list(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); int32_t graph_random_sample_neighbors(Table *table, @@ -100,21 +100,21 @@ class GraphBrpcService : public PsBaseService { int32_t remove_graph_node(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t barrier(Table *table, const PsRequestMessage &request, + int32_t Barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t load_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_server(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t start_profiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); - int32_t stop_profiler(Table *table, const PsRequestMessage &request, + int32_t LoadOneTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t LoadAllTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopServer(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t StartProfiler(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t StopProfiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); - int32_t print_table_stat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, brpc::Controller *cntl); + int32_t PrintTableStat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); int32_t sample_neighbors_across_multi_servers(Table *table, const PsRequestMessage &request, diff --git a/paddle/fluid/distributed/ps/service/ps_client.cc b/paddle/fluid/distributed/ps/service/ps_client.cc index 27f2d88fdd9fa..f7df99ec13cdf 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_client.cc @@ -25,7 +25,7 @@ REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient); REGISTER_PSCORE_CLASS(PSClient, PsLocalClient); REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient); -int32_t PSClient::configure( +int32_t PSClient::Configure( const PSParameter &config, const std::map> ®ions, PSEnvironment &env, size_t client_id) { @@ -51,10 +51,10 @@ int32_t PSClient::configure( _table_accessors[work_param.downpour_table_param(i).table_id()].reset( accessor); } - return initialize(); + return Initialize(); } -PSClient *PSClientFactory::create(const PSParameter &ps_config) { +PSClient *PSClientFactory::Create(const PSParameter &ps_config) { const auto &config = ps_config.server_param(); if (!config.has_downpour_server_param()) { LOG(ERROR) << "miss downpour_server_param in ServerParameter"; @@ -81,7 +81,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { return NULL; } - TableManager::instance().initialize(); + TableManager::Instance().Initialize(); VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success"; return client; } diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 83d2aba1db445..6f27b0eb04624 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -26,7 +26,6 @@ #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h" -#include "paddle/fluid/distributed/ps/table/table.h" #include "paddle/fluid/platform/timer.h" namespace paddle { @@ -60,41 +59,6 @@ class PSClientClosure : public google::protobuf::Closure { std::vector>> _promises; }; -struct LoadSaveContext { - int table_id; - std::string epoch; - std::string mode; -}; - -enum TrainingMode { Async = 0, Sync = 1, Geo = 3 }; - -enum TrainingPhase { Init = 0, Train = 1, Save = 2 }; - -// enum ValueType { -// Sparse = 0, -// Dense = 1 -// }; - -struct PushContext { - const uint64_t *keys; - const float **push_values; - const Region *push_dense_values; -}; - -struct RequestContext { - int table; - TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync - TrainingPhase training_phase; // 1 for init, 2 for train - ValueType value_type; // 1 for sparse, 2 for dense - uint64_t *keys; - float **sparse_values; // for sparse values - Region *dense_values; // for dense values - PushContext push_context; - size_t num; - bool is_training; - void *callback; -}; - class PSClient { public: PSClient() {} @@ -102,41 +66,37 @@ class PSClient { PSClient(PSClient &&) = delete; PSClient(const PSClient &) = delete; - virtual int32_t configure( // NOLINT + virtual int32_t Configure( // NOLINT const PSParameter &config, const std::map> ®ions, PSEnvironment &_env, size_t client_id) final; // NOLINT - virtual int32_t create_client2client_connection( - int pserver_timeout_ms, int pserver_connect_timeout_ms, - int max_retry) = 0; + virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) = 0; // 触发table数据退场 - virtual std::future shrink(uint32_t table_id, + virtual std::future Shrink(uint32_t table_id, const std::string threshold) = 0; // 全量table进行数据load - virtual std::future load(const std::string &epoch, + virtual std::future Load(const std::string &epoch, const std::string &mode) = 0; // 指定table数据load - virtual std::future load(uint32_t table_id, const std::string &epoch, + virtual std::future Load(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; - // context配置load选项 - virtual std::future Load(const LoadSaveContext &load_context) = 0; // 全量table数据save value_accessor根据mode,可能有不同的save条件 - virtual std::future save(const std::string &epoch, + virtual std::future Save(const std::string &epoch, const std::string &mode) = 0; // 指定table数据save value_accessor根据mode,可能有不同的save条件 - virtual std::future save(uint32_t table_id, const std::string &epoch, + virtual std::future Save(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; - virtual std::future Save(const LoadSaveContext &save_context) = 0; - // 清空table数据 - virtual std::future clear() = 0; - virtual std::future clear(uint32_t table_id) = 0; + virtual std::future Clear() = 0; + virtual std::future Clear(uint32_t table_id) = 0; // pull dense的参数部分,并分块填充到本地网络参数中 // start和num用于拉取部分参数 @@ -145,23 +105,19 @@ class PSClient { // sender聚集同一区块的请求,累计多个填充buffer // server将参数区块中配置的某一维提取返回 // 返回数据解包后填充到累计的多个buffer中 - virtual std::future pull_dense(Region *regions, size_t region_num, - size_t table_id) = 0; // 保留 - - virtual std::future Push(RequestContext &push_context) = 0; + virtual std::future PullDense(Region *regions, size_t region_num, + size_t table_id) = 0; // 保留 // firstly push dense param for parameter server // this is neccessary because dense weight initialized in trainer on cold // start - virtual std::future push_dense_param(const Region *regions, - size_t region_num, - size_t table_id) = 0; - - virtual std::future push_dense(const Region *regions, - size_t region_num, - size_t table_id) = 0; + virtual std::future PushDenseParam(const Region *regions, + size_t region_num, + size_t table_id) = 0; - virtual std::future Pull(RequestContext &pull_context) = 0; + virtual std::future PushDense(const Region *regions, + size_t region_num, + size_t table_id) = 0; // 使用keys进行pull请求,结果填充values // keys和values的个数均为num个,每个value占用select_size空间 @@ -169,15 +125,14 @@ class PSClient { // 整合多个线程请求的keys,聚集并分散发送到server // 返回结果后,遍历buffer并对values赋值 // is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理. - virtual std::future pull_sparse(float **select_values, - size_t table_id, - const uint64_t *keys, size_t num, - bool is_training) = 0; - - virtual std::future pull_sparse_param(float **select_values, - size_t table_id, - const uint64_t *keys, - size_t num, bool is_training) { + virtual std::future PullSparse(float **select_values, + size_t table_id, const uint64_t *keys, + size_t num, bool is_training) = 0; + + virtual std::future PullSparseParam(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num, + bool is_training) { VLOG(0) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); @@ -185,10 +140,10 @@ class PSClient { return fut; } - virtual ::std::future pull_sparse_ptr(char **select_values, - size_t table_id, - const uint64_t *keys, - size_t num) { + virtual ::std::future PullSparsePtr(char **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) { VLOG(0) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); @@ -196,38 +151,38 @@ class PSClient { return fut; } - virtual std::future print_table_stat(uint32_t table_id) = 0; + virtual std::future PrintTableStat(uint32_t table_id) = 0; // 确保所有积攒中的请求都发起发送 - virtual std::future flush() = 0; + virtual std::future Flush() = 0; // server优雅退出 - virtual std::future stop_server() = 0; + virtual std::future StopServer() = 0; // server profilera - virtual std::future start_profiler() = 0; - virtual std::future stop_profiler() = 0; + virtual std::future StartProfiler() = 0; + virtual std::future StopProfiler() = 0; - virtual std::future barrier(size_t table_id, + virtual std::future Barrier(size_t table_id, uint32_t barrier_type) = 0; - virtual std::future pull_geo_param(size_t table_id, - std::vector *values, - std::vector *keys, - int pserver_idx) = 0; + virtual std::future PullGeoParam(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) = 0; - virtual std::future push_global_step(int table_id, - int64_t *total_send_data, - void *done) = 0; + virtual std::future PushGlobalStep(int table_id, + int64_t *total_send_data, + void *done) = 0; // recv table from server and save it in LodTensor - virtual int32_t recv_and_save_table(const uint64_t table_id, - const std::string &path) = 0; + virtual int32_t RecvAndSaveTable(const uint64_t table_id, + const std::string &path) = 0; - virtual void finalize_worker() = 0; + virtual void FinalizeWorker() = 0; // client to client, 消息发送 - virtual std::future send_client2client_msg(int msg_type, - int to_client_id, - const std::string &msg) { + virtual std::future SendClient2ClientMsg(int msg_type, + int to_client_id, + const std::string &msg) { VLOG(0) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); @@ -238,13 +193,13 @@ class PSClient { // client2client消息处理,std::function ret (msg_type, from_client_id, msg) typedef std::function MsgHandlerFunc; - virtual int registe_client2client_msg_handler(int msg_type, - MsgHandlerFunc handler) { + virtual int RegisteClient2ClientMsgHandler(int msg_type, + MsgHandlerFunc handler) { _msg_handler_map[msg_type] = handler; return 0; } - virtual int handle_client2client_msg(int msg_type, int from_client_id, - const std::string &msg) { + virtual int HandleClient2ClientMsg(int msg_type, int from_client_id, + const std::string &msg) { auto itr = _msg_handler_map.find(msg_type); if (itr == _msg_handler_map.end()) { LOG(WARNING) << "unknown client2client_msg type:" << msg_type; @@ -253,7 +208,7 @@ class PSClient { return itr->second(msg_type, from_client_id, msg); } - virtual ValueAccessor *table_accessor(size_t table_id) { + virtual ValueAccessor *GetTableAccessor(size_t table_id) { auto itr = _table_accessors.find(table_id); if (itr == _table_accessors.end()) { return NULL; @@ -261,31 +216,31 @@ class PSClient { return itr->second.get(); } - virtual size_t get_server_nums() = 0; + virtual size_t GetServerNums() = 0; - virtual std::future push_dense_raw_gradient( - int table_id, float *total_send_data, size_t total_send_data_size, - void *done) = 0; + virtual std::future PushDenseRawGradient(int table_id, + float *total_send_data, + size_t total_send_data_size, + void *done) = 0; - virtual std::future push_sparse_raw_gradient( + virtual std::future PushSparseRawGradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) = 0; - virtual std::future push_sparse_raw_gradient_partial( + virtual std::future PushSparseRawGradientPartial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) = 0; - virtual std::future push_sparse_param(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num, void *done) = 0; - virtual std::future push_sparse(size_t table_id, - const uint64_t *keys, - const float **update_values, - size_t num) = 0; + virtual std::future PushSparseParam(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) = 0; + virtual std::future PushSparse(size_t table_id, const uint64_t *keys, + const float **update_values, + size_t num) = 0; protected: - virtual int32_t initialize() = 0; + virtual int32_t Initialize() = 0; size_t _client_id; PSParameter _config; std::map> @@ -333,7 +288,7 @@ REGISTER_PSCORE_REGISTERER(PSClient); class PSClientFactory { public: - static PSClient *create(const PSParameter &config); + static PSClient *Create(const PSParameter &config); }; } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index dbf47f0df4116..bb8ba223d828e 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -19,166 +19,91 @@ namespace paddle { namespace distributed { -int32_t PsLocalClient::initialize() { +int32_t PsLocalClient::Initialize() { const auto& downpour_param = _config.server_param().downpour_server_param(); - TableManager::instance().initialize(); + TableManager::Instance().Initialize(); for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { auto* table = CREATE_PSCORE_CLASS( Table, downpour_param.downpour_table_param(i).table_class()); - table->set_shard(0, 1); - table->initialize(downpour_param.downpour_table_param(i), + table->SetShard(0, 1); + table->Initialize(downpour_param.downpour_table_param(i), _config.fs_client_param()); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); } return 0; } -::std::future PsLocalClient::shrink(uint32_t table_id, +::std::future PsLocalClient::Shrink(uint32_t table_id, const std::string threshold) { // TODO return done(); } -::std::future PsLocalClient::load(const std::string& epoch, +::std::future PsLocalClient::Load(const std::string& epoch, const std::string& mode) { // TODO for (auto& it : _table_map) { - load(it.first, epoch, mode); + Load(it.first, epoch, mode); } return done(); } -::std::future PsLocalClient::load(uint32_t table_id, +::std::future PsLocalClient::Load(uint32_t table_id, const std::string& epoch, const std::string& mode) { // TODO - auto* table_ptr = table(table_id); - table_ptr->load(epoch, mode); + auto* table_ptr = GetTable(table_id); + table_ptr->Load(epoch, mode); return done(); } -std::future PsLocalClient::Load(const LoadSaveContext& load_context) { - if (load_context.table_id < 0) { - for (auto& it : _table_map) { - load(it.first, load_context.epoch, load_context.mode); - } - return done(); - } else { - auto* table_ptr = table(load_context.table_id); - table_ptr->load(load_context.epoch, load_context.mode); - return done(); - } -} - -::std::future PsLocalClient::save(const std::string& epoch, +::std::future PsLocalClient::Save(const std::string& epoch, const std::string& mode) { // TODO for (auto& it : _table_map) { - save(it.first, epoch, mode); + Save(it.first, epoch, mode); } return done(); } -::std::future PsLocalClient::save(uint32_t table_id, +::std::future PsLocalClient::Save(uint32_t table_id, const std::string& epoch, const std::string& mode) { // TODO - auto* table_ptr = table(table_id); - table_ptr->flush(); - table_ptr->save(epoch, mode); + auto* table_ptr = GetTable(table_id); + table_ptr->Flush(); + table_ptr->Save(epoch, mode); return done(); } -::std::future PsLocalClient::Save( - const LoadSaveContext& save_context) { - if (save_context.table_id < 0) { - for (auto& it : _table_map) { - save(it.first, save_context.epoch, save_context.mode); - } - return done(); - } else { - auto* table_ptr = table(save_context.table_id); - table_ptr->flush(); - table_ptr->save(save_context.epoch, save_context.mode); - return done(); - } -} - -::std::future PsLocalClient::clear() { +::std::future PsLocalClient::Clear() { // TODO return done(); } -::std::future PsLocalClient::clear(uint32_t table_id) { +::std::future PsLocalClient::Clear(uint32_t table_id) { // TODO return done(); } -::std::future PsLocalClient::flush() { +::std::future PsLocalClient::Flush() { // no need return done(); } -::std::future PsLocalClient::stop_server() { +::std::future PsLocalClient::StopServer() { // no need return done(); } -::std::future PsLocalClient::Pull(RequestContext& pull_context) { - if (pull_context.value_type == Dense) { // pull dense - Region* dense_region = reinterpret_cast(pull_context.dense_values); - pull_dense(dense_region, pull_context.num, pull_context.table); - } else { // pull sparse - // uint64_t* keys = reinterpret_cast(pull_context.keys); - // char** select_values = - // reinterpret_cast(pull_context.sparse_values); - size_t table_id = pull_context.table; - size_t num = pull_context.num; - pull_sparse_ptr(reinterpret_cast(pull_context.sparse_values), - table_id, pull_context.keys, num); - } -} +::std::future PsLocalClient::PullDense(Region* regions, + size_t region_num, + size_t table_id) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); -::std::future PsLocalClient::Push(RequestContext& push_context) { - if (push_context.value_type == Dense) { // push dense - if (push_context.training_phase == Init) { - const Region* regions = push_context.push_context.push_dense_values; - size_t region_num = push_context.num; - push_dense_param(regions, region_num, push_context.table); - } else { - if (push_context.training_mode == Geo) { // geo - float* total_send_data = - reinterpret_cast(push_context.dense_values); - size_t total_send_data_size = push_context.num; - push_dense_raw_gradient(push_context.table, total_send_data, - total_send_data_size, push_context.callback); - } else { // async and sync - const Region* regions = push_context.push_context.push_dense_values; - size_t region_num = push_context.num; - push_dense(regions, region_num, push_context.table); - } - } - } else { // push sparse - if (push_context.training_mode == Async) { - const uint64_t* keys = push_context.push_context.keys; - const float** update_values = push_context.push_context.push_values; - size_t table_id = push_context.table; - size_t num = push_context.num; - push_sparse(table_id, keys, update_values, num); - } else { - // TODO - } - } -} - -::std::future PsLocalClient::pull_dense(Region* regions, - size_t region_num, - size_t table_id) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); + uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1); - uint32_t num_per_shard = - dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1); std::vector region_buffer; region_buffer.resize(num_per_shard); - table_ptr->pull_dense(region_buffer.data(), region_buffer.size()); + table_ptr->PullDense(region_buffer.data(), region_buffer.size()); size_t region_idx = 0; size_t region_data_idx = 0; @@ -213,48 +138,49 @@ ::std::future PsLocalClient::pull_dense(Region* regions, return done(); } -::std::future PsLocalClient::push_dense_param(const Region* regions, - size_t region_num, - size_t table_id) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); +::std::future PsLocalClient::PushDenseParam(const Region* regions, + size_t region_num, + size_t table_id) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); std::vector region_buffer; - region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1), - 0); + region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0); + for (size_t i = 0, offset = 0; i < region_num; ++i) { uint32_t data_num = regions[i].size / sizeof(float); memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size); offset += data_num; } - // table_ptr->push_dense_param(region_buffer.data(), region_buffer.size()); + // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size()); return done(); } -::std::future PsLocalClient::push_dense_raw_gradient( +::std::future PsLocalClient::PushDenseRawGradient( int table_id, float* total_send_data, size_t total_send_data_size, void* callback) { VLOG(1) << "wxx push_dense_raw_gradient"; PSClientClosure* closure = reinterpret_cast(callback); - auto* table_ptr = table(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->push_dense(total_send_data, total_send_data_size); + table_ptr->PushDense(total_send_data, total_send_data_size); delete closure; return done(); } -::std::future PsLocalClient::push_dense(const Region* regions, - size_t region_num, - size_t table_id) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); +::std::future PsLocalClient::PushDense(const Region* regions, + size_t region_num, + size_t table_id) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); std::vector region_buffer; - region_buffer.resize(dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), 1)); + region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1)); + size_t data_size = region_buffer.size(); for (size_t i = 0, offset = 0; i < region_num; ++i) { uint32_t data_num = regions[i].size / sizeof(float); @@ -267,12 +193,12 @@ ::std::future PsLocalClient::push_dense(const Region* regions, offset += data_num; } - table_ptr->push_dense(region_buffer.data(), region_buffer.size()); + table_ptr->PushDense(region_buffer.data(), region_buffer.size()); return done(); } -//::std::future PsLocalClient::pull_sparse(float** select_values, +//::std::future PsLocalClient::PullSparse(float** select_values, // size_t table_id, // const uint64_t* keys, // size_t num) { @@ -282,14 +208,14 @@ ::std::future PsLocalClient::push_dense(const Region* regions, // // auto local_timer = // // std::make_shared("pslib_downpour_client_pull_sparse_local"); // //将key拆分到各shard请求,并记录原始对应value指针 -// auto* accessor = table_accessor(table_id); -// auto* table_ptr = table(table_id); +// auto* accessor = GetTableAccessor(table_id); +// auto* table_ptr = GetTable(table_id); // size_t value_size = accessor->select_size(); // -// // table_ptr->pull_sparse(keys, num); +// // table_ptr->PullSparse(keys, num); // std::vector res_data; // res_data.resize(num * value_size / sizeof(float)); -// table_ptr->pull_sparse(res_data.data(), keys, num); +// table_ptr->PullSparse(res_data.data(), keys, num); // // memcpy(select_values[0], res_data->data(), res_data->size() * // // sizeof(float)); // size_t offset = 0; @@ -302,43 +228,43 @@ ::std::future PsLocalClient::push_dense(const Region* regions, // return done(); //} -::std::future PsLocalClient::pull_sparse_ptr(char** select_values, - size_t table_id, - const uint64_t* keys, - size_t num) { +::std::future PsLocalClient::PullSparsePtr(char** select_values, + size_t table_id, + const uint64_t* keys, + size_t num) { // FIXME // auto timer = // std::make_shared("pslib_downpour_client_pull_sparse"); // auto local_timer = // std::make_shared("pslib_downpour_client_pull_sparse_local"); //将key拆分到各shard请求,并记录原始对应value指针 - auto* table_ptr = table(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->pull_sparse_ptr(select_values, keys, num); + table_ptr->PullSparsePtr(select_values, keys, num); return done(); } -::std::future PsLocalClient::push_sparse_raw_gradient( +::std::future PsLocalClient::PushSparseRawGradient( size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* callback) { PSClientClosure* closure = reinterpret_cast(callback); - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->push_sparse(keys, update_values, num); + table_ptr->PushSparse(keys, update_values, num); delete closure; return done(); } -::std::future PsLocalClient::push_sparse(size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num) { - auto* accessor = table_accessor(table_id); - auto* table_ptr = table(table_id); +::std::future PsLocalClient::PushSparse(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num) { + auto* accessor = GetTableAccessor(table_id); + auto* table_ptr = GetTable(table_id); - table_ptr->push_sparse(keys, update_values, num); + table_ptr->PushSparse(keys, update_values, num); return done(); } } diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.h b/paddle/fluid/distributed/ps/service/ps_local_client.h index 83ca558e3d2cb..439ecf79f2f80 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.h +++ b/paddle/fluid/distributed/ps/service/ps_local_client.h @@ -26,54 +26,46 @@ class PsLocalClient : public PSClient { public: PsLocalClient() {} virtual ~PsLocalClient() { _running = false; } - virtual int32_t create_client2client_connection(int pslib_timeout_ms, - int pslib_connect_timeout_ms, - int max_retry) { + virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms, + int pslib_connect_timeout_ms, + int max_retry) { return 0; } - virtual ::std::future shrink(uint32_t table_id, + virtual ::std::future Shrink(uint32_t table_id, const std::string threshold) override; - virtual ::std::future load(const std::string& epoch, + virtual ::std::future Load(const std::string& epoch, const std::string& mode) override; - virtual ::std::future load(uint32_t table_id, + virtual ::std::future Load(uint32_t table_id, const std::string& epoch, const std::string& mode) override; - virtual std::future Load( - const LoadSaveContext& load_context) override; - virtual ::std::future save(const std::string& epoch, + virtual ::std::future Save(const std::string& epoch, const std::string& mode) override; - virtual ::std::future save(uint32_t table_id, + virtual ::std::future Save(uint32_t table_id, const std::string& epoch, const std::string& mode) override; - virtual std::future Save( - const LoadSaveContext& save_context) override; - virtual ::std::future clear() override; - virtual ::std::future clear(uint32_t table_id) override; + virtual ::std::future Clear() override; + virtual ::std::future Clear(uint32_t table_id) override; - virtual ::std::future stop_server() override; + virtual ::std::future StopServer() override; - virtual void finalize_worker() override {} - virtual ::std::future pull_dense(Region* regions, size_t region_num, - size_t table_id); + virtual void FinalizeWorker() override {} + virtual ::std::future PullDense(Region* regions, size_t region_num, + size_t table_id); - virtual ::std::future Pull(RequestContext& pull_context) override; + virtual ::std::future PushDense(const Region* regions, + size_t region_num, size_t table_id); - virtual ::std::future Push(RequestContext& push_context) override; + virtual ::std::future PushDenseParam(const Region* regions, + size_t region_num, + size_t table_id); - virtual ::std::future push_dense(const Region* regions, - size_t region_num, size_t table_id); - - virtual ::std::future push_dense_param(const Region* regions, - size_t region_num, - size_t table_id); - - virtual ::std::future pull_sparse(float** select_values, - size_t table_id, - const uint64_t* keys, size_t num, - bool is_training) { + virtual ::std::future PullSparse(float** select_values, + size_t table_id, + const uint64_t* keys, size_t num, + bool is_training) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -81,26 +73,26 @@ class PsLocalClient : public PSClient { return fut; } - virtual ::std::future pull_sparse_ptr(char** select_values, - size_t table_id, - const uint64_t* keys, - size_t num); + virtual ::std::future PullSparsePtr(char** select_values, + size_t table_id, + const uint64_t* keys, + size_t num); - virtual ::std::future print_table_stat(uint32_t table_id) { + virtual ::std::future PrintTableStat(uint32_t table_id) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } - virtual ::std::future push_sparse(size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num); + virtual ::std::future PushSparse(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num); - virtual ::std::future flush(); + virtual ::std::future Flush(); // server profilera - virtual std::future start_profiler() { + virtual std::future StartProfiler() { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -108,7 +100,7 @@ class PsLocalClient : public PSClient { return fut; }; - virtual std::future stop_profiler() { + virtual std::future StopProfiler() { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -116,7 +108,7 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future barrier(size_t table_id, uint32_t barrier_type) { + virtual std::future Barrier(size_t table_id, uint32_t barrier_type) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -124,10 +116,10 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future pull_geo_param(size_t table_id, - std::vector* values, - std::vector* keys, - int pserver_idx) { + virtual std::future PullGeoParam(size_t table_id, + std::vector* values, + std::vector* keys, + int pserver_idx) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -135,9 +127,9 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future push_global_step(int table_id, - int64_t* total_send_data, - void* done) { + virtual std::future PushGlobalStep(int table_id, + int64_t* total_send_data, + void* done) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -146,12 +138,12 @@ class PsLocalClient : public PSClient { } // recv table from server and save it in LodTensor - virtual int32_t recv_and_save_table(const uint64_t table_id, - const std::string& path) { + virtual int32_t RecvAndSaveTable(const uint64_t table_id, + const std::string& path) { return 0; } - virtual ::std::future send_client2client_msg( + virtual ::std::future SendClient2ClientMsg( int msg_type, int to_client_id, const std::string& msg) override { std::promise prom; std::future fut = prom.get_future(); @@ -159,17 +151,18 @@ class PsLocalClient : public PSClient { return fut; } - virtual size_t get_server_nums() { return 1; } + virtual size_t GetServerNums() { return 1; } - virtual std::future push_dense_raw_gradient( - int table_id, float* total_send_data, size_t total_send_data_size, - void* callback) override; + virtual std::future PushDenseRawGradient(int table_id, + float* total_send_data, + size_t total_send_data_size, + void* callback) override; - virtual std::future push_sparse_raw_gradient( + virtual std::future PushSparseRawGradient( size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* callback) override; - virtual std::future push_sparse_raw_gradient_partial( + virtual std::future PushSparseRawGradientPartial( size_t table_id, const uint64_t* keys, const float** update_values, uint32_t num, void* done, int pserver_idx) override { std::promise prom; @@ -179,11 +172,11 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future push_sparse_param(size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num, - void* done) override { + virtual std::future PushSparseParam(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num, + void* done) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -192,7 +185,7 @@ class PsLocalClient : public PSClient { } private: - virtual int32_t initialize() override; + virtual int32_t Initialize() override; std::future done() { std::shared_ptr> prom = @@ -202,16 +195,16 @@ class PsLocalClient : public PSClient { return fut; } - inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, - uint32_t shard_num) { + inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, + uint32_t shard_num) { return dense_dim_total / shard_num + 1; } - inline std::unordered_map>* table() { + inline std::unordered_map>* GetTable() { return &_table_map; } - inline Table* table(size_t table_id) { + inline Table* GetTable(size_t table_id) { auto itr = _table_map.find(table_id); if (itr != _table_map.end()) { return itr->second.get(); diff --git a/paddle/fluid/distributed/ps/service/ps_local_server.h b/paddle/fluid/distributed/ps/service/ps_local_server.h index 31b52126fc576..c09f8585b659d 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_server.h +++ b/paddle/fluid/distributed/ps/service/ps_local_server.h @@ -25,17 +25,17 @@ class PsLocalServer : public PSServer { public: PsLocalServer() {} virtual ~PsLocalServer() {} - virtual uint64_t start() { return 0; } - virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } - virtual int32_t stop() { return 0; } - virtual int32_t configure( + virtual uint64_t Start() { return 0; } + virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; } + virtual int32_t Stop() { return 0; } + virtual int32_t Configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program = {}) { return 0; } private: - virtual int32_t initialize() { return 0; } + virtual int32_t Initialize() { return 0; } }; } } diff --git a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc index c8be0f7971090..92dfeb6818a28 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc +++ b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc @@ -70,7 +70,7 @@ void GraphPyService::set_up(std::string ips_str, int shard_num, port_list.push_back(ip_and_port[1]); uint32_t port = stoul(ip_and_port[1]); auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index); - host_sign_list.push_back(ph_host.serialize_to_string()); + host_sign_list.push_back(ph_host.SerializeToString()); index++; } } @@ -83,11 +83,11 @@ void GraphPyClient::start_client() { paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list, servers_); + _ps_env.SetPsServers(&host_sign_list, servers_); worker_ptr = std::shared_ptr( (paddle::distributed::GraphBrpcClient*) - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id); worker_ptr->set_shard_num(get_shard_num()); } void GraphPyServer::start_server(bool block) { @@ -96,17 +96,17 @@ void GraphPyServer::start_server(bool block) { ::paddle::distributed::PSParameter server_proto = this->GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&this->host_sign_list, - this->host_sign_list.size()); // test + _ps_env.SetPsServers(&this->host_sign_list, + this->host_sign_list.size()); // test pserver_ptr = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); VLOG(0) << "pserver-ptr created "; std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec); - pserver_ptr->start(ip, port); + pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec); + pserver_ptr->Start(ip, port); pserver_ptr->build_peer2peer_connection(rank); std::condition_variable* cv_ = pserver_ptr->export_cv(); if (block) { @@ -246,7 +246,7 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath, VLOG(0) << "loadding data with type " << name << " from " << filepath; uint32_t table_id = this->table_id_map[name]; auto status = - get_ps_client()->load(table_id, std::string(filepath), params); + get_ps_client()->Load(table_id, std::string(filepath), params); status.wait(); } } @@ -285,7 +285,7 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; auto status = - get_ps_client()->load(table_id, std::string(filepath), params); + get_ps_client()->Load(table_id, std::string(filepath), params); status.wait(); } } @@ -396,13 +396,13 @@ std::vector GraphPyClient::pull_graph_list(std::string name, return res; } -void GraphPyClient::stop_server() { +void GraphPyClient::StopServer() { VLOG(0) << "going to stop server"; std::unique_lock lock(mutex_); if (stoped_) return; - auto status = this->worker_ptr->stop_server(); + auto status = this->worker_ptr->StopServer(); if (status.get() == 0) stoped_ = true; } -void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); } +void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); } } } diff --git a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h index 85707137c1800..19f34dad80745 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h +++ b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h @@ -123,7 +123,7 @@ class GraphPyServer : public GraphPyService { set_rank(rank); GraphPyService::set_up(ips_str, shard_num, node_types, edge_types); } - int get_rank() { return rank; } + int GetRank() { return rank; } void set_rank(int rank) { this->rank = rank; } void start_server(bool block = true); @@ -154,8 +154,8 @@ class GraphPyClient : public GraphPyService { (paddle::distributed::GraphBrpcService*)server.get_ps_server() ->get_service()); } - void stop_server(); - void finalize_worker(); + void StopServer(); + void FinalizeWorker(); void load_edge_file(std::string name, std::string filepath, bool reverse); void load_node_file(std::string name, std::string filepath); void clear_nodes(std::string name); diff --git a/paddle/fluid/distributed/ps/service/ps_service/service.cc b/paddle/fluid/distributed/ps/service/ps_service/service.cc index 73793d2f9bd0e..9c3a06c2212e6 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/service.cc +++ b/paddle/fluid/distributed/ps/service/ps_service/service.cc @@ -46,7 +46,7 @@ paddle::distributed::PSParameter load_from_prototxt( return param; } -void PSCore::init_gflag(const std::string& gflags) { +void PSCore::InitGFlag(const std::string& gflags) { VLOG(3) << "Init With Gflags:" << gflags; std::vector flags = paddle::string::split_string(gflags); if (flags.size() < 1) { @@ -65,67 +65,67 @@ void PSCore::init_gflag(const std::string& gflags) { ::GFLAGS_NAMESPACE::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); } -int PSCore::init_server( +int PSCore::InitServer( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, int trainers, const std::vector& server_sub_program) { google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); - init_gflag(_ps_param.init_gflags()); + InitGFlag(_ps_param.init_gflags()); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(host_sign_list, node_num); - _ps_env.set_trainers(trainers); + _ps_env.SetPsServers(host_sign_list, node_num); + _ps_env.SetTrainers(trainers); int ret = 0; _server_ptr = std::shared_ptr( - paddle::distributed::PSServerFactory::create(_ps_param)); - ret = _server_ptr->configure(_ps_param, _ps_env, index, server_sub_program); + paddle::distributed::PSServerFactory::Create(_ps_param)); + ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program); CHECK(ret == 0) << "failed to configure server"; return ret; } -int PSCore::init_worker( +int PSCore::InitWorker( const std::string& dist_desc, const std::map>& regions, const std::vector* host_sign_list, int node_num, int index) { google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); - init_gflag(_ps_param.init_gflags()); + InitGFlag(_ps_param.init_gflags()); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(host_sign_list, node_num); + _ps_env.SetPsServers(host_sign_list, node_num); int ret = 0; - VLOG(1) << "PSCore::init_worker"; + VLOG(1) << "PSCore::InitWorker"; auto* communicator = Communicator::GetInstance(); - ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env, + ret = communicator->GetPsClient()->Configure(_ps_param, regions, _ps_env, index); communicator->Start(); return ret; } -std::vector PSCore::get_client_info() { - return _ps_env.get_client_info(); +std::vector PSCore::GetClientInfo() { + return _ps_env.GetClientInfo(); } -int PSCore::create_client2client_connection(int pserver_timeout_ms, - int pserver_connect_timeout_ms, - int max_retry) { - int ret = _worker_ptr->create_client2client_connection( +int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) { + int ret = _worker_ptr->CreateClient2ClientConnection( pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); return ret; } -uint64_t PSCore::run_server(const std::string& ip, uint32_t port) { - return _server_ptr->start(ip, port); +uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) { + return _server_ptr->Start(ip, port); } -int PSCore::finalize_worker() { - _worker_ptr->finalize_worker(); +int PSCore::FinalizeWorker() { + _worker_ptr->FinalizeWorker(); return 0; } -int PSCore::stop_server() { - auto stop_status = _worker_ptr->stop_server(); +int PSCore::StopServer() { + auto stop_status = _worker_ptr->StopServer(); stop_status.wait(); return 0; } -paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; } +paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/ps_service/service.h b/paddle/fluid/distributed/ps/service/ps_service/service.h index 202c2407f15ae..112fdc3e14183 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/service.h +++ b/paddle/fluid/distributed/ps/service/ps_service/service.h @@ -42,31 +42,31 @@ class PSCore { explicit PSCore() {} virtual ~PSCore() {} - virtual int init_server( + virtual int InitServer( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, int trainers, const std::vector& server_sub_program = {}); - virtual int init_worker( + virtual int InitWorker( const std::string& dist_desc, const std::map>& regions, const std::vector* host_sign_list, int node_num, int index); - virtual uint64_t run_server(const std::string& ip, uint32_t port); - virtual int stop_server(); - virtual int finalize_worker(); - virtual std::vector get_client_info(); - virtual int create_client2client_connection(int pserver_timeout_ms, - int pserver_connect_timeout_ms, - int max_retry); + virtual uint64_t RunServer(const std::string& ip, uint32_t port); + virtual int StopServer(); + virtual int FinalizeWorker(); + virtual std::vector GetClientInfo(); + virtual int CreateClient2ClientConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry); std::shared_ptr _server_ptr; // pointer to server std::shared_ptr _worker_ptr; // pointer to worker - virtual paddle::distributed::PSParameter* get_param(); + virtual paddle::distributed::PSParameter* GetParam(); private: - void init_gflag(const std::string& gflags); + void InitGFlag(const std::string& gflags); paddle::distributed::PSParameter _ps_param; paddle::distributed::PaddlePSEnvironment _ps_env; }; diff --git a/paddle/fluid/distributed/ps/service/server.cc b/paddle/fluid/distributed/ps/service/server.cc index 893f671359e40..65f7ae821cef1 100644 --- a/paddle/fluid/distributed/ps/service/server.cc +++ b/paddle/fluid/distributed/ps/service/server.cc @@ -29,7 +29,7 @@ REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService); REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer); REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService); -PSServer *PSServerFactory::create(const PSParameter &ps_config) { +PSServer *PSServerFactory::Create(const PSParameter &ps_config) { const auto &config = ps_config.server_param(); if (!config.has_downpour_server_param()) { @@ -56,18 +56,18 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) { << service_param.server_class(); return NULL; } - TableManager::instance().initialize(); + TableManager::Instance().Initialize(); return server; } -int32_t PSServer::configure( +int32_t PSServer::Configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program) { scope_.reset(new framework::Scope()); _config = config.server_param(); _rank = server_rank; _environment = &env; - size_t shard_num = env.get_ps_servers().size(); + size_t shard_num = env.GetPsServers().size(); const auto &downpour_param = _config.downpour_server_param(); @@ -87,21 +87,21 @@ int32_t PSServer::configure( global_step_table = downpour_param.downpour_table_param(i).table_id(); } - table->set_program_env(scope_.get(), place_, &server_sub_program); - table->set_shard(_rank, shard_num); - table->initialize(downpour_param.downpour_table_param(i), + table->SetProgramEnv(scope_.get(), place_, &server_sub_program); + table->SetShard(_rank, shard_num); + table->Initialize(downpour_param.downpour_table_param(i), config.fs_client_param()); _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); } if (barrier_table != UINT32_MAX) { - _table_map[barrier_table]->set_table_map(&_table_map); + _table_map[barrier_table]->SetTableMap(&_table_map); } if (global_step_table != UINT32_MAX) { - _table_map[global_step_table]->set_table_map(&_table_map); + _table_map[global_step_table]->SetTableMap(&_table_map); } - return initialize(); + return Initialize(); } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/server.h b/paddle/fluid/distributed/ps/service/server.h index d2804405b4198..5da819326b052 100644 --- a/paddle/fluid/distributed/ps/service/server.h +++ b/paddle/fluid/distributed/ps/service/server.h @@ -65,19 +65,19 @@ class PSServer { PSServer(PSServer &&) = delete; PSServer(const PSServer &) = delete; - virtual int32_t configure( + virtual int32_t Configure( const PSParameter &config, PSEnvironment &env, size_t server_rank, const std::vector &server_sub_program = {}); - virtual uint64_t start(const std::string &ip, uint32_t port) = 0; - virtual int32_t stop() = 0; + virtual uint64_t Start(const std::string &ip, uint32_t port) = 0; + virtual int32_t Stop() = 0; - inline size_t rank() const { return _rank; } + inline size_t Rank() const { return _rank; } - inline PSEnvironment *environment() { return _environment; } + inline PSEnvironment *Environment() { return _environment; } - inline const ServerParameter *config() const { return &_config; } - inline Table *table(size_t table_id) { + inline const ServerParameter *Config() const { return &_config; } + inline Table *GetTable(size_t table_id) { auto itr = _table_map.find(table_id); if (itr != _table_map.end()) { return itr->second.get(); @@ -85,12 +85,12 @@ class PSServer { return NULL; } - inline std::unordered_map> *table() { + inline std::unordered_map> *GetTable() { return &_table_map; } protected: - virtual int32_t initialize() = 0; + virtual int32_t Initialize() = 0; protected: size_t _rank; @@ -129,11 +129,11 @@ class PsBaseService : public PsService { public: PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} virtual ~PsBaseService() {} - virtual size_t get_rank() { return _rank; } - virtual int32_t configure(PSServer *server) { + virtual size_t GetRank() { return _rank; } + virtual int32_t Configure(PSServer *server) { _server = server; - _rank = _server->rank(); - _config = _server->config(); + _rank = _server->Rank(); + _config = _server->Config(); return 0; } virtual void service(::google::protobuf::RpcController *controller, @@ -148,8 +148,8 @@ class PsBaseService : public PsService { LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg; } - virtual int32_t initialize() = 0; - PSServer *get_server() { return _server; } + virtual int32_t Initialize() = 0; + PSServer *GetServer() { return _server; } protected: size_t _rank; @@ -160,7 +160,7 @@ REGISTER_PSCORE_REGISTERER(PsBaseService); class PSServerFactory { public: - static PSServer *create(const PSParameter &config); + static PSServer *Create(const PSParameter &config); }; } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/barrier_table.cc b/paddle/fluid/distributed/ps/table/barrier_table.cc index 25838e7ac2f04..b9d0345313cc3 100644 --- a/paddle/fluid/distributed/ps/table/barrier_table.cc +++ b/paddle/fluid/distributed/ps/table/barrier_table.cc @@ -17,7 +17,7 @@ namespace paddle { namespace distributed { -int32_t BarrierTable::initialize() { +int32_t BarrierTable::Initialize() { auto trainers = _config.common().trainer_num(); trigger_.store(trainers); @@ -29,7 +29,7 @@ int32_t BarrierTable::initialize() { } // 0: send_barrier 1: recv_barrier 2: complete -int32_t BarrierTable::barrier(const uint32_t trainer_id, +int32_t BarrierTable::Barrier(const uint32_t trainer_id, const std::string barrier_type) { std::unique_lock lock(mutex_); @@ -56,7 +56,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, VLOG(1) << "barrier table optimize begin"; for (auto& x : *table_map_) { auto table = x.second; - table->pour(); + table->Pour(); } VLOG(1) << "barrier table optimize done"; @@ -66,7 +66,7 @@ int32_t BarrierTable::barrier(const uint32_t trainer_id, return 0; } -int32_t BarrierTable::set_table_map( +int32_t BarrierTable::SetTableMap( std::unordered_map>* table_map) { table_map_ = table_map; return 0; diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.cc b/paddle/fluid/distributed/ps/table/common_dense_table.cc index caec575e33eef..f0cb586e45190 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/common_dense_table.cc @@ -21,8 +21,8 @@ namespace distributed { int FLAGS_pslib_table_save_max_retry_dense = 3; -void CommonDenseTable::create_initializer(const std::string& attr, - const std::string& name) { +void CommonDenseTable::CreateInitializer(const std::string& attr, + const std::string& name) { auto slices = string::split_string(attr, "&"); if (slices[0] == "gaussian_random") { @@ -39,7 +39,7 @@ void CommonDenseTable::create_initializer(const std::string& attr, } } -int32_t CommonDenseTable::initialize() { +int32_t CommonDenseTable::Initialize() { _shards_task_pool.resize(task_pool_size_); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -49,12 +49,12 @@ int32_t CommonDenseTable::initialize() { VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; _global_lr = new float(1.0); - initialize_value(); - initialize_optimizer(); + InitializeValue(); + InitializeOptimizer(); return 0; } -int32_t CommonDenseTable::initialize_value() { +int32_t CommonDenseTable::InitializeValue() { auto common = _config.common(); int size = static_cast(common.params().size()); values_.resize(size); @@ -70,7 +70,7 @@ int32_t CommonDenseTable::initialize_value() { auto& initializer = common.initializers()[x]; total_dim_ += dim; - create_initializer(initializer, varname); + CreateInitializer(initializer, varname); values_[x].resize(dim); names_index_[varname] = x; @@ -92,27 +92,27 @@ int32_t CommonDenseTable::initialize_value() { param_col_ids_.insert(param_col_ids_.begin() + 1, -1); } - VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_ + VLOG(1) << "CommonDenseTable::InitializeValue total dim: " << total_dim_ << " fixed_len_params_dim: " << fixed_len_params_dim_; pull_reservoir_ = ReservoirValue(param_dim_); return 0; } -int32_t CommonDenseTable::initialize_optimizer() { +int32_t CommonDenseTable::InitializeOptimizer() { auto common = _config.common(); auto name = common.name(); auto attrs = common.attributes(); if (name == "sgd") { optimizer_ = std::make_shared(common, &values_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "adam") { optimizer_ = std::make_shared(common, &values_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "adam_d2sum") { optimizer_ = std::make_shared(common, &values_); - // optimizer_->set_global_lr(_global_lr); //no use + // optimizer_->SetGlobalLR(_global_lr); //no use } else if (name == "sum") { optimizer_ = std::make_shared(common, &values_); } else if (name == "summary") { @@ -124,34 +124,34 @@ int32_t CommonDenseTable::initialize_optimizer() { return 0; } -int32_t CommonDenseTable::set_global_lr(float* lr) { +int32_t CommonDenseTable::SetGlobalLR(float* lr) { _global_lr = lr; - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); return 0; } int32_t CommonDenseTable::Pull(TableContext& context) { CHECK(context.value_type == Dense); float* pull_values = context.pull_context.values; - return pull_dense(pull_values, context.num); + return PullDense(pull_values, context.num); } int32_t CommonDenseTable::Push(TableContext& context) { CHECK(context.value_type == Dense); if (context.push_context.values != nullptr) { const float* values = context.push_context.values; - return push_dense(values, context.num); + return PushDense(values, context.num); } return 0; } -int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { +int32_t CommonDenseTable::PullDense(float* pull_values, size_t num) { std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), pull_values); return 0; } -int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { +int32_t CommonDenseTable::PushDenseParam(const float* values, size_t num) { PADDLE_ENFORCE_GE( num, param_dim_, paddle::platform::errors::InvalidArgument( @@ -160,14 +160,14 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { return 0; } -int32_t CommonDenseTable::pour() { +int32_t CommonDenseTable::Pour() { pull_reservoir_.avg(); - _push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); + _PushDense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); pull_reservoir_.reset(); return 0; } -int32_t CommonDenseTable::push_dense(const float* values, size_t num) { +int32_t CommonDenseTable::PushDense(const float* values, size_t num) { if (sync) { std::future task = _shards_task_pool[0]->enqueue([this, &values]() -> int { @@ -176,12 +176,12 @@ int32_t CommonDenseTable::push_dense(const float* values, size_t num) { }); task.wait(); } else { - _push_dense(values, num); + _PushDense(values, num); } return 0; } -int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { +int32_t CommonDenseTable::_PushDense(const float* values, size_t num) { PADDLE_ENFORCE_GE( num, param_dim_, paddle::platform::errors::InvalidArgument( @@ -195,7 +195,7 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { [this, shard_id, &buckets, &values]() -> int { auto begin = buckets[shard_id]; auto end = buckets[shard_id + 1]; - optimizer_->update(values, param_dim_, begin, end); + optimizer_->Update(values, param_dim_, begin, end); return 0; }); } @@ -207,12 +207,12 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { return 0; } -int32_t CommonDenseTable::load(const std::string& path, +int32_t CommonDenseTable::Load(const std::string& path, const std::string& param) { if (param_dim_ <= 0) { return 0; } - std::string table_path = table_dir(path); + std::string table_path = TableDir(path); auto file_list = _afs_client.list(table_path); std::sort(file_list.begin(), file_list.end()); for (auto ff : file_list) { @@ -314,7 +314,7 @@ int32_t CommonDenseTable::load(const std::string& path, return 0; } -int32_t CommonDenseTable::save(const std::string& path, +int32_t CommonDenseTable::Save(const std::string& path, const std::string& param) { int save_param = atoi(param.c_str()); uint32_t feasign_size; @@ -323,10 +323,10 @@ int32_t CommonDenseTable::save(const std::string& path, FsChannelConfig channel_config; if (_config.compress_in_save()) { channel_config.path = paddle::string::format_string( - "%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx); + "%s/part-%03d.gz", TableDir(path).c_str(), _shard_idx); } else { channel_config.path = paddle::string::format_string( - "%s/part-%03d", table_dir(path).c_str(), _shard_idx); + "%s/part-%03d", TableDir(path).c_str(), _shard_idx); } _afs_client.remove(channel_config.path); channel_config.converter = _value_accesor->Converter(save_param).converter; diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.h b/paddle/fluid/distributed/ps/table/common_dense_table.h index cad49a0a449c4..8e4ff1ecaf487 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.h +++ b/paddle/fluid/distributed/ps/table/common_dense_table.h @@ -34,29 +34,29 @@ class CommonDenseTable : public DenseTable { public: CommonDenseTable() {} virtual ~CommonDenseTable() {} - int32_t initialize() override; - int32_t initialize_shard() override { return 0; } - virtual void create_initializer(const std::string& attr, - const std::string& name); - virtual int32_t initialize_value(); - virtual int32_t initialize_optimizer(); + int32_t Initialize() override; + int32_t InitializeShard() override { return 0; } + virtual void CreateInitializer(const std::string& attr, + const std::string& name); + virtual int32_t InitializeValue(); + virtual int32_t InitializeOptimizer(); virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - int32_t pull_dense(float* pull_values, size_t num) override; - int32_t push_dense_param(const float* values, size_t num) override; - int32_t push_dense(const float* values, size_t num) override; - int32_t pour() override; - int32_t set_global_lr(float* lr) override; + int32_t PullDense(float* pull_values, size_t num) override; + int32_t PushDenseParam(const float* values, size_t num) override; + int32_t PushDense(const float* values, size_t num) override; + int32_t Pour() override; + int32_t SetGlobalLR(float* lr) override; - int32_t load(const std::string& path, const std::string& param) override; - int32_t save(const std::string& path, const std::string& param) override; + int32_t Load(const std::string& path, const std::string& param) override; + int32_t Save(const std::string& path, const std::string& param) override; - int32_t flush() override { return 0; } - int32_t shrink(const std::string& param) override { return 0; } - void clear() override { return; } + int32_t Flush() override { return 0; } + int32_t Shrink(const std::string& param) override { return 0; } + void Clear() override { return; } protected: - int32_t _push_dense(const float* values, size_t num); + int32_t _PushDense(const float* values, size_t num); private: const int task_pool_size_ = 10; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index dcce46270d026..7aab679954709 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -448,7 +448,7 @@ int32_t GraphTable::load_graph_split_config(const std::string &path) { return 0; } -int32_t GraphTable::load(const std::string &path, const std::string ¶m) { +int32_t GraphTable::Load(const std::string &path, const std::string ¶m) { bool load_edge = (param[0] == 'e'); bool load_node = (param[0] == 'n'); if (load_edge) { @@ -1066,11 +1066,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, int32_t GraphTable::get_server_index_by_id(int64_t id) { return id % shard_num / shard_num_per_server; } -int32_t GraphTable::initialize(const TableParameter &config, +int32_t GraphTable::Initialize(const TableParameter &config, const FsClientParameter &fs_config) { LOG(INFO) << "in graphTable initialize"; _config = config; - if (initialize_accessor() != 0) { + if (InitializeAccessor() != 0) { LOG(WARNING) << "Table accessor initialize failed"; return -1; } @@ -1082,9 +1082,9 @@ int32_t GraphTable::initialize(const TableParameter &config, auto graph = config.graph_parameter(); shard_num = _config.shard_num(); LOG(INFO) << "in graphTable initialize over"; - return initialize(graph); + return Initialize(graph); } -int32_t GraphTable::initialize(const GraphParameter &graph) { +int32_t GraphTable::Initialize(const GraphParameter &graph) { #ifdef PADDLE_WITH_HETERPS if (graph.gpups_mode()) { gpups_mode = true; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 72600b42b8282..035a3de3eba63 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -280,7 +280,7 @@ class ScaledLRU { } } auto status = - thread_pool->enqueue([this]() -> int { return shrink(); }); + thread_pool->enqueue([this]() -> int { return Shrink(); }); status.wait(); } }); @@ -298,7 +298,7 @@ class ScaledLRU { LRUResponse insert(size_t index, K *keys, V *data, size_t length) { return lru_pool[index].insert(keys, data, length); } - int shrink() { + int Shrink() { int node_size = 0; for (size_t i = 0; i < lru_pool.size(); i++) { node_size += lru_pool[i].node_size - lru_pool[i].remove_count; @@ -329,7 +329,7 @@ class ScaledLRU { if (diff != 0) { __sync_fetch_and_add(&global_count, diff); if (global_count > int(1.25 * size_limit)) { - thread_pool->enqueue([this]() -> int { return shrink(); }); + thread_pool->enqueue([this]() -> int { return Shrink(); }); } } } @@ -430,11 +430,11 @@ class GraphTable : public SparseTable { virtual int32_t get_nodes_ids_by_ranges( std::vector> ranges, std::vector &res); - virtual int32_t initialize() { return 0; } - virtual int32_t initialize(const TableParameter &config, + virtual int32_t Initialize() { return 0; } + virtual int32_t Initialize(const TableParameter &config, const FsClientParameter &fs_config); - virtual int32_t initialize(const GraphParameter &config); - int32_t load(const std::string &path, const std::string ¶m); + virtual int32_t Initialize(const GraphParameter &config); + int32_t Load(const std::string &path, const std::string ¶m); int32_t load_graph_split_config(const std::string &path); int32_t load_edges(const std::string &path, bool reverse); @@ -452,26 +452,25 @@ class GraphTable : public SparseTable { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - virtual int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) { + virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) { return 0; } - virtual int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) { + virtual int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) { return 0; } virtual int32_t clear_nodes(); - virtual void clear() {} - virtual int32_t flush() { return 0; } - virtual int32_t shrink(const std::string ¶m) { return 0; } + virtual void Clear() {} + virtual int32_t Flush() { return 0; } + virtual int32_t Shrink(const std::string ¶m) { return 0; } //指定保存路径 - virtual int32_t save(const std::string &path, const std::string &converter) { + virtual int32_t Save(const std::string &path, const std::string &converter) { return 0; } - virtual int32_t initialize_shard() { return 0; } - virtual int32_t set_shard(size_t shard_idx, size_t server_num) { + virtual int32_t InitializeShard() { return 0; } + virtual int32_t SetShard(size_t shard_idx, size_t server_num) { _shard_idx = shard_idx; /* _shard_num is not used in graph_table, this following operation is for the diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.cc b/paddle/fluid/distributed/ps/table/common_sparse_table.cc index 1fc8adc2b92eb..6b3d3a6ea1584 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.cc @@ -167,7 +167,7 @@ int64_t CommonSparseTable::LoadFromText( return 0; } -int32_t CommonSparseTable::initialize() { +int32_t CommonSparseTable::Initialize() { _shards_task_pool.resize(task_pool_size_); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -200,15 +200,15 @@ int32_t CommonSparseTable::initialize() { offset += dim; } - initialize_value(); - initialize_optimizer(); - initialize_recorder(); + InitializeValue(); + InitializeOptimizer(); + InitializeRecorder(); return 0; } -int32_t CommonSparseTable::initialize_recorder() { return 0; } +int32_t CommonSparseTable::InitializeRecorder() { return 0; } -int32_t CommonSparseTable::initialize_value() { +int32_t CommonSparseTable::InitializeValue() { auto common = _config.common(); shard_values_.reserve(task_pool_size_); @@ -223,18 +223,18 @@ int32_t CommonSparseTable::initialize_value() { return 0; } -int32_t CommonSparseTable::initialize_optimizer() { +int32_t CommonSparseTable::InitializeOptimizer() { auto common = _config.common(); auto name = common.name(); if (name == "sgd") { optimizer_ = std::make_shared(value_names_, value_dims_, value_offsets_, value_idx_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "adam") { optimizer_ = std::make_shared(value_names_, value_dims_, value_offsets_, value_idx_); - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); } else if (name == "sum") { optimizer_ = std::make_shared(value_names_, value_dims_, value_offsets_, value_idx_); @@ -246,13 +246,13 @@ int32_t CommonSparseTable::initialize_optimizer() { return 0; } -int32_t CommonSparseTable::set_global_lr(float* lr) { +int32_t CommonSparseTable::SetGlobalLR(float* lr) { _global_lr = lr; - optimizer_->set_global_lr(_global_lr); + optimizer_->SetGlobalLR(_global_lr); return 0; } -int32_t CommonSparseTable::load(const std::string& dirname, +int32_t CommonSparseTable::Load(const std::string& dirname, const std::string& param) { auto begin = GetCurrentUS(); rwlock_->WRLock(); @@ -276,7 +276,7 @@ int32_t CommonSparseTable::load(const std::string& dirname, return 0; } -int32_t CommonSparseTable::save(const std::string& dirname, +int32_t CommonSparseTable::Save(const std::string& dirname, const std::string& param) { auto begin = GetCurrentUS(); rwlock_->WRLock(); @@ -322,7 +322,7 @@ int32_t CommonSparseTable::save(const std::string& dirname, return 0; } -std::pair CommonSparseTable::print_table_stat() { +std::pair CommonSparseTable::PrintTableStat() { int64_t feasign_size = 0; int64_t mf_size = 0; @@ -335,7 +335,7 @@ std::pair CommonSparseTable::print_table_stat() { return {feasign_size, mf_size}; } -int32_t CommonSparseTable::pour() { +int32_t CommonSparseTable::Pour() { std::vector values; std::vector keys; @@ -349,7 +349,7 @@ int32_t CommonSparseTable::pour() { std::copy(reservoir.values.begin(), reservoir.values.end(), std::back_inserter(values)); } - _push_sparse(keys.data(), values.data(), pull_reservoir_.size()); + _PushSparse(keys.data(), values.data(), pull_reservoir_.size()); pull_reservoir_.clear(); return 0; @@ -360,11 +360,11 @@ int32_t CommonSparseTable::Pull(TableContext& context) { if (context.use_ptr) { char** pull_values = context.pull_context.ptr_values; const uint64_t* keys = context.pull_context.keys; - return pull_sparse_ptr(pull_values, keys, context.num); + return PullSparsePtr(pull_values, keys, context.num); } else { float* pull_values = context.pull_context.values; const PullSparseValue& pull_value = context.pull_context.pull_value; - return pull_sparse(pull_values, pull_value); + return PullSparse(pull_values, pull_value); } } @@ -373,16 +373,16 @@ int32_t CommonSparseTable::Push(TableContext& context) { if (context.push_context.values != nullptr) { const float* values = context.push_context.values; const uint64_t* keys = context.push_context.keys; - return push_sparse(keys, values, context.num); + return PushSparse(keys, values, context.num); } else { const float** values = context.push_context.ptr_values; const uint64_t* keys = context.push_context.keys; - return push_sparse(keys, values, context.num); + return PushSparse(keys, values, context.num); } } -int32_t CommonSparseTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t CommonSparseTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { auto shard_num = task_pool_size_; std::vector> tasks(shard_num); @@ -421,8 +421,8 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, return 0; } -int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, - const uint64_t* keys, size_t num) { +int32_t CommonSparseTable::PullSparsePtr(char** pull_values, + const uint64_t* keys, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -458,8 +458,8 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, return 0; } -int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t CommonSparseTable::_PushSparse(const uint64_t* keys, + const float* values, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -474,7 +474,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( [this, shard_id, &keys, &values, num, &offset_bucket]() -> int { auto& offsets = offset_bucket[shard_id]; - optimizer_->update(keys, values, num, offsets, + optimizer_->Update(keys, values, num, offsets, shard_values_[shard_id].get()); return 0; }); @@ -486,8 +486,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, return 0; } -int32_t CommonSparseTable::push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t CommonSparseTable::PushSparse(const uint64_t* keys, const float* values, + size_t num) { if (sync) { std::future task = _shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int { @@ -506,20 +506,20 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys, }); task.wait(); } else { - _push_sparse(keys, values, num); + _PushSparse(keys, values, num); } return 0; } -int32_t CommonSparseTable::push_sparse(const uint64_t* keys, - const float** values, size_t num) { - _push_sparse(keys, values, num); +int32_t CommonSparseTable::PushSparse(const uint64_t* keys, + const float** values, size_t num) { + _PushSparse(keys, values, num); return 0; } -int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, - const float** values, size_t num) { +int32_t CommonSparseTable::_PushSparse(const uint64_t* keys, + const float** values, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -536,7 +536,7 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, auto& offsets = offset_bucket[shard_id]; for (size_t i = 0; i < offsets.size(); ++i) { std::vector tmp_off = {0}; - optimizer_->update(keys + offsets[i], values[offsets[i]], num, + optimizer_->Update(keys + offsets[i], values[offsets[i]], num, tmp_off, shard_values_[shard_id].get()); } return 0; @@ -549,8 +549,8 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, return 0; } -int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, - const float* values, size_t num) { +int32_t CommonSparseTable::PushSparseParam(const uint64_t* keys, + const float* values, size_t num) { std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -585,21 +585,21 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, return 0; } -int32_t CommonSparseTable::flush() { return 0; } +int32_t CommonSparseTable::Flush() { return 0; } -int32_t CommonSparseTable::shrink(const std::string& param) { +int32_t CommonSparseTable::Shrink(const std::string& param) { int threshold = std::stoi(param); - VLOG(3) << "sparse table shrink: " << threshold; + VLOG(3) << "sparse table Shrink: " << threshold; for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { - // shrink - VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink"; + // Shrink + VLOG(4) << shard_id << " " << task_pool_size_ << " begin Shrink"; shard_values_[shard_id]->Shrink(threshold); } return 0; } -void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; } +void CommonSparseTable::Clear() { VLOG(0) << "clear coming soon"; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.h b/paddle/fluid/distributed/ps/table/common_sparse_table.h index 138c544742066..f6deaf0a82b13 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.h @@ -114,25 +114,23 @@ class CommonSparseTable : public SparseTable { virtual ~CommonSparseTable() {} // unused method begin - virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } - virtual int32_t push_dense_param(const float* values, size_t num) { - return 0; - } - virtual int32_t push_dense(const float* values, size_t num) { return 0; } + virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } + virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } + virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - virtual int32_t initialize(); - virtual int32_t initialize_shard() { return 0; } - virtual int32_t initialize_value(); - virtual int32_t initialize_optimizer(); - virtual int32_t initialize_recorder(); + virtual int32_t Initialize(); + virtual int32_t InitializeShard() { return 0; } + virtual int32_t InitializeValue(); + virtual int32_t InitializeOptimizer(); + virtual int32_t InitializeRecorder(); - virtual int32_t load(const std::string& path, const std::string& param); + virtual int32_t Load(const std::string& path, const std::string& param); - virtual int32_t save(const std::string& path, const std::string& param); + virtual int32_t Save(const std::string& path, const std::string& param); void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, const size_t shard_idx, const int64_t total); @@ -150,34 +148,34 @@ class CommonSparseTable : public SparseTable { const int pserver_id, const int pserver_num, const int local_shard_num, std::vector>* blocks); - virtual std::pair print_table_stat(); - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual std::pair PrintTableStat(); + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, - size_t num); + virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float** values, + size_t num); // only for sparse geo table - virtual int32_t push_sparse_param(const uint64_t* keys, const float* values, - size_t num); + virtual int32_t PushSparseParam(const uint64_t* keys, const float* values, + size_t num); - virtual int32_t set_global_lr(float* lr) override; + virtual int32_t SetGlobalLR(float* lr) override; - virtual int32_t pour(); - virtual int32_t flush(); - virtual int32_t shrink(const std::string& param); - virtual void clear(); + virtual int32_t Pour(); + virtual int32_t Flush(); + virtual int32_t Shrink(const std::string& param); + virtual void Clear(); protected: - virtual int32_t _push_sparse(const uint64_t* keys, const float* values, - size_t num); - virtual int32_t _push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t _PushSparse(const uint64_t* keys, const float* values, + size_t num); + virtual int32_t _PushSparse(const uint64_t* keys, const float** values, + size_t num); protected: const int task_pool_size_ = 11; diff --git a/paddle/fluid/distributed/ps/table/common_table.h b/paddle/fluid/distributed/ps/table/common_table.h index 3d291c0152246..f5e263e8e7189 100644 --- a/paddle/fluid/distributed/ps/table/common_table.h +++ b/paddle/fluid/distributed/ps/table/common_table.h @@ -71,11 +71,11 @@ class SparseTable : public Table { SparseTable() {} virtual ~SparseTable() {} - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } static int32_t sparse_local_shard_num(uint32_t shard_num, uint32_t server_num) { @@ -97,19 +97,17 @@ class DenseTable : public Table { DenseTable() {} virtual ~DenseTable() {} - virtual void *get_shard(size_t shard_idx) { return 0; } - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + virtual void *GetShard(size_t shard_idx) { return 0; } + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t push_dense_param(const float *values, size_t num) override { - return 0; - } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t PushDenseParam(const float *values, size_t num) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } }; class BarrierTable : public Table { @@ -117,44 +115,42 @@ class BarrierTable : public Table { BarrierTable() {} virtual ~BarrierTable() {} - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_dense_param(const float *values, size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } - virtual void clear() {} - virtual int32_t flush() { return 0; } - virtual int32_t load(const std::string &path, const std::string ¶m) { + int32_t PushDenseParam(const float *values, size_t num) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } + virtual void Clear() {} + virtual int32_t Flush() { return 0; } + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t initialize() override; + virtual int32_t Initialize() override; // only for barrier // 0: send_barrier 1: recv_barrier 2: complete - virtual int32_t barrier(const uint32_t trainer_id, + virtual int32_t Barrier(const uint32_t trainer_id, const std::string barrier_type) override; - virtual int32_t set_table_map( + virtual int32_t SetTableMap( std::unordered_map> *table_map) override; private: diff --git a/paddle/fluid/distributed/ps/table/depends/dense.h b/paddle/fluid/distributed/ps/table/depends/dense.h index 8661eb1feecc8..258c0f4b6a4e6 100644 --- a/paddle/fluid/distributed/ps/table/depends/dense.h +++ b/paddle/fluid/distributed/ps/table/depends/dense.h @@ -34,9 +34,9 @@ class DenseOptimizer { DenseOptimizer() {} explicit DenseOptimizer(const CommonAccessorParameter& accessor, std::vector>* values) {} - virtual void update(const float* update_values, size_t num, int begin, + virtual void Update(const float* update_values, size_t num, int begin, int end) = 0; - virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } + virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; } protected: float* global_learning_rate_; @@ -55,7 +55,7 @@ class DSUM : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; GetBlas().VADD(update_numel, update_values + begin, param + begin, @@ -81,7 +81,7 @@ class DSGD : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; std::vector grads; @@ -134,7 +134,7 @@ class DAdam : public DenseOptimizer { // make sure common_dense_table.task_pool_size_ == 1; // otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; std::vector grad, grad2, tmp; @@ -214,7 +214,7 @@ class DAdamD2Sum : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; Eigen::Map mat_ada_g2sum(ada_g2sum + begin, 1, @@ -276,7 +276,7 @@ class DSummary : public DenseOptimizer { } } - void update(const float* update_values, size_t num, int begin, + void Update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; Eigen::Map mat_w(param + begin, 1, update_numel); diff --git a/paddle/fluid/distributed/ps/table/depends/sparse.h b/paddle/fluid/distributed/ps/table/depends/sparse.h index d4ea7829e45f8..7eed5ab6c794b 100644 --- a/paddle/fluid/distributed/ps/table/depends/sparse.h +++ b/paddle/fluid/distributed/ps/table/depends/sparse.h @@ -40,11 +40,11 @@ class SparseOptimizer { value_offsets_(value_offsets), value_idx_(value_idx) {} - virtual void update(const uint64_t* keys, const float* update_values, + virtual void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) = 0; - virtual void set_global_lr(float* lr) { global_learning_rate_ = lr; } + virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; } const std::vector& value_names_; const std::vector& value_dims_; @@ -70,7 +70,7 @@ class SSUM : public SparseOptimizer { update_numel = value_dims.at(idx); } - void update(const uint64_t* keys, const float* update_values, size_t num, + void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) override { auto blas = GetBlas(); @@ -100,7 +100,7 @@ class SSGD : public SparseOptimizer { lr_offset = value_offsets.at(idx); } - void update(const uint64_t* keys, const float* update_values, size_t num, + void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) override { auto blas = GetBlas(); @@ -156,7 +156,7 @@ class SAdam : public SparseOptimizer { epsilon = 1.0e-8; } - void update(const uint64_t* keys, const float* update_values, size_t num, + void Update(const uint64_t* keys, const float* update_values, size_t num, const std::vector& offsets, ValueBlock* block) override { auto blas = GetBlas(); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc index f16f4fc7f34a5..979e1c482547c 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc @@ -17,11 +17,10 @@ namespace paddle { namespace distributed { -int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, - const float* values, - size_t num) { - VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin " - "push_sparse_param " +int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys, + const float* values, size_t num) { + VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin " + "PushSparseParam " << num; auto shard_num = _task_pool_size; std::vector> offset_bucket; @@ -31,8 +30,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, auto y = keys[x] % shard_num; offset_bucket[y].push_back(x); if (x < 10) { - VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: " - << keys[x] << " shard: " << y; + VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam key: " << keys[x] + << " shard: " << y; } } @@ -51,8 +50,8 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, feature_value.resize(_dim); std::copy_n(values + _dim * offset, _dim, feature_value.data()); if (i < 10) { - VLOG(5) << "MemorySparseGeoTable::push_sparse_param " - "push_sparse_param key " + VLOG(5) << "MemorySparseGeoTable::PushSparseParam " + "PushSparseParam key " << id << " value[0]: " << (values + _dim * offset)[0] << " data: " << feature_value.data()[0] << " value[-1]: " << (values + _dim * offset)[_dim - 1] @@ -69,9 +68,9 @@ int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys, return 0; } -int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, - std::vector* values, - std::vector* ids) { +int32_t MemorySparseGeoTable::PullGeoParam(const uint32_t trainer_id, + std::vector* values, + std::vector* ids) { _geo_recorder->GetAndClear(trainer_id, ids); VLOG(5) << "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id " @@ -86,23 +85,23 @@ int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id, pull_value.frequencies_ = frequencies.data(); values->resize(ids->size() * _dim); - pull_sparse(values->data(), pull_value); + PullSparse(values->data(), pull_value); return 0; } -int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys, - const float* values, size_t num) { - VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0] +int32_t MemorySparseGeoTable::PushSparse(const uint64_t* keys, + const float* values, size_t num) { + VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparse keys[0]" << keys[0] << " key_num: " << num; std::vector ids; ids.resize(num); std::copy_n(keys, num, ids.begin()); _geo_recorder->Update(ids); - _push_sparse(keys, values, num); + _PushSparse(keys, values, num); return 0; } -int32_t MemorySparseGeoTable::initialize() { +int32_t MemorySparseGeoTable::Initialize() { if (!_geo_recorder) { auto trainers = _config.common().trainer_num(); _geo_recorder = std::make_shared(trainers); @@ -118,8 +117,8 @@ int32_t MemorySparseGeoTable::initialize() { return 0; } -int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t MemorySparseGeoTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { auto shard_num = _task_pool_size; std::vector> tasks(shard_num); @@ -146,13 +145,13 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, auto& feature_value = local_shard[key]; feature_value.resize(_dim); memset(feature_value.data(), 0, sizeof(float) * _dim); - VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! " + VLOG(0) << "MemorySparseGeoTable PullSparse key not found!!! " << key; itr = local_shard.find(key); } memcpy(select_data, itr.value().data(), _dim * sizeof(float)); - VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key + VLOG(5) << "DEBUG MemorySparseGeoTable::PullSparse key: " << key << " select_data[0] " << select_data[0] << " value[0]: " << itr.value().data()[0]; } @@ -167,8 +166,8 @@ int32_t MemorySparseGeoTable::pull_sparse(float* pull_values, return 0; } -int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys, + const float* values, size_t num) { auto shard_num = _task_pool_size; std::vector> tasks(shard_num); std::vector>> task_keys(shard_num); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h index 3b43f99543fdd..1a74df32db8e7 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h @@ -40,31 +40,31 @@ class MemorySparseGeoTable : public SparseTable { MemorySparseGeoTable() { _geo_recorder = nullptr; } virtual ~MemorySparseGeoTable() {} - virtual int32_t initialize(); - virtual int32_t initialize_shard() { return 0; } - virtual int32_t load(const std::string& path, const std::string& param) { + virtual int32_t Initialize(); + virtual int32_t InitializeShard() { return 0; } + virtual int32_t Load(const std::string& path, const std::string& param) { return 0; } - virtual int32_t save(const std::string& path, const std::string& param) { + virtual int32_t Save(const std::string& path, const std::string& param) { return 0; } virtual int32_t Pull(TableContext& context) { return 0; } virtual int32_t Push(TableContext& context) { return 0; } - virtual int32_t flush() { return 0; } - virtual int32_t shrink(const std::string& param) { return 0; } - virtual void clear() { return; } - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual int32_t Flush() { return 0; } + virtual int32_t Shrink(const std::string& param) { return 0; } + virtual void Clear() { return; } + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - int32_t push_sparse_param(const uint64_t* keys, const float* values, - size_t num); + int32_t PushSparseParam(const uint64_t* keys, const float* values, + size_t num); // TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse - int32_t pull_geo_param(const uint32_t trainer_id, std::vector* values, - std::vector* keys); + int32_t PullGeoParam(const uint32_t trainer_id, std::vector* values, + std::vector* keys); - int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num) override; + int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num) override; - int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num); + int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num); // int32_t _pull_sparse(float* pull_values, const PullSparseValue& // pull_value); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index 61ea2f8f2007e..97e3c008d9478 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -31,7 +31,7 @@ bool FLAGS_pserver_create_value_when_push = true; int FLAGS_pserver_table_save_max_retry = 3; bool FLAGS_pserver_enable_create_feasign_randomly = false; -int32_t MemorySparseTable::initialize() { +int32_t MemorySparseTable::Initialize() { _shards_task_pool.resize(_task_pool_size); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -39,12 +39,12 @@ int32_t MemorySparseTable::initialize() { auto& profiler = CostProfiler::instance(); profiler.register_profiler("pserver_sparse_update_all"); profiler.register_profiler("pserver_sparse_select_all"); - initialize_value(); + InitializeValue(); VLOG(0) << "initalize MemorySparseTable succ"; return 0; } -int32_t MemorySparseTable::initialize_value() { +int32_t MemorySparseTable::InitializeValue() { _sparse_table_shard_num = static_cast(_config.shard_num()); _avg_local_shard_num = SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num); @@ -64,14 +64,14 @@ int32_t MemorySparseTable::initialize_value() { return 0; } -int32_t MemorySparseTable::load(const std::string& path, +int32_t MemorySparseTable::Load(const std::string& path, const std::string& param) { - std::string table_path = table_dir(path); + std::string table_path = TableDir(path); auto file_list = _afs_client.list(table_path); std::sort(file_list.begin(), file_list.end()); for (auto file : file_list) { - VLOG(1) << "MemorySparseTable::load() file list: " << file; + VLOG(1) << "MemorySparseTable::Load() file list: " << file; } int load_param = atoi(param.c_str()); @@ -154,9 +154,9 @@ int32_t MemorySparseTable::load(const std::string& path, return 0; } -int32_t MemorySparseTable::load_local_fs(const std::string& path, - const std::string& param) { - std::string table_path = table_dir(path); +int32_t MemorySparseTable::LoadLocalFS(const std::string& path, + const std::string& param) { + std::string table_path = TableDir(path); auto file_list = paddle::framework::localfs_list(table_path); int load_param = atoi(param.c_str()); @@ -225,12 +225,12 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path, return 0; } -int32_t MemorySparseTable::save(const std::string& dirname, +int32_t MemorySparseTable::Save(const std::string& dirname, const std::string& param) { VLOG(0) << "MemorySparseTable::save dirname: " << dirname; int save_param = atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 - std::string table_path = table_dir(dirname); + std::string table_path = TableDir(dirname); _afs_client.remove(paddle::string::format_string( "%s/part-%03d-*", table_path.c_str(), _shard_idx)); std::atomic feasign_size_all{0}; @@ -309,12 +309,12 @@ int32_t MemorySparseTable::save(const std::string& dirname, return 0; } -int32_t MemorySparseTable::save_local_fs(const std::string& dirname, - const std::string& param, - const std::string& prefix) { +int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname, + const std::string& param, + const std::string& prefix) { int save_param = atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 - std::string table_path = table_dir(dirname); + std::string table_path = TableDir(dirname); int feasign_cnt = 0; size_t file_start_idx = _avg_local_shard_num * _shard_idx; @@ -349,7 +349,7 @@ int32_t MemorySparseTable::save_local_fs(const std::string& dirname, return 0; } -int64_t MemorySparseTable::local_size() { +int64_t MemorySparseTable::LocalSize() { int64_t local_size = 0; for (size_t i = 0; i < _real_local_shard_num; ++i) { local_size += _local_shards[i].size(); @@ -357,7 +357,7 @@ int64_t MemorySparseTable::local_size() { return local_size; } -int64_t MemorySparseTable::local_mf_size() { +int64_t MemorySparseTable::LocalMFSize() { std::vector size_arr(_real_local_shard_num, 0); std::vector> tasks(_real_local_shard_num); int64_t ret_size = 0; @@ -384,9 +384,9 @@ int64_t MemorySparseTable::local_mf_size() { return ret_size; } -std::pair MemorySparseTable::print_table_stat() { - int64_t feasign_size = local_size(); - int64_t mf_size = local_mf_size(); +std::pair MemorySparseTable::PrintTableStat() { + int64_t feasign_size = LocalSize(); + int64_t mf_size = LocalMFSize(); return {feasign_size, mf_size}; } @@ -395,11 +395,11 @@ int32_t MemorySparseTable::Pull(TableContext& context) { if (context.use_ptr) { char** pull_values = context.pull_context.ptr_values; const uint64_t* keys = context.pull_context.keys; - return pull_sparse_ptr(pull_values, keys, context.num); + return PullSparsePtr(pull_values, keys, context.num); } else { float* pull_values = context.pull_context.values; const PullSparseValue& pull_value = context.pull_context.pull_value; - return pull_sparse(pull_values, pull_value); + return PullSparse(pull_values, pull_value); } } @@ -407,11 +407,11 @@ int32_t MemorySparseTable::Push(TableContext& context) { CHECK(context.value_type == Sparse); const uint64_t* keys = context.push_context.keys; - return push_sparse(keys, context.push_context.values, context.num); + return PushSparse(keys, context.push_context.values, context.num); } -int32_t MemorySparseTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t MemorySparseTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { CostTimer timer("pserver_sparse_select_all"); std::vector> tasks(_real_local_shard_num); @@ -479,8 +479,8 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values, return 0; } -int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, - const uint64_t* keys, size_t num) { +int32_t MemorySparseTable::PullSparsePtr(char** pull_values, + const uint64_t* keys, size_t num) { CostTimer timer("pscore_sparse_select_all"); size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); @@ -530,8 +530,8 @@ int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, return 0; } -int32_t MemorySparseTable::push_sparse(const uint64_t* keys, - const float* values, size_t num) { +int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, + size_t num) { CostTimer timer("pserver_sparse_update_all"); std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( @@ -603,14 +603,14 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys, return 0; } -int32_t MemorySparseTable::push_sparse(const uint64_t* keys, - const float** values, size_t num) { - _push_sparse(keys, values, num); +int32_t MemorySparseTable::PushSparse(const uint64_t* keys, + const float** values, size_t num) { + _PushSparse(keys, values, num); return 0; } -int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, - const float** values, size_t num) { +int32_t MemorySparseTable::_PushSparse(const uint64_t* keys, + const float** values, size_t num) { std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( _real_local_shard_num); @@ -677,13 +677,13 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, return 0; } -int32_t MemorySparseTable::flush() { return 0; } +int32_t MemorySparseTable::Flush() { return 0; } -int32_t MemorySparseTable::shrink(const std::string& param) { - VLOG(0) << "MemorySparseTable::shrink"; +int32_t MemorySparseTable::Shrink(const std::string& param) { + VLOG(0) << "MemorySparseTable::Shrink"; // TODO(zhaocaibei123): implement with multi-thread for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { - // shrink + // Shrink auto& shard = _local_shards[shard_id]; for (auto it = shard.begin(); it != shard.end();) { if (_value_accesor->Shrink(it.value().data())) { @@ -696,7 +696,7 @@ int32_t MemorySparseTable::shrink(const std::string& param) { return 0; } -void MemorySparseTable::clear() { VLOG(0) << "clear coming soon"; } +void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index d26c67319760d..a4af4caa472d7 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -41,50 +41,48 @@ class MemorySparseTable : public SparseTable { virtual ~MemorySparseTable() {} // unused method begin - virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } - virtual int32_t push_dense_param(const float* values, size_t num) { - return 0; - } - virtual int32_t push_dense(const float* values, size_t num) { return 0; } + virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } + virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } + virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - virtual int32_t initialize(); - virtual int32_t initialize_shard() { return 0; } - virtual int32_t initialize_value(); + virtual int32_t Initialize(); + virtual int32_t InitializeShard() { return 0; } + virtual int32_t InitializeValue(); - virtual int32_t load(const std::string& path, const std::string& param); + virtual int32_t Load(const std::string& path, const std::string& param); - virtual int32_t save(const std::string& path, const std::string& param); + virtual int32_t Save(const std::string& path, const std::string& param); - int32_t load_local_fs(const std::string& path, const std::string& param); - int32_t save_local_fs(const std::string& path, const std::string& param, - const std::string& prefix); + int32_t LoadLocalFS(const std::string& path, const std::string& param); + int32_t SaveLocalFS(const std::string& path, const std::string& param, + const std::string& prefix); - int64_t local_size(); - int64_t local_mf_size(); + int64_t LocalSize(); + int64_t LocalMFSize(); - virtual std::pair print_table_stat(); - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual std::pair PrintTableStat(); + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, - size_t num); + virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num); - virtual int32_t push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float** values, + size_t num); - virtual int32_t flush(); - virtual int32_t shrink(const std::string& param); - virtual void clear(); + virtual int32_t Flush(); + virtual int32_t Shrink(const std::string& param); + virtual void Clear(); protected: - virtual int32_t _push_sparse(const uint64_t* keys, const float** values, - size_t num); + virtual int32_t _PushSparse(const uint64_t* keys, const float** values, + size_t num); protected: const int _task_pool_size = 24; diff --git a/paddle/fluid/distributed/ps/table/sparse_geo_table.cc b/paddle/fluid/distributed/ps/table/sparse_geo_table.cc index 6ef4330113e8f..de9628a5b5235 100644 --- a/paddle/fluid/distributed/ps/table/sparse_geo_table.cc +++ b/paddle/fluid/distributed/ps/table/sparse_geo_table.cc @@ -17,9 +17,9 @@ namespace paddle { namespace distributed { -int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id, - std::vector* values, - std::vector* ids) { +int32_t SparseGeoTable::PullGeoParam(const uint32_t trainer_id, + std::vector* values, + std::vector* ids) { geo_recorder->GetAndClear(trainer_id, ids); auto dim = _config.common().dims()[0]; @@ -32,21 +32,21 @@ int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id, pull_value.frequencies_ = frequencies.data(); values->resize(ids->size() * dim); - CommonSparseTable::pull_sparse(values->data(), pull_value); + CommonSparseTable::PullSparse(values->data(), pull_value); return 0; } -int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values, - size_t num) { +int32_t SparseGeoTable::PushSparse(const uint64_t* keys, const float* values, + size_t num) { std::vector ids; ids.resize(num); std::copy_n(keys, num, ids.begin()); geo_recorder->Update(ids); - CommonSparseTable::push_sparse(keys, values, num); + CommonSparseTable::PushSparse(keys, values, num); return 0; } -int32_t SparseGeoTable::initialize_value() { +int32_t SparseGeoTable::InitializeValue() { auto common = _config.common(); shard_values_.reserve(task_pool_size_); @@ -82,7 +82,7 @@ int32_t SparseGeoTable::initialize_value() { auto pull_value = PullSparseValue(ids, fres, param_dim_); std::vector pulls; pulls.resize(bucket_feasigns * param_dim_); - pull_sparse(pulls.data(), pull_value); + PullSparse(pulls.data(), pull_value); } return 0; } diff --git a/paddle/fluid/distributed/ps/table/sparse_geo_table.h b/paddle/fluid/distributed/ps/table/sparse_geo_table.h index 1151c9f81ac97..261338c2ba7b1 100644 --- a/paddle/fluid/distributed/ps/table/sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/sparse_geo_table.h @@ -44,15 +44,15 @@ class SparseGeoTable : public CommonSparseTable { explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; } virtual ~SparseGeoTable() {} - virtual int32_t initialize_value(); + virtual int32_t InitializeValue(); - int32_t pull_geo_param(const uint32_t trainer_id, std::vector* values, - std::vector* keys); + int32_t PullGeoParam(const uint32_t trainer_id, std::vector* values, + std::vector* keys); - int32_t push_sparse(const uint64_t* keys, const float* values, - size_t num) override; + int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num) override; - virtual int32_t initialize_recorder() { + virtual int32_t InitializeRecorder() { if (!geo_recorder) { auto trainers = _config.common().trainer_num(); geo_recorder = std::make_shared(trainers); diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc index 5bc58bc5a1108..484fa9e1c6eea 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc @@ -20,7 +20,7 @@ DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file"); namespace paddle { namespace distributed { -int32_t SSDSparseTable::initialize() { +int32_t SSDSparseTable::Initialize() { _shards_task_pool.resize(task_pool_size_); for (int i = 0; i < _shards_task_pool.size(); ++i) { _shards_task_pool[i].reset(new ::ThreadPool(1)); @@ -53,9 +53,9 @@ int32_t SSDSparseTable::initialize() { offset += dim; } - initialize_value(); - initialize_optimizer(); - initialize_recorder(); + InitializeValue(); + InitializeOptimizer(); + InitializeRecorder(); _db = paddle::distributed::RocksDBHandler::GetInstance(); _db->initialize(FLAGS_rocksdb_path, task_pool_size_); return 0; @@ -66,18 +66,18 @@ int32_t SSDSparseTable::Pull(TableContext& context) { if (context.use_ptr) { char** pull_values = context.pull_context.ptr_values; const uint64_t* keys = context.pull_context.keys; - return pull_sparse_ptr(pull_values, keys, context.num); + return PullSparsePtr(pull_values, keys, context.num); } else { float* pull_values = context.pull_context.values; const PullSparseValue& pull_value = context.pull_context.pull_value; - return pull_sparse(pull_values, pull_value); + return PullSparse(pull_values, pull_value); } } int32_t SSDSparseTable::Push(TableContext& context) { return 0; } -int32_t SSDSparseTable::pull_sparse(float* pull_values, - const PullSparseValue& pull_value) { +int32_t SSDSparseTable::PullSparse(float* pull_values, + const PullSparseValue& pull_value) { auto shard_num = task_pool_size_; std::vector> tasks(shard_num); @@ -140,8 +140,8 @@ int32_t SSDSparseTable::pull_sparse(float* pull_values, return 0; } -int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, - const uint64_t* keys, size_t num) { +int32_t SSDSparseTable::PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num) { auto shard_num = task_pool_size_; std::vector> tasks(shard_num); @@ -201,9 +201,9 @@ int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values, return 0; } -int32_t SSDSparseTable::shrink(const std::string& param) { return 0; } +int32_t SSDSparseTable::Shrink(const std::string& param) { return 0; } -int32_t SSDSparseTable::update_table() { +int32_t SSDSparseTable::UpdateTable() { int count = 0; int value_size = shard_values_[0]->value_length_; int db_size = 3 + value_size; @@ -299,7 +299,7 @@ int64_t SSDSparseTable::SaveValueToText(std::ostream* os, return save_num; } -int32_t SSDSparseTable::load(const std::string& path, +int32_t SSDSparseTable::Load(const std::string& path, const std::string& param) { rwlock_->WRLock(); VLOG(3) << "ssd sparse table load with " << path << " with meta " << param; diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h index 3a703d7d966d3..11a776bd9e847 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h @@ -23,7 +23,7 @@ class SSDSparseTable : public CommonSparseTable { SSDSparseTable() {} virtual ~SSDSparseTable() {} - virtual int32_t initialize() override; + virtual int32_t Initialize() override; void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, const size_t shard_idx, const int64_t total); @@ -37,22 +37,22 @@ class SSDSparseTable : public CommonSparseTable { const int pserver_id, const int pserver_num, const int local_shard_num, std::vector>* blocks); - virtual int32_t load(const std::string& path, const std::string& param); + virtual int32_t Load(const std::string& path, const std::string& param); // exchange data - virtual int32_t update_table(); + virtual int32_t UpdateTable(); virtual int32_t Pull(TableContext& context); virtual int32_t Push(TableContext& context); - virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, - size_t num); + virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, + size_t num); - virtual int32_t flush() override { return 0; } - virtual int32_t shrink(const std::string& param) override; - virtual void clear() override {} + virtual int32_t Flush() override { return 0; } + virtual int32_t Shrink(const std::string& param) override; + virtual void Clear() override {} private: RocksDBHandler* _db; diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index 99790606f0b31..9f17a2006d232 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -56,7 +56,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule); -int32_t TableManager::initialize() { +int32_t TableManager::Initialize() { static bool initialized = false; if (initialized) { return 0; @@ -65,10 +65,10 @@ int32_t TableManager::initialize() { return 0; } -int32_t Table::initialize(const TableParameter &config, +int32_t Table::Initialize(const TableParameter &config, const FsClientParameter &fs_config) { _config = config; - if (initialize_accessor() != 0) { + if (InitializeAccessor() != 0) { LOG(WARNING) << "Table accessor initialize failed"; return -1; } @@ -77,10 +77,10 @@ int32_t Table::initialize(const TableParameter &config, LOG(WARNING) << "Table fs_client initialize failed"; // return -1; } - return initialize(); + return Initialize(); } -int32_t Table::initialize_accessor() { +int32_t Table::InitializeAccessor() { if (!_config.has_accessor() || !_config.accessor().has_accessor_class()) { LOG(ERROR) << "missing accessor config in table, table_id:" << _config.table_id(); diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index bba34d89377a7..c61efe769e2f8 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -60,101 +60,99 @@ class Table { public: Table() {} virtual ~Table() {} - virtual int32_t initialize(const TableParameter &config, + virtual int32_t Initialize(const TableParameter &config, const FsClientParameter &fs_config); virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0; - virtual int32_t pull_dense(float *values, size_t num) = 0; - virtual int32_t push_dense(const float *values, size_t num) = 0; + virtual int32_t PullDense(float *values, size_t num) = 0; + virtual int32_t PushDense(const float *values, size_t num) = 0; // for push global_step - virtual int32_t push_dense(const int64_t *values, const int32_t trainer_id) { - return 0; - } - virtual int32_t push_dense_param(const float *values, size_t num) { + virtual int32_t PushDense(const int64_t *values, const int32_t trainer_id) { return 0; } + virtual int32_t PushDenseParam(const float *values, size_t num) { return 0; } - virtual int32_t pull_sparse_ptr(char **pull_values, const uint64_t *keys, - size_t num) { + virtual int32_t PullSparsePtr(char **pull_values, const uint64_t *keys, + size_t num) { VLOG(0) << "NOT IMPLEMENT"; return 0; } - virtual int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) = 0; - virtual int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) = 0; - virtual int32_t push_sparse(const uint64_t *keys, const float **values, - size_t num) { + virtual int32_t PullSparse(float *values, + const PullSparseValue &pull_value) = 0; + virtual int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) = 0; + virtual int32_t PushSparse(const uint64_t *keys, const float **values, + size_t num) { return 0; } - virtual int32_t push_sparse_param(const uint64_t *keys, const float *values, - size_t num) { + virtual int32_t PushSparseParam(const uint64_t *keys, const float *values, + size_t num) { return 0; } // only for sparse geo table - virtual int32_t pull_geo_param(const uint32_t trainer_id, - std::vector *values, - std::vector *keys) { + virtual int32_t PullGeoParam(const uint32_t trainer_id, + std::vector *values, + std::vector *keys) { return 0; } // only for barrier - virtual int32_t barrier(const uint32_t trainer_id, + virtual int32_t Barrier(const uint32_t trainer_id, const std::string barrier_type) { return 0; } // only for barrier table - virtual int32_t set_table_map( + virtual int32_t SetTableMap( std::unordered_map> *table_map) { return 0; } // only for tensor table - virtual int32_t set_program_env( + virtual int32_t SetProgramEnv( framework::Scope *scope, platform::Place place, const std::vector *sub_program) { return 0; } - virtual int32_t set_global_lr(float *lr) { + virtual int32_t SetGlobalLR(float *lr) { _global_lr = lr; return 0; } - virtual int32_t pour() { return 0; } + virtual int32_t Pour() { return 0; } - virtual void clear() = 0; - virtual int32_t flush() = 0; - virtual int32_t shrink(const std::string ¶m) = 0; + virtual void Clear() = 0; + virtual int32_t Flush() = 0; + virtual int32_t Shrink(const std::string ¶m) = 0; // 指定加载路径 - virtual int32_t load(const std::string &path, + virtual int32_t Load(const std::string &path, const std::string &converter) = 0; // 指定保存路径 - virtual int32_t save(const std::string &path, + virtual int32_t Save(const std::string &path, const std::string &converter) = 0; - virtual int32_t set_shard(size_t shard_idx, size_t shard_num) { + virtual int32_t SetShard(size_t shard_idx, size_t shard_num) { _shard_idx = shard_idx; _shard_num = shard_num; - return initialize_shard(); + return InitializeShard(); } - inline std::shared_ptr value_accesor() { + inline std::shared_ptr ValueAccesor() { return _value_accesor; } - virtual void *get_shard(size_t shard_idx) = 0; - virtual std::pair print_table_stat() { return {0, 0}; } + virtual void *GetShard(size_t shard_idx) = 0; + virtual std::pair PrintTableStat() { return {0, 0}; } protected: - virtual int32_t initialize() = 0; - virtual int32_t initialize_accessor(); - virtual int32_t initialize_shard() = 0; - virtual std::string table_dir(const std::string &model_dir) { + virtual int32_t Initialize() = 0; + virtual int32_t InitializeAccessor(); + virtual int32_t InitializeShard() = 0; + virtual std::string TableDir(const std::string &model_dir) { return paddle::string::format_string("%s/%03d/", model_dir.c_str(), _config.table_id()); } @@ -171,11 +169,11 @@ REGISTER_PSCORE_REGISTERER(Table); class TableManager { public: - static TableManager &instance() { + static TableManager &Instance() { static TableManager manager; return manager; } - int32_t initialize(); + int32_t Initialize(); private: TableManager() {} diff --git a/paddle/fluid/distributed/ps/table/tensor_table.h b/paddle/fluid/distributed/ps/table/tensor_table.h index e59314923cdbc..175aa194fb80f 100644 --- a/paddle/fluid/distributed/ps/table/tensor_table.h +++ b/paddle/fluid/distributed/ps/table/tensor_table.h @@ -52,42 +52,42 @@ class TensorTable : public Table { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t flush() { return 0; } + virtual int32_t Flush() { return 0; } - virtual int32_t load(const std::string &path, const std::string ¶m) { + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } - virtual void clear() {} + virtual void Clear() {} - int32_t initialize() override { return 0; } + int32_t Initialize() override { return 0; } - int32_t push_dense(const int64_t *values, const int32_t trainer_id) override { + int32_t PushDense(const int64_t *values, const int32_t trainer_id) override { return 0; } - int32_t set_program_env( + int32_t SetProgramEnv( framework::Scope *scope, platform::Place place, const std::vector *sub_program) override { scope_ = scope; @@ -111,48 +111,48 @@ class DenseTensorTable : public TensorTable { DenseTensorTable() {} virtual ~DenseTensorTable() {} - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t flush() { return 0; } + virtual int32_t Flush() { return 0; } - virtual void clear() {} + virtual void Clear() {} // Todo: Support program Load & Save - virtual int32_t load(const std::string &path, const std::string ¶m) { + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } // Todo: Support pull dense - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } /*----------------------------------------------------------------------*/ - int32_t initialize() override { return 0; } + int32_t Initialize() override { return 0; } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t push_dense(const int64_t *values, const int32_t trainer_id) { + int32_t PushDense(const int64_t *values, const int32_t trainer_id) { return 0; } protected: - virtual int32_t _run_program(const float *values, size_t num, - const uint32_t trainer_id) { + virtual int32_t _RunProgram(const float *values, size_t num, + const uint32_t trainer_id) { return 0; } @@ -167,36 +167,36 @@ class GlobalStepTable : public DenseTensorTable { GlobalStepTable() {} virtual ~GlobalStepTable() {} - int32_t pull_sparse(float *values, - const PullSparseValue &pull_value) override { + int32_t PullSparse(float *values, + const PullSparseValue &pull_value) override { return 0; } - int32_t push_sparse(const uint64_t *keys, const float *values, - size_t num) override { + int32_t PushSparse(const uint64_t *keys, const float *values, + size_t num) override { return 0; } - int32_t shrink(const std::string ¶m) override { return 0; } + int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *get_shard(size_t shard_idx) { return 0; } + virtual void *GetShard(size_t shard_idx) { return 0; } - virtual int32_t initialize_shard() { return 0; } + virtual int32_t InitializeShard() { return 0; } - virtual int32_t flush() { return 0; } + virtual int32_t Flush() { return 0; } - virtual void clear() {} + virtual void Clear() {} - virtual int32_t load(const std::string &path, const std::string ¶m) { + virtual int32_t Load(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t save(const std::string &path, const std::string ¶m) { + virtual int32_t Save(const std::string &path, const std::string ¶m) { return 0; } - int32_t pull_dense(float *values, size_t num) override { return 0; } + int32_t PullDense(float *values, size_t num) override { return 0; } /*----------------------------------------------------------------------*/ - int32_t initialize() override { + int32_t Initialize() override { auto _program_config = _config.tensor(); auto trainers_ = _config.common().trainer_num(); FLAGS_eager_delete_tensor_gb = -1; @@ -237,14 +237,14 @@ class GlobalStepTable : public DenseTensorTable { } } - int32_t push_dense(const float *values, size_t num) override { return 0; } + int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t push_dense(const int64_t *values, const int32_t trainer_id) { - return _run_program(values, trainer_id); + int32_t PushDense(const int64_t *values, const int32_t trainer_id) { + return _RunProgram(values, trainer_id); } - int32_t set_table_map(std::unordered_map> - *table_map) override { + int32_t SetTableMap(std::unordered_map> + *table_map) override { auto *lr_var = scope_->FindVar(fetch_var_name_); auto *lr_tensor = lr_var->GetMutable(); auto *lr_value = lr_tensor->mutable_data(platform::CPUPlace()); @@ -255,14 +255,14 @@ class GlobalStepTable : public DenseTensorTable { if (table_id == _config.table_id()) { continue; } - iter->second->set_global_lr(lr_value); + iter->second->SetGlobalLR(lr_value); } return 0; } private: - virtual int32_t _run_program(const int64_t *values, - const uint32_t trainer_id) { + virtual int32_t _RunProgram(const int64_t *values, + const uint32_t trainer_id) { FLAGS_eager_delete_tensor_gb = -1; auto counter = decay_counters_.at(trainer_id); counter += int(values[0]); diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index c9093368c693e..7bc50a868104a 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -51,32 +51,6 @@ int32_t FleetWrapper::CopyTableByFeasign( return 0; } -void FleetWrapper::Stop() { StopServer(); } - -void FleetWrapper::Load(WrapperContext& context) { - auto table_id = context.table_id; - if (table_id >= 0 && context.meta != "") { - LoadSparseOnServer(context.path, context.meta, context.table_id); - return; - } - if (table_id < 0) { // laod all - LoadModel(context.path, context.mode); - } else { // load one table - LoadModelOneTable(table_id, context.path, context.mode); - } - return; -} - -void FleetWrapper::Save(WrapperContext& context) { - auto table_id = context.table_id; - if (table_id < 0) { - SaveModel(context.path, context.mode); - } else { - SaveModelOneTable(table_id, context.path, context.mode); - } - return; -} - void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry) { @@ -90,7 +64,7 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path, uint32_t table_id) { VLOG(3) << "load sparse table " << table_id << " with " << path << " meta " << meta; - pserver_ptr_->_server_ptr->table(table_id)->load(path, meta); + pserver_ptr_->_server_ptr->GetTable(table_id)->Load(path, meta); } void FleetWrapper::InitServer( @@ -101,8 +75,8 @@ void FleetWrapper::InitServer( VLOG(3) << "Going to init server"; pserver_ptr_ = std::shared_ptr( new paddle::distributed::PSCore()); - pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), - index, trainers, server_sub_program); + pserver_ptr_->InitServer(dist_desc, &host_sign_list, host_sign_list.size(), + index, trainers, server_sub_program); is_initialized_ = true; } else { VLOG(3) << "Server can be initialized only once"; @@ -143,10 +117,10 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, google::protobuf::TextFormat::ParseFromString(dist_desc, &ps_param); InitGFlag(ps_param.init_gflags()); int servers = host_sign_list.size(); - ps_env_.set_ps_servers(&host_sign_list, servers); + ps_env_.SetPsServers(&host_sign_list, servers); worker_ptr_ = std::shared_ptr( - paddle::distributed::PSClientFactory::create(ps_param)); - worker_ptr_->configure(ps_param, dense_pull_regions, ps_env_, index); + paddle::distributed::PSClientFactory::Create(ps_param)); + worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index); } } else { VLOG(3) << "Client can be initialized only once"; @@ -155,13 +129,13 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, void FleetWrapper::StopServer() { VLOG(3) << "Going to stop server"; - auto status = worker_ptr_->stop_server(); + auto status = worker_ptr_->StopServer(); status.wait(); } void FleetWrapper::FinalizeWorker() { VLOG(3) << "Going to finalize worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); } void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { @@ -172,13 +146,13 @@ void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { VLOG(3) << "Going to run server with ip " << ip << " port " << port; - auto ret = pserver_ptr_->run_server(ip, port); + auto ret = pserver_ptr_->RunServer(ip, port); return ret; } std::vector FleetWrapper::GetClientsInfo() { VLOG(3) << "Going to get client info"; - std::vector res = ps_env_.get_client_info(); + std::vector res = ps_env_.GetClientInfo(); for (auto rr : res) { VLOG(2) << "FleetWrapper::GetClientInfo " << rr; } @@ -187,14 +161,14 @@ std::vector FleetWrapper::GetClientsInfo() { int FleetWrapper::SetClients(std::vector& host_sign_list) { int node = host_sign_list.size(); - return ps_env_.set_ps_clients(host_sign_list.data(), node); + return ps_env_.SetPsClients(host_sign_list.data(), node); } void FleetWrapper::CreateClient2ClientConnection() { VLOG(1) << "Going to create client2client connection"; - worker_ptr_->create_client2client_connection( - client2client_request_timeout_ms_, client2client_connect_timeout_ms_, - client2client_max_retry_); + worker_ptr_->CreateClient2ClientConnection(client2client_request_timeout_ms_, + client2client_connect_timeout_ms_, + client2client_max_retry_); } std::future FleetWrapper::PullSparseVarsAsync( @@ -230,9 +204,9 @@ std::future FleetWrapper::PullSparseVarsAsync( } bool training = true; - return pserver_ptr_->_worker_ptr->pull_sparse(pull_result_ptr.data(), - table_id, fea_keys->data(), - fea_keys->size(), training); + return pserver_ptr_->_worker_ptr->PullSparse(pull_result_ptr.data(), table_id, + fea_keys->data(), + fea_keys->size(), training); } void FleetWrapper::PullSparseVarsSync( @@ -279,7 +253,7 @@ void FleetWrapper::PullSparseVarsSync( pull_result_ptr.push_back(t.data()); } bool training = true; - auto status = pserver_ptr_->_worker_ptr->pull_sparse( + auto status = pserver_ptr_->_worker_ptr->PullSparse( pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(), training); pull_sparse_status.push_back(std::move(status)); @@ -337,21 +311,10 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, pull_result_ptr.push_back(output_data + output_len); } } - // ps client pull sparse - // construct client request context - RequestContext req_context; - req_context.value_type = Sparse; - req_context.training_mode = Async; - req_context.table = table_id; - req_context.sparse_values = pull_result_ptr.data(); - req_context.keys = fea_keys.data(); - req_context.num = fea_keys.size(); - req_context.is_training = is_training; - auto status = worker_ptr_->Pull(req_context); - // auto status = - // worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id, - // fea_keys.data(), fea_keys.size(), - // is_training); + + auto status = + worker_ptr_->PullSparse(pull_result_ptr.data(), table_id, fea_keys.data(), + fea_keys.size(), is_training); status.wait(); auto ret = status.get(); if (ret != 0) { @@ -364,7 +327,7 @@ void FleetWrapper::PullDenseVarsAsync( const Scope& scope, const uint64_t tid, const std::vector& var_names, std::vector>* pull_dense_status, bool in_cpu) { - auto& regions = _regions[tid]; + auto& regions = regions_[tid]; regions.clear(); regions.resize(var_names.size()); for (auto i = 0u; i < var_names.size(); ++i) { @@ -378,21 +341,15 @@ void FleetWrapper::PullDenseVarsAsync( paddle::distributed::Region reg(w, tensor->numel()); regions[i] = std::move(reg); } - RequestContext req_context; - req_context.value_type = Dense; - req_context.training_mode = Async; - req_context.table = tid; - req_context.dense_values = regions.data(); - req_context.num = regions.size(); - auto status = worker_ptr_->Pull(req_context); - // auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); + + auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid); pull_dense_status->push_back(std::move(status)); } void FleetWrapper::PullDenseVarsSync( const Scope& scope, const uint64_t tid, const std::vector& var_names) { - auto& regions = _regions[tid]; + auto& regions = regions_[tid]; regions.clear(); regions.reserve(var_names.size()); for (auto& t : var_names) { @@ -404,7 +361,7 @@ void FleetWrapper::PullDenseVarsSync( regions.emplace_back(std::move(reg)); } } - auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid); + auto status = worker_ptr_->PullDense(regions.data(), regions.size(), tid); status.wait(); } @@ -424,7 +381,7 @@ void FleetWrapper::PushDenseParamSync( } } auto push_status = - worker_ptr_->push_dense_param(regions.data(), regions.size(), table_id); + worker_ptr_->PushDenseParam(regions.data(), regions.size(), table_id); push_status.wait(); auto status = push_status.get(); CHECK(status == 0) << "push dense param failed, status[" << status << "]"; @@ -470,15 +427,8 @@ void FleetWrapper::PushDenseVarsAsync( << g[tensor->numel() - 1]; } - RequestContext req_context; - req_context.value_type = Dense; - req_context.training_mode = Async; - req_context.table = table_id; - req_context.push_context.push_dense_values = regions.data(); - req_context.num = regions.size(); - // auto push_status = - // worker_ptr_->push_dense(regions.data(), regions.size(), table_id); - auto push_status = worker_ptr_->Push(req_context); + auto push_status = + worker_ptr_->PushDense(regions.data(), regions.size(), table_id); } void FleetWrapper::PushSparseVarsAsync( @@ -650,23 +600,13 @@ void FleetWrapper::PushSparseFromTensorAsync( push_g_vec[i] = push_values.at(i).data(); } - // ps client push sparse - // construct request context - RequestContext req_context; - req_context.value_type = Sparse; - req_context.training_mode = Async; - req_context.table = table_id; - req_context.push_context.push_values = (const float**)push_g_vec.data(); - req_context.push_context.keys = push_keys.data(); - req_context.num = push_keys.size(); - auto status = worker_ptr_->Push(req_context); - // auto status = worker_ptr_->push_sparse(table_id, push_keys.data(), - // (const float**)push_g_vec.data(), - // push_keys.size()); + auto status = worker_ptr_->PushSparse(table_id, push_keys.data(), + (const float**)push_g_vec.data(), + push_keys.size()); } void FleetWrapper::LoadModel(const std::string& path, const int mode) { - auto ret = worker_ptr_->load(path, std::to_string(mode)); + auto ret = worker_ptr_->Load(path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model from path:" << path << " failed"; @@ -675,7 +615,7 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModelOneTable(const uint64_t table_id, const std::string& path, const int mode) { - auto ret = worker_ptr_->load(table_id, path, std::to_string(mode)); + auto ret = worker_ptr_->Load(table_id, path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model of table id: " << table_id @@ -684,7 +624,7 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id, } void FleetWrapper::SaveModel(const std::string& path, const int mode) { - auto ret = worker_ptr_->save(path, std::to_string(mode)); + auto ret = worker_ptr_->Save(path, std::to_string(mode)); ret.wait(); int32_t feasign_cnt = ret.get(); if (feasign_cnt == -1) { @@ -694,7 +634,7 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { void FleetWrapper::SaveModelOneTable(const uint64_t table_id, const std::string& path, const int mode) { - auto ret = worker_ptr_->save(table_id, path, std::to_string(mode)); + auto ret = worker_ptr_->Save(table_id, path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "save model of table id: " << table_id @@ -704,7 +644,7 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id, void FleetWrapper::RecvAndSaveTable(const uint64_t table_id, const std::string& path) { - auto ret = worker_ptr_->recv_and_save_table(table_id, path); + auto ret = worker_ptr_->RecvAndSaveTable(table_id, path); if (ret != 0) { LOG(ERROR) << "save model of table id: " << table_id << ", to path: " << path << " failed"; @@ -712,7 +652,7 @@ void FleetWrapper::RecvAndSaveTable(const uint64_t table_id, } void FleetWrapper::PrintTableStat(const uint64_t table_id) { - auto ret = worker_ptr_->print_table_stat(table_id); + auto ret = worker_ptr_->PrintTableStat(table_id); ret.wait(); int32_t err_code = ret.get(); if (err_code == -1) { @@ -721,7 +661,7 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { } void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) { - auto ret = worker_ptr_->shrink(table_id, std::to_string(threshold)); + auto ret = worker_ptr_->Shrink(table_id, std::to_string(threshold)); ret.wait(); int32_t err_code = ret.get(); if (err_code == -1) { @@ -730,12 +670,12 @@ void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) { } void FleetWrapper::ClearModel() { - auto ret = pserver_ptr_->_worker_ptr->clear(); + auto ret = pserver_ptr_->_worker_ptr->Clear(); ret.wait(); } void FleetWrapper::ClearOneTable(const uint64_t table_id) { - auto ret = pserver_ptr_->_worker_ptr->clear(table_id); + auto ret = pserver_ptr_->_worker_ptr->Clear(table_id); ret.wait(); } @@ -774,7 +714,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, regions.emplace_back(std::move(reg)); } } - auto push_status = pserver_ptr_->_worker_ptr->push_dense_param( + auto push_status = pserver_ptr_->_worker_ptr->PushDenseParam( regions.data(), regions.size(), table_id); push_status.wait(); auto status = push_status.get(); @@ -791,7 +731,7 @@ void FleetWrapper::ClientFlush() { VLOG(0) << "worker_ptr null, do nothing"; return; } - auto ret = worker_ptr_->flush(); + auto ret = worker_ptr_->Flush(); ret.wait(); int32_t err_code = ret.get(); if (err_code == -1) { @@ -805,13 +745,13 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, VLOG(0) << "FleetWrapper::Client is null"; return -1; } else { - return worker_ptr_->registe_client2client_msg_handler(msg_type, handler); + return worker_ptr_->RegisteClient2ClientMsgHandler(msg_type, handler); } } std::future FleetWrapper::SendClientToClientMsg( int msg_type, int to_client_id, const std::string& msg) { - return worker_ptr_->send_client2client_msg(msg_type, to_client_id, msg); + return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg); } std::default_random_engine& FleetWrapper::LocalRandomEngine() { diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.h b/paddle/fluid/distributed/ps/wrapper/fleet.h index 13b7ea7609ee6..e6ec09a12637d 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.h +++ b/paddle/fluid/distributed/ps/wrapper/fleet.h @@ -25,7 +25,6 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/ps_service/service.h" -#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h" #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/shell.h" @@ -55,7 +54,7 @@ using framework::Variable; using RpcCtxMap = std::unordered_map; -class FleetWrapper : public PSWrapper { +class FleetWrapper { public: virtual ~FleetWrapper() {} FleetWrapper() { @@ -69,7 +68,6 @@ class FleetWrapper : public PSWrapper { // pserver request max retry client2client_max_retry_ = 3; } - virtual int32_t Initialize(InitContext& context) { return 0; } // TODO(zhaocaibei123: later) int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id); @@ -81,12 +79,6 @@ class FleetWrapper : public PSWrapper { typedef std::function HeterCallBackFunc; int RegisterHeterCallback(HeterCallBackFunc handler); - virtual void Stop() override; - - virtual void Load(WrapperContext& context) override; - - virtual void Save(WrapperContext& context) override; - // set client to client communication config void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry); @@ -278,7 +270,7 @@ class FleetWrapper : public PSWrapper { protected: static bool is_initialized_; - std::map> _regions; + std::map> regions_; bool scale_sparse_gradient_with_batch_size_; int32_t sleep_seconds_before_fail_exit_; int client2client_request_timeout_ms_; diff --git a/paddle/fluid/distributed/test/barrier_table_test.cc b/paddle/fluid/distributed/test/barrier_table_test.cc index 0715f777fa5cb..c4c5b22992804 100644 --- a/paddle/fluid/distributed/test/barrier_table_test.cc +++ b/paddle/fluid/distributed/test/barrier_table_test.cc @@ -39,19 +39,19 @@ TEST(BarrierTable, Barrier) { common_config->set_trainer_num(trainers); common_config->set_sync(sync); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); std::unordered_map> maps = std::unordered_map>(); - table->set_table_map(&maps); + table->SetTableMap(&maps); std::shared_ptr<::ThreadPool> pool_ = std::make_shared<::ThreadPool>(trainers); std::vector> task_status; for (auto x = 0; x < trainers; x++) { - auto task = [table, x] { table->barrier(x, 0); }; + auto task = [table, x] { table->Barrier(x, 0); }; task_status.push_back(pool_->enqueue(std::move(task))); } diff --git a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc index 19ff50ec2a43b..d5e196ff3219f 100644 --- a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc @@ -155,16 +155,16 @@ void RunServer() { auto _ps_env = paddle::distributed::PaddlePSEnvironment(); LOG(INFO) << "RUN set_ps_servers"; - _ps_env.set_ps_servers(&host_sign_list_, 1); + _ps_env.SetPsServers(&host_sign_list_, 1); pserver_ptr_ = std::shared_ptr( - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); LOG(INFO) << "RUN configure"; std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); LOG(INFO) << "RUN start"; - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Start(ip_, port_); LOG(INFO) << "End start"; } @@ -175,19 +175,19 @@ void RunClient(std::map>& auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); LOG(INFO) << "Run set_ps_servers"; - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); LOG(INFO) << "Run Create PSClient"; worker_ptr_ = std::shared_ptr( - paddle::distributed::PSClientFactory::create(worker_proto)); + paddle::distributed::PSClientFactory::Create(worker_proto)); LOG(INFO) << "Run configure"; - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); } void RunBrpcPushDense() { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // Srart Server std::thread server_thread(RunServer); @@ -218,7 +218,7 @@ void RunBrpcPushDense() { paddle::distributed::Region temp_reg(temp, tensor->numel()); temp_region.emplace_back(std::move(temp_reg)); auto pull_status = - worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0); + worker_ptr_->PullDense(temp_region.data(), temp_region.size(), 0); pull_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -229,10 +229,10 @@ void RunBrpcPushDense() { LOG(INFO) << "Run push_dense_param"; auto push_status = - worker_ptr_->push_dense_param(regions.data(), regions.size(), 0); + worker_ptr_->PushDenseParam(regions.data(), regions.size(), 0); push_status.wait(); - pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_status = worker_ptr_->PullDense(regions.data(), regions.size(), 0); pull_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -257,11 +257,11 @@ void RunBrpcPushDense() { LOG(INFO) << "Run pull_dense_grad"; auto push_grad_status = - worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure); + worker_ptr_->PushDenseRawGradient(0, temp, tensor->numel(), closure); push_grad_status.wait(); auto pull_update_status = - worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + worker_ptr_->PullDense(regions.data(), regions.size(), 0); pull_update_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -269,9 +269,9 @@ void RunBrpcPushDense() { } LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); server_thread.join(); } diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc index 633f3b2f3c550..f7d287af84472 100644 --- a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -156,14 +156,14 @@ void RunServer() { ::paddle::distributed::PSParameter server_proto = GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, 1); + _ps_env.SetPsServers(&host_sign_list_, 1); pserver_ptr_ = std::shared_ptr( - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Start(ip_, port_); } void RunClient(std::map>& @@ -172,17 +172,17 @@ void RunClient(std::map>& paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); worker_ptr_ = std::shared_ptr( - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); } void RunBrpcPushSparse() { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // Srart Server std::thread server_thread(RunServer); @@ -214,7 +214,7 @@ void RunBrpcPushSparse() { /*-----------------------Test Server Init----------------------------------*/ LOG(INFO) << "Run pull_sparse_param"; - auto pull_status = worker_ptr_->pull_sparse( + auto pull_status = worker_ptr_->PullSparse( fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_status.wait(); for (size_t idx = 0; idx < tensor->numel(); ++idx) { @@ -237,12 +237,12 @@ void RunBrpcPushSparse() { } closure->set_promise_value(ret); }); - auto push_status = worker_ptr_->push_sparse_param( + auto push_status = worker_ptr_->PushSparseParam( 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), closure_push_param); push_status.wait(); - auto pull_param_status = worker_ptr_->pull_sparse( + auto pull_param_status = worker_ptr_->PullSparse( fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_param_status.wait(); @@ -271,12 +271,12 @@ void RunBrpcPushSparse() { for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { push_g_vec.push_back(tensor->data() + i * 10); } - auto push_grad_status = worker_ptr_->push_sparse_raw_gradient( + auto push_grad_status = worker_ptr_->PushSparseRawGradient( 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), closure_push_grad); push_grad_status.wait(); - auto pull_update_status = worker_ptr_->pull_sparse( + auto pull_update_status = worker_ptr_->PullSparse( fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_update_status.wait(); @@ -285,9 +285,9 @@ void RunBrpcPushSparse() { } LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); server_thread.join(); } diff --git a/paddle/fluid/distributed/test/dense_table_test.cc b/paddle/fluid/distributed/test/dense_table_test.cc index c9a038e000e14..49346c2898fc6 100644 --- a/paddle/fluid/distributed/test/dense_table_test.cc +++ b/paddle/fluid/distributed/test/dense_table_test.cc @@ -63,13 +63,13 @@ TEST(CommonDenseTable, Adam) { common_config->add_params("LearningRate"); common_config->add_dims(1); common_config->add_initializers("fill_constant&5e-6"); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->pull_dense(init_values.data(), fea_dim); + table->PullDense(init_values.data(), fea_dim); // push gradient std::vector> trainer_gradient_values; @@ -85,12 +85,12 @@ TEST(CommonDenseTable, Adam) { // for adam for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; - table->push_dense(push_values.data(), push_values.size()); + table->PushDense(push_values.data(), push_values.size()); } std::vector pull_values; pull_values.resize(fea_dim); - table->pull_dense(pull_values.data(), fea_dim); + table->PullDense(pull_values.data(), fea_dim); float mom_rate = 0.99; float decay_rate = 0.9999; @@ -118,6 +118,7 @@ TEST(CommonDenseTable, Adam) { } } for (int j = 0; j < fea_dim; j++) { + VLOG(0) << param[j] << " " << pull_values[j]; ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5); } } @@ -143,13 +144,13 @@ TEST(CommonDenseTable, SGD) { common_config->add_params("LearningRate"); common_config->add_dims(1); common_config->add_initializers("fill_constant&1.0"); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->pull_dense(init_values.data(), fea_dim); + table->PullDense(init_values.data(), fea_dim); std::vector total_gradients; total_gradients.resize(fea_dim); @@ -172,7 +173,7 @@ TEST(CommonDenseTable, SGD) { for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_values] { - table->push_dense(push_values.data(), push_values.size()); + table->PushDense(push_values.data(), push_values.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -182,7 +183,7 @@ TEST(CommonDenseTable, SGD) { std::vector pull_values; pull_values.resize(fea_dim); - table->pull_dense(pull_values.data(), fea_dim); + table->PullDense(pull_values.data(), fea_dim); for (int j = 0; j < fea_dim; j++) { auto update_val = init_values[j] - 1.0 * total_gradients[j]; ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5); diff --git a/paddle/fluid/distributed/test/graph_node_split_test.cc b/paddle/fluid/distributed/test/graph_node_split_test.cc index a2f495de3c953..ce4f38f6cec9f 100644 --- a/paddle/fluid/distributed/test/graph_node_split_test.cc +++ b/paddle/fluid/distributed/test/graph_node_split_test.cc @@ -166,16 +166,16 @@ void RunServer() { ::paddle::distributed::PSParameter server_proto = GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, 2); // test + _ps_env.SetPsServers(&host_sign_list_, 2); // test pserver_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); LOG(INFO) << "first server, run start(ip,port)"; - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Start(ip_, port_); pserver_ptr_->build_peer2peer_connection(0); LOG(INFO) << "init first server Done"; } @@ -185,15 +185,15 @@ void RunServer2() { ::paddle::distributed::PSParameter server_proto2 = GetServerProto(); auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); - _ps_env2.set_ps_servers(&host_sign_list_, 2); // test + _ps_env2.SetPsServers(&host_sign_list_, 2); // test pserver_ptr2 = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto2)); + paddle::distributed::PSServerFactory::Create(server_proto2)); std::vector empty_vec2; framework::ProgramDesc empty_prog2; empty_vec2.push_back(empty_prog2); - pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); - pserver_ptr2->start(ip2, port2); + pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2); + pserver_ptr2->Start(ip2, port2); pserver_ptr2->build_peer2peer_connection(1); } @@ -204,11 +204,11 @@ void RunClient( paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); worker_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcClient*) - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->set_shard_num(127); worker_ptr_->set_local_channel(index); worker_ptr_->set_local_graph_service( @@ -222,11 +222,11 @@ void RunGraphSplit() { prepare_file(node_file_name, nodes); prepare_file(graph_split_file_name, graph_split); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // test-start auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); - host_sign_list_.push_back(ph_host2.serialize_to_string()); + host_sign_list_.push_back(ph_host2.SerializeToString()); // test-end // Srart Server std::thread* server_thread = new std::thread(RunServer); @@ -247,7 +247,7 @@ void RunGraphSplit() { 0, std::string(graph_split_file_name)); pull_status.wait(); pull_status = - worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); + worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>")); srand(time(0)); pull_status.wait(); std::vector> _vs; @@ -266,9 +266,9 @@ void RunGraphSplit() { std::remove(node_file_name); std::remove(graph_split_file_name); LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); } TEST(RunGraphSplit, Run) { RunGraphSplit(); } diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index e55d39cd4834d..b2c741df7a5dd 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -348,16 +348,16 @@ void RunServer() { ::paddle::distributed::PSParameter server_proto = GetServerProto(); auto _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, 2); // test + _ps_env.SetPsServers(&host_sign_list_, 2); // test pserver_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto)); + paddle::distributed::PSServerFactory::Create(server_proto)); std::vector empty_vec; framework::ProgramDesc empty_prog; empty_vec.push_back(empty_prog); - pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + pserver_ptr_->Configure(server_proto, _ps_env, 0, empty_vec); LOG(INFO) << "first server, run start(ip,port)"; - pserver_ptr_->start(ip_, port_); + pserver_ptr_->Start(ip_, port_); pserver_ptr_->build_peer2peer_connection(0); LOG(INFO) << "init first server Done"; } @@ -367,15 +367,15 @@ void RunServer2() { ::paddle::distributed::PSParameter server_proto2 = GetServerProto(); auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); - _ps_env2.set_ps_servers(&host_sign_list_, 2); // test + _ps_env2.SetPsServers(&host_sign_list_, 2); // test pserver_ptr2 = std::shared_ptr( (paddle::distributed::GraphBrpcServer*) - paddle::distributed::PSServerFactory::create(server_proto2)); + paddle::distributed::PSServerFactory::Create(server_proto2)); std::vector empty_vec2; framework::ProgramDesc empty_prog2; empty_vec2.push_back(empty_prog2); - pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); - pserver_ptr2->start(ip2, port2); + pserver_ptr2->Configure(server_proto2, _ps_env2, 1, empty_vec2); + pserver_ptr2->Start(ip2, port2); pserver_ptr2->build_peer2peer_connection(1); } @@ -386,11 +386,11 @@ void RunClient( paddle::distributed::PaddlePSEnvironment _ps_env; auto servers_ = host_sign_list_.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); - _ps_env.set_ps_servers(&host_sign_list_, servers_); + _ps_env.SetPsServers(&host_sign_list_, servers_); worker_ptr_ = std::shared_ptr( (paddle::distributed::GraphBrpcClient*) - paddle::distributed::PSClientFactory::create(worker_proto)); - worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + paddle::distributed::PSClientFactory::Create(worker_proto)); + worker_ptr_->Configure(worker_proto, dense_regions, _ps_env, 0); worker_ptr_->set_shard_num(127); worker_ptr_->set_local_channel(index); worker_ptr_->set_local_graph_service( @@ -404,11 +404,11 @@ void RunBrpcPushSparse() { prepare_file(edge_file_name, 1); prepare_file(node_file_name, 0); auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); - host_sign_list_.push_back(ph_host.serialize_to_string()); + host_sign_list_.push_back(ph_host.SerializeToString()); // test-start auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); - host_sign_list_.push_back(ph_host2.serialize_to_string()); + host_sign_list_.push_back(ph_host2.SerializeToString()); // test-end // Srart Server std::thread* server_thread = new std::thread(RunServer); @@ -424,7 +424,7 @@ void RunBrpcPushSparse() { /*-----------------------Test Server Init----------------------------------*/ auto pull_status = - worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); + worker_ptr_->Load(0, std::string(edge_file_name), std::string("e>")); srand(time(0)); pull_status.wait(); std::vector> _vs; @@ -438,7 +438,7 @@ void RunBrpcPushSparse() { pull_status.wait(); ASSERT_EQ(0, _vs[0].size()); paddle::distributed::GraphTable* g = - (paddle::distributed::GraphTable*)pserver_ptr_->table(0); + (paddle::distributed::GraphTable*)pserver_ptr_->GetTable(0); size_t ttl = 6; g->make_neighbor_sample_cache(4, ttl); int round = 5; @@ -622,15 +622,15 @@ void RunBrpcPushSparse() { std::remove(node_file_name); testAddNode(worker_ptr_); LOG(INFO) << "Run stop_server"; - worker_ptr_->stop_server(); + worker_ptr_->StopServer(); LOG(INFO) << "Run finalize_worker"; - worker_ptr_->finalize_worker(); + worker_ptr_->FinalizeWorker(); testFeatureNodeSerializeInt(); testFeatureNodeSerializeInt64(); testFeatureNodeSerializeFloat32(); testFeatureNodeSerializeFloat64(); testGraphToBuffer(); - client1.stop_server(); + client1.StopServer(); } void testCache() { @@ -700,4 +700,4 @@ void testGraphToBuffer() { VLOG(0) << s1.get_feature(0); } -TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } \ No newline at end of file +TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } diff --git a/paddle/fluid/distributed/test/memory_geo_table_test.cc b/paddle/fluid/distributed/test/memory_geo_table_test.cc index fb48b38c76a28..965f67992d000 100644 --- a/paddle/fluid/distributed/test/memory_geo_table_test.cc +++ b/paddle/fluid/distributed/test/memory_geo_table_test.cc @@ -48,7 +48,7 @@ TEST(MemorySparseGeoTable, SSUM) { common_config->add_dims(emb_dim); common_config->add_initializers("fill_constant&1.0"); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // test push_sparse_param, and create params @@ -58,12 +58,12 @@ TEST(MemorySparseGeoTable, SSUM) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { init_values.push_back(0.0); } - table->push_sparse_param(init_keys.data(), init_values.data(), - init_keys.size()); + table->PushSparseParam(init_keys.data(), init_values.data(), + init_keys.size()); std::vector pull_values(init_values.size()); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->pull_sparse(pull_values.data(), value); + table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5); @@ -93,8 +93,7 @@ TEST(MemorySparseGeoTable, SSUM) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_values[i]; auto task = [table, &push_keys, &push_values] { - table->push_sparse(push_keys.data(), push_values.data(), - push_keys.size()); + table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -107,7 +106,7 @@ TEST(MemorySparseGeoTable, SSUM) { geo_pull_ids.resize(trainers); geo_pull_values.resize(trainers); for (int i = 0; i < trainers; i++) { - table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]); + table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]); ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { auto id = geo_pull_ids[i][j]; diff --git a/paddle/fluid/distributed/test/memory_sparse_table_test.cc b/paddle/fluid/distributed/test/memory_sparse_table_test.cc index aec02e8aec558..73fa7272280b2 100644 --- a/paddle/fluid/distributed/test/memory_sparse_table_test.cc +++ b/paddle/fluid/distributed/test/memory_sparse_table_test.cc @@ -36,7 +36,7 @@ TEST(MemorySparseTable, SGD) { table_config.set_shard_num(10); FsClientParameter fs_config; Table *table = new MemorySparseTable(); - table->set_shard(0, 1); + table->SetShard(0, 1); TableAccessorParameter *accessor_config = table_config.mutable_accessor(); accessor_config->set_accessor_class("CtrCommonAccessor"); @@ -66,7 +66,7 @@ TEST(MemorySparseTable, SGD) { naive_param->add_weight_bounds(-10.0); naive_param->add_weight_bounds(10.0); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, 0); // pull parameters for create and check @@ -76,7 +76,7 @@ TEST(MemorySparseTable, SGD) { std::vector init_values; init_values.resize(init_keys.size() * (emb_dim + 3)); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->pull_sparse(init_values.data(), value); + table->PullSparse(init_values.data(), value); // for check std::vector total_gradients; @@ -109,8 +109,7 @@ TEST(MemorySparseTable, SGD) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_keys, &push_values] { - table->push_sparse(push_keys.data(), push_values.data(), - push_keys.size()); + table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -120,7 +119,7 @@ TEST(MemorySparseTable, SGD) { std::vector pull_values; pull_values.resize(init_keys.size() * (emb_dim + 3)); - table->pull_sparse(pull_values.data(), value); + table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size(); ++i) { for (size_t j = 2; j < emb_dim + 3; ++j) { @@ -133,7 +132,7 @@ TEST(MemorySparseTable, SGD) { } MemorySparseTable *ctr_table = dynamic_cast(table); - ctr_table->save_local_fs("./work/table.save", "0", "test"); + ctr_table->SaveLocalFS("./work/table.save", "0", "test"); } } // namespace distributed diff --git a/paddle/fluid/distributed/test/table_test.cc b/paddle/fluid/distributed/test/table_test.cc index 6a29781158b83..8690aee39f69c 100644 --- a/paddle/fluid/distributed/test/table_test.cc +++ b/paddle/fluid/distributed/test/table_test.cc @@ -26,7 +26,7 @@ TEST(Table, Initialize) { FsClientParameter fs_config; // case 1. no accessor Table *table = new SparseGeoTable(); - auto ret = table->initialize(table_config, fs_config); + auto ret = table->Initialize(table_config, fs_config); ASSERT_EQ(ret, -1); } } // namespace distributed diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 72f998a772764..75f5c24af5a99 100755 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -343,7 +343,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { #ifdef PADDLE_WITH_PSCORE int32_t cnt = 0; while (true) { - auto tt = fleet_ptr->worker_ptr_->pull_sparse_ptr( + auto tt = fleet_ptr->worker_ptr_->PullSparsePtr( reinterpret_cast(local_ptr[i].data()), this->table_id_, local_keys[i].data(), key_size); bool flag = true; diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 83926336cbec8..61cd7ad01696e 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -276,7 +276,7 @@ void MultiTrainer::Finalize() { if (communicator == nullptr) { VLOG(0) << "MultiTrainer::Finalize communicator is null!"; } else { - communicator->_worker_ptr->flush(); + communicator->_worker_ptr->Flush(); VLOG(1) << "MultiTrainer::Finalize ps client flush done"; } #endif diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index befcf36b41c24..330719762ae08 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -86,11 +86,11 @@ void BindDistFleetWrapper(py::module* m) { void BindPSHost(py::module* m) { py::class_(*m, "PSHost") .def(py::init()) - .def("serialize_to_string", &distributed::PSHost::serialize_to_string) - .def("parse_from_string", &distributed::PSHost::parse_from_string) - .def("to_uint64", &distributed::PSHost::serialize_to_uint64) - .def("from_uint64", &distributed::PSHost::parse_from_uint64) - .def("to_string", &distributed::PSHost::to_string); + .def("serialize_to_string", &distributed::PSHost::SerializeToString) + .def("parse_from_string", &distributed::PSHost::ParseFromString) + .def("to_uint64", &distributed::PSHost::SerializeToUint64) + .def("from_uint64", &distributed::PSHost::ParseFromUint64) + .def("to_string", &distributed::PSHost::ToString); } void BindSparseShardingTools(py::module* m) { @@ -224,7 +224,7 @@ void BindGraphPyClient(py::module* m) { &GraphPyClient::use_neighbors_sample_cache) .def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes) - .def("stop_server", &GraphPyClient::stop_server) + .def("stop_server", &GraphPyClient::StopServer) .def("get_node_feat", [](GraphPyClient& self, std::string node_type, std::vector node_ids, From 8df4622981339a61f9ecf4e09463a23205c75550 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Sat, 2 Apr 2022 11:12:58 +0800 Subject: [PATCH 10/93] wrapper the usage of distributed functions (#39720) --- .../distributed/collective/ProcessGroup.h | 13 +- python/paddle/distributed/collective.py | 367 ++++++++---------- python/paddle/distributed/parallel.py | 99 ++++- python/paddle/fluid/dygraph/parallel.py | 7 +- .../fluid/tests/unittests/CMakeLists.txt | 3 + .../tests/unittests/init_process_group.py | 14 +- .../tests/unittests/process_group_nccl.py | 157 ++++++-- .../tests/unittests/test_eager_dist_api.py | 33 ++ .../tests/unittests/test_fleet_base_single.py | 2 +- ...t_parallel_dygraph_dataparallel_cpuonly.py | 3 + 10 files changed, 436 insertions(+), 262 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_eager_dist_api.py diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 36a00a7d31758..c2ad1aa2c93ea 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -158,16 +158,17 @@ class ProcessGroupMapFromGid { } void insert(int gid, ProcessGroup* pg) { - PADDLE_ENFORCE_EQ(has(gid), false, - platform::errors::PreconditionNotMet( - "The process group with id %d doesnot exist.", gid)); + // PADDLE_ENFORCE_EQ(has(gid), false, + // platform::errors::PreconditionNotMet( + // "The process group with id %d does exist.", gid)); map_[gid] = pg; } ProcessGroup* get(int gid) { - PADDLE_ENFORCE_EQ(has(gid), false, - platform::errors::PreconditionNotMet( - "The process group with id %d doesnot exist.", gid)); + // PADDLE_ENFORCE_EQ(has(gid), true, + // platform::errors::PreconditionNotMet( + // "The process group with id %d doesnot exist.", + // gid)); return map_.find(gid)->second; } diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 6dbd7d228eefa..ecd31386a2334 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -16,7 +16,9 @@ import os from datetime import timedelta from ..fluid.layer_helper import LayerHelper +import paddle.fluid.framework as framework from ..fluid.framework import Variable +from ..fluid.framework import in_dygraph_mode from ..fluid.framework import OpProtoHolder from ..fluid.framework import _non_static_mode from ..fluid.framework import convert_np_dtype_to_dtype_ @@ -174,10 +176,6 @@ def _new_ring_id(): return len(_get_group_map()) + max(_get_global_env().nrings, 9) -def _new_group_name_id(): - return len(_get_group_map_by_name()) + max(_get_global_env().nrings, 9) - - def get_group(id=0): """ @@ -202,194 +200,24 @@ def get_group(id=0): return gm[id] if id in gm else None -def _new_process_group_impl(backend, store, rank, world_size, group_name, - pg_options): - if backend == "gloo": - gloo_store = core.GlooStore(store) - +def _new_process_group_impl(backend, + store, + rank, + world_size, + group_name, + pg_options, + group_id=0): pg = None if backend == "gloo": - pg = core.ProcessGroupGloo(gloo_store, rank, world_size) + pg = core.ProcessGroupGloo(store, rank, world_size, group_id) elif backend == "nccl": - pg = core.ProcessGroupNCCL(store, rank, world_size) + pg = core.ProcessGroupNCCL(store, rank, world_size, group_id) elif backend == "hccl": - pg = core.ProcessGroupHCCL(store, rank, world_size) + pg = core.ProcessGroupHCCL(store, rank, world_size, group_id) return pg -def _init_parallel_env(rank=None, - world_size=None, - backend="nccl", - timeout=timedelta(0), - pg_options=None): - """ - - Initializes the default distributed environment. - - Args: - rank (int, optional): the rank of the current process or device from 0 to world_size (exclusive). - If you launch your training with paddle.distributed.run or - paddle.distributed.launch module, None can be given. Default: None. - world_size (int, optional): total number of processes or devices. - If you launch your training with paddle.distributed.run or - paddle.distributed.launch module, None can be given. Default: None. - backend (str, optional): the name of the backend used to initialize - the distributed environment. The value can be one of 'nccl' for - GPU, 'gloo' for CPU or 'hccl' for NPU. Default: 'nccl'. - timeout (datetime.timedelta, optional): timeout used for operations of - the group. Default: datetime.timedelta(0) which means no timeout. - pg_options (dict, optional): options for the group. Default: None. - - Returns: - Group: a group. - - Examples: - - .. code-block:: python - - # filename: train.py - import paddle - paddle.distributed.init_parallel_env(0, 1) - - # how to start - # python paddle.distributed.run --gpus="0,1" train.py - - """ - - global _group_map_by_name - global _default_group_name - assert _default_group_name not in _group_map_by_name, ( - "The default distributed environment has been initialized.") - - assert backend in _valid_backend_list, ( - "Backend must be one of {}, but the given one is: {}".format( - _valid_backend_list, backend)) - _default_backend = backend - - assert isinstance(timeout, timedelta), ( - "timeout must be of the type datetime.timedelta.") - - if rank is None or world_size is None: - assert rank is None and world_size is None, ( - "rank and world_size should be unset at the same time.") - trainer_id = os.getenv("PADDLE_TRAINER_ID", None) - trainer_num = os.getenv("PADDLE_TRAINERS_NUM", None) - if trainer_id is None or trainer_num is None: - warnings.warn("If rank and world_size are both None, please start " - "your training with paddle.distributed.run or " - "paddle.distributed.launch module. Otherwise, " - "init_parallel_env will do nothing.") - return None - rank = int(trainer_id) - world_size = int(trainer_num) - - assert rank >= 0 and world_size > rank and world_size > 1, ( - "rank must be non-negative and world_size must be the " - "maximum rank plus one. Moreover, at least two processes are " - "required to create a process group.") - - master_addr = os.getenv("MASTER_ADDR", None) - master_port = os.getenv("MASTER_PORT", None) - if not master_addr or not master_port: - endpoints = os.getenv("PADDLE_MASTER", None) - if endpoints is None: - endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None) - if not endpoints: - raise ValueError( - "The environment variable 'MASTER_ADDR' and 'MASTER_PORT' " - "must be specified, for example 'export MASTER_ADDR=127.0.0.1' " - "and 'export MASTER_ADDR=54612'. Or you can start your training" - "with paddle.distributed.run or " - "paddle.distributed.luanch module.") - if ',' in endpoints: - endpoints = endpoints.split(',')[0] - master_addr, master_port = endpoints.split(":") - - master_port = int(master_port) - - is_master = rank == 0 - global _default_store - _default_store = core.TCPStore(master_addr, master_port, is_master, - world_size, timeout) - - pg = _new_process_group_impl(backend, _default_store, rank, world_size, - _default_group_name, pg_options) - ranks = list(range(world_size)) - group = Group( - rank, world_size, id=0, ranks=ranks, pg=pg, name=_default_group_name) - - paddle.fluid.dygraph.parallel_helper._set_parallel_ctx(True) - _group_map_by_name[_default_group_name] = group - return group - - -def _new_group(ranks=None, - backend=None, - group_name=None, - timeout=timedelta(0), - pg_options=None): - """ - Create a new process group. - - Args: - ranks (list, optional): list of ranks for the new group. If None is given, - all processes is used. Default: None. - backend (str, optional): the name of the backend used to initialize - the distributed environment. Default: the one for init_parallel_env. - timeout (datetime.timedelta, optional): timeout used for operations of - the group. Default: datetime.timedelta(0). - pg_options (dict, optional): options for the group. Default: None. - - Examples: - - .. code-block:: python - - import paddle - paddle.distributed.init_parallel_env(0, 1) - paddle.distributed.new_group([0, 1]) - - # how to start - # python paddle.distributed.run --gpus="0,1" train.py - - """ - global _default_group_name - if group_name is None: - group_name = _default_group_name + str(_new_group_name_id()) - if group_name == _default_group_name: - raise ValueError("group_name must be specified and it cannot be '{}' " - "which is used for the default process group created " - "by init_parallel_env.".format(_default_group_name)) - global_group = _get_default_group() - global_rank = global_group.rank - global_ranks = global_group.ranks - if ranks is None: - ranks = global_ranks - assert len(ranks) <= len(global_ranks), ( - "Size of new group must be less than or " - "equal to that of the default global group.") - size = len(ranks) - assert size > 1, "A group must have at least two memebers." - ranks = sorted(ranks) - if global_rank in ranks: - rank = ranks.index(global_rank) - pg = _new_process_group_impl(backend, _default_store, rank, size, - group_name, pg_options) - else: - rank = -1 - pg = None - group = Group( - rank, - size, - id=_new_group_name_id(), - ranks=ranks, - pg=pg, - name=group_name) - _group_map_by_name[group_name] = group - - return group - - def barrier(group=None): """ @@ -414,6 +242,12 @@ def barrier(group=None): if group is not None and not group.is_member(): return + if framework._in_eager_mode_ and in_dygraph_mode(): + group = _get_default_group() if group is None else group + task = group.process_group.barrier() + task.wait() + return + ring_id = 0 if group is None else group.id temp = fill_constant([1], dtype="int32", value="1") @@ -455,6 +289,40 @@ def new_group(ranks=None, backend=None): paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False) """ + global _group_map + if framework._in_eager_mode_: + global _default_group_name + gid = _new_ring_id() + group_name = _default_group_name + str(gid) + global_group = _get_default_group() + global_rank = global_group.rank + global_ranks = global_group.ranks + if ranks is None: + ranks = global_ranks + assert len(ranks) <= len(global_ranks), ( + "Size of new group must be less than or " + "equal to that of the default global group.") + size = len(ranks) + assert size > 1, "A group must have at least two memebers." + ranks = sorted(ranks) + if global_rank in ranks: + rank = ranks.index(global_rank) + pg = _new_process_group_impl( + backend, + _default_store, + rank, + size, + group_name, + pg_options=None, + group_id=gid) + else: + rank = -1 + pg = None + group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name) + _group_map_by_name[group_name] = group + _group_map[gid] = group + + return group if not backend: backend = 'nccl' @@ -465,7 +333,6 @@ def new_group(ranks=None, backend=None): ring_id = _new_ring_id() - global _group_map if global_rank not in ranks: gp = Group(-1, -1, ring_id, ranks) _group_map[ring_id] = gp @@ -628,7 +495,18 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): if not isinstance(src, int): raise ValueError("src should be int.") - ring_id = 0 if group is None else group.id + if framework._in_eager_mode_ and in_dygraph_mode(): + group = _get_default_group() if group is None else group + gsrc = group.get_group_rank(src) + assert gsrc >= 0, ("src rank out of group, need global rank") + task = group.process_group.broadcast(tensor, gsrc) + if use_calc_stream: + task.wait() + return None + else: + return task + + ring_id = ring_id = 0 if group is None else group.id gsrc = src if group is None else group.get_group_rank(src) assert gsrc >= 0, ("src rank out of group, need global rank") @@ -701,6 +579,23 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return + if framework._in_eager_mode_ and in_dygraph_mode(): + if op == ReduceOp.SUM: + op_type = core.ReduceOp.SUM + elif op == ReduceOp.MAX: + op_type = core.ReduceOp.MAX + elif op == ReduceOp.MIN: + op_type = core.ReduceOp.MIN + else: + raise ValueError("Unknown reduce_op type for allreduce.") + group = _get_default_group() if group is None else group + task = group.process_group.allreduce(tensor, op_type) + if use_calc_stream: + task.wait() + return None + else: + return task + ring_id = 0 if group is None else group.id if _non_static_mode(): if op == ReduceOp.SUM: @@ -721,9 +616,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): check_variable_and_dtype( tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_reduce') - if not op in [ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD]: - raise ValueError("The op for all_reduce must be one of educeOp.PROD, " - "ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN.") if op == ReduceOp.SUM: op_type = 'c_allreduce_sum' elif op == ReduceOp.MAX: @@ -789,8 +681,24 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - if not isinstance(dst, int): - raise ValueError("dst should be int.") + if framework._in_eager_mode_ and in_dygraph_mode(): + if op == ReduceOp.SUM: + op_type = core.ReduceOp.SUM + elif op == ReduceOp.MAX: + op_type = core.ReduceOp.MAX + elif op == ReduceOp.MIN: + op_type = core.ReduceOp.MIN + else: + raise ValueError("Unknown reduce_op type for reduce.") + group = _get_default_group() if group is None else group + gdst = group.get_group_rank(dst) + assert gdst >= 0, ("dst rank out of group, need global rank") + task = group.process_group.reduce(tensor, gdst, op_type) + if use_calc_stream: + task.wait() + return None + else: + return task ring_id = 0 if group is None else group.id gdst = dst if group is None else group.get_group_rank(dst) @@ -820,9 +728,6 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): check_variable_and_dtype( tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_reduce') - if not op in [ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD]: - raise ValueError("The op for reduce must be one of educeOp.PROD, " - "ReduceOp.SUM, ReduceOp.MAX, ReduceOp.MIN.") if op == ReduceOp.SUM: op_type = 'c_reduce_sum' @@ -897,6 +802,15 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return + if framework._in_eager_mode_ and in_dygraph_mode(): + group = _get_default_group() if group is None else group + out = paddle.concat(tensor_list) + task = group.process_group.all_gather(tensor, out) + task.wait() + tensor_list.clear() + tensor_list.extend(paddle.split(out, group.nranks, 0)) + return + ring_id = 0 if group is None else group.id nranks = _get_global_group().nranks if group is None else group.nranks @@ -985,18 +899,32 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): if not isinstance(src, int): raise ValueError("src should be int.") - ring_id = 0 if group is None else group.id - gsrc = src if group is None else group.get_group_rank(src) + if framework._in_eager_mode_ and in_dygraph_mode(): + group = _get_default_group() if group is None else group + gsrc = group.get_group_rank(src) + rank = group.rank + nranks = group.nranks + else: + ring_id = 0 if group is None else group.id + gsrc = src if group is None else group.get_group_rank(src) + rank = _get_global_group().rank if group is None else group.rank + nranks = _get_global_group().nranks if group is None else group.nranks assert gsrc >= 0, ("src rank out of group, need global rank") - rank = _get_global_group().rank if group is None else group.rank - nranks = _get_global_group().nranks if group is None else group.nranks if rank != gsrc: tensor_list = [] for _ in range(nranks): tensor_list.append(tensor) temp = paddle.concat(tensor_list, axis=0) - if _non_static_mode(): + if framework._in_eager_mode_ and in_dygraph_mode(): + task = group.process_group.scatter(temp, tensor, gsrc) + if use_calc_stream: + task.wait() + return None + else: + return task + + if in_dygraph_mode(): return _C_ops.c_scatter(temp, tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks, 'root', gsrc) @@ -1070,11 +998,12 @@ def _c_concat(tensor, group=None): """ if group is not None and not group.is_member(): return - ring_id = 0 if group is None else group.id + group = _get_default_group() if group is None else group + ring_id = group.id global_rank = _get_global_env().rank - rank = global_rank if group is None else group.get_group_rank(global_rank) - nranks = _get_global_env().world_size if group is None else group.nranks + rank = group.rank + nranks = group.nranks if _non_static_mode(): return _C_ops.c_concat(tensor, 'ring_id', ring_id, 'use_calc_stream', @@ -1765,9 +1694,21 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - ring_id = 0 if group is None else group.id + if framework._in_eager_mode_ and in_dygraph_mode(): + group = _get_default_group() if group is None else group + else: + ring_id = 0 if group is None else group.id + temp = paddle.concat(in_tensor_list, axis=0) nranks = len(in_tensor_list) + if framework._in_eager_mode_ and in_dygraph_mode(): + out = paddle.concat(out_tensor_list, axis=0) + task = group.process_group.alltoall(temp, out) + task.wait() + out_tensor_list.clear() + out_tensor_list.extend(paddle.split(out, nranks, 0)) + return + if _non_static_mode(): out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id) @@ -1834,6 +1775,16 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): """ if group is not None and not group.is_member(): return + + if framework._in_eager_mode_ and in_dygraph_mode(): + group = _get_default_group() if group is None else group + task = group.process_group.send(tensor, dst) + if use_calc_stream: + task.wait() + return None + else: + return task + ring_id = 0 if group is None else group.id if _non_static_mode(): @@ -1887,6 +1838,16 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): """ if group is not None and not group.is_member(): return + + if framework._in_eager_mode_ and in_dygraph_mode(): + group = _get_default_group() if group is None else group + task = group.process_group.recv(tensor, src) + if use_calc_stream: + task.wait() + return None + else: + return task + ring_id = 0 if group is None else group.id if _non_static_mode(): diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 16ed528b64f0c..71ac15bd4b097 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -24,11 +24,21 @@ # deprecated module import from paddle.fluid import core +import paddle.fluid.framework as framework from paddle.fluid.framework import _set_expected_place from paddle.fluid.dygraph import parallel_helper from paddle.distributed.fleet.launch_utils import check_backend from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401 +import paddle.distributed.collective as collective +from paddle.distributed.collective import _group_map_by_name +from paddle.distributed.collective import _group_map +from paddle.distributed.collective import _default_group_name +from paddle.distributed.collective import _valid_backend_list +from paddle.distributed.collective import _default_backend +from paddle.distributed.collective import _default_store +from paddle.distributed.collective import _new_process_group_impl +from paddle.distributed.collective import Group __all__ = [] @@ -159,18 +169,88 @@ def train(): if not is_cpu_only and core.is_compiled_with_cuda(): _check_var_exists("FLAGS_selected_gpus") + backend = "nccl" if backend == "auto" else backend elif not is_cpu_only and core.is_compiled_with_xpu(): _check_var_exists('FLAGS_selected_xpus') + backend = "bkcl" if backend == "auto" else backend elif not is_cpu_only and core.is_compiled_with_npu(): _check_var_exists('FLAGS_selected_npus') + backend = "hccl" if backend == "auto" else backend elif not is_cpu_only and core.is_compiled_with_mlu(): _check_var_exists('FLAGS_selected_mlus') + backend = "cncl" if backend == "auto" else backend _check_var_exists("PADDLE_TRAINER_ID") _check_var_exists("PADDLE_CURRENT_ENDPOINT") _check_var_exists("PADDLE_TRAINERS_NUM") _check_var_exists("PADDLE_TRAINER_ENDPOINTS") + # NOTE(chenweihang): [ why config global place here? ] + # the dygraph mode will be set to default mode, + # users will not call `dygraph.guard` or `enable_dygraph` + # directly, if they want to switch default place, + # they need to call a function to change default place, + # here just set correctly place to users + if is_cpu_only: + place = core.CPUPlace() + elif core.is_compiled_with_cuda(): + place = core.CUDAPlace(parallel_env.device_id) + elif core.is_compiled_with_xpu(): + place = core.XPUPlace(parallel_env.device_id) + elif core.is_compiled_with_npu(): + place = core.NPUPlace(parallel_env.device_id) + elif core.is_compiled_with_mlu(): + place = core.MLUPlace(parallel_env.device_id) + + _set_expected_place(place) + + group = None + if backend in _valid_backend_list and framework._in_eager_mode_: + if _default_group_name in collective._group_map_by_name: + return collective._group_map_by_name[_default_group_name] + _default_backend = backend + rank = int(os.getenv("PADDLE_TRAINER_ID")) + world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) + assert rank >= 0 and world_size > rank and world_size > 1, ( + "rank must be non-negative and world_size must be the " + "maximum rank plus one. Moreover, at least two processes are " + "required to create a process group.") + master_addr = os.getenv("MASTER_ADDR", None) + master_port = os.getenv("MASTER_PORT", None) + if not master_addr or not master_port: + endpoints = os.getenv("PADDLE_MASTER", None) + if endpoints is None: + endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0] + assert endpoints, ( + "The environment variable 'MASTER_ADDR' and 'MASTER_PORT' " + "must be specified, for example 'export MASTER_ADDR=127.0.0.1' " + "and 'export MASTER_ADDR=54612'. Or you can start your training" + "with paddle.distributed.run module.") + master_addr, master_port = endpoints.split(":") + master_port = int(master_port) + is_master = rank == 0 + _default_store = core.TCPStore(master_addr, master_port, is_master, + world_size) + pg = _new_process_group_impl( + backend, + _default_store, + rank, + world_size, + _default_group_name, + pg_options=None) + ranks = list(range(world_size)) + group = Group( + rank, + world_size, + id=0, + ranks=ranks, + pg=pg, + name=_default_group_name) + collective._group_map_by_name[_default_group_name] = group + _group_map[0] = group + parallel_helper._set_parallel_ctx(True) + return group + node_num = set([i.split(":")[0] for i in parallel_env.trainer_endpoints]) # 3: init gloo context (step 1: httpsever start) init_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0")) @@ -202,24 +282,6 @@ def train(): strategy.current_endpoint = parallel_env.current_endpoint strategy.nrings = parallel_env.nrings - # NOTE(chenweihang): [ why config global place here? ] - # the dygraph mode will be set to default mode, - # users will not call `dygraph.guard` or `enable_dygraph` - # directly, if they want to switch default place, - # they need to call a function to change default place, - # here just set correctly place to users - if is_cpu_only: - place = core.CPUPlace() - elif core.is_compiled_with_cuda(): - place = core.CUDAPlace(parallel_env.device_id) - elif core.is_compiled_with_xpu(): - place = core.XPUPlace(parallel_env.device_id) - elif core.is_compiled_with_npu(): - place = core.NPUPlace(parallel_env.device_id) - elif core.is_compiled_with_mlu(): - place = core.MLUPlace(parallel_env.device_id) - - _set_expected_place(place) # init nccl or hccl or bkcl or heter context if is_cpu_only: parallel_helper._set_parallel_ctx( @@ -274,6 +336,7 @@ def train(): if parallel_env.rank == 0: http_server_d["running"] = False http_server.join() + return group def get_rank(): diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 64388aadb2f02..cac67a02ddec2 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -360,9 +360,10 @@ def sync_params_buffers(model, is_model_parallel=False): model_vars = [] for _, param in model._obtain_parameters_buffers().items(): - if not isinstance(param, core.VarBase): - raise TypeError("The data type of '%s' must be Varbase" % - param.name) + if not isinstance(param, (core.VarBase, core.eager.Tensor)): + raise TypeError( + "The data type of '%s' must be Varbase or eager.Tensor" % + param.name) # is_distributed param not need to sync when in mp mode if isinstance(param, ParamBase): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c816a8c4c231f..272ca806747ed 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -60,6 +60,7 @@ list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) list(APPEND DIST_TEST_OPS test_auto_parallel_save_load) list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert) list(APPEND DIST_TEST_OPS test_collective_process_group) +list(APPEND DIST_TEST_OPS test_eager_dist_api) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -311,6 +312,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_save_load) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_autoconvert) LIST(REMOVE_ITEM TEST_OPS test_collective_process_group) + LIST(REMOVE_ITEM TEST_OPS test_eager_dist_api) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) @@ -1147,6 +1149,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_process_group PROPERTIES TIMEOUT 120) + set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 300) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/init_process_group.py b/python/paddle/fluid/tests/unittests/init_process_group.py index 90926b1a021d3..c9c957572c515 100644 --- a/python/paddle/fluid/tests/unittests/init_process_group.py +++ b/python/paddle/fluid/tests/unittests/init_process_group.py @@ -37,11 +37,15 @@ def config(self): pass def test_init_process_group(self): - paddle.distributed.collective._init_parallel_env() - paddle.distributed.collective._new_group() - with self.assertRaises(ValueError): - paddle.distributed.collective._new_group( - backend="gloo", group_name="_default_pg") + with _test_eager_guard(): + paddle.distributed.init_parallel_env() + paddle.distributed.new_group() + group = paddle.distributed.new_group([-1, -2]) + assert group.process_group == None + + group = paddle.distributed.collective.Group(-1, 2, 0, [-1, -2]) + ret = paddle.distributed.barrier(group) + assert ret == None print("test ok\n") diff --git a/python/paddle/fluid/tests/unittests/process_group_nccl.py b/python/paddle/fluid/tests/unittests/process_group_nccl.py index b1da0777feb3d..7ae38b3bbc4d2 100644 --- a/python/paddle/fluid/tests/unittests/process_group_nccl.py +++ b/python/paddle/fluid/tests/unittests/process_group_nccl.py @@ -26,16 +26,16 @@ import paddle.fluid.core as core from paddle.fluid.framework import _test_eager_guard from paddle.fluid.dygraph.parallel import ParallelEnv +import paddle.distributed as dist def init_process_group(strategy=None): nranks = ParallelEnv().nranks rank = ParallelEnv().local_rank is_master = True if rank == 0 else False - store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks) - pg_group = core.ProcessGroupNCCL(store, rank, nranks) + pg_group = dist.init_parallel_env() - return pg_group + return pg_group.process_group class TestProcessGroupFp32(unittest.TestCase): @@ -68,12 +68,10 @@ def test_create_process_group_nccl(self): sum_result = tensor_x + tensor_y if pg.rank() == 0: - task = pg.allreduce(tensor_x) - task.wait() + task = dist.all_reduce(tensor_x) assert np.array_equal(tensor_x, sum_result) else: - task = pg.allreduce(tensor_y) - task.wait() + task = dist.all_reduce(tensor_y) assert np.array_equal(tensor_y, sum_result) print("test allreduce sum api ok") @@ -89,16 +87,41 @@ def test_create_process_group_nccl(self): max_result = paddle.maximum(tensor_x, tensor_y) if pg.rank() == 0: - task = pg.allreduce(tensor_x, core.ReduceOp.MAX) + task = dist.all_reduce( + tensor_x, dist.ReduceOp.MAX, use_calc_stream=False) task.wait() assert np.array_equal(tensor_x, max_result) else: - task = pg.allreduce(tensor_y, core.ReduceOp.MAX) + task = dist.all_reduce( + tensor_y, dist.ReduceOp.MAX, use_calc_stream=False) task.wait() assert np.array_equal(tensor_y, max_result) print("test allreduce max api ok") + # test allreduce min + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + min_result = paddle.minimum(tensor_x, tensor_y) + + if pg.rank() == 0: + task = dist.all_reduce( + tensor_x, dist.ReduceOp.MIN, use_calc_stream=False) + task.wait() + assert np.array_equal(tensor_x, min_result) + else: + task = dist.all_reduce( + tensor_y, dist.ReduceOp.MIN, use_calc_stream=False) + task.wait() + assert np.array_equal(tensor_y, min_result) + + print("test allreduce min api ok") + # test broadcast # rank 0 x = np.random.random(self.shape).astype(self.dtype) @@ -109,16 +132,14 @@ def test_create_process_group_nccl(self): broadcast_result = paddle.assign(tensor_x) if pg.rank() == 0: - task = pg.broadcast(tensor_x, 0) + task = dist.broadcast(tensor_x, 0, use_calc_stream=False) task.synchronize() paddle.device.cuda.synchronize() assert task.is_completed() assert np.array_equal(broadcast_result, tensor_x) else: - task = pg.broadcast(tensor_y, 0) - task.synchronize() + task = dist.broadcast(tensor_y, 0) paddle.device.cuda.synchronize() - assert task.is_completed() assert np.array_equal(broadcast_result, tensor_y) print("test broadcast api ok") @@ -126,8 +147,7 @@ def test_create_process_group_nccl(self): # test barrier # rank 0 if pg.rank() == 0: - task = pg.barrier() - task.wait() + dist.barrier() # rank 1 else: task = pg.barrier() @@ -151,9 +171,13 @@ def test_create_process_group_nccl(self): paddle.device.cuda.synchronize() # rank 1 else: - task = pg.all_gather(tensor_y, tensor_out) - task.wait() + tensor_out_list = [ + paddle.empty_like(tensor_x), paddle.empty_like(tensor_x) + ] + task = dist.all_gather( + tensor_out_list, tensor_y, use_calc_stream=False) paddle.device.cuda.synchronize() + tensor_out = paddle.concat(tensor_out_list) out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2]) out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2], [out_shape[0]]) @@ -178,12 +202,14 @@ def test_create_process_group_nccl(self): if pg.rank() == 0: task = pg.alltoall(tensor_x, tensor_out1) task.wait() - paddle.device.cuda.synchronize() # rank 1 else: - task = pg.alltoall(tensor_y, tensor_out2) - task.wait() + in_1, in_2 = paddle.split(tensor_y, 2) + out_1, out_2 = paddle.split(tensor_out2, 2) + out_tensor_list = [out_1, out_2] + task = dist.alltoall([in_1, in_2], out_tensor_list) paddle.device.cuda.synchronize() + tensor_out2 = paddle.concat(out_tensor_list) out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2], [self.shape[0]]) out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2]) @@ -201,18 +227,61 @@ def test_create_process_group_nccl(self): tensor_y = paddle.to_tensor(y) sum_result = tensor_x + tensor_y if pg.rank() == 0: - task = pg.reduce(tensor_x, 0) - task.wait() + task = dist.reduce(tensor_x, 0, use_calc_stream=True) paddle.device.cuda.synchronize() # rank 1 else: - task = pg.reduce(tensor_y, 0) + task = dist.reduce(tensor_y, 0, use_calc_stream=False) task.wait() paddle.device.cuda.synchronize() if pg.rank() == 0: assert np.array_equal(tensor_x, sum_result) print("test reduce sum api ok\n") + # test reduce max + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + max_result = paddle.maximum(tensor_x, tensor_y) + + if pg.rank() == 0: + task = dist.reduce( + tensor_x, 0, dist.ReduceOp.MAX, use_calc_stream=False) + task.wait() + assert np.array_equal(tensor_x, max_result) + else: + task = dist.reduce( + tensor_y, 0, dist.ReduceOp.MAX, use_calc_stream=False) + task.wait() + + print("test reduce max api ok") + + # test reduce min + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + min_result = paddle.minimum(tensor_x, tensor_y) + + if pg.rank() == 0: + task = dist.reduce( + tensor_x, 0, dist.ReduceOp.MIN, use_calc_stream=False) + task.wait() + assert np.array_equal(tensor_x, min_result) + else: + task = dist.reduce( + tensor_y, 0, dist.ReduceOp.MIN, use_calc_stream=False) + task.wait() + + print("test reduce min api ok") + # test Scatter # rank 0 in_shape = list(self.shape) @@ -222,12 +291,14 @@ def test_create_process_group_nccl(self): tensor_x = paddle.to_tensor(x) tensor_y = paddle.to_tensor(y) if pg.rank() == 0: - task = pg.scatter(tensor_x, tensor_y, 0) - task.wait() + in_1, in_2 = paddle.split(tensor_x, 2) + task = dist.scatter( + tensor_y, [in_1, in_2], 0, use_calc_stream=True) + #task.wait() paddle.device.cuda.synchronize() # rank 1 else: - task = pg.scatter(tensor_x, tensor_y, 0) + task = dist.scatter(tensor_y, [], 0, use_calc_stream=False) task.wait() paddle.device.cuda.synchronize() out1 = paddle.slice(tensor_x, [0], [0], [self.shape[0]]) @@ -239,6 +310,40 @@ def test_create_process_group_nccl(self): assert np.array_equal(tensor_y, out2) print("test scatter api ok\n") + # test send min + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + if pg.rank() == 0: + task = dist.send(tensor_x, 1, use_calc_stream=False) + task.wait() + else: + task = dist.recv(tensor_y, 0, use_calc_stream=False) + task.wait() + assert np.array_equal(tensor_y, tensor_x) + + print("test send api ok") + + # test send min + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + if pg.rank() == 0: + task = dist.send(tensor_x, 1, use_calc_stream=True) + else: + task = dist.recv(tensor_y, 0, use_calc_stream=True) + assert np.array_equal(tensor_y, tensor_x) + + print("test send api ok") + class TestProcessGroupFp16(TestProcessGroupFp32): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_eager_dist_api.py b/python/paddle/fluid/tests/unittests/test_eager_dist_api.py new file mode 100644 index 0000000000000..e00f90f4b0d5f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_dist_api.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestProcessGroup(TestMultipleGpus): + def test_process_group_nccl(self): + self.run_mnist_2gpu('process_group_nccl.py') + + def test_process_group_gloo(self): + self.run_mnist_2gpu('process_group_gloo.py') + + def test_init_process_group(self): + self.run_mnist_2gpu('init_process_group.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_single.py b/python/paddle/fluid/tests/unittests/test_fleet_base_single.py index 589d6adb0f52d..ff54035045b2e 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_single.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_single.py @@ -46,7 +46,7 @@ def setUp(self): def test_dygraph_single(self): paddle.disable_static() - fleet.init(is_collective=True) + paddle.distributed.init_parallel_env() layer = LinearNet() loss_fn = nn.MSELoss() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py index 587824a1dc74c..6c5a2375f6e51 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py @@ -70,6 +70,9 @@ def start_local_trainers(cluster, "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": "6170", + "NCCL_DEBUG": "INFO", "PADDLE_DISTRI_BACKEND": "gloo", # make init_parallel_env get 'gloo' argument. } From 7dd4a9fe686ea7ef31673e596d5b7eb1e601213c Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Sat, 2 Apr 2022 11:18:31 +0800 Subject: [PATCH 11/93] Fix a bug when reduceHigherDim in HIP (#41273) --- paddle/phi/kernels/funcs/reduce_function.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 0ee668c9ac1d9..39d708cad6b9b 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -808,7 +808,7 @@ __global__ void ReduceHigherDimKernel(const Tx* x, 1, 1, left_num); - kps::ElementwiseUnary( + kps::ElementwiseUnary( &reduce_compute, &reduce_input, transformer); kps::Reduce( + kps::ElementwiseUnary( &reduce_compute, &reduce_input, transformer); kps::Reduce Date: Sat, 2 Apr 2022 11:35:04 +0800 Subject: [PATCH 12/93] add topk cast (#41304) --- python/paddle/fluid/layers/nn.py | 8 +++++++- python/paddle/fluid/layers/tensor.py | 7 ++++++- python/paddle/fluid/tests/unittests/op_test.py | 2 ++ .../paddle/fluid/tests/unittests/test_cast_op.py | 15 +++++++++++++++ .../fluid/tests/unittests/test_reduce_op.py | 1 + .../fluid/tests/unittests/test_top_k_v2_op.py | 4 ++-- python/paddle/tensor/search.py | 6 ++++++ python/paddle/utils/code_gen/api.yaml | 3 ++- python/paddle/utils/code_gen/backward.yaml | 10 ++++++++++ 9 files changed, 51 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0d2c1f14f2ddd..75583fb5c109a 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4791,7 +4791,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_prod(y, dim=[1, 2]) # [24.0, 1680.0] fluid.layers.reduce_prod(y, dim=[0, 1]) # [105.0, 384.0] """ - helper = LayerHelper('reduce_prod', **locals()) + if dim is not None and not isinstance(dim, list): if isinstance(dim, tuple): dim = list(dim) @@ -4801,6 +4801,12 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): raise TypeError( "The type of axis must be int, list or tuple, but received {}". format(type(dim))) + if in_dygraph_mode(): + return _C_ops.final_state_reduce_prod( + input, dim if dim != None and dim != [] else [0], keep_dim, True if + dim == None or dim == [] or len(dim) == len(input.shape) else False) + + helper = LayerHelper('reduce_prod', **locals()) check_variable_and_dtype( input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod') out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 252e4931b39a4..ff7008fddd47d 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -21,7 +21,7 @@ from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..initializer import Initializer -from ..framework import convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph +from ..framework import convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode from ..framework import Variable from ..initializer import Constant from ..core import VarDesc @@ -243,6 +243,11 @@ def cast(x, dtype): x = paddle.to_tensor([2, 3, 4], 'float64') y = paddle.cast(x, 'uint8') """ + if in_dygraph_mode(): + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + return _C_ops.final_state_cast(x, dtype) + if _non_static_mode(): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index be883d243f795..1756537ba6240 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1559,6 +1559,8 @@ def calculate_output(self): def _compare_numpy(self, name, actual_np, expect_np): with _test_eager_guard(): + print(actual_np) + print(expect_np) super()._compare_numpy(name, actual_np, expect_np) def convert_uint16_to_float_ifneed(self, actual_np, expect_np): diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index d80a9dc920076..a828eca4f4ba7 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard from op_test import OpTest, convert_uint16_to_float, convert_float_to_uint16 +from paddle.fluid.framework import _test_eager_guard class TestCastOpFp32ToFp64(OpTest): @@ -115,6 +116,20 @@ def test_errors(self): self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32') +class TestCastOpEager(unittest.TestCase): + def test_eager(self): + with paddle.fluid.dygraph.base.guard(): + with _test_eager_guard(): + x = paddle.ones([2, 2], dtype="float16") + x.stop_gradient = False + out = paddle.cast(x, "float32") + self.assertTrue( + np.array_equal(out, np.ones([2, 2]).astype("float32"))) + out.backward() + self.assertTrue(np.array_equal(x.gradient(), x.numpy())) + self.assertTrue(x.gradient().dtype == np.float16) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 737e1af851fa7..98607fb07fedf 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -241,6 +241,7 @@ def test_check_output(self): class TestProdOp(OpTest): def setUp(self): self.op_type = "reduce_prod" + self.python_api = paddle.prod self.init_data_type() self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)} self.outputs = {'Out': self.inputs['X'].prod(axis=0)} diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index f1c4ca18da72b..c4f50414f954e 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -57,10 +57,10 @@ def setUp(self): self.outputs = {'Out': output, 'Indices': indices} def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(set(['X']), 'Out', check_eager=False) + self.check_grad(set(['X']), 'Out', check_eager=True) class TestTopkOp1(TestTopkOp): diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 7a2dd22cff294..15c9e060c5517 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -858,6 +858,12 @@ def topk(x, k, axis=None, largest=True, sorted=True, name=None): """ + if in_dygraph_mode(): + if axis == None: + axis = -1 + out, indices = _C_ops.final_state_top_k(x, k, axis, largest, sorted) + return out, indices + if _non_static_mode(): if axis is None: out, indices = _C_ops.top_k_v2(x, 'k', diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index ef1e4797874a8..466c26d3f46c9 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1123,7 +1123,8 @@ infer_meta : func : ReduceInferMetaBase kernel : - func : reduce_prod + func : prod_raw + backward : reduce_prod_grad - api : relu args : (Tensor x) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index a59b02c34cf76..48faa4682d742 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -721,6 +721,16 @@ kernel : func : reciprocal_grad +- backward_api : reduce_prod_grad + forward : reduce_prod (Tensor x, int64_t[] dims, bool keep_dim, bool reduce_all) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims, bool keep_dim, bool reduce_all) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : reduce_prod_grad + - backward_api : relu_double_grad forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) args : (Tensor out, Tensor grad_x_grad) From d0f46aacbdf381bef3bae146f6b41c6d0ca5d6aa Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Sat, 2 Apr 2022 12:40:02 +0800 Subject: [PATCH 13/93] [KP] fix bug in phi static graph mode (#41269) * [KP] fix bug in phi static graph mode * modify the useless code --- paddle/fluid/framework/operator.cc | 79 ++++++++++++++++++-- paddle/fluid/imperative/prepared_operator.cc | 7 +- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 19fa0f66739ce..49248edd322d2 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1293,16 +1293,54 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } else { pt_kernel_name = pt_kernel_signature_->name; +// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], +// But the default library_type is Plain, so we need to modify the +// library_type here, otherwise it can't work. +#ifdef PADDLE_WITH_XPU_KP + if (paddle::platform::is_xpu_place(kernel_type_->place_)) { + bool use_xpu_kp_kernel_rt = + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_in_xpu_kpwhite_list(type_); + if (use_xpu_kp_kernel_rt) { + VLOG(3) << "phi xpu_kp using rt mode in static graph"; + } + if (use_xpu_kp_kernel_debug) { + VLOG(3) << "phi xpu_kp using debug mode in static graph"; + } + bool is_xpu_kp_support = + (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + auto expected_kernel_key_library_type = kernel_type_->library_type_; + kernel_type_->library_type_ = LibraryType::kKP; + VLOG(3) << "modifing XPU KP kernel in static graph: " << type_ + << ", using_kernel_key:" << *kernel_type_.get(); + auto try_pt_kernel_key = + TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); + if (!phi::KernelFactory::Instance().IsSelectKernelValid( + pt_kernel_name, try_pt_kernel_key)) { + kernel_type_->library_type_ = expected_kernel_key_library_type; + VLOG(3) << "modify XPU KP kernel in static graph: " << type_ + << " is failed " << *kernel_type_.get(); + } + } + } +#endif pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); } -#ifdef PADDLE_WITH_XPU + +// NOTE(Liu-xiandong): Determine whether the selected kernel is valid +// If not, use the kernel registered in fluid. And if the fluid do not +// contains the related heterogeneous kernel, use phi CPU kernel. +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) bool is_xpu_unsupport = paddle::platform::is_xpu_place(kernel_type_->place_) && !paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) || paddle::platform::is_in_xpu_black_list(type_); #endif if (pt_kernel_->IsValid() -#ifdef PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) && !is_xpu_unsupport #endif ) { @@ -1310,10 +1348,29 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } else { auto& all_op_kernels = AllOpKernels(); auto kernels_iter = all_op_kernels.find(type_); + +// NOTE(Liu-xiandong): If we can't find heterogeneous kernel in phi, +// we need to select the heterogeneous kernel in fluid, but the kernel +// registered in KP use library_type[KP], we need to modify it. +#ifdef PADDLE_WITH_XPU_KP + bool use_xpu_kp_kernel_rt = + paddle::platform::is_xpu_place(kernel_type_->place_) && + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_xpu_place(kernel_type_->place_) && + paddle::platform::is_in_xpu_kpwhite_list(type_); + bool is_xpu_kp_support = + (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + kernel_type_->library_type_ = LibraryType::kKP; + } +#endif + if (kernels_iter == all_op_kernels.end() || kernels_iter->second.find(*kernel_type_.get()) == kernels_iter->second.end() -#ifdef PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || is_xpu_unsupport #endif ) { @@ -1552,10 +1609,22 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { } bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { + auto cache_expected_kernel_key_library_type = + expected_kernel_key.library_type_; expected_kernel_key.library_type_ = LibraryType::kKP; kernel_iter = kernels.find(expected_kernel_key); - VLOG(3) << "using XPU KP kernel: " << type_ - << ", using_kernel_key:" << expected_kernel_key; + // if can't find corresponding kernel when is_xpu_kp_support is on + // if the fluid do not register related kernel, it can't work and hava + // error as before + if (kernel_iter == kernels.end()) { + expected_kernel_key.library_type_ = + cache_expected_kernel_key_library_type; + expected_kernel_key.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(expected_kernel_key); + } else { + VLOG(3) << "using XPU KP kernel: " << type_ + << ", using_kernel_key:" << expected_kernel_key; + } } bool is_xpu_unsupport = (!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) || diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 077dd54bc9fa5..b56d113937d69 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -174,7 +174,9 @@ PreparedOp PrepareImpl(const NameVarMap& ins, VLOG(6) << pt_kernel_signature; pt_kernel_name = pt_kernel_signature.name; -// modify the expected_kernel_key for KP in phi +// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], +// But the default library_type is Plain, so we need to modify the +// library_type here, otherwise it can't work. #ifdef PADDLE_WITH_XPU_KP if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { bool use_xpu_kp_kernel_rt = @@ -238,6 +240,9 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto& all_op_kernels = op.AllOpKernels(); auto kernels_iter = all_op_kernels.find(op.Type()); +// NOTE(Liu-xiandong): If we can't find heterogeneous kernel in phi, +// we need to select the heterogeneous kernel in fluid, but the kernel +// registered in KP use library_type[KP], we need to modify it. #ifdef PADDLE_WITH_XPU_KP bool use_xpu_kp_kernel_rt = paddle::platform::is_xpu_place(expected_kernel_key.place_) && From 66d1b1f6b0b554040bc6b30eced5cfad459f555b Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Sat, 2 Apr 2022 12:47:47 +0800 Subject: [PATCH 14/93] update infrt build parallel (#41278) --- paddle/scripts/infrt_build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index 6634f5396ac74..6b0611bf61cdc 100755 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -81,6 +81,7 @@ function init() { } function infrt_gen_and_build() { + parallel_number=24 if [ "$1" != "" ]; then parallel_number=$1 fi From 5d3fd4fee7df4c2dda48212d263fc7d5ac6f6260 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 2 Apr 2022 13:53:41 +0800 Subject: [PATCH 15/93] Sparse conv and pool support indices as template (#41137) --- paddle/phi/kernels/empty_kernel.cc | 4 + paddle/phi/kernels/funcs/sparse/convolution.h | 37 +-- .../kernels/sparse/convolution_grad_kernel.h | 4 +- .../phi/kernels/sparse/convolution_kernel.h | 6 +- paddle/phi/kernels/sparse/cpu/convolution.h | 75 +++--- .../sparse/cpu/convolution_grad_kernel.cc | 131 ++++++---- .../kernels/sparse/cpu/convolution_kernel.cc | 96 ++++--- .../sparse/cpu/sparse_pool_grad_kernel.cc | 55 +++- .../kernels/sparse/cpu/sparse_pool_kernel.cc | 72 ++++-- .../phi/kernels/sparse/gpu/convolution.cu.h | 241 +++++++++--------- .../sparse/gpu/convolution_grad_kernel.cu | 143 +++++++---- .../kernels/sparse/gpu/convolution_kernel.cu | 117 +++++---- .../sparse/gpu/sparse_pool_grad_kernel.cu | 77 ++++-- .../kernels/sparse/gpu/sparse_pool_kernel.cu | 99 ++++--- .../kernels/sparse/sparse_pool_grad_kernel.h | 20 +- .../phi/kernels/sparse/sparse_pool_kernel.h | 6 +- .../kernels/test_sparse_conv3d_dev_api.cc | 148 +++++++---- .../tests/kernels/test_sparse_pool_dev_api.cc | 120 +++++---- 18 files changed, 862 insertions(+), 589 deletions(-) diff --git a/paddle/phi/kernels/empty_kernel.cc b/paddle/phi/kernels/empty_kernel.cc index e547e0ea1318d..06d258a8a4e80 100644 --- a/paddle/phi/kernels/empty_kernel.cc +++ b/paddle/phi/kernels/empty_kernel.cc @@ -45,6 +45,7 @@ PD_REGISTER_KERNEL(empty, phi::EmptyKernel, float, double, + int8_t, uint8_t, int16_t, int, @@ -61,6 +62,7 @@ PD_REGISTER_KERNEL(empty_like, phi::EmptyLikeKernel, float, double, + int8_t, uint8_t, int16_t, int, @@ -80,6 +82,7 @@ PD_REGISTER_KERNEL(empty, phi::EmptyKernel, float, double, + int8_t, uint8_t, int16_t, int, @@ -95,6 +98,7 @@ PD_REGISTER_KERNEL(empty_like, phi::EmptyLikeKernel, float, double, + int8_t, uint8_t, int16_t, int, diff --git a/paddle/phi/kernels/funcs/sparse/convolution.h b/paddle/phi/kernels/funcs/sparse/convolution.h index 19f1f3d3cd2fa..f3caa2a62f4a8 100644 --- a/paddle/phi/kernels/funcs/sparse/convolution.h +++ b/paddle/phi/kernels/funcs/sparse/convolution.h @@ -33,28 +33,30 @@ struct Dims4D { }; // Judge whether the current position x is in (lower, upper) -inline HOSTDEVICE bool Check(const int& x, +template +inline HOSTDEVICE bool Check(const IntT& x, const int& kx, const int& pad, const int& stride, const int dilation, const int kdim, const int xdim) { - const int lower = x - dilation * kx + pad; - const int uper = x + (kdim - kx - 1) * dilation - pad; + const IntT lower = x - dilation * kx + pad; + const IntT uper = x + (kdim - kx - 1) * dilation - pad; return (lower >= 0 && lower % stride == 0 && uper < xdim); } // Check whether the current position(x, y, z) is legal: // Judge the minimum and maximum values at each latitude +template inline HOSTDEVICE bool Check(const Dims4D& dims, const Dims4D& kernel_dims, const Dims4D& paddings, const Dims4D& dilations, const Dims4D& strides, - const int x, - const int y, - const int z, + const IntT x, + const IntT y, + const IntT z, const int kx, const int ky, const int kz) { @@ -67,22 +69,22 @@ inline HOSTDEVICE bool Check(const Dims4D& dims, return (x_valid && y_valid && z_valid); } -template -inline HOSTDEVICE int PointToIndex(const int& batch, - const int& x, - const int& y, - const int& z, - const Dim& dims) { +template +inline HOSTDEVICE IntT PointToIndex(const IntT& batch, + const IntT& x, + const IntT& y, + const IntT& z, + const Dim& dims) { return batch * dims[1] * dims[2] * dims[3] + z * dims[2] * dims[3] + y * dims[3] + x; } // TODO(zhangkaihuo): use division and multiply to optimize // modulo operation -template +template inline HOSTDEVICE void IndexToPoint( - const int index, const Dim& dims, int* batch, int* x, int* y, int* z) { - int n = index; + const IntT index, const Dim& dims, IntT* batch, IntT* x, IntT* y, IntT* z) { + IntT n = index; *x = n % dims[3]; n /= dims[3]; *y = n % dims[2]; @@ -176,8 +178,9 @@ inline const std::vector PoolResetKernel( return res; } -inline void PrefixSum(const int* counter, int* offsets, const int n) { - int offset = 0; +template +inline void PrefixSum(const T* counter, T* offsets, const int n) { + T offset = 0; for (int i = 0; i < n; i++) { offsets[i] = offset; offset += counter[i]; diff --git a/paddle/phi/kernels/sparse/convolution_grad_kernel.h b/paddle/phi/kernels/sparse/convolution_grad_kernel.h index 5a47575141a2d..eebfcddfc7a9e 100644 --- a/paddle/phi/kernels/sparse/convolution_grad_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_grad_kernel.h @@ -49,8 +49,8 @@ std::tuple Conv3dGrad( const int groups, const bool subm) { SparseCooTensor x_grad; - DenseTensor kernel_grad = phi::Empty( - dev_ctx, DenseTensorMeta(kernel.dtype(), {1}, kernel.layout())); + DenseTensor kernel_grad; + // TODO(zhangkaihuo): call InferMeta func here Conv3dGradKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/sparse/convolution_kernel.h b/paddle/phi/kernels/sparse/convolution_kernel.h index ff2cf94edb5a3..6120d6339a7eb 100644 --- a/paddle/phi/kernels/sparse/convolution_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_kernel.h @@ -45,11 +45,7 @@ SparseCooTensor Conv3d(const Context& dev_ctx, const int groups, const bool subm, DenseTensor* rulebook) { - DenseTensor indices = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); - DenseTensor values = - phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); - SparseCooTensor coo(indices, values, x.dims()); + SparseCooTensor coo; Conv3dKernel(dev_ctx, x, kernel, diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/convolution.h index 4ea93f4ad5aaf..b2544619774c2 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution.h +++ b/paddle/phi/kernels/sparse/cpu/convolution.h @@ -31,7 +31,7 @@ using Dims4D = phi::funcs::sparse::Dims4D; // such as: kernel(3, 3, 3), kernel_size = 27 // counter_per_weight: (kernel_size) // TODO(zhangkaihuo): optimize performance with multithreading -template +template void ProductRuleBook(const Context& dev_ctx, const SparseCooTensor& x, const std::vector& kernel_sizes, @@ -44,7 +44,7 @@ void ProductRuleBook(const Context& dev_ctx, DenseTensor* counter_per_kernel) { const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); - const int* indices_ptr = non_zero_indices.data(); + const IntT* indices_ptr = non_zero_indices.data(); int* counter_ptr = counter_per_kernel->data(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; memset(counter_ptr, 0, kernel_size * sizeof(int)); @@ -60,33 +60,33 @@ void ProductRuleBook(const Context& dev_ctx, const Dims4D c_strides(1, strides[2], strides[1], strides[0]); const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]); - std::set hash_in; + std::set hash_in; if (subm) { for (int i = 0; i < non_zero_num; i++) { - int batch = indices_ptr[i]; - int in_z = indices_ptr[i + non_zero_num]; - int in_y = indices_ptr[i + 2 * non_zero_num]; - int in_x = indices_ptr[i + 3 * non_zero_num]; - int index = phi::funcs::sparse::PointToIndex( + IntT batch = indices_ptr[i]; + IntT in_z = indices_ptr[i + non_zero_num]; + IntT in_y = indices_ptr[i + 2 * non_zero_num]; + IntT in_x = indices_ptr[i + 3 * non_zero_num]; + IntT index = phi::funcs::sparse::PointToIndex( batch, in_x, in_y, in_z, x_dims); hash_in.insert(index); } } - auto f_calc_rulebook = [&](int* rulebook_ptr) { + auto f_calc_rulebook = [&](IntT* rulebook_ptr) { int kernel_index = 0, rulebook_index = 0; for (int kz = 0; kz < kernel_sizes[0]; kz++) { for (int ky = 0; ky < kernel_sizes[1]; ky++) { for (int kx = 0; kx < kernel_sizes[2]; kx++) { ++kernel_index; for (int64_t i = 0; i < non_zero_num; i++) { - int batch = indices_ptr[i]; - int in_z = indices_ptr[i + non_zero_num]; - int in_y = indices_ptr[i + 2 * non_zero_num]; - int in_x = indices_ptr[i + 3 * non_zero_num]; - int out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0]; - int out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1]; - int out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2]; + IntT batch = indices_ptr[i]; + IntT in_z = indices_ptr[i + non_zero_num]; + IntT in_y = indices_ptr[i + 2 * non_zero_num]; + IntT in_x = indices_ptr[i + 3 * non_zero_num]; + IntT out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0]; + IntT out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1]; + IntT out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2]; if (phi::funcs::sparse::Check(c_x_dims, c_kernel_dims, c_paddings, @@ -99,7 +99,7 @@ void ProductRuleBook(const Context& dev_ctx, ky, kz)) { if (subm) { - int out_index = phi::funcs::sparse::PointToIndex( + IntT out_index = phi::funcs::sparse::PointToIndex( batch, out_x, out_y, out_z, out_dims); if (hash_in.find(out_index) == hash_in.end()) { continue; @@ -126,15 +126,16 @@ void ProductRuleBook(const Context& dev_ctx, f_calc_rulebook(nullptr); // alloc the rulebook - DenseTensorMeta rulebook_meta( - DataType::INT32, {3, rulebook_len}, DataLayout::NCHW); - rulebook->set_meta(rulebook_meta); - dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int)); - int* rulebook_ptr = rulebook->data(); + *rulebook = phi::Empty( + dev_ctx, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {3, rulebook_len}, + DataLayout::NCHW)); + IntT* rulebook_ptr = rulebook->data(); f_calc_rulebook(rulebook_ptr); } -template +template void UpdateRulebookAndOutIndex(const Context& dev_ctx, const SparseCooTensor& x, const int kernel_size, @@ -142,9 +143,9 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, const DDim& out_dims, DenseTensor* rulebook, SparseCooTensor* out) { - std::set out_indexs; + std::set out_indexs; int n = rulebook->dims()[1]; - int* rulebook_ptr = rulebook->data(); + IntT* rulebook_ptr = rulebook->data(); for (int i = 0; i < n; i++) { out_indexs.insert(rulebook_ptr[i + n * 2]); } @@ -152,17 +153,19 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, int out_non_zero_num = out_indexs.size(); const int64_t sparse_dim = 4; DenseTensorMeta indices_meta( - DataType::INT32, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); + paddle::experimental::CppTypeToDataType::Type(), + {sparse_dim, out_non_zero_num}, + DataLayout::NCHW); DenseTensorMeta values_meta(x.dtype(), {out_non_zero_num, out_channels}, x.non_zero_elements().layout()); phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); - int* out_indices_ptr = out_indices.data(); + IntT* out_indices_ptr = out_indices.data(); int i = 0; for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) { - const int index = *it; - int batch, x, y, z; + const IntT index = *it; + IntT batch, x, y, z; phi::funcs::sparse::IndexToPoint(index, out_dims, &batch, &x, &y, &z); out_indices_ptr[i] = batch; out_indices_ptr[i + out_non_zero_num] = z; @@ -170,7 +173,7 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, out_indices_ptr[i + out_non_zero_num * 3] = x; } for (i = 0; i < n; i++) { - int out_index = rulebook_ptr[i + n * 2]; + IntT out_index = rulebook_ptr[i + n * 2]; rulebook_ptr[i + n * 2] = std::distance(out_indexs.begin(), out_indexs.find(out_index)); } @@ -178,20 +181,20 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, out->SetMember(out_indices, out_values, out_dims, true); } -template +template void Gather( - const T* x, const int* indexs, const int n, const int channels, T* out) { + const T* x, const IntT* indexs, const int n, const int channels, T* out) { for (int i = 0; i < n; i++) { - int real_i = indexs[i]; + IntT real_i = indexs[i]; memcpy(out + i * channels, x + real_i * channels, channels * sizeof(T)); } } -template +template void Scatter( - const T* x, const int* indexs, const int n, const int channels, T* out) { + const T* x, const IntT* indexs, const int n, const int channels, T* out) { for (int i = 0; i < n; i++) { - int real_i = indexs[i]; + IntT real_i = indexs[i]; for (int j = 0; j < channels; j++) { out[real_i * channels + j] += x[i * channels + j]; } diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index 29079918cbf86..80693c90d1e7f 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -29,24 +31,24 @@ namespace sparse { //] // x_grad = out_grad * transpose(kenrel) // kernel_grad = transpose(x) * out_grad -template -void Conv3dGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const DenseTensor& rulebook, - const SparseCooTensor& out_grad, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* x_grad, - DenseTensor* kernel_grad) { +template +void Conv3dGradCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; - const int* rulebook_ptr = rulebook.data(); + const IntT* rulebook_ptr = rulebook.data(); const int rulebook_len = rulebook.dims()[1]; @@ -66,32 +68,30 @@ void Conv3dGradKernel(const Context& dev_ctx, T* in_features_ptr = in_features.data(); T* d_x_features_ptr = d_x_features.data(); T* out_grad_features_ptr = out_grad_features.data(); - kernel_grad->Resize(kernel_dims); - dev_ctx.Alloc( - kernel_grad, kernel_grad->dtype(), kernel_grad->numel() * sizeof(T)); + *kernel_grad = phi::EmptyLike(dev_ctx, kernel); T* d_kernel_ptr = kernel_grad->data(); memset(d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel()); int half_kernel_size = kernel_size / 2; - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); DenseTensor x_grad_indices = - phi::EmptyLike(dev_ctx, x.non_zero_indices()); + phi::EmptyLike(dev_ctx, x.non_zero_indices()); DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); T* x_grad_values_ptr = x_grad_values.data(); memset(x_grad_values_ptr, 0, sizeof(T) * x_grad_values.numel()); memset(d_x_features_ptr, 0, sizeof(T) * d_x_features.numel()); - phi::Copy(dev_ctx, - x.non_zero_indices(), - dev_ctx.GetPlace(), - false, - &x_grad_indices); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0); for (int i = 0; i < rulebook_len; i++) { counter[rulebook_ptr[i]] += 1; } - int offset = 0, max_count = 0; + IntT offset = 0, max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; offset += counter[i]; @@ -102,30 +102,31 @@ void Conv3dGradKernel(const Context& dev_ctx, offsets[kernel_size] = offset; if (subm) { - phi::funcs::sparse::SubmPreProcess(dev_ctx, - x, - kernel, - out_grad.non_zero_elements(), - in_channels, - out_channels, - half_kernel_size, - kernel_grad, - &x_grad_values); + phi::funcs::sparse::SubmPreProcess( + dev_ctx, + x, + kernel, + out_grad.non_zero_elements(), + in_channels, + out_channels, + half_kernel_size, + kernel_grad, + &x_grad_values); if (max_count == 0) { return; } } - Gather(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - rulebook_len, - in_channels, - in_features_ptr); - Gather(out_grad.non_zero_elements().data(), - rulebook_ptr + rulebook_len * 2, - rulebook_len, - out_channels, - out_grad_features_ptr); + Gather(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + rulebook_len, + in_channels, + in_features_ptr); + Gather(out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + rulebook_len, + out_channels, + out_grad_features_ptr); const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { @@ -170,11 +171,41 @@ void Conv3dGradKernel(const Context& dev_ctx, } // 4. scatter - Scatter(d_x_features_ptr, - rulebook.data() + rulebook_len, - rulebook_len, - in_channels, - x_grad_values_ptr); + Scatter(d_x_features_ptr, + rulebook.data() + rulebook_len, + rulebook_len, + in_channels, + x_grad_values_ptr); +} + +template +void Conv3dGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dGradCPUKernel", ([&] { + Conv3dGradCPUKernel(dev_ctx, + x, + kernel, + rulebook, + out_grad, + paddings, + dilations, + strides, + groups, + subm, + x_grad, + kernel_grad); + })); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc index f022e4ef4bb63..a1c8cf014c7fb 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc @@ -17,6 +17,8 @@ limitations under the License. */ #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -25,17 +27,17 @@ namespace sparse { * kernel: (D, H, W, C, OC) * out: (N, D, H, W, OC) **/ -template -void Conv3dKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void Conv3dCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { // update padding and dilation // Currently, only support x.layout is NDHWC, groups = 1 // if x.layout != NDHWC then transpose(x), transpose(weight) @@ -66,18 +68,18 @@ void Conv3dKernel(const Context& dev_ctx, DataType::INT32, {kernel_size}, DataLayout::NCHW); DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); - ProductRuleBook(dev_ctx, - x, - kernel_sizes, - subm_paddings, - dilations, - subm_strides, - out_dims, - subm, - rulebook, - &counter_per_kernel); - - UpdateRulebookAndOutIndex( + ProductRuleBook(dev_ctx, + x, + kernel_sizes, + subm_paddings, + dilations, + subm_strides, + out_dims, + subm, + rulebook, + &counter_per_kernel); + + UpdateRulebookAndOutIndex( dev_ctx, x, kernel_size, out_channels, out_dims, rulebook, out); int n = rulebook->dims()[1]; @@ -95,14 +97,14 @@ void Conv3dKernel(const Context& dev_ctx, T* in_features_ptr = in_features.data(); T* out_features_ptr = out_features.data(); - Gather(x.non_zero_elements().data(), - rulebook->data() + n, - n, - in_channels, - in_features_ptr); + Gather(x.non_zero_elements().data(), + rulebook->data() + n, + n, + in_channels, + in_features_ptr); // 3. call gemm for every werght - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); std::vector offsets(kernel_size + 1); int offset = 0; for (int i = 0; i < kernel_size; i++) { @@ -139,11 +141,37 @@ void Conv3dKernel(const Context& dev_ctx, // 4. scatter T* out_values_ptr = out->mutable_non_zero_elements()->data(); memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels); - Scatter(out_features_ptr, - rulebook->data() + n * 2, - n, - out_channels, - out_values_ptr); + Scatter(out_features_ptr, + rulebook->data() + n * 2, + n, + out_channels, + out_values_ptr); +} + +template +void Conv3dKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dCPUKernel", ([&] { + Conv3dCPUKernel(dev_ctx, + x, + kernel, + paddings, + dilations, + strides, + groups, + subm, + out, + rulebook); + })); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc index 3010d480b55c9..30221975e7756 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc @@ -14,24 +14,28 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { -template -void MaxPoolGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const DenseTensor& out_grad, - const std::vector& kernel_sizes, - DenseTensor* x_grad) { +template +void MaxPoolGradCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; - const int* rulebook_ptr = rulebook.data(); + const IntT* rulebook_ptr = rulebook.data(); std::vector offsets(kernel_size + 1), counter(kernel_size, 0); for (int i = 0; i < rulebook_len; i++) { counter[rulebook_ptr[i]] += 1; @@ -40,15 +44,25 @@ void MaxPoolGradKernel(const Context& dev_ctx, const T* in_features_ptr = x.non_zero_elements().data(); const T* out_features_ptr = out.non_zero_elements().data(); - const T* out_grad_ptr = out_grad.data(); - T* x_grad_ptr = x_grad->data(); + const T* out_grad_ptr = out_grad.non_zero_elements().data(); + // TODO(zhangkaihuo): call phi::sparse::EmptyLike + DenseTensor x_grad_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); + T* x_grad_ptr = x_grad_values.data(); memset(x_grad_ptr, 0, sizeof(T) * x_grad->numel()); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); phi::funcs::MaxPoolGrad grad_functor; for (int i = 0; i < kernel_size; i++) { for (int j = 0; j < counter[i]; j++) { - int in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; - int out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; + IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; + IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; for (int c = 0; c < channels; c++) { grad_functor.compute(in_features_ptr[in_i * channels + c], out_features_ptr[out_i * channels + c], @@ -60,6 +74,21 @@ void MaxPoolGradKernel(const Context& dev_ctx, } } +template +void MaxPoolGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolGradCPUKernel", ([&] { + MaxPoolGradCPUKernel( + dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc index 86971242df5ae..ed6e0200587e8 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -27,15 +29,15 @@ namespace sparse { * kernel: (D, H, W, C, OC) * out: (N, D, H, W, OC) **/ -template -void MaxPoolKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void MaxPoolCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -51,22 +53,22 @@ void MaxPoolKernel(const Context& dev_ctx, const T* in_features_ptr = x.non_zero_elements().data(); // 1. product rule book - ProductRuleBook(dev_ctx, - x, - real_kernel_sizes, - paddings, - dilations, - strides, - out_dims, - false, - rulebook, - &counter_per_kernel); - - UpdateRulebookAndOutIndex( + ProductRuleBook(dev_ctx, + x, + real_kernel_sizes, + paddings, + dilations, + strides, + out_dims, + false, + rulebook, + &counter_per_kernel); + + UpdateRulebookAndOutIndex( dev_ctx, x, kernel_size, in_channels, out_dims, rulebook, out); int rulebook_len = rulebook->dims()[1]; - const int* rulebook_ptr = rulebook->data(); + const IntT* rulebook_ptr = rulebook->data(); const int* counter_ptr = counter_per_kernel.data(); std::vector offsets(kernel_size + 1); @@ -78,8 +80,8 @@ void MaxPoolKernel(const Context& dev_ctx, phi::funcs::MaxPool max_pool_functor; for (int i = 0; i < kernel_size; i++) { for (int j = 0; j < counter_ptr[i]; j++) { - int in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; - int out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; + IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; + IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; if (!out_flags[out_i]) { out_flags[out_i] = true; memcpy(&out_features_ptr[out_i * in_channels], @@ -95,6 +97,28 @@ void MaxPoolKernel(const Context& dev_ctx, } } +template +void MaxPoolKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolCPUKernel", ([&] { + MaxPoolCPUKernel(dev_ctx, + x, + kernel_sizes, + paddings, + dilations, + strides, + out, + rulebook); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index a512a60b94ff8..5662a4fac71c5 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -98,21 +98,21 @@ __global__ void ScatterKernel(const T* input, } } -template -inline int* SortedAndUniqueIndex(const Context& dev_ctx, - const int* rulebook_ptr, - const int len, - DenseTensor* out_index, - DenseTensor* unique_key, - DenseTensor* unique_value) { +template +inline IntT* SortedAndUniqueIndex(const Context& dev_ctx, + const IntT* rulebook_ptr, + const int len, + DenseTensor* out_index, + DenseTensor* unique_key, + DenseTensor* unique_value) { phi::IndexKernel>( dev_ctx, out_index, kps::IdentityFunctor()); phi::IndexKernel>( dev_ctx, unique_value, kps::IdentityFunctor()); - phi::backends::gpu::GpuMemcpyAsync(unique_key->data(), + phi::backends::gpu::GpuMemcpyAsync(unique_key->data(), rulebook_ptr, - sizeof(int) * len, + sizeof(IntT) * len, #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToDevice, #else @@ -126,19 +126,19 @@ inline int* SortedAndUniqueIndex(const Context& dev_ctx, #else thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()), #endif - unique_key->data(), - unique_key->data() + len, + unique_key->data(), + unique_key->data() + len, out_index->data()); // 4. unique - thrust::pair new_end = + thrust::pair new_end = #ifdef PADDLE_WITH_HIP thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()), #else thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()), #endif - unique_key->data(), - unique_key->data() + len, + unique_key->data(), + unique_key->data() + len, unique_value->data()); return new_end.first; } @@ -159,7 +159,7 @@ __global__ void SetFlagAndUpdateCounterKernel(const int* indexs, for (int i = tid; i < n; i += gridDim.x * blockDim.x) { int index = indexs[i]; - int kernel_index = rulebook_ptr[index]; + T kernel_index = rulebook_ptr[index]; rulebook_ptr[index + rulebook_len] = -1; rulebook_ptr[index + 2 * rulebook_len] = -1; rulebook_ptr[index] = -1; @@ -183,18 +183,18 @@ __global__ void SetFlagAndUpdateCounterKernel(const int* indexs, * rulebook_out_indexs: the output index in rulebook **/ template -__global__ void UpdateIndexKernel(const int* unique_keys, +__global__ void UpdateIndexKernel(const T* unique_keys, const int* unique_values, const int* out_indexs, - const int non_zero_num, + const int64_t non_zero_num, const int rulebook_len, const Dims4D out_dims, T* out_indices, T* rulebook_out_indexs) { int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { - const int index = unique_keys[i]; - int batch, x, y, z; + const T index = unique_keys[i]; + T batch, x, y, z; phi::funcs::sparse::IndexToPoint( index, out_dims, &batch, &x, &y, &z); // get out indices @@ -207,7 +207,7 @@ __global__ void UpdateIndexKernel(const int* unique_keys, int start = unique_values[i]; int end = i == non_zero_num - 1 ? rulebook_len : unique_values[i + 1]; // max(end-start) = kernel_size - for (int j = start; j < end; j++) { + for (T j = start; j < end; j++) { rulebook_out_indexs[out_indexs[j]] = i; } } @@ -215,7 +215,7 @@ __global__ void UpdateIndexKernel(const int* unique_keys, // brief: calculation the distance between start and end template -__global__ void DistanceKernel(const T* start, const T* end, int* distance) { +__global__ void DistanceKernel(const T* start, const T* end, T* distance) { if (threadIdx.x == 0) { *distance = end - start; } @@ -249,7 +249,7 @@ __global__ void ProductRuleBookKernel(const T* x_indices, const bool subm, T* rulebook, int* counter, - int* in_indexs) { + T* in_indexs) { int tid = threadIdx.x + blockIdx.x * blockDim.x; extern __shared__ int counter_buf[]; // kernel_size const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1]; @@ -261,10 +261,10 @@ __global__ void ProductRuleBookKernel(const T* x_indices, for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { int kernel_index = 0; - int batch = x_indices[i]; - int in_z = x_indices[i + non_zero_num]; - int in_y = x_indices[i + 2 * non_zero_num]; - int in_x = x_indices[i + 3 * non_zero_num]; + T batch = x_indices[i]; + T in_z = x_indices[i + non_zero_num]; + T in_y = x_indices[i + 2 * non_zero_num]; + T in_x = x_indices[i + 3 * non_zero_num]; if (subm) { in_indexs[i] = PointToIndex(batch, in_x, in_y, in_z, x_dims); } @@ -283,9 +283,9 @@ __global__ void ProductRuleBookKernel(const T* x_indices, kx, ky, kz)) { - int out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1]; - int out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2]; - int out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; + T out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1]; + T out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2]; + T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; in_i = i; out_index = phi::funcs::sparse::PointToIndex( batch, out_x, out_y, out_z, out_dims); @@ -321,7 +321,7 @@ __global__ void ProductRuleBookKernel(const T* x_indices, // 5. update the out_index by unique_key, uniqe_value and the index of // unique_value: // the new out_index: 0, 2, 3, 2, 3, 0, 1 -template +template int ProductRuleBook(const Context& dev_ctx, const SparseCooTensor& x, const std::vector& kernel_sizes, @@ -334,26 +334,26 @@ int ProductRuleBook(const Context& dev_ctx, DenseTensor* counter_per_kernel, DenseTensor* offsets_per_kernel, DenseTensor* out_index, - DenseTensor* unique_key, DenseTensor* unique_value, SparseCooTensor* out, std::vector* h_counter, std::vector* h_offsets) { + // TODO(zhangkaihuo): use PD_DISPATCH_INTEGRAL_TYPES for secondary dispatch + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); - const int* indices_ptr = non_zero_indices.data(); + const IntT* indices_ptr = non_zero_indices.data(); DenseTensor in_indexs = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW)); + dev_ctx, DenseTensorMeta(indices_dtype, {x.nnz()}, DataLayout::NCHW)); int* counter_ptr = counter_per_kernel->data(); int* offsets_ptr = offsets_per_kernel->data(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int rulebook_rows = 3; const int rulebook_cols = kernel_size * non_zero_num; DenseTensorMeta rulebook_meta( - DataType::INT32, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); - rulebook->set_meta(rulebook_meta); - dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int)); - int* rulebook_ptr = rulebook->data(); + indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); + *rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); + IntT* rulebook_ptr = rulebook->data(); const auto x_dims = x.dims(); Dims4D d_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]); @@ -369,39 +369,39 @@ int ProductRuleBook(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - ProductRuleBookKernel<<>>(indices_ptr, - d_x_dims, - d_kernel_dims, - d_out_dims, - non_zero_num, - d_paddings, - d_dilations, - d_strides, - subm, - rulebook_ptr, - counter_ptr, - in_indexs.data()); + ProductRuleBookKernel<<>>(indices_ptr, + d_x_dims, + d_kernel_dims, + d_out_dims, + non_zero_num, + d_paddings, + d_dilations, + d_strides, + subm, + rulebook_ptr, + counter_ptr, + in_indexs.data()); // 2. remove -1 #ifdef PADDLE_WITH_HIP - int* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), #else - int* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), #endif - rulebook_ptr, - rulebook_ptr + rulebook_rows * rulebook_cols, - -1); + rulebook_ptr, + rulebook_ptr + rulebook_rows * rulebook_cols, + -1); - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( rulebook_ptr, last, rulebook_ptr + 3 * kernel_size * non_zero_num - 1); - int rulebook_len = 0; + IntT rulebook_len = 0; phi::backends::gpu::GpuMemcpyAsync( &rulebook_len, rulebook_ptr + 3 * kernel_size * non_zero_num - 1, - sizeof(int), + sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -418,11 +418,10 @@ int ProductRuleBook(const Context& dev_ctx, // and then the intermediate output index is subtracted from the input index // to obain the rulebook. // get difference - int32_t* A_key_ptr = rulebook_ptr + 2 * rulebook_len; - int32_t* B_key_ptr = in_indexs.data(); - DenseTensor A_val = phi::Empty( - dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); + IntT* A_key_ptr = rulebook_ptr + 2 * rulebook_len; + IntT* B_key_ptr = in_indexs.data(); + DenseTensorMeta val_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW); + DenseTensor A_val = phi::Empty(dev_ctx, std::move(val_meta)); DenseTensor B_val = phi::Empty( dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW)); phi::IndexKernel>( @@ -431,10 +430,8 @@ int ProductRuleBook(const Context& dev_ctx, dev_ctx, &B_val, kps::IdentityFunctor()); DenseTensor key_result = phi::Empty( dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len + 1}, DataLayout::NCHW)); - DenseTensor val_result = phi::Empty( - dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {rulebook_len + 1}, DataLayout::NCHW)); + DenseTensor val_result = phi::Empty(dev_ctx, std::move(val_meta)); #ifdef PADDLE_WITH_HIP thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), @@ -457,7 +454,7 @@ int ProductRuleBook(const Context& dev_ctx, dev_ctx.stream()); dev_ctx.Wait(); - thrust::pair end; + thrust::pair end; // Because set_diff does not support duplicate data, set_diff is performed // separately for each segment of data. // TODO(zhangkaihuo): Using hashtable here may get better performance, @@ -465,7 +462,7 @@ int ProductRuleBook(const Context& dev_ctx, for (int i = 0; i < kernel_size; i++) { int start = offsets[i]; int stop = i == kernel_size - 1 ? rulebook_len : offsets[i + 1]; - int* key_result_start = (i == 0 ? key_result.data() : end.first); + IntT* key_result_start = (i == 0 ? key_result.data() : end.first); int* val_result_start = i == 0 ? val_result.data() : end.second; end = #ifdef PADDLE_WITH_HIP @@ -483,14 +480,14 @@ int ProductRuleBook(const Context& dev_ctx, val_result_start); } - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( - key_result.data(), + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + key_result.data(), end.first, - key_result.data() + rulebook_len); - int len = 0; + key_result.data() + rulebook_len); + IntT len = 0; phi::backends::gpu::GpuMemcpyAsync(&len, - key_result.data() + rulebook_len, - sizeof(int), + key_result.data() + rulebook_len, + sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -500,10 +497,10 @@ int ProductRuleBook(const Context& dev_ctx, dev_ctx.Wait(); // set the diff value = -1, and update counter auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len, 1); - SetFlagAndUpdateCounterKernel<<>>( + SetFlagAndUpdateCounterKernel<<>>( val_result.data(), len, rulebook_len, @@ -512,18 +509,18 @@ int ProductRuleBook(const Context& dev_ctx, counter_ptr); // remove -1 #ifdef PADDLE_WITH_HIP - int* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), #else - int* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), #endif - rulebook_ptr, - rulebook_ptr + 3 * rulebook_len, - -1); - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( - rulebook_ptr, last, key_result.data() + rulebook_len); + rulebook_ptr, + rulebook_ptr + 3 * rulebook_len, + -1); + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + rulebook_ptr, last, key_result.data() + rulebook_len); phi::backends::gpu::GpuMemcpyAsync(&rulebook_len, - key_result.data() + rulebook_len, - sizeof(int), + key_result.data() + rulebook_len, + sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -566,42 +563,47 @@ int ProductRuleBook(const Context& dev_ctx, cudaMemcpyDeviceToHost, dev_ctx.stream()); #endif - rulebook->Resize({rulebook_rows, rulebook_len}); + rulebook->Resize({rulebook_rows, static_cast(rulebook_len)}); // 3. sorted or merge the out index - out_index->ResizeAndAllocate({rulebook_len}); - unique_value->ResizeAndAllocate({rulebook_len}); - unique_key->ResizeAndAllocate({rulebook_len}); + out_index->ResizeAndAllocate({static_cast(rulebook_len)}); + unique_value->ResizeAndAllocate({static_cast(rulebook_len)}); + DenseTensor unique_key = phi::Empty( + dev_ctx, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {static_cast(rulebook_len)}, + DataLayout::NCHW)); int* out_index_ptr = out_index->data(); int* unique_value_ptr = unique_value->data(); - int* unique_key_ptr = unique_key->data(); - - int* new_end = SortedAndUniqueIndex(dev_ctx, - rulebook_ptr + 2 * rulebook_len, - rulebook_len, - out_index, - unique_key, - unique_value); + IntT* unique_key_ptr = unique_key.data(); + + IntT* new_end = + SortedAndUniqueIndex(dev_ctx, + rulebook_ptr + 2 * rulebook_len, + rulebook_len, + out_index, + &unique_key, + unique_value); // thrust::distance doesn't support stream parameters // const int out_non_zero_num = thrust::distance(unique_key_ptr, // new_end.first); - DistanceKernel<<<1, 1>>>( + DistanceKernel<<<1, 1>>>( unique_key_ptr, new_end, rulebook_ptr + rulebook_rows * rulebook_cols - 1); - int out_non_zero_num = 0; + IntT out_non_zero_num = 0; #ifdef PADDLE_WITH_HIP phi::backends::gpu::GpuMemcpyAsync( &out_non_zero_num, rulebook_ptr + rulebook_rows * rulebook_cols - 1, - sizeof(int), + sizeof(IntT), hipMemcpyDeviceToHost, dev_ctx.stream()); #else phi::backends::gpu::GpuMemcpyAsync( &out_non_zero_num, rulebook_ptr + rulebook_rows * rulebook_cols - 1, - sizeof(int), + sizeof(IntT), cudaMemcpyDeviceToHost, dev_ctx.stream()); #endif @@ -610,28 +612,29 @@ int ProductRuleBook(const Context& dev_ctx, // 5. update out_indices and rulebook by unique_value_ptr const int64_t sparse_dim = 4; DenseTensorMeta indices_meta( - DataType::INT32, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); + indices_dtype, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); DenseTensorMeta values_meta(x.dtype(), {out_non_zero_num, kernel_sizes[4]}, x.non_zero_elements().layout()); phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); - int* out_indices_ptr = out_indices.data(); + IntT* out_indices_ptr = out_indices.data(); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_non_zero_num, 1); - UpdateIndexKernel<<>>(unique_key_ptr, - unique_value_ptr, - out_index_ptr, - out_non_zero_num, - rulebook_len, - d_out_dims, - out_indices_ptr, - rulebook_ptr + 2 * rulebook_len); + UpdateIndexKernel<<>>( + unique_key_ptr, + unique_value_ptr, + out_index_ptr, + out_non_zero_num, + rulebook_len, + d_out_dims, + out_indices_ptr, + rulebook_ptr + 2 * rulebook_len); out->SetMember(out_indices, out_values, out_dims, true); return rulebook_len; } diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index 4a6094c23bc79..2b61be7289646 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -24,6 +24,8 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -35,24 +37,24 @@ namespace sparse { //] // x_grad = out_grad * transpose(kenrel) // kernel_grad = transpose(x) * out_grad -template -void Conv3dGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const DenseTensor& rulebook, - const SparseCooTensor& out_grad, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* x_grad, - DenseTensor* kernel_grad) { +template +void Conv3dGradGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; - const int* rulebook_ptr = rulebook.data(); + const IntT* rulebook_ptr = rulebook.data(); const int rulebook_len = rulebook.dims()[1]; @@ -74,29 +76,29 @@ void Conv3dGradKernel(const Context& dev_ctx, T* out_grad_features_ptr = out_grad_features.data(); *kernel_grad = phi::EmptyLike(dev_ctx, kernel); T* d_kernel_ptr = kernel_grad->data(); - phi::funcs::SetConstant set_zero; + phi::funcs::SetConstant set_zero; set_zero(dev_ctx, kernel_grad, static_cast(0.0f)); int half_kernel_size = kernel_size / 2; - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); DenseTensor x_grad_indices = - phi::EmptyLike(dev_ctx, x.non_zero_indices()); + phi::EmptyLike(dev_ctx, x.non_zero_indices()); DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); T* x_grad_values_ptr = x_grad_values.data(); set_zero(dev_ctx, &x_grad_values, static_cast(0.0f)); set_zero(dev_ctx, &d_x_features, static_cast(0.0f)); - phi::Copy(dev_ctx, - x.non_zero_indices(), - dev_ctx.GetPlace(), - false, - &x_grad_indices); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0), + std::vector offsets(kernel_size + 1), counter(kernel_size, 0), h_counter(rulebook_len, 0); phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], rulebook_ptr, - rulebook_len * sizeof(int), + rulebook_len * sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -109,7 +111,7 @@ void Conv3dGradKernel(const Context& dev_ctx, for (int i = 0; i < rulebook_len; i++) { counter[h_counter[i]] += 1; } - int offset = 0, max_count = 0; + IntT offset = 0, max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; offset += counter[i]; @@ -120,15 +122,16 @@ void Conv3dGradKernel(const Context& dev_ctx, offsets[kernel_size] = offset; if (subm) { - phi::funcs::sparse::SubmPreProcess(dev_ctx, - x, - kernel, - out_grad.non_zero_elements(), - in_channels, - out_channels, - half_kernel_size, - kernel_grad, - &x_grad_values); + phi::funcs::sparse::SubmPreProcess( + dev_ctx, + x, + kernel, + out_grad.non_zero_elements(), + in_channels, + out_channels, + half_kernel_size, + kernel_grad, + &x_grad_values); if (max_count == 0) { return; } @@ -136,21 +139,21 @@ void Conv3dGradKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - in_features_ptr, - rulebook_len, - in_channels); + GatherKernel<<>>(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + in_features_ptr, + rulebook_len, + in_channels); config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * out_channels, 1); - GatherKernel<<>>( + GatherKernel<<>>( out_grad.non_zero_elements().data(), rulebook_ptr + rulebook_len * 2, out_grad_features_ptr, @@ -203,15 +206,19 @@ void Conv3dGradKernel(const Context& dev_ctx, // x_grad->ResizeAndAllocate(x.non_zero_elements().dims()); DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW); DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); - DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_key = phi::Empty( + dev_ctx, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {rulebook_len}, + DataLayout::NCHW)); DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); - SortedAndUniqueIndex(dev_ctx, - rulebook_ptr + rulebook_len, - rulebook_len, - &out_index, - &unique_key, - &unique_value); + SortedAndUniqueIndex(dev_ctx, + rulebook_ptr + rulebook_len, + rulebook_len, + &out_index, + &unique_key, + &unique_value); config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * in_channels, 1); @@ -229,6 +236,36 @@ void Conv3dGradKernel(const Context& dev_ctx, subm); } +template +void Conv3dGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dGradGPUKernel", ([&] { + Conv3dGradGPUKernel(dev_ctx, + x, + kernel, + rulebook, + out_grad, + paddings, + dilations, + strides, + groups, + subm, + x_grad, + kernel_grad); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu index 214e689e9370a..2d212eadffac1 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu @@ -19,29 +19,25 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/convolution_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { -/** - * x: (N, D, H, W, C) - * kernel: (D, H, W, C, OC) - * out: (N, D, H, W, OC) -**/ -template -void Conv3dKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void Conv3dGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { // update padding and dilation // Currently, only support x.layout is NDHWC, groups = 1 // if x.layout != NDHWC then transpose(x), transpose(weight) - const auto& x_dims = x.dims(); const auto& kernel_dims = kernel.dims(); int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; @@ -67,7 +63,6 @@ void Conv3dKernel(const Context& dev_ctx, DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta)); DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); - DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); std::vector subm_paddings(paddings), subm_strides(strides); @@ -75,28 +70,26 @@ void Conv3dKernel(const Context& dev_ctx, phi::funcs::sparse::ResetSubmKernelSizeAndStrides( kernel.dims(), &subm_paddings, &subm_strides); } - - int n = ProductRuleBook(dev_ctx, - x, - kernel_sizes, - subm_paddings, - dilations, - subm_strides, - out_dims, - subm, - rulebook, - &counter_per_kernel, - &offsets_per_kernel, - &out_index, - &unique_key, - &unique_value, - out, - &h_counter, - &offsets); + int n = ProductRuleBook(dev_ctx, + x, + kernel_sizes, + subm_paddings, + dilations, + subm_strides, + out_dims, + subm, + rulebook, + &counter_per_kernel, + &offsets_per_kernel, + &out_index, + &unique_value, + out, + &h_counter, + &offsets); const int* counter_ptr = counter_per_kernel.data(); const int* offsets_ptr = counter_per_kernel.data(); - const int* rulebook_ptr = rulebook->data(); + const IntT* rulebook_ptr = rulebook->data(); // 2. gather DenseTensorMeta in_features_meta( @@ -109,22 +102,22 @@ void Conv3dKernel(const Context& dev_ctx, phi::Empty(dev_ctx, std::move(out_features_meta)); T* in_features_ptr = in_features.data(); T* out_features_ptr = out_features.data(); - phi::funcs::SetConstant set_zero; + phi::funcs::SetConstant set_zero; set_zero(dev_ctx, &out_features, static_cast(0.0f)); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + n, - in_features_ptr, - n, - in_channels); + GatherKernel<<>>(x.non_zero_elements().data(), + rulebook_ptr + n, + in_features_ptr, + n, + in_channels); // 3. call gemm for every werght - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); auto* out_values = out->mutable_non_zero_elements(); T* out_values_ptr = out_values->data(); @@ -168,6 +161,36 @@ void Conv3dKernel(const Context& dev_ctx, out_channels, out_values_ptr); } +/** + * x: (N, D, H, W, C) + * kernel: (D, H, W, C, OC) + * out: (N, D, H, W, OC) +**/ +template +void Conv3dKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dGPUKernel", ([&] { + Conv3dGPUKernel(dev_ctx, + x, + kernel, + paddings, + dilations, + strides, + groups, + subm, + out, + rulebook); + })); +} } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu index 1048dd1be0c01..8657e7319d8ca 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu @@ -12,24 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" -#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" +#include "paddle/phi/api/ext/dispatch.h" namespace phi { namespace sparse { -template +template __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, const T* out_features_ptr, const T* out_grad_ptr, - const int* rulebook_ptr, + const IntT* rulebook_ptr, const int n, const int rulebook_len, const int channels, @@ -38,8 +42,8 @@ __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, CUDA_KERNEL_LOOP_TYPE(i, n * channels, int64_t) { int real_i = i / channels; int c = i - real_i * channels; - int in_i = rulebook_ptr[real_i]; - int out_i = rulebook_ptr[real_i + rulebook_len]; + IntT in_i = rulebook_ptr[real_i]; + IntT out_i = rulebook_ptr[real_i + rulebook_len]; grad_functor.compute(in_features_ptr[in_i * channels + c], out_features_ptr[out_i * channels + c], out_grad_ptr[out_i * channels + c], @@ -48,23 +52,23 @@ __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, } } -template -void MaxPoolGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const DenseTensor& out_grad, - const std::vector& kernel_sizes, - DenseTensor* x_grad) { +template +void MaxPoolGradGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int in_channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; - const int* rulebook_ptr = rulebook.data(); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0), + const IntT* rulebook_ptr = rulebook.data(); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0), h_counter(kernel_size); phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], rulebook_ptr, - rulebook_len * sizeof(int), + rulebook_len * sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -80,10 +84,20 @@ void MaxPoolGradKernel(const Context& dev_ctx, const T* in_features_ptr = x.non_zero_elements().data(); const T* out_features_ptr = out.non_zero_elements().data(); - const T* out_grad_ptr = out_grad.data(); - T* x_grad_ptr = x_grad->data(); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, x_grad, static_cast(0.0f)); + const T* out_grad_ptr = out_grad.non_zero_elements().data(); + // TODO(zhangkaihuo): call phi::sparse::EmptyLike + DenseTensor x_grad_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); + T* x_grad_ptr = x_grad_values.data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, &x_grad_values, static_cast(0.0f)); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); for (int i = 0; i < kernel_size; i++) { if (counter[i] <= 0) { @@ -92,10 +106,10 @@ void MaxPoolGradKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, counter[i] * in_channels, 1); - MaxPoolGradCudaKernel<<>>( + MaxPoolGradCudaKernel<<>>( in_features_ptr, out_features_ptr, out_grad_ptr, @@ -107,6 +121,21 @@ void MaxPoolGradKernel(const Context& dev_ctx, } } +template +void MaxPoolGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolGradGPUKernel", ([&] { + MaxPoolGradGPUKernel( + dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu index 0f6a0d13b1ddb..a59cd3c7a5a78 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu @@ -12,19 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" + #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" -#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" + +#include "paddle/phi/api/ext/dispatch.h" namespace phi { namespace sparse { -template +template __global__ void MaxPoolCudaKernel(const T* in_features_ptr, - const int* rulebook_ptr, + const IntT* rulebook_ptr, const int n, const int rulebook_len, const int channels, @@ -33,8 +36,8 @@ __global__ void MaxPoolCudaKernel(const T* in_features_ptr, CUDA_KERNEL_LOOP_TYPE(i, n * channels, int64_t) { int real_i = i / channels; int channel_i = i - real_i * channels; - int in_i = rulebook_ptr[real_i]; - int out_i = rulebook_ptr[real_i + rulebook_len]; + IntT in_i = rulebook_ptr[real_i]; + IntT out_i = rulebook_ptr[real_i + rulebook_len]; max_pool_functor.compute(in_features_ptr[in_i * channels + channel_i], &out_features_ptr[out_i * channels + channel_i]); } @@ -45,15 +48,15 @@ __global__ void MaxPoolCudaKernel(const T* in_features_ptr, * kernel: (D, H, W, C, OC) * out: (N, D, H, W, OC) **/ -template -void MaxPoolKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void MaxPoolGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -70,29 +73,27 @@ void MaxPoolKernel(const Context& dev_ctx, DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); - DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); // 1. product rulebook - int rulebook_len = ProductRuleBook(dev_ctx, - x, - real_kernel_sizes, - paddings, - dilations, - strides, - out_dims, - false, - rulebook, - &counter_per_kernel, - &offsets_per_kernel, - &out_index, - &unique_key, - &unique_value, - out, - &counter, - &offsets); - - const int* rulebook_ptr = rulebook->data(); + int rulebook_len = ProductRuleBook(dev_ctx, + x, + real_kernel_sizes, + paddings, + dilations, + strides, + out_dims, + false, + rulebook, + &counter_per_kernel, + &offsets_per_kernel, + &out_index, + &unique_value, + out, + &counter, + &offsets); + + const IntT* rulebook_ptr = rulebook->data(); T* out_features_ptr = out->mutable_non_zero_elements()->data(); const T* in_features_ptr = x.non_zero_elements().data(); @@ -113,10 +114,10 @@ void MaxPoolKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, counter[i] * in_channels, 1); - MaxPoolCudaKernel<<>>( + MaxPoolCudaKernel<<>>( in_features_ptr, rulebook_ptr + offsets[i] + rulebook_len, counter[i], @@ -126,6 +127,28 @@ void MaxPoolKernel(const Context& dev_ctx, } } +template +void MaxPoolKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolGPUKernel", ([&] { + MaxPoolGPUKernel(dev_ctx, + x, + kernel_sizes, + paddings, + dilations, + strides, + out, + rulebook); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h index 572ade76281bc..2f7366a010aaa 100644 --- a/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h @@ -26,20 +26,18 @@ void MaxPoolGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, const SparseCooTensor& out, - const DenseTensor& out_grad, + const SparseCooTensor& out_grad, const std::vector& kernel_sizes, - DenseTensor* x_grad); + SparseCooTensor* x_grad); template -DenseTensor MaxPoolGrad(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const DenseTensor& out_grad, - const std::vector& kernel_sizes) { - DenseTensor x_grad = phi::Empty( - dev_ctx, - DenseTensorMeta(x.dtype(), x.non_zero_elements().dims(), x.layout())); +SparseCooTensor MaxPoolGrad(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes) { + SparseCooTensor x_grad; MaxPoolGradKernel( dev_ctx, x, rulebook, out, out_grad, kernel_sizes, &x_grad); return x_grad; diff --git a/paddle/phi/kernels/sparse/sparse_pool_kernel.h b/paddle/phi/kernels/sparse/sparse_pool_kernel.h index bfadbf72e300f..d5248a1ad250e 100644 --- a/paddle/phi/kernels/sparse/sparse_pool_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_pool_kernel.h @@ -39,11 +39,7 @@ SparseCooTensor MaxPool(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, DenseTensor* rulebook) { - DenseTensor indices = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); - DenseTensor values = - phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); - SparseCooTensor coo(indices, values, x.dims()); + SparseCooTensor coo; MaxPoolKernel( dev_ctx, x, kernel_sizes, paddings, dilations, strides, &coo, rulebook); return coo; diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index c22464e538c21..9fb0e5692645d 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -48,13 +48,13 @@ std::vector cast(const std::vector& in) { return out; } -template -void TestConv3dBase(const std::vector& indices, +template +void TestConv3dBase(const std::vector& indices, const std::vector& features, const DDim& x_dims, const std::vector& kernel, const DDim& kernel_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -80,11 +80,13 @@ void TestConv3dBase(const std::vector& indices, const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); DenseTensor indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); - memcpy( - indices_tensor.data(), indices.data(), indices.size() * sizeof(int)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); + memcpy(indices_tensor.data(), + indices.data(), + indices.size() * sizeof(IntT)); DenseTensor features_tensor = phi::Empty( dev_ctx_cpu, DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), @@ -111,7 +113,7 @@ void TestConv3dBase(const std::vector& indices, if (!std::is_same::value) { DenseTensor rulebook = phi::Empty( - dev_ctx_cpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + dev_ctx_cpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW)); SparseCooTensor out = sparse::Conv3d(dev_ctx_cpu, x_tensor, kernel_tensor, @@ -129,8 +131,8 @@ void TestConv3dBase(const std::vector& indices, ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out.nnz()); int cmp_indices = memcmp(correct_out_indices.data(), - out.non_zero_indices().data(), - correct_out_indices.size() * sizeof(int)); + out.non_zero_indices().data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices, 0); f_verify(out.non_zero_elements().data(), correct_out_features); @@ -172,7 +174,7 @@ void TestConv3dBase(const std::vector& indices, DenseTensor d_indices_tensor = phi::Empty( dev_ctx_gpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); phi::Copy( dev_ctx_gpu, indices_tensor, phi::GPUPlace(), true, &d_indices_tensor); @@ -195,7 +197,7 @@ void TestConv3dBase(const std::vector& indices, dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor); DenseTensor d_rulebook = phi::Empty( - dev_ctx_gpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + dev_ctx_gpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW)); SparseCooTensor d_out = sparse::Conv3d(dev_ctx_gpu, d_x_tensor, d_kernel_tensor, @@ -214,7 +216,7 @@ void TestConv3dBase(const std::vector& indices, DenseTensor h_indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, d_out.nnz()}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); phi::Copy(dev_ctx_gpu, d_out.non_zero_indices(), phi::CPUPlace(), @@ -222,8 +224,8 @@ void TestConv3dBase(const std::vector& indices, &h_indices_tensor); int cmp_indices2 = memcmp(correct_out_indices.data(), - h_indices_tensor.data(), - correct_out_indices.size() * sizeof(int)); + h_indices_tensor.data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices2, 0); DenseTensor h_features_tensor = @@ -264,12 +266,13 @@ void TestConv3dBase(const std::vector& indices, #endif } -void TestConv3d(const std::vector& indices, +template +void TestConv3d(const std::vector& indices, const std::vector& features, const DDim& x_dims, const std::vector& kernel, const DDim& kernel_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -282,41 +285,41 @@ void TestConv3d(const std::vector& indices, const std::vector kernel_grad = {}, const bool subm = false) { // test float - TestConv3dBase(indices, - features, - x_dims, - kernel, - kernel_dims, - correct_out_indices, - correct_out_features, - correct_out_dims, - non_zero_num, - paddings, - strides, - dilations, - diff, - backward, - features_grad, - kernel_grad, - subm); + TestConv3dBase(indices, + features, + x_dims, + kernel, + kernel_dims, + correct_out_indices, + correct_out_features, + correct_out_dims, + non_zero_num, + paddings, + strides, + dilations, + diff, + backward, + features_grad, + kernel_grad, + subm); // test double - TestConv3dBase(indices, - cast(features), - x_dims, - cast(kernel), - kernel_dims, - correct_out_indices, - cast(correct_out_features), - correct_out_dims, - non_zero_num, - paddings, - strides, - dilations, - diff, - backward, - cast(features_grad), - cast(kernel_grad), - subm); + TestConv3dBase(indices, + cast(features), + x_dims, + cast(kernel), + kernel_dims, + correct_out_indices, + cast(correct_out_features), + correct_out_dims, + non_zero_num, + paddings, + strides, + dilations, + diff, + backward, + cast(features_grad), + cast(kernel_grad), + subm); } TEST(DEV_API, sparse_conv3d) { @@ -616,6 +619,51 @@ TEST(DEV_API, sparse_conv2d) { dilations); } +TEST(DEV_API, sparse_conv2d_int64) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 1, 5, 5, in_channels}; + DDim kernel_dims = {1, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 1, 3, 3, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices_flatten = {0, 0, 0, 0, 0, 0, 0, 4, 0, 3, 2, 4}; + + std::vector features = {-0.79394531, -0.3125, -0.55029297}; + // 3*3*3=27 + std::vector kernel = {0.65820312, + 0.75048828, + 0.21411133, + 0.17370605, + 0.85546875, + 0.53076172, + 0.28833008, + 0.71044922, + 0.00659943}; + + std::vector out_indices_flatten = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 2, 2, 1, 2, 0, 1, 2}; + + std::vector out_features = { + -0.17004, -0.71338, -0.00206, -0.22205, -0.09009}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + TEST(DEV_API, sparse_conv3d_backward) { const int in_channels = 1; const int out_channels = 1; diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc index 632beadf3de0e..8f7288d70d7d0 100644 --- a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -36,11 +36,11 @@ std::vector cast(const std::vector& in) { } return out; } -template -void TestMaxPoolBase(const std::vector& indices, +template +void TestMaxPoolBase(const std::vector& indices, const std::vector& features, const DDim& x_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -65,11 +65,13 @@ void TestMaxPoolBase(const std::vector& indices, const int in_channels = x_dims[4]; const int out_channels = in_channels; + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); DenseTensor indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); - memcpy( - indices_tensor.data(), indices.data(), indices.size() * sizeof(int)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); + memcpy(indices_tensor.data(), + indices.data(), + indices.size() * sizeof(IntT)); DenseTensor features_tensor = phi::Empty( dev_ctx_cpu, DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), @@ -88,8 +90,7 @@ void TestMaxPoolBase(const std::vector& indices, }; if (!std::is_same::value) { - DenseTensor rulebook = phi::Empty( - dev_ctx_cpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + DenseTensor rulebook; SparseCooTensor out = sparse::MaxPool(dev_ctx_cpu, x_tensor, kernel_sizes, @@ -105,20 +106,16 @@ void TestMaxPoolBase(const std::vector& indices, ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out.nnz()); int cmp_indices = memcmp(correct_out_indices.data(), - out.non_zero_indices().data(), - correct_out_indices.size() * sizeof(int)); + out.non_zero_indices().data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices, 0); f_verify(out.non_zero_elements().data(), correct_out_features); if (backward) { - DenseTensor x_grad = sparse::MaxPoolGrad(dev_ctx_cpu, - x_tensor, - rulebook, - out, - out.non_zero_elements(), - kernel_sizes); - f_verify(x_grad.data(), features_grad); + SparseCooTensor x_grad = sparse::MaxPoolGrad( + dev_ctx_cpu, x_tensor, rulebook, out, out, kernel_sizes); + f_verify(x_grad.non_zero_elements().data(), features_grad); } } @@ -142,7 +139,7 @@ void TestMaxPoolBase(const std::vector& indices, DenseTensor d_indices_tensor = phi::Empty( dev_ctx_gpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); phi::Copy( dev_ctx_gpu, indices_tensor, phi::GPUPlace(), true, &d_indices_tensor); @@ -153,8 +150,7 @@ void TestMaxPoolBase(const std::vector& indices, SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims); - DenseTensor d_rulebook = phi::Empty( - dev_ctx_gpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + DenseTensor d_rulebook; SparseCooTensor d_out = sparse::MaxPool(dev_ctx_gpu, d_x_tensor, kernel_sizes, @@ -171,7 +167,7 @@ void TestMaxPoolBase(const std::vector& indices, DenseTensor h_indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, d_out.nnz()}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); phi::Copy(dev_ctx_gpu, d_out.non_zero_indices(), phi::CPUPlace(), @@ -179,8 +175,8 @@ void TestMaxPoolBase(const std::vector& indices, &h_indices_tensor); int cmp_indices2 = memcmp(correct_out_indices.data(), - h_indices_tensor.data(), - correct_out_indices.size() * sizeof(int)); + h_indices_tensor.data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices2, 0); DenseTensor h_features_tensor = @@ -194,23 +190,25 @@ void TestMaxPoolBase(const std::vector& indices, f_verify(h_features_tensor.data(), correct_out_features); if (backward) { - DenseTensor x_grad = sparse::MaxPoolGrad(dev_ctx_gpu, - d_x_tensor, - d_rulebook, - d_out, - d_out.non_zero_elements(), - kernel_sizes); - DenseTensor h_features_grad = phi::EmptyLike(dev_ctx_cpu, x_grad); - phi::Copy(dev_ctx_gpu, x_grad, phi::CPUPlace(), true, &h_features_grad); + SparseCooTensor x_grad = sparse::MaxPoolGrad( + dev_ctx_gpu, d_x_tensor, d_rulebook, d_out, d_out, kernel_sizes); + DenseTensor h_features_grad = + phi::EmptyLike(dev_ctx_cpu, x_grad.non_zero_elements()); + phi::Copy(dev_ctx_gpu, + x_grad.non_zero_elements(), + phi::CPUPlace(), + true, + &h_features_grad); f_verify(h_features_grad.data(), features_grad); } #endif } -void TestMaxPool(const std::vector& indices, +template +void TestMaxPool(const std::vector& indices, const std::vector& features, const DDim& x_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -222,35 +220,35 @@ void TestMaxPool(const std::vector& indices, const bool backward = false, const std::vector features_grad = {}) { // test float - TestMaxPoolBase(indices, - features, - x_dims, - correct_out_indices, - correct_out_features, - correct_out_dims, - non_zero_num, - kernel_sizes, - paddings, - strides, - dilations, - diff, - backward, - features_grad); + TestMaxPoolBase(indices, + features, + x_dims, + correct_out_indices, + correct_out_features, + correct_out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + diff, + backward, + features_grad); // test double - TestMaxPoolBase(indices, - cast(features), - x_dims, - correct_out_indices, - cast(correct_out_features), - correct_out_dims, - non_zero_num, - kernel_sizes, - paddings, - strides, - dilations, - diff, - backward, - cast(features_grad)); + TestMaxPoolBase(indices, + cast(features), + x_dims, + correct_out_indices, + cast(correct_out_features), + correct_out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + diff, + backward, + cast(features_grad)); } TEST(DEV_API, sparse_maxpool) { From 56f108ff0373d143d8dd0e8d7bae44d3783dca8f Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Sat, 2 Apr 2022 15:03:56 +0800 Subject: [PATCH 16/93] filter unsupported inputs for elementwise op in op teller (#41253) * filter unsupported inputs for elementwise op in op teller * add unittest for corner case --- paddle/fluid/inference/tensorrt/op_teller.cc | 15 ++ .../inference/test_trt_convert_elementwise.py | 134 ++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 13c16ab6897e3..cfdccecb5c8f7 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1011,6 +1011,21 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, VLOG(3) << "Now trt may not support two 1d tensor elementwise op."; return false; } + if (op_type == "elementwise_add" || op_type == "elementwise_mul") { + if (x_var_desc->Persistable()) { + VLOG(3) << "Input X is a parameter which is not supported for " + "elementwise_add/elementwise_mul in tensorrt, swap x and " + "y will work"; + return false; + } + } + if (op_type == "elementwise_sub" || op_type == "elementwise_div") { + if (x_var_desc->Persistable() || y_var_desc->Persistable()) { + VLOG(3) << "Input X or Input Y is a parameter which is not supported " + "for elementwise_sub/elementwise_div in tensorrt"; + return false; + } + } } if (op_type == "stack") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 047a6094ec1e1..e849496621a10 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -397,5 +397,139 @@ def test(self): self.run_test() +class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input(shape): + return np.random.random(shape).astype(np.float32) + + def generate_weight(): + return np.random.randn(32).astype(np.float32) + + for batch in [1, 2, 4]: + for shape in [[32], [batch, 32], [batch, 32, 32], + [batch, 32, 16, 32]]: + for op_type in [ + "elementwise_add", "elementwise_mul", "elementwise_sub", + "elementwise_div" + ]: + for axis in [-1 if len(shape) == 1 else 1]: + self.dims = len(shape) + dics = [{"axis": axis}] + ops_config = [{ + "op_type": op_type, + "op_inputs": { + "X": ["weight"], + "Y": ["input_data"] + }, + "op_outputs": { + "Out": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "weight": + TensorConfig(data_gen=partial(generate_weight)) + }, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, shape)), + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + # The input.dims[1] must be equal to the weight's length. + if self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [4]} + self.dynamic_shape.max_input_shape = {"input_data": [256]} + self.dynamic_shape.opt_input_shape = {"input_data": [16]} + elif self.dims == 2: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 32]} + self.dynamic_shape.opt_input_shape = {"input_data": [2, 32]} + elif self.dims == 3: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 4]} + self.dynamic_shape.max_input_shape = { + "input_data": [4, 32, 256] + } + self.dynamic_shape.opt_input_shape = {"input_data": [2, 32, 16]} + elif self.dims == 4: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 32, 4, 4] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 32, 128, 256] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [2, 32, 32, 16] + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if self.dims == 1: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + input_x_names = program_config.ops[0].inputs["X"] + for weight_name in program_config.weights: + if weight_name in input_x_names: + return True + op_type = program_config.ops[0].type + if op_type in ["elementwise_sub", "elementwise_div"]: + input_y_names = program_config.ops[0].inputs["Y"] + for weight_name in program_config.weights: + if weight_name in input_y_names: + return True + return False + + self.add_skip_case( + teller1, SkipReasons.TRT_NOT_SUPPORT, + "Input X should not be parameters in elementwise op and Input Y should not be parameters in elementwise_sub or elementwise_div op" + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + if __name__ == "__main__": unittest.main() From afadb8c5b90165f612e91d9c4200f7c431f90ef3 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Sat, 2 Apr 2022 15:40:32 +0800 Subject: [PATCH 17/93] [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad() (#41198) * [Refactor] refactored eager_gen.py PR #2 * [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes * Fixed minor issue * Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition * Fixed issues * Supported higher-order grad node generation * [DoubleGrad PR #4] Supported higher-order GradNode generation * [DoubleGrad #4] Bug Fixes to Double Grad Node Generation * Fixed yaml typo * Fixed yaml typo * fixed minor issues * [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad() * Fixed minor issue * Fixed CI-Inference issue * Fixed CI-inference issues --- paddle/fluid/eager/CMakeLists.txt | 10 +- paddle/fluid/eager/api/utils/hook_utils.cc | 1 + paddle/fluid/eager/backward.cc | 17 +-- paddle/fluid/eager/grad_tensor_holder.cc | 118 ++++++++++++------ paddle/fluid/eager/grad_tensor_holder.h | 6 +- paddle/fluid/eager/tests/CMakeLists.txt | 5 +- .../tests/data_structure_tests/CMakeLists.txt | 5 +- .../grad_tensor_holder_test.cc | 11 +- .../eager/tests/task_tests/CMakeLists.txt | 10 +- .../eager/tests/task_tests/backward_test.cc | 1 + .../tests/task_tests/fwd_bwd_joint_test.cc | 2 + .../fluid/eager/tests/task_tests/grad_test.cc | 2 + paddle/phi/api/include/tensor.h | 1 + paddle/phi/api/lib/tensor.cc | 5 + 14 files changed, 124 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index d8089bedf924e..da326ff7d76d7 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -13,12 +13,16 @@ add_subdirectory(accumulation) add_subdirectory(custom_operator) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) add_subdirectory(pylayer) + cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) + add_dependencies(grad_tensor_holder eager_final_state_codegen) + cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) endif() + cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor) -cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) cc_library(autograd_meta SRCS autograd_meta.cc DEPS phi_api phi_tensor) cc_library(utils SRCS utils.cc DEPS phi_api phi_tensor global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils) -cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) -add_subdirectory(tests) +if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) + add_subdirectory(tests) +endif() diff --git a/paddle/fluid/eager/api/utils/hook_utils.cc b/paddle/fluid/eager/api/utils/hook_utils.cc index 9abd7be49d44c..8ee646b718c2f 100644 --- a/paddle/fluid/eager/api/utils/hook_utils.cc +++ b/paddle/fluid/eager/api/utils/hook_utils.cc @@ -76,6 +76,7 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name(); // Simply Copy impl() to grad_tensor grad_tensor->set_impl(t.impl()); + grad_tensor->set_autograd_meta(t.mutable_autograd_meta()); return *grad_tensor.get(); } else { VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook"; diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 0ce2f17cb45be..ed286dd5fd960 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -466,6 +466,7 @@ std::vector RunBackward( continue; } + // TODO(zhanlve): Copy and Modify GradNode if is_general_grad GradNodeBase* grad_node = shared_grad_node.get(); // Prepare GradTensorHolder @@ -486,16 +487,9 @@ std::vector RunBackward( // Feed given tensor if it's provided VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor"; - if (grad_tensors[i].is_initialized()) { - // Deep copy - paddle::experimental::Tensor tmp_tensor; - tmp_tensor.copy_(grad_tensors[i], grad_tensors[i].inner_place(), false); - node_input_buffers_dict[grad_node]->add(input_info.first, - input_info.second, tmp_tensor); - } else { - node_input_buffers_dict[grad_node]->add( - input_info.first, input_info.second, grad_tensors[i]); - } + // Deep copy + node_input_buffers_dict[grad_node]->CopyValueFromTensor( + input_info.first, input_info.second, grad_tensors[i]); } else { VLOG(6) << "Fill grad input tensor " << i << " with 1.0"; @@ -504,7 +498,7 @@ std::vector RunBackward( // dims // GradTensorHolder will initialize another tensor with same tensortype, // datatype and dims but filled with 1.0 - node_input_buffers_dict[grad_node]->add( + node_input_buffers_dict[grad_node]->CopyValueFromTensor( input_info.first, input_info.second, tensor, true /*fill_one=true*/); } @@ -686,6 +680,7 @@ std::vector RunBackward( } } } + if (!is_general_grad) return {}; return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph); } diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index 038ad09aa4d8b..b15d9b892f810 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/eager/grad_tensor_holder.h" #include "paddle/fluid/imperative/gradient_accumulator.h" +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -26,9 +27,9 @@ void GradTensorHolder::SetBufferSlotRankZeros(size_t slot_id, size_t rank) { paddle::experimental::zeros_like(buffer_[slot_id][rank]); } -void GradTensorHolder::add(size_t slot_id, size_t rank, - const paddle::experimental::Tensor& t, - bool fill_one) { +void GradTensorHolder::CopyValueFromTensor( + size_t slot_id, size_t rank, const paddle::experimental::Tensor& t, + bool fill_one) { // TODO(jiabin): We need to deal with empty input_buffer with slot size not // empty; PADDLE_ENFORCE(slot_id < buffer_.size(), @@ -50,44 +51,15 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, slot_id, buffer_[slot_id].size(), rank)); if (!fill_one) { paddle::experimental::Tensor& buffer_tensor = buffer_[slot_id][rank]; - // TODO(jiabin): Code bellow is ugly to divide which inner var we used, - // remove framework::Variable - // related code later. - // This if statement is trying to test neither phi::Tensor nor - // framework::Variable is initialized. if ((!buffer_tensor.defined() || !buffer_tensor.initialized())) { - // Simply copy tensor->impl - buffer_tensor = t; + // Perform deep copy here + buffer_tensor.copy_(t, t.inner_place(), false); + buffer_tensor.set_autograd_meta(t.mutable_autograd_meta()); + } else { - // Accumulation - PADDLE_ENFORCE_EQ(t.initialized(), true, - paddle::platform::errors::Fatal( - "We can only accumulate initialized tensor, but we " - "got tensor: %s is empty please check you network " - "and make sure it creates grads.", - t.name())); - if (t.is_dense_tensor()) { - if (buffer_tensor.is_dense_tensor()) { - paddle::imperative::TensorAdd( - t, &buffer_tensor); - } else { - // TODO(jiabin): Support Other TensorBase later - paddle::experimental::Tensor new_buffer( - std::make_shared(), "tmp_accumulator"); - paddle::imperative::SelectedRowsAddTensor(buffer_tensor, t, - &new_buffer); - buffer_tensor.set_impl(new_buffer.impl()); - } - } else { - // TODO(jiabin): Support Other TensorBase later - if (buffer_tensor.is_dense_tensor()) { - paddle::imperative::SelectedRowsAddToTensor(t, &buffer_tensor); - } else { - buffer_tensor = - std::move(*paddle::imperative::SelectedRowsMerge< - paddle::experimental::Tensor>(t, buffer_tensor)); - } - } + PADDLE_THROW(paddle::platform::errors::Fatal( + "Cannot copy grad_tensors' value to grad tensor holders," + "input buffer has already been initialized.")); } } else { // Create new tensor->impl and fill it with 1.0 @@ -98,4 +70,72 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, } } +void GradTensorHolder::add(size_t slot_id, size_t rank, + const paddle::experimental::Tensor& t) { + // TODO(jiabin): We need to deal with empty input_buffer with slot size not + // empty; + PADDLE_ENFORCE(slot_id < buffer_.size(), + paddle::platform::errors::Fatal( + "Invalid slot_id for GradTensorHolder::add() " + "which exceeds size of buffer")); + VLOG(6) << "Add Tensor for buffer_ slot: " << slot_id + << ", size: " << buffer_[slot_id].size(); + if (buffer_[slot_id].empty()) { + VLOG(6) << "Pass add Tensor for buffer_ slot: " << slot_id + << " since its buffer_ is empty "; + return; + } + PADDLE_ENFORCE( + rank < buffer_[slot_id].size(), + paddle::platform::errors::Fatal( + "Invalid rank for GradTensorHolder::add() which exceeds size " + "of buffer slot %d, got slot size is: %d rank is: %d", + slot_id, buffer_[slot_id].size(), rank)); + + paddle::experimental::Tensor& buffer_tensor = buffer_[slot_id][rank]; + // TODO(jiabin): Code bellow is ugly to divide which inner var we used, + // remove framework::Variable + // related code later. + // This if statement is trying to test neither phi::Tensor nor + // framework::Variable is initialized. + if ((!buffer_tensor.defined() || !buffer_tensor.initialized())) { + // Simply copy tensor->impl + buffer_tensor = t; + } else { + // Accumulation + PADDLE_ENFORCE_EQ(t.initialized(), true, + paddle::platform::errors::Fatal( + "We can only accumulate initialized tensor, but we " + "got tensor: %s is empty please check you network " + "and make sure it creates grads.", + t.name())); + if (t.is_dense_tensor()) { + if (buffer_tensor.is_dense_tensor()) { + buffer_tensor = add_final_state_dygraph_function(t, buffer_tensor); + + } else { + // TODO(jiabin): Support Other TensorBase later + // TODO(zhanlve): Replace SelectedRowsAddTensor with + // add_dygraph_function once it's supported + paddle::experimental::Tensor new_buffer( + std::make_shared(), "tmp_accumulator"); + paddle::imperative::SelectedRowsAddTensor(buffer_tensor, t, + &new_buffer); + buffer_tensor.set_impl(new_buffer.impl()); + } + } else { + // TODO(jiabin): Support Other TensorBase later + // TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function + // once it's supported + if (buffer_tensor.is_dense_tensor()) { + paddle::imperative::SelectedRowsAddToTensor(t, &buffer_tensor); + } else { + buffer_tensor = + std::move(*paddle::imperative::SelectedRowsMerge< + paddle::experimental::Tensor>(t, buffer_tensor)); + } + } + } +} + } // namespace egr diff --git a/paddle/fluid/eager/grad_tensor_holder.h b/paddle/fluid/eager/grad_tensor_holder.h index db03789ea7632..a4f2507728c64 100644 --- a/paddle/fluid/eager/grad_tensor_holder.h +++ b/paddle/fluid/eager/grad_tensor_holder.h @@ -45,8 +45,10 @@ class GradTensorHolder { GradTensorHolder& operator=(const GradTensorHolder& other) = default; // Create new tensor and copy tensor->impl - void add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t, - bool fill_one = false); + void add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t); + void CopyValueFromTensor(size_t slot_id, size_t rank, + const paddle::experimental::Tensor& t, + bool fill_one = false); const std::vector& operator[]( const size_t& pos) { diff --git a/paddle/fluid/eager/tests/CMakeLists.txt b/paddle/fluid/eager/tests/CMakeLists.txt index 2bfb9937c8c91..6bcd34262c8ab 100644 --- a/paddle/fluid/eager/tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/CMakeLists.txt @@ -1,6 +1,3 @@ add_subdirectory(data_structure_tests) add_subdirectory(task_tests) - -if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) - add_subdirectory(performance_tests) -endif() +add_subdirectory(performance_tests) diff --git a/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt b/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt index e1cd9939aca77..76c59561fc0bb 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt @@ -1,6 +1,9 @@ cc_test(test_egr_ds_eager_tensor SRCS eager_tensor_test.cc DEPS ${eager_deps}) cc_test(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS ${eager_deps}) cc_test(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS ${eager_deps}) -cc_test(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc DEPS ${eager_deps}) cc_test(test_egr_ds_accumulation_node SRCS accumulation_node_test.cc DEPS ${eager_deps}) cc_test(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS ${eager_deps}) + +if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) + cc_test(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc DEPS ${eager_deps} ${generated_deps}) +endif() diff --git a/paddle/fluid/eager/tests/data_structure_tests/grad_tensor_holder_test.cc b/paddle/fluid/eager/tests/data_structure_tests/grad_tensor_holder_test.cc index 645eac06ddda5..7d2aafc63628e 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/grad_tensor_holder_test.cc +++ b/paddle/fluid/eager/tests/data_structure_tests/grad_tensor_holder_test.cc @@ -25,6 +25,7 @@ #include "paddle/phi/core/kernel_registry.h" PD_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); // TODO(jiabin): remove nolint here!!! using namespace egr; // NOLINT @@ -77,11 +78,11 @@ TEST(GradTensorHolder, Interfaces) { // add(): // fill one - grad_tensor_holder.add(0, 0, et0, true); + grad_tensor_holder.CopyValueFromTensor(0, 0, et0, true); // accumulation - grad_tensor_holder.add(1, 0, et0, false); - grad_tensor_holder.add(1, 0, et1, false); + grad_tensor_holder.add(1, 0, et0); + grad_tensor_holder.add(1, 0, et1); // Buffers() const auto& buffers = grad_tensor_holder.Buffers(); @@ -141,8 +142,8 @@ TEST(GradTensorHolder, SelectedRowsMergeAdd) { GradTensorHolder({slot_meta, slot_meta}); // accumulation - grad_tensor_holder.add(0, 0, t1, false); - grad_tensor_holder.add(0, 0, t2, false); + grad_tensor_holder.add(0, 0, t1); + grad_tensor_holder.add(0, 0, t2); // Buffers() const auto& buffers = grad_tensor_holder.Buffers(); diff --git a/paddle/fluid/eager/tests/task_tests/CMakeLists.txt b/paddle/fluid/eager/tests/task_tests/CMakeLists.txt index 52dba6b9218c7..5a09ffd6a1e5f 100644 --- a/paddle/fluid/eager/tests/task_tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/task_tests/CMakeLists.txt @@ -1,13 +1,13 @@ cc_test(test_egr_task_tensor_utils SRCS tensor_utils_test.cc DEPS ${eager_deps}) cc_test(test_egr_task_eager_utils SRCS eager_utils_test.cc DEPS ${eager_deps}) cc_test(test_egr_task_forward_autograd SRCS forward_autograd_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) -cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) -cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) -cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) -cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) -cc_test(test_egr_task_grad SRCS grad_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) + cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node) + cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node) + cc_test(test_egr_task_grad SRCS grad_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node) + cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node) + cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node) cc_test(test_egr_task_hook_intermidiate SRCS hook_test_intermidiate.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} dygraph_node) cc_test(test_egr_task_autocodegen SRCS generated_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps}) endif() diff --git a/paddle/fluid/eager/tests/task_tests/backward_test.cc b/paddle/fluid/eager/tests/task_tests/backward_test.cc index 87f8f6eca1f88..8c127efa4f7f3 100644 --- a/paddle/fluid/eager/tests/task_tests/backward_test.cc +++ b/paddle/fluid/eager/tests/task_tests/backward_test.cc @@ -34,6 +34,7 @@ PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); namespace egr { diff --git a/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc b/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc index 882695e98d109..d2bef100ca2b5 100644 --- a/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc +++ b/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc @@ -33,8 +33,10 @@ #include "paddle/phi/core/kernel_registry.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT); #endif namespace egr { diff --git a/paddle/fluid/eager/tests/task_tests/grad_test.cc b/paddle/fluid/eager/tests/task_tests/grad_test.cc index 6b03799c48659..7e64c65d8205e 100644 --- a/paddle/fluid/eager/tests/task_tests/grad_test.cc +++ b/paddle/fluid/eager/tests/task_tests/grad_test.cc @@ -33,6 +33,8 @@ PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); + namespace egr { TEST(Grad, SingleNodeEmptyGrad) { diff --git a/paddle/phi/api/include/tensor.h b/paddle/phi/api/include/tensor.h index 9b0371fc380b9..0a2e815be8411 100644 --- a/paddle/phi/api/include/tensor.h +++ b/paddle/phi/api/include/tensor.h @@ -494,6 +494,7 @@ class PADDLE_API Tensor final { * @return AbstractAutogradMeta* */ AbstractAutogradMeta* get_autograd_meta() const; + const std::shared_ptr& mutable_autograd_meta() const; /** * @brief Set the autograd meta object diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index 5cd1fcb919638..3790384c8af16 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -354,6 +354,11 @@ AbstractAutogradMeta *Tensor::get_autograd_meta() const { return autograd_meta_.get(); } +const std::shared_ptr &Tensor::mutable_autograd_meta() + const { + return autograd_meta_; +} + void Tensor::set_autograd_meta( std::shared_ptr autograd_meta) { autograd_meta_ = std::move(autograd_meta); From 6b5cff5462fd7c37d0da57510585a847f67ae7f4 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Sat, 2 Apr 2022 15:48:08 +0800 Subject: [PATCH 18/93] Add UT for full_like after migration YAML (#41290) * Add UT for full_like after migration YAML * rename test class --- .../tests/unittests/test_full_like_op.py | 41 +++++++++++++++++++ python/paddle/tensor/creation.py | 5 ++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_full_like_op.py b/python/paddle/fluid/tests/unittests/test_full_like_op.py index 3ae2e9ff6bdaf..05a310a9c5033 100644 --- a/python/paddle/fluid/tests/unittests/test_full_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_full_like_op.py @@ -21,6 +21,7 @@ import unittest import numpy as np from op_test import OpTest +from paddle.fluid.framework import convert_np_dtype_to_dtype_ class TestFullOp(unittest.TestCase): @@ -92,5 +93,45 @@ def test_input_dtype(): dtype='uint4') +class TestFullLikeOp1(OpTest): + # test basic + def setUp(self): + self.op_type = "fill_any_like" + self.python_api = paddle.full_like + self.init_data() + + x = np.zeros(self.shape) + out = np.full_like(x, self.fill_value, self.dtype) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = { + 'value': self.fill_value, + 'dtype': convert_np_dtype_to_dtype_(self.dtype) + } + + def init_data(self): + self.fill_value = 5 + self.shape = [10, 10] + self.dtype = np.float32 + + def test_check_output(self): + self.check_output(check_eager=True) + + +class TestFullLikeOp2(TestFullLikeOp1): + def init_data(self): + self.fill_value = 1000 + self.shape = [1024, 1024] + self.dtype = np.float64 + + +class TestFullLikeOp3(TestFullLikeOp1): + def init_data(self): + self.fill_value = 8888 + self.shape = [5000, 5000] + self.dtype = np.int64 + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 6e7e5678be0b0..ca16995f84d2f 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -224,7 +224,10 @@ def full_like(x, fill_value, dtype=None, name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_full_like(x, fill_value, dtype, x.place) + + if _in_legacy_dygraph(): return _C_ops.fill_any_like(x, 'value', fill_value, 'dtype', dtype) helper = LayerHelper("full_like", **locals()) From a9d66025a378f03c71f5bfb74481c6348f4448b3 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Sat, 2 Apr 2022 16:40:36 +0800 Subject: [PATCH 19/93] Fix ci problem2 (#41263) * support test_create_paramter * support fused_transformer_encoder_layer * skip program_desc tracer related tests in eager mode * fix ci tests on eager --- .../tests/unittests/test_create_parameter.py | 8 +- .../test_fused_transformer_encoder_layer.py | 4 +- ...imperative_trace_non_persistable_inputs.py | 2 + .../fluid/tests/unittests/test_initializer.py | 35 +++- .../unittests/test_op_function_generator.py | 4 +- .../fluid/tests/unittests/test_parameter.py | 18 +- .../tests/unittests/test_retain_graph.py | 10 +- .../unittests/test_traced_layer_err_msg.py | 11 ++ python/paddle/nn/initializer/dirac.py | 157 +++++++++++------- .../paddle/nn/utils/transform_parameters.py | 60 ++++--- python/paddle/tests/test_model.py | 2 + 11 files changed, 215 insertions(+), 96 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_create_parameter.py b/python/paddle/fluid/tests/unittests/test_create_parameter.py index 199558acd4ef6..fb4b5e4b6fa88 100644 --- a/python/paddle/fluid/tests/unittests/test_create_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_create_parameter.py @@ -22,7 +22,8 @@ class TestCreateParameterError(unittest.TestCase): - def test_errors(self): + def func_errors(self): + paddle.enable_static() with program_guard(Program(), Program()): def test_shape(): @@ -49,6 +50,11 @@ def test_default_initializer(): self.assertRaises(TypeError, test_default_initializer) + def test_errors(self): + with fluid.framework._test_eager_guard(): + self.func_errors() + self.func_errors() + if __name__ == '__main__': paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py b/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py index e0281d6e21e5a..7dc86d0dea382 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py +++ b/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py @@ -16,7 +16,7 @@ import paddle from paddle.incubate.nn import FusedTransformerEncoderLayer from paddle.nn import TransformerEncoderLayer -from paddle.fluid.framework import default_main_program +from paddle.fluid.framework import default_main_program, in_dygraph_mode import unittest @@ -61,6 +61,8 @@ def fused_qkv(self, q, k, v, num_head): return paddle.concat(x=[fq, fk, fv], axis=0) def test_out(self): + if in_dygraph_mode(): + return default_main_program().random_seed = 42 base_encoder = TransformerEncoderLayer( self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py b/python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py index 645a05e75f6fb..a621105f5084c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py @@ -33,6 +33,8 @@ def forward(self, x): class TestTracedLayerRecordNonPersistableInput(unittest.TestCase): def test_main(self): + if fluid.framework.in_dygraph_mode(): + return traced_layer = None with fluid.dygraph.guard(): feature_size = 3 diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index 8dc822c69b2c5..91c2800836c9d 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -655,7 +655,7 @@ def test_set_global_bias_initilizer(self): class TestUniformInitializerDygraph(unittest.TestCase): - def test_uniform_initializer(self, dtype="float32"): + def func_uniform_initializer(self, dtype="float32"): """ In dygraph mode, we can use initializer directly to initialize a tensor. """ @@ -679,9 +679,14 @@ def test_uniform_initializer(self, dtype="float32"): paddle.enable_static() + def test_uniform_initializer(self, dtype="float32"): + with framework._test_eager_guard(): + self.func_uniform_initializer() + self.func_uniform_initializer() + class TesetconsistencyOfDynamicAndStaticGraph(unittest.TestCase): - def test_order(self): + def func_order(self): paddle.set_device('cpu') SEED = 123 weight_attr = paddle.framework.ParamAttr( @@ -723,6 +728,11 @@ def run_static_graph(): self.assertTrue(np.array_equal(dynamic_res[0], static_res[0])) self.assertTrue(np.array_equal(dynamic_res[1], static_res[1])) + def test_order(self): + with framework._test_eager_guard(): + self.func_order() + self.func_order() + # 2-D Parameter with shape: [10, 15] class TestOrthogonalInitializer1(unittest.TestCase): @@ -742,7 +752,7 @@ def check_result(self, a, b): self.assertTrue(np.array_equal(a, b)) self.assertTrue(np.allclose(np.matmul(a, a.T), 9 * np.eye(10))) - def test_orthogonal(self): + def func_orthogonal(self): self.config() paddle.set_default_dtype(self.dtype) @@ -777,6 +787,11 @@ def test_orthogonal(self): self.check_result(res_dygraph, res_static) + def test_orthogonal(self): + with framework._test_eager_guard(): + self.func_orthogonal() + self.func_orthogonal() + # 2-D Parameter with shape: [15, 10] class TestOrthogonalInitializer2(TestOrthogonalInitializer1): @@ -841,7 +856,7 @@ def check_result(self, a, b): a = a.reshape(6, -1) self.assertTrue(np.allclose(np.matmul(a, a.T), 9 * np.eye(6))) - def test_orthogonal(self): + def func_orthogonal(self): self.config() paddle.set_default_dtype(self.dtype) @@ -869,6 +884,11 @@ def test_orthogonal(self): fetch_list=[conv2d.weight])[0] self.check_result(res_dygraph, res_static) + def test_orthogonal(self): + with framework._test_eager_guard(): + self.func_orthogonal() + self.func_orthogonal() + # 4-D Parameter with shape: [50, 4, 3, 3] class TestOrthogonalInitializer5(TestOrthogonalInitializer4): @@ -928,7 +948,7 @@ def check_result(self, w_dygraph, w_static, conv_in, conv_out): self.assertTrue(np.array_equal(w_dygraph, w_static)) self.assertTrue(np.array_equal(conv_out, conv_in[:, 0:2, 1:9])) - def test_dirac(self): + def func_dirac(self): self.config() paddle.set_default_dtype(self.dtype) @@ -971,6 +991,11 @@ def test_dirac(self): self.check_result(weight_dygraph, weight_static, conv_input, conv_output) + def test_dirac(self): + with framework._test_eager_guard(): + self.func_dirac() + self.func_dirac() + # initialize Conv2D weight class TestDiracInitializer2(TestDiracInitializer1): diff --git a/python/paddle/fluid/tests/unittests/test_op_function_generator.py b/python/paddle/fluid/tests/unittests/test_op_function_generator.py index 216deddb9ef98..c712b5db0f31f 100644 --- a/python/paddle/fluid/tests/unittests/test_op_function_generator.py +++ b/python/paddle/fluid/tests/unittests/test_op_function_generator.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest -from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, _non_static_mode +from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core @@ -92,6 +92,8 @@ def test_trace_backward(self): self.assertTrue(np.array_equal(y_grad, loss.gradient() * a)) def test_traced_layer(self): + if in_dygraph_mode(): + return with fluid.dygraph.guard(): layer = TestTracedLayer("test_traced_layer") a = np.random.uniform(-1, 1, self.shape).astype(self.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_parameter.py b/python/paddle/fluid/tests/unittests/test_parameter.py index 85ba69cd438a7..61d75fca2745e 100644 --- a/python/paddle/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_parameter.py @@ -18,7 +18,7 @@ import copy import paddle from paddle.fluid.dygraph import guard -from paddle.fluid.framework import default_main_program, Variable +from paddle.fluid.framework import default_main_program, Variable, _test_eager_guard import paddle.fluid.core as core from paddle.fluid.executor import Executor import paddle.fluid.io as io @@ -50,7 +50,7 @@ def test_parameter(self): p = io.get_parameter_value_by_name('fc.w', exe, main_program) self.assertTrue(np.array_equal(p, np.ones(shape) * val)) - def test_parambase(self): + def func_parambase(self): with guard(): linear = paddle.nn.Linear(10, 10) param = linear.weight @@ -72,7 +72,12 @@ def test_parambase(self): pram_copy2 = copy.deepcopy(param, memo) self.assertEqual(id(param_copy), id(pram_copy2)) - def test_exception(self): + def test_parambase(self): + with _test_eager_guard(): + self.func_parambase() + self.func_parambase() + + def func_exception(self): b = main_program.global_block() with self.assertRaises(ValueError): b.create_parameter( @@ -87,7 +92,7 @@ def test_exception(self): b.create_parameter( name='test', shape=[-1], dtype='float32', initializer=None) - def test_parambase_to_vector(self): + def func_parambase_to_vector(self): with guard(): initializer = paddle.ParamAttr( initializer=paddle.nn.initializer.Constant(3.)) @@ -112,6 +117,11 @@ def test_parambase_to_vector(self): self.assertTrue(linear2.weight.is_leaf, True) self.assertTrue(linear2.bias.is_leaf, True) + def test_parambase_to_vector(self): + with _test_eager_guard(): + self.func_parambase_to_vector() + self.func_parambase_to_vector() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_retain_graph.py b/python/paddle/fluid/tests/unittests/test_retain_graph.py index 79664fe4b12fb..0259b898a488e 100644 --- a/python/paddle/fluid/tests/unittests/test_retain_graph.py +++ b/python/paddle/fluid/tests/unittests/test_retain_graph.py @@ -134,9 +134,15 @@ def run_retain(self, need_retain): loss_g.backward() optim_g.minimize(loss_g) - def test_retain(self): + def func_retain(self): self.run_retain(need_retain=True) - self.assertRaises(RuntimeError, self.run_retain, need_retain=False) + if not fluid.framework.in_dygraph_mode(): + self.assertRaises(RuntimeError, self.run_retain, need_retain=False) + + def test_retain(self): + with fluid.framework._test_eager_guard(): + self.func_retain() + self.func_retain() if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py b/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py index 3b9fbd69e9d0a..5703ce1313176 100644 --- a/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py +++ b/python/paddle/fluid/tests/unittests/test_traced_layer_err_msg.py @@ -1,4 +1,5 @@ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,6 +53,8 @@ def setUp(self): self.type_str = 'class' def test_trace_err(self): + if fluid.framework.in_dygraph_mode(): + return with fluid.dygraph.guard(): in_x = fluid.dygraph.to_variable( np.random.random((self.batch_size, self.feature_size)).astype( @@ -80,6 +83,8 @@ def test_trace_err(self): self.layer, [in_x]) def test_set_strategy_err(self): + if fluid.framework.in_dygraph_mode(): + return with fluid.dygraph.guard(): in_x = fluid.dygraph.to_variable( np.random.random((self.batch_size, self.feature_size)).astype( @@ -105,6 +110,8 @@ def test_set_strategy_err(self): fluid.ExecutionStrategy()) def test_save_inference_model_err(self): + if fluid.framework.in_dygraph_mode(): + return with fluid.dygraph.guard(): in_x = fluid.dygraph.to_variable( np.random.random((self.batch_size, self.feature_size)).astype( @@ -169,6 +176,8 @@ def _train_simple_net(self): class TestOutVarWithNoneErrMsg(unittest.TestCase): def test_linear_net_with_none(self): + if fluid.framework.in_dygraph_mode(): + return model = LinearNetWithNone(100, 16) in_x = paddle.to_tensor(np.random.random((4, 100)).astype('float32')) with self.assertRaises(TypeError): @@ -186,6 +195,8 @@ def setUp(self): shutil.rmtree(os.path.dirname(self.save_path)) def test_mkdir_when_input_path_non_exist(self): + if fluid.framework.in_dygraph_mode(): + return fc_layer = SimpleFCLayer(3, 4, 2) input_var = paddle.to_tensor(np.random.random([4, 3]).astype('float32')) with fluid.dygraph.guard(): diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index da3266ab33694..46f47fbc7b639 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -18,7 +18,8 @@ from ...fluid import framework from paddle import in_dynamic_mode from paddle.utils import unique_name - +from paddle import _C_ops +from ... import fluid __all__ = [] @@ -123,17 +124,24 @@ def __call__(self, var, block=None): persistable=False) else: out_var = var - - block.append_op( - type='fill_constant', - inputs={}, - outputs={'Out': out_var}, - attrs={ - 'value': float(0), - 'dtype': out_var.dtype, - 'shape': out_var.shape, - }, - stop_gradient=True) + op = None + if framework.in_dygraph_mode(): + with fluid.dygraph.no_grad(): + _C_ops.fill_constant(out_var, 'value', + float(0), 'force_cpu', False, 'dtype', + out_var.dtype, 'str_value', + str(float(0)), 'shape', out_var.shape) + else: + block.append_op( + type='fill_constant', + inputs={}, + outputs={'Out': out_var}, + attrs={ + 'value': float(0), + 'dtype': out_var.dtype, + 'shape': out_var.shape, + }, + stop_gradient=True) origin_shape = var.shape num_per_group = origin_shape[0] // self._groups @@ -158,71 +166,100 @@ def __call__(self, var, block=None): else: offset += origin_shape[k] // 2 * stride idx_list.append(offset) - - block.append_op( - type="reshape", - inputs={"X": out_var}, - attrs={'shape': [-1]}, - outputs={"Out": out_var}, - stop_gradient=True) + if framework.in_dygraph_mode(): + with fluid.dygraph.no_grad(): + tmp_out = _C_ops.reshape(out_var, 'shape', [-1]) + tmp_out._share_underline_tensor_to(out_var) + else: + block.append_op( + type="reshape", + inputs={"X": out_var}, + attrs={'shape': [-1]}, + outputs={"Out": out_var}, + stop_gradient=True) index_tensor = block.create_var( name=unique_name.generate('scatter_index'), persistable=False, stop_gradient=True) - block.append_op( - type='assign_value', - outputs={'Out': index_tensor}, - attrs={ - 'dtype': VarDesc.VarType.INT64, - 'shape': [len(idx_list)], - 'int64_values': idx_list - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + with fluid.dygraph.no_grad(): + tmp_tensor = _C_ops.assign_value('shape', [len(idx_list)], + 'dtype', VarDesc.VarType.INT64, + 'int64_values', idx_list) + tmp_tensor._share_underline_tensor_to(index_tensor) + else: + block.append_op( + type='assign_value', + outputs={'Out': index_tensor}, + attrs={ + 'dtype': VarDesc.VarType.INT64, + 'shape': [len(idx_list)], + 'int64_values': idx_list + }, + stop_gradient=True) value_tensor = block.create_var( name=unique_name.generate('scatter_value'), persistable=False, stop_gradient=True) - block.append_op( - type='assign_value', - outputs={'Out': value_tensor}, - attrs={ - 'dtype': VarDesc.VarType.FP32, - 'shape': [len(value_list)], - 'fp32_values': value_list - }, - stop_gradient=True) - - op = block.append_op( - type="scatter", - inputs={ - "X": out_var, - "Ids": index_tensor, - "Updates": value_tensor - }, - attrs={'overwrite': True}, - outputs={"Out": out_var}, - stop_gradient=True) + if framework.in_dygraph_mode(): + with fluid.dygraph.no_grad(): + tmp_tensor = _C_ops.assign_value('shape', [len(value_list)], + 'dtype', VarDesc.VarType.FP32, + 'fp32_values', value_list) + tmp_tensor._share_underline_tensor_to(value_tensor) + else: + block.append_op( + type='assign_value', + outputs={'Out': value_tensor}, + attrs={ + 'dtype': VarDesc.VarType.FP32, + 'shape': [len(value_list)], + 'fp32_values': value_list + }, + stop_gradient=True) - block.append_op( - type="reshape", - inputs={"X": out_var}, - attrs={'shape': origin_shape}, - outputs={"Out": out_var}, - stop_gradient=True) + if framework.in_dygraph_mode(): + with fluid.dygraph.no_grad(): + tmp_out = _C_ops.final_state_scatter(out_var, index_tensor, + value_tensor, True) + tmp_out._share_underline_tensor_to(out_var) + tmp_reshape_out = _C_ops.reshape(out_var, 'shape', origin_shape) + tmp_reshape_out._share_underline_tensor_to(out_var) + if var.dtype != VarDesc.VarType.FP32: + tmp_cast_out = _C_ops.cast(out_var, 'in_dtype', + out_var.dtype, 'out_dtype', + var.dtype) + tmp_cast_out._share_underline_tensor_to(var) - if var.dtype != VarDesc.VarType.FP32: + else: + op = block.append_op( + type="scatter", + inputs={ + "X": out_var, + "Ids": index_tensor, + "Updates": value_tensor + }, + attrs={'overwrite': True}, + outputs={"Out": out_var}, + stop_gradient=True) block.append_op( - type="cast", + type="reshape", inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}, + attrs={'shape': origin_shape}, + outputs={"Out": out_var}, stop_gradient=True) - + if var.dtype != VarDesc.VarType.FP32: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}, + stop_gradient=True) if not in_dynamic_mode(): var.op = op return op diff --git a/python/paddle/nn/utils/transform_parameters.py b/python/paddle/nn/utils/transform_parameters.py index 03d2fa514869d..ef5cd8700761f 100644 --- a/python/paddle/nn/utils/transform_parameters.py +++ b/python/paddle/nn/utils/transform_parameters.py @@ -15,20 +15,25 @@ from functools import reduce import paddle -from paddle.fluid.framework import dygraph_only, _dygraph_tracer, _varbase_creator +from paddle.fluid.framework import dygraph_only, _dygraph_tracer, _varbase_creator, in_dygraph_mode from paddle import _C_ops #input==output, inplace strategy of reshape has no cost almostly def _inplace_reshape_dygraph(x, shape): - x_shape = _varbase_creator(dtype=x.dtype) - _dygraph_tracer().trace_op( - type="reshape2", - inputs={'X': x}, - outputs={'Out': x, - 'XShape': x_shape}, - attrs={'shape': shape}, - stop_gradient=True) + x_shape = _varbase_creator(dtype='int64') + if in_dygraph_mode(): + with paddle.fluid.dygraph.no_grad(): + tmp_out, _ = _C_ops.reshape2(x, None, 'shape', shape) + tmp_out._share_underline_tensor_to(x) + else: + _dygraph_tracer().trace_op( + type="reshape2", + inputs={'X': x}, + outputs={'Out': x, + 'XShape': x_shape}, + attrs={'shape': shape}, + stop_gradient=True) @dygraph_only @@ -62,12 +67,16 @@ def parameters_to_vector(parameters, name=None): _inplace_reshape_dygraph(param, [-1]) out = _varbase_creator(dtype=dtype) - _dygraph_tracer().trace_op( - type='concat', - inputs={'X': parameters}, - outputs={'Out': [out]}, - attrs={'axis': 0}, - stop_gradient=True) + if in_dygraph_mode(): + with paddle.fluid.dygraph.no_grad(): + _C_ops.concat(parameters, 'axis', 0)._share_underline_tensor_to(out) + else: + _dygraph_tracer().trace_op( + type='concat', + inputs={'X': parameters}, + outputs={'Out': [out]}, + attrs={'axis': 0}, + stop_gradient=True) for i, param in enumerate(parameters): _inplace_reshape_dygraph(param, origin_shapes[i]) return out @@ -109,13 +118,20 @@ def vector_to_parameters(vec, parameters, name=None): numel = reduce(lambda x, y: x * y, shape) sections.append(numel) - _dygraph_tracer().trace_op( - type='split', - inputs={'X': [vec]}, - outputs={'Out': parameters}, - attrs={'axis': 0, - 'sections': sections}, - stop_gradient=True) + if in_dygraph_mode(): + with paddle.fluid.dygraph.no_grad(): + res = _C_ops.split(vec, + len(parameters), 'axis', 0, 'sections', sections) + for i in range(0, len(res)): + res[i]._share_underline_tensor_to(parameters[i]) + else: + _dygraph_tracer().trace_op( + type='split', + inputs={'X': [vec]}, + outputs={'Out': parameters}, + attrs={'axis': 0, + 'sections': sections}, + stop_gradient=True) for i, param in enumerate(parameters): _inplace_reshape_dygraph(param, origin_shapes[i]) diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 53dce286b71e9..ce3a3bd4b02fe 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -715,6 +715,8 @@ def test_summary_error(self): paddle.summary(nlp_net, (1, 1, 2)) def test_static_flops(self): + if paddle.fluid.framework._in_eager_without_dygraph_check(): + return paddle.disable_static() net = models.__dict__['mobilenet_v2'](pretrained=False) inputs = paddle.randn([1, 3, 224, 224]) From f48261373343ff5ad1c9093296cbedf932070c36 Mon Sep 17 00:00:00 2001 From: wangguanqun Date: Sat, 2 Apr 2022 16:52:28 +0800 Subject: [PATCH 20/93] Delete function in accessor and update function name in accessor and sgd (#41292) * delete function * fix bug * update name * fix bug in strategy --- .../distributed/ps/service/brpc_ps_client.cc | 73 +++---- .../distributed/ps/service/brpc_ps_server.cc | 4 +- .../distributed/ps/service/ps_local_client.cc | 11 +- paddle/fluid/distributed/ps/table/accessor.h | 26 +-- .../ps/table/common_dense_table.cc | 3 +- .../distributed/ps/table/ctr_accessor.cc | 178 ++++++---------- .../fluid/distributed/ps/table/ctr_accessor.h | 75 +++---- .../ps/table/ctr_double_accessor.cc | 192 ++++++------------ .../ps/table/ctr_double_accessor.h | 86 +++----- .../ps/table/downpour_ctr_accessor.cc | 189 ++++++----------- .../ps/table/downpour_ctr_accessor.h | 86 +++----- .../ps/table/memory_sparse_table.cc | 32 +-- .../distributed/ps/table/sparse_accessor.cc | 180 ++++++---------- .../distributed/ps/table/sparse_accessor.h | 72 +++---- .../distributed/ps/table/sparse_sgd_rule.cc | 93 +++++---- .../distributed/ps/table/sparse_sgd_rule.h | 90 ++++---- paddle/fluid/distributed/ps/table/table.cc | 1 - paddle/fluid/distributed/ps/table/table.h | 1 - .../distributed/ps/table/tensor_accessor.cc | 52 +---- .../distributed/ps/table/tensor_accessor.h | 18 +- .../distributed/test/ctr_accessor_test.cc | 32 +-- .../distributed/test/sparse_sgd_rule_test.cc | 46 ++--- .../distributed/fleet/base/fleet_base.py | 4 +- 23 files changed, 586 insertions(+), 958 deletions(-) mode change 100755 => 100644 paddle/fluid/distributed/ps/service/brpc_ps_client.cc diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc old mode 100755 new mode 100644 index 893e0f9a97596..971c448bf2714 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -525,12 +525,12 @@ std::future BrpcPsClient::PullGeoParam(size_t table_id, io_buffer_itr.copy_and_forward(reinterpret_cast(&shard_nums), sizeof(uint32_t)); keys->resize(shard_nums); - values->resize(shard_nums * accessor->GetTableInfo(UPDATE_DIM)); + values->resize(shard_nums * accessor->GetAccessorInfo().update_dim); io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT sizeof(uint64_t) * shard_nums); io_buffer_itr.copy_and_forward( (void *)(values->data()), // NOLINT - shard_nums * accessor->GetTableInfo(UPDATE_SIZE)); + shard_nums * accessor->GetAccessorInfo().update_size); closure->set_promise_value(ret); }); auto promise = std::make_shared>(); @@ -573,7 +573,7 @@ std::future BrpcPsClient::PushSparseParam(size_t table_id, auto kvs = ids[shard_idx]; auto value_ptr = value_ptrs[shard_idx]; size_t kv_size = kvs.size(); - uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); + uint32_t value_size = accessor->GetAccessorInfo().update_size; // 发送RPC请求 auto *push_request = closure->request(shard_idx); push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); @@ -581,14 +581,13 @@ std::future BrpcPsClient::PushSparseParam(size_t table_id, push_request->set_client_id(_client_id); push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); - push_data->resize(kv_size * - (sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE))); + push_data->resize(kv_size * (sizeof(uint64_t) + value_size)); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); push_data_ptr += kv_size * sizeof(uint64_t); for (int i = 0; i < kv_size; ++i) { - memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); - push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); + memcpy(push_data_ptr, value_ptr[i], value_size); + push_data_ptr += value_size; } PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( @@ -603,11 +602,9 @@ std::future BrpcPsClient::PullDense(Region *regions, size_t region_num, size_t table_id) { auto timer = std::make_shared("pserver_client_pull_dense"); auto *accessor = GetTableAccessor(table_id); - auto fea_dim = accessor->GetTableInfo(FEA_DIM); - auto select_size = accessor->GetTableInfo(SELECT_SIZE); + auto fea_dim = accessor->GetAccessorInfo().fea_dim; size_t request_call_num = _server_channels.size(); - uint32_t num_per_shard = - DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); + uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num); // callback 将各shard结果,顺序填入region DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, num_per_shard, regions, region_num, @@ -617,7 +614,7 @@ std::future BrpcPsClient::PullDense(Region *regions, size_t region_num, size_t region_data_idx = 0; // 当前填充的region内data偏移 auto *closure = reinterpret_cast(done); size_t shard_data_size = - num_per_shard * accessor->GetTableInfo(SELECT_SIZE); + num_per_shard * accessor->GetAccessorInfo().select_size; for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { ret = -1; @@ -681,12 +678,13 @@ std::future BrpcPsClient::PushDenseParam(const Region *regions, size_t region_num, size_t table_id) { auto *accessor = GetTableAccessor(table_id); + auto accessor_info = accessor->GetAccessorInfo(); size_t request_call_num = _server_channels.size(); // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 std::vector> regions_partition(request_call_num); uint32_t num_per_shard = - DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); - size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE); + DenseDimPerShard(accessor_info.fea_dim, request_call_num); + size_t shard_data_size = num_per_shard * accessor_info.update_size; size_t current_region_idx = 0; size_t current_region_data_idx = 0; for (size_t i = 0; i < request_call_num; ++i) { @@ -793,7 +791,7 @@ std::future BrpcPsClient::PushSparseRawGradient( auto value_ptr = value_ptrs[shard_idx]; size_t kv_size = kvs.size(); - uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); + uint32_t value_size = accessor->GetAccessorInfo().update_size; // 发送RPC请求 auto *push_request = closure->request(shard_idx); @@ -802,15 +800,14 @@ std::future BrpcPsClient::PushSparseRawGradient( push_request->set_client_id(_client_id); push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); - push_data->resize(kv_size * - (sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE))); + push_data->resize(kv_size * (sizeof(uint64_t) + value_size)); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); push_data_ptr += kv_size * sizeof(uint64_t); for (int i = 0; i < kv_size; ++i) { - memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE)); - push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); + memcpy(push_data_ptr, value_ptr[i], value_size); + push_data_ptr += value_size; } PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( @@ -831,7 +828,7 @@ std::future BrpcPsClient::PushDenseRawGradient( std::future fut = promise->get_future(); auto *accessor = GetTableAccessor(table_id); uint32_t num_per_shard = - DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num); for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_table_id(table_id); @@ -910,7 +907,7 @@ std::future BrpcPsClient::PullSparse(float **select_values, auto *accessor = GetTableAccessor(table_id); - size_t value_size = accessor->GetTableInfo(SELECT_SIZE); + size_t value_size = accessor->GetAccessorInfo().select_size; DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [shard_sorted_kvs, value_size](void *done) { @@ -1023,8 +1020,7 @@ std::future BrpcPsClient::PullSparseParam(float **select_values, } auto *accessor = GetTableAccessor(table_id); - size_t value_size = accessor->GetTableInfo(SELECT_SIZE); - + size_t value_size = accessor->GetAccessorInfo().select_size; DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [shard_sorted_kvs, value_size](void *done) { int ret = 0; @@ -1147,7 +1143,7 @@ std::future BrpcPsClient::PushSparseRawGradientPartial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) { auto *accessor = GetTableAccessor(table_id); - size_t value_size = accessor->GetTableInfo(UPDATE_SIZE); + size_t value_size = accessor->GetAccessorInfo().update_size; DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); @@ -1307,7 +1303,7 @@ std::future BrpcPsClient::PushSparse(size_t table_id, shard_kv_data.kv_num = 0; continue; } - uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); + uint32_t value_size = accessor->GetAccessorInfo().update_size; for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first; shard_kv_data.value_list[kv_idx].assign( @@ -1453,7 +1449,7 @@ void BrpcPsClient::PushSparseTaskConsume() { void sparse_local_merge(ValueAccessor *accessor, float *merge_data, const float *another_data) { - size_t col_num = accessor->GetTableInfo(UPDATE_SIZE) / sizeof(float); + size_t col_num = accessor->GetAccessorInfo().update_dim; float *merge_data_shell[col_num]; const float *another_data_shell[col_num]; for (int i = 0; i < col_num; ++i) { @@ -1469,7 +1465,7 @@ int BrpcPsClient::PushSparseAsyncShardMerge( ValueAccessor *accessor) { size_t merged_kv_count = 0; uint64_t min_key = UINT64_MAX; - uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE); + uint32_t value_size = accessor->GetAccessorInfo().update_size; thread_local std::vector> sorted_kv_list; sorted_kv_list.clear(); @@ -1575,9 +1571,8 @@ int BrpcPsClient::PushSparseAsyncShardPush( push_request->add_params(reinterpret_cast(&merged_kv_count), sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); - int update_size = accessor->GetTableInfo(UPDATE_SIZE); - push_data->resize(merged_kv_count * - (sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE))); + int update_size = accessor->GetAccessorInfo().update_size; + push_data->resize(merged_kv_count * (sizeof(uint64_t) + update_size)); char *push_data_ptr = const_cast(push_data->data()); memcpy(push_data_ptr, merged_key_list.data(), merged_kv_count * sizeof(uint64_t)); @@ -1586,8 +1581,8 @@ int BrpcPsClient::PushSparseAsyncShardPush( const char *task_data_ptr = merged_value_list[i].data(); memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT - accessor->GetTableInfo(UPDATE_SIZE)); - push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE); + update_size); + push_data_ptr += update_size; } PsService_Stub rpc_stub(GetSparseChannel(shard_idx)); closure->cntl(shard_idx)->set_request_compress_type( @@ -1602,8 +1597,8 @@ std::future BrpcPsClient::PushDense(const Region *regions, size_t region_num, size_t table_id) { auto *accessor = GetTableAccessor(table_id); - int fea_dim = accessor->GetTableInfo(FEA_DIM); - int update_dim = accessor->GetTableInfo(UPDATE_DIM); + int fea_dim = accessor->GetAccessorInfo().fea_dim; + int update_dim = accessor->GetAccessorInfo().update_dim; auto push_timer = std::make_shared("pserver_client_push_dense"); auto parse_timer = std::make_shared("pserver_client_push_dense_parse"); @@ -1621,13 +1616,9 @@ std::future BrpcPsClient::PushDense(const Region *regions, auto dense_data = std::make_shared>(); auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer); size_t request_call_num = _server_channels.size(); - - uint32_t num_per_shard = - DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); - + uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num); // 将region数据拷贝到转置矩阵中 - async_task->data()->resize(num_per_shard * request_call_num * - accessor->GetTableInfo(UPDATE_DIM)); + async_task->data()->resize(num_per_shard * request_call_num * update_dim); float *data = async_task->data()->data(); size_t data_size = async_task->data()->size(); uint32_t pos = 0; @@ -1757,7 +1748,7 @@ void BrpcPsClient::PushDenseRawGradient(std::shared_ptr &task, auto timer = std::make_shared("pserver_client_push_dense_rpc"); closure->add_timer(timer); uint32_t num_per_shard = - DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), request_call_num); + DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num); auto send_timer = std::make_shared("pserver_client_push_dense_send"); for (size_t i = 0; i < request_call_num; ++i) { diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index 1d88d88ebcf14..a1690cbb9353b 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -205,7 +205,7 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request, } auto res_data = butil::get_object>(); - res_data->resize(num * table->ValueAccesor()->GetTableInfo(SELECT_SIZE) / + res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size / sizeof(float)); TableContext table_context; @@ -384,7 +384,7 @@ int32_t BrpcPsService::PullSparse(Table *table, const PsRequestMessage &request, CostTimer timer("pserver_server_pull_sparse"); uint32_t num = *(uint32_t *)(request.params(0).c_str()); - auto dim = table->ValueAccesor()->GetTableInfo(SELECT_DIM); + auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim; thread_local std::string req_buffer; req_buffer.reserve(req_buffer_size); diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index bb8ba223d828e..3e93f861d4e0e 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -99,7 +99,8 @@ ::std::future PsLocalClient::PullDense(Region* regions, auto* accessor = GetTableAccessor(table_id); auto* table_ptr = GetTable(table_id); - uint32_t num_per_shard = DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1); + uint32_t num_per_shard = + DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1); std::vector region_buffer; region_buffer.resize(num_per_shard); @@ -145,8 +146,8 @@ ::std::future PsLocalClient::PushDenseParam(const Region* regions, auto* table_ptr = GetTable(table_id); std::vector region_buffer; - region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1), 0); - + region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1), + 0); for (size_t i = 0, offset = 0; i < region_num; ++i) { uint32_t data_num = regions[i].size / sizeof(float); memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size); @@ -179,8 +180,8 @@ ::std::future PsLocalClient::PushDense(const Region* regions, auto* table_ptr = GetTable(table_id); std::vector region_buffer; - region_buffer.resize(DenseDimPerShard(accessor->GetTableInfo(FEA_DIM), 1)); - + region_buffer.resize( + DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1)); size_t data_size = region_buffer.size(); for (size_t i = 0, offset = 0; i < region_num; ++i) { uint32_t data_num = regions[i].size / sizeof(float); diff --git a/paddle/fluid/distributed/ps/table/accessor.h b/paddle/fluid/distributed/ps/table/accessor.h index efc1e604dc9d0..024af327a33af 100644 --- a/paddle/fluid/distributed/ps/table/accessor.h +++ b/paddle/fluid/distributed/ps/table/accessor.h @@ -46,27 +46,24 @@ struct DataConverter { }; struct AccessorInfo { + // value维度 size_t dim; + // value各个维度的size size_t size; - size_t select_size; + // pull value维度 size_t select_dim; - size_t update_size; + // pull value各维度相加总size + size_t select_size; + // push value维度 size_t update_dim; + // push value各个维度的size + size_t update_size; + // value中mf动态长度部分总size大小, sparse下生效 size_t mf_size; + // value总维度,dense下生效 size_t fea_dim; }; -enum InfoKey { - DIM = 0, - SIZE = 1, - SELECT_SIZE = 2, - SELECT_DIM = 3, - UPDATE_SIZE = 4, - UPDATE_DIM = 5, - MF_SIZE = 6, - FEA_DIM = 7 -}; - class ValueAccessor { public: ValueAccessor() {} @@ -90,8 +87,7 @@ class ValueAccessor { } virtual int Initialize() = 0; - virtual void SetTableInfo(AccessorInfo& info) = 0; - virtual size_t GetTableInfo(InfoKey key) = 0; + virtual AccessorInfo GetAccessorInfo() { return _accessor_info; } virtual bool NeedExtendMF(float* value) { return false; } virtual bool HasMF(size_t size) { return false; } diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.cc b/paddle/fluid/distributed/ps/table/common_dense_table.cc index f0cb586e45190..4242b65dea023 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/common_dense_table.cc @@ -220,7 +220,8 @@ int32_t CommonDenseTable::Load(const std::string& path, } size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1; // param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1 - size_t dim_num_per_shard = _table_info.fea_dim / _shard_num + 1; + size_t dim_num_per_shard = + _value_accesor->GetAccessorInfo().fea_dim / _shard_num + 1; size_t start_dim_idx = dim_num_per_shard * _shard_idx; size_t start_file_idx = start_dim_idx / dim_num_per_file; size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file; diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_accessor.cc index 8380177963ed9..2eda47ccaa505 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.cc @@ -23,87 +23,35 @@ namespace distributed { int CtrCommonAccessor::Initialize() { auto name = _config.embed_sgd_param().name(); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); + _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1); name = _config.embedx_sgd_param().name(); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embedx_sgd_rule->load_config(_config.embedx_sgd_param(), - _config.embedx_dim()); + _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), + _config.embedx_dim()); - common_feature_value.embed_sgd_dim = _embed_sgd_rule->dim(); + common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim(); common_feature_value.embedx_dim = _config.embedx_dim(); - common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim(); + common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); + InitAccessorInfo(); return 0; } -void CtrCommonAccessor::SetTableInfo(AccessorInfo& info) { - info.dim = Dim(); - info.size = Size(); - info.select_dim = SelectDim(); - info.select_size = SelectSize(); - info.update_dim = UpdateDim(); - info.update_size = UpdateSize(); - info.mf_size = MFSize(); -} - -size_t CtrCommonAccessor::GetTableInfo(InfoKey key) { - switch (key) { - case DIM: - return Dim(); - case SIZE: - return Size(); - case SELECT_DIM: - return SelectDim(); - case SELECT_SIZE: - return SelectSize(); - case UPDATE_DIM: - return UpdateDim(); - case UPDATE_SIZE: - return UpdateSize(); - case MF_SIZE: - return MFSize(); - default: - return 0; - } - return 0; -} - -size_t CtrCommonAccessor::Dim() { return common_feature_value.Dim(); } - -size_t CtrCommonAccessor::DimSize(size_t dim) { - auto embedx_dim = _config.embedx_dim(); - return common_feature_value.DimSize(dim, embedx_dim); -} +void CtrCommonAccessor::InitAccessorInfo() { + _accessor_info.dim = common_feature_value.Dim(); + _accessor_info.size = common_feature_value.Size(); -size_t CtrCommonAccessor::Size() { return common_feature_value.Size(); } - -size_t CtrCommonAccessor::MFSize() { - return (_config.embedx_dim() + common_feature_value.embedx_sgd_dim) * - sizeof(float); // embedx embedx_g2sum -} - -// pull value -size_t CtrCommonAccessor::SelectDim() { - auto embedx_dim = _config.embedx_dim(); - return 3 + embedx_dim; -} - -size_t CtrCommonAccessor::SelectDimSize(size_t dim) { return sizeof(float); } - -size_t CtrCommonAccessor::SelectSize() { return SelectDim() * sizeof(float); } - -// push value -size_t CtrCommonAccessor::UpdateDim() { auto embedx_dim = _config.embedx_dim(); - return 4 + embedx_dim; + _accessor_info.select_dim = 3 + embedx_dim; + _accessor_info.select_size = _accessor_info.select_dim * sizeof(float); + _accessor_info.update_dim = 4 + embedx_dim; + _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); + _accessor_info.mf_size = + (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); } -size_t CtrCommonAccessor::UpdateDimSize(size_t dim) { return sizeof(float); } - -size_t CtrCommonAccessor::UpdateSize() { return UpdateDim() * sizeof(float); } - bool CtrCommonAccessor::Shrink(float* value) { auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); @@ -116,9 +64,9 @@ bool CtrCommonAccessor::Shrink(float* value) { common_feature_value.Click(value) *= _show_click_decay_rate; // shrink after - auto score = show_click_score(common_feature_value.Show(value), - common_feature_value.Click(value)); - auto unseen_days = common_feature_value.unseen_days(value); + auto score = ShowClickScore(common_feature_value.Show(value), + common_feature_value.Click(value)); + auto unseen_days = common_feature_value.UnseenDays(value); if (score < delete_threshold || unseen_days > delete_after_unseen_days) { return true; } @@ -141,14 +89,13 @@ bool CtrCommonAccessor::Save(float* value, int param) { case 1: // save xbox base case 2: { - if (show_click_score(common_feature_value.Show(value), - common_feature_value.Click(value)) >= - base_threshold && - common_feature_value.delta_score(value) >= delta_threshold && - common_feature_value.unseen_days(value) <= delta_keep_days) { + if (ShowClickScore(common_feature_value.Show(value), + common_feature_value.Click(value)) >= base_threshold && + common_feature_value.DeltaScore(value) >= delta_threshold && + common_feature_value.UnseenDays(value) <= delta_keep_days) { // do this after save, because it must not be modified when retry if (param == 2) { - common_feature_value.delta_score(value) = 0; + common_feature_value.DeltaScore(value) = 0; } return true; } else { @@ -158,7 +105,7 @@ bool CtrCommonAccessor::Save(float* value, int param) { // already decayed in shrink case 3: { // do this after save, because it must not be modified when retry - // common_feature_value.unseen_days(value)++; + // common_feature_value.UnseenDays(value)++; return true; } // save revert batch_model @@ -179,17 +126,16 @@ void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) { } switch (param) { case 1: { - if (show_click_score(common_feature_value.Show(value), - common_feature_value.Click(value)) >= - base_threshold && - common_feature_value.delta_score(value) >= delta_threshold && - common_feature_value.unseen_days(value) <= delta_keep_days) { - common_feature_value.delta_score(value) = 0; + if (ShowClickScore(common_feature_value.Show(value), + common_feature_value.Click(value)) >= base_threshold && + common_feature_value.DeltaScore(value) >= delta_threshold && + common_feature_value.UnseenDays(value) <= delta_keep_days) { + common_feature_value.DeltaScore(value) = 0; } } return; case 3: { - common_feature_value.unseen_days(value)++; + common_feature_value.UnseenDays(value)++; } return; default: @@ -201,17 +147,16 @@ int32_t CtrCommonAccessor::Create(float** values, size_t num) { auto embedx_dim = _config.embedx_dim(); for (size_t value_item = 0; value_item < num; ++value_item) { float* value = values[value_item]; - value[common_feature_value.unseen_days_index()] = 0; - value[common_feature_value.delta_score_index()] = 0; + value[common_feature_value.UnseenDaysIndex()] = 0; + value[common_feature_value.DeltaScoreIndex()] = 0; value[common_feature_value.ShowIndex()] = 0; value[common_feature_value.ClickIndex()] = 0; value[common_feature_value.SlotIndex()] = -1; - _embed_sgd_rule->init_value( - value + common_feature_value.Embed_W_Index(), - value + common_feature_value.embed_g2sum_index()); - _embedx_sgd_rule->init_value( - value + common_feature_value.Embedx_W_Index(), - value + common_feature_value.embedx_g2sum_index(), false); + _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(), + value + common_feature_value.EmbedG2SumIndex()); + _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(), + value + common_feature_value.EmbedxG2SumIndex(), + false); } return 0; } @@ -225,7 +170,7 @@ bool CtrCommonAccessor::NeedExtendMF(float* value) { } bool CtrCommonAccessor::HasMF(size_t size) { - return size > common_feature_value.embedx_g2sum_index(); + return size > common_feature_value.EmbedxG2SumIndex(); } // from CommonFeatureValue to CtrCommonPullValue @@ -239,10 +184,10 @@ int32_t CtrCommonAccessor::Select(float** select_values, const float** values, value[common_feature_value.ShowIndex()]; select_value[CtrCommonPullValue::ClickIndex()] = value[common_feature_value.ClickIndex()]; - select_value[CtrCommonPullValue::Embed_W_Index()] = - value[common_feature_value.Embed_W_Index()]; - memcpy(select_value + CtrCommonPullValue::Embedx_W_Index(), - value + common_feature_value.Embedx_W_Index(), + select_value[CtrCommonPullValue::EmbedWIndex()] = + value[common_feature_value.EmbedWIndex()]; + memcpy(select_value + CtrCommonPullValue::EmbedxWIndex(), + value + common_feature_value.EmbedxWIndex(), embedx_dim * sizeof(float)); } return 0; @@ -283,18 +228,18 @@ int32_t CtrCommonAccessor::Update(float** update_values, update_value[common_feature_value.ShowIndex()] += push_show; update_value[common_feature_value.ClickIndex()] += push_click; update_value[common_feature_value.SlotIndex()] = slot; - update_value[common_feature_value.delta_score_index()] += + update_value[common_feature_value.DeltaScoreIndex()] += (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + push_click * _config.ctr_accessor_param().click_coeff(); - update_value[common_feature_value.unseen_days_index()] = 0; - _embed_sgd_rule->update_value( - update_value + common_feature_value.Embed_W_Index(), - update_value + common_feature_value.embed_g2sum_index(), - push_value + CtrCommonPushValue::Embed_G_Index()); - _embedx_sgd_rule->update_value( - update_value + common_feature_value.Embedx_W_Index(), - update_value + common_feature_value.embedx_g2sum_index(), - push_value + CtrCommonPushValue::Embedx_G_Index()); + update_value[common_feature_value.UnseenDaysIndex()] = 0; + _embed_sgd_rule->UpdateValue( + update_value + common_feature_value.EmbedWIndex(), + update_value + common_feature_value.EmbedG2SumIndex(), + push_value + CtrCommonPushValue::EmbedGIndex()); + _embedx_sgd_rule->UpdateValue( + update_value + common_feature_value.EmbedxWIndex(), + update_value + common_feature_value.EmbedxG2SumIndex(), + push_value + CtrCommonPushValue::EmbedxGIndex()); } return 0; } @@ -308,7 +253,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) { // operation auto show = CtrCommonPushValue::Show(const_cast(value)); auto click = CtrCommonPushValue::Click(const_cast(value)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score <= 0) { return false; } @@ -322,7 +267,7 @@ bool CtrCommonAccessor::CreateValue(int stage, const float* value) { } } -float CtrCommonAccessor::show_click_score(float show, float click) { +float CtrCommonAccessor::ShowClickScore(float show, float click) { auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); auto click_coeff = _config.ctr_accessor_param().click_coeff(); return (show - click) * nonclk_coeff + click * click_coeff; @@ -334,16 +279,16 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) { os.str(""); os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " " << v[5]; - for (int i = common_feature_value.embed_g2sum_index(); - i < common_feature_value.Embedx_W_Index(); i++) { + for (int i = common_feature_value.EmbedG2SumIndex(); + i < common_feature_value.EmbedxWIndex(); i++) { os << " " << v[i]; } auto show = common_feature_value.Show(const_cast(v)); auto click = common_feature_value.Click(const_cast(v)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score >= _config.embedx_threshold() && - param > common_feature_value.Embedx_W_Index()) { - for (auto i = common_feature_value.Embedx_W_Index(); + param > common_feature_value.EmbedxWIndex()) { + for (auto i = common_feature_value.EmbedxWIndex(); i < common_feature_value.Dim(); ++i) { os << " " << v[i]; } @@ -354,9 +299,8 @@ std::string CtrCommonAccessor::ParseToString(const float* v, int param) { int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) { int embedx_dim = _config.embedx_dim(); - _embedx_sgd_rule->init_value( - value + common_feature_value.Embedx_W_Index(), - value + common_feature_value.embedx_g2sum_index()); + _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(), + value + common_feature_value.EmbedxG2SumIndex()); auto ret = paddle::string::str_to_float(str.data(), value); CHECK(ret >= 6) << "expect more than 6 real:" << ret; return ret; diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.h b/paddle/fluid/distributed/ps/table/ctr_accessor.h index 21dfc6a5c1c38..b8895e74d1d09 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.h @@ -44,24 +44,24 @@ class CtrCommonAccessor : public ValueAccessor { int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } int Size() { return Dim() * sizeof(float); } int SlotIndex() { return 0; } - int unseen_days_index() { return SlotIndex() + 1; } - int delta_score_index() { return unseen_days_index() + 1; } - int ShowIndex() { return delta_score_index() + 1; } + int UnseenDaysIndex() { return SlotIndex() + 1; } + int DeltaScoreIndex() { return UnseenDaysIndex() + 1; } + int ShowIndex() { return DeltaScoreIndex() + 1; } int ClickIndex() { return ShowIndex() + 1; } - int Embed_W_Index() { return ClickIndex() + 1; } - int embed_g2sum_index() { return Embed_W_Index() + 1; } - int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; } - int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; } + int EmbedWIndex() { return ClickIndex() + 1; } + int EmbedG2SumIndex() { return EmbedWIndex() + 1; } + int EmbedxWIndex() { return EmbedG2SumIndex() + embed_sgd_dim; } + int EmbedxG2SumIndex() { return EmbedxWIndex() + embedx_dim; } - float& unseen_days(float* val) { return val[unseen_days_index()]; } - float& delta_score(float* val) { return val[delta_score_index()]; } + float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; } + float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } float& Show(float* val) { return val[ShowIndex()]; } float& Click(float* val) { return val[ClickIndex()]; } float& Slot(float* val) { return val[SlotIndex()]; } - float& EmbedW(float* val) { return val[Embed_W_Index()]; } - float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } - float& EmbedxW(float* val) { return val[Embedx_W_Index()]; } - float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } + float& EmbedW(float* val) { return val[EmbedWIndex()]; } + float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; } + float& EmbedxW(float* val) { return val[EmbedxWIndex()]; } + float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; } int embed_sgd_dim; int embedx_dim; @@ -84,10 +84,8 @@ class CtrCommonAccessor : public ValueAccessor { static int SlotIndex() { return 0; } static int ShowIndex() { return CtrCommonPushValue::SlotIndex() + 1; } static int ClickIndex() { return CtrCommonPushValue::ShowIndex() + 1; } - static int Embed_G_Index() { return CtrCommonPushValue::ClickIndex() + 1; } - static int Embedx_G_Index() { - return CtrCommonPushValue::Embed_G_Index() + 1; - } + static int EmbedGIndex() { return CtrCommonPushValue::ClickIndex() + 1; } + static int EmbedxGIndex() { return CtrCommonPushValue::EmbedGIndex() + 1; } static float& Slot(float* val) { return val[CtrCommonPushValue::SlotIndex()]; } @@ -98,10 +96,10 @@ class CtrCommonAccessor : public ValueAccessor { return val[CtrCommonPushValue::ClickIndex()]; } static float& EmbedG(float* val) { - return val[CtrCommonPushValue::Embed_G_Index()]; + return val[CtrCommonPushValue::EmbedGIndex()]; } static float* EmbedxG(float* val) { - return val + CtrCommonPushValue::Embedx_G_Index(); + return val + CtrCommonPushValue::EmbedxGIndex(); } }; @@ -118,8 +116,8 @@ class CtrCommonAccessor : public ValueAccessor { static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int ShowIndex() { return 0; } static int ClickIndex() { return 1; } - static int Embed_W_Index() { return 2; } - static int Embedx_W_Index() { return 3; } + static int EmbedWIndex() { return 2; } + static int EmbedxWIndex() { return 3; } static float& Show(float* val) { return val[CtrCommonPullValue::ShowIndex()]; } @@ -127,38 +125,17 @@ class CtrCommonAccessor : public ValueAccessor { return val[CtrCommonPullValue::ClickIndex()]; } static float& EmbedW(float* val) { - return val[CtrCommonPullValue::Embed_W_Index()]; + return val[CtrCommonPullValue::EmbedWIndex()]; } static float* EmbedxW(float* val) { - return val + CtrCommonPullValue::Embedx_W_Index(); + return val + CtrCommonPullValue::EmbedxWIndex(); } }; CtrCommonAccessor() {} - virtual int Initialize(); virtual ~CtrCommonAccessor() {} - - virtual void SetTableInfo(AccessorInfo& info); - virtual size_t GetTableInfo(InfoKey key); - // value维度 - size_t Dim(); - // value各个维度的size - size_t DimSize(size_t dim); - // value各维度相加总size - size_t Size(); - // value中mf动态长度部分总size大小, sparse下生效 - size_t MFSize(); - // pull value维度 - size_t SelectDim(); - // pull value各个维度的size - size_t SelectDimSize(size_t dim); - // pull value各维度相加总size - size_t SelectSize(); - // push value维度 - size_t UpdateDim(); - // push value各个维度的size - size_t UpdateDimSize(size_t dim); - // push value各维度相加总size - size_t UpdateSize(); + virtual int Initialize(); + // 初始化AccessorInfo + virtual void InitAccessorInfo(); // 判断该value是否进行shrink virtual bool Shrink(float* value); // 判断该value是否保存到ssd @@ -202,7 +179,7 @@ class CtrCommonAccessor : public ValueAccessor { } private: - // float show_click_score(float show, float click); + // float ShowClickScore(float show, float click); // SparseValueSGDRule* _embed_sgd_rule; // SparseValueSGDRule* _embedx_sgd_rule; @@ -213,7 +190,7 @@ class CtrCommonAccessor : public ValueAccessor { public: // TODO(zhaocaibei123): it should be private, but we make it public // for unit test CtrCommonFeatureValue common_feature_value; - float show_click_score(float show, float click); + float ShowClickScore(float show, float click); SparseValueSGDRule* _embed_sgd_rule; SparseValueSGDRule* _embedx_sgd_rule; }; diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc index ed21a6dac317e..740b03a84e461 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc @@ -23,89 +23,32 @@ namespace distributed { int DownpourCtrDoubleAccessor::Initialize() { auto name = _config.embed_sgd_param().name(); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); + _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1); name = _config.embedx_sgd_param().name(); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embedx_sgd_rule->load_config(_config.embedx_sgd_param(), - _config.embedx_dim()); + _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), + _config.embedx_dim()); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); _ssd_unseenday_threshold = _config.ctr_accessor_param().ssd_unseenday_threshold(); + InitAccessorInfo(); return 0; } -void DownpourCtrDoubleAccessor::SetTableInfo(AccessorInfo& info) { - info.dim = Dim(); - info.size = Size(); - info.select_dim = SelectDim(); - info.select_size = SelectSize(); - info.update_dim = UpdateDim(); - info.update_size = UpdateSize(); - info.mf_size = MFSize(); -} - -size_t DownpourCtrDoubleAccessor::GetTableInfo(InfoKey key) { - switch (key) { - case DIM: - return Dim(); - case SIZE: - return Size(); - case SELECT_DIM: - return SelectDim(); - case SELECT_SIZE: - return SelectSize(); - case UPDATE_DIM: - return UpdateDim(); - case UPDATE_SIZE: - return UpdateSize(); - case MF_SIZE: - return MFSize(); - default: - return 0; - } - return 0; -} - -size_t DownpourCtrDoubleAccessor::Dim() { - auto embedx_dim = _config.embedx_dim(); - return DownpourCtrDoubleFeatureValue::Dim(embedx_dim); -} -size_t DownpourCtrDoubleAccessor::DimSize(size_t dim) { - auto embedx_dim = _config.embedx_dim(); - return DownpourCtrDoubleFeatureValue::DimSize(dim, embedx_dim); -} -size_t DownpourCtrDoubleAccessor::Size() { +void DownpourCtrDoubleAccessor::InitAccessorInfo() { auto embedx_dim = _config.embedx_dim(); - return DownpourCtrDoubleFeatureValue::Size(embedx_dim); -} -size_t DownpourCtrDoubleAccessor::MFSize() { - return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum -} -// pull value -size_t DownpourCtrDoubleAccessor::SelectDim() { - auto embedx_dim = _config.embedx_dim(); - return 3 + embedx_dim; -} -size_t DownpourCtrDoubleAccessor::SelectDimSize(size_t dim) { - return sizeof(float); -} -size_t DownpourCtrDoubleAccessor::SelectSize() { - return SelectDim() * sizeof(float); -} -// push value -size_t DownpourCtrDoubleAccessor::UpdateDim() { - auto embedx_dim = _config.embedx_dim(); - return 4 + embedx_dim; -} -size_t DownpourCtrDoubleAccessor::UpdateDimSize(size_t dim) { - return sizeof(float); -} -size_t DownpourCtrDoubleAccessor::UpdateSize() { - return UpdateDim() * sizeof(float); + _accessor_info.dim = DownpourCtrDoubleFeatureValue::Dim(embedx_dim); + _accessor_info.size = DownpourCtrDoubleFeatureValue::Size(embedx_dim); + _accessor_info.select_dim = 3 + embedx_dim; + _accessor_info.select_size = _accessor_info.select_dim * sizeof(float); + _accessor_info.update_dim = 4 + embedx_dim; + _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); + _accessor_info.mf_size = (embedx_dim + 1) * sizeof(float); } + bool DownpourCtrDoubleAccessor::Shrink(float* value) { // auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); @@ -119,16 +62,16 @@ bool DownpourCtrDoubleAccessor::Shrink(float* value) { DownpourCtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate; DownpourCtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate; // shrink after - auto score = show_click_score(DownpourCtrDoubleFeatureValue::Show(value), - DownpourCtrDoubleFeatureValue::Click(value)); - auto unseen_days = DownpourCtrDoubleFeatureValue::unseen_days(value); + auto score = ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), + DownpourCtrDoubleFeatureValue::Click(value)); + auto unseen_days = DownpourCtrDoubleFeatureValue::UnseenDays(value); if (score < delete_threshold || unseen_days > delete_after_unseen_days) { return true; } return false; } bool DownpourCtrDoubleAccessor::save_ssd(float* value) { - if (DownpourCtrDoubleFeatureValue::unseen_days(value) > + if (DownpourCtrDoubleFeatureValue::UnseenDays(value) > _ssd_unseenday_threshold) { return true; } @@ -138,9 +81,9 @@ bool DownpourCtrDoubleAccessor::save_ssd(float* value) { // float* value, int param, double global_cache_threshold) { // auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); -// if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value), +// if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), // DownpourCtrDoubleFeatureValue::Click(value)) >= base_threshold -// && DownpourCtrDoubleFeatureValue::unseen_days(value) <= +// && DownpourCtrDoubleFeatureValue::UnseenDays(value) <= // delta_keep_days) { // return DownpourCtrDoubleFeatureValue::Show(value) > // global_cache_threshold; @@ -166,16 +109,14 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) { case 1: // save xbox base case 2: { - if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value), - DownpourCtrDoubleFeatureValue::Click(value)) >= + if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), + DownpourCtrDoubleFeatureValue::Click(value)) >= base_threshold && - DownpourCtrDoubleFeatureValue::delta_score(value) >= - delta_threshold && - DownpourCtrDoubleFeatureValue::unseen_days(value) <= - delta_keep_days) { + DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold && + DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) { // do this after save, because it must not be modified when retry if (param == 2) { - DownpourCtrDoubleFeatureValue::delta_score(value) = 0; + DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0; } return true; } else { @@ -187,7 +128,7 @@ bool DownpourCtrDoubleAccessor::Save(float* value, int param) { // DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate; // do this after save, because it must not be modified when retry - // DownpourCtrDoubleFeatureValue::unseen_days(value)++; + // DownpourCtrDoubleFeatureValue::UnseenDays(value)++; return true; } default: @@ -204,19 +145,17 @@ void DownpourCtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) { } switch (param) { case 1: { - if (show_click_score(DownpourCtrDoubleFeatureValue::Show(value), - DownpourCtrDoubleFeatureValue::Click(value)) >= + if (ShowClickScore(DownpourCtrDoubleFeatureValue::Show(value), + DownpourCtrDoubleFeatureValue::Click(value)) >= base_threshold && - DownpourCtrDoubleFeatureValue::delta_score(value) >= - delta_threshold && - DownpourCtrDoubleFeatureValue::unseen_days(value) <= - delta_keep_days) { - DownpourCtrDoubleFeatureValue::delta_score(value) = 0; + DownpourCtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold && + DownpourCtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) { + DownpourCtrDoubleFeatureValue::DeltaScore(value) = 0; } } return; case 3: { - DownpourCtrDoubleFeatureValue::unseen_days(value)++; + DownpourCtrDoubleFeatureValue::UnseenDays(value)++; } return; default: @@ -228,17 +167,17 @@ int32_t DownpourCtrDoubleAccessor::Create(float** values, size_t num) { auto embedx_dim = _config.embedx_dim(); for (size_t value_item = 0; value_item < num; ++value_item) { float* value = values[value_item]; - value[DownpourCtrDoubleFeatureValue::unseen_days_index()] = 0; - value[DownpourCtrDoubleFeatureValue::delta_score_index()] = 0; + value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0; + value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] = 0; *(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()) = 0; *(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()) = 0; value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1; - _embed_sgd_rule->init_value( - value + DownpourCtrDoubleFeatureValue::Embed_W_Index(), - value + DownpourCtrDoubleFeatureValue::embed_g2sum_index()); - _embedx_sgd_rule->init_value( - value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), - value + DownpourCtrDoubleFeatureValue::embedx_g2sum_index(), false); + _embed_sgd_rule->InitValue( + value + DownpourCtrDoubleFeatureValue::EmbedWIndex(), + value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()); + _embedx_sgd_rule->InitValue( + value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), + value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(), false); } return 0; } @@ -264,10 +203,10 @@ int32_t DownpourCtrDoubleAccessor::Select(float** select_values, (float)*(double*)(value + DownpourCtrDoubleFeatureValue::ShowIndex()); select_value[DownpourCtrDoublePullValue::ClickIndex()] = (float)*(double*)(value + DownpourCtrDoubleFeatureValue::ClickIndex()); - select_value[DownpourCtrDoublePullValue::Embed_W_Index()] = - value[DownpourCtrDoubleFeatureValue::Embed_W_Index()]; - memcpy(select_value + DownpourCtrDoublePullValue::Embedx_W_Index(), - value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), + select_value[DownpourCtrDoublePullValue::EmbedWIndex()] = + value[DownpourCtrDoubleFeatureValue::EmbedWIndex()]; + memcpy(select_value + DownpourCtrDoublePullValue::EmbedxWIndex(), + value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), embedx_dim * sizeof(float)); } return 0; @@ -316,20 +255,20 @@ int32_t DownpourCtrDoubleAccessor::Update(float** update_values, *(double*)(update_value + DownpourCtrDoubleFeatureValue::ClickIndex()) += (double)push_click; update_value[DownpourCtrDoubleFeatureValue::SlotIndex()] = slot; - update_value[DownpourCtrDoubleFeatureValue::delta_score_index()] += + update_value[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()] += (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + push_click * _config.ctr_accessor_param().click_coeff(); //(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + // push_click * _config.ctr_accessor_param().click_coeff(); - update_value[DownpourCtrDoubleFeatureValue::unseen_days_index()] = 0; - _embed_sgd_rule->update_value( - update_value + DownpourCtrDoubleFeatureValue::Embed_W_Index(), - update_value + DownpourCtrDoubleFeatureValue::embed_g2sum_index(), - push_value + DownpourCtrDoublePushValue::Embed_G_Index(), push_show); - _embedx_sgd_rule->update_value( - update_value + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), - update_value + DownpourCtrDoubleFeatureValue::embedx_g2sum_index(), - push_value + DownpourCtrDoublePushValue::Embedx_G_Index(), push_show); + update_value[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()] = 0; + _embed_sgd_rule->UpdateValue( + update_value + DownpourCtrDoubleFeatureValue::EmbedWIndex(), + update_value + DownpourCtrDoubleFeatureValue::EmbedG2SumIndex(), + push_value + DownpourCtrDoublePushValue::EmbedGIndex(), push_show); + _embedx_sgd_rule->UpdateValue( + update_value + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), + update_value + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(), + push_value + DownpourCtrDoublePushValue::EmbedxGIndex(), push_show); } return 0; } @@ -341,7 +280,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) { } else if (stage == 1) { auto show = DownpourCtrDoublePushValue::Show(const_cast(value)); auto click = DownpourCtrDoublePushValue::Click(const_cast(value)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score <= 0) { return false; } @@ -354,7 +293,7 @@ bool DownpourCtrDoubleAccessor::CreateValue(int stage, const float* value) { return true; } } -double DownpourCtrDoubleAccessor::show_click_score(double show, double click) { +double DownpourCtrDoubleAccessor::ShowClickScore(double show, double click) { // auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); // auto click_coeff = _config.ctr_accessor_param().click_coeff(); auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); @@ -371,7 +310,7 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v, << v[8]; auto show = DownpourCtrDoubleFeatureValue::Show(const_cast(v)); auto click = DownpourCtrDoubleFeatureValue::Click(const_cast(v)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score >= _config.embedx_threshold() && param_size > 9) { os << " " << v[9]; for (auto i = 0; i < _config.embedx_dim(); ++i) { @@ -383,19 +322,19 @@ std::string DownpourCtrDoubleAccessor::ParseToString(const float* v, int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str, float* value) { int embedx_dim = _config.embedx_dim(); - float data_buff[Dim() + 2]; + float data_buff[_accessor_info.dim + 2]; float* data_buff_ptr = data_buff; - _embedx_sgd_rule->init_value( - data_buff_ptr + DownpourCtrDoubleFeatureValue::Embedx_W_Index(), - data_buff_ptr + DownpourCtrDoubleFeatureValue::embedx_g2sum_index()); + _embedx_sgd_rule->InitValue( + data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxWIndex(), + data_buff_ptr + DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex()); auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr); CHECK(str_len >= 6) << "expect more than 6 real:" << str_len; int show_index = DownpourCtrDoubleFeatureValue::ShowIndex(); int click_index = DownpourCtrDoubleFeatureValue::ClickIndex(); - int embed_w_index = DownpourCtrDoubleFeatureValue::Embed_W_Index(); + int embed_w_index = DownpourCtrDoubleFeatureValue::EmbedWIndex(); // no slot, embedx - int value_dim = Dim(); - int embedx_g2sum_index = DownpourCtrDoubleFeatureValue::embedx_g2sum_index(); + int value_dim = _accessor_info.dim; + int embedx_g2sum_index = DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex(); value[DownpourCtrDoubleFeatureValue::SlotIndex()] = -1; // other case if (str_len == (value_dim - 1)) { @@ -405,9 +344,8 @@ int DownpourCtrDoubleAccessor::ParseFromString(const std::string& str, *(double*)(value + show_index) = (double)data_buff_ptr[2]; *(double*)(value + click_index) = (double)data_buff_ptr[3]; // copy others - value[DownpourCtrDoubleFeatureValue::Embed_W_Index()] = data_buff_ptr[4]; - value[DownpourCtrDoubleFeatureValue::embed_g2sum_index()] = - data_buff_ptr[5]; + value[DownpourCtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4]; + value[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5]; memcpy(value + embedx_g2sum_index, data_buff_ptr + 6, (embedx_dim + 1) * sizeof(float)); } else { diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h index 29ddcbc86d7c7..3995903463637 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h @@ -43,38 +43,38 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { static int Size(int embedx_dim) { return (Dim(embedx_dim) + 2) * sizeof(float); } - static int unseen_days_index() { return 0; } - static int delta_score_index() { - return DownpourCtrDoubleFeatureValue::unseen_days_index() + 1; + static int UnseenDaysIndex() { return 0; } + static int DeltaScoreIndex() { + return DownpourCtrDoubleFeatureValue::UnseenDaysIndex() + 1; } static int ShowIndex() { - return DownpourCtrDoubleFeatureValue::delta_score_index() + 1; + return DownpourCtrDoubleFeatureValue::DeltaScoreIndex() + 1; } // show is double static int ClickIndex() { return DownpourCtrDoubleFeatureValue::ShowIndex() + 2; } // click is double - static int Embed_W_Index() { + static int EmbedWIndex() { return DownpourCtrDoubleFeatureValue::ClickIndex() + 2; } - static int embed_g2sum_index() { - return DownpourCtrDoubleFeatureValue::Embed_W_Index() + 1; + static int EmbedG2SumIndex() { + return DownpourCtrDoubleFeatureValue::EmbedWIndex() + 1; } static int SlotIndex() { - return DownpourCtrDoubleFeatureValue::embed_g2sum_index() + 1; + return DownpourCtrDoubleFeatureValue::EmbedG2SumIndex() + 1; } - static int embedx_g2sum_index() { + static int EmbedxG2SumIndex() { return DownpourCtrDoubleFeatureValue::SlotIndex() + 1; } - static int Embedx_W_Index() { - return DownpourCtrDoubleFeatureValue::embedx_g2sum_index() + 1; + static int EmbedxWIndex() { + return DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex() + 1; } - static float& unseen_days(float* val) { - return val[DownpourCtrDoubleFeatureValue::unseen_days_index()]; + static float& UnseenDays(float* val) { + return val[DownpourCtrDoubleFeatureValue::UnseenDaysIndex()]; } - static float& delta_score(float* val) { - return val[DownpourCtrDoubleFeatureValue::delta_score_index()]; + static float& DeltaScore(float* val) { + return val[DownpourCtrDoubleFeatureValue::DeltaScoreIndex()]; } static double& Show(float* val) { return ((double*)(val + DownpourCtrDoubleFeatureValue::ShowIndex()))[0]; @@ -86,16 +86,16 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { return val[DownpourCtrDoubleFeatureValue::SlotIndex()]; } static float& EmbedW(float* val) { - return val[DownpourCtrDoubleFeatureValue::Embed_W_Index()]; + return val[DownpourCtrDoubleFeatureValue::EmbedWIndex()]; } - static float& embed_g2sum(float* val) { - return val[DownpourCtrDoubleFeatureValue::embed_g2sum_index()]; + static float& EmbedG2Sum(float* val) { + return val[DownpourCtrDoubleFeatureValue::EmbedG2SumIndex()]; } - static float& embedx_g2sum(float* val) { - return val[DownpourCtrDoubleFeatureValue::embedx_g2sum_index()]; + static float& EmbedxG2Sum(float* val) { + return val[DownpourCtrDoubleFeatureValue::EmbedxG2SumIndex()]; } static float* EmbedxW(float* val) { - return (val + DownpourCtrDoubleFeatureValue::Embedx_W_Index()); + return (val + DownpourCtrDoubleFeatureValue::EmbedxWIndex()); } }; struct DownpourCtrDoublePushValue { @@ -116,11 +116,11 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { static int ClickIndex() { return DownpourCtrDoublePushValue::ShowIndex() + 1; } - static int Embed_G_Index() { + static int EmbedGIndex() { return DownpourCtrDoublePushValue::ClickIndex() + 1; } - static int Embedx_G_Index() { - return DownpourCtrDoublePushValue::Embed_G_Index() + 1; + static int EmbedxGIndex() { + return DownpourCtrDoublePushValue::EmbedGIndex() + 1; } static float& Slot(float* val) { return val[DownpourCtrDoublePushValue::SlotIndex()]; @@ -132,10 +132,10 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { return val[DownpourCtrDoublePushValue::ClickIndex()]; } static float& EmbedG(float* val) { - return val[DownpourCtrDoublePushValue::Embed_G_Index()]; + return val[DownpourCtrDoublePushValue::EmbedGIndex()]; } static float* EmbedxG(float* val) { - return val + DownpourCtrDoublePushValue::Embedx_G_Index(); + return val + DownpourCtrDoublePushValue::EmbedxGIndex(); } }; struct DownpourCtrDoublePullValue { @@ -150,8 +150,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int ShowIndex() { return 0; } static int ClickIndex() { return 1; } - static int Embed_W_Index() { return 2; } - static int Embedx_W_Index() { return 3; } + static int EmbedWIndex() { return 2; } + static int EmbedxWIndex() { return 3; } static float& Show(float* val) { return val[DownpourCtrDoublePullValue::ShowIndex()]; } @@ -159,37 +159,17 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { return val[DownpourCtrDoublePullValue::ClickIndex()]; } static float& EmbedW(float* val) { - return val[DownpourCtrDoublePullValue::Embed_W_Index()]; + return val[DownpourCtrDoublePullValue::EmbedWIndex()]; } static float* EmbedxW(float* val) { - return val + DownpourCtrDoublePullValue::Embedx_W_Index(); + return val + DownpourCtrDoublePullValue::EmbedxWIndex(); } }; DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {} virtual int Initialize(); - virtual void SetTableInfo(AccessorInfo& info); - virtual size_t GetTableInfo(InfoKey key); - // value维度 - size_t Dim(); - // value各个维度的size - size_t DimSize(size_t dim); - // value各维度相加总size - size_t Size(); - // value中mf动态长度部分总size大小, sparse下生效 - size_t MFSize(); - // pull value维度 - size_t SelectDim(); - // pull value各个维度的size - size_t SelectDimSize(size_t dim); - // pull value各维度相加总size - size_t SelectSize(); - // push value维度 - size_t UpdateDim(); - // push value各个维度的size - size_t UpdateDimSize(size_t dim); - // push value各维度相加总size - size_t UpdateSize(); + // 初始化AccessorInfo + virtual void InitAccessorInfo(); // 判断该value是否进行shrink virtual bool Shrink(float* value); virtual bool NeedExtendMF(float* value); @@ -235,7 +215,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embed_w) // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embedx_w) private: - double show_click_score(double show, double click); + double ShowClickScore(double show, double click); private: SparseValueSGDRule* _embed_sgd_rule; diff --git a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc index 1140afd1c1e09..bad75d2de16ba 100644 --- a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc @@ -23,91 +23,32 @@ namespace distributed { int DownpourCtrAccessor::Initialize() { auto name = _config.embed_sgd_param().name(); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); + _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1); name = _config.embedx_sgd_param().name(); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embedx_sgd_rule->load_config(_config.embedx_sgd_param(), - _config.embedx_dim()); + _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), + _config.embedx_dim()); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); _ssd_unseenday_threshold = _config.ctr_accessor_param().ssd_unseenday_threshold(); set_time_decay_rates(); + InitAccessorInfo(); return 0; } -void DownpourCtrAccessor::SetTableInfo(AccessorInfo& info) { - info.dim = Dim(); - info.size = Size(); - info.select_dim = SelectDim(); - info.select_size = SelectSize(); - info.update_dim = UpdateDim(); - info.update_size = UpdateSize(); - info.mf_size = MFSize(); -} - -size_t DownpourCtrAccessor::GetTableInfo(InfoKey key) { - switch (key) { - case DIM: - return Dim(); - case SIZE: - return Size(); - case SELECT_DIM: - return SelectDim(); - case SELECT_SIZE: - return SelectSize(); - case UPDATE_DIM: - return UpdateDim(); - case UPDATE_SIZE: - return UpdateSize(); - case MF_SIZE: - return MFSize(); - default: - return 0; - } - return 0; -} - -size_t DownpourCtrAccessor::Dim() { - auto embedx_dim = _config.embedx_dim(); - return DownpourCtrFeatureValue::Dim(embedx_dim); -} - -size_t DownpourCtrAccessor::DimSize(size_t dim) { - auto embedx_dim = _config.embedx_dim(); - return DownpourCtrFeatureValue::DimSize(dim, embedx_dim); -} - -size_t DownpourCtrAccessor::Size() { +void DownpourCtrAccessor::InitAccessorInfo() { auto embedx_dim = _config.embedx_dim(); - return DownpourCtrFeatureValue::Size(embedx_dim); + _accessor_info.dim = DownpourCtrFeatureValue::Dim(embedx_dim); + _accessor_info.size = DownpourCtrFeatureValue::Size(embedx_dim); + _accessor_info.select_dim = 3 + embedx_dim; + _accessor_info.select_size = _accessor_info.select_dim * sizeof(float); + _accessor_info.update_dim = 4 + embedx_dim; + _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); + _accessor_info.mf_size = (embedx_dim + 1) * sizeof(float); } -size_t DownpourCtrAccessor::MFSize() { - return (_config.embedx_dim() + 1) * sizeof(float); // embedx embedx_g2sum -} - -// pull value -size_t DownpourCtrAccessor::SelectDim() { - auto embedx_dim = _config.embedx_dim(); - return 3 + embedx_dim; -} - -size_t DownpourCtrAccessor::SelectDimSize(size_t dim) { return sizeof(float); } - -size_t DownpourCtrAccessor::SelectSize() { return SelectDim() * sizeof(float); } - -// push value -size_t DownpourCtrAccessor::UpdateDim() { - auto embedx_dim = _config.embedx_dim(); - return 4 + embedx_dim; -} - -size_t DownpourCtrAccessor::UpdateDimSize(size_t dim) { return sizeof(float); } - -size_t DownpourCtrAccessor::UpdateSize() { return UpdateDim() * sizeof(float); } - bool DownpourCtrAccessor::Shrink(float* value) { // auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); @@ -119,7 +60,7 @@ bool DownpourCtrAccessor::Shrink(float* value) { auto delete_threshold = _config.ctr_accessor_param().delete_threshold(); // time_decay first - auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); + auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value); int16_t day_diff = _day_id - unseen_days; if (day_diff < 0 || day_diff > delete_after_unseen_days) { return true; @@ -130,7 +71,7 @@ bool DownpourCtrAccessor::Shrink(float* value) { DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff]; // shrink after - auto score = show_click_score(show_right, click_right); + auto score = ShowClickScore(show_right, click_right); if (score < delete_threshold) { return true; } @@ -145,7 +86,7 @@ bool DownpourCtrAccessor::save_ssd(float* value) { if (_day_id == 0) { return true; } - auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); + auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value); if (unseen_days == 0) { return false; } @@ -164,9 +105,9 @@ bool DownpourCtrAccessor::save_ssd(float* value) { // float* value, int param, double global_cache_threshold) { // auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); -// auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); +// auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value); // int16_t day_diff = _day_id - unseen_days; -// if (show_click_score(DownpourCtrFeatureValue::Show(value), +// if (ShowClickScore(DownpourCtrFeatureValue::Show(value), // DownpourCtrFeatureValue::Click(value)) >= base_threshold // && day_diff <= delta_keep_days) { // return DownpourCtrFeatureValue::Show(value) > global_cache_threshold; @@ -193,7 +134,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) { case 1: // save xbox base case 2: { - auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); + auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value); int16_t day_diff = _day_id - unseen_days; auto show_right = @@ -201,12 +142,12 @@ bool DownpourCtrAccessor::Save(float* value, int param) { auto click_right = DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff]; - if (show_click_score(show_right, click_right) >= base_threshold && - DownpourCtrFeatureValue::delta_score(value) >= delta_threshold && + if (ShowClickScore(show_right, click_right) >= base_threshold && + DownpourCtrFeatureValue::DeltaScore(value) >= delta_threshold && day_diff <= delta_keep_days) { // do this after save, because it must not be modified when retry if (param == 2) { - DownpourCtrFeatureValue::delta_score(value) = 0; + DownpourCtrFeatureValue::DeltaScore(value) = 0; } return true; } else { @@ -218,7 +159,7 @@ bool DownpourCtrAccessor::Save(float* value, int param) { // DownpourCtrFeatureValue::Show(value) *= _show_click_decay_rate; // DownpourCtrFeatureValue::Click(value) *= _show_click_decay_rate; // do this after save, because it must not be modified when retry - // DownpourCtrFeatureValue::unseen_days(value)++; + // DownpourCtrFeatureValue::UnseenDays(value)++; return true; } default: @@ -235,23 +176,23 @@ void DownpourCtrAccessor::UpdateStatAfterSave(float* value, int param) { } switch (param) { case 1: { - auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); + auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value); int16_t day_diff = _day_id - unseen_days; auto show_right = DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff]; auto click_right = DownpourCtrFeatureValue::Click(value) * _time_decay_rates[day_diff]; - if (show_click_score(show_right, click_right) >= base_threshold && - DownpourCtrFeatureValue::delta_score(value) >= delta_threshold && + if (ShowClickScore(show_right, click_right) >= base_threshold && + DownpourCtrFeatureValue::DeltaScore(value) >= delta_threshold && day_diff <= delta_keep_days) { - DownpourCtrFeatureValue::delta_score(value) = 0; + DownpourCtrFeatureValue::DeltaScore(value) = 0; } } return; // case 3: // { - // DownpourCtrFeatureValue::unseen_days(value)++; + // DownpourCtrFeatureValue::UnseenDays(value)++; // } // return; default: @@ -263,17 +204,17 @@ int32_t DownpourCtrAccessor::Create(float** values, size_t num) { auto embedx_dim = _config.embedx_dim(); for (size_t value_item = 0; value_item < num; ++value_item) { float* value = values[value_item]; - value[DownpourCtrFeatureValue::unseen_days_index()] = 0; - value[DownpourCtrFeatureValue::delta_score_index()] = 0; + value[DownpourCtrFeatureValue::UnseenDaysIndex()] = 0; + value[DownpourCtrFeatureValue::DeltaScoreIndex()] = 0; value[DownpourCtrFeatureValue::ShowIndex()] = 0; value[DownpourCtrFeatureValue::ClickIndex()] = 0; value[DownpourCtrFeatureValue::SlotIndex()] = -1; - _embed_sgd_rule->init_value( - value + DownpourCtrFeatureValue::Embed_W_Index(), - value + DownpourCtrFeatureValue::embed_g2sum_index(), true); - _embedx_sgd_rule->init_value( - value + DownpourCtrFeatureValue::Embedx_W_Index(), - value + DownpourCtrFeatureValue::embedx_g2sum_index()); + _embed_sgd_rule->InitValue( + value + DownpourCtrFeatureValue::EmbedWIndex(), + value + DownpourCtrFeatureValue::EmbedG2SumIndex(), true); + _embedx_sgd_rule->InitValue( + value + DownpourCtrFeatureValue::EmbedxWIndex(), + value + DownpourCtrFeatureValue::EmbedxG2SumIndex()); } return 0; } @@ -289,7 +230,7 @@ bool DownpourCtrAccessor::NeedExtendMF(float* value) { } bool DownpourCtrAccessor::HasMF(size_t size) { - return size > DownpourCtrFeatureValue::embedx_g2sum_index(); + return size > DownpourCtrFeatureValue::EmbedxG2SumIndex(); } // from DownpourCtrFeatureValue to DownpourCtrPullValue @@ -303,10 +244,10 @@ int32_t DownpourCtrAccessor::Select(float** select_values, const float** values, value[DownpourCtrFeatureValue::ShowIndex()]; select_value[DownpourCtrPullValue::ClickIndex()] = value[DownpourCtrFeatureValue::ClickIndex()]; - select_value[DownpourCtrPullValue::Embed_W_Index()] = - value[DownpourCtrFeatureValue::Embed_W_Index()]; - memcpy(select_value + DownpourCtrPullValue::Embedx_W_Index(), - value + DownpourCtrFeatureValue::Embedx_W_Index(), + select_value[DownpourCtrPullValue::EmbedWIndex()] = + value[DownpourCtrFeatureValue::EmbedWIndex()]; + memcpy(select_value + DownpourCtrPullValue::EmbedxWIndex(), + value + DownpourCtrFeatureValue::EmbedxWIndex(), embedx_dim * sizeof(float)); } return 0; @@ -347,20 +288,20 @@ int32_t DownpourCtrAccessor::Update(float** update_values, update_value[DownpourCtrFeatureValue::ShowIndex()] += push_show; update_value[DownpourCtrFeatureValue::ClickIndex()] += push_click; update_value[DownpourCtrFeatureValue::SlotIndex()] = slot; - update_value[DownpourCtrFeatureValue::delta_score_index()] += + update_value[DownpourCtrFeatureValue::DeltaScoreIndex()] += (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + push_click * _config.ctr_accessor_param().click_coeff(); //(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + // push_click * _config.ctr_accessor_param().click_coeff(); - update_value[DownpourCtrFeatureValue::unseen_days_index()] = 0; - _embed_sgd_rule->update_value( - update_value + DownpourCtrFeatureValue::Embed_W_Index(), - update_value + DownpourCtrFeatureValue::embed_g2sum_index(), - push_value + DownpourCtrPushValue::Embed_G_Index(), push_show); - _embedx_sgd_rule->update_value( - update_value + DownpourCtrFeatureValue::Embedx_W_Index(), - update_value + DownpourCtrFeatureValue::embedx_g2sum_index(), - push_value + DownpourCtrPushValue::Embedx_G_Index(), push_show); + update_value[DownpourCtrFeatureValue::UnseenDaysIndex()] = 0; + _embed_sgd_rule->UpdateValue( + update_value + DownpourCtrFeatureValue::EmbedWIndex(), + update_value + DownpourCtrFeatureValue::EmbedG2SumIndex(), + push_value + DownpourCtrPushValue::EmbedGIndex(), push_show); + _embedx_sgd_rule->UpdateValue( + update_value + DownpourCtrFeatureValue::EmbedxWIndex(), + update_value + DownpourCtrFeatureValue::EmbedxG2SumIndex(), + push_value + DownpourCtrPushValue::EmbedxGIndex(), push_show); } return 0; } @@ -373,7 +314,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) { } else if (stage == 1) { auto show = DownpourCtrPushValue::Show(const_cast(value)); auto click = DownpourCtrPushValue::Click(const_cast(value)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score <= 0) { return false; } @@ -387,7 +328,7 @@ bool DownpourCtrAccessor::CreateValue(int stage, const float* value) { } } -float DownpourCtrAccessor::show_click_score(float show, float click) { +float DownpourCtrAccessor::ShowClickScore(float show, float click) { // auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); // auto click_coeff = _config.ctr_accessor_param().click_coeff(); auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); @@ -403,7 +344,7 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) { << v[5] << " " << v[6]; auto show = DownpourCtrFeatureValue::Show(const_cast(v)); auto click = DownpourCtrFeatureValue::Click(const_cast(v)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score >= _config.embedx_threshold() && param_size > 7) { os << " " << v[7]; for (auto i = 0; i < _config.embedx_dim(); ++i) { @@ -415,18 +356,18 @@ std::string DownpourCtrAccessor::ParseToString(const float* v, int param_size) { int DownpourCtrAccessor::ParseFromString(const std::string& str, float* value) { int embedx_dim = _config.embedx_dim(); - float data_buff[Dim()]; + float data_buff[_accessor_info.dim]; float* data_buff_ptr = data_buff; - _embedx_sgd_rule->init_value( - data_buff_ptr + DownpourCtrFeatureValue::Embedx_W_Index(), - data_buff_ptr + DownpourCtrFeatureValue::embedx_g2sum_index()); + _embedx_sgd_rule->InitValue( + data_buff_ptr + DownpourCtrFeatureValue::EmbedxWIndex(), + data_buff_ptr + DownpourCtrFeatureValue::EmbedxG2SumIndex()); auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr); CHECK(str_len >= 6) << "expect more than 6 real:" << str_len; // no slot, embedx - int value_dim = Dim(); - int embedx_g2sum_index = DownpourCtrFeatureValue::embedx_g2sum_index(); + int value_dim = _accessor_info.dim; + int embedx_g2sum_index = DownpourCtrFeatureValue::EmbedxG2SumIndex(); value[DownpourCtrFeatureValue::SlotIndex()] = -1; // other case if (str_len == (value_dim - 1)) { @@ -459,25 +400,25 @@ void DownpourCtrAccessor::update_time_decay(float* value, if (_day_id == 0) { return; } - auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); + auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value); if (unseen_days == 0) { - DownpourCtrFeatureValue::unseen_days(value) = _day_id; + DownpourCtrFeatureValue::UnseenDays(value) = _day_id; return; } // for the origin load (unseenday = 0 -15) if (unseen_days < _config.ctr_accessor_param().delete_after_unseen_days()) { // pull if (is_update_seen_day) { - DownpourCtrFeatureValue::unseen_days(value) = _day_id; + DownpourCtrFeatureValue::UnseenDays(value) = _day_id; return; // save 舍弃原始的unseenday,都变为上一天出现,保证show/click不被重复decay } else { - DownpourCtrFeatureValue::unseen_days(value) = _day_id - 1; + DownpourCtrFeatureValue::UnseenDays(value) = _day_id - 1; } } int16_t day_diff = _day_id - unseen_days; if (day_diff < 0) { - DownpourCtrFeatureValue::unseen_days(value) = _day_id; + DownpourCtrFeatureValue::UnseenDays(value) = _day_id; return; } if (day_diff >= _config.ctr_accessor_param().delete_after_unseen_days()) { @@ -486,7 +427,7 @@ void DownpourCtrAccessor::update_time_decay(float* value, DownpourCtrFeatureValue::Show(value) *= _time_decay_rates[day_diff]; DownpourCtrFeatureValue::Click(value) *= _time_decay_rates[day_diff]; if (is_update_seen_day) { - DownpourCtrFeatureValue::unseen_days(value) = _day_id; + DownpourCtrFeatureValue::UnseenDays(value) = _day_id; } } diff --git a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h index de1f080f42e1f..785acaf8ea5a4 100644 --- a/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h +++ b/paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h @@ -45,34 +45,34 @@ class DownpourCtrAccessor : public ValueAccessor { static int Dim(int embedx_dim) { return 8 + embedx_dim; } static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } - static int unseen_days_index() { return 0; } - static int delta_score_index() { - return DownpourCtrFeatureValue::unseen_days_index() + 1; + static int UnseenDaysIndex() { return 0; } + static int DeltaScoreIndex() { + return DownpourCtrFeatureValue::UnseenDaysIndex() + 1; } static int ShowIndex() { - return DownpourCtrFeatureValue::delta_score_index() + 1; + return DownpourCtrFeatureValue::DeltaScoreIndex() + 1; } static int ClickIndex() { return DownpourCtrFeatureValue::ShowIndex() + 1; } - static int Embed_W_Index() { + static int EmbedWIndex() { return DownpourCtrFeatureValue::ClickIndex() + 1; } - static int embed_g2sum_index() { - return DownpourCtrFeatureValue::Embed_W_Index() + 1; + static int EmbedG2SumIndex() { + return DownpourCtrFeatureValue::EmbedWIndex() + 1; } static int SlotIndex() { - return DownpourCtrFeatureValue::embed_g2sum_index() + 1; + return DownpourCtrFeatureValue::EmbedG2SumIndex() + 1; } - static int embedx_g2sum_index() { + static int EmbedxG2SumIndex() { return DownpourCtrFeatureValue::SlotIndex() + 1; } - static int Embedx_W_Index() { - return DownpourCtrFeatureValue::embedx_g2sum_index() + 1; + static int EmbedxWIndex() { + return DownpourCtrFeatureValue::EmbedxG2SumIndex() + 1; } - static float& unseen_days(float* val) { - return val[DownpourCtrFeatureValue::unseen_days_index()]; + static float& UnseenDays(float* val) { + return val[DownpourCtrFeatureValue::UnseenDaysIndex()]; } - static float& delta_score(float* val) { - return val[DownpourCtrFeatureValue::delta_score_index()]; + static float& DeltaScore(float* val) { + return val[DownpourCtrFeatureValue::DeltaScoreIndex()]; } static float& Show(float* val) { return val[DownpourCtrFeatureValue::ShowIndex()]; @@ -84,16 +84,16 @@ class DownpourCtrAccessor : public ValueAccessor { return val[DownpourCtrFeatureValue::SlotIndex()]; } static float& EmbedW(float* val) { - return val[DownpourCtrFeatureValue::Embed_W_Index()]; + return val[DownpourCtrFeatureValue::EmbedWIndex()]; } - static float& embed_g2sum(float* val) { - return val[DownpourCtrFeatureValue::embed_g2sum_index()]; + static float& EmbedG2Sum(float* val) { + return val[DownpourCtrFeatureValue::EmbedG2SumIndex()]; } - static float& embedx_g2sum(float* val) { - return val[DownpourCtrFeatureValue::embedx_g2sum_index()]; + static float& EmbedxG2Sum(float* val) { + return val[DownpourCtrFeatureValue::EmbedxG2SumIndex()]; } static float* EmbedxW(float* val) { - return (val + DownpourCtrFeatureValue::Embedx_W_Index()); + return (val + DownpourCtrFeatureValue::EmbedxWIndex()); } }; @@ -113,11 +113,9 @@ class DownpourCtrAccessor : public ValueAccessor { static int SlotIndex() { return 0; } static int ShowIndex() { return DownpourCtrPushValue::SlotIndex() + 1; } static int ClickIndex() { return DownpourCtrPushValue::ShowIndex() + 1; } - static int Embed_G_Index() { - return DownpourCtrPushValue::ClickIndex() + 1; - } - static int Embedx_G_Index() { - return DownpourCtrPushValue::Embed_G_Index() + 1; + static int EmbedGIndex() { return DownpourCtrPushValue::ClickIndex() + 1; } + static int EmbedxGIndex() { + return DownpourCtrPushValue::EmbedGIndex() + 1; } static float& Slot(float* val) { return val[0]; } static float& Show(float* val) { return val[1]; } @@ -139,8 +137,8 @@ class DownpourCtrAccessor : public ValueAccessor { static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } static int ShowIndex() { return 0; } static int ClickIndex() { return 1; } - static int Embed_W_Index() { return 2; } - static int Embedx_W_Index() { return 3; } + static int EmbedWIndex() { return 2; } + static int EmbedxWIndex() { return 3; } static float& Show(float* val) { return val[DownpourCtrPullValue::ShowIndex()]; } @@ -148,38 +146,18 @@ class DownpourCtrAccessor : public ValueAccessor { return val[DownpourCtrPullValue::ClickIndex()]; } static float& EmbedW(float* val) { - return val[DownpourCtrPullValue::Embed_W_Index()]; + return val[DownpourCtrPullValue::EmbedWIndex()]; } static float* EmbedxW(float* val) { - return val + DownpourCtrPullValue::Embedx_W_Index(); + return val + DownpourCtrPullValue::EmbedxWIndex(); } }; DownpourCtrAccessor() {} virtual ~DownpourCtrAccessor() {} virtual int Initialize(); - virtual void SetTableInfo(AccessorInfo& info); - virtual size_t GetTableInfo(InfoKey key); - // value维度 - size_t Dim(); - // value各个维度的size - size_t DimSize(size_t dim); - // value各维度相加总size - size_t Size(); - // value中mf动态长度部分总size大小, sparse下生效 - size_t MFSize(); - // pull value维度 - size_t SelectDim(); - // pull value各个维度的size - size_t SelectDimSize(size_t dim); - // pull value各维度相加总size - size_t SelectSize(); - // push value维度 - size_t UpdateDim(); - // push value各个维度的size - size_t UpdateDimSize(size_t dim); - // push value各维度相加总size - size_t UpdateSize(); + // 初始化AccessorInfo + virtual void InitAccessorInfo(); // 判断该value是否进行shrink virtual bool Shrink(float* value); // 判断该value是否保存到ssd @@ -219,7 +197,7 @@ class DownpourCtrAccessor : public ValueAccessor { virtual float GetField(float* value, const std::string& name) override { CHECK(name == "show"); if (name == "show") { - auto unseen_days = DownpourCtrFeatureValue::unseen_days(value); + auto unseen_days = DownpourCtrFeatureValue::UnseenDays(value); int16_t day_diff = _day_id - unseen_days; auto show_right = DownpourCtrFeatureValue::Show(value) * _time_decay_rates[day_diff]; @@ -238,7 +216,7 @@ class DownpourCtrAccessor : public ValueAccessor { bool test_func() { return false; } private: - float show_click_score(float show, float click); + float ShowClickScore(float show, float click); void set_time_decay_rates(); private: diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index 97e3c008d9478..b4b2263ed77bf 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -89,7 +89,7 @@ int32_t MemorySparseTable::Load(const std::string& path, size_t file_start_idx = _shard_idx * _avg_local_shard_num; size_t feature_value_size = - _value_accesor->GetTableInfo(SIZE) / sizeof(float); + _value_accesor->GetAccessorInfo().size / sizeof(float); int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; omp_set_num_threads(thread_num); @@ -174,7 +174,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, size_t file_start_idx = _shard_idx * _avg_local_shard_num; size_t feature_value_size = - _value_accesor->GetTableInfo(SIZE) / sizeof(float); + _value_accesor->GetAccessorInfo().size / sizeof(float); int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; omp_set_num_threads(thread_num); @@ -415,10 +415,12 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, CostTimer timer("pserver_sparse_select_all"); std::vector> tasks(_real_local_shard_num); - const size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); - size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); + const size_t value_size = + _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_size = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); size_t select_value_size = - _value_accesor->GetTableInfo(SELECT_SIZE) / sizeof(float); + _value_accesor->GetAccessorInfo().select_size / sizeof(float); // std::atomic missed_keys{0}; std::vector>> task_keys( @@ -482,8 +484,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, int32_t MemorySparseTable::PullSparsePtr(char** pull_values, const uint64_t* keys, size_t num) { CostTimer timer("pscore_sparse_select_all"); - size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float); - size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); + size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_size = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( @@ -541,10 +544,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, task_keys[shard_id].push_back({keys[i], i}); } - const size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float); - size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); + const size_t value_col = + _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_col = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); size_t update_value_col = - _value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float); + _value_accesor->GetAccessorInfo().update_size / sizeof(float); for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( @@ -619,10 +624,11 @@ int32_t MemorySparseTable::_PushSparse(const uint64_t* keys, task_keys[shard_id].push_back({keys[i], i}); } - size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float); - size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float); + size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_col = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); size_t update_value_col = - _value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float); + _value_accesor->GetAccessorInfo().update_size / sizeof(float); for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.cc b/paddle/fluid/distributed/ps/table/sparse_accessor.cc index 511b36389aaee..bc537880f1c21 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.cc +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.cc @@ -23,87 +23,35 @@ namespace distributed { int SparseAccessor::Initialize() { auto name = _config.embed_sgd_param().name(); _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1); + _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1); name = _config.embedx_sgd_param().name(); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); - _embedx_sgd_rule->load_config(_config.embedx_sgd_param(), - _config.embedx_dim()); + _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), + _config.embedx_dim()); - sparse_feature_value.embed_sgd_dim = _embed_sgd_rule->dim(); + sparse_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim(); sparse_feature_value.embedx_dim = _config.embedx_dim(); - sparse_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim(); + sparse_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); + InitAccessorInfo(); return 0; } -void SparseAccessor::SetTableInfo(AccessorInfo& info) { - info.dim = Dim(); - info.size = Size(); - info.select_dim = SelectDim(); - info.select_size = SelectSize(); - info.update_dim = UpdateDim(); - info.update_size = UpdateSize(); - info.mf_size = MFSize(); -} - -size_t SparseAccessor::GetTableInfo(InfoKey key) { - switch (key) { - case DIM: - return Dim(); - case SIZE: - return Size(); - case SELECT_DIM: - return SelectDim(); - case SELECT_SIZE: - return SelectSize(); - case UPDATE_DIM: - return UpdateDim(); - case UPDATE_SIZE: - return UpdateSize(); - case MF_SIZE: - return MFSize(); - default: - return 0; - } - return 0; -} - -size_t SparseAccessor::Dim() { return sparse_feature_value.Dim(); } - -size_t SparseAccessor::DimSize(size_t dim) { +void SparseAccessor::InitAccessorInfo() { + _accessor_info.dim = sparse_feature_value.Dim(); + _accessor_info.size = sparse_feature_value.Size(); auto embedx_dim = _config.embedx_dim(); - return sparse_feature_value.DimSize(dim, embedx_dim); -} - -size_t SparseAccessor::Size() { return sparse_feature_value.Size(); } - -size_t SparseAccessor::MFSize() { - return (_config.embedx_dim() + sparse_feature_value.embedx_sgd_dim) * - sizeof(float); // embedx embedx_g2sum + _accessor_info.select_dim = 1 + embedx_dim; + _accessor_info.select_size = _accessor_info.select_dim * sizeof(float); + ; + _accessor_info.update_dim = 4 + embedx_dim; + _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); + _accessor_info.mf_size = + (embedx_dim + sparse_feature_value.embedx_sgd_dim) * sizeof(float); } -// pull value -size_t SparseAccessor::SelectDim() { - auto embedx_dim = _config.embedx_dim(); - return 1 + embedx_dim; -} - -size_t SparseAccessor::SelectDimSize(size_t dim) { return sizeof(float); } - -size_t SparseAccessor::SelectSize() { return SelectDim() * sizeof(float); } - -// push value -size_t SparseAccessor::UpdateDim() { - auto embedx_dim = _config.embedx_dim(); - return 4 + embedx_dim; -} - -size_t SparseAccessor::UpdateDimSize(size_t dim) { return sizeof(float); } - -size_t SparseAccessor::UpdateSize() { return UpdateDim() * sizeof(float); } - bool SparseAccessor::Shrink(float* value) { auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); @@ -116,9 +64,9 @@ bool SparseAccessor::Shrink(float* value) { sparse_feature_value.Click(value) *= _show_click_decay_rate; // shrink after - auto score = show_click_score(sparse_feature_value.Show(value), - sparse_feature_value.Click(value)); - auto unseen_days = sparse_feature_value.unseen_days(value); + auto score = ShowClickScore(sparse_feature_value.Show(value), + sparse_feature_value.Click(value)); + auto unseen_days = sparse_feature_value.UnseenDays(value); if (score < delete_threshold || unseen_days > delete_after_unseen_days) { return true; } @@ -141,14 +89,13 @@ bool SparseAccessor::Save(float* value, int param) { case 1: // save xbox base case 2: { - if (show_click_score(sparse_feature_value.Show(value), - sparse_feature_value.Click(value)) >= - base_threshold && - sparse_feature_value.delta_score(value) >= delta_threshold && - sparse_feature_value.unseen_days(value) <= delta_keep_days) { + if (ShowClickScore(sparse_feature_value.Show(value), + sparse_feature_value.Click(value)) >= base_threshold && + sparse_feature_value.DeltaScore(value) >= delta_threshold && + sparse_feature_value.UnseenDays(value) <= delta_keep_days) { // do this after save, because it must not be modified when retry if (param == 2) { - sparse_feature_value.delta_score(value) = 0; + sparse_feature_value.DeltaScore(value) = 0; } return true; } else { @@ -158,7 +105,7 @@ bool SparseAccessor::Save(float* value, int param) { // already decayed in shrink case 3: { // do this after save, because it must not be modified when retry - // sparse_feature_value.unseen_days(value)++; + // sparse_feature_value.UnseenDays(value)++; return true; } // save revert batch_model @@ -179,17 +126,16 @@ void SparseAccessor::UpdateStatAfterSave(float* value, int param) { } switch (param) { case 1: { - if (show_click_score(sparse_feature_value.Show(value), - sparse_feature_value.Click(value)) >= - base_threshold && - sparse_feature_value.delta_score(value) >= delta_threshold && - sparse_feature_value.unseen_days(value) <= delta_keep_days) { - sparse_feature_value.delta_score(value) = 0; + if (ShowClickScore(sparse_feature_value.Show(value), + sparse_feature_value.Click(value)) >= base_threshold && + sparse_feature_value.DeltaScore(value) >= delta_threshold && + sparse_feature_value.UnseenDays(value) <= delta_keep_days) { + sparse_feature_value.DeltaScore(value) = 0; } } return; case 3: { - sparse_feature_value.unseen_days(value)++; + sparse_feature_value.UnseenDays(value)++; } return; default: @@ -201,17 +147,16 @@ int32_t SparseAccessor::Create(float** values, size_t num) { auto embedx_dim = _config.embedx_dim(); for (size_t value_item = 0; value_item < num; ++value_item) { float* value = values[value_item]; - value[sparse_feature_value.unseen_days_index()] = 0; - value[sparse_feature_value.delta_score_index()] = 0; + value[sparse_feature_value.UnseenDaysIndex()] = 0; + value[sparse_feature_value.DeltaScoreIndex()] = 0; value[sparse_feature_value.ShowIndex()] = 0; value[sparse_feature_value.ClickIndex()] = 0; value[sparse_feature_value.SlotIndex()] = -1; - _embed_sgd_rule->init_value( - value + sparse_feature_value.Embed_W_Index(), - value + sparse_feature_value.embed_g2sum_index()); - _embedx_sgd_rule->init_value( - value + sparse_feature_value.Embedx_W_Index(), - value + sparse_feature_value.embedx_g2sum_index(), false); + _embed_sgd_rule->InitValue(value + sparse_feature_value.EmbedWIndex(), + value + sparse_feature_value.EmbedG2SumIndex()); + _embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(), + value + sparse_feature_value.EmbedxG2SumIndex(), + false); } return 0; } @@ -225,7 +170,7 @@ bool SparseAccessor::NeedExtendMF(float* value) { } bool SparseAccessor::HasMF(size_t size) { - return size > sparse_feature_value.embedx_g2sum_index(); + return size > sparse_feature_value.EmbedxG2SumIndex(); } // from SparseFeatureValue to SparsePullValue @@ -235,10 +180,10 @@ int32_t SparseAccessor::Select(float** select_values, const float** values, for (size_t value_item = 0; value_item < num; ++value_item) { float* select_value = select_values[value_item]; const float* value = values[value_item]; - select_value[SparsePullValue::Embed_W_Index()] = - value[sparse_feature_value.Embed_W_Index()]; - memcpy(select_value + SparsePullValue::Embedx_W_Index(), - value + sparse_feature_value.Embedx_W_Index(), + select_value[SparsePullValue::EmbedWIndex()] = + value[sparse_feature_value.EmbedWIndex()]; + memcpy(select_value + SparsePullValue::EmbedxWIndex(), + value + sparse_feature_value.EmbedxWIndex(), embedx_dim * sizeof(float)); } return 0; @@ -278,18 +223,18 @@ int32_t SparseAccessor::Update(float** update_values, const float** push_values, update_value[sparse_feature_value.ShowIndex()] += push_show; update_value[sparse_feature_value.ClickIndex()] += push_click; update_value[sparse_feature_value.SlotIndex()] = slot; - update_value[sparse_feature_value.delta_score_index()] += + update_value[sparse_feature_value.DeltaScoreIndex()] += (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + push_click * _config.ctr_accessor_param().click_coeff(); - update_value[sparse_feature_value.unseen_days_index()] = 0; - _embed_sgd_rule->update_value( - update_value + sparse_feature_value.Embed_W_Index(), - update_value + sparse_feature_value.embed_g2sum_index(), - push_value + SparsePushValue::Embed_G_Index()); - _embedx_sgd_rule->update_value( - update_value + sparse_feature_value.Embedx_W_Index(), - update_value + sparse_feature_value.embedx_g2sum_index(), - push_value + SparsePushValue::Embedx_G_Index()); + update_value[sparse_feature_value.UnseenDaysIndex()] = 0; + _embed_sgd_rule->UpdateValue( + update_value + sparse_feature_value.EmbedWIndex(), + update_value + sparse_feature_value.EmbedG2SumIndex(), + push_value + SparsePushValue::EmbedGIndex()); + _embedx_sgd_rule->UpdateValue( + update_value + sparse_feature_value.EmbedxWIndex(), + update_value + sparse_feature_value.EmbedxG2SumIndex(), + push_value + SparsePushValue::EmbedxGIndex()); } return 0; } @@ -303,7 +248,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) { // operation auto show = SparsePushValue::Show(const_cast(value)); auto click = SparsePushValue::Click(const_cast(value)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score <= 0) { return false; } @@ -317,7 +262,7 @@ bool SparseAccessor::CreateValue(int stage, const float* value) { } } -float SparseAccessor::show_click_score(float show, float click) { +float SparseAccessor::ShowClickScore(float show, float click) { auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff(); auto click_coeff = _config.ctr_accessor_param().click_coeff(); return (show - click) * nonclk_coeff + click * click_coeff; @@ -329,16 +274,16 @@ std::string SparseAccessor::ParseToString(const float* v, int param) { os.str(""); os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " " << v[5]; - for (int i = sparse_feature_value.embed_g2sum_index(); - i < sparse_feature_value.Embedx_W_Index(); i++) { + for (int i = sparse_feature_value.EmbedG2SumIndex(); + i < sparse_feature_value.EmbedxWIndex(); i++) { os << " " << v[i]; } auto show = sparse_feature_value.Show(const_cast(v)); auto click = sparse_feature_value.Click(const_cast(v)); - auto score = show_click_score(show, click); + auto score = ShowClickScore(show, click); if (score >= _config.embedx_threshold() && - param > sparse_feature_value.Embedx_W_Index()) { - for (auto i = sparse_feature_value.Embedx_W_Index(); + param > sparse_feature_value.EmbedxWIndex()) { + for (auto i = sparse_feature_value.EmbedxWIndex(); i < sparse_feature_value.Dim(); ++i) { os << " " << v[i]; } @@ -349,9 +294,8 @@ std::string SparseAccessor::ParseToString(const float* v, int param) { int SparseAccessor::ParseFromString(const std::string& str, float* value) { int embedx_dim = _config.embedx_dim(); - _embedx_sgd_rule->init_value( - value + sparse_feature_value.Embedx_W_Index(), - value + sparse_feature_value.embedx_g2sum_index()); + _embedx_sgd_rule->InitValue(value + sparse_feature_value.EmbedxWIndex(), + value + sparse_feature_value.EmbedxG2SumIndex()); auto ret = paddle::string::str_to_float(str.data(), value); CHECK(ret >= 6) << "expect more than 6 real:" << ret; return ret; diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.h b/paddle/fluid/distributed/ps/table/sparse_accessor.h index b11acff6aaaa3..5ca5d21707a2b 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.h +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.h @@ -44,24 +44,24 @@ class SparseAccessor : public ValueAccessor { int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } int Size() { return Dim() * sizeof(float); } int SlotIndex() { return 0; } - int unseen_days_index() { return SlotIndex() + 1; } - int delta_score_index() { return unseen_days_index() + 1; } - int ShowIndex() { return delta_score_index() + 1; } + int UnseenDaysIndex() { return SlotIndex() + 1; } + int DeltaScoreIndex() { return UnseenDaysIndex() + 1; } + int ShowIndex() { return DeltaScoreIndex() + 1; } int ClickIndex() { return ShowIndex() + 1; } - int Embed_W_Index() { return ClickIndex() + 1; } - int embed_g2sum_index() { return Embed_W_Index() + 1; } - int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; } - int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; } + int EmbedWIndex() { return ClickIndex() + 1; } + int EmbedG2SumIndex() { return EmbedWIndex() + 1; } + int EmbedxWIndex() { return EmbedG2SumIndex() + embed_sgd_dim; } + int EmbedxG2SumIndex() { return EmbedxWIndex() + embedx_dim; } - float& unseen_days(float* val) { return val[unseen_days_index()]; } - float& delta_score(float* val) { return val[delta_score_index()]; } + float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; } + float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } float& Show(float* val) { return val[ShowIndex()]; } float& Click(float* val) { return val[ClickIndex()]; } float& Slot(float* val) { return val[SlotIndex()]; } - float& EmbedW(float* val) { return val[Embed_W_Index()]; } - float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } - float& EmbedxW(float* val) { return val[Embedx_W_Index()]; } - float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } + float& EmbedW(float* val) { return val[EmbedWIndex()]; } + float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; } + float& EmbedxW(float* val) { return val[EmbedxWIndex()]; } + float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; } int embed_sgd_dim; int embedx_dim; @@ -84,18 +84,18 @@ class SparseAccessor : public ValueAccessor { static int SlotIndex() { return 0; } static int ShowIndex() { return SparsePushValue::SlotIndex() + 1; } static int ClickIndex() { return SparsePushValue::ShowIndex() + 1; } - static int Embed_G_Index() { return SparsePushValue::ClickIndex() + 1; } - static int Embedx_G_Index() { return SparsePushValue::Embed_G_Index() + 1; } + static int EmbedGIndex() { return SparsePushValue::ClickIndex() + 1; } + static int EmbedxGIndex() { return SparsePushValue::EmbedGIndex() + 1; } static float& Slot(float* val) { return val[SparsePushValue::SlotIndex()]; } static float& Show(float* val) { return val[SparsePushValue::ShowIndex()]; } static float& Click(float* val) { return val[SparsePushValue::ClickIndex()]; } static float& EmbedG(float* val) { - return val[SparsePushValue::Embed_G_Index()]; + return val[SparsePushValue::EmbedGIndex()]; } static float* EmbedxG(float* val) { - return val + SparsePushValue::Embedx_G_Index(); + return val + SparsePushValue::EmbedxGIndex(); } }; @@ -108,41 +108,21 @@ class SparseAccessor : public ValueAccessor { static int Dim(int embedx_dim) { return 1 + embedx_dim; } static int DimSize(size_t dim) { return sizeof(float); } static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } - static int Embed_W_Index() { return 0; } - static int Embedx_W_Index() { return 1; } + static int EmbedWIndex() { return 0; } + static int EmbedxWIndex() { return 1; } static float& EmbedW(float* val) { - return val[SparsePullValue::Embed_W_Index()]; + return val[SparsePullValue::EmbedWIndex()]; } static float* EmbedxW(float* val) { - return val + SparsePullValue::Embedx_W_Index(); + return val + SparsePullValue::EmbedxWIndex(); } }; SparseAccessor() {} - virtual int Initialize(); - virtual void SetTableInfo(AccessorInfo& info); - virtual size_t GetTableInfo(InfoKey key); virtual ~SparseAccessor() {} - // value维度 - size_t Dim(); - // value各个维度的size - size_t DimSize(size_t dim); - // value各维度相加总size - size_t Size(); - // value中mf动态长度部分总size大小, sparse下生效 - size_t MFSize(); - // pull value维度 - size_t SelectDim(); - // pull value各个维度的size - size_t SelectDimSize(size_t dim); - // pull value各维度相加总size - size_t SelectSize(); - // push value维度 - size_t UpdateDim(); - // push value各个维度的size - size_t UpdateDimSize(size_t dim); - // push value各维度相加总size - size_t UpdateSize(); + virtual int Initialize(); + // 初始化AccessorInfo + virtual void InitAccessorInfo(); // 判断该value是否进行shrink virtual bool Shrink(float* value); // 判断该value是否保存到ssd @@ -186,7 +166,7 @@ class SparseAccessor : public ValueAccessor { } private: - // float show_click_score(float show, float click); + // float ShowClickScore(float show, float click); // SparseValueSGDRule* _embed_sgd_rule; // SparseValueSGDRule* _embedx_sgd_rule; @@ -197,7 +177,7 @@ class SparseAccessor : public ValueAccessor { public: // TODO(zhaocaibei123): it should be private, but we make it public // for unit test SparseFeatureValue sparse_feature_value; - float show_click_score(float show, float click); + float ShowClickScore(float show, float click); SparseValueSGDRule* _embed_sgd_rule; SparseValueSGDRule* _embedx_sgd_rule; }; diff --git a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc index 3e39d6f976d12..8471b93612828 100644 --- a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc +++ b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc @@ -21,8 +21,8 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient"); namespace paddle { namespace distributed { -void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim) { +void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim) { _embedding_dim = emb_dim; auto naive_param = param.naive(); learning_rate_ = naive_param.learning_rate(); @@ -39,17 +39,16 @@ void SparseNaiveSGDRule::load_config(const SparseCommonSGDRuleParameter& param, } } -void SparseNaiveSGDRule::update_value_work(float* w, float* sgd, - const float* push_value, - float scale) { +void SparseNaiveSGDRule::UpdateValueWork(float* w, float* sgd, + const float* push_value, float scale) { for (size_t i = 0; i < _embedding_dim; ++i) { w[i] -= learning_rate_ * push_value[i]; - bound_value(w[i]); + BoundValue(w[i]); } } -void SparseNaiveSGDRule::init_value_work(float* value, float* sgd, - bool zero_init) { +void SparseNaiveSGDRule::InitValueWork(float* value, float* sgd, + bool zero_init) { if (zero_init) { for (size_t i = 0; i < _embedding_dim; ++i) { value[i] = 0; @@ -60,12 +59,12 @@ void SparseNaiveSGDRule::init_value_work(float* value, float* sgd, (local_uniform_real_distribution()(local_random_engine()) * 2 - 1) * _initial_range; - bound_value(value[i]); + BoundValue(value[i]); } } } -void SparseAdaGradSGDRule::load_config( - const SparseCommonSGDRuleParameter& param, size_t emb_dim) { +void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim) { _embedding_dim = emb_dim; auto adagrad_param = param.adagrad(); learning_rate_ = adagrad_param.learning_rate(); @@ -84,42 +83,42 @@ void SparseAdaGradSGDRule::load_config( } } -void SparseAdaGradSGDRule::update_value_work(float* w, float* sgd, - const float* grad, float scale) { - float& g2sum = sgd[g2sum_index()]; +void SparseAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, + const float* grad, float scale) { + float& g2sum = sgd[G2SumIndex()]; double add_g2sum = 0; for (int i = 0; i < _embedding_dim; i++) { double scaled_grad = grad[i] / scale; w[i] -= learning_rate_ * scaled_grad * sqrt(_initial_g2sum / (_initial_g2sum + g2sum)); - bound_value(w[i]); + BoundValue(w[i]); add_g2sum += scaled_grad * scaled_grad; } g2sum += add_g2sum / _embedding_dim; } -void SparseAdaGradSGDRule::init_value_work(float* value, float* sgd, - bool zero_init) { +void SparseAdaGradSGDRule::InitValueWork(float* value, float* sgd, + bool zero_init) { for (int i = 0; i < _embedding_dim; ++i) { if (zero_init) { value[i] = 0.0; - bound_value(value[i]); + BoundValue(value[i]); } else { value[i] = (local_uniform_real_distribution()(local_random_engine()) * 2 - 1) * _initial_range; - bound_value(value[i]); + BoundValue(value[i]); } } - sgd[g2sum_index()] = 0; + sgd[G2SumIndex()] = 0; } -void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim) { +void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim) { _embedding_dim = emb_dim; auto adagrad_param = param.adagrad(); learning_rate_ = adagrad_param.learning_rate(); @@ -138,38 +137,38 @@ void StdAdaGradSGDRule::load_config(const SparseCommonSGDRuleParameter& param, } } -void StdAdaGradSGDRule::update_value_work(float* w, float* sgd, - const float* grad, float scale) { +void StdAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, + float scale) { for (int i = 0; i < _embedding_dim; i++) { - float& g2sum = sgd[g2sum_index() + i]; + float& g2sum = sgd[G2SumIndex() + i]; double scaled_grad = grad[i] / scale; w[i] -= learning_rate_ * scaled_grad * sqrt(_initial_g2sum / (_initial_g2sum + g2sum)); - bound_value(w[i]); + BoundValue(w[i]); g2sum += scaled_grad * scaled_grad; } } -void StdAdaGradSGDRule::init_value_work(float* value, float* sgd, - bool zero_init) { +void StdAdaGradSGDRule::InitValueWork(float* value, float* sgd, + bool zero_init) { for (int i = 0; i < _embedding_dim; ++i) { if (zero_init) { value[i] = 0.0; - bound_value(value[i]); + BoundValue(value[i]); } else { value[i] = (local_uniform_real_distribution()(local_random_engine()) * 2 - 1) * _initial_range; - bound_value(value[i]); + BoundValue(value[i]); } - sgd[g2sum_index() + i] = 0; + sgd[G2SumIndex() + i] = 0; } } -void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim) { +void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim) { _embedding_dim = emb_dim; auto adam_param = param.adam(); learning_rate_ = adam_param.learning_rate(); @@ -189,12 +188,12 @@ void SparseAdamSGDRule::load_config(const SparseCommonSGDRuleParameter& param, } } -void SparseAdamSGDRule::update_value_work(float* w, float* sgd, - const float* grad, float scale) { - float* gsum = sgd + gsum_index(); - float* g2sum = sgd + g2sum_index(); - float* beta1_pow = sgd + beta1_pow_index(); - float* beta2_pow = sgd + beta2_pow_index(); +void SparseAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, + float scale) { + float* gsum = sgd + GSumIndex(); + float* g2sum = sgd + G2SumIndex(); + float* beta1_pow = sgd + Beta1PowIndex(); + float* beta2_pow = sgd + Beta2PowIndex(); const float* g = grad; float lr = learning_rate_; @@ -209,35 +208,35 @@ void SparseAdamSGDRule::update_value_work(float* w, float* sgd, g2sum[i] = _beta2_decay_rate * g2sum[i] + (1 - _beta2_decay_rate) * g[i] * g[i]; w[i] = w[i] - lr * (gsum[i] / (sqrt(g2sum[i]) + _ada_epsilon)); - bound_value(w[i]); + BoundValue(w[i]); } // update beta_pow_decay (*beta1_pow) *= _beta1_decay_rate; (*beta2_pow) *= _beta2_decay_rate; } -void SparseAdamSGDRule::init_value_work(float* value, float* sgd, - bool zero_init) { +void SparseAdamSGDRule::InitValueWork(float* value, float* sgd, + bool zero_init) { for (int i = 0; i < _embedding_dim; ++i) { if (zero_init) { value[i] = 0.0; - bound_value(value[i]); + BoundValue(value[i]); } else { value[i] = (local_uniform_real_distribution()(local_random_engine()) * 2 - 1) * _initial_range; - bound_value(value[i]); + BoundValue(value[i]); } } // init rule gsum and g2sum - for (int i = gsum_index(); i < beta1_pow_index(); i++) { + for (int i = GSumIndex(); i < Beta1PowIndex(); i++) { sgd[i] = 0.0; } // init beta1_pow and beta2_pow - *(sgd + beta1_pow_index()) = _beta1_decay_rate; - *(sgd + beta2_pow_index()) = _beta2_decay_rate; + *(sgd + Beta1PowIndex()) = _beta1_decay_rate; + *(sgd + Beta2PowIndex()) = _beta2_decay_rate; } } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h index ba2baa42f742a..55a37b5941921 100644 --- a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h +++ b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h @@ -28,33 +28,33 @@ class SparseValueSGDRule { public: SparseValueSGDRule() {} virtual ~SparseValueSGDRule() {} - virtual void load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim) { + virtual void LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim) { _embedding_dim = emb_dim; _name = param.name(); } - virtual void update_value_work(float* w, float* sgd, const float* push_value, - float scale) = 0; - virtual void init_value_work(float* value, float* sgd, bool zero_init) = 0; - virtual size_t dim() = 0; - const std::string& get_name() const { return _name; } - void init_value(float* value, float* sgd, bool zero_init = true) { - init_value_work(value, sgd, zero_init); + virtual void UpdateValueWork(float* w, float* sgd, const float* push_value, + float scale) = 0; + virtual void InitValueWork(float* value, float* sgd, bool zero_init) = 0; + virtual size_t Dim() = 0; + const std::string& GetName() const { return _name; } + void InitValue(float* value, float* sgd, bool zero_init = true) { + InitValueWork(value, sgd, zero_init); } - void update_value(float* w, float* sgd, const float* push_value, - float scale = 1) { - update_value_work(w, sgd, push_value, scale); + void UpdateValue(float* w, float* sgd, const float* push_value, + float scale = 1) { + UpdateValueWork(w, sgd, push_value, scale); } template - void bound_value(T& w) { // NOLINT + void BoundValue(T& w) { // NOLINT if (!(w >= _min_bound)) { w = (T)_min_bound; } else if (!(w <= _max_bound)) { w = (T)_max_bound; } } - float& min_bound() { return _min_bound; } - float& max_bound() { return _max_bound; } + float& MinBound() { return _min_bound; } + float& MaxBound() { return _max_bound; } protected: float _min_bound; @@ -70,12 +70,12 @@ REGISTER_PSCORE_REGISTERER(SparseValueSGDRule); class SparseNaiveSGDRule : public SparseValueSGDRule { public: - virtual void load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim); - virtual void update_value_work(float* w, float* sgd, const float* push_value, - float scale); - virtual void init_value_work(float* value, float* sgd, bool zero_init); - virtual size_t dim() { return 0; } + virtual void LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim); + virtual void UpdateValueWork(float* w, float* sgd, const float* push_value, + float scale); + virtual void InitValueWork(float* value, float* sgd, bool zero_init); + virtual size_t Dim() { return 0; } private: float learning_rate_; @@ -83,13 +83,13 @@ class SparseNaiveSGDRule : public SparseValueSGDRule { class SparseAdaGradSGDRule : public SparseValueSGDRule { public: - virtual void load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim); - virtual void update_value_work(float* w, float* sgd, const float* push_value, - float scale); - virtual void init_value_work(float* value, float* sgd, bool zero_init); - virtual size_t dim() { return 1; } - size_t g2sum_index() { return 0; } + virtual void LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim); + virtual void UpdateValueWork(float* w, float* sgd, const float* push_value, + float scale); + virtual void InitValueWork(float* value, float* sgd, bool zero_init); + virtual size_t Dim() { return 1; } + size_t G2SumIndex() { return 0; } private: float learning_rate_; @@ -98,13 +98,13 @@ class SparseAdaGradSGDRule : public SparseValueSGDRule { class StdAdaGradSGDRule : public SparseValueSGDRule { public: - virtual void load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim); - virtual void update_value_work(float* w, float* sgd, const float* push_value, - float scale); - virtual void init_value_work(float* value, float* sgd, bool zero_init); - virtual size_t dim() { return _embedding_dim; } - size_t g2sum_index() { return 0; } + virtual void LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim); + virtual void UpdateValueWork(float* w, float* sgd, const float* push_value, + float scale); + virtual void InitValueWork(float* value, float* sgd, bool zero_init); + virtual size_t Dim() { return _embedding_dim; } + size_t G2SumIndex() { return 0; } private: float learning_rate_; @@ -113,16 +113,16 @@ class StdAdaGradSGDRule : public SparseValueSGDRule { class SparseAdamSGDRule : public SparseValueSGDRule { public: - virtual void load_config(const SparseCommonSGDRuleParameter& param, - size_t emb_dim); - virtual void update_value_work(float* w, float* sgd, const float* push_value, - float scale); - virtual void init_value_work(float* value, float* sgd, bool zero_init); - virtual size_t dim() { return _embedding_dim * 2 + 2; } - size_t gsum_index() { return 0; } - size_t g2sum_index() { return gsum_index() + _embedding_dim; } - size_t beta1_pow_index() { return g2sum_index() + _embedding_dim; } - size_t beta2_pow_index() { return beta1_pow_index() + 1; } + virtual void LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim); + virtual void UpdateValueWork(float* w, float* sgd, const float* push_value, + float scale); + virtual void InitValueWork(float* value, float* sgd, bool zero_init); + virtual size_t Dim() { return _embedding_dim * 2 + 2; } + size_t GSumIndex() { return 0; } + size_t G2SumIndex() { return GSumIndex() + _embedding_dim; } + size_t Beta1PowIndex() { return G2SumIndex() + _embedding_dim; } + size_t Beta2PowIndex() { return Beta1PowIndex() + 1; } protected: float learning_rate_; diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index 9f17a2006d232..0a7352c97731f 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -103,7 +103,6 @@ int32_t Table::InitializeAccessor() { return -1; } _value_accesor.reset(accessor); - // _value_accesor->SetTableInfo(_table_info); return 0; } diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index c61efe769e2f8..f55c30b774059 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -162,7 +162,6 @@ class Table { TableParameter _config; float *_global_lr = nullptr; std::shared_ptr _value_accesor; - AccessorInfo _table_info; AfsClient _afs_client; }; REGISTER_PSCORE_REGISTERER(Table); diff --git a/paddle/fluid/distributed/ps/table/tensor_accessor.cc b/paddle/fluid/distributed/ps/table/tensor_accessor.cc index 43b791b6ac03b..5d1f69b7463da 100644 --- a/paddle/fluid/distributed/ps/table/tensor_accessor.cc +++ b/paddle/fluid/distributed/ps/table/tensor_accessor.cc @@ -18,51 +18,19 @@ namespace paddle { namespace distributed { -int CommMergeAccessor::Initialize() { return 0; } - -void CommMergeAccessor::SetTableInfo(AccessorInfo &info) { - info.select_dim = SelectDim(); - info.select_size = SelectSize(); - info.update_dim = UpdateDim(); - info.update_size = UpdateSize(); - info.fea_dim = fea_dim(); -} - -size_t CommMergeAccessor::GetTableInfo(InfoKey key) { - switch (key) { - case SELECT_DIM: - return SelectDim(); - case SELECT_SIZE: - return SelectSize(); - case UPDATE_DIM: - return UpdateDim(); - case UPDATE_SIZE: - return UpdateSize(); - case FEA_DIM: - return fea_dim(); - default: - return 0; - } +int CommMergeAccessor::Initialize() { + InitAccessorInfo(); return 0; } -// pull value 维度 -size_t CommMergeAccessor::SelectDim() { return _config.embedx_dim(); } - -// pull value 各个维度的size -size_t CommMergeAccessor::SelectDimSize(size_t dim) { return sizeof(float); } - -// pull value 各维度相加总size -size_t CommMergeAccessor::SelectSize() { return SelectDim() * sizeof(float); } - -// push value 维度 -size_t CommMergeAccessor::UpdateDim() { return _config.embedx_dim(); } - -// push value 各个维度的size -size_t CommMergeAccessor::UpdateDimSize(size_t dim) { return sizeof(float); } - -// push value 各维度相加总size -size_t CommMergeAccessor::UpdateSize() { return UpdateDim() * sizeof(float); } +void CommMergeAccessor::InitAccessorInfo() { + auto embedx_dim = _config.embedx_dim(); + _accessor_info.select_dim = embedx_dim; + _accessor_info.select_size = _accessor_info.select_dim * sizeof(float); + _accessor_info.update_dim = embedx_dim; + _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); + _accessor_info.fea_dim = _config.fea_dim(); +} // 判断该value 是否进行shrink bool CommMergeAccessor::Shrink(float * /*value*/) { return false; } diff --git a/paddle/fluid/distributed/ps/table/tensor_accessor.h b/paddle/fluid/distributed/ps/table/tensor_accessor.h index 1b454fe0c734b..60951598482ad 100644 --- a/paddle/fluid/distributed/ps/table/tensor_accessor.h +++ b/paddle/fluid/distributed/ps/table/tensor_accessor.h @@ -30,22 +30,8 @@ class CommMergeAccessor : public ValueAccessor { CommMergeAccessor() {} virtual ~CommMergeAccessor() {} virtual int Initialize(); - virtual void SetTableInfo(AccessorInfo &info); - virtual size_t GetTableInfo(InfoKey key); - // value维度 - // pull value维度 - size_t SelectDim(); - // pull value各个维度的size - size_t SelectDimSize(size_t dim); - // pull value各维度相加总size - size_t SelectSize(); - // push value维度 - size_t UpdateDim(); - // push value各个维度的size - size_t UpdateDimSize(size_t dim); - // push value各维度相加总size - size_t UpdateSize(); - size_t fea_dim() { return _config.fea_dim(); } + // 初始化AccessorInfo + virtual void InitAccessorInfo(); // 判断该value是否进行shrink virtual bool Shrink(float * /*value*/); // 判断该value是否在save阶段dump, diff --git a/paddle/fluid/distributed/test/ctr_accessor_test.cc b/paddle/fluid/distributed/test/ctr_accessor_test.cc index 8d9d0abd2394c..844aa54946c4c 100644 --- a/paddle/fluid/distributed/test/ctr_accessor_test.cc +++ b/paddle/fluid/distributed/test/ctr_accessor_test.cc @@ -75,8 +75,8 @@ TEST(downpour_feature_value_accessor_test, test_shrink) { << acc->common_feature_value.embedx_sgd_dim << " " << acc->common_feature_value.Dim() << "\n"; - float* value = new float[acc->Dim()]; - for (auto i = 0u; i < acc->Dim(); ++i) { + float* value = new float[acc->GetAccessorInfo().dim]; + for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) { value[i] = i * 1.0; } ASSERT_TRUE(!acc->Shrink(value)); @@ -94,8 +94,8 @@ TEST(downpour_feature_value_accessor_test, test_save) { ASSERT_EQ(acc->Configure(parameter), 0); ASSERT_EQ(acc->Initialize(), 0); - float* value = new float[acc->Dim()]; - for (auto i = 0u; i < acc->Dim(); ++i) { + float* value = new float[acc->GetAccessorInfo().dim]; + for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) { value[i] = i * 1.0; } @@ -109,7 +109,7 @@ TEST(downpour_feature_value_accessor_test, test_save) { ASSERT_TRUE(acc->Save(value, 2)); VLOG(3) << "test_save:"; - for (auto i = 0u; i < acc->Dim(); ++i) { + for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) { VLOG(3) << value[i]; } } @@ -145,7 +145,7 @@ TEST(downpour_feature_value_accessor_test, test_update) { ASSERT_EQ(acc->Initialize(), 0); VLOG(3) << "dim: " << acc->common_feature_value.Dim() << "\n"; - VLOG(3) << "update_dim: " << acc->GetTableInfo(UPDATE_DIM) << "\n"; + VLOG(3) << "update_dim: " << acc->GetAccessorInfo().update_dim << "\n"; const int field_size = 7 + 8; const int item_size = 10; @@ -162,8 +162,8 @@ TEST(downpour_feature_value_accessor_test, test_update) { typedef const float* const_float_ptr; const_float_ptr* grad = new const_float_ptr[item_size]; for (auto i = 0u; i < item_size; ++i) { - float* p = new float[acc->GetTableInfo(UPDATE_DIM)]; - for (auto j = 0u; j < acc->GetTableInfo(UPDATE_DIM); ++j) { + float* p = new float[acc->GetAccessorInfo().update_dim]; + for (auto j = 0u; j < acc->GetAccessorInfo().update_dim; ++j) { p[j] = i; } grad[i] = p; @@ -244,21 +244,21 @@ TEST(downpour_feature_value_accessor_test, test_update) { v.unseen_days = 0; v.show += push_v.show; v.click += push_v.click; - v.delta_score += acc->show_click_score(push_v.show, push_v.click); + v.delta_score += acc->ShowClickScore(push_v.show, push_v.click); - acc->_embed_sgd_rule->update_value(&v.embed_w, &v.embed_g2sum[0], - &push_v.embed_g); - acc->_embedx_sgd_rule->update_value(&v.embedx_w[0], &v.embedx_g2sum[0], - &push_v.embedx_g[0]); + acc->_embed_sgd_rule->UpdateValue(&v.embed_w, &v.embed_g2sum[0], + &push_v.embed_g); + acc->_embedx_sgd_rule->UpdateValue(&v.embedx_w[0], &v.embedx_g2sum[0], + &push_v.embedx_g[0]); - float* ptr = new float[acc->Dim()]; + float* ptr = new float[acc->GetAccessorInfo().dim]; v.to_array(ptr, parameter.embedx_dim()); exp_value.push_back(ptr); } acc->Update(value, grad, item_size); for (auto i = 0u; i < item_size; ++i) { - for (auto j = 0u; j < acc->Dim(); ++j) { + for (auto j = 0u; j < acc->GetAccessorInfo().dim; ++j) { VLOG(3) << value[i][j] << ":" << exp_value[i][j] << " "; ASSERT_FLOAT_EQ(value[i][j], exp_value[i][j]); } @@ -273,7 +273,7 @@ TEST(downpour_feature_value_accessor_test, test_show_click_score) { float show = 10; float click = 6; - ASSERT_FLOAT_EQ(acc->show_click_score(show, click), 6.8); + ASSERT_FLOAT_EQ(acc->ShowClickScore(show, click), 6.8); } TEST(downpour_feature_value_accessor_test, test_string_related) { diff --git a/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc b/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc index c895231d93ec5..1a4e16b926619 100644 --- a/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc +++ b/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc @@ -31,22 +31,22 @@ TEST(sparse_value_naive_sgd_test, init_and_update) { naive_param->add_weight_bounds(-10.0); naive_param->add_weight_bounds(10.0); - rule.load_config(param, 10); + rule.LoadConfig(param, 10); // check init_value for zero const int kItemSize = 10; float w[kItemSize]; float grad[kItemSize]; - rule.init_value(w, w + 9, true); + rule.InitValue(w, w + 9, true); for (auto i = 0u; i < kItemSize; ++i) { ASSERT_FLOAT_EQ(w[i], 0); } // check init_value for random - rule.init_value(w, w + 9, false); + rule.InitValue(w, w + 9, false); for (auto i = 0u; i < kItemSize; ++i) { - ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound()); + ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound()); } // check update_value for one field @@ -59,7 +59,7 @@ TEST(sparse_value_naive_sgd_test, init_and_update) { float label[] = {-0.100000, -0.200000, -0.300000, -0.400000, -0.500000, -0.600000, -0.700000, -0.800000, -0.900000, -1.000000}; const float* ptr_grad = grad; - rule.update_value(w, w + 9, ptr_grad); + rule.UpdateValue(w, w + 9, ptr_grad); for (auto i = 0u; i < kItemSize; ++i) { VLOG(3) << w[i] << "\n"; @@ -78,14 +78,14 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { adagrad_param->add_weight_bounds(-10.0); adagrad_param->add_weight_bounds(10.0); - rule.load_config(param, 10); + rule.LoadConfig(param, 10); // check init_value for zero const int kValueSize = 11; int kEmbSize = 10; float w[kValueSize]; - rule.init_value(w, w + 10, true); + rule.InitValue(w, w + 10, true); for (auto i = 0u; i < kEmbSize; ++i) { ASSERT_FLOAT_EQ(w[i], 0); @@ -93,9 +93,9 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { ASSERT_FLOAT_EQ(w[kEmbSize], 0); // check init_value for random - rule.init_value(w, w + 10, false); + rule.InitValue(w, w + 10, false); for (auto i = 0u; i < kEmbSize; ++i) { - ASSERT_TRUE(w[i] >= rule.min_bound() && w[i] <= rule.max_bound()); + ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound()); } ASSERT_FLOAT_EQ(w[kEmbSize], 0); @@ -110,7 +110,7 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { } const float* ptr_grad = grad; - rule.update_value(w, w + 10, ptr_grad); + rule.UpdateValue(w, w + 10, ptr_grad); float label[] = {-0.100000, -0.200000, -0.300000, -0.400000, -0.500000, -0.600000, -0.700000, -0.800000, -0.900000, -1.000000, 38.500000}; @@ -140,33 +140,33 @@ TEST(downpour_sparse_adam_test, test_init_and_update) { SparseAdamSGDRule rule; - rule.load_config(param, embed_dim); + rule.LoadConfig(param, embed_dim); // check init_value for zero const int rule_dim = - rule.dim(); // dims of gsum + g2sum + beta1_pow + beta2_pow in adam + rule.Dim(); // dims of gsum + g2sum + beta1_pow + beta2_pow in adam const int value_dim = embed_dim + rule_dim; // total dims of w + rule float* value = new float[value_dim]; - rule.init_value(value, value + embed_dim, true); - for (auto i = 0u; i < rule.beta1_pow_index(); ++i) { + rule.InitValue(value, value + embed_dim, true); + for (auto i = 0u; i < rule.Beta1PowIndex(); ++i) { ASSERT_FLOAT_EQ(value[i], 0); } - ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9); - ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999); + ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta1PowIndex()), 0.9); + ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta2PowIndex()), 0.999); // check init_value for random - rule.init_value(value, value + embed_dim, false); + rule.InitValue(value, value + embed_dim, false); for (auto i = 0u; i < embed_dim; ++i) { - ASSERT_TRUE(value[i] >= rule.min_bound() && value[i] <= rule.max_bound()); + ASSERT_TRUE(value[i] >= rule.MinBound() && value[i] <= rule.MaxBound()); } - for (auto i = rule.gsum_index(); i < rule.beta1_pow_index(); ++i) { + for (auto i = rule.GSumIndex(); i < rule.Beta1PowIndex(); ++i) { ASSERT_FLOAT_EQ(value[i + embed_dim], 0); } - ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta1_pow_index()), 0.9); - ASSERT_FLOAT_EQ(*(value + embed_dim + rule.beta2_pow_index()), 0.999); + ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta1PowIndex()), 0.9); + ASSERT_FLOAT_EQ(*(value + embed_dim + rule.Beta2PowIndex()), 0.999); // check update_value - rule.init_value(value, value + embed_dim, true); + rule.InitValue(value, value + embed_dim, true); float* grad = new float[embed_dim]; for (auto i = 0u; i < embed_dim; ++i) { grad[i] = (i + 1) * 1.0; @@ -181,7 +181,7 @@ TEST(downpour_sparse_adam_test, test_init_and_update) { 0.0249996781, 0.0359995365, 0.0489993691, 0.063999176, 0.0809989572, 0.0999987125, 0.809999943, 0.998001039}; - rule.update_value(value, value + embed_dim, grad); + rule.UpdateValue(value, value + embed_dim, grad); for (auto i = 0u; i < value_dim; ++i) { // check update ASSERT_FLOAT_EQ(value[i], label[i]) << "i is " << i; diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index ae1a63d72a5cf..4e975e74bdb14 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1668,7 +1668,7 @@ def _minimize_impl(self, opt_info["mpi_rank"] = self.worker_index() for k, v in self._user_defined_strategy.trainer_desc_configs.items( ): - if v: + if v or k not in opt_info: opt_info[k] = v program._fleet_opt = opt_info @@ -1745,7 +1745,7 @@ def _minimize_losses_impl(self, opt_info["mpi_rank"] = self.worker_index() for k, v in self._user_defined_strategy.trainer_desc_configs.items( ): - if v: + if v or k not in opt_info: opt_info[k] = v program._fleet_opt = opt_info # print("fleet base opt info:", id(program), program._fleet_opt) From a5e00bb7239956e10766c3b89d1919416af9c646 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Sat, 2 Apr 2022 16:54:31 +0800 Subject: [PATCH 21/93] [DoubleGrad PR #6] Fixed issues with TensorWrapper::recover() interface (#41287) --- .../final_state_generator/eager_gen.py | 4 ++-- paddle/fluid/eager/grad_node_info.h | 2 +- paddle/fluid/eager/tensor_wrapper.h | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index fb86c5da6856c..0d1d3ab722522 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -1249,9 +1249,9 @@ def GenerateNodeDefinition(self, grad_node_creation_str): is_optional = (name in self.optional_inputs) if is_optional: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" else: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" grad_api_args[grad_api_position] = transformed_tensor_name get_grad_in_args_list.append(tensor_wrapper_recover_str) diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 0d07f780dda9d..70fc4afa0ac71 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -87,7 +87,7 @@ class GradSlotMeta { std::shared_ptr meta_ = nullptr; }; -class GradNodeBase { +class GradNodeBase : public std::enable_shared_from_this { public: GradNodeBase() { VLOG(6) << "Construct GradNodeBase"; } GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num); diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index e7886339f06b1..dc4cf379390f1 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -95,18 +95,19 @@ class TensorWrapper { } check_inplace_version(); + // if it's full_reserved just return the full copy of tensor - if (full_reserved_) { - return intermidiate_tensor_; - } else { + paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; + if (!full_reserved_) { std::shared_ptr new_grad_node = grad_node; auto p_ab_autograd_meta = std::make_shared(Edge(new_grad_node, out_rank_info_)); - intermidiate_tensor_.set_autograd_meta( + recovered_tensor.set_autograd_meta( std::static_pointer_cast( p_ab_autograd_meta)); - return intermidiate_tensor_; } + + return recovered_tensor; } void check_inplace_version() { From e59a693ead47ef75756782fdde5f2f96c5088a7e Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Apr 2022 16:57:14 +0800 Subject: [PATCH 22/93] enable new-executor on windows to test it (#41301) * enable new-executor on windows to test it * add message * fix ut --- paddle/phi/kernels/gpu/range_kernel.cu | 19 ++++++++++++++++--- python/paddle/fluid/executor.py | 13 ++++++++++++- .../tests/unittests/check_nan_inf_base.py | 15 ++++++++------- .../fluid/tests/unittests/test_nan_inf.py | 8 +++++--- 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/paddle/phi/kernels/gpu/range_kernel.cu b/paddle/phi/kernels/gpu/range_kernel.cu index 65d9b45efbcdd..d9a98f06d0795 100644 --- a/paddle/phi/kernels/gpu/range_kernel.cu +++ b/paddle/phi/kernels/gpu/range_kernel.cu @@ -21,6 +21,19 @@ namespace phi { +template +inline T GetValue(const Context& dev_ctx, const DenseTensor& x) { + T value = static_cast(0); + if (x.place() != CPUPlace()) { + DenseTensor cpu_x; + Copy(dev_ctx, x, CPUPlace(), true, &cpu_x); + value = cpu_x.data()[0]; + } else { + value = x.data()[0]; + } + return value; +} + template __global__ void Range(T start, T step, int64_t size, T* out) { CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } @@ -32,9 +45,9 @@ void RangeKernel(const Context& dev_ctx, const DenseTensor& end, const DenseTensor& step, DenseTensor* out) { - T start_value = start.data()[0]; - T end_value = end.data()[0]; - T step_value = step.data()[0]; + T start_value = GetValue(dev_ctx, start); + T end_value = GetValue(dev_ctx, end); + T step_value = GetValue(dev_ctx, step); int64_t size = 0; phi::funcs::GetSize(start_value, end_value, step_value, &size); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index a7971763f53e1..eb833428afa42 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -394,9 +394,20 @@ def _is_enable_standalone_executor(): Whether to use experimental executor `StandaloneExecutor`. """ flag = False - env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None) + # NOTE(zhiqiu): enable STANDALONE_EXECUTOR on windows platform by default + # It should be enabled on all platform in the future. + + import platform + sysstr = platform.system().lower() + if sysstr == 'windows': + env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', 1) + else: + env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None) + if env_val in [1, '1', True, 'True', 'true']: flag = True + warnings.warn("STANDALONE_EXECUTOR is enabled.") + return flag diff --git a/python/paddle/fluid/tests/unittests/check_nan_inf_base.py b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py index 1c5db616306ca..13a7ff6860e4d 100644 --- a/python/paddle/fluid/tests/unittests/check_nan_inf_base.py +++ b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py @@ -103,6 +103,14 @@ def check(use_cuda): if __name__ == '__main__': + try: + check(use_cuda=False) + assert False + except Exception as e: + print(e) + print(type(e)) + assert type(e) == RuntimeError + if core.is_compiled_with_cuda(): try: check(use_cuda=True) @@ -113,10 +121,3 @@ def check(use_cuda): # Note. Enforce in cuda kernel may not catch in paddle, and # Exception type will be RuntimeError assert type(e) == OSError or type(e) == RuntimeError - try: - check(use_cuda=False) - assert False - except Exception as e: - print(e) - print(type(e)) - assert type(e) == RuntimeError diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf.py b/python/paddle/fluid/tests/unittests/test_nan_inf.py index cb7e673c6ca29..84559048a2b8a 100644 --- a/python/paddle/fluid/tests/unittests/test_nan_inf.py +++ b/python/paddle/fluid/tests/unittests/test_nan_inf.py @@ -47,10 +47,12 @@ def check_nan_inf(self): print(out) print(err) - assert returncode == 0 # in python3, type(out+err) is 'bytes', need use encode - assert (out + err - ).find('There are `nan` or `inf` in tensor'.encode()) != -1 + if paddle.fluid.core.is_compiled_with_cuda(): + assert (out + err).find('find nan or inf==='.encode()) != -1 + else: + assert (out + err + ).find('There are `nan` or `inf` in tensor'.encode()) != -1 def test_nan_inf_in_static_mode(self): self._python_interp += " check_nan_inf_base.py" From c06580451271133d185e3f32dbaf8101f5a00333 Mon Sep 17 00:00:00 2001 From: wuyefeilin <30919197+wuyefeilin@users.noreply.github.com> Date: Sat, 2 Apr 2022 17:09:50 +0800 Subject: [PATCH 23/93] [phi] Move clip op to phi (#40602) * move clip op to phi * fix as review * update hierarchical_sigmoid_kernel.cc * update selected_rows * update clip_kernel.cu * fix as review --- paddle/fluid/operators/clip_op.cc | 53 ++--- paddle/fluid/operators/clip_op.cu | 32 --- paddle/fluid/operators/clip_op.h | 196 ------------------ paddle/fluid/operators/clip_op_npu.cc | 30 +-- paddle/fluid/operators/clip_op_xpu.cc | 27 ++- paddle/fluid/operators/fake_quantize_op.cc | 14 +- paddle/phi/kernels/clip_grad_kernel.h | 31 +++ paddle/phi/kernels/clip_kernel.h | 31 +++ paddle/phi/kernels/cpu/clip_grad_kernel.cc | 27 +++ paddle/phi/kernels/cpu/clip_kernel.cc | 21 ++ .../cpu/hierarchical_sigmoid_kernel.cc | 5 +- paddle/phi/kernels/gpu/clip_grad_kernel.cu | 28 +++ paddle/phi/kernels/gpu/clip_kernel.cu | 30 +++ .../phi/kernels/impl/clip_grad_kernel_impl.h | 74 +++++++ paddle/phi/kernels/impl/clip_kernel_impl.h | 79 +++++++ .../phi/kernels/selected_rows/clip_kernel.h | 34 +++ .../kernels/selected_rows/cpu/clip_kernel.cc | 28 +++ .../kernels/selected_rows/gpu/clip_kernel.cu | 30 +++ .../selected_rows/impl/clip_kernel_impl.h | 62 ++++++ paddle/phi/ops/compat/clip_sig.cc | 88 ++++++++ 20 files changed, 619 insertions(+), 301 deletions(-) delete mode 100644 paddle/fluid/operators/clip_op.cu delete mode 100644 paddle/fluid/operators/clip_op.h create mode 100644 paddle/phi/kernels/clip_grad_kernel.h create mode 100644 paddle/phi/kernels/clip_kernel.h create mode 100644 paddle/phi/kernels/cpu/clip_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/clip_kernel.cc create mode 100644 paddle/phi/kernels/gpu/clip_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/clip_kernel.cu create mode 100644 paddle/phi/kernels/impl/clip_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/clip_kernel_impl.h create mode 100644 paddle/phi/kernels/selected_rows/clip_kernel.h create mode 100644 paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc create mode 100644 paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu create mode 100644 paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h create mode 100644 paddle/phi/ops/compat/clip_sig.cc diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index 436d1edcedf1e..6e898d31663fa 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -1,21 +1,23 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/clip_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -23,15 +25,6 @@ namespace operators { class ClipOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "clip"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "clip"); - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = @@ -176,23 +169,15 @@ class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(clip, ClipInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker, ops::ClipGradOpMaker, ops::ClipGradOpMaker, - ops::ClipInplaceInferer); + ops::ClipInplaceInferer, ClipInferShapeFunctor); REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer, ops::ClipDoubleGradOpMaker, ops::ClipDoubleGradOpMaker); -REGISTER_OP_CPU_KERNEL( - clip, ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel); -REGISTER_OP_CPU_KERNEL( - clip_grad, ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel); REGISTER_OP_VERSION(clip) .AddCheckpoint( diff --git a/paddle/fluid/operators/clip_op.cu b/paddle/fluid/operators/clip_op.cu deleted file mode 100644 index 846354fcb81c5..0000000000000 --- a/paddle/fluid/operators/clip_op.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/clip_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - clip, ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel, - ops::ClipKernel); - -REGISTER_OP_CUDA_KERNEL( - clip_grad, ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel, - ops::ClipGradKernel); diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h deleted file mode 100644 index 3b815cd1fa74a..0000000000000 --- a/paddle/fluid/operators/clip_op.h +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/transform.h" -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#endif - -namespace paddle { -namespace operators { - -using framework::Tensor; -using platform::Transform; - -template -class ClipFunctor { - public: - explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} - HOSTDEVICE T operator()(const T x) const { - return x < min_ ? min_ : x > max_ ? max_ : x; - } - - private: - T min_; - T max_; -}; - -template -class ClipGradFunctor { - public: - explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} - HOSTDEVICE T operator()(const T x, const T y) const { - return (y > min_ && y < max_) ? x : static_cast(0); - } - - private: - T min_; - T max_; -}; - -template -class ClipKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto max = static_cast(context.Attr("max")); - Tensor max_cpu; - if (context.HasInput("Max")) { - auto* max_t = context.Input("Max"); - auto* max_data = max_t->data(); - if (platform::is_gpu_place(max_t->place())) { - paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(), - &max_cpu); - max_data = max_cpu.data(); - } - max = max_data[0]; - } - max = static_cast(max); - - auto min = static_cast(context.Attr("min")); - Tensor min_cpu; - if (context.HasInput("Min")) { - auto* min_t = context.Input("Min"); - auto* min_data = min_t->data(); - if (platform::is_gpu_place(min_t->place())) { - paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(), - &min_cpu); - min_data = min_cpu.data(); - } - min = min_data[0]; - } - - PADDLE_ENFORCE_LE(min, max, - platform::errors::InvalidArgument( - "max should be greater than or equal to min. " - "But received min = %f, max = %f", - static_cast(min), static_cast(max))); - - auto* x_var = context.InputVar("X"); - if (x_var->IsType()) { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - T* out_data = out->mutable_data(context.GetPlace()); - const T* x_data = x->data(); - int64_t numel = x->numel(); - if (platform::is_gpu_place(context.GetPlace())) { -#if defined(__NVCC__) || defined(__HIPCC__) - std::vector ins = {x}; - std::vector outs = {out}; - auto functor = ClipFunctor(min, max); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - context.template device_context(), ins, - &outs, functor); -#endif - } else { - Transform trans; - trans(context.template device_context(), x_data, - x_data + numel, out_data, ClipFunctor(min, max)); - } - } else if (x_var->IsType()) { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument( - "Inplace clip is not allowed " - "when x is SelectedRows")); - math::scatter::MergeAdd merge_func; - merge_func(context.template device_context(), *x, out); - auto* out_tensor = out->mutable_value(); - auto* out_data = out_tensor->data(); - int64_t numel = out_tensor->numel(); - Transform trans; - trans(context.template device_context(), out_data, - out_data + numel, out_data, ClipFunctor(min, max)); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "ClipOp only supports LoDTensor and SelectedRows.")); - } - } -}; - -template -class ClipGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto max = static_cast(context.Attr("max")); - Tensor max_cpu; - if (context.HasInput("Max")) { - auto* max_t = context.Input("Max"); - auto* max_data = max_t->data(); - if (platform::is_gpu_place(max_t->place())) { - paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(), - &max_cpu); - max_data = max_cpu.data(); - } - max = max_data[0]; - } - max = static_cast(max); - - auto min = static_cast(context.Attr("min")); - Tensor min_cpu; - if (context.HasInput("Min")) { - auto* min_t = context.Input("Min"); - auto* min_data = min_t->data(); - if (platform::is_gpu_place(min_t->place())) { - paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(), - &min_cpu); - min_data = min_cpu.data(); - } - min = min_data[0]; - } - min = static_cast(min); - - auto* d_out = - context.Input(framework::GradVarName("Out")); - auto* d_x = - context.Output(framework::GradVarName("X")); - if (d_x != nullptr) { - auto* x = context.Input("X"); -#if defined(__NVCC__) || defined(__HIPCC__) - std::vector ins = {d_out, x}; - std::vector outs = {d_x}; - auto functor = ClipGradFunctor(min, max); - d_x->mutable_data(context.GetPlace()); - LaunchSameDimsElementwiseCudaKernel( - context.template device_context(), ins, - &outs, functor); -#else - int64_t numel = d_out->numel(); - auto* d_x_data = d_x->mutable_data(context.GetPlace()); - const T* d_out_data = d_out->data(); - const T* x_data = x->data(); - Transform trans; - trans(context.template device_context(), d_out_data, - d_out_data + numel, x_data, d_x_data, ClipGradFunctor(min, max)); -#endif - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/clip_op_npu.cc b/paddle/fluid/operators/clip_op_npu.cc index 372ba707329bb..17d7ad9796504 100644 --- a/paddle/fluid/operators/clip_op_npu.cc +++ b/paddle/fluid/operators/clip_op_npu.cc @@ -1,18 +1,18 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/clip_op.h" +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/clip_op_xpu.cc b/paddle/fluid/operators/clip_op_xpu.cc index c53bb2d9e4d0c..c551312837274 100644 --- a/paddle/fluid/operators/clip_op_xpu.cc +++ b/paddle/fluid/operators/clip_op_xpu.cc @@ -1,20 +1,19 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 4544386718813..ac72f23d46ea8 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -17,8 +17,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace paddle { namespace operators { @@ -91,7 +91,7 @@ struct ClipAndFakeQuantFunctor { T inv_s = inverse(s); platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); } @@ -109,7 +109,7 @@ struct ClipAndFakeQuantDequantFunctor { platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); @@ -144,7 +144,7 @@ struct ChannelClipAndFakeQuantFunctor { auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); + phi::ClipFunctor(-s, s)); } for (int64_t i = 0; i < channel; i++) { T s = scale_data[i]; @@ -163,7 +163,7 @@ struct ChannelClipAndFakeQuantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); for (int k = 0; k < step_j; k++) { cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); } @@ -200,7 +200,7 @@ struct ChannelClipFakeQuantDequantFunctor { auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); + phi::ClipFunctor(-s, s)); } for (int i = 0; i < channel; i++) { T s = scale_data[i]; @@ -220,7 +220,7 @@ struct ChannelClipFakeQuantDequantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); for (int k = 0; k < step_j; k++) { cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) * s / static_cast(bin_cnt); diff --git a/paddle/phi/kernels/clip_grad_kernel.h b/paddle/phi/kernels/clip_grad_kernel.h new file mode 100644 index 0000000000000..8a7e5b99fd924 --- /dev/null +++ b/paddle/phi/kernels/clip_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void ClipGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& min, + const Scalar& max, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/clip_kernel.h b/paddle/phi/kernels/clip_kernel.h new file mode 100644 index 0000000000000..14ac8342e03bc --- /dev/null +++ b/paddle/phi/kernels/clip_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/clip_grad_kernel.cc b/paddle/phi/kernels/cpu/clip_grad_kernel.cc new file mode 100644 index 0000000000000..bccdc0746d51c --- /dev/null +++ b/paddle/phi/kernels/cpu/clip_grad_kernel.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_grad, + CPU, + ALL_LAYOUT, + phi::ClipGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/clip_kernel.cc b/paddle/phi/kernels/cpu/clip_kernel.cc new file mode 100644 index 0000000000000..5fd9aea966f8d --- /dev/null +++ b/paddle/phi/kernels/cpu/clip_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL( + clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc index 096a54f9fb263..4c4f1aa125a33 100644 --- a/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/hierarchical_sigmoid_kernel.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/cpu/cpu_context.h" @@ -22,6 +21,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function_impl.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace phi { @@ -92,8 +92,7 @@ void HierarchicalSigmoidKernel(const Context& ctx, pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, - paddle::operators::ClipFunctor(static_cast(-40.0), - static_cast(40.0))); + ClipFunctor(static_cast(-40.0), static_cast(40.0))); bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu new file mode 100644 index 0000000000000..b76086be64887 --- /dev/null +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_grad, + GPU, + ALL_LAYOUT, + phi::ClipGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu new file mode 100644 index 0000000000000..9e0050db7fdbf --- /dev/null +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip, + GPU, + ALL_LAYOUT, + phi::ClipKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/clip_grad_kernel_impl.h b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h new file mode 100644 index 0000000000000..7ce86492327ba --- /dev/null +++ b/paddle/phi/kernels/impl/clip_grad_kernel_impl.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/transform.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#endif + +namespace phi { + +template +class ClipGradFunctor { + public: + explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T x, const T y) const { + return (y > min_ && y < max_) ? x : static_cast(0); + } + + private: + T min_; + T max_; +}; + +template +void ClipGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& min, + const Scalar& max, + DenseTensor* x_grad) { + auto max_ = max.to(); + auto min_ = min.to(); + +#if defined(__NVCC__) || defined(__HIPCC__) + std::vector ins = {&out_grad, &x}; + std::vector outs = {x_grad}; + auto functor = ClipGradFunctor(min_, max_); + dev_ctx.template Alloc(x_grad); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +#else + int64_t numel = out_grad.numel(); + auto* d_x_data = dev_ctx.template Alloc(x_grad); + const T* d_out_data = out_grad.data(); + const T* x_data = x.data(); + paddle::platform::Transform trans; + trans(dev_ctx, + d_out_data, + d_out_data + numel, + x_data, + d_x_data, + ClipGradFunctor(min_, max_)); +#endif +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/clip_kernel_impl.h b/paddle/phi/kernels/impl/clip_kernel_impl.h new file mode 100644 index 0000000000000..17c04c31a598a --- /dev/null +++ b/paddle/phi/kernels/impl/clip_kernel_impl.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/transform.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#endif + +namespace phi { + +template +class ClipFunctor { + public: + explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T x) const { + return x < min_ ? min_ : x > max_ ? max_ : x; + } + + private: + T min_; + T max_; +}; + +template +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out) { + auto max_ = max.to(); + auto min_ = min.to(); + + PADDLE_ENFORCE_LE( + min_, + max_, + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_), + static_cast(max_))); + + T* out_data = dev_ctx.template Alloc(out); + // const T* x_data = x->data(); + // int64_t numel = x->numel(); + const T* x_data = x.data(); + int64_t numel = x.numel(); + if (paddle::platform::is_gpu_place(dev_ctx.GetPlace())) { +#if defined(__NVCC__) || defined(__HIPCC__) + std::vector ins = {&x}; + std::vector outs = {out}; + auto functor = ClipFunctor(min_, max_); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +#endif + } else { + paddle::platform::Transform trans; + trans( + dev_ctx, x_data, x_data + numel, out_data, ClipFunctor(min_, max_)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/clip_kernel.h b/paddle/phi/kernels/selected_rows/clip_kernel.h new file mode 100644 index 0000000000000..ec56d92c513ea --- /dev/null +++ b/paddle/phi/kernels/selected_rows/clip_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" + +namespace phi { +namespace sr { + +template +void ClipSparseKernel(const Context& dev_ctx, + const SelectedRows& x, + const Scalar& min, + const Scalar& max, + SelectedRows* out); +} // namespace sr +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc new file mode 100644 index 0000000000000..0098bf13f2b2f --- /dev/null +++ b/paddle/phi/kernels/selected_rows/cpu/clip_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_sr, + CPU, + ALL_LAYOUT, + phi::sr::ClipSparseKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu new file mode 100644 index 0000000000000..a8d659559e19e --- /dev/null +++ b/paddle/phi/kernels/selected_rows/gpu/clip_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h" + +PD_REGISTER_KERNEL(clip_sr, + GPU, + ALL_LAYOUT, + phi::sr::ClipSparseKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h new file mode 100644 index 0000000000000..1d95e633b93a6 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/selected_rows/clip_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { +namespace sr { + +template +void ClipSparseKernel(const Context& dev_ctx, + const SelectedRows& x, + const Scalar& min, + const Scalar& max, + SelectedRows* out) { + auto max_ = max.to(); + auto min_ = min.to(); + + PADDLE_ENFORCE_LE( + min_, + max_, + errors::InvalidArgument("max should be greater than or equal to min. " + "But received min = %f, max = %f", + static_cast(min_), + static_cast(max_))); + + PADDLE_ENFORCE_NE(&x, + out, + errors::InvalidArgument("Inplace clip is not allowed " + "when x is SelectedRows")); + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, x, out); + auto* out_tensor = out->mutable_value(); + auto* out_data = out_tensor->data(); + int64_t numel = out_tensor->numel(); + paddle::platform::Transform trans; + trans(dev_ctx, + out_data, + out_data + numel, + out_data, + ClipFunctor(min_, max_)); +} +} // namespace sr +} // namespace phi diff --git a/paddle/phi/ops/compat/clip_sig.cc b/paddle/phi/ops/compat/clip_sig.cc new file mode 100644 index 0000000000000..78fa6c36a5149 --- /dev/null +++ b/paddle/phi/ops/compat/clip_sig.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { + paddle::SmallVector attr_names; + attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); + attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); + if (ctx.IsDenseTensorInput("X")) { + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip", {"X"}, {"Min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip", {"X"}, {"Min", "max"}, {"Out"}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip", {"X"}, {"min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip", {"X"}, {"min", "max"}, {"Out"}); + } + } + } else if (ctx.IsSelectedRowsInput("X")) { + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_sr", {"X"}, {"Min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip_sr", {"X"}, {"Min", "max"}, {"Out"}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_sr", {"X"}, {"min", "Max"}, {"Out"}); + } else { + return KernelSignature("clip_sr", {"X"}, {"min", "max"}, {"Out"}); + } + } + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Min")) { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"Min", "Max"}, + {GradVarName("X")}); + } else { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"Min", "max"}, + {GradVarName("X")}); + } + } else { + if (ctx.HasInput("Max")) { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"min", "Max"}, + {GradVarName("X")}); + } else { + return KernelSignature("clip_grad", + {"X", GradVarName("Out")}, + {"min", "max"}, + {GradVarName("X")}); + } + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(clip, phi::ClipOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(clip_grad, phi::ClipGradOpArgumentMapping); From 4a09da02441a1b0c2afd83d3cdc83aa57e9040ad Mon Sep 17 00:00:00 2001 From: pangyoki Date: Sat, 2 Apr 2022 17:19:33 +0800 Subject: [PATCH 24/93] fix test_tunable_variable (#41268) --- .../tests/unittests/auto_parallel/test_tunable_variable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py index c36fca7a9d09a..ade228f6c494b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py @@ -76,7 +76,7 @@ def test_float_range(self): "float_range", start=0.4, stop=4.4, default=2.0) float_range = tv.FloatRange.from_state(float_range.get_state()) self.assertEqual(float_range.default, 2.0) - self.assertGreater(float_range.random(), 0.4) + self.assertGreaterEqual(float_range.random(), 0.4) self.assertLess(float_range.random(1234), 4.4) self.assertNotAlmostEqual(float_range.random(), 1) self.assertNotAlmostEqual(float_range.random(), 4.4) @@ -90,7 +90,7 @@ def test_float_range(self): endpoint=True) float_range = tv.FloatRange.from_state(float_range.get_state()) self.assertEqual(float_range.default, 3.0) - self.assertGreater(float_range.random(), 0.4) + self.assertGreaterEqual(float_range.random(), 0.4) self.assertLessEqual(float_range.random(1234), 8.4) self.assertNotAlmostEqual(float_range.random(), 2) From 1b58ce144a340ff895dedeeab68e3a3a3ab36c06 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Sat, 2 Apr 2022 17:23:43 +0800 Subject: [PATCH 25/93] [Paddle inference] support new quant_model (#41049) * paddle inference support new quant_model --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/add_support_int8_pass.cc | 61 ++- .../ir/delete_quant_dequant_linear_op_pass.cc | 148 +++++++ .../ir/delete_quant_dequant_linear_op_pass.h | 35 ++ .../ir/delete_quant_dequant_op_pass.cc | 10 +- .../delete_weight_dequant_linear_op_pass.cc | 415 ++++++++++++++++++ .../ir/delete_weight_dequant_linear_op_pass.h | 35 ++ paddle/fluid/framework/ir/fc_fuse_pass.cc | 29 +- .../ir/gpu_cpu_map_matmul_to_mul_pass.cc | 19 +- .../framework/ir/graph_pattern_detector.cc | 101 ++++- .../framework/ir/graph_pattern_detector.h | 36 +- .../ir/multihead_matmul_fuse_pass.cc | 51 +-- .../ir/quant_conv2d_dequant_fuse_pass.cc | 11 +- .../ir/trt_map_matmul_to_mul_pass.cc | 101 ++++- .../inference/api/paddle_pass_builder.cc | 16 +- .../tensorrt/convert/activation_op.cc | 6 - .../tensorrt/convert/affine_channel_op.cc | 4 +- .../inference/tensorrt/convert/conv2d_op.cc | 13 +- .../inference/tensorrt/convert/conv3d_op.cc | 11 +- .../tensorrt/convert/deformable_conv_op.cc | 3 +- .../tensorrt/convert/elementwise_op.cc | 20 +- .../tensorrt/convert/emb_eltwise_layernorm.cc | 2 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 60 +-- .../tensorrt/convert/group_norm_op.cc | 2 +- .../tensorrt/convert/leaky_relu_op.cc | 4 +- .../inference/tensorrt/convert/matmul_op.cc | 4 +- .../tensorrt/convert/multihead_matmul_op.cc | 46 +- .../inference/tensorrt/convert/op_converter.h | 88 ++-- .../inference/tensorrt/convert/pool2d_op.cc | 7 +- .../inference/tensorrt/convert/pool3d_op.cc | 5 +- .../convert/preln_emb_eltwise_layernorm.cc | 2 +- .../tensorrt/convert/preln_skip_layernorm.cc | 2 +- .../inference/tensorrt/convert/prelu_op.cc | 4 +- .../tensorrt/convert/skip_layernorm.cc | 2 +- paddle/fluid/inference/tensorrt/engine.cc | 4 +- paddle/fluid/inference/tensorrt/engine.h | 3 +- .../operators/compat/dequantize_linear.pbtxt | 25 ++ paddle/fluid/operators/compat/mul.pbtxt | 10 +- .../operators/compat/quantize_linear.pbtxt | 25 ++ .../test_trt_convert_multihead_matmul.py | 9 +- 40 files changed, 1146 insertions(+), 285 deletions(-) create mode 100644 paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc create mode 100644 paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h create mode 100644 paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc create mode 100644 paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h create mode 100644 paddle/fluid/operators/compat/dequantize_linear.pbtxt create mode 100644 paddle/fluid/operators/compat/quantize_linear.pbtxt diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 7aaaef712a6e9..8cacf34834a16 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -86,6 +86,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference) +pass_library(delete_weight_dequant_linear_op_pass inference) +pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_dropout_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/add_support_int8_pass.cc b/paddle/fluid/framework/ir/add_support_int8_pass.cc index d157d2e934ace..3a3f5c3741f4d 100644 --- a/paddle/fluid/framework/ir/add_support_int8_pass.cc +++ b/paddle/fluid/framework/ir/add_support_int8_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,11 +19,7 @@ namespace framework { namespace ir { #define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); -#define GET_NODES \ - GET_IR_NODE(prev_op); \ - GET_IR_NODE(prev_out); \ - GET_IR_NODE(quant_op); \ - GET_IR_NODE(quant_out); +#define GET_NODES GET_IR_NODE(quant_op); void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "add_support_int8"; @@ -37,10 +33,57 @@ void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_NODES; - if (prev_op->Op()->HasAttr("out_threshold") && - quant_op->Op()->HasAttr("out_threshold")) { - quant_op->Op()->SetAttr("support_int8", true); + + bool inscale_flag = false; + bool outscale_flag = false; + auto* quanted_op_desc = quant_op->Op(); + // If inputs'tensors have the inputs_scale, then save it's index in + // input_quant_tensor_index + // OP'Attr hasn't std::vector>. To do: Support multi-tensor + // scale for one input + for (size_t i = 0; i < quanted_op_desc->InputNames().size(); i++) { + if (quanted_op_desc->Input(quanted_op_desc->InputNames()[i]).size() > 0 && + quanted_op_desc->HasAttr( + "Input_scale_" + + quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0])) { + inscale_flag = true; + quanted_op_desc->SetAttr( + quanted_op_desc->InputNames()[i], + quanted_op_desc->GetAttr( + "Input_scale_" + + quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0])); + } + } + + // If outputs'tensors have the outputs_scale, then save it's index in + // output_quant_tensor_index + // OP'Attr hasn't std::vector>. To do: Support multi-tensor + // scale for one output + for (auto out_node : quant_op->outputs) { + for (auto out_op_node : out_node->outputs) { + for (auto name : out_op_node->Op()->InputNames()) { + for (auto input_name : out_op_node->Op()->Input(name)) { + if (out_op_node->Op()->HasAttr("Input_scale_" + input_name)) { + for (size_t i = 0; i < quanted_op_desc->OutputNames().size(); + i++) { + if (quanted_op_desc->Output(quanted_op_desc->OutputNames()[i]) + .size() > 0 && + input_name == + quanted_op_desc->Output( + quanted_op_desc->OutputNames()[i])[0]) { + outscale_flag = true; + quanted_op_desc->SetAttr( + quanted_op_desc->OutputNames()[i], + out_op_node->Op()->GetAttr("Input_scale_" + input_name)); + } + } + } + } + } + } } + quanted_op_desc->SetAttr("support_int8", inscale_flag && outscale_flag); + quanted_op_desc->Flush(); found_count++; }; gpd(graph, handler); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc new file mode 100644 index 0000000000000..8f2b58ed51b99 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h" + +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(quantize_linear_op_x); \ + GET_IR_NODE(quantize_linear_op_scale); \ + GET_IR_NODE(quantize_linear_op); \ + GET_IR_NODE(quantize_linear_op_out); \ + GET_IR_NODE(dequantize_linear_op); \ + GET_IR_NODE(dequantize_linear_op_out); \ + GET_IR_NODE(any_op2); + +DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { + AddOpCompat(OpCompat("quantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); + AddOpCompat(OpCompat("dequantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); +} +// Delete quantize_linear_op dequantize_linear_op, then add input_scales +void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "delete_quantdequant_linear_op_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::InvalidArgument( + "Scope in DeleteQuantDequantLinearOpPass should not be null.")); + // Create pattern + patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(), + pattern_name); + pattern(); + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + /* + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "delete_quant_dequant_linear_op_pass " + "compat check failed."; + return; + } + */ + std::unordered_set nodes2rm = {}; + int bit_length = + BOOST_GET_CONST(int, quantize_linear_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + + // Get input scale from tensor + const LoDTensor& input_scale_tensor = + scope->GetVar(quantize_linear_op_scale->Name())->Get(); + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(input_scale_tensor.place()), true, + platform::errors::InvalidArgument( + "Input scale tensor's place should be CPU.")); + const float* input_scale_data = input_scale_tensor.data(); + float input_scale = input_scale_data[0] / range; + + auto* any_op2_desc = any_op2->Op(); + any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), + input_scale); + + nodes2rm.insert(quantize_linear_op_scale); + nodes2rm.insert(quantize_linear_op); + nodes2rm.insert(quantize_linear_op_out); + nodes2rm.insert(dequantize_linear_op); + nodes2rm.insert(dequantize_linear_op_out); + + // link x to any_op2 + any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), + quantize_linear_op_x->Var()->Name()); + any_op2_desc->Flush(); + IR_NODE_LINK_TO(quantize_linear_op_x, any_op2); + GraphSafeRemoveNodes(graph, nodes2rm); + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_quant_dequant_linear_op_pass, + paddle::framework::ir::DeleteQuantDequantLinearOpPass); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h new file mode 100644 index 0000000000000..b00e3cb5c468b --- /dev/null +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteQuantDequantLinearOpPass : public FusePassBase { + public: + DeleteQuantDequantLinearOpPass(); + virtual ~DeleteQuantDequantLinearOpPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc index 63d68bd04b5f0..e2bb62dba7cf0 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc @@ -61,7 +61,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { GET_NODES; int bit_length = BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length")); - int range = ((1 << (bit_length - 1)) - 1); // Get input scale from tensor std::string input_scale_var_name = @@ -76,7 +75,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument( "Input scale tensor's place should be CPU.")); const float* input_scale_data = input_scale_tensor.data(); - float input_scale = input_scale_data[0] / range; + float input_scale = input_scale_data[0]; // Set input scale in attr, and relink nodes std::string input_name = input->Var()->Name(); @@ -85,12 +84,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { for (auto* quantized_node : outlinks) { auto op_desc = quantized_node->Op(); std::string quantized_op_type = op_desc->Type(); - if (quantized_op_type == "mul" || quantized_op_type == "matmul" || - quantized_op_type == "matmul_v2") { - op_desc->SetAttr("X_scale", input_scale); - } else { - op_desc->SetAttr("Input_scale", input_scale); - } + op_desc->SetAttr("Input_scale", input_scale); op_desc->SetAttr("bit_length", bit_length); op_desc->RenameInput(quant_dequant_output_name, input_name); op_desc->Flush(); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc new file mode 100644 index 0000000000000..8ebea231e7a2a --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -0,0 +1,415 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" + +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(weight_dequantize_linear_op_x); \ + GET_IR_NODE(weight_dequantize_linear_op_scale); \ + GET_IR_NODE(weight_dequantize_linear_op); \ + GET_IR_NODE(weight_dequantize_linear_op_out); \ + GET_IR_NODE(any_op2); + +DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { + AddOpCompat(OpCompat("quantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); + AddOpCompat(OpCompat("dequantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("depthwise_conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("output_padding") + .IsType>() + .IsOptional() + .End() + .AddAttr("output_size") + .IsType>() + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); +} +// Delete dequantize_linear_op, then dequantize weight +void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = + "delete_weight_quantdequant_linear_op_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::InvalidArgument( + "Scope in DeleteWeightQuantDequantLinearOpPass should not be null.")); + // Create pattern + patterns::DeleteWeightQuantDequantLinearOpPattern pattern( + gpd.mutable_pattern(), pattern_name); + pattern(); + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + /* + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "delete_weight_dequant_linear_op_pass " + "compat check failed."; + return; + } + */ + std::unordered_set nodes2rm = {}; + int bit_length = BOOST_GET_CONST( + int, weight_dequantize_linear_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + + auto* any_op2_desc = any_op2->Op(); + + // get weight tensor + auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name()) + ->GetMutable(); + int8_t* quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + auto w_dims = weight_tensor->dims(); + + // Get weight scale + std::vector weight_scale; + auto* weight_scale_tensor = + scope->GetVar(weight_dequantize_linear_op_scale->Name()) + ->GetMutable(); + float* weight_scale_data = + weight_scale_tensor->mutable_data(platform::CPUPlace()); + + auto weight_scale_nums = weight_scale_tensor->numel(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(weight_scale_data[i] / range); + } + + // dequant weight + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_tensor->numel()); + + int quant_axis = BOOST_GET_CONST( + int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); + if (quant_axis == -1) { // per_layer quant_dequant: all OP + PADDLE_ENFORCE_EQ(weight_scale_nums, 1, + platform::errors::InvalidArgument( + "When quant_axis == -1 means use per_layer " + "quant_dequant, weight_scale'number should be 1.")); + + // float(weight) * scale + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data_tmp[i] = + static_cast(quantized_weight_data[i]) * weight_scale[0]; + } + } else if (quant_axis == 0) { // per_channel quant_dequant: conv2d, + // depthwise_conv2d, conv2d_fusion + PADDLE_ENFORCE_EQ( + weight_scale_nums, w_dims[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis == 0 means use per_channel quant_dequant, " + "weight_scale'numbers should be equal channels.")); + PADDLE_ENFORCE_EQ(w_dims.size(), 4, + platform::errors::InvalidArgument( + "When quant_axis == 0 means use per_channel " + "quant_dequant, (conv2d, depthwise_conv2d, " + "conv2d_fusion)'s weight dims should be 4.")); + + for (int i = 0; i < weight_tensor->numel(); i++) { + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[i / inner_size]; + } + } else if (quant_axis == 1) { + PADDLE_ENFORCE_EQ( + weight_scale_nums, w_dims[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis == 1 means use per_channel quant_dequant, " + "weight_scale'numbers should be equal channels.")); + + if (w_dims.size() == 4) { // conv2d_transpose + std::string quantized_op_type = any_op2->Op()->Type(); + PADDLE_ENFORCE_EQ( + quantized_op_type, "conv2d_transpose", + platform::errors::InvalidArgument( + "When quant_axis == 1 means use per_channel quant_dequant, " + "only conv2d_transpose weight dims equal 4.")); + for (int i = 0; i < weight_tensor->numel(); i++) { + int inner_size = w_dims[2] * w_dims[3]; + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[(i / inner_size) % w_dims[1]]; + } + } else if (w_dims.size() == 2) { + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[i % w_dims[1]]; + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When quant_axis == 1 , weight dims should be 2 or 4, please check " + "your model ")); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "quant_axis should be -1 or 0 or 1, please check your model " + "OP'attribute ")); + } + weight_tensor->clear(); // clear int weight + weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims))); + float* new_quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_quantized_weight_data, weight_data_tmp.data(), + weight_tensor->numel() * sizeof(float)); + + nodes2rm.insert(weight_dequantize_linear_op_scale); + nodes2rm.insert(weight_dequantize_linear_op); + nodes2rm.insert(weight_dequantize_linear_op_out); + + // relink weight to any_op2 + any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(), + weight_dequantize_linear_op_x->Var()->Name()); + any_op2_desc->Flush(); + IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2); + GraphSafeRemoveNodes(graph, nodes2rm); + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_weight_dequant_linear_op_pass, + paddle::framework::ir::DeleteWeightQuantDequantLinearOpPass); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h new file mode 100644 index 0000000000000..e240b6212b84a --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteWeightQuantDequantLinearOpPass : public FusePassBase { + public: + DeleteWeightQuantDequantLinearOpPass(); + virtual ~DeleteWeightQuantDequantLinearOpPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index e246a10961c0c..1e25b21483b82 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -226,23 +226,34 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { // For anakin subgraph int8 // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul + // fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass - // will add "input_scale", "weight_scale" which are extracted from + // will add "input_scale" which are extracted from // fake_quant op and fake_dequant op to mul op, and then delete the // fake_quant op and fake_dequant op in the graph. If the mul op has the // scale info, we should add those to the fused fc. auto* mul_op_desc = mul->Op(); + auto* elementwise_add_op_desc = elementwise_add->Op(); + if (mul_op_desc->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8")); - desc.SetAttr("Input_scale", mul_op_desc->GetAttr("X_scale")); - desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale")); - if (mul_op_desc->HasAttr("out_scale")) - desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale")); - auto elementwise_desc = elementwise_add->Op(); - if (elementwise_desc->HasAttr("out_scale")) - desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale")); } - auto* elementwise_add_op_desc = elementwise_add->Op(); + if (mul_op_desc->HasAttr("Input_scale")) { + desc.SetAttr("Input_scale", mul_op_desc->GetAttr("Input_scale")); + } + + bool inscale_flag = false; + bool outscale_flag = false; + + if (mul_op_desc->HasAttr("X")) { + desc.SetAttr("X", mul_op_desc->GetAttr("X")); + inscale_flag = true; + } + if (elementwise_add_op_desc->HasAttr("Out")) { + desc.SetAttr("Out", elementwise_add_op_desc->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + // if we can find out_threshold in elementwise_add, then set it as the // out_thrshold of fc auto out_threshold_attr = diff --git a/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc index 1759d18761da3..ac580b99b5c95 100644 --- a/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc @@ -298,8 +298,7 @@ void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } @@ -372,9 +371,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", - matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } @@ -451,8 +448,7 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { } if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } @@ -532,8 +528,7 @@ void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } @@ -677,8 +672,7 @@ void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } @@ -765,8 +759,7 @@ void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 164a13d1560f4..03da1289205e4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2949,6 +2949,84 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() { any_op2->LinksFrom({quant_dequant_out}); } +void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() { + auto weight_dequantize_linear_op_x = + pattern->NewNode(weight_dequantize_linear_op_x_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "X") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op_scale = + pattern->NewNode(weight_dequantize_linear_op_scale_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "Scale") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op = + pattern->NewNode(weight_dequantize_linear_op_repr()) + ->assert_is_op("dequantize_linear"); + + auto weight_dequantize_linear_op_out = + pattern->NewNode(weight_dequantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("dequantize_linear", "Y"); + + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + weight_dequantize_linear_op + ->LinksFrom( + {weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale}) + .LinksTo({weight_dequantize_linear_op_out}); + any_op2->LinksFrom({weight_dequantize_linear_op_out}); +} + +void patterns::DeleteQuantDequantLinearOpPattern::operator()() { + auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr()) + ->AsInput() + ->assert_is_op_input("quantize_linear", "X"); + + auto quantize_linear_op_scale = + pattern->NewNode(quantize_linear_op_scale_repr()) + ->AsInput() + ->assert_is_op_input("quantize_linear", "Scale") + ->assert_is_persistable_var(); + + auto quantize_linear_op = pattern->NewNode(quantize_linear_op_repr()) + ->assert_is_op("quantize_linear"); + + auto quantize_linear_op_out = + pattern->NewNode(quantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("quantize_linear", "Y") + ->assert_is_op_input("dequantize_linear", "X") + ->assert_var_not_persistable(); + + // Can not add this node. Todo: Wangzheee + /* + auto dequantize_linear_op_scale = + pattern->NewNode(dequantize_linear_op_scale_repr()) + ->assert_is_op_input("dequantize_linear", "Scale") + ->AsIntermediate(); + */ + + auto dequantize_linear_op = pattern->NewNode(dequantize_linear_op_repr()) + ->assert_is_op("dequantize_linear"); + + auto dequantize_linear_op_out = + pattern->NewNode(dequantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("dequantize_linear", "Y"); + + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + quantize_linear_op + ->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale}) + .LinksTo({quantize_linear_op_out}); + dequantize_linear_op->LinksFrom({quantize_linear_op_out}) + .LinksTo({dequantize_linear_op_out}); + any_op2->LinksFrom({dequantize_linear_op_out}); +} + PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( const std::string &op_name, bool with_reshape_xshape, bool with_transpose_xshape) { @@ -3311,25 +3389,14 @@ PDNode *patterns::LayerNorm::operator()() { return shift_out; } -// Add support int8 flag +// Add support int8 flag and out_threshold PDNode *patterns::AddSupportInt8::operator()() { - auto prev_op = - pattern->NewNode(prev_op_repr()) - ->assert_is_op() - ->assert_more([&](Node *node) { - return node->Op()->HasAttr("out_threshold") ? true : false; - }); - auto prev_out = pattern->NewNode(prev_out_repr())->assert_is_var(); - auto quant_op = - pattern->NewNode(quant_op_repr()) - ->assert_is_op() - ->assert_more([&](Node *node) { - return node->Op()->HasAttr("out_threshold") ? true : false; - }); + auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op(); auto quant_out = - pattern->NewNode(quant_out_repr())->assert_is_var()->AsOutput(); - prev_op->LinksTo({prev_out}); - prev_out->LinksTo({quant_op}); + pattern->NewNode(quant_out_repr()) + ->assert_is_var() + ->assert_more([&](Node *node) { return node->outputs.size() > 0; }) + ->AsOutput(); quant_op->LinksTo({quant_out}); return quant_out; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 17c70ace301d3..1f253c6b91043 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1702,6 +1702,40 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase { PATTERN_DECL_NODE(any_op2); }; +struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase { + DeleteWeightQuantDequantLinearOpPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, + "delete_weight_quant_dequant_linear_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(weight_dequantize_linear_op_x); + PATTERN_DECL_NODE(weight_dequantize_linear_op_scale); + PATTERN_DECL_NODE(weight_dequantize_linear_op); + PATTERN_DECL_NODE(weight_dequantize_linear_op_out); + PATTERN_DECL_NODE(any_op2); +}; + +struct DeleteQuantDequantLinearOpPattern : public PatternBase { + DeleteQuantDequantLinearOpPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, + "delete_quant_dequant_linear_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(quantize_linear_op_x); + PATTERN_DECL_NODE(quantize_linear_op_scale); + PATTERN_DECL_NODE(quantize_linear_op); + PATTERN_DECL_NODE(quantize_linear_op_out); + PATTERN_DECL_NODE(dequantize_linear_op); + // PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node. + // Todo: Wangzheee + PATTERN_DECL_NODE(dequantize_linear_op_out); + PATTERN_DECL_NODE(any_op2); +}; + // Reshape + Transpose + Matmul // named nodes: // reshape_op, reshape_out, reshape_xshape, @@ -1887,8 +1921,6 @@ struct AddSupportInt8 : public PatternBase { : PatternBase(pattern, name_scope, "Add_support_int8") {} PDNode* operator()(); - PATTERN_DECL_NODE(prev_op); - PATTERN_DECL_NODE(prev_out); PATTERN_DECL_NODE(quant_op); PATTERN_DECL_NODE(quant_out); }; diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 989b5460743b0..a8595d55b31b0 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -862,43 +862,30 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, multihead_op_desc.SetAttr("head_number", head_number); auto* mul0_op_desc = mul0->Op(); - auto* mul1_op_desc = mul1->Op(); - auto* mul2_op_desc = mul2->Op(); - if (mul0_op_desc->HasAttr("enable_int8")) { - multihead_op_desc.SetAttr("enable_int8", - mul0_op_desc->GetAttr("enable_int8")); - // all mul op has same input. + + // all mul op has same input. + if (multihead_op_desc.HasAttr("Input_scale")) { multihead_op_desc.SetAttr("Input_scale", - mul0_op_desc->GetAttr("X_scale")); - auto weight_scale0 = BOOST_GET_CONST( - std::vector, mul0_op_desc->GetAttr("weight_scale")); - auto weight_scale1 = BOOST_GET_CONST( - std::vector, mul1_op_desc->GetAttr("weight_scale")); - auto weight_scale2 = BOOST_GET_CONST( - std::vector, mul2_op_desc->GetAttr("weight_scale")); - auto weight_max = std::max(weight_scale0, weight_scale1); - weight_max = std::max(weight_max, weight_scale2); - multihead_op_desc.SetAttr("weight_scale", weight_max); - - auto* add0_op_desc = eltadd0->Op(); - auto* add1_op_desc = eltadd1->Op(); - auto* add2_op_desc = eltadd2->Op(); - if (add0_op_desc->HasAttr("out_threshold")) { - auto out_scale0 = - BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold")); - auto out_scale1 = - BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold")); - auto out_scale2 = - BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold")); - auto out_scale_max = std::max(out_scale0, out_scale1); - out_scale_max = std::max(out_scale_max, out_scale2); - multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max); - } + mul0_op_desc->GetAttr("Input_scale")); + } + auto* add0_op_desc = eltadd0->Op(); + auto* add1_op_desc = eltadd1->Op(); + auto* add2_op_desc = eltadd2->Op(); + if (add0_op_desc->HasAttr("out_threshold")) { + auto out_scale0 = + BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold")); + auto out_scale1 = + BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold")); + auto out_scale2 = + BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold")); + auto out_scale_max = std::max(out_scale0, out_scale1); + out_scale_max = std::max(out_scale_max, out_scale2); + multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max); } auto* softmax_qk_op_desc = softmax_qk->Op(); auto* matmul_qk_op_desc = matmul_qk->Op(); - if (matmul_qk_op_desc->HasAttr("X_scale")) { + if (matmul_qk_op_desc->HasAttr("Input_scale")) { multihead_op_desc.SetAttr("qkv2context_plugin_int8", true); if (softmax_qk_op_desc->HasAttr("out_threshold")) { auto qkv_plugin_scale = BOOST_GET_CONST( diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 619fe7ab4f738..281e0b9910619 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -341,7 +341,6 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node")); Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node")); int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length")); - int range = ((1 << (bit_length - 1)) - 1); // Get input scale from tensor std::string input_scale_var_name = quant->Op()->Input("InScale").front(); @@ -356,7 +355,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, "Input scale tensor's place should be CPU.")); const float* input_scale_data = input_scale_tensor.data(); float in_scale = input_scale_data[0]; - float scale_value = in_scale / range; + float scale_value = in_scale; // Set input scale in attr, and relink nodes std::string input_act_name = input_act->Var()->Name(); @@ -369,11 +368,10 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, quantized_op_type == "conv2d_fusion" || quantized_op_type == "depthwise_conv2d" || quantized_op_type == "fc" || - quantized_op_type == "conv2d_transpose") { + quantized_op_type == "conv2d_transpose" || + quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "matmul_v2") { op_desc->SetAttr("Input_scale", scale_value); - } else if (quantized_op_type == "mul" || quantized_op_type == "matmul" || - quantized_op_type == "matmul_v2") { - op_desc->SetAttr("X_scale", scale_value); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported quantized op type %s.", quantized_op_type)); @@ -619,7 +617,6 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, new_op_desc.SetInput("X", {new_input}); new_op_desc.SetOutput("Out", {new_output}); } - new_op_desc.SetAttr("weight_scale", weight_scale); new_op_desc.Flush(); auto* new_op = graph->CreateOpNode(&new_op_desc); IR_NODE_LINK_TO(quantized_op_input_node, new_op); diff --git a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc index 3caaf08dc9cb5..d3211c0841416 100644 --- a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc @@ -297,11 +297,24 @@ void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("transpose_Y", matmul_op->Op()->GetAttr("transpose_Y")); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag = false; + bool outscale_flag = false; + + if (matmul_op->Op()->HasAttr("X")) { + desc.SetAttr("X", matmul_op->Op()->GetAttr("X")); + inscale_flag = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); @@ -370,12 +383,23 @@ void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y")); if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", - matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag = false; + bool outscale_flag = false; + if (matmul_v2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X")); + inscale_flag = true; + } + if (matmul_v2_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_v2_in_x, mul_node); IR_NODE_LINK_TO(matmul_v2_in_y, mul_node); @@ -448,11 +472,23 @@ void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { } if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag = false; + bool outscale_flag = false; + if (matmul_v2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X")); + inscale_flag = true; + } + if (matmul_v2_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + auto matmul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node); IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node); @@ -530,11 +566,24 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag_x = false; + bool outscale_flag = false; + + if (squeeze2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", squeeze2_op->Op()->GetAttr("X")); + inscale_flag_x = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag_x && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(squeeze2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); @@ -675,11 +724,24 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag_x = false; + bool outscale_flag = false; + + if (reshape2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", reshape2_op->Op()->GetAttr("X")); + inscale_flag_x = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag_x && outscale_flag); + if (!IsCompat(desc)) { LOG(WARNING) << "TrtReshape2MatmulFusePass in out mul op compat failed."; @@ -763,11 +825,24 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag_x = false; + bool outscale_flag = false; + + if (flatten2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", flatten2_op->Op()->GetAttr("X")); + inscale_flag_x = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag_x && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(flatten2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 95975d8f2a892..20418e37a7b94 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -76,10 +76,13 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ "adaptive_pool2d_convert_global_pass", - "shuffle_channel_detect_pass", // - "quant_conv2d_dequant_fuse_pass", // - "delete_quant_dequant_op_pass", // - "delete_quant_dequant_filter_op_pass", // + "shuffle_channel_detect_pass", // + "quant_conv2d_dequant_fuse_pass", // + "delete_quant_dequant_op_pass", // + "delete_quant_dequant_filter_op_pass", // + "delete_weight_dequant_linear_op_pass", // + "delete_quant_dequant_linear_op_pass", // + "add_support_int8_pass", // // "fc_fuse_pass", // "simplify_with_basic_ops_pass", // "embedding_eltwise_layernorm_fuse_pass", // @@ -98,9 +101,8 @@ const std::vector kTRTSubgraphPasses({ "trt_map_matmul_to_mul_pass", // "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // - "add_support_int8_pass", - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index e6a0ecf4aecec..b86351e394bd1 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -68,12 +68,6 @@ class ActivationOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode); - if (op_desc.HasAttr("out_scale")) { -#if IS_TRT_VERSION_GE(5130) - float out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_scale")); - engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); -#endif - } } protected: diff --git a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc index eba67c3c098ca..cc06f82ae3901 100644 --- a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc @@ -49,11 +49,11 @@ class AffineChannelOpConverter : public OpConverter { auto* scale_v = scope.FindVar(scale_name); auto* scale_t = scale_v->GetMutable(); - float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t, false); + float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t); auto* bias_v = scope.FindVar(bias_name); auto* bias_t = bias_v->GetMutable(); - float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false); + float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t); // tensorrt scalend layer only support spatial dims >= 2, // so nhwc is not availabe (spatial dims == 0) diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index a296a2641db65..1b2abeac6c19f 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -49,18 +49,11 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, if (enable_int8) { #if IS_TRT_VERSION_GE(5000) - float in_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, - true, weight_scale); + float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine->SetTensorDynamicRange(X, in_scale); #endif - } else { - weight_data = - engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false); } + weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL, platform::errors::InvalidArgument( @@ -115,7 +108,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front()); auto* bias_tensor_data = bias_tensor->GetMutable(); bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(), - bias_tensor_data, false); + bias_tensor_data); bias_size = static_cast(bias_tensor_data->numel()); } diff --git a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc index dae92264d2c3e..dbb2786ed78ab 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc @@ -48,17 +48,10 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op, bool enable_int8 = op_desc.HasAttr("enable_int8"); if (enable_int8) { - float in_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, - true, weight_scale); + float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine->SetTensorDynamicRange(X, in_scale); - } else { - weight_data = - engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false); } + weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL, platform::errors::InvalidArgument( diff --git a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc index d8534a4183bdd..2bbe6ea3d2fa8 100644 --- a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc @@ -47,8 +47,7 @@ class DeformableConvOpConverter : public OpConverter { auto* filter_var = scope.FindVar(filter_name); auto* filter_tensor = filter_var->GetMutable(); - float* filter_data = - engine_->GetWeightCPUData(filter_name, filter_tensor, false); + float* filter_data = engine_->GetWeightCPUData(filter_name, filter_tensor); const int c_o = filter_tensor->dims()[0]; const int c_i = filter_tensor->dims()[1]; diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index a66a97b4be9da..8fd0e1bbd068d 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -51,8 +51,7 @@ class ElementwiseWeightOpConverter : public OpConverter { auto* Y_t = Y_v->GetMutable(); float* weight_data = nullptr; auto output_name = op_desc.Output("Out")[0]; - weight_data = - engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t, false); + weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t); nvinfer1::Dims dims_x = X->getDimensions(); auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) { @@ -112,13 +111,6 @@ class ElementwiseWeightOpConverter : public OpConverter { RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name}, test_mode); } - if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); - engine_->SetTensorDynamicRange(X, x_scale); -#endif - } }; if (engine_->with_dynamic_shape()) { @@ -222,16 +214,6 @@ class ElementwiseTensorOpConverter : public OpConverter { auto common_func = [&](nvinfer1::ILayer* layer) { RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); - if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - CHECK(op_desc.HasAttr("Y_scale")); - float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); - float y_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Y_scale")); - engine_->SetTensorDynamicRange(X, x_scale); - engine_->SetTensorDynamicRange(Y, y_scale); -#endif - } }; if (dims_x.nbDims == dims_y.nbDims) { diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index 9741aab32dea5..7a494860e6fa1 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -77,7 +77,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index bdea14c9e9f89..a631332dae360 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -113,22 +113,20 @@ class FcOpConverter : public OpConverter { // assigned from CPU memory, which can't be avoided. float* weight_data = nullptr; bool enable_int8 = op_desc.HasAttr("enable_int8"); - float in_scale = 0.; - if (enable_int8) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr(i_name + "_scale")); - in_scale = - BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), - Y_t, true, weight_scale); + bool support_int8 = false; + if (op_desc.HasAttr("support_int8")) { + support_int8 = BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")); + } + float in_scale = 0; + if (enable_int8 || support_int8) { + if (enable_int8) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + } else { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X")); + } engine_->SetTensorDynamicRange(X, in_scale); -#endif - } else { - weight_data = - engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t, false); } + weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL, platform::errors::InvalidArgument( @@ -148,14 +146,18 @@ class FcOpConverter : public OpConverter { auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, TensorRTEngine::Weight& weight, TensorRTEngine::Weight& bias) { - if (enable_int8) { + if (enable_int8 || support_int8) { // add conv layer - PADDLE_ENFORCE_EQ( - op_desc.HasAttr("out_threshold"), true, - platform::errors::InvalidArgument( - "must have out threshold in fc layers in int8 mode")); - float out_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + float out_scale = 0; + if (enable_int8) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in fc layers in int8 mode")); + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + } else { + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); + } nvinfer1::DimsHW nv_ksize(1, 1); auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, @@ -235,8 +237,7 @@ class FcOpConverter : public OpConverter { if (with_bias) { auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); auto* b_t = b_v->GetMutable(); - bias_data = - engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t, false); + bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t); bias_num = b_t->numel(); } TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, @@ -251,7 +252,7 @@ class FcOpConverter : public OpConverter { // not add Shuffle layer in ernie's multihead. if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && x_dim.d[3] == 1 && x_num_col_dims == 2) { - if (enable_int8) { + if (enable_int8 || support_int8) { // add conv1x1 layer nvinfer1::DimsHW nv_ksize(1, 1); auto* fc_layer_int8 = @@ -265,8 +266,13 @@ class FcOpConverter : public OpConverter { op_desc.HasAttr("out_threshold"), true, platform::errors::InvalidArgument( "must have out threshold in fc layers in int8 mode")); - float out_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + float out_scale = 0; + if (enable_int8) { + out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + } else { + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); + } engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale); nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( @@ -308,7 +314,7 @@ class FcOpConverter : public OpConverter { auto* reshape_before_fc_layer = reshape_before_fc(X, x_dim, x_num_col_dims, output_name); auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); - if (enable_int8) { + if (enable_int8 || support_int8) { engine_->SetTensorDynamicRange(reshape_itensor, in_scale); } regist_fc(reshape_itensor, n_output, weight, bias); diff --git a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc index b3c1f986aa030..910a807d3626a 100644 --- a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc @@ -48,7 +48,7 @@ class GroupNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc index c6dbfc832201b..c7a551b7436db 100644 --- a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc @@ -49,8 +49,8 @@ class LeakyReluOpConverter : public OpConverter { bool enable_int8 = op_desc.HasAttr("enable_int8"); if (enable_int8) { - CHECK(op_desc.HasAttr("X_scale")); - float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); + CHECK(op_desc.HasAttr("Input_scale")); + float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input, in_scale); } #else diff --git a/paddle/fluid/inference/tensorrt/convert/matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc index b2e76b9a0e61b..7568f67d64d04 100644 --- a/paddle/fluid/inference/tensorrt/convert/matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc @@ -64,7 +64,9 @@ class MatMulOpConverter : public OpConverter { : nvinfer1::MatrixOperation::kNONE; if (op_desc.HasAttr("support_int8") && - engine_->precision() == AnalysisConfig::Precision::kInt8) { + BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")) && + engine_->precision() == AnalysisConfig::Precision::kInt8 && + platform::GetGPUComputeCapability(0) >= 75) { if (engine_->with_dynamic_shape()) { VLOG(3) << "Convert a fluid matmul_op_int8_dynamic to TensorRT " "MatmulPluginLayer"; diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index f19b21d3e6326..21c79f0edd27f 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -40,22 +40,16 @@ class MultiheadMatMulOpConverter : public OpConverter { auto* bias_t = bias_v->GetMutable(); float* weight_data = nullptr; - bool enable_int8 = op_desc.HasAttr("enable_int8"); bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8"); float in_scale = 0.; - if (enable_int8) { - in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = - engine_->GetWeightCPUData(weight_name, weight_t, true, weight_scale); + if (op_desc.HasAttr("Input_scale")) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input, in_scale); - } else { - weight_data = engine_->GetWeightCPUData(weight_name, weight_t, false); } + weight_data = engine_->GetWeightCPUData(weight_name, weight_t); - float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t, false); + float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); std::vector weight_data_tmp; weight_data_tmp.reserve(weight_t->numel()); memcpy(weight_data_tmp.data(), weight_data, @@ -85,6 +79,10 @@ class MultiheadMatMulOpConverter : public OpConverter { if (engine_->with_dynamic_shape()) { if (engine_->use_oss()) { + if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { + PADDLE_THROW(platform::errors::Fatal( + "use use_oss must be int8 or half, not float32.")); + } nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), static_cast(weight_t->numel())}; @@ -93,7 +91,7 @@ class MultiheadMatMulOpConverter : public OpConverter { static_cast(bias_t->numel())}; if (engine_->with_interleaved()) { VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved"; - if (!enable_int8) { + if (!op_desc.HasAttr("Input_scale")) { PADDLE_THROW( platform::errors::Fatal("use with_interleaved must be int8.")); } @@ -213,7 +211,7 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; - if (enable_int8) { + if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, nv_ksize, weight, bias); @@ -222,7 +220,7 @@ class MultiheadMatMulOpConverter : public OpConverter { weight, bias); } - if (enable_int8) { + if (op_desc.HasAttr("fc_out_threshold")) { PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, platform::errors::InvalidArgument( "must have out threshold in multihead layers " @@ -241,14 +239,10 @@ class MultiheadMatMulOpConverter : public OpConverter { auto creator = GetPluginRegistry()->getPluginCreator( "CustomQKVToContextPluginDynamic", "2"); assert(creator != nullptr); - int type = static_cast((engine_->WithFp16() == 1) - ? nvinfer1::DataType::kHALF - : nvinfer1::DataType::kFLOAT); - if (enable_int8) { - type = static_cast(nvinfer1::DataType::kHALF); - if (qkv2context_plugin_int8) { - type = static_cast(nvinfer1::DataType::kINT8); - } + int type = static_cast(nvinfer1::DataType::kHALF); + if (qkv2context_plugin_int8 && + (engine_->precision() == AnalysisConfig::Precision::kInt8)) { + type = static_cast(nvinfer1::DataType::kINT8); } bool has_mask = true; int var_seqlen = 1; @@ -335,7 +329,7 @@ class MultiheadMatMulOpConverter : public OpConverter { reshape_before_fc_dim.d[4] = 1; auto* reshape_before_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - if (enable_int8) { + if (op_desc.HasAttr("Input_scale")) { engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), in_scale); } @@ -346,7 +340,7 @@ class MultiheadMatMulOpConverter : public OpConverter { // add layer fc nvinfer1::ILayer* fc_layer = nullptr; - if (enable_int8) { + if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); fc_layer = TRT_ENGINE_ADD_LAYER( engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n, @@ -357,7 +351,7 @@ class MultiheadMatMulOpConverter : public OpConverter { n, weight.get(), bias.get()); } - if (enable_int8) { + if (op_desc.HasAttr("fc_out_threshold")) { PADDLE_ENFORCE_EQ( op_desc.HasAttr("fc_out_threshold"), true, platform::errors::InvalidArgument( @@ -382,8 +376,8 @@ class MultiheadMatMulOpConverter : public OpConverter { bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - if (enable_int8) { - with_fp16 = 1; + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; } plugin::DynamicPluginTensorRT* plugin = new plugin::QkvToContextPluginDynamic(hidden_in, head_number, diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 7e0c8bf1da177..f7eb7f859afaa 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -145,42 +145,68 @@ class OpConverter { (*it)(op, scope, test_mode); size_t output_num = op_desc.OutputNames().size(); - if (output_num == 1) { // The number of output is 1 - if (op_desc.HasAttr("out_threshold")) { - float out_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); - std::string output_name = ""; - if (op_desc.HasOutput("Output")) { - output_name = op_desc.Output("Output").front(); - } else if (op_desc.HasOutput("Out")) { - output_name = op_desc.Output("Out").front(); - } else if (op_desc.HasOutput("Y")) { - output_name = op_desc.Output("Y").front(); - } else { - PADDLE_THROW( - platform::errors::NotFound("Op %s has out threshold but doesn't " - "have an output named \"Output\", " - "\"Out\" or \"Y\".", - op_desc.Type())); - } + // only one out settensordynamicRange + if (op_desc.HasAttr("out_threshold")) { + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + std::string output_name = ""; + if (op_desc.HasOutput("Output")) { + output_name = op_desc.Output("Output").front(); + } else if (op_desc.HasOutput("Out")) { + output_name = op_desc.Output("Out").front(); + } else if (op_desc.HasOutput("Y")) { + output_name = op_desc.Output("Y").front(); + } else { + PADDLE_THROW( + platform::errors::NotFound("Op %s has out threshold but doesn't " + "have an output named \"Output\", " + "\"Out\" or \"Y\".", + op_desc.Type())); + } + auto* output_itensor = engine->GetITensor(output_name); + engine->SetTensorDynamicRange(output_itensor, out_scale); + VLOG(1) << "Set out scale = " << out_scale << " for tensor " + << output_name << "."; + } + // outs settensordynamicRange + for (size_t i = 0; i < output_num; ++i) { + if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) { + float out_scale = BOOST_GET_CONST( + float, op_desc.GetAttr("out_" + std::to_string(i) + "_threshold")); + std::string output_name = + op_desc.Output(op_desc.OutputNames()[i]).front(); auto* output_itensor = engine->GetITensor(output_name); engine->SetTensorDynamicRange(output_itensor, out_scale); VLOG(1) << "Set out scale = " << out_scale << " for tensor " << output_name << "."; } - } else if (output_num > 1) { // The number of outputs greater than 1 - for (size_t i = 0; i < output_num; ++i) { - if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) { - float out_scale = BOOST_GET_CONST( - float, - op_desc.GetAttr("out_" + std::to_string(i) + "_threshold")); - std::string output_name = - op_desc.Output(op_desc.OutputNames()[i]).front(); - auto* output_itensor = engine->GetITensor(output_name); - engine->SetTensorDynamicRange(output_itensor, out_scale); - VLOG(1) << "Set out scale = " << out_scale << " for tensor " - << output_name << "."; - } + } + + // quant_dequant_linear support for paddle trt + + std::vector inputs_name = op_desc.InputNames(); + std::vector outputs_name = op_desc.OutputNames(); + + for (size_t i = 0; i < inputs_name.size(); i++) { + if (op_desc.HasAttr(inputs_name[i])) { + std::string input_tensor_name = op_desc.Input(inputs_name[i])[0]; + auto* input_itensor = engine->GetITensor(input_tensor_name); + float input_scale = + BOOST_GET_CONST(float, op_desc.GetAttr(inputs_name[i])); + engine->SetTensorDynamicRange(input_itensor, input_scale); + VLOG(1) << "Set input tensor scale = " << input_scale + << " for tensor: " << input_tensor_name << "."; + } + } + for (size_t i = 0; i < outputs_name.size(); i++) { + if (op_desc.HasAttr(outputs_name[i])) { + std::string output_tensor_name = op_desc.Output(outputs_name[i])[0]; + auto* output_itensor = engine->GetITensor(output_tensor_name); + float output_scale = + BOOST_GET_CONST(float, op_desc.GetAttr(outputs_name[i])); + engine->SetTensorDynamicRange(output_itensor, output_scale); + VLOG(1) << "Set output tensor scale = " << output_scale + << " for tensor: " << output_tensor_name << "."; } } } diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 7b65d2d7c97cc..7824a0f1e29f4 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -132,11 +132,10 @@ class Pool2dOpConverter : public OpConverter { } if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); + CHECK(op_desc.HasAttr("Input_scale")); + float input_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input1, input_scale); -#endif } std::vector real_paddings = paddings; diff --git a/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc index 5a306f622adbe..665bf9c8d22ed 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc @@ -123,8 +123,9 @@ class Pool3dOpConverter : public OpConverter { nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]); nvinfer1::ILayer *layer = nullptr; if (op_desc.HasAttr("enable_int8")) { - CHECK(op_desc.HasAttr("X_scale")); - float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); + CHECK(op_desc.HasAttr("Input_scale")); + float input_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input1, input_scale); } diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index daa3b186ab4c4..87fdbb71a3faf 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -70,7 +70,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc index d9eca65fc45dc..8053135cc452c 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc @@ -48,7 +48,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index 9e81d1177cfe1..d5b5d9bc81b6a 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -57,8 +57,8 @@ class PReluOpConverter : public OpConverter { layer = engine_->AddDynamicPlugin(&input, input_num, plugin); } else { #if IS_TRT_VERSION_GE(7000) - float* alpha_weight_data = engine_->GetWeightCPUData( - op_desc.Input("Alpha")[0], alpha_tensor, false); + float* alpha_weight_data = + engine_->GetWeightCPUData(op_desc.Input("Alpha")[0], alpha_tensor); TensorRTEngine::Weight alpha_weight{ nvinfer1::DataType::kFLOAT, static_cast(alpha_weight_data), static_cast(alpha_tensor->numel())}; diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 753cd70727643..831e117311771 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -40,7 +40,7 @@ class SkipLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 794475dfc10ca..33386c746ae5a 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -356,9 +356,7 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { } float *TensorRTEngine::GetWeightCPUData(const std::string &name, - framework::Tensor *weight_tensor, - bool enable_int8, - const std::vector &scale) { + framework::Tensor *weight_tensor) { static int name_suffix_counter = 0; std::string name_suffix = std::to_string(name_suffix_counter); std::string splitter = "__"; diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index d53a8923af612..f781cd0cb3a8d 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -389,8 +389,7 @@ class TensorRTEngine { } float* GetWeightCPUData(const std::string& name, - framework::Tensor* weight_tensor, bool enable_int8, - const std::vector& scale = {}); + framework::Tensor* weight_tensor); // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. diff --git a/paddle/fluid/operators/compat/dequantize_linear.pbtxt b/paddle/fluid/operators/compat/dequantize_linear.pbtxt new file mode 100644 index 0000000000000..73b61f8bc29fb --- /dev/null +++ b/paddle/fluid/operators/compat/dequantize_linear.pbtxt @@ -0,0 +1,25 @@ +type: "dequantize_linear" +def { + inputs { + name: "X" + } + inputs { + name: "Scale" + } + inputs { + name: "ZeroPoint" + } + outputs { + name: "Y" + } + attrs { + name: "bit_length" + type: INT + } + attrs { + name: "quant_axis" + type: INT + } +} +extra { +} diff --git a/paddle/fluid/operators/compat/mul.pbtxt b/paddle/fluid/operators/compat/mul.pbtxt index 617775eaaae9e..056f799c6c49c 100644 --- a/paddle/fluid/operators/compat/mul.pbtxt +++ b/paddle/fluid/operators/compat/mul.pbtxt @@ -60,15 +60,7 @@ extra { type: BOOLEAN } attrs { - name: "X_scale" - type: FLOAT - } - attrs { - name: "weight_scale" - type: FLOAT - } - attrs { - name: "out_scale" + name: "Input_scale" type: FLOAT } attrs { diff --git a/paddle/fluid/operators/compat/quantize_linear.pbtxt b/paddle/fluid/operators/compat/quantize_linear.pbtxt new file mode 100644 index 0000000000000..7a3ca515029c3 --- /dev/null +++ b/paddle/fluid/operators/compat/quantize_linear.pbtxt @@ -0,0 +1,25 @@ +type: "quantize_linear" +def { + inputs { + name: "X" + } + inputs { + name: "Scale" + } + inputs { + name: "ZeroPoint" + } + outputs { + name: "Y" + } + attrs { + name: "bit_length" + type: INT + } + attrs { + name: "quant_axis" + type: INT + } +} +extra { +} diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py index 97a94ef348a67..26066be7dc787 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py @@ -491,8 +491,7 @@ def generate_weight2(): "x_num_col_dims": 2, "y_num_col_dims": 1, "enable_int8": True, - "X_scale": 1.0, - "weight_scale": [1.0], + "Input_scale": 1.0, }, { "axis": 2, "out_threshold": 1.0, @@ -504,8 +503,7 @@ def generate_weight2(): "x_num_col_dims": 2, "y_num_col_dims": 1, "enable_int8": True, - "X_scale": 1.0, - "weight_scale": [1.0], + "Input_scale": 1.0, }, { "axis": 2, "out_threshold": 1.0, @@ -517,8 +515,7 @@ def generate_weight2(): "x_num_col_dims": 2, "y_num_col_dims": 1, "enable_int8": True, - "X_scale": 1.0, - "weight_scale": [1.0], + "Input_scale": 1.0, }, { "axis": 2, "out_threshold": 1.0, From 0f6412c0c645e9a3c901cbcf4fa83c314ab85a37 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Apr 2022 19:08:56 +0800 Subject: [PATCH 26/93] do not use scope in op kernel (#41316) --- .../pscore/distributed_lookup_table_op.h | 48 +++++++------------ 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/pscore/distributed_lookup_table_op.h b/paddle/fluid/operators/pscore/distributed_lookup_table_op.h index da439407a422b..c2717c19b2d8e 100644 --- a/paddle/fluid/operators/pscore/distributed_lookup_table_op.h +++ b/paddle/fluid/operators/pscore/distributed_lookup_table_op.h @@ -26,17 +26,13 @@ template class DistributedLookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto &scope = context.scope(); - auto padding_idx = context.Attr("padding_idx"); auto table_id = context.Attr("table_id"); bool is_test = context.Attr("is_test"); - auto embedding_name = context.InputNames("W").front(); + auto *var = context.InputVar("W"); int64_t emb_dim = 0; - auto *var = scope.FindVar(embedding_name); - if (var->IsType()) { emb_dim = var->Get().dims()[1]; } else if (var->IsType()) { @@ -61,35 +57,31 @@ class DistributedLookupTableKernel : public framework::OpKernel { } else { auto inputs_variable = context.MultiInputVar("Ids"); auto outputs_variable = context.MultiOutputVar("Outputs"); - auto inputs_name = context.InputNames("Ids"); - auto outputs_name = context.OutputNames("Outputs"); auto cpu_place = platform::CPUPlace(); - framework::Scope *tmp_scope = scope.NewTmpScope().release(); std::vector tmp_input_vec; auto input_var_size = inputs_variable.size(); std::vector tmp_output_vec; auto output_var_size = outputs_variable.size(); + std::vector> tmp_tensors; + // create temp input for (size_t idx = 0; idx < input_var_size; ++idx) { - framework::Variable *tmp_input_var = tmp_scope->Var(inputs_name[idx]); - framework::LoDTensor *tmp_input_tensor = - tmp_input_var->GetMutable(); + tmp_tensors.emplace_back(std::make_shared()); + auto *p = tmp_tensors.back().get(); framework::TensorCopy(inputs_variable[idx]->Get(), - cpu_place, context.device_context(), - tmp_input_tensor); - tmp_input_vec.push_back(tmp_input_tensor); + cpu_place, context.device_context(), p); + tmp_input_vec.push_back(p); } // create temp output for (size_t idx = 0; idx < output_var_size; ++idx) { - framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]); - framework::LoDTensor *tmp_output_tensor = - tmp_output_var->GetMutable(); - tmp_output_tensor->Resize(outputs[idx]->dims()); - tmp_output_vec.push_back(tmp_output_tensor); + tmp_tensors.emplace_back(std::make_shared()); + auto *p = tmp_tensors.back().get(); + p->Resize(outputs[idx]->dims()); + tmp_output_vec.push_back(p); } // use fleet->PullSparse @@ -100,27 +92,21 @@ class DistributedLookupTableKernel : public framework::OpKernel { // cp temp to origin for (size_t idx = 0; idx < output_var_size; ++idx) { - framework::Variable *tmp_output_var = tmp_scope->Var(outputs_name[idx]); - framework::LoDTensor *tmp_output_tensor = - tmp_output_var->GetMutable(); framework::TensorCopy( - *tmp_output_tensor, context.GetPlace(), context.device_context(), + *tmp_output_vec[idx], context.GetPlace(), context.device_context(), outputs_variable[idx]->GetMutable()); } - delete tmp_scope; } - auto id_names = context.InputNames("Ids"); - auto out_names = context.OutputNames("Outputs"); auto lookup_table_version = context.Attr("lookup_table_version"); + auto id_vars = context.MultiInputVar("Ids"); + auto out_vars = context.MultiOutputVar("Outputs"); if (lookup_table_version == "lookup_table_v2") { - for (size_t i = 0; i < id_names.size(); ++i) { - auto *id_var = scope.FindVar(id_names[i]); - auto *out_var = scope.FindVar(out_names[i]); - auto *id_tensor = id_var->GetMutable(); - auto *out_tensor = out_var->GetMutable(); + for (size_t i = 0; i < id_vars.size(); ++i) { + auto *id_tensor = id_vars[i]->GetMutable(); + auto *out_tensor = out_vars[i]->GetMutable(); auto id_dims = id_tensor->dims(); out_tensor->Resize(phi::make_ddim({static_cast(id_dims[0]), From 90b95becee9b2d828fd98b5793296b6eb9ce0a4c Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Sat, 2 Apr 2022 19:22:57 +0800 Subject: [PATCH 27/93] [launch] fix log more stable; default to stdout (#41314) --- .../paddle/distributed/launch/context/node.py | 1 + .../launch/controllers/controller.py | 5 ++-- .../distributed/launch/job/container.py | 25 +++++++++++-------- python/paddle/distributed/launch/main.py | 2 +- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/python/paddle/distributed/launch/context/node.py b/python/paddle/distributed/launch/context/node.py index 2fa8b892275a0..8082541ffe06c 100644 --- a/python/paddle/distributed/launch/context/node.py +++ b/python/paddle/distributed/launch/context/node.py @@ -44,6 +44,7 @@ def get_ports_occupied(self): return self.free_ports def get_free_port(self): + # for loop to avoid port conflict for _ in range(100): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: diff --git a/python/paddle/distributed/launch/controllers/controller.py b/python/paddle/distributed/launch/controllers/controller.py index fbe9df4c9a223..9527ae35c4b6b 100644 --- a/python/paddle/distributed/launch/controllers/controller.py +++ b/python/paddle/distributed/launch/controllers/controller.py @@ -75,8 +75,9 @@ def watch(self) -> bool: while not self.ctx.status.is_done(): status = self.pod.watch(timeout=2) - if self.ctx.continous_log(): - self.pod.logs() + #if self.ctx.continous_log(): + # default to print log + self.pod.logs() # completed if status == self.ctx.status.COMPLETED: diff --git a/python/paddle/distributed/launch/job/container.py b/python/paddle/distributed/launch/job/container.py index 1f43b6ce04bac..a1ad6dbe24e8e 100644 --- a/python/paddle/distributed/launch/job/container.py +++ b/python/paddle/distributed/launch/job/container.py @@ -145,31 +145,34 @@ def __str__(self): self.errfile, self._env, ) - def logs(self, fn=None, offset=0, whence=1, lines=1000): + def logs(self, fn=None, offset=0, whence=1, limit=1000): if not self._log_handler: self._log_handler = open(self._out) if fn is None: fn = sys.stdout - self._log_handler.seek(offset, whence) - try: - idx = 0 - for line in self._log_handler: - fn.write(line) - idx += 1 - if idx > lines: + if offset != 0 or whence != 1: + self._log_handler.seek(offset, whence) + + for _ in range(limit): + line = self._log_handler.readline() + if not line: break - finally: + fn.write(line) + except: return def tail(self, length=3000): if not self._log_handler: self._log_handler = open(self._out) - self._log_handler.seek(0, 2) - ed = self._log_handler.tell() + try: + self._log_handler.seek(0, 2) + ed = self._log_handler.tell() + except: + pass if ed > length: self.logs(offset=ed - length, whence=0) diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index dd7edba35a474..400a447260252 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -40,7 +40,7 @@ def launch(): - ``--rank``: The rank of the node, can be auto assigned by master. Default ``--rank=-1``. - - ``--log_level``: The log level to set for logging.setLevel which can be CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET, case insensitive. The rank 0 log will not print in the terminal by default, while you can enable it by adding --log_level=debug. Default ``--log_level=INFO``. + - ``--log_level``: The log level to set for logging.setLevel which can be CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET, case insensitive. Default ``--log_level=INFO``. - ``--nnodes``: The number of nodes for a distributed job, it can be a range in elastic mode, e.g., ``--nnodes=2:3``. Default ``--nnodes=1``. From 1d8246b08290780e2400f9b3b4682a76fb0edf9a Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Sat, 2 Apr 2022 19:41:53 +0800 Subject: [PATCH 28/93] [Eager] Fix Pylayer compile error (#41240) * fix bug, test=develop * refine, test=develop --- paddle/fluid/pybind/eager_py_layer.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc index 59f21a1e1face..e9ddfd80bb867 100644 --- a/paddle/fluid/pybind/eager_py_layer.cc +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -34,6 +34,8 @@ limitations under the License. */ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "pybind11/detail/internals.h" +#pragma GCC diagnostic ignored "-Wwrite-strings" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" namespace paddle { namespace pybind { @@ -479,7 +481,7 @@ void BindEagerPyLayer(PyObject* module) { type->tp_dealloc = (destructor)PyLayerDealloc; type->tp_methods = pylayer_methods; type->tp_getset = pylayer_properties; - type->tp_new = PyLayerNew; + type->tp_new = (newfunc)PyLayerNew; Py_INCREF(&PyBaseObject_Type); type->tp_base = reinterpret_cast(&PyBaseObject_Type); type->tp_flags |= From 36f97cdca2a13ee952cc89a4f4b186fa6284ebb1 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Sat, 2 Apr 2022 20:21:57 +0800 Subject: [PATCH 29/93] [Yaml] add yaml for 5 ops [ elementwise_pow, expm1, floor_divide, logsumexp, mish ] (#41288) * add yaml for ele_max ele_min * add yaml for: mish / logexpsum / expm1 / elemenwise_pow / elementwise_floordiv --- .../kernels/impl/logsumexp_grad_kernel_impl.h | 15 ++++-- paddle/phi/kernels/logsumexp_grad_kernel.h | 2 +- python/paddle/fluid/layers/nn.py | 4 +- .../tests/unittests/test_activation_op.py | 12 ++++- .../unittests/test_elementwise_floordiv_op.py | 3 +- .../unittests/test_elementwise_pow_op.py | 27 +++++++++-- .../fluid/tests/unittests/test_logsumexp.py | 12 ++++- python/paddle/nn/functional/activation.py | 4 +- python/paddle/tensor/math.py | 6 ++- python/paddle/utils/code_gen/api.yaml | 46 +++++++++++++++++++ python/paddle/utils/code_gen/backward.yaml | 40 ++++++++++++++++ 11 files changed, 154 insertions(+), 17 deletions(-) diff --git a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h index c2583ce8d32df..23e4414858a78 100644 --- a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h @@ -46,7 +46,7 @@ void LogsumexpGradKernel(const Context& dev_ctx, const DenseTensor& in, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& axis, + const std::vector& axis, bool keepdim, bool reduce_all, DenseTensor* in_grad) { @@ -67,22 +67,27 @@ void LogsumexpGradKernel(const Context& dev_ctx, } else { int rank = in.dims().size(); LogsumexpGradFunctor functor; + std::vector axis32; + axis32.reserve(axis.size()); + std::for_each(axis.begin(), axis.end(), [&axis32](const int64_t& t) { + axis32.push_back(t); + }); switch (rank) { case 1: phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); + dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 2: phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); + dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 3: phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); + dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 4: phi::funcs::ReduceGradFunctor( - dev_ctx, in, out, out_grad, in_grad, functor, axis); + dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; } } diff --git a/paddle/phi/kernels/logsumexp_grad_kernel.h b/paddle/phi/kernels/logsumexp_grad_kernel.h index d68c447aa65cb..170f1c6c557ea 100644 --- a/paddle/phi/kernels/logsumexp_grad_kernel.h +++ b/paddle/phi/kernels/logsumexp_grad_kernel.h @@ -23,7 +23,7 @@ void LogsumexpGradKernel(const Context& ctx, const DenseTensor& in, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& axis, + const std::vector& axis, bool keepdim, bool reduce_all, DenseTensor* in_grad); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 75583fb5c109a..0dcc8ee517fb1 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -15349,7 +15349,9 @@ def mish(x, threshold=20, name=None): out, = exe.run(feed={'x':x_data}, fetch_list=[y.name]) print(out) # [[0.66666667, 1.66666667, 3., 4.]] """ - if _non_static_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_mish(x, threshold) + if _in_legacy_dygraph(): return _C_ops.mish(x, 'threshold', threshold) check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'mish') diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index ef47b841cf819..5573ecf33687b 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -83,6 +83,7 @@ def init_kernel_type(self): class TestExpm1(TestActivation): def setUp(self): self.op_type = "expm1" + self.python_api = paddle.expm1 self.init_dtype() np.random.seed(2049) @@ -93,7 +94,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) + + def test_check_output(self): + self.check_output(check_eager=True) class TestExpm1API(unittest.TestCase): @@ -3002,6 +3006,7 @@ def ref_mish(x, threshold=20.): class TestMish(TestActivation): def setUp(self): self.op_type = "mish" + self.python_api = paddle.fluid.layers.nn.mish self.init_dtype() np.random.seed(1024) @@ -3010,10 +3015,13 @@ def setUp(self): self.inputs = {'X': x} self.outputs = {'Out': out} + def test_check_output(self): + self.check_output(check_eager=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestMishAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py index 007affc140849..6ea24b4543f3f 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py @@ -29,6 +29,7 @@ def init_kernel_type(self): def setUp(self): self.op_type = "elementwise_floordiv" + self.python_api = paddle.floor_divide self.dtype = np.int32 self.axis = -1 self.init_dtype() @@ -44,7 +45,7 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def init_input_output(self): self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py index 08ffb564484b3..3c9e350360dd1 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py @@ -17,11 +17,13 @@ import numpy as np from op_test import OpTest, skip_check_grad_ci import paddle.fluid as fluid +import paddle class TestElementwisePowOp(OpTest): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(1, 2, [20, 5]).astype("float64"), 'Y': np.random.uniform(1, 2, [20, 5]).astype("float64") @@ -29,15 +31,22 @@ def setUp(self): self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} def test_check_output(self): - self.check_output() + if hasattr(self, 'attrs'): + self.check_output(check_eager=False) + else: + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + if hasattr(self, 'attrs'): + self.check_grad(['X', 'Y'], 'Out', check_eager=False) + else: + self.check_grad(['X', 'Y'], 'Out', check_eager=True) class TestElementwisePowOp_big_shape_1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(1, 2, [10, 10]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [10, 10]).astype("float64") @@ -48,6 +57,7 @@ def setUp(self): class TestElementwisePowOp_big_shape_2(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(1, 2, [10, 10]).astype("float64"), 'Y': np.random.uniform(0.2, 2, [10, 10]).astype("float64") @@ -60,6 +70,7 @@ def setUp(self): class TestElementwisePowOp_scalar(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(0.1, 1, [3, 3, 4]).astype(np.float64), 'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64) @@ -70,6 +81,7 @@ def setUp(self): class TestElementwisePowOp_tensor(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(0.1, 1, [100]).astype("float64"), 'Y': np.random.uniform(1, 3, [100]).astype("float64") @@ -80,6 +92,7 @@ def setUp(self): class TestElementwisePowOp_broadcast_0(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 1, 100]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64") @@ -90,6 +103,7 @@ def setUp(self): class TestElementwisePowOp_broadcast_1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 100, 1]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64") @@ -103,6 +117,7 @@ def setUp(self): class TestElementwisePowOp_broadcast_2(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(0.1, 1, [100, 3, 1]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64") @@ -117,6 +132,7 @@ def setUp(self): class TestElementwisePowOp_broadcast_3(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 20, 5, 1]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [20, 5]).astype("float64") @@ -131,6 +147,7 @@ def setUp(self): class TestElementwisePowOp_broadcast_4(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 10, 3, 5]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [2, 10, 1, 5]).astype("float64") @@ -141,11 +158,15 @@ def setUp(self): class TestElementwisePowOpInt(OpTest): def setUp(self): self.op_type = "elementwise_pow" + self.python_api = paddle.pow self.inputs = {'X': np.asarray([1, 3, 6]), 'Y': np.asarray([1, 1, 1])} self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} def test_check_output(self): - self.check_output() + if hasattr(self, 'attrs'): + self.check_output(check_eager=False) + else: + self.check_output(check_eager=True) class TestElementwisePowGradOpInt(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_logsumexp.py b/python/paddle/fluid/tests/unittests/test_logsumexp.py index 31c68b88b86a7..91eb65ef284a5 100644 --- a/python/paddle/fluid/tests/unittests/test_logsumexp.py +++ b/python/paddle/fluid/tests/unittests/test_logsumexp.py @@ -29,9 +29,16 @@ def ref_logsumexp(x, axis=None, keepdim=False, reduce_all=False): return out +def logsumexp_wrapper(x, axis=None, keepdim=False, allreduce=False): + if allreduce: + return paddle.logsumexp(x, None, keepdim) + return paddle.logsumexp(x, axis, keepdim) + + class TestLogsumexp(OpTest): def setUp(self): self.op_type = 'logsumexp' + self.python_api = logsumexp_wrapper self.shape = [2, 3, 4, 5] self.dtype = 'float64' self.axis = [-1] @@ -61,13 +68,14 @@ def set_attrs_addition(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad( ['X'], ['Out'], user_defined_grads=self.user_defined_grads, - user_defined_grad_outputs=self.user_defined_grad_outputs) + user_defined_grad_outputs=self.user_defined_grad_outputs, + check_eager=True) def calc_grad(self): dy = np.ones(1, dtype=self.dtype) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 66c50d16e7201..3bdda982ff4f1 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1220,7 +1220,9 @@ def mish(x, name=None): x = paddle.to_tensor([-5., 0., 5.]) out = F.mish(x) # [-0.03357624, 0., 4.99955208] """ - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_mish(x, 20) + if _in_legacy_dygraph(): return _C_ops.mish(x) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'mish') diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e932595fc378e..ccd5efbd580af 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1565,7 +1565,11 @@ def logsumexp(x, axis=None, keepdim=False, name=None): if axis is None or len(axis) == 0: axis = [0] - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if reduce_all: + axis = range(len(x.shape)) + return _C_ops.final_state_logsumexp(x, axis, keepdim, reduce_all) + if _in_legacy_dygraph(): return _C_ops.logsumexp(x, 'axis', axis, 'keepdim', keepdim, 'reduce_all', reduce_all) check_variable_and_dtype(x, 'x', diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 466c26d3f46c9..ece46837c6def 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -422,6 +422,15 @@ func : eigh backward : eigh_grad +- api : elementwise_pow + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta + kernel : + func : elementwise_pow + backward : elementwise_pow_grad + # elu - api : elu args : (Tensor x, float alpha) @@ -485,6 +494,16 @@ func : erfinv backward : erfinv_grad +- api : expm1 + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : expm1 + backward : expm1_grad + - api : flatten args : (Tensor x, int start_axis, int stop_axis) output : Tensor @@ -511,6 +530,14 @@ func : floor backward : floor_grad +- api : floor_divide + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta + kernel : + func : floor_divide + - api : fmax args : (Tensor x, Tensor y, int axis) output : Tensor(out) @@ -878,6 +905,15 @@ func : logsigmoid backward : logsigmoid_grad +- api : logsumexp + args : (Tensor x, int64_t[] axis, bool keepdim, bool reduce_all) + output : Tensor(out) + infer_meta : + func : LogsumexpInferMeta + kernel : + func : logsumexp + backward : logsumexp_grad + # masked_select - api : masked_select args : (Tensor x, Tensor mask) @@ -954,6 +990,16 @@ func : minimum backward : minimum_grad +- api : mish + args : (Tensor x, float lambda) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : mish + backward : mish_grad + - api : mode args : (Tensor x, int axis, bool keepdim) output : Tensor(out), Tensor(indices) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 48faa4682d742..6d046cb68d93d 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -271,6 +271,16 @@ kernel : func : eigh_grad +- backward_api : elementwise_pow_grad + forward : elementwise_pow(Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param: [x, y] + kernel : + func : elementwise_pow_grad + - backward_api : elu_grad forward : elu (Tensor x, float alpha) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, float alpha) @@ -302,6 +312,16 @@ kernel : func : erfinv_grad +- backward_api : expm1_grad + forward : expm1 (Tensor x) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : expm1_grad + - backward_api : floor_grad forward : floor(Tensor x) -> Tensor(out) args : (Tensor out_grad) @@ -514,6 +534,16 @@ kernel : func : logsigmoid_grad +- backward_api : logsumexp_grad + forward : logsumexp(Tensor x, int64_t[] axis, bool keepdim, bool reduce_all) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keepdim, bool reduce_all) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : logsumexp_grad + - backward_api : masked_select_grad forward : masked_select (Tensor x, Tensor mask) -> Tensor(out) args : (Tensor x, Tensor mask, Tensor out_grad) @@ -607,6 +637,16 @@ kernel : func : minimum_grad +- backward_api : mish_grad + forward : mish (Tensor x, float threshold) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float threshold) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : mish_grad + - backward_api : mode_grad forward : mode(Tensor x, int axis, bool keepdim) -> Tensor(out), Tensor(indices) args : (Tensor x, Tensor indices, Tensor out_grad, int axis, bool keepdim) From b0398c8e9db4f4608fd57b7b42df03558fb23366 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Sat, 2 Apr 2022 20:26:34 +0800 Subject: [PATCH 30/93] Add graph apis (#40809) * Add graph_reindex API * add graph_sample_neighbors api * Add buffer * delete VLOG * delete thrust::copy for output * add ShareDataWith * delete graph_reindex hashtable output * add graph_reindex dispensable * add reindex unittest, move memset to cuda kernel, change api * fix conflict * add reindex buffer for gpu version note * fix conflicts for op_func_generator * Add fisher_yates sampling, add dispensable, change infermeta * add dtype for edge_id * fix rocm ci and static check ci * add unittest * fix unittest * fix unittest * fix bug --- paddle/fluid/operators/graph_reindex_op.cc | 77 ++++ .../operators/graph_sample_neighbors_op.cc | 82 ++++ paddle/fluid/pybind/op_function_generator.h | 3 + paddle/phi/infermeta/multiary.cc | 97 +++++ paddle/phi/infermeta/multiary.h | 23 + .../phi/kernels/cpu/graph_reindex_kernel.cc | 84 ++++ .../cpu/graph_sample_neighbors_kernel.cc | 151 +++++++ paddle/phi/kernels/gpu/graph_reindex_funcs.h | 203 +++++++++ .../phi/kernels/gpu/graph_reindex_kernel.cu | 363 ++++++++++++++++ .../gpu/graph_sample_neighbors_kernel.cu | 393 ++++++++++++++++++ paddle/phi/kernels/graph_reindex_kernel.h | 33 ++ .../kernels/graph_sample_neighbors_kernel.h | 36 ++ paddle/phi/ops/compat/graph_reindex_sig.cc | 30 ++ .../ops/compat/graph_sample_neighbors_sig.cc | 30 ++ .../tests/unittests/test_graph_reindex.py | 113 +++++ .../unittests/test_graph_sample_neighbors.py | 209 ++++++++++ python/paddle/incubate/__init__.py | 4 + python/paddle/incubate/operators/__init__.py | 2 + .../incubate/operators/graph_reindex.py | 127 ++++++ .../operators/graph_sample_neighbors.py | 150 +++++++ 20 files changed, 2210 insertions(+) create mode 100644 paddle/fluid/operators/graph_reindex_op.cc create mode 100644 paddle/fluid/operators/graph_sample_neighbors_op.cc create mode 100644 paddle/phi/kernels/cpu/graph_reindex_kernel.cc create mode 100644 paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc create mode 100644 paddle/phi/kernels/gpu/graph_reindex_funcs.h create mode 100644 paddle/phi/kernels/gpu/graph_reindex_kernel.cu create mode 100644 paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu create mode 100644 paddle/phi/kernels/graph_reindex_kernel.h create mode 100644 paddle/phi/kernels/graph_sample_neighbors_kernel.h create mode 100644 paddle/phi/ops/compat/graph_reindex_sig.cc create mode 100644 paddle/phi/ops/compat/graph_sample_neighbors_sig.cc create mode 100644 python/paddle/fluid/tests/unittests/test_graph_reindex.py create mode 100644 python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py create mode 100644 python/paddle/incubate/operators/graph_reindex.py create mode 100644 python/paddle/incubate/operators/graph_sample_neighbors.py diff --git a/paddle/fluid/operators/graph_reindex_op.cc b/paddle/fluid/operators/graph_reindex_op.cc new file mode 100644 index 0000000000000..593de659c7608 --- /dev/null +++ b/paddle/fluid/operators/graph_reindex_op.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +class GraphReindexOP : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class GraphReindexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The destination nodes of the input graph."); + AddInput("Neighbors", "The neighbor nodes of the destination nodes `X`."); + AddInput("Count", "The number of neighbor nodes of each destination node."); + // Note(daisiming): If using buffer hashtable, we must ensure the number of + // nodes of the input graph should be no larger than maximum(int32). + AddInput("HashTable_Value", + "One of the buffer tensor of hashtable for reindex") + .AsDispensable(); + AddInput("HashTable_Index", + "One of the buffer tensor of hashtable for reindex") + .AsDispensable(); + AddAttr("flag_buffer_hashtable", + "Define whether using the buffer hashtable.") + .SetDefault(false); + AddOutput("Reindex_Src", + "The source node index of graph edges after reindex."); + AddOutput("Reindex_Dst", + "The destination node index of graph edges after reindex."); + AddOutput("Out_Nodes", "The original index of graph nodes before reindex"); + + AddComment(R"DOC( +Graph Reindex operator. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(graph_reindex, GraphReindexInferShapeFunctor, + PD_INFER_META(phi::GraphReindexInferMeta)); + +REGISTER_OPERATOR( + graph_reindex, ops::GraphReindexOP, ops::GraphReindexOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + GraphReindexInferShapeFunctor); diff --git a/paddle/fluid/operators/graph_sample_neighbors_op.cc b/paddle/fluid/operators/graph_sample_neighbors_op.cc new file mode 100644 index 0000000000000..5ac9e2d4e4519 --- /dev/null +++ b/paddle/fluid/operators/graph_sample_neighbors_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +class GraphSampleNeighborsOP : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Row"), + ctx.device_context()); + } +}; + +class GraphSampleNeighborsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Row", + "One of the components of the CSC format of the input graph."); + AddInput("Col_Ptr", + "One of the components of the CSC format of the input graph."); + AddInput("X", "The input center nodes index tensor."); + AddInput("Eids", "The edge ids of the input graph.").AsDispensable(); + AddInput("Perm_Buffer", "Permutation buffer for fisher-yates sampling.") + .AsDispensable(); + AddOutput("Out", "The neighbors of input nodes X after sampling."); + AddOutput("Out_Count", + "The number of sample neighbors of input nodes respectively."); + AddOutput("Out_Eids", "The eids of the sample edges"); + AddAttr( + "sample_size", "The sample size of graph sample neighbors method. ", + "Set default value as -1, means return all neighbors of nodes.") + .SetDefault(-1); + AddAttr("return_eids", + "Whether to return the eid of the sample edges.") + .SetDefault(false); + AddAttr("flag_perm_buffer", + "Using the permutation for fisher-yates sampling in GPU" + "Set default value as false, means not using it.") + .SetDefault(false); + AddComment(R"DOC( +Graph Learning Sampling Neighbors operator, for graphsage sampling method. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(graph_sample_neighbors, + GraphSampleNeighborsInferShapeFunctor, + PD_INFER_META(phi::GraphSampleNeighborsInferMeta)); + +REGISTER_OPERATOR( + graph_sample_neighbors, ops::GraphSampleNeighborsOP, + ops::GraphSampleNeighborsOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + GraphSampleNeighborsInferShapeFunctor); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 10c8a90ae0a36..1e501a0c9e024 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -105,6 +105,9 @@ std::map> op_ins_map = { {"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}}, {"crf_decoding", {"Emission", "Transition", "Label", "Length"}}, {"chunk_eval", {"Inference", "Label", "SeqLength"}}, + {"graph_reindex", + {"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"}}, + {"graph_sample_neighbors", {"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1f6cf1a6882d8..8e4f0b1fbb5c9 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1775,6 +1775,103 @@ void WhereInferMeta(const MetaTensor& condition, out->share_meta(x); } +void GraphReindexInferMeta(const MetaTensor& x, + const MetaTensor& neighbors, + const MetaTensor& count, + paddle::optional hashtable_value, + paddle::optional hashtable_index, + bool flag_buffer_hashtable, + MetaTensor* reindex_src, + MetaTensor* reindex_dst, + MetaTensor* out_nodes) { + auto GraphReindexShapeCheck = [](const phi::DDim& dims, + std::string tensor_name) { + if (dims.size() == 2) { + PADDLE_ENFORCE_EQ( + dims[1], + 1, + phi::errors::InvalidArgument("The last dim of %s should be 1 when it " + "is 2D, but we get %d", + tensor_name, + dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dims.size(), + 1, + phi::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, + dims.size())); + } + }; + + GraphReindexShapeCheck(x.dims(), "X"); + GraphReindexShapeCheck(neighbors.dims(), "Neighbors"); + GraphReindexShapeCheck(count.dims(), "Count"); + if (flag_buffer_hashtable) { + GraphReindexShapeCheck(hashtable_value->dims(), "HashTable_Value"); + GraphReindexShapeCheck(hashtable_index->dims(), "HashTable_Index"); + } + + reindex_src->set_dims({-1}); + reindex_src->set_dtype(neighbors.dtype()); + reindex_dst->set_dims({-1}); + reindex_dst->set_dtype(neighbors.dtype()); + out_nodes->set_dims({-1}); + out_nodes->set_dtype(x.dtype()); +} + +void GraphSampleNeighborsInferMeta( + const MetaTensor& row, + const MetaTensor& col_ptr, + const MetaTensor& x, + paddle::optional eids, + paddle::optional perm_buffer, + int sample_size, + bool return_eids, + bool flag_perm_buffer, + MetaTensor* out, + MetaTensor* out_count, + MetaTensor* out_eids) { + // GSN: GraphSampleNeighbors + auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) { + if (dims.size() == 2) { + PADDLE_ENFORCE_EQ( + dims[1], + 1, + phi::errors::InvalidArgument("The last dim of %s should be 1 when it " + "is 2D, but we get %d", + tensor_name, + dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dims.size(), + 1, + phi::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, + dims.size())); + } + }; + + GSNShapeCheck(row.dims(), "Row"); + GSNShapeCheck(col_ptr.dims(), "Col_Ptr"); + GSNShapeCheck(x.dims(), "X"); + if (return_eids) { + GSNShapeCheck(eids->dims(), "Eids"); + out_eids->set_dims({-1}); + out_eids->set_dtype(row.dtype()); + } + if (flag_perm_buffer) { + GSNShapeCheck(perm_buffer->dims(), "Perm_Buffer"); + } + + out->set_dims({-1}); + out->set_dtype(row.dtype()); + out_count->set_dims({-1}); + out_count->set_dtype(DataType::INT32); +} + void Yolov3LossInferMeta(const MetaTensor& x, const MetaTensor& gt_box, const MetaTensor& gt_label, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index b748d898c1e4e..72c64e8500ad2 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -265,6 +265,29 @@ void WhereInferMeta(const MetaTensor& condition, const MetaTensor& y, MetaTensor* out); +void GraphReindexInferMeta(const MetaTensor& x, + const MetaTensor& neighbors, + const MetaTensor& count, + paddle::optional hashtable_value, + paddle::optional hashtable_index, + bool flag_buffer_hashtable, + MetaTensor* reindex_src, + MetaTensor* reindex_dst, + MetaTensor* out_nodes); + +void GraphSampleNeighborsInferMeta( + const MetaTensor& row, + const MetaTensor& col_ptr, + const MetaTensor& x, + paddle::optional eids, + paddle::optional perm_buffer, + int sample_size, + bool return_eids, + bool flag_perm_buffer, + MetaTensor* out, + MetaTensor* out_count, + MetaTensor* out_eids); + void Yolov3LossInferMeta(const MetaTensor& x, const MetaTensor& gt_box, const MetaTensor& gt_label, diff --git a/paddle/phi/kernels/cpu/graph_reindex_kernel.cc b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc new file mode 100644 index 0000000000000..d6454b4796430 --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/phi/kernels/graph_reindex_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GraphReindexKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& neighbors, + const DenseTensor& count, + paddle::optional hashtable_value, + paddle::optional hashtable_index, + bool flag_buffer_hashtable, + DenseTensor* reindex_src, + DenseTensor* reindex_dst, + DenseTensor* out_nodes) { + const T* x_data = x.data(); + const T* neighbors_data = neighbors.data(); + const int* count_data = count.data(); + const int bs = x.dims()[0]; + const int num_edges = neighbors.dims()[0]; + + std::unordered_map node_map; + std::vector unique_nodes; + int reindex_id = 0; + for (int i = 0; i < bs; i++) { + T node = x_data[i]; + unique_nodes.emplace_back(node); + node_map[node] = reindex_id++; + } + // Reindex Src + std::vector src(num_edges); + std::vector dst(num_edges); + for (int i = 0; i < num_edges; i++) { + T node = neighbors_data[i]; + if (node_map.find(node) == node_map.end()) { + unique_nodes.emplace_back(node); + node_map[node] = reindex_id++; + } + src[i] = node_map[node]; + } + // Reindex Dst + int cnt = 0; + for (int i = 0; i < bs; i++) { + for (int j = 0; j < count_data[i]; j++) { + T node = x_data[i]; + dst[cnt++] = node_map[node]; + } + } + + reindex_src->Resize({num_edges}); + T* reindex_src_data = dev_ctx.template Alloc(reindex_src); + std::copy(src.begin(), src.end(), reindex_src_data); + reindex_dst->Resize({num_edges}); + T* reindex_dst_data = dev_ctx.template Alloc(reindex_dst); + std::copy(dst.begin(), dst.end(), reindex_dst_data); + out_nodes->Resize({static_cast(unique_nodes.size())}); + T* out_nodes_data = dev_ctx.template Alloc(out_nodes); + std::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + graph_reindex, CPU, ALL_LAYOUT, phi::GraphReindexKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc new file mode 100644 index 0000000000000..e18848af0dc08 --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/kernels/graph_sample_neighbors_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SampleUniqueNeighbors( + bidiiter begin, + bidiiter end, + int num_samples, + std::mt19937& rng, + std::uniform_int_distribution& dice_distribution) { + int left_num = std::distance(begin, end); + for (int i = 0; i < num_samples; i++) { + bidiiter r = begin; + int random_step = dice_distribution(rng) % left_num; + std::advance(r, random_step); + std::swap(*begin, *r); + ++begin; + --left_num; + } +} + +template +void SampleNeighbors(const T* row, + const T* col_ptr, + const T* input, + std::vector* output, + std::vector* output_count, + int sample_size, + int bs) { + // Allocate the memory of output + // Collect the neighbors size + std::vector> out_src_vec; + // `sample_cumsum_sizes` record the start position and end position + // after sampling. + std::vector sample_cumsum_sizes(bs + 1); + // `total_neighbors` the size of output after sample. + int total_neighbors = 0; + sample_cumsum_sizes[0] = total_neighbors; + for (int i = 0; i < bs; i++) { + T node = input[i]; + int cap = col_ptr[node + 1] - col_ptr[node]; + int k = cap > sample_size ? sample_size : cap; + total_neighbors += k; + sample_cumsum_sizes[i + 1] = total_neighbors; + std::vector out_src; + out_src.resize(cap); + out_src_vec.emplace_back(out_src); + } + + output_count->resize(bs); + output->resize(total_neighbors); + + std::random_device rd; + std::mt19937 rng{rd()}; + std::uniform_int_distribution dice_distribution( + 0, std::numeric_limits::max()); + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + // Sample the neighbors in parallelism. + for (int i = 0; i < bs; i++) { + T node = input[i]; + T begin = col_ptr[node], end = col_ptr[node + 1]; + int cap = end - begin; + if (sample_size < cap) { + std::copy(row + begin, row + end, out_src_vec[i].begin()); + // TODO(daisiming): Check whether is correct. + SampleUniqueNeighbors(out_src_vec[i].begin(), + out_src_vec[i].end(), + sample_size, + rng, + dice_distribution); + *(output_count->data() + i) = sample_size; + } else { + std::copy(row + begin, row + end, out_src_vec[i].begin()); + *(output_count->data() + i) = cap; + } + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + // Copy the results parallelism + for (int i = 0; i < bs; i++) { + int k = sample_cumsum_sizes[i + 1] - sample_cumsum_sizes[i]; + std::copy(out_src_vec[i].begin(), + out_src_vec[i].begin() + k, + output->data() + sample_cumsum_sizes[i]); + } +} + +template +void GraphSampleNeighborsKernel( + const Context& dev_ctx, + const DenseTensor& row, + const DenseTensor& col_ptr, + const DenseTensor& x, + paddle::optional eids, + paddle::optional perm_buffer, + int sample_size, + bool return_eids, + bool flag_perm_buffer, + DenseTensor* out, + DenseTensor* out_count, + DenseTensor* out_eids) { + const T* row_data = row.data(); + const T* col_ptr_data = col_ptr.data(); + const T* x_data = x.data(); + int bs = x.dims()[0]; + + std::vector output; + std::vector output_count; + SampleNeighbors( + row_data, col_ptr_data, x_data, &output, &output_count, sample_size, bs); + out->Resize({static_cast(output.size())}); + T* out_data = dev_ctx.template Alloc(out); + std::copy(output.begin(), output.end(), out_data); + out_count->Resize({bs}); + int* out_count_data = dev_ctx.template Alloc(out_count); + std::copy(output_count.begin(), output_count.end(), out_count_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_sample_neighbors, + CPU, + ALL_LAYOUT, + phi::GraphSampleNeighborsKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/graph_reindex_funcs.h b/paddle/phi/kernels/gpu/graph_reindex_funcs.h new file mode 100644 index 0000000000000..ea4f67e9d47e3 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_reindex_funcs.h @@ -0,0 +1,203 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/graph_reindex_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" + +namespace phi { + +template +inline __device__ size_t Hash(T id, int64_t size) { + return id % size; +} + +template +inline __device__ bool AttemptInsert( + size_t pos, T id, int index, T* keys, int* key_index) { + if (sizeof(T) == 4) { + const T key = atomicCAS(reinterpret_cast(&keys[pos]), + static_cast(-1), + static_cast(id)); + if (key == -1 || key == id) { + atomicMin(reinterpret_cast(&key_index[pos]), // NOLINT + static_cast(index)); // NOLINT + return true; + } else { + return false; + } + } else if (sizeof(T) == 8) { + const T key = atomicCAS( + reinterpret_cast(&keys[pos]), // NOLINT + static_cast(-1), // NOLINT + static_cast(id)); // NOLINT + if (key == -1 || key == id) { + atomicMin(reinterpret_cast(&key_index[pos]), // NOLINT + static_cast(index)); // NOLINT + return true; + } else { + return false; + } + } +} + +template +inline __device__ void Insert( + T id, int index, int64_t size, T* keys, int* key_index) { + size_t pos = Hash(id, size); + size_t delta = 1; + while (!AttemptInsert(pos, id, index, keys, key_index)) { + pos = Hash(pos + delta, size); + delta += 1; + } +} + +template +inline __device__ int64_t Search(T id, const T* keys, int64_t size) { + int64_t pos = Hash(id, size); + + int64_t delta = 1; + while (keys[pos] != id) { + pos = Hash(pos + delta, size); + delta += 1; + } + + return pos; +} + +template +__global__ void BuildHashTable( + const T* items, int num_items, int64_t size, T* keys, int* key_index) { + CUDA_KERNEL_LOOP(index, num_items) { + Insert(items[index], index, size, keys, key_index); + } +} + +template +__global__ void BuildHashTable(const T* items, int num_items, int* key_index) { + CUDA_KERNEL_LOOP(index, num_items) { + atomicMin( + reinterpret_cast(&key_index[items[index]]), // NOLINT + static_cast(index)); // NOLINT + } +} + +template +__global__ void ResetHashTable(const T* items, + int num_items, + int* key_index, + int* values) { + CUDA_KERNEL_LOOP(index, num_items) { + key_index[items[index]] = -1; + values[items[index]] = -1; + } +} + +template +__global__ void GetItemIndexCount(const T* items, + int* item_count, + int num_items, + int64_t size, + const T* keys, + int* key_index) { + CUDA_KERNEL_LOOP(i, num_items) { + int64_t pos = Search(items[i], keys, size); + if (key_index[pos] == i) { + item_count[i] = 1; + } + } +} + +template +__global__ void GetItemIndexCount(const T* items, + int* item_count, + int num_items, + int* key_index) { + CUDA_KERNEL_LOOP(i, num_items) { + if (key_index[items[i]] == i) { + item_count[i] = 1; + } + } +} + +template +__global__ void FillUniqueItems(const T* items, + int num_items, + int64_t size, + T* unique_items, + const int* item_count, + const T* keys, + int* values, + int* key_index) { + CUDA_KERNEL_LOOP(i, num_items) { + int64_t pos = Search(items[i], keys, size); + if (key_index[pos] == i) { + values[pos] = item_count[i]; + unique_items[item_count[i]] = items[i]; + } + } +} + +template +__global__ void FillUniqueItems(const T* items, + int num_items, + T* unique_items, + const int* item_count, + int* values, + int* key_index) { + CUDA_KERNEL_LOOP(i, num_items) { + if (key_index[items[i]] == i) { + values[items[i]] = item_count[i]; + unique_items[item_count[i]] = items[i]; + } + } +} + +template +__global__ void ReindexSrcOutput(T* src_output, + int num_items, + int64_t size, + const T* keys, + const int* values) { + CUDA_KERNEL_LOOP(i, num_items) { + int64_t pos = Search(src_output[i], keys, size); + src_output[i] = values[pos]; + } +} + +template +__global__ void ReindexSrcOutput(T* src_output, + int num_items, + const int* values) { + CUDA_KERNEL_LOOP(i, num_items) { src_output[i] = values[src_output[i]]; } +} + +template +__global__ void ReindexInputNodes(const T* nodes, + int num_items, + T* reindex_nodes, + int64_t size, + const T* keys, + const int* values) { + CUDA_KERNEL_LOOP(i, num_items) { + int64_t pos = Search(nodes[i], keys, size); + reindex_nodes[i] = values[pos]; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu new file mode 100644 index 0000000000000..34bd1d6db77da --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu @@ -0,0 +1,363 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "paddle/phi/kernels/gpu/graph_reindex_funcs.h" +#include "paddle/phi/kernels/graph_reindex_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +constexpr int WARP_SIZE = 32; + +template +void FillHashTable(const Context& dev_ctx, + const T* input, + int num_input, + int64_t len_hashtable, + thrust::device_vector* unique_items, + T* keys, + int* values, + int* key_index) { +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int grid_tmp = (num_input + block - 1) / block; + int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + // Insert data into keys and values. + BuildHashTable<<>>( + input, num_input, len_hashtable, keys, key_index); + + // Get item index count. + thrust::device_vector item_count(num_input + 1, 0); + GetItemIndexCount<<>>( + input, + thrust::raw_pointer_cast(item_count.data()), + num_input, + len_hashtable, + keys, + key_index); + + thrust::exclusive_scan( + item_count.begin(), item_count.end(), item_count.begin()); + size_t total_unique_items = item_count[num_input]; + unique_items->resize(total_unique_items); + + // Get unique items + FillUniqueItems<<>>( + input, + num_input, + len_hashtable, + thrust::raw_pointer_cast(unique_items->data()), + thrust::raw_pointer_cast(item_count.data()), + keys, + values, + key_index); +} + +template +void FillBufferHashTable(const Context& dev_ctx, + const T* input, + int num_input, + thrust::device_vector* unique_items, + int* values, + int* key_index) { +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int grid_tmp = (num_input + block - 1) / block; + int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + // Insert data. + BuildHashTable<<>>( + input, num_input, key_index); + + // Get item index count. + thrust::device_vector item_count(num_input + 1, 0); + GetItemIndexCount<<>>( + input, thrust::raw_pointer_cast(item_count.data()), num_input, key_index); + + thrust::exclusive_scan( + item_count.begin(), item_count.end(), item_count.begin()); + size_t total_unique_items = item_count[num_input]; + unique_items->resize(total_unique_items); + + // Get unique items + FillUniqueItems<<>>( + input, + num_input, + thrust::raw_pointer_cast(unique_items->data()), + thrust::raw_pointer_cast(item_count.data()), + values, + key_index); +} + +template +void ResetBufferHashTable(const Context& dev_ctx, + const T* input, + int num_input, + thrust::device_vector* unique_items, + int* values, + int* key_index) { +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int grid_tmp = (unique_items->size() + block - 1) / block; + int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + ResetHashTable<<>>( + thrust::raw_pointer_cast(unique_items->data()), + unique_items->size(), + key_index, + values); +} + +template +void Reindex(const Context& dev_ctx, + const T* inputs, + thrust::device_ptr src_outputs, + thrust::device_vector* out_nodes, + int num_inputs, + int num_edges) { + out_nodes->resize(num_inputs + num_edges); + thrust::copy(inputs, inputs + num_inputs, out_nodes->begin()); + thrust::copy( + src_outputs, src_outputs + num_edges, out_nodes->begin() + num_inputs); + thrust::device_vector unique_nodes; + unique_nodes.clear(); + + // Fill hash table + int64_t num = out_nodes->size(); + int64_t log_num = 1 << static_cast(1 + std::log2(num >> 1)); + int64_t table_size = log_num << 1; + T* keys; + int *values, *key_index; + +#ifdef PADDLE_WITH_HIP + hipMalloc(&keys, table_size * sizeof(T)); + hipMalloc(&values, table_size * sizeof(int)); + hipMalloc(&key_index, table_size * sizeof(int)); + hipMemset(keys, -1, table_size * sizeof(T)); + hipMemset(values, -1, table_size * sizeof(int)); + hipMemset(key_index, -1, table_size * sizeof(int)); +#else + cudaMalloc(&keys, table_size * sizeof(T)); + cudaMalloc(&values, table_size * sizeof(int)); + cudaMalloc(&key_index, table_size * sizeof(int)); + cudaMemset(keys, -1, table_size * sizeof(T)); + cudaMemset(values, -1, table_size * sizeof(int)); + cudaMemset(key_index, -1, table_size * sizeof(int)); +#endif + + FillHashTable(dev_ctx, + thrust::raw_pointer_cast(out_nodes->data()), + out_nodes->size(), + table_size, + &unique_nodes, + keys, + values, + key_index); + out_nodes->resize(unique_nodes.size()); + thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes->begin()); + +// Fill outputs with reindex result. +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int grid_tmp = (num_edges + block - 1) / block; + int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + ReindexSrcOutput<<>>( + thrust::raw_pointer_cast(src_outputs), + num_edges, + table_size, + keys, + values); +#ifdef PADDLE_WITH_HIP + hipFree(keys); + hipFree(values); + hipFree(key_index); +#else + cudaFree(keys); + cudaFree(values); + cudaFree(key_index); +#endif +} + +template +void BufferReindex(const Context& dev_ctx, + const T* inputs, + thrust::device_ptr src_outputs, + thrust::device_vector* out_nodes, + int num_inputs, + int* hashtable_value, + int* hashtable_index, + int num_edges) { + out_nodes->resize(num_inputs + num_edges); + thrust::copy(inputs, inputs + num_inputs, out_nodes->begin()); + thrust::copy( + src_outputs, src_outputs + num_edges, out_nodes->begin() + num_inputs); + thrust::device_vector unique_nodes; + unique_nodes.clear(); + + // Fill hash table + FillBufferHashTable(dev_ctx, + thrust::raw_pointer_cast(out_nodes->data()), + out_nodes->size(), + &unique_nodes, + hashtable_value, + hashtable_index); + out_nodes->resize(unique_nodes.size()); + thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes->begin()); + +// Fill outputs with reindex result. +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int grid_tmp = (num_edges + block - 1) / block; + int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + ReindexSrcOutput<<>>( + thrust::raw_pointer_cast(src_outputs), num_edges, hashtable_value); + + ResetBufferHashTable(dev_ctx, + thrust::raw_pointer_cast(out_nodes->data()), + out_nodes->size(), + &unique_nodes, + hashtable_value, + hashtable_index); +} + +template +__global__ void GetDstEdgeCUDAKernel(const int64_t num_rows, + const int* in_rows, + const int* dst_counts, + const int* dst_ptr, + T* dst_outputs) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; + const int64_t last_row = + min(static_cast(blockIdx.x + 1) * TILE_SIZE, num_rows); + + while (out_row < last_row) { + const int row = in_rows[out_row]; + const int dst_sample_size = dst_counts[out_row]; + const int out_row_start = dst_ptr[out_row]; + for (int idx = threadIdx.x; idx < dst_sample_size; idx += WARP_SIZE) { + dst_outputs[out_row_start + idx] = row; + } + out_row += BLOCK_WARPS; + } +} + +template +void GraphReindexKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& neighbors, + const DenseTensor& count, + paddle::optional hashtable_value, + paddle::optional hashtable_index, + bool flag_buffer_hashtable, + DenseTensor* reindex_src, + DenseTensor* reindex_dst, + DenseTensor* out_nodes) { + const T* x_data = x.data(); + const T* neighbors_data = neighbors.data(); + const int* count_data = count.data(); + const int bs = x.dims()[0]; + const int num_edges = neighbors.dims()[0]; + reindex_src->Resize({num_edges}); + + T* reindex_src_data = dev_ctx.template Alloc(reindex_src); + thrust::device_ptr src_outputs(reindex_src_data); + + thrust::device_vector unique_nodes; + thrust::copy(neighbors_data, neighbors_data + num_edges, src_outputs); + + if (flag_buffer_hashtable) { + // Here we directly use buffer tensor to act as a hash table. + DenseTensor hashtable_value_out(hashtable_value->type()); + const auto* ph_value = hashtable_value.get_ptr(); + hashtable_value_out.ShareDataWith(*ph_value); + DenseTensor hashtable_index_out(hashtable_index->type()); + const auto* ph_index = hashtable_index.get_ptr(); + hashtable_index_out.ShareDataWith(*ph_index); + int* hashtable_value_data = + hashtable_value_out.mutable_data(dev_ctx.GetPlace()); + int* hashtable_index_data = + hashtable_index_out.mutable_data(dev_ctx.GetPlace()); + BufferReindex(dev_ctx, + x_data, + src_outputs, + &unique_nodes, + bs, + hashtable_value_data, + hashtable_index_data, + num_edges); + } else { + Reindex( + dev_ctx, x_data, src_outputs, &unique_nodes, bs, num_edges); + } + + // Get reindex dst edge. + thrust::device_vector unique_dst_reindex(bs); + thrust::sequence(unique_dst_reindex.begin(), unique_dst_reindex.end()); + thrust::device_vector dst_ptr(bs); + thrust::exclusive_scan(count_data, count_data + bs, dst_ptr.begin()); + constexpr int BLOCK_WARPS = 128 / WARP_SIZE; + constexpr int TILE_SIZE = BLOCK_WARPS * 16; + const dim3 block(WARP_SIZE, BLOCK_WARPS); + const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); + + reindex_dst->Resize({num_edges}); + T* reindex_dst_data = dev_ctx.template Alloc(reindex_dst); + + GetDstEdgeCUDAKernel<<>>( + bs, + thrust::raw_pointer_cast(unique_dst_reindex.data()), + count_data, + thrust::raw_pointer_cast(dst_ptr.data()), + reindex_dst_data); + + out_nodes->Resize({static_cast(unique_nodes.size())}); + T* out_nodes_data = dev_ctx.template Alloc(out_nodes); + thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + graph_reindex, GPU, ALL_LAYOUT, phi::GraphReindexKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu new file mode 100644 index 0000000000000..1757b6b98dbf9 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu @@ -0,0 +1,393 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#ifdef PADDLE_WITH_HIP +#include +#include +#else +#include +#include +#endif + +#include "paddle/phi/kernels/graph_sample_neighbors_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +struct DegreeFunctor { + const T* col_ptr; + HOSTDEVICE explicit inline DegreeFunctor(const T* x) { this->col_ptr = x; } + HOSTDEVICE inline int operator()(T i) const { + return col_ptr[i + 1] - col_ptr[i]; + } +}; + +struct MaxFunctor { + int cap; + HOSTDEVICE explicit inline MaxFunctor(int cap) { this->cap = cap; } + HOSTDEVICE inline int operator()(int x) const { + if (x > cap) { + return cap; + } + return x; + } +}; + +template +__global__ void SampleKernel(const uint64_t rand_seed, + int k, + const int64_t num_nodes, + const T* nodes, + const T* row, + const T* col_ptr, + T* output, + int* output_ptr, + int* output_idxs) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; + const int64_t last_row = + min(static_cast(blockIdx.x + 1) * TILE_SIZE, num_nodes); +#ifdef PADDLE_WITH_HIP + hiprandState rng; + hiprand_init(rand_seed * gridDim.x + blockIdx.x, + threadIdx.y * WARP_SIZE + threadIdx.x, + 0, + &rng); +#else + curandState rng; + curand_init(rand_seed * gridDim.x + blockIdx.x, + threadIdx.y * WARP_SIZE + threadIdx.x, + 0, + &rng); +#endif + + while (out_row < last_row) { + T node = nodes[out_row]; + T in_row_start = col_ptr[node]; + int deg = col_ptr[node + 1] - in_row_start; + int out_row_start = output_ptr[out_row]; + + if (deg <= k) { + for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { + output[out_row_start + idx] = row[in_row_start + idx]; + } + } else { + for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { + output_idxs[out_row_start + idx] = idx; + } +#ifdef PADDLE_WITH_CUDA + __syncwarp(); +#endif + + for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) { +#ifdef PADDLE_WITH_HIP + const int num = hiprand(&rng) % (idx + 1); +#else + const int num = curand(&rng) % (idx + 1); +#endif + if (num < k) { + atomicMax(reinterpret_cast( // NOLINT + output_idxs + out_row_start + num), + static_cast(idx)); // NOLINT + } + } +#ifdef PADDLE_WITH_CUDA + __syncwarp(); +#endif + + for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { + T perm_idx = output_idxs[out_row_start + idx] + in_row_start; + output[out_row_start + idx] = row[perm_idx]; + } + } + + out_row += BLOCK_WARPS; + } +} + +template +int GetTotalSampleNum(const thrust::device_ptr input, + const T* col_ptr, + thrust::device_ptr output_count, + int sample_size, + int bs) { + thrust::transform(input, input + bs, output_count, DegreeFunctor(col_ptr)); + if (sample_size >= 0) { + thrust::transform( + output_count, output_count + bs, output_count, MaxFunctor(sample_size)); + } + int total_sample_num = thrust::reduce(output_count, output_count + bs); + return total_sample_num; +} + +template +void SampleNeighbors(const Context& dev_ctx, + const T* row, + const T* col_ptr, + const thrust::device_ptr input, + thrust::device_ptr output, + thrust::device_ptr output_count, + int sample_size, + int bs, + int total_sample_num) { + thrust::device_vector output_ptr; + thrust::device_vector output_idxs; + output_ptr.resize(bs); + output_idxs.resize(total_sample_num); + thrust::exclusive_scan( + output_count, output_count + bs, output_ptr.begin(), 0); + + constexpr int WARP_SIZE = 32; + constexpr int BLOCK_WARPS = 128 / WARP_SIZE; + constexpr int TILE_SIZE = BLOCK_WARPS * 16; + const dim3 block(WARP_SIZE, BLOCK_WARPS); + const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); + SampleKernel<<>>( + 0, + sample_size, + bs, + thrust::raw_pointer_cast(input), + row, + col_ptr, + thrust::raw_pointer_cast(output), + thrust::raw_pointer_cast(output_ptr.data()), + thrust::raw_pointer_cast(output_idxs.data())); +} + +template +__global__ void FisherYatesSampleKernel(const uint64_t rand_seed, + int k, + const int64_t num_rows, + const T* in_rows, + T* src, + const T* dst_count) { +#ifdef PADDLE_WITH_HIP + hiprandState rng; + hiprand_init( + rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); +#else + curandState rng; + curand_init( + rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); +#endif + CUDA_KERNEL_LOOP(out_row, num_rows) { + const T row = in_rows[out_row]; + const T in_row_start = dst_count[row]; + const int deg = dst_count[row + 1] - in_row_start; + int split; + T tmp; + + if (k < deg) { + if (deg < 2 * k) { + split = k; + } else { + split = deg - k; + } + for (int idx = deg - 1; idx >= split; idx--) { +#ifdef PADDLE_WITH_HIP + const int num = hiprand(&rng) % (idx + 1); +#else + const int num = curand(&rng) % (idx + 1); +#endif + src[in_row_start + idx] = static_cast( + atomicExch(reinterpret_cast( // NOLINT + src + in_row_start + num), + static_cast( // NOLINT + src[in_row_start + idx]))); + } + } + } +} + +template +__global__ void GatherEdge(int k, + int64_t num_rows, + const T* in_rows, + const T* src, + const T* dst_count, + T* outputs, + int* output_ptr, + T* perm_data) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; + const int64_t last_row = + min(static_cast(blockIdx.x + 1) * TILE_SIZE, num_rows); + + while (out_row < last_row) { + const T row = in_rows[out_row]; + const T in_row_start = dst_count[row]; + const int deg = dst_count[row + 1] - in_row_start; + const T out_row_start = output_ptr[out_row]; + + if (deg <= k) { + for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { + const T in_idx = in_row_start + idx; + outputs[out_row_start + idx] = src[in_idx]; + } + } else { + int split = k; + int begin, end; + if (deg < 2 * k) { + begin = 0; + end = k; + } else { + begin = deg - k; + end = deg; + } + + for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) { + outputs[out_row_start + idx - begin] = + src[perm_data[in_row_start + idx]]; + } + } + out_row += BLOCK_WARPS; + } +} + +template +void FisherYatesSampleNeighbors(const Context& dev_ctx, + const T* row, + const T* col_ptr, + T* perm_data, + const thrust::device_ptr input, + thrust::device_ptr output, + thrust::device_ptr output_count, + int sample_size, + int bs, + int total_sample_num) { + thrust::device_vector output_ptr; + output_ptr.resize(bs); + thrust::exclusive_scan( + output_count, output_count + bs, output_ptr.begin(), 0); + +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int grid_tmp = (bs + block - 1) / block; + int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + + FisherYatesSampleKernel<<>>( + 0, sample_size, bs, thrust::raw_pointer_cast(input), perm_data, col_ptr); + + constexpr int GATHER_WARP_SIZE = 32; + constexpr int GATHER_BLOCK_WARPS = 128 / GATHER_WARP_SIZE; + constexpr int GATHER_TILE_SIZE = GATHER_BLOCK_WARPS * 16; + const dim3 gather_block(GATHER_WARP_SIZE, GATHER_BLOCK_WARPS); + const dim3 gather_grid((bs + GATHER_TILE_SIZE - 1) / GATHER_TILE_SIZE); + + GatherEdge< + T, + GATHER_WARP_SIZE, + GATHER_BLOCK_WARPS, + GATHER_TILE_SIZE><<>>( + sample_size, + bs, + thrust::raw_pointer_cast(input), + row, + col_ptr, + thrust::raw_pointer_cast(output), + thrust::raw_pointer_cast(output_ptr.data()), + perm_data); +} + +template +void GraphSampleNeighborsKernel( + const Context& dev_ctx, + const DenseTensor& row, + const DenseTensor& col_ptr, + const DenseTensor& x, + paddle::optional eids, + paddle::optional perm_buffer, + int sample_size, + bool return_eids, + bool flag_perm_buffer, + DenseTensor* out, + DenseTensor* out_count, + DenseTensor* out_eids) { + auto* row_data = row.data(); + auto* col_ptr_data = col_ptr.data(); + auto* x_data = x.data(); + int bs = x.dims()[0]; + + const thrust::device_ptr input(x_data); + + out_count->Resize({bs}); + int* out_count_data = dev_ctx.template Alloc(out_count); + thrust::device_ptr output_count(out_count_data); + + int total_sample_size = GetTotalSampleNum( + input, col_ptr_data, output_count, sample_size, bs); + + out->Resize({static_cast(total_sample_size)}); + T* out_data = dev_ctx.template Alloc(out); + thrust::device_ptr output(out_data); + + if (!flag_perm_buffer) { + SampleNeighbors(dev_ctx, + row_data, + col_ptr_data, + input, + output, + output_count, + sample_size, + bs, + total_sample_size); + } else { + DenseTensor perm_buffer_out(perm_buffer->type()); + const auto* p_perm_buffer = perm_buffer.get_ptr(); + perm_buffer_out.ShareDataWith(*p_perm_buffer); + T* perm_buffer_out_data = + perm_buffer_out.mutable_data(dev_ctx.GetPlace()); + FisherYatesSampleNeighbors(dev_ctx, + row_data, + col_ptr_data, + perm_buffer_out_data, + input, + output, + output_count, + sample_size, + bs, + total_sample_size); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_sample_neighbors, + GPU, + ALL_LAYOUT, + phi::GraphSampleNeighborsKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/graph_reindex_kernel.h b/paddle/phi/kernels/graph_reindex_kernel.h new file mode 100644 index 0000000000000..68f1ebc6f5cc4 --- /dev/null +++ b/paddle/phi/kernels/graph_reindex_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GraphReindexKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& neighbors, + const DenseTensor& count, + paddle::optional hashtable_value, + paddle::optional hashtable_index, + bool flag_buffer_hashtable, + DenseTensor* reindex_src, + DenseTensor* reindex_dst, + DenseTensor* out_nodes); + +} // namespace phi diff --git a/paddle/phi/kernels/graph_sample_neighbors_kernel.h b/paddle/phi/kernels/graph_sample_neighbors_kernel.h new file mode 100644 index 0000000000000..f7d205bd08ad0 --- /dev/null +++ b/paddle/phi/kernels/graph_sample_neighbors_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GraphSampleNeighborsKernel( + const Context& dev_ctx, + const DenseTensor& row, + const DenseTensor& col_ptr, + const DenseTensor& x, + paddle::optional eids, + paddle::optional perm_buffer, + int sample_size, + bool return_eids, + bool flag_perm_buffer, + DenseTensor* out, + DenseTensor* out_count, + DenseTensor* out_eids); + +} // namespace phi diff --git a/paddle/phi/ops/compat/graph_reindex_sig.cc b/paddle/phi/ops/compat/graph_reindex_sig.cc new file mode 100644 index 0000000000000..4e1e7ccedc19d --- /dev/null +++ b/paddle/phi/ops/compat/graph_reindex_sig.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GraphReindexOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "graph_reindex", + {"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"}, + {"flag_buffer_hashtable"}, + {"Reindex_Src", "Reindex_Dst", "Out_Nodes"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(graph_reindex, phi::GraphReindexOpArgumentMapping); diff --git a/paddle/phi/ops/compat/graph_sample_neighbors_sig.cc b/paddle/phi/ops/compat/graph_sample_neighbors_sig.cc new file mode 100644 index 0000000000000..dd8aaa95c583d --- /dev/null +++ b/paddle/phi/ops/compat/graph_sample_neighbors_sig.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GraphSampleNeighborsOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("graph_sample_neighbors", + {"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"}, + {"sample_size", "return_eids", "flag_perm_buffer"}, + {"Out", "Out_Count", "Out_Eids"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(graph_sample_neighbors, + phi::GraphSampleNeighborsOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_graph_reindex.py b/python/paddle/fluid/tests/unittests/test_graph_reindex.py new file mode 100644 index 0000000000000..52abbbe81aef9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_graph_reindex.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid + + +class TestGraphReindex(unittest.TestCase): + def setUp(self): + self.x = np.arange(5).astype("int64") + self.neighbors = np.random.randint(100, size=20).astype("int64") + self.count = np.array([2, 8, 4, 3, 3], dtype="int32") + + # Get numpy result. + out_nodes = list(self.x) + for neighbor in self.neighbors: + if neighbor not in out_nodes: + out_nodes.append(neighbor) + self.out_nodes = np.array(out_nodes, dtype="int64") + reindex_dict = {node: ind for ind, node in enumerate(self.out_nodes)} + self.reindex_src = np.array( + [reindex_dict[node] for node in self.neighbors]) + reindex_dst = [] + for node, c in zip(self.x, self.count): + for i in range(c): + reindex_dst.append(reindex_dict[node]) + self.reindex_dst = np.array(reindex_dst, dtype="int64") + self.num_nodes = np.max(np.concatenate([self.x, self.neighbors])) + 1 + + def test_reindex_result(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + neighbors = paddle.to_tensor(self.neighbors) + count = paddle.to_tensor(self.count) + value_buffer = paddle.full([self.num_nodes], -1, dtype="int32") + index_buffer = paddle.full([self.num_nodes], -1, dtype="int32") + + reindex_src, reindex_dst, out_nodes = \ + paddle.incubate.graph_reindex(x, neighbors, count) + self.assertTrue(np.allclose(self.reindex_src, reindex_src)) + self.assertTrue(np.allclose(self.reindex_dst, reindex_dst)) + self.assertTrue(np.allclose(self.out_nodes, out_nodes)) + + reindex_src, reindex_dst, out_nodes = \ + paddle.incubate.graph_reindex(x, neighbors, count, + value_buffer, index_buffer, + flag_buffer_hashtable=True) + self.assertTrue(np.allclose(self.reindex_src, reindex_src)) + self.assertTrue(np.allclose(self.reindex_dst, reindex_dst)) + self.assertTrue(np.allclose(self.out_nodes, out_nodes)) + + def test_reindex_result_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="x", shape=self.x.shape, dtype=self.x.dtype) + neighbors = paddle.static.data( + name="neighbors", + shape=self.neighbors.shape, + dtype=self.neighbors.dtype) + count = paddle.static.data( + name="count", shape=self.count.shape, dtype=self.count.dtype) + value_buffer = paddle.static.data( + name="value_buffer", shape=[self.num_nodes], dtype="int32") + index_buffer = paddle.static.data( + name="index_buffer", shape=[self.num_nodes], dtype="int32") + + reindex_src_1, reindex_dst_1, out_nodes_1 = \ + paddle.incubate.graph_reindex(x, neighbors, count) + reindex_src_2, reindex_dst_2, out_nodes_2 = \ + paddle.incubate.graph_reindex(x, neighbors, count, + value_buffer, index_buffer, + flag_buffer_hashtable=True) + + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'x': self.x, + 'neighbors': self.neighbors, + 'count': self.count, + 'value_buffer': np.full( + [self.num_nodes], -1, dtype="int32"), + 'index_buffer': np.full( + [self.num_nodes], -1, dtype="int32") + }, + fetch_list=[ + reindex_src_1, reindex_dst_1, out_nodes_1, + reindex_src_2, reindex_dst_2, out_nodes_2 + ]) + reindex_src_1, reindex_dst_1, out_nodes_1, reindex_src_2, \ + reindex_dst_2, out_nodes_2 = ret + self.assertTrue(np.allclose(self.reindex_src, reindex_src_1)) + self.assertTrue(np.allclose(self.reindex_dst, reindex_dst_1)) + self.assertTrue(np.allclose(self.out_nodes, out_nodes_1)) + self.assertTrue(np.allclose(self.reindex_src, reindex_src_2)) + self.assertTrue(np.allclose(self.reindex_dst, reindex_dst_2)) + self.assertTrue(np.allclose(self.out_nodes, out_nodes_2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py b/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py new file mode 100644 index 0000000000000..d2fbeab3fd42c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py @@ -0,0 +1,209 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid + + +class TestGraphSampleNeighbors(unittest.TestCase): + def setUp(self): + num_nodes = 20 + edges = np.random.randint(num_nodes, size=(100, 2)) + edges = np.unique(edges, axis=0) + self.edges_id = np.arange(0, len(edges)).astype("int64") + sorted_edges = edges[np.argsort(edges[:, 1])] + + # Calculate dst index cumsum counts, also means colptr + dst_count = np.zeros(num_nodes) + dst_src_dict = {} + for dst in range(0, num_nodes): + true_index = sorted_edges[:, 1] == dst + dst_count[dst] = np.sum(true_index) + dst_src_dict[dst] = sorted_edges[:, 0][true_index] + dst_count = dst_count.astype("int64") + colptr = np.cumsum(dst_count) + colptr = np.insert(colptr, 0, 0) + + self.row = sorted_edges[:, 0].astype("int64") + self.colptr = colptr.astype("int64") + self.nodes = np.unique(np.random.randint( + num_nodes, size=5)).astype("int64") + self.sample_size = 5 + self.dst_src_dict = dst_src_dict + + def test_sample_result(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + + out_neighbors, out_count = paddle.incubate.graph_sample_neighbors( + row, colptr, nodes, sample_size=self.sample_size) + out_count_cumsum = paddle.cumsum(out_count) + for i in range(len(out_count)): + if i == 0: + neighbors = out_neighbors[0:out_count_cumsum[i]] + else: + neighbors = out_neighbors[out_count_cumsum[i - 1]: + out_count_cumsum[i]] + # Ensure the correct sample size. + self.assertTrue( + out_count[i] == self.sample_size or + out_count[i] == len(self.dst_src_dict[self.nodes[i]])) + # Ensure no repetitive sample neighbors. + self.assertTrue( + neighbors.shape[0] == paddle.unique(neighbors).shape[0]) + # Ensure the correct sample neighbors. + in_neighbors = np.isin(neighbors.numpy(), + self.dst_src_dict[self.nodes[i]]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_sample_result_fisher_yates_sampling(self): + paddle.disable_static() + if fluid.core.is_compiled_with_cuda(): + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + perm_buffer = paddle.to_tensor(self.edges_id) + + out_neighbors, out_count = paddle.incubate.graph_sample_neighbors( + row, + colptr, + nodes, + perm_buffer=perm_buffer, + sample_size=self.sample_size, + flag_perm_buffer=True) + out_count_cumsum = paddle.cumsum(out_count) + for i in range(len(out_count)): + if i == 0: + neighbors = out_neighbors[0:out_count_cumsum[i]] + else: + neighbors = out_neighbors[out_count_cumsum[i - 1]: + out_count_cumsum[i]] + # Ensure the correct sample size. + self.assertTrue( + out_count[i] == self.sample_size or + out_count[i] == len(self.dst_src_dict[self.nodes[i]])) + # Ensure no repetitive sample neighbors. + self.assertTrue( + neighbors.shape[0] == paddle.unique(neighbors).shape[0]) + # Ensure the correct sample neighbors. + in_neighbors = np.isin(neighbors.numpy(), + self.dst_src_dict[self.nodes[i]]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_sample_result_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data( + name="row", shape=self.row.shape, dtype=self.row.dtype) + colptr = paddle.static.data( + name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype) + nodes = paddle.static.data( + name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype) + + out_neighbors, out_count = paddle.incubate.graph_sample_neighbors( + row, colptr, nodes, sample_size=self.sample_size) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'row': self.row, + 'colptr': self.colptr, + 'nodes': self.nodes + }, + fetch_list=[out_neighbors, out_count]) + out_neighbors, out_count = ret + out_count_cumsum = np.cumsum(out_count) + out_neighbors = np.split(out_neighbors, out_count_cumsum)[:-1] + for neighbors, node, count in zip(out_neighbors, self.nodes, + out_count): + self.assertTrue(count == self.sample_size or + count == len(self.dst_src_dict[node])) + self.assertTrue( + neighbors.shape[0] == np.unique(neighbors).shape[0]) + in_neighbors = np.isin(neighbors, self.dst_src_dict[node]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_raise_errors(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + + def check_eid_error(): + paddle.incubate.graph_sample_neighbors( + row, + colptr, + nodes, + sample_size=self.sample_size, + return_eids=True) + + def check_perm_buffer_error(): + paddle.incubate.graph_sample_neighbors( + row, + colptr, + nodes, + sample_size=self.sample_size, + flag_perm_buffer=True) + + self.assertRaises(ValueError, check_eid_error) + self.assertRaises(ValueError, check_perm_buffer_error) + + def test_sample_result_with_eids(self): + # Note: Currently return eid results is not initialized. + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + eids = paddle.to_tensor(self.edges_id) + + out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors( + row, + colptr, + nodes, + eids=eids, + sample_size=self.sample_size, + return_eids=True) + + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data( + name="row", shape=self.row.shape, dtype=self.row.dtype) + colptr = paddle.static.data( + name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype) + nodes = paddle.static.data( + name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype) + eids = paddle.static.data( + name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype) + + out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors( + row, + colptr, + nodes, + eids, + sample_size=self.sample_size, + return_eids=True) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'row': self.row, + 'colptr': self.colptr, + 'nodes': self.nodes, + 'eids': self.edges_id + }, + fetch_list=[out_neighbors, out_count]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 182aae40f2982..d8cc322a66e27 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -21,6 +21,8 @@ from .operators import softmax_mask_fuse # noqa: F401 from .operators import graph_send_recv from .operators import graph_khop_sampler +from .operators import graph_sample_neighbors +from .operators import graph_reindex from .tensor import segment_sum from .tensor import segment_mean from .tensor import segment_max @@ -37,6 +39,8 @@ 'softmax_mask_fuse', 'graph_send_recv', 'graph_khop_sampler', + 'graph_sample_neighbors', + 'graph_reindex', 'segment_sum', 'segment_mean', 'segment_max', diff --git a/python/paddle/incubate/operators/__init__.py b/python/paddle/incubate/operators/__init__.py index 073c3afcbcbfc..bc4ba8c3890fd 100644 --- a/python/paddle/incubate/operators/__init__.py +++ b/python/paddle/incubate/operators/__init__.py @@ -17,3 +17,5 @@ from .resnet_unit import ResNetUnit #noqa: F401 from .graph_send_recv import graph_send_recv #noqa: F401 from .graph_khop_sampler import graph_khop_sampler #noqa: F401 +from .graph_sample_neighbors import graph_sample_neighbors #noqa: F401 +from .graph_reindex import graph_reindex #noqa: F401 diff --git a/python/paddle/incubate/operators/graph_reindex.py b/python/paddle/incubate/operators/graph_reindex.py new file mode 100644 index 0000000000000..328b87a699750 --- /dev/null +++ b/python/paddle/incubate/operators/graph_reindex.py @@ -0,0 +1,127 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import _non_static_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid import core +from paddle import _C_ops + + +def graph_reindex(x, + neighbors, + count, + value_buffer=None, + index_buffer=None, + flag_buffer_hashtable=False, + name=None): + """ + Graph Reindex API. + + This API is mainly used in Graph Learning domain, which should be used + in conjunction with `graph_sample_neighbors` API. And the main purpose + is to reindex the ids information of the input nodes, and return the + corresponding graph edges after reindex. + + Take input nodes x = [0, 1, 2] as an example. + If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], + then we know that the neighbors of 0 is [8, 9], the neighbors of 1 + is [0, 4, 7], and the neighbors of 2 is [6, 7]. + + Args: + x (Tensor): The input nodes which we sample neighbors for. The available + data type is int32, int64. + neighbors (Tensor): The neighbors of the input nodes `x`. The data type + should be the same with `x`. + count (Tensor): The neighbor count of the input nodes `x`. And the + data type should be int32. + value_buffer (Tensor|None): Value buffer for hashtable. The data type should + be int32, and should be filled with -1. + index_buffer (Tensor|None): Index buffer for hashtable. The data type should + be int32, and should be filled with -1. + flag_buffer_hashtable (bool): Whether to use buffer for hashtable to speed up. + Default is False. Only useful for gpu version currently. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + reindex_src (Tensor): The source node index of graph edges after reindex. + reindex_dst (Tensor): The destination node index of graph edges after reindex. + out_nodes (Tensor): The index of unique input nodes and neighbors before reindex, + where we put the input nodes `x` in the front, and put neighbor + nodes in the back. + + Examples: + + .. code-block:: python + + import paddle + + x = [0, 1, 2] + neighbors = [8, 9, 0, 4, 7, 6, 7] + count = [2, 3, 2] + x = paddle.to_tensor(x, dtype="int64") + neighbors = paddle.to_tensor(neighbors, dtype="int64") + count = paddle.to_tensor(count, dtype="int32") + + reindex_src, reindex_dst, out_nodes = \ + paddle.incubate.graph_reindex(x, neighbors, count) + # reindex_src: [3, 4, 0, 5, 6, 7, 6] + # reindex_dst: [0, 0, 1, 1, 1, 2, 2] + # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6] + + """ + if flag_buffer_hashtable: + if value_buffer is None or index_buffer is None: + raise ValueError(f"`value_buffer` and `index_buffer` should not" + "be None if `flag_buffer_hashtable` is True.") + + if _non_static_mode(): + reindex_src, reindex_dst, out_nodes = \ + _C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer, + "flag_buffer_hashtable", flag_buffer_hashtable) + return reindex_src, reindex_dst, out_nodes + + check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex") + check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"), + "graph_reindex") + check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex") + + if flag_buffer_hashtable: + check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"), + "graph_reindex") + check_variable_and_dtype(index_buffer, "HashTable_Value", ("int32"), + "graph_reindex") + + helper = LayerHelper("graph_reindex", **locals()) + reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype) + reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype) + out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="graph_reindex", + inputs={ + "X": x, + "Neighbors": neighbors, + "Count": count, + "HashTable_Value": value_buffer if flag_buffer_hashtable else None, + "HashTable_Index": index_buffer if flag_buffer_hashtable else None, + }, + outputs={ + "Reindex_Src": reindex_src, + "Reindex_Dst": reindex_dst, + "Out_Nodes": out_nodes + }, + attrs={"flag_buffer_hashtable": flag_buffer_hashtable}) + return reindex_src, reindex_dst, out_nodes diff --git a/python/paddle/incubate/operators/graph_sample_neighbors.py b/python/paddle/incubate/operators/graph_sample_neighbors.py new file mode 100644 index 0000000000000..d5a85af7272e7 --- /dev/null +++ b/python/paddle/incubate/operators/graph_sample_neighbors.py @@ -0,0 +1,150 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import _non_static_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid import core +from paddle import _C_ops + + +def graph_sample_neighbors(row, + colptr, + input_nodes, + eids=None, + perm_buffer=None, + sample_size=-1, + return_eids=False, + flag_perm_buffer=False, + name=None): + """ + Graph Sample Neighbors API. + + This API is mainly used in Graph Learning domain, and the main purpose is to + provide high performance of graph sampling method. For example, we get the + CSC(Compressed Sparse Column) format of the input graph edges as `row` and + `colptr`, so as to convert graph data into a suitable format for sampling. + `input_nodes` means the nodes we need to sample neighbors, and `sample_sizes` + means the number of neighbors and number of layers we want to sample. + + Besides, we support fisher-yates sampling in GPU version. + + Args: + row (Tensor): One of the components of the CSC format of the input graph, and + the shape should be [num_edges, 1] or [num_edges]. The available + data type is int32, int64. + colptr (Tensor): One of the components of the CSC format of the input graph, + and the shape should be [num_nodes + 1, 1] or [num_nodes + 1]. + The data type should be the same with `row`. + input_nodes (Tensor): The input nodes we need to sample neighbors for, and the + data type should be the same with `row`. + eids (Tensor): The eid information of the input graph. If return_eids is True, + then `eids` should not be None. The data type should be the + same with `row`. Default is None. + perm_buffer (Tensor): Permutation buffer for fisher-yates sampling. If `flag_perm_buffer` + is True, then `perm_buffer` should not be None. The data type should + be the same with `row`. Default is None. + sample_size (int): The number of neighbors we need to sample. Default value is + -1, which means returning all the neighbors of the input nodes. + return_eids (bool): Whether to return eid information of sample edges. Default is False. + flag_perm_buffer (bool): Using the permutation for fisher-yates sampling in GPU. Default + value is false, means not using it. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out_neighbors (Tensor): The sample neighbors of the input nodes. + out_count (Tensor): The number of sampling neighbors of each input node, and the shape + should be the same with `input_nodes`. + out_eids (Tensor): If `return_eids` is True, we will return the eid information of the + sample edges. + + Examples: + .. code-block:: python + import paddle + # edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4), + # (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8) + row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7] + colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13] + nodes = [0, 8, 1, 2] + sample_size = 2 + row = paddle.to_tensor(row, dtype="int64") + colptr = paddle.to_tensor(colptr, dtype="int64") + nodes = paddle.to_tensor(nodes, dtype="int64") + out_neighbors, out_count = \ + paddle.incubate.graph_sample_neighbors(row, colptr, nodes, + sample_size=sample_size) + + """ + + if return_eids: + if eids is None: + raise ValueError( + f"`eids` should not be None if `return_eids` is True.") + + if flag_perm_buffer: + if perm_buffer is None: + raise ValueError( + f"`perm_buffer` should not be None if `flag_perm_buffer`" + "is True.") + + if _non_static_mode(): + out_neighbors, out_count, out_eids = _C_ops.graph_sample_neighbors( + row, colptr, input_nodes, eids, perm_buffer, "sample_size", + sample_size, "return_eids", return_eids, "flag_perm_buffer", + flag_perm_buffer) + if return_eids: + return out_neighbors, out_count, out_eids + return out_neighbors, out_count + + check_variable_and_dtype(row, "Row", ("int32", "int64"), + "graph_sample_neighbors") + check_variable_and_dtype(colptr, "Col_Ptr", ("int32", "int64"), + "graph_sample_neighbors") + check_variable_and_dtype(input_nodes, "X", ("int32", "int64"), + "graph_sample_neighbors") + if return_eids: + check_variable_and_dtype(eids, "Eids", ("int32", "int64"), + "graph_sample_neighbors") + if flag_perm_buffer: + check_variable_and_dtype(perm_buffer, "Perm_Buffer", ("int32", "int64"), + "graph_sample_neighbors") + + helper = LayerHelper("graph_sample_neighbors", **locals()) + out_neighbors = helper.create_variable_for_type_inference(dtype=row.dtype) + out_count = helper.create_variable_for_type_inference(dtype=row.dtype) + out_eids = helper.create_variable_for_type_inference(dtype=row.dtype) + helper.append_op( + type="graph_sample_neighbors", + inputs={ + "Row": row, + "Col_Ptr": colptr, + "X": input_nodes, + "Eids": eids if return_eids else None, + "Perm_Buffer": perm_buffer if flag_perm_buffer else None + }, + outputs={ + "Out": out_neighbors, + "Out_Count": out_count, + "Out_Eids": out_eids + }, + attrs={ + "sample_size": sample_size, + "return_eids": return_eids, + "flag_perm_buffer": flag_perm_buffer + }) + if return_eids: + return out_neighbors, out_count, out_eids + return out_neighbors, out_count From 78200976e33428e8da03e29289873cf577cf51f8 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 2 Apr 2022 20:59:13 +0800 Subject: [PATCH 31/93] [Phi] Fix no pinned transform (#41300) * fix no pinned trans * fix cond error --- paddle/phi/api/lib/data_transform.cc | 7 ++++--- paddle/phi/core/compat/convert_utils.cc | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index c1fc0fd907bba..90d47977cdf60 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -37,9 +37,10 @@ inline bool NeedTransformDataType(const DataType& input, inline bool NeedTransformPlace(const paddle::platform::Place& input, const Backend& target, const TransformFlag& transform_flag) { - bool ret = transform_flag.need_trans_backend() && - target != Backend::ALL_BACKEND && - phi::TransToPhiBackend(input) != target; + bool ret = + input.GetType() == AllocationType::GPUPINNED || + (transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND && + phi::TransToPhiBackend(input) != target); return ret; } diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index cc9c2caa88991..c08dfa64c7f1b 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/backends/xpu/xpu_info.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/compat/op_utils.h" +#include "paddle/phi/core/enforce.h" #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/backends/device_manager.h" @@ -31,6 +32,8 @@ Backend TransToPhiBackend(const phi::Place& place) { return Backend::CPU; } else if (allocation_type == phi::AllocationType::GPU) { return Backend::GPU; + } else if (allocation_type == phi::AllocationType::GPUPINNED) { + return Backend::GPU; } else if (allocation_type == phi::AllocationType::XPU) { return Backend::XPU; } else if (allocation_type == phi::AllocationType::NPU) { @@ -40,7 +43,8 @@ Backend TransToPhiBackend(const phi::Place& place) { static_cast(Backend::NUM_BACKENDS) + GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType())); } else { - return Backend::UNDEFINED; + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupported transform %s to phi Backend.", place)); } } From 50714d5cc41d121b9bb979023bc58eabc2a3a49a Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sat, 2 Apr 2022 21:07:11 +0800 Subject: [PATCH 32/93] [Eager]Fix eager no take effect problem (#41291) * [Eager]Fix eager no take effect problem * add element_wise and fix greater_than --- paddle/fluid/pybind/eager_method.cc | 11 +++++++++++ python/paddle/__init__.py | 5 ++++- python/paddle/fluid/tests/unittests/test_cross_op.py | 4 ++-- python/paddle/tensor/linalg.py | 4 ++++ python/paddle/tensor/logic.py | 3 ++- python/paddle/utils/code_gen/api.yaml | 8 ++++---- 6 files changed, 27 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 37ace14d145c6..d9face124bd82 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1279,6 +1279,15 @@ static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor_method_element_size(TensorObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + uint32_t element_size = framework::DataTypeSize(self->tensor.dtype()); + + return ToPyObject(element_size); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor__bump_inplace_version(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -1417,6 +1426,8 @@ PyMethodDef variable_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"to_dense", (PyCFunction)(void (*)(void))tensor_method_to_dense, METH_VARARGS | METH_KEYWORDS, NULL}, + {"element_size", (PyCFunction)(void (*)(void))tensor_method_element_size, + METH_VARARGS | METH_KEYWORDS, NULL}, /***the method of sparse tensor****/ {"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index bba9c226dc07b..e532633b6eb35 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -48,7 +48,10 @@ from .framework.dtype import bool # noqa: F401 from .framework.dtype import complex64 # noqa: F401 from .framework.dtype import complex128 # noqa: F401 -from .framework import VarBase as Tensor # noqa: F401 +if fluid.framework._in_eager_mode_: + Tensor = framework.core.eager.Tensor +else: + from .framework import VarBase as Tensor # noqa: F401 Tensor.__qualname__ = 'Tensor' # noqa: F401 import paddle.compat # noqa: F401 import paddle.distributed # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_cross_op.py b/python/paddle/fluid/tests/unittests/test_cross_op.py index 6cba72213ff97..8b884583646a7 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_op.py @@ -48,10 +48,10 @@ def init_output(self): self.outputs = {'Out': np.array(z_list).reshape(self.shape)} def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_eager=False) + self.check_grad(['X', 'Y'], 'Out', check_eager=True) class TestCrossOpCase1(TestCrossOp): diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 818ce2f5c6757..8afab2e05f26b 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -27,6 +27,9 @@ __all__ = [] +# Consistent with kDefaultDim from C++ Backend +K_DEFAULT_DIM = 9 + def matmul(x, y, transpose_x=False, transpose_y=False, name=None): """ @@ -1157,6 +1160,7 @@ def cross(x, y, axis=None, name=None): # [0. 0. 0.]] """ if in_dygraph_mode(): + axis = K_DEFAULT_DIM if axis is None else axis return _C_ops.final_state_cross(x, y, axis) else: if _in_legacy_dygraph(): diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index e3ffd36d77972..3896fa535ff22 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -280,7 +280,8 @@ def greater_than(x, y, name=None): print(result1) # result1 = [False False True] """ if in_dygraph_mode(): - return _C_ops.final_state_greater_than(x, y) + axis = -1 # default value + return _C_ops.final_state_greater_than(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.greater_than(x, y) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index ece46837c6def..b46accfb11b01 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -610,21 +610,21 @@ func : gelu backward : gelu_grad -- api : greater +- api : greater_equal args : (Tensor x, Tensor y, int axis = -1) output : Tensor infer_meta : func : CompareInferMeta kernel : - func : greater + func : greater_equal -- api : greater_equal +- api : greater_than args : (Tensor x, Tensor y, int axis = -1) output : Tensor infer_meta : func : CompareInferMeta kernel : - func : greater_equal + func : greater_than - api : gumbel_softmax args : (Tensor x, float temperature, bool hard, int axis) From 2a01a15742c38ff9b6c392e4554fa06111bdd22b Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Sat, 2 Apr 2022 21:50:20 +0800 Subject: [PATCH 33/93] [Infrt] skip grad kernel in infrt frame (#41315) * code * code --- tools/infrt/get_compat_kernel_signature.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tools/infrt/get_compat_kernel_signature.py b/tools/infrt/get_compat_kernel_signature.py index 45dc931fac19d..a66a236b0f975 100644 --- a/tools/infrt/get_compat_kernel_signature.py +++ b/tools/infrt/get_compat_kernel_signature.py @@ -19,6 +19,13 @@ skip_list = ["adam_sig.cc", "adamw_sig.cc"] +def is_grad_kernel(kernel_info): + kernel_name = kernel_info.split(",")[0] + if kernel_name.endswith("_grad"): + return True + return False + + def parse_compat_registry(kernel_info): name, inputs_str, attrs_str, outputs_str = kernel_info.split(",{") kernel_info = {} @@ -62,6 +69,8 @@ def get_compat_kernels_info(): "").strip("return").strip("KernelSignature(").strip( "\);").replace("\"", "").replace("\\", "") registry = False + if is_grad_kernel(data): + continue name, registry_info = parse_compat_registry(data) if name in kernels_info: From e0ccaeafaf64d8c8cd2e2579b0d973e4cec622f7 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Apr 2022 22:07:37 +0800 Subject: [PATCH 34/93] [new-exec] fit empty program for new executor (#41328) --- .../fluid/framework/new_executor/interpretercore.cc | 8 ++++++-- .../interpreter/test_standalone_executor.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 1b15ca6746257..cf0b64cbc3a70 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -516,6 +516,12 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::ExecuteInstructionList( const std::vector& vec_instr) { + unfinished_op_numer_ = vec_instr.size(); + if (unfinished_op_numer_ == 0) { + VLOG(4) << "No op to run, return"; + return; + } + // NOTE(zhiqiu): get the prepared deps from std::future, and async prepare // those for the next step auto atomic_deps = async_work_queue_->AtomicDeps(); @@ -524,8 +530,6 @@ void InterpreterCore::ExecuteInstructionList( async_work_queue_->PrepareAtomicDeps(dependecy_count_); async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo()); - unfinished_op_numer_ = vec_instr.size(); - exception_holder_.Clear(); for (size_t i = 0; i < dependecy_count_.size(); ++i) { diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index cff4f7f41d02b..c07d4cc15bee0 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -277,6 +277,18 @@ def test_compiled_program(self): for x, y in zip(gt, res): self.assertTrue(np.array_equal(x, y)) + def test_empty_program(self): + program = paddle.static.Program() + exe = paddle.static.Executor(self.place) + for i in range(10): + out = exe.run() # old executor + + for i in range(10): + print(i, flush=1) + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' + out = exe.run(program, feed=None) + del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + class TestException(unittest.TestCase): def setUp(self): From af247f958295930f5b15b4e26a6bcb55c7c08370 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Sun, 3 Apr 2022 10:59:54 +0800 Subject: [PATCH 35/93] fix reduce prod backward bug (#41357) --- .../paddle/fluid/tests/unittests/op_test.py | 2 -- .../fluid/tests/unittests/test_reduce_op.py | 20 ++++++++++++------- python/paddle/utils/code_gen/backward.yaml | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 1756537ba6240..be883d243f795 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1559,8 +1559,6 @@ def calculate_output(self): def _compare_numpy(self, name, actual_np, expect_np): with _test_eager_guard(): - print(actual_np) - print(expect_np) super()._compare_numpy(name, actual_np, expect_np) def convert_uint16_to_float_ifneed(self, actual_np, expect_np): diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 98607fb07fedf..69693f57bb2f3 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -238,10 +238,14 @@ def test_check_output(self): self.check_output(check_eager=True) +def raw_reduce_prod(x, dim=[0], keep_dim=False): + return paddle.prod(x, dim, keep_dim) + + class TestProdOp(OpTest): def setUp(self): self.op_type = "reduce_prod" - self.python_api = paddle.prod + self.python_api = raw_reduce_prod self.init_data_type() self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)} self.outputs = {'Out': self.inputs['X'].prod(axis=0)} @@ -251,15 +255,16 @@ def init_data_type(self): ) else "float64" def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestProd6DOp(OpTest): def setUp(self): self.op_type = "reduce_prod" + self.python_api = raw_reduce_prod self.init_data_type() self.inputs = { 'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type) @@ -274,15 +279,16 @@ def init_data_type(self): ) else "float64" def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestProd8DOp(OpTest): def setUp(self): self.op_type = "reduce_prod" + self.python_api = raw_reduce_prod self.init_data_type() self.inputs = { 'X': np.random.random( @@ -298,10 +304,10 @@ def init_data_type(self): ) else "float64" def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestAllOp(OpTest): diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 6d046cb68d93d..ad22723c994cf 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -769,7 +769,7 @@ func : UnchangedInferMeta param : [x] kernel : - func : reduce_prod_grad + func : prod_grad - backward_api : relu_double_grad forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) From bce9c8c4e97e30406e5bfd78feeeec3c31a80601 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sun, 3 Apr 2022 11:19:52 +0800 Subject: [PATCH 36/93] [Eager] Support two callback related tests (#41275) --- .../tests/test_callback_reduce_lr_on_plateau.py | 15 +++++++++++++-- python/paddle/tests/test_callback_visualdl.py | 8 +++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/paddle/tests/test_callback_reduce_lr_on_plateau.py b/python/paddle/tests/test_callback_reduce_lr_on_plateau.py index e950528ee4b65..d7680537f378b 100644 --- a/python/paddle/tests/test_callback_reduce_lr_on_plateau.py +++ b/python/paddle/tests/test_callback_reduce_lr_on_plateau.py @@ -29,6 +29,7 @@ from paddle.vision.datasets import MNIST from paddle.metric import Accuracy from paddle.nn.layer.loss import CrossEntropyLoss +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph # Accelerate unittest @@ -38,7 +39,7 @@ def __len__(self): class TestReduceLROnPlateau(unittest.TestCase): - def test_reduce_lr_on_plateau(self): + def func_reduce_lr_on_plateau(self): transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) train_dataset = CustomMnist(mode='train', transform=transform) val_dataset = CustomMnist(mode='test', transform=transform) @@ -59,7 +60,12 @@ def test_reduce_lr_on_plateau(self): epochs=10, callbacks=[callbacks]) - def test_warn_or_error(self): + def test_reduce_lr_on_plateau(self): + with _test_eager_guard(): + self.func_reduce_lr_on_plateau() + self.func_reduce_lr_on_plateau() + + def func_warn_or_error(self): with self.assertRaises(ValueError): paddle.callbacks.ReduceLROnPlateau(factor=2.0) # warning @@ -101,6 +107,11 @@ def test_warn_or_error(self): epochs=3, callbacks=[callbacks]) + def test_warn_or_error(self): + with _test_eager_guard(): + self.func_warn_or_error() + self.func_warn_or_error() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tests/test_callback_visualdl.py b/python/paddle/tests/test_callback_visualdl.py index db3b83f2b1414..355e88edd2bec 100644 --- a/python/paddle/tests/test_callback_visualdl.py +++ b/python/paddle/tests/test_callback_visualdl.py @@ -29,6 +29,7 @@ from paddle.vision.datasets import MNIST from paddle.metric import Accuracy from paddle.nn.layer.loss import CrossEntropyLoss +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class MnistDataset(MNIST): @@ -43,7 +44,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.save_dir) - def test_visualdl_callback(self): + def func_visualdl_callback(self): # visualdl not support python2 if sys.version_info < (3, ): return @@ -70,6 +71,11 @@ def test_visualdl_callback(self): batch_size=64, callbacks=callback) + def test_visualdl_callback(self): + with _test_eager_guard(): + self.func_visualdl_callback() + self.func_visualdl_callback() + if __name__ == '__main__': unittest.main() From 2ae10efd0916d39a397f8a46c0a0e31aa46c279c Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sun, 3 Apr 2022 11:20:34 +0800 Subject: [PATCH 37/93] [Eager] Support transformer tests in eager mode (#41347) --- ..._imperative_transformer_sorted_gradient.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py index 010c8aeccacd6..531c89fb19ec6 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -21,7 +21,7 @@ from paddle.fluid.dygraph import to_variable, guard from paddle.fluid.dygraph import TracedLayer from test_imperative_base import new_program_scope -from paddle.fluid.framework import _test_eager_guard +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode, _in_legacy_dygraph from paddle.fluid import core import numpy as np import six @@ -1041,8 +1041,9 @@ def run_dygraph(): with guard(): fluid.set_flags({'FLAGS_sort_sum_gradient': True}) - dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \ - dy_param_init, dy_param_updated = run_dygraph() + if _in_legacy_dygraph(): + dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \ + dy_param_init, dy_param_updated = run_dygraph() with new_program_scope(): paddle.seed(seed) @@ -1116,21 +1117,22 @@ def run_dygraph(): for k in range(4, len(out)): static_param_updated[static_param_name_list[k - 4]] = out[k] - - self.assertTrue( - np.array_equal(static_avg_cost_value, dy_avg_cost_value)) - self.assertTrue( - np.array_equal(static_sum_cost_value, dy_sum_cost_value)) - self.assertTrue(np.array_equal(static_predict_value, dy_predict_value)) - self.assertTrue( - np.array_equal(static_token_num_value, dy_token_num_value)) - - for key, value in six.iteritems(static_param_init): - self.assertTrue(np.array_equal(value, dy_param_init[key])) - for key, value in six.iteritems(static_param_updated): - self.assertTrue(np.array_equal(value, dy_param_updated[key])) - - # check eager result + if _in_legacy_dygraph(): + self.assertTrue( + np.array_equal(static_avg_cost_value, dy_avg_cost_value)) + self.assertTrue( + np.array_equal(static_sum_cost_value, dy_sum_cost_value)) + self.assertTrue( + np.array_equal(static_predict_value, dy_predict_value)) + self.assertTrue( + np.array_equal(static_token_num_value, dy_token_num_value)) + + for key, value in six.iteritems(static_param_init): + self.assertTrue(np.array_equal(value, dy_param_init[key])) + for key, value in six.iteritems(static_param_updated): + self.assertTrue(np.array_equal(value, dy_param_updated[key])) + + # compare eager result with imperative result with guard(): fluid.set_flags({'FLAGS_sort_sum_gradient': False}) dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \ From e4914734980979f0413412ecfcf6d92893687295 Mon Sep 17 00:00:00 2001 From: From00 Date: Sun, 3 Apr 2022 11:31:39 +0800 Subject: [PATCH 38/93] Add some yaml config (#41053) * Add yaml config * Add yaml for flatten_contiguous_range_op * Remove h_sigmoid yaml * Fix CI errors * Fix code format * Fix flatten OP errors * Fix conflicts * Fix CI errors * Remove flatten_contiguous_range OP * Remove redundant code * Fix typos --- .../kernels/cpu/hierarchical_sigmoid_grad.h | 4 +- .../cpu/hierarchical_sigmoid_grad_kernel.cc | 8 +- .../hierarchical_sigmoid_grad_kernel.h | 4 +- .../hierarchical_sigmoid_grad_kernel.cc | 8 +- .../hierarchical_sigmoid_grad_kernel.h | 4 +- .../ops/compat/hierarchical_sigmoid_sig.cc | 12 +- .../test_functional_conv2d_transpose.py | 23 +++- .../test_functional_conv3d_transpose.py | 21 +++- .../tests/unittests/test_index_select_op.py | 5 +- .../fluid/tests/unittests/test_norm_all.py | 11 +- .../fluid/tests/unittests/test_pool1d_api.py | 19 +++- .../fluid/tests/unittests/test_pool2d_api.py | 17 ++- .../fluid/tests/unittests/test_pool3d_api.py | 14 ++- .../fluid/tests/unittests/test_roll_op.py | 5 +- .../tests/unittests/test_searchsorted_op.py | 4 +- .../tests/unittests/test_tril_triu_op.py | 5 +- python/paddle/nn/functional/conv.py | 35 ++++-- python/paddle/nn/functional/pooling.py | 100 ++++++++++++----- python/paddle/tensor/creation.py | 10 +- python/paddle/tensor/linalg.py | 7 +- python/paddle/tensor/manipulation.py | 5 +- python/paddle/tensor/search.py | 10 +- python/paddle/utils/code_gen/api.yaml | 102 ++++++++++++++++- python/paddle/utils/code_gen/api_base.py | 2 +- python/paddle/utils/code_gen/backward.yaml | 105 ++++++++++++++++++ 25 files changed, 449 insertions(+), 91 deletions(-) diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h index b79aab96c0fc2..cc67f8e7f210c 100644 --- a/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h @@ -31,11 +31,11 @@ void HierarchicalSigmoidGradKernelImpl( const DenseTensor& x, const DenseTensor& w, const DenseTensor& label, - const DenseTensor& pre_out, - const DenseTensor& out_grad, paddle::optional path, paddle::optional code, paddle::optional bias, + const DenseTensor& pre_out, + const DenseTensor& out_grad, int num_classes, bool remote_prefetch, int trainer_id, diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc index f64a1a8162a37..9edc9f87d4b1f 100644 --- a/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc @@ -25,11 +25,11 @@ void HierarchicalSigmoidGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& w, const DenseTensor& label, - const DenseTensor& pre_out, - const DenseTensor& out_grad, paddle::optional path, paddle::optional code, paddle::optional bias, + const DenseTensor& pre_out, + const DenseTensor& out_grad, int num_classes, bool remote_prefetch, int trainer_id, @@ -44,11 +44,11 @@ void HierarchicalSigmoidGradKernel(const Context& ctx, x, w, label, - pre_out, - out_grad, path, code, bias, + pre_out, + out_grad, num_classes, remote_prefetch, trainer_id, diff --git a/paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h b/paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h index f7a327cd3f566..7922a767db23c 100644 --- a/paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h +++ b/paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h @@ -23,11 +23,11 @@ void HierarchicalSigmoidGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& w, const DenseTensor& label, - const DenseTensor& pre_out, - const DenseTensor& out_grad, paddle::optional path, paddle::optional code, paddle::optional bias, + const DenseTensor& pre_out, + const DenseTensor& out_grad, int num_classes, bool remote_prefetch, int trainer_id, diff --git a/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc index 80b2a1f6678a2..1660601bbd36e 100644 --- a/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc +++ b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc @@ -40,11 +40,11 @@ void HierarchicalSigmoidGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& w, const DenseTensor& label, - const DenseTensor& pre_out, - const DenseTensor& out_grad, paddle::optional path, paddle::optional code, paddle::optional bias, + const DenseTensor& pre_out, + const DenseTensor& out_grad, int num_classes, bool remote_prefetch, int trainer_id, @@ -70,11 +70,11 @@ void HierarchicalSigmoidGradKernel(const Context& ctx, x, w, label, - pre_out, - out_grad, path, code, bias, + pre_out, + out_grad, num_classes, remote_prefetch, trainer_id, diff --git a/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h index 557c8b1bc5eed..4c03b83d80fff 100644 --- a/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h +++ b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h @@ -25,11 +25,11 @@ void HierarchicalSigmoidGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& w, const DenseTensor& label, - const DenseTensor& pre_out, - const DenseTensor& out_grad, paddle::optional path, paddle::optional code, paddle::optional bias, + const DenseTensor& pre_out, + const DenseTensor& out_grad, int num_classes, bool remote_prefetch, int trainer_id, diff --git a/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc b/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc index 20183d1a9b066..58c190fb657bb 100644 --- a/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc +++ b/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc @@ -38,11 +38,11 @@ KernelSignature HierarchicalSigmoidGradOpArgumentMapping( {"X", "W", "Label", - "PreOut", - GradVarName("Out"), "PathTable", "PathCode", - "Bias"}, + "Bias", + "PreOut", + GradVarName("Out")}, {"num_classes", "remote_prefetch", "trainer_id", @@ -57,11 +57,11 @@ KernelSignature HierarchicalSigmoidGradOpArgumentMapping( {"X", "W", "Label", - "PreOut", - GradVarName("Out"), "PathTable", "PathCode", - "Bias"}, + "Bias", + "PreOut", + GradVarName("Out")}, {"num_classes", "remote_prefetch", "trainer_id", diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py b/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py index f25a15106c491..781169d70c17c 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py @@ -13,12 +13,13 @@ # limitations under the License. import paddle -import paddle.nn.functional as F -from paddle import fluid +import unittest +import numpy as np import paddle.fluid.dygraph as dg import paddle.fluid.initializer as I -import numpy as np -import unittest +import paddle.nn.functional as F +from paddle import fluid +from paddle.fluid.framework import _test_eager_guard from unittest import TestCase @@ -159,12 +160,22 @@ def test_identity_cpu(self): self.place = fluid.CPUPlace() self._test_identity() + def test_identity_cpu_check_eager(self): + with _test_eager_guard(): + self.test_identity_cpu() + @unittest.skipIf(not fluid.core.is_compiled_with_cuda(), "core is not compiled with CUDA") def test_identity_gpu(self): self.place = fluid.CUDAPlace(0) self._test_identity() + @unittest.skipIf(not fluid.core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + def test_identity_gpu_check_eager(self): + with _test_eager_guard(): + self.test_identity_gpu() + class TestFunctionalConv2DError(TestCase): batch_size = 4 @@ -520,6 +531,10 @@ def test_dygraph_exception(self): with self.assertRaises(ValueError): self.dygraph_case() + def test_dygraph_exception_check_eager(self): + with _test_eager_guard(): + self.test_dygraph_exception() + def test_static_exception(self): with self.assertRaises(ValueError): self.static_graph_case() diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py b/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py index a003de6596822..6f25d65aac227 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py @@ -13,12 +13,13 @@ # limitations under the License. import paddle -import paddle.nn.functional as F -from paddle import fluid +import numpy as np import paddle.fluid.dygraph as dg import paddle.fluid.initializer as I -import numpy as np +import paddle.nn.functional as F import unittest +from paddle import fluid +from paddle.fluid.framework import _test_eager_guard from unittest import TestCase @@ -165,12 +166,22 @@ def test_identity_cpu(self): self.place = fluid.CPUPlace() self._test_identity() + def test_identity_cpu_check_eager(self): + with _test_eager_guard(): + self.test_identity_cpu() + @unittest.skipIf(not fluid.core.is_compiled_with_cuda(), "core is not compiled with CUDA") def test_identity_gpu(self): self.place = fluid.CUDAPlace(0) self._test_identity() + @unittest.skipIf(not fluid.core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + def test_identity_gpu_check_eager(self): + with _test_eager_guard(): + self.test_identity_gpu() + class TestFunctionalConv3DTransposeError(TestCase): batch_size = 4 @@ -540,6 +551,10 @@ def test_dygraph_exception(self): with self.assertRaises(ValueError): self.dygraph_case() + def test_dygraph_exception_check_eager(self): + with _test_eager_guard(): + self.test_dygraph_exception() + def test_static_exception(self): with self.assertRaises(ValueError): self.static_graph_case() diff --git a/python/paddle/fluid/tests/unittests/test_index_select_op.py b/python/paddle/fluid/tests/unittests/test_index_select_op.py index f4545d406901c..0c0e946fddede 100644 --- a/python/paddle/fluid/tests/unittests/test_index_select_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_select_op.py @@ -25,6 +25,7 @@ class TestIndexSelectOp(OpTest): def setUp(self): + self.python_api = paddle.index_select self.op_type = "index_select" self.init_dtype_type() index_np = np.random.randint( @@ -54,10 +55,10 @@ def init_dtype_type(self): self.index_size = 100 def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestIndexSelectOpCase2(TestIndexSelectOp): diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py index 17c45299d0fc5..5b0a9599bf84e 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_all.py +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -86,8 +86,13 @@ def frobenius_norm(x, axis=None, keepdims=False): return r +def final_state_frobenius_norm(x, dim, keep_dim, reduce_all): + return paddle.linalg.norm(x, p='fro', axis=dim, keepdim=keep_dim) + + class TestFrobeniusNormOp(OpTest): def setUp(self): + self.python_api = final_state_frobenius_norm self.op_type = "frobenius_norm" self.init_test_case() x = (np.random.random(self.shape) + 1.0).astype(self.dtype) @@ -102,10 +107,10 @@ def setUp(self): self.outputs = {'Out': norm} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def init_test_case(self): self.shape = [2, 3, 4, 5] @@ -122,7 +127,7 @@ def init_test_case(self): self.dtype = "float32" def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestPnormOp(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_pool1d_api.py b/python/paddle/fluid/tests/unittests/test_pool1d_api.py index 9e7b0c8a1efa7..e1cfcc3f06602 100644 --- a/python/paddle/fluid/tests/unittests/test_pool1d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool1d_api.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import paddle import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.nn.functional as F import numpy as np from op_test import OpTest -import paddle.fluid.core as core -import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard -import paddle -import paddle.nn.functional as F -import paddle.fluid as fluid +from paddle.fluid.framework import _test_eager_guard def adaptive_start_index(index, input_size, output_size): @@ -244,6 +243,10 @@ def test_pool1d(self): self.check_avg_dygraph_padding_same(place) self.check_max_dygraph_return_index_results(place) + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_pool1d() + class TestPool2DError_API(unittest.TestCase): def test_error_api(self): @@ -370,6 +373,10 @@ def run_stride_out_of_range(): self.assertRaises(ValueError, run_stride_out_of_range) + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_error_api() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_api.py b/python/paddle/fluid/tests/unittests/test_pool2d_api.py index 872bec666bf8c..e86fa0ec48330 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_api.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from test_pool2d_op import adaptive_start_index, adaptive_end_index, pool2D_forward_naive, avg_pool2D_forward_naive, max_pool2D_forward_naive import unittest -from op_test import OpTest +import paddle import numpy as np +import paddle.fluid as fluid import paddle.fluid.core as core +from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard from paddle.nn.functional import avg_pool2d, max_pool2d -import paddle.fluid as fluid -import paddle +from test_pool2d_op import adaptive_start_index, adaptive_end_index, pool2D_forward_naive, avg_pool2D_forward_naive, max_pool2D_forward_naive class TestPool2D_API(unittest.TestCase): @@ -324,6 +325,10 @@ def test_pool2d(self): self.check_max_dygraph_ceilmode_results(place) self.check_max_dygraph_nhwc_results(place) + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_pool2d() + class TestPool2DError_API(unittest.TestCase): def test_error_api(self): @@ -524,6 +529,10 @@ def run_stride_out_of_range(): self.assertRaises(ValueError, run_stride_out_of_range) + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_error_api() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_api.py b/python/paddle/fluid/tests/unittests/test_pool3d_api.py index cddb09e5daa41..f20d2aad49f27 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_api.py @@ -15,13 +15,15 @@ from __future__ import print_function from __future__ import division +import paddle import unittest import numpy as np -import paddle +import paddle.fluid as fluid import paddle.fluid.core as core from op_test import OpTest -import paddle.fluid as fluid +from paddle.fluid.framework import _test_eager_guard from paddle.nn.functional import avg_pool3d, max_pool3d +from paddle.fluid.framework import _test_eager_guard from test_pool3d_op import adaptive_start_index, adaptive_end_index, pool3D_forward_naive, avg_pool3D_forward_naive, max_pool3D_forward_naive @@ -326,6 +328,10 @@ def test_pool3d(self): self.check_max_dygraph_ndhwc_results(place) self.check_max_dygraph_ceilmode_results(place) + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_pool3d() + class TestPool3DError_API(unittest.TestCase): def test_error_api(self): @@ -499,6 +505,10 @@ def run_size_out_of_range(): self.assertRaises(ValueError, run_size_out_of_range) + def test_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_error_api() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_roll_op.py b/python/paddle/fluid/tests/unittests/test_roll_op.py index bca7665b814db..c315aa9b74618 100644 --- a/python/paddle/fluid/tests/unittests/test_roll_op.py +++ b/python/paddle/fluid/tests/unittests/test_roll_op.py @@ -25,6 +25,7 @@ class TestRollOp(OpTest): def setUp(self): + self.python_api = paddle.roll self.op_type = "roll" self.init_dtype_type() self.inputs = {'X': np.random.random(self.x_shape).astype(self.dtype)} @@ -41,10 +42,10 @@ def init_dtype_type(self): self.axis = [0, -2] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestRollOpCase2(TestRollOp): diff --git a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py index f595d06d5bce7..f802b0adfcb2a 100644 --- a/python/paddle/fluid/tests/unittests/test_searchsorted_op.py +++ b/python/paddle/fluid/tests/unittests/test_searchsorted_op.py @@ -25,7 +25,7 @@ class TestSearchSorted(OpTest): def setUp(self): - + self.python_api = paddle.searchsorted self.op_type = "searchsorted" self.init_test_case() @@ -41,7 +41,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def init_test_case(self): self.sorted_sequence = np.array([1, 3, 5, 7, 9]).astype("float32") diff --git a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py index cdb5f66f57892..00f6169fa3103 100644 --- a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py +++ b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py @@ -28,6 +28,7 @@ class TrilTriuOpDefaultTest(OpTest): def setUp(self): self.initTestCase() + self.python_api = paddle.tril if self.real_op_type == 'tril' else paddle.triu self.real_np_op = getattr(np, self.real_op_type) self.op_type = "tril_triu" @@ -42,10 +43,10 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) def initTestCase(self): self.real_op_type = np.random.choice(['triu', 'tril']) diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index f7d765d854116..414f5cefff498 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function -from paddle.fluid.framework import _global_flags import numpy as np from ...device import get_cudnn_version @@ -22,15 +21,18 @@ from ...fluid.data_feeder import check_variable_and_dtype from ...framework import ParamAttr from ...fluid.layer_helper import LayerHelper -from paddle import _C_ops from ...tensor.manipulation import unsqueeze, squeeze from ...tensor.math import add from ...fluid.layers import nn +from paddle import _C_ops +from paddle import get_flags +from paddle import in_dynamic_mode from paddle.device import is_compiled_with_cuda -from paddle.device import is_compiled_with_rocm from paddle.device import is_compiled_with_npu -from paddle import in_dynamic_mode -from paddle import get_flags +from paddle.device import is_compiled_with_rocm +from paddle.fluid.framework import _global_flags +from paddle.fluid.framework import _in_legacy_dygraph +from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -1061,7 +1063,17 @@ def conv2d_transpose(x, op_type = 'depthwise_conv2d_transpose' use_cudnn = False - if in_dynamic_mode(): + if in_dygraph_mode(): + final_state_op = _C_ops.final_state_conv2d_transpose if op_type == 'conv2d_transpose' else _C_ops.final_state_depthwise_conv2d_transpose + pre_bias = final_state_op(x, weight, stride, padding, output_padding, + output_size, padding_algorithm, groups, + dilation, data_format) + if bias is not None: + return nn.elementwise_add(pre_bias, bias, axis=channel_dim) + else: + return pre_bias + + if _in_legacy_dygraph(): attrs = ('output_padding', output_padding, 'output_size', output_size, 'strides', stride, 'paddings', padding, 'padding_algorithm', padding_algorithm, 'dilations', dilation, 'groups', groups, @@ -1468,7 +1480,16 @@ def conv3d_transpose(x, op_type = 'conv3d_transpose' data_format_ = "NHWC" if channel_last else "NCHW" - if in_dynamic_mode(): + if in_dygraph_mode(): + pre_bias = _C_ops.final_state_conv3d_transpose( + x, weight, stride, padding, output_padding, output_size, + padding_algorithm, groups, dilation, data_format_) + if bias is not None: + return nn.elementwise_add(pre_bias, bias, axis=channel_dim) + else: + return pre_bias + + if _in_legacy_dygraph(): attrs = ('output_padding', output_padding, 'output_size', output_size, 'paddings', padding, "padding_algorithm", padding_algorithm, 'strides', stride, 'dilations', dilation, 'groups', groups, diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 34a0159fbb0dc..b9cae4784725d 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -18,6 +18,8 @@ from ...fluid.data_feeder import check_type, check_variable_and_dtype from paddle import _C_ops from paddle import in_dynamic_mode +from paddle.fluid.framework import _in_legacy_dygraph +from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -344,13 +346,18 @@ def avg_pool2d(x, padding, padding_algorithm = _update_padding_nd( padding, 2, channel_last, ceil_mode=ceil_mode) - if in_dynamic_mode(): - output = _C_ops.pool2d(x, 'pooling_type', 'avg', 'ksize', kernel_size, - 'global_pooling', False, 'padding_algorithm', - padding_algorithm, 'strides', stride, 'paddings', - padding, 'use_cudnn', True, 'ceil_mode', - ceil_mode, 'use_mkldnn', False, 'exclusive', - exclusive, 'data_format', data_format) + if in_dygraph_mode() or _in_legacy_dygraph(): + if in_dygraph_mode(): + output = _C_ops.final_state_pool2d( + x, kernel_size, stride, padding, ceil_mode, exclusive, + data_format, 'avg', False, False, padding_algorithm) + else: + output = _C_ops.pool2d( + x, 'pooling_type', 'avg', 'ksize', kernel_size, + 'global_pooling', False, 'padding_algorithm', padding_algorithm, + 'strides', stride, 'paddings', padding, 'use_cudnn', True, + 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', + exclusive, 'data_format', data_format) if divisor_override is None: return output else: @@ -466,13 +473,18 @@ def avg_pool3d(x, _check_value_limitation(kernel_size, "kernel_size", min_limit=1e-3) _check_value_limitation(stride, "stride", min_limit=1e-3) - if in_dynamic_mode(): - output = _C_ops.pool3d( - x, 'pooling_type', 'avg', 'ksize', kernel_size, 'strides', stride, - 'paddings', padding, 'global_pooling', False, 'padding_algorithm', - padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, - 'use_mkldnn', False, 'exclusive', exclusive, 'data_format', - data_format) + if in_dygraph_mode() or _in_legacy_dygraph(): + if in_dygraph_mode(): + output = _C_ops.final_state_pool3d( + x, kernel_size, stride, padding, ceil_mode, exclusive, + data_format, 'avg', False, False, padding_algorithm) + if _in_legacy_dygraph(): + output = _C_ops.pool3d( + x, 'pooling_type', 'avg', 'ksize', kernel_size, 'strides', + stride, 'paddings', padding, 'global_pooling', False, + 'padding_algorithm', padding_algorithm, 'use_cudnn', True, + 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', + exclusive, 'data_format', data_format) if divisor_override is None: return output else: @@ -585,7 +597,20 @@ def max_pool1d(x, # use 2d to implenment 1d should expand padding in advance. padding = _expand_low_nd_padding(padding) - if in_dynamic_mode(): + if in_dygraph_mode(): + if return_mask: + pool_out = _C_ops.final_state_max_pool2d_with_index( + x, kernel_size, stride, padding, False, False) + return (squeeze(pool_out[0], [2]), + squeeze(pool_out[1], + [2])) if return_mask else squeeze(pool_out[0], [2]) + else: + pool_out = _C_ops.final_state_pool2d( + x, kernel_size, stride, padding, ceil_mode, True, data_format, + 'max', False, False, padding_algorithm) + return squeeze(pool_out, [2]) + + if _in_legacy_dygraph(): if return_mask: pool_out = _C_ops.max_pool2d_with_index( x, 'ksize', kernel_size, 'global_pooling', False, 'strides', @@ -1027,7 +1052,17 @@ def max_pool2d(x, "When setting return_mask to true, data_format must be set to NCHW in API:max_pool2d" ) - if in_dynamic_mode(): + if in_dygraph_mode(): + if return_mask: + output = _C_ops.final_state_max_pool2d_with_index( + x, kernel_size, stride, padding, False, False) + return output if return_mask else output[0] + else: + return _C_ops.final_state_pool2d( + x, kernel_size, stride, padding, ceil_mode, True, data_format, + 'max', False, False, padding_algorithm) + + if _in_legacy_dygraph(): if return_mask: output = _C_ops.max_pool2d_with_index( x, 'ksize', kernel_size, 'global_pooling', False, 'strides', @@ -1158,7 +1193,17 @@ def max_pool3d(x, "When setting return_mask to true, data_format must be set to NCDHW in API:max_pool3d" ) - if in_dynamic_mode(): + if in_dygraph_mode(): + if return_mask: + output = _C_ops.final_state_max_pool3d_with_index( + x, kernel_size, stride, padding, False, False) + return output if return_mask else output[0] + else: + return _C_ops.final_state_pool3d( + x, kernel_size, stride, padding, ceil_mode, True, data_format, + 'max', False, False, padding_algorithm) + + if _in_legacy_dygraph(): if return_mask: output = _C_ops.max_pool3d_with_index( x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides', @@ -1355,11 +1400,15 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): if output_size[1] == None: output_size[1] = in_w - if in_dynamic_mode(): - output = _C_ops.pool2d(x, 'pooling_type', 'avg', 'ksize', output_size, - 'global_pooling', False, 'adaptive', True, - 'data_format', data_format) - return output + if in_dygraph_mode(): + return _C_ops.final_state_pool2d(x, output_size, [1, 1], [0, 0], False, + True, data_format, 'avg', False, True, + "EXPLICIT") + + if _in_legacy_dygraph(): + return _C_ops.pool2d(x, 'pooling_type', 'avg', 'ksize', output_size, + 'global_pooling', False, 'adaptive', True, + 'data_format', data_format) l_type = 'pool2d' @@ -1462,10 +1511,9 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None): output_size[2] = in_w if in_dynamic_mode(): - output = _C_ops.pool3d(x, 'pooling_type', 'avg', 'ksize', output_size, - 'global_pooling', False, 'adaptive', True, - 'data_format', data_format) - return output + return _C_ops.pool3d(x, 'pooling_type', 'avg', 'ksize', output_size, + 'global_pooling', False, 'adaptive', True, + 'data_format', data_format) l_type = 'pool3d' diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index ca16995f84d2f..166ae58a19770 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -661,7 +661,10 @@ def tril(x, diagonal=0, name=None): # [ 9, 10, 0, 0]]) """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_tril_triu(x, diagonal, True) + + if _in_legacy_dygraph(): op = getattr(_C_ops, 'tril_triu') return op(x, 'diagonal', diagonal, "lower", True) @@ -728,7 +731,10 @@ def triu(x, diagonal=0, name=None): # [ 0, 10, 11, 12]]) """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_tril_triu(x, diagonal, False) + + if _in_legacy_dygraph(): op = getattr(_C_ops, 'tril_triu') return op(x, 'diagonal', diagonal, "lower", False) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 8afab2e05f26b..81c99c5a41e03 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -254,7 +254,12 @@ def frobenius_norm(input, dim=None, keepdim=False, name=None): raise ValueError( "The dim of frobenius norm op should be None or two elements list!" ) - if paddle.in_dynamic_mode(): + + if in_dygraph_mode(): + if dim is None: + return _C_ops.final_state_frobenius_norm(input, keepdim, True) + return _C_ops.final_state_frobenius_norm(input, dim, keepdim, False) + if _in_legacy_dygraph(): if dim is None: return _C_ops.frobenius_norm(input, 'keep_dim', keepdim, 'reduce_all', True) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 9fe3304bf2471..ca807c286a05b 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -796,7 +796,10 @@ def roll(x, shifts, axis=None, name=None): else: axis = [] - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_roll(x, shifts, axis) + + if _in_legacy_dygraph(): return _C_ops.roll(x, 'axis', axis, 'shifts', shifts) helper = LayerHelper("roll", **locals()) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 15c9e060c5517..5c290aa0eb760 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -319,7 +319,10 @@ def index_select(x, index, axis=0, name=None): # [ 9. 10. 10.]] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_index_select(x, index, axis) + + if _in_legacy_dygraph(): return _C_ops.index_select(x, index, 'dim', axis) helper = LayerHelper("index_select", **locals()) @@ -946,8 +949,11 @@ def searchsorted(sorted_sequence, # [1, 3, 4, 5]]) """ + if in_dygraph_mode(): + return _C_ops.final_state_searchsorted(sorted_sequence, values, + out_int32, right) - if paddle.in_dynamic_mode(): + if _in_legacy_dygraph(): return _C_ops.searchsorted(sorted_sequence, values, "out_int32", out_int32, "right", right) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index b46accfb11b01..b3bf1f7890400 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -306,6 +306,24 @@ kernel : func : conj +- api : conv2d_transpose + args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) + output : Tensor(out) + infer_meta : + func : ConvTransposeInferMeta + kernel : + func : conv2d_transpose + backward : conv2d_transpose_grad + +- api : conv3d_transpose + args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) + output : Tensor(out) + infer_meta : + func : ConvTransposeInferMeta + kernel : + func : conv3d_transpose + backward : conv3d_transpose_grad + - api : copy_to args : (Tensor x, Place place, bool blocking) output : Tensor @@ -359,6 +377,15 @@ kernel : func : cumsum +- api : depthwise_conv2d_transpose + args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) + output : Tensor(out) + infer_meta : + func : ConvTransposeInferMeta + kernel : + func : depthwise_conv2d_transpose + backward : depthwise_conv2d_transpose_grad + - api : diag args : (Tensor x, int offset, float padding_value) output : Tensor @@ -558,6 +585,15 @@ func : fmin backward : fmin_grad +- api : frobenius_norm + args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) + output : Tensor(out) + infer_meta : + func : ReduceInferMetaBase + kernel : + func : frobenius_norm + backward : frobenius_norm_grad + - api : full args : (IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) output: Tensor @@ -695,6 +731,16 @@ backward : index_sample_grad # no_need_buffer : x +- api : index_select + args : (Tensor x, Tensor index, int dim) + output : Tensor(out) + infer_meta : + func : IndexSelectInferMeta + kernel : + func : index_select + data_type : x + backward : index_select_grad + # is_empty - api : is_empty args : (Tensor x) @@ -954,6 +1000,24 @@ func : max backward : max_grad +- api : max_pool2d_with_index + args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + output : Tensor(out), Tensor(mask) + infer_meta : + func : MaxPoolWithIndexInferMeta + kernel : + func : max_pool2d_with_index + backward : max_pool2d_with_index_grad + +- api : max_pool3d_with_index + args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + output : Tensor(out), Tensor(mask) + infer_meta : + func : MaxPoolWithIndexInferMeta + kernel : + func : max_pool3d_with_index + backward : max_pool3d_with_index_grad + - api : maximum args : (Tensor x, Tensor y) output : Tensor(out) @@ -1129,8 +1193,18 @@ output : Tensor(out) infer_meta : func : PoolInferMeta - kernel: + kernel : func : pool2d + backward : pool2d_grad + +- api : pool3d + args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + output : Tensor(out) + infer_meta : + func : PoolInferMeta + kernel : + func : pool3d + backward : pool3d_grad - api : prelu args : (Tensor x, Tensor alpha, str data_format, str mode) @@ -1194,6 +1268,15 @@ intermediate : xshape backward: reshape_grad +- api : roll + args : (Tensor x, IntArray shifts, int64_t[] axis) + output : Tensor(out) + infer_meta : + func : RollInferMeta + kernel : + func : roll + backward : roll_grad + - api : round args : (Tensor x) output : Tensor(out) @@ -1235,6 +1318,14 @@ backward : scatter_nd_add_grad # no_need_buffer : updates +- api : searchsorted + args : (Tensor sorted_sequence, Tensor value, bool out_int32, bool right) + output : Tensor(out) + infer_meta : + func : SearchsortedInferMeta + kernel : + func : searchsorted + # segment_pool - api : segment_pool args : (Tensor x, Tensor segment_ids, str pooltype) @@ -1522,6 +1613,15 @@ func : triangular_solve # backward : triangular_solve_grad +- api : tril_triu + args : (Tensor x, int diagonal, bool lower) + output : Tensor(out) + infer_meta : + func : TrilTriuInferMeta + kernel : + func : tril_triu + backward : tril_triu_grad + - api : trunc args : (Tensor x) output : Tensor diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index e281484f69744..d3c3177827b28 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -710,9 +710,9 @@ def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False): self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag) api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') return f""" +{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; {code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); -{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; {code_indent} VLOG(6) << "{self.api} API kernel: " << kernel; {code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index ad22723c994cf..d3d589d00f7f2 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -172,6 +172,24 @@ kernel : func : cholesky_solve_grad +- backward_api : conv2d_transpose_grad + forward : conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) + args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) + output : Tensor(x_grad), Tensor(filter_grad) + infer_meta : + func : ConvTransposeGradInferMeta + kernel : + func : conv2d_transpose_grad + +- backward_api : conv3d_transpose_grad + forward : conv3d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) + args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) + output : Tensor(x_grad), Tensor(filter_grad) + infer_meta : + func : ConvTransposeGradInferMeta + kernel : + func : conv3d_transpose_grad + - backward_api : cos_grad forward : cos (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -221,6 +239,15 @@ # kernel : # func : gumbel_softmax_grad +- backward_api : depthwise_conv2d_transpose_grad + forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) + args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) + output : Tensor(x_grad), Tensor(filter_grad) + infer_meta : + func : ConvTransposeGradInferMeta + kernel : + func : depthwise_conv2d_transpose_grad + - backward_api : diagonal_grad forward : diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) args : (Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1) @@ -352,6 +379,16 @@ kernel : func : fmin_grad +- backward_api : frobenius_norm_grad + forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : frobenius_norm_grad + - backward_api : gather_nd_grad forward : gather_nd (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) @@ -403,6 +440,17 @@ func : index_sample_grad data_type : out_grad +- backward_api : index_select_grad + forward : index_select(Tensor x, Tensor index, int dim) -> Tensor(out) + args : (Tensor x, Tensor index, Tensor out_grad, int dim) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : index_select_grad + data_type : x + - backward_api : kldiv_loss_grad forward : kldiv_loss(Tensor x, Tensor label, str reduction) -> Tensor(out) args : (Tensor x, Tensor label, Tensor out_grad, str reduction) @@ -597,6 +645,24 @@ kernel : func : max_grad +- backward_api : max_pool2d_with_index_grad + forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + output : Tensor(x_grad) + infer_meta : + func : MaxPoolWithIndexGradInferMeta + kernel : + func : max_pool2d_with_index_grad + +- backward_api : max_pool3d_with_index_grad + forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + output : Tensor(x_grad) + infer_meta : + func : MaxPoolWithIndexGradInferMeta + kernel : + func : max_pool3d_with_index_grad + - backward_api : maximum_grad forward : maximum(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) @@ -719,6 +785,24 @@ kernel : func : pad3d_grad +- backward_api : pool2d_grad + forward : pool2d(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + output : Tensor(x_grad) + infer_meta : + func : PoolGradInferMeta + kernel : + func : pool2d_grad + +- backward_api : pool3d_grad + forward : pool3d(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + output : Tensor(x_grad) + infer_meta : + func : PoolGradInferMeta + kernel : + func : pool3d_grad + - backward_api : prelu_grad forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out) args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode) @@ -806,6 +890,17 @@ backend: out_grad layout: out_grad +- backward_api : roll_grad + forward : roll(Tensor x, IntArray shifts, int64_t[] axis) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray shifts, int64_t[] axis) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : roll_grad + data_type : x + - backward_api : round_grad forward : round(Tensor x) -> Tensor(out) args : (Tensor out_grad) @@ -1079,6 +1174,16 @@ kernel : func : transpose_grad +- backward_api : tril_triu_grad + forward : tril_triu(Tensor x, int diagonal, bool lower) -> Tensor(out) + args : (Tensor out_grad, int diagonal, bool lower) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : tril_triu_grad + - backward_api : trunc_grad forward : trunc (Tensor x) -> Tensor(out) args : (Tensor out_grad) From 7315fb2d310eeee418be456105e9a502b30e0728 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sun, 3 Apr 2022 11:45:30 +0800 Subject: [PATCH 39/93] [Eager] Support admax, fill_diagonal, fill_diagonal_tensor_, to_list, ... in eager mode (#41117) * Update ResNet test cases * [Eager] Support uva, adamax, fill_diagonal_, to_list and so on. * Fix CI * Updated CUDA defined statement * Fix CI * Update headers, Fix CI * Remove useless setting * Updated func name to Fix windows-CI * Remove tensor uva related codes * Remove uva related code * recover original test --- paddle/fluid/pybind/eager_functions.cc | 5 +- python/paddle/fluid/framework.py | 6 +- .../fluid/tests/unittests/test_Tensor_type.py | 22 +++++++- .../fluid/tests/unittests/test_adamax_api.py | 22 +++++++- .../unittests/test_tensor_fill_diagonal_.py | 43 +++++++++++++-- .../test_tensor_fill_diagonal_tensor_.py | 36 ++++++++++-- .../tests/unittests/test_tensor_to_list.py | 8 ++- python/paddle/optimizer/adamax.py | 55 +++++++++++-------- 8 files changed, 154 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 7a6705e63b420..0c6707748ef5a 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/op_meta_info_helper.h" +#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" @@ -35,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" +#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/storage.h" @@ -771,6 +773,7 @@ static PyObject* eager_api_async_write(PyObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } #endif + PyMethodDef variable_functions[] = { // TODO(jiabin): Remove scale when we have final state tests {"scale", (PyCFunction)(void (*)(void))eager_api_scale, @@ -794,13 +797,13 @@ PyMethodDef variable_functions[] = { {"sparse_csr_tensor", (PyCFunction)(void (*)(void))eager_api_sparse_csr_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, +/**sparse functions**/ #if defined(PADDLE_WITH_CUDA) {"async_read", (PyCFunction)(void (*)(void))eager_api_async_read, METH_VARARGS | METH_KEYWORDS, NULL}, {"async_write", (PyCFunction)(void (*)(void))eager_api_async_write, METH_VARARGS | METH_KEYWORDS, NULL}, #endif - /**sparse functions**/ {NULL, NULL, 0, NULL}}; void BindFunctions(PyObject* module) { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b8ed2716fc7d5..dc1f82d235e31 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -173,9 +173,13 @@ def _test_eager_guard(place=None): monkey_patch_math_varbase() # Ugly setting - from paddle.tensor.manipulation import fill_, zero_ + from paddle.tensor.manipulation import fill_, zero_, fill_diagonal_, fill_diagonal_tensor_, tolist setattr(core.eager.Tensor, 'fill_', fill_) setattr(core.eager.Tensor, 'zero_', zero_) + setattr(core.eager.Tensor, 'fill_diagonal_', fill_diagonal_) + setattr(core.eager.Tensor, 'fill_diagonal_tensor_', + fill_diagonal_tensor_) + setattr(core.eager.Tensor, 'tolist', tolist) _already_patch_eager_tensor = True try: diff --git a/python/paddle/fluid/tests/unittests/test_Tensor_type.py b/python/paddle/fluid/tests/unittests/test_Tensor_type.py index f1427d29782b9..c40981c073724 100644 --- a/python/paddle/fluid/tests/unittests/test_Tensor_type.py +++ b/python/paddle/fluid/tests/unittests/test_Tensor_type.py @@ -18,10 +18,11 @@ import numpy as np import paddle import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard class TensorTypeTest(unittest.TestCase): - def test_type_totensor(self): + def func_type_totensor(self): paddle.disable_static() inx = np.array([1, 2]) tensorx = paddle.to_tensor(inx) @@ -29,7 +30,12 @@ def test_type_totensor(self): expectx = "" self.assertEqual((typex_str == expectx), True) - def test_type_Tensor(self): + def test_type_totensor(self): + with _test_eager_guard(): + self.func_type_totensor() + self.func_type_totensor() + + def func_type_Tensor(self): paddle.disable_static() inx = np.array([1, 2]) tensorx = paddle.Tensor(inx) @@ -43,7 +49,12 @@ def test_type_Tensor(self): expectx = "" self.assertEqual((typex_str == expectx), True) - def test_type_core(self): + def test_type_Tensor(self): + with _test_eager_guard(): + self.func_type_Tensor() + self.func_type_Tensor() + + def func_type_core(self): paddle.disable_static() inx = np.array([1, 2]) tensorx = core.VarBase(inx) @@ -56,6 +67,11 @@ def test_type_core(self): expectx = "" self.assertEqual((typex_str == expectx), True) + def test_type_core(self): + with _test_eager_guard(): + pass + self.func_type_core() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adamax_api.py b/python/paddle/fluid/tests/unittests/test_adamax_api.py index 57cb9d3cb5f7d..1698ac90a9f2d 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_api.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_api.py @@ -19,10 +19,11 @@ from op_test import OpTest import paddle import paddle.fluid as fluid +from paddle.fluid.framework import _test_eager_guard class TestAdamaxAPI(unittest.TestCase): - def test_adamax_api_dygraph(self): + def func_adamax_api_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) @@ -36,7 +37,12 @@ def test_adamax_api_dygraph(self): adam.step() adam.clear_gradients() - def test_adamax_api(self): + def test_adamax_api_dygraph(self): + with _test_eager_guard(): + self.func_adamax_api_dygraph() + self.func_adamax_api_dygraph() + + def func_adamax_api(self): paddle.enable_static() place = fluid.CPUPlace() shape = [2, 3, 8, 8] @@ -63,9 +69,14 @@ def test_adamax_api(self): rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss]) assert rets[0] is not None + def test_adamax_api(self): + with _test_eager_guard(): + self.func_adamax_api() + self.func_adamax_api() + class TestAdamaxAPIGroup(TestAdamaxAPI): - def test_adamax_api_dygraph(self): + def func_adamax_api_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) @@ -89,6 +100,11 @@ def test_adamax_api_dygraph(self): adam.step() adam.clear_gradients() + def test_adamax_api_dygraph(self): + with _test_eager_guard(): + self.func_adamax_api_dygraph() + self.func_adamax_api_dygraph() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py index 3beb6a537eca0..ca0c97adedb94 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py @@ -17,10 +17,11 @@ import numpy as np import six import paddle +from paddle.fluid.framework import _test_eager_guard class TensorFillDiagonal_Test(unittest.TestCase): - def test_dim2_normal(self): + def func_dim2_normal(self): expected_np = np.array( [[1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') expected_grad = np.array( @@ -50,7 +51,12 @@ def test_dim2_normal(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_offset(self): + def test_dim2_normal(self): + with _test_eager_guard(): + self.func_dim2_normal() + self.func_dim2_normal() + + def func_offset(self): expected_np = np.array( [[2, 2, 1], [2, 2, 2], [2, 2, 2]]).astype('float32') expected_grad = np.array( @@ -80,7 +86,12 @@ def test_offset(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_bool(self): + def test_offset(self): + with _test_eager_guard(): + self.func_offset() + self.func_offset() + + def func_bool(self): expected_np = np.array( [[False, True, True], [True, False, True], [True, True, False]]) @@ -101,7 +112,12 @@ def test_bool(self): self.assertEqual((x.numpy() == expected_np).all(), True) - def test_dim2_unnormal_wrap(self): + def test_bool(self): + with _test_eager_guard(): + self.func_bool() + self.func_bool() + + def func_dim2_unnormal_wrap(self): expected_np = np.array([[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2], [1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') @@ -133,7 +149,12 @@ def test_dim2_unnormal_wrap(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim2_unnormal_unwrap(self): + def test_dim2_unnormal_wrap(self): + with _test_eager_guard(): + self.func_dim2_unnormal_wrap() + self.func_dim2_unnormal_wrap() + + def func_dim2_unnormal_unwrap(self): expected_np = np.array([[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]).astype('float32') @@ -165,7 +186,12 @@ def test_dim2_unnormal_unwrap(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim_larger2_normal(self): + def test_dim2_unnormal_unwrap(self): + with _test_eager_guard(): + self.func_dim2_unnormal_unwrap() + self.func_dim2_unnormal_unwrap() + + def func_dim_larger2_normal(self): expected_np = np.array([[[1, 2, 2], [2, 2, 2], [2, 2, 2]], [[2, 2, 2], [ 2, 1, 2 ], [2, 2, 2]], [[2, 2, 2], [2, 2, 2], [2, 2, 1]]]).astype('float32') @@ -198,6 +224,11 @@ def test_dim_larger2_normal(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) + def test_dim_larger2_normal(self): + with _test_eager_guard(): + self.func_dim_larger2_normal() + self.func_dim_larger2_normal() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py index 2f37ccf219eb0..81ec1daa6691d 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py @@ -18,6 +18,7 @@ import numpy as np import six import paddle +from paddle.fluid.framework import _test_eager_guard class TensorFillDiagTensor_Test(unittest.TestCase): @@ -27,7 +28,7 @@ def setUp(self): if fluid.core.is_compiled_with_cuda(): self.places.append(fluid.CUDAPlace(0)) - def test_dim2(self): + def func_dim2(self): expected_np = np.array( [[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2]]).astype('float32') expected_grad = np.array( @@ -54,7 +55,12 @@ def test_dim2(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim2_offset_1(self): + def test_dim2(self): + with _test_eager_guard(): + self.func_dim2() + self.func_dim2() + + def func_dim2_offset_1(self): expected_np = np.array( [[2, 2, 2], [1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') expected_grad = np.array( @@ -81,7 +87,12 @@ def test_dim2_offset_1(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim2_offset1(self): + def test_dim2_offset_1(self): + with _test_eager_guard(): + self.func_dim2_offset_1() + self.func_dim2_offset_1() + + def func_dim2_offset1(self): expected_np = np.array( [[2, 1, 2], [2, 2, 1], [2, 2, 2], [2, 2, 2]]).astype('float32') expected_grad = np.array( @@ -108,7 +119,12 @@ def test_dim2_offset1(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_dim4(self): + def test_dim2_offset1(self): + with _test_eager_guard(): + self.func_dim2_offset1() + self.func_dim2_offset1() + + def func_dim4(self): expected_np = np.array( [[[[0, 3], [2, 2], [2, 2]], [[2, 2], [1, 4], [2, 2]], [[2, 2], [2, 2], [2, 5]], [[2, 2], [2, 2], [2, 2]]], @@ -144,7 +160,12 @@ def test_dim4(self): (y.grad.numpy().astype('float32') == expected_grad).all(), True) - def test_largedim(self): + def test_func_dim4(self): + with _test_eager_guard(): + self.func_dim4() + self.func_dim4() + + def func_largedim(self): #large dim only test on gpu because the cpu version is too slow for ci test, and the memory is limited if len(self.places) > 1: bsdim = 1024 @@ -168,6 +189,11 @@ def test_largedim(self): self.assertEqual((y == expected_pred).all(), True) self.assertEqual((y.grad == expected_grad).all(), True) + def test_largedim(self): + with _test_eager_guard(): + self.func_largedim() + self.func_largedim() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_to_list.py b/python/paddle/fluid/tests/unittests/test_tensor_to_list.py index 73b91297e6fd6..a78113030ed53 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_to_list.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_to_list.py @@ -17,13 +17,14 @@ import numpy as np import six import paddle +from paddle.fluid.framework import _test_eager_guard class TensorToListTest(unittest.TestCase): def setUp(self): self.shape = [11, 25, 32, 43] - def test_tensor_tolist(self): + def func_tensor_tolist(self): places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) @@ -39,6 +40,11 @@ def test_tensor_tolist(self): self.assertEqual(tensorlist, expectlist) + def test_tensor_tolist(self): + with _test_eager_guard(): + self.func_tensor_tolist() + self.func_tensor_tolist() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/optimizer/adamax.py b/python/paddle/optimizer/adamax.py index de70e2e72a9c6..4c4a85559c0d9 100644 --- a/python/paddle/optimizer/adamax.py +++ b/python/paddle/optimizer/adamax.py @@ -16,6 +16,7 @@ from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable, name_scope +from paddle import _C_ops __all__ = [] @@ -190,30 +191,38 @@ def _append_optimize_op(self, block, param_and_grad): param_and_grad[0]) beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, param_and_grad[0]) - # create the adamax optimize op - adamax_op = block.append_op( - type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "LearningRate": self._create_param_lr(param_and_grad), - "Moment": moment, - "InfNorm": inf_norm, - "Beta1Pow": beta1_pow_acc - }, - outputs={ - "ParamOut": param_and_grad[0], - "MomentOut": moment, - "InfNormOut": inf_norm - }, - attrs={ - "beta1": self._beta1, - "beta2": self._beta2, - "epsilon": self._epsilon - }, - stop_gradient=True) - return adamax_op + if framework._non_static_mode(): + _C_ops.adamax(param_and_grad[0], param_and_grad[1], + self._create_param_lr(param_and_grad), moment, + inf_norm, beta1_pow_acc, param_and_grad[0], moment, + inf_norm, "beta1", self._beta1, "beta2", self._beta2, + "epsilon", self._epsilon) + else: + # create the adamax optimize op + adamax_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": self._create_param_lr(param_and_grad), + "Moment": moment, + "InfNorm": inf_norm, + "Beta1Pow": beta1_pow_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": moment, + "InfNormOut": inf_norm + }, + attrs={ + "beta1": self._beta1, + "beta2": self._beta2, + "epsilon": self._epsilon + }, + stop_gradient=True) + + return adamax_op def _finish_update(self, block, parameters_and_grads): """Update Beta1 Power accumulator From fd1ecfc50a99886d47263700b8d3ff439f3bb34d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 3 Apr 2022 11:50:17 +0800 Subject: [PATCH 40/93] Add randperm and range yaml (#41265) * add randperm and range yaml * add eager test for randperm --- paddle/fluid/operators/range_op.cc | 2 +- paddle/phi/infermeta/nullary.cc | 5 + paddle/phi/infermeta/nullary.h | 2 + paddle/phi/infermeta/ternary.cc | 100 +++++++++--------- paddle/phi/infermeta/ternary.h | 10 +- .../{range_kernel.h => arange_kernel.h} | 10 +- .../cpu/{range_kernel.cc => arange_kernel.cc} | 14 +-- .../gpu/{range_kernel.cu => arange_kernel.cu} | 14 +-- paddle/phi/ops/compat/range_sig.cc | 17 +++ python/paddle/fluid/layers/tensor.py | 8 +- .../fluid/tests/unittests/test_randperm_op.py | 19 ++++ .../fluid/tests/unittests/test_range.py | 13 ++- python/paddle/tensor/random.py | 7 +- python/paddle/utils/code_gen/api.yaml | 26 +++++ 14 files changed, 167 insertions(+), 80 deletions(-) rename paddle/phi/kernels/{range_kernel.h => arange_kernel.h} (78%) rename paddle/phi/kernels/cpu/{range_kernel.cc => arange_kernel.cc} (78%) rename paddle/phi/kernels/gpu/{range_kernel.cu => arange_kernel.cu} (86%) create mode 100644 paddle/phi/ops/compat/range_sig.cc diff --git a/paddle/fluid/operators/range_op.cc b/paddle/fluid/operators/range_op.cc index ddfbdbace054d..80fdb2ce6c345 100644 --- a/paddle/fluid/operators/range_op.cc +++ b/paddle/fluid/operators/range_op.cc @@ -61,6 +61,6 @@ class RangeOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(range, RangeInferMetaFunctor, - PD_INFER_META(phi::RangeInferMeta)); + PD_INFER_META(phi::ArangeInferMeta)); REGISTER_OP_WITHOUT_GRADIENT(range, ops::RangeOp, ops::RangeOpMaker, RangeInferMetaFunctor); diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 4a11d24a9868b..6a05e1b4d7f30 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -58,6 +58,11 @@ void GaussianRandomInferMeta(const IntArray& shape, out->set_layout(DataLayout::NCHW); } +void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) { + out->set_dims(phi::make_ddim({n})); + out->set_dtype(dtype); +} + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 4c9eb0b62a74e..ada44658a2c25 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -53,6 +53,8 @@ void GaussianRandomInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out); +void RandpermInferMeta(int n, DataType dtype, MetaTensor* out); + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 582dcb0137894..3e4aa7b4448e3 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -141,6 +141,56 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void ArangeInferMeta(const MetaTensor& start, + const MetaTensor& end, + const MetaTensor& step, + MetaTensor* out) { + auto start_dims = start.dims(); + auto end_dims = end.dims(); + auto step_dims = step.dims(); + PADDLE_ENFORCE_EQ( + start_dims.size(), + 1, + phi::errors::InvalidArgument( + "The dim of the shape of Input(Start) should be 1, but got %d", + start_dims.size())); + + PADDLE_ENFORCE_EQ(start_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dim of the shape of Input(Start) should " + "be 1, but got %d", + start_dims[0])); + PADDLE_ENFORCE_EQ( + end_dims.size(), + 1, + phi::errors::InvalidArgument( + "The dim of the shape of Input(End) should be 1, but got %d", + end_dims.size())); + + PADDLE_ENFORCE_EQ( + end_dims[0], + 1, + phi::errors::InvalidArgument("The first dim of the shape of " + "Input(End) should be 1, but got %d", + end_dims[0])); + PADDLE_ENFORCE_EQ( + step_dims.size(), + 1, + phi::errors::InvalidArgument( + "The dim of the shape of Input(Step) should be 1, but got %d", + step_dims.size())); + + PADDLE_ENFORCE_EQ(step_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dim of the shape of Input(Step) should " + "be 1, but got %d", + step_dims[0])); + out->set_dims({-1}); + out->set_dtype(start.dtype()); +} + void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, @@ -345,56 +395,6 @@ void PutAlongAxisInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } -void RangeInferMeta(const MetaTensor& start, - const MetaTensor& end, - const MetaTensor& step, - MetaTensor* out) { - auto start_dims = start.dims(); - auto end_dims = end.dims(); - auto step_dims = step.dims(); - PADDLE_ENFORCE_EQ( - start_dims.size(), - 1, - phi::errors::InvalidArgument( - "The dim of the shape of Input(Start) should be 1, but got %d", - start_dims.size())); - - PADDLE_ENFORCE_EQ(start_dims[0], - 1, - phi::errors::InvalidArgument( - "The first dim of the shape of Input(Start) should " - "be 1, but got %d", - start_dims[0])); - PADDLE_ENFORCE_EQ( - end_dims.size(), - 1, - phi::errors::InvalidArgument( - "The dim of the shape of Input(End) should be 1, but got %d", - end_dims.size())); - - PADDLE_ENFORCE_EQ( - end_dims[0], - 1, - phi::errors::InvalidArgument("The first dim of the shape of " - "Input(End) should be 1, but got %d", - end_dims[0])); - PADDLE_ENFORCE_EQ( - step_dims.size(), - 1, - phi::errors::InvalidArgument( - "The dim of the shape of Input(Step) should be 1, but got %d", - step_dims.size())); - - PADDLE_ENFORCE_EQ(step_dims[0], - 1, - phi::errors::InvalidArgument( - "The first dim of the shape of Input(Step) should " - "be 1, but got %d", - step_dims[0])); - out->set_dims({-1}); - out->set_dtype(start.dtype()); -} - void RoiAlignInferMeta(const MetaTensor& x, const MetaTensor& boxes, paddle::optional boxes_num, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index c18dde42f1ed2..00e49811688ac 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -47,6 +47,11 @@ void AddmmInferMeta(const MetaTensor& input, float beta, MetaTensor* out); +void ArangeInferMeta(const MetaTensor& start, + const MetaTensor& end, + const MetaTensor& step, + MetaTensor* out); + void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, @@ -81,11 +86,6 @@ void PutAlongAxisInferMeta(const MetaTensor& x, const std::string& reduce, MetaTensor* out); -void RangeInferMeta(const MetaTensor& start, - const MetaTensor& end, - const MetaTensor& step, - MetaTensor* out); - void RoiAlignInferMeta(const MetaTensor& x, const MetaTensor& boxes, paddle::optional boxes_num, diff --git a/paddle/phi/kernels/range_kernel.h b/paddle/phi/kernels/arange_kernel.h similarity index 78% rename from paddle/phi/kernels/range_kernel.h rename to paddle/phi/kernels/arange_kernel.h index c76308193ae5e..be60824ac2be2 100644 --- a/paddle/phi/kernels/range_kernel.h +++ b/paddle/phi/kernels/arange_kernel.h @@ -19,10 +19,10 @@ namespace phi { template -void RangeKernel(const Context& dev_ctx, - const DenseTensor& start, - const DenseTensor& end, - const DenseTensor& step, - DenseTensor* out); +void ArangeKernel(const Context& dev_ctx, + const DenseTensor& start, + const DenseTensor& end, + const DenseTensor& step, + DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/range_kernel.cc b/paddle/phi/kernels/cpu/arange_kernel.cc similarity index 78% rename from paddle/phi/kernels/cpu/range_kernel.cc rename to paddle/phi/kernels/cpu/arange_kernel.cc index 8731696f61760..478251b0d3b6a 100644 --- a/paddle/phi/kernels/cpu/range_kernel.cc +++ b/paddle/phi/kernels/cpu/arange_kernel.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/range_kernel.h" +#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/range_function.h" @@ -20,11 +20,11 @@ limitations under the License. */ namespace phi { template -void RangeKernel(const Context& dev_ctx, - const DenseTensor& start, - const DenseTensor& end, - const DenseTensor& step, - DenseTensor* out) { +void ArangeKernel(const Context& dev_ctx, + const DenseTensor& start, + const DenseTensor& end, + const DenseTensor& step, + DenseTensor* out) { T start_value = start.data()[0]; T end_value = end.data()[0]; T step_value = step.data()[0]; @@ -42,4 +42,4 @@ void RangeKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - range, CPU, ALL_LAYOUT, phi::RangeKernel, float, double, int, int64_t) {} + arange, CPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/range_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu similarity index 86% rename from paddle/phi/kernels/gpu/range_kernel.cu rename to paddle/phi/kernels/gpu/arange_kernel.cu index d9a98f06d0795..916f6aa5537a6 100644 --- a/paddle/phi/kernels/gpu/range_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/range_kernel.h" +#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -40,11 +40,11 @@ __global__ void Range(T start, T step, int64_t size, T* out) { } template -void RangeKernel(const Context& dev_ctx, - const DenseTensor& start, - const DenseTensor& end, - const DenseTensor& step, - DenseTensor* out) { +void ArangeKernel(const Context& dev_ctx, + const DenseTensor& start, + const DenseTensor& end, + const DenseTensor& step, + DenseTensor* out) { T start_value = GetValue(dev_ctx, start); T end_value = GetValue(dev_ctx, end); T step_value = GetValue(dev_ctx, step); @@ -63,7 +63,7 @@ void RangeKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - range, GPU, ALL_LAYOUT, phi::RangeKernel, float, double, int64_t, int) { + arange, GPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int64_t, int) { kernel->InputAt(0).SetBackend(phi::Backend::CPU); kernel->InputAt(1).SetBackend(phi::Backend::CPU); kernel->InputAt(2).SetBackend(phi::Backend::CPU); diff --git a/paddle/phi/ops/compat/range_sig.cc b/paddle/phi/ops/compat/range_sig.cc new file mode 100644 index 0000000000000..d48898bd8487c --- /dev/null +++ b/paddle/phi/ops/compat/range_sig.cc @@ -0,0 +1,17 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +PD_REGISTER_BASE_KERNEL_NAME(range, arange); diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index ff7008fddd47d..81a60bf517522 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -21,7 +21,7 @@ from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..initializer import Initializer -from ..framework import convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode +from ..framework import _current_expected_place, convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode from ..framework import Variable from ..initializer import Constant from ..core import VarDesc @@ -1433,6 +1433,10 @@ def range(start, end, step, dtype, name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) + if in_dygraph_mode(): + return _C_ops.final_state_arange(start, end, step, dtype, + _current_expected_place()) + if not isinstance(start, Variable): with device_guard("cpu"): start = fill_constant([1], dtype, start, force_cpu=True) @@ -1451,7 +1455,7 @@ def range(start, end, step, dtype, name=None): elif step.dtype != dtype: step = cast(step, dtype) - if _non_static_mode(): + if _in_legacy_dygraph(): out = _C_ops.range(start, end, step) out.stop_gradient = True return out diff --git a/python/paddle/fluid/tests/unittests/test_randperm_op.py b/python/paddle/fluid/tests/unittests/test_randperm_op.py index 2380ccb14aaee..5c9ab36fa34bc 100644 --- a/python/paddle/fluid/tests/unittests/test_randperm_op.py +++ b/python/paddle/fluid/tests/unittests/test_randperm_op.py @@ -18,6 +18,7 @@ import paddle import paddle.fluid.core as core from paddle.static import program_guard, Program +from paddle.fluid.framework import _test_eager_guard import os @@ -50,6 +51,7 @@ class TestRandpermOp(OpTest): def setUp(self): self.op_type = "randperm" + self.python_api = paddle.randperm self.n = 200 self.dtype = "int64" @@ -72,6 +74,10 @@ def verify_output(self, outs): self.assertTrue( check_randperm_out(self.n, out_np), msg=error_msg(out_np)) + def test_eager(self): + with _test_eager_guard(): + self.test_check_output() + class TestRandpermOpN(TestRandpermOp): def init_attrs(self): @@ -130,6 +136,19 @@ def test_out(self): paddle.enable_static() +class TestRandpermEager(unittest.TestCase): + def test_out(self): + paddle.disable_static() + n = 10 + with _test_eager_guard(): + for dtype in ['int32', np.int64, 'float32', 'float64']: + data_p = paddle.randperm(n, dtype) + data_np = data_p.numpy() + self.assertTrue( + check_randperm_out(n, data_np), msg=error_msg(data_np)) + paddle.enable_static() + + class TestRandomValue(unittest.TestCase): def test_fixed_random_number(self): # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' diff --git a/python/paddle/fluid/tests/unittests/test_range.py b/python/paddle/fluid/tests/unittests/test_range.py index f129ae78cbf7e..e19c1b227f531 100644 --- a/python/paddle/fluid/tests/unittests/test_range.py +++ b/python/paddle/fluid/tests/unittests/test_range.py @@ -14,9 +14,15 @@ from __future__ import print_function +import paddle import unittest import numpy as np from op_test import OpTest +from functools import partial + + +def arange_wrapper(start, end, step, dtype=None): + return paddle.arange(start, end, step, dtype) class TestRangeOp(OpTest): @@ -36,33 +42,38 @@ def setUp(self): def init_config(self): self.dtype = np.float32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) self.case = (0, 1, 0.2) def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestFloatRangeOpCase0(TestRangeOp): def init_config(self): self.dtype = np.float32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) self.case = (0, 5, 1) class TestInt32RangeOpCase0(TestRangeOp): def init_config(self): self.dtype = np.int32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) self.case = (0, 5, 2) class TestInt32RangeOpCase1(TestRangeOp): def init_config(self): self.dtype = np.int32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) self.case = (10, 1, -2) class TestInt32RangeOpCase2(TestRangeOp): def init_config(self): self.dtype = np.int32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) self.case = (-1, -10, -2) diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 1fa91ae148f60..20f4e73b2718a 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -22,7 +22,7 @@ import paddle from paddle import _C_ops from paddle.static import Variable -from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph +from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode __all__ = [] @@ -919,7 +919,10 @@ def randperm(n, dtype="int64", name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_randperm( + n, dtype, paddle.fluid.framework._current_expected_place()) + if _in_legacy_dygraph(): return _C_ops.randperm('n', n, 'seed', 0, 'dtype', dtype) if n < 1: diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index b3bf1f7890400..0b855b0f967ba 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -97,6 +97,20 @@ kernel : func : any +- api : arange + args : (Tensor start, Tensor end, Tensor step, DataType dtype, Place place={}) + output : Tensor + infer_meta : + func : ArangeInferMeta + param : [start, end, step] + kernel : + func : arange + param : [start, end, step] + data_type : dtype + backend : place + data_transform : + support_trans_dtype : start, end, step + # arg_max - api : argmax args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) @@ -1227,6 +1241,18 @@ data_type : x backward : put_along_axis_grad +- api : randperm + args : (int n, DataType dtype, Place place={}) + output : Tensor + infer_meta : + func : RandpermInferMeta + param : [n, dtype] + kernel : + func : randperm + param : [n, dtype] + data_type : dtype + backend : place + - api : reciprocal args : (Tensor x) output : Tensor From 61e60e683d8ab13388b15020c3d6fc78ef976ff2 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sun, 3 Apr 2022 11:59:54 +0800 Subject: [PATCH 41/93] [Eager]Fix 17 unittest and open check_eager=True (#41270) * [Eager]Enhance eager_trace_op logic to support Optimizer Op * fix AsDispensable * [Eager]Fix 17 unittest and open check_eager=True * remove print * fix unittests * fix op_testa * fix coverage CI failed * fix ci --- paddle/fluid/eager/grad_tensor_holder.cc | 5 ++-- paddle/fluid/pybind/op_function_generator.h | 1 + python/paddle/fluid/framework.py | 21 ++++++++++++++ .../paddle/fluid/tests/unittests/op_test.py | 6 +++- .../tests/unittests/test_bicubic_interp_op.py | 11 ++++++-- .../unittests/test_bicubic_interp_v2_op.py | 14 ++++++++-- .../unittests/test_bilinear_interp_op.py | 23 +++++++++++---- .../fluid/tests/unittests/test_crop_op.py | 6 ++-- .../tests/unittests/test_crop_tensor_op.py | 10 ++++--- .../unittests/test_decayed_adagrad_op.py | 6 ++-- .../fluid/tests/unittests/test_dpsgd_op.py | 4 ++- .../fluid/tests/unittests/test_ftrl_op.py | 4 ++- .../fluid/tests/unittests/test_mean_iou.py | 7 ++++- .../tests/unittests/test_nearest_interp_op.py | 25 +++++++++++++---- .../tests/unittests/test_prroi_pool_op.py | 10 ++++--- .../tests/unittests/test_smooth_l1_loss_op.py | 28 +++++++++++++------ .../unittests/test_sparse_momentum_op.py | 7 ++++- .../fluid/tests/unittests/test_stft_op.py | 4 +-- .../unittests/test_trilinear_interp_op.py | 23 +++++++++++---- .../unittests/test_trilinear_interp_v2_op.py | 26 +++++++++++++---- 20 files changed, 187 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index b15d9b892f810..2dacb588ff847 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -64,8 +64,9 @@ void GradTensorHolder::CopyValueFromTensor( } else { // Create new tensor->impl and fill it with 1.0 if (t.defined()) { - // Fill 1.0 - buffer_[slot_id][rank] = paddle::experimental::ones_like(t, t.dtype()); + // Fill 1.0, use full to support complex, one_like don't support it. + buffer_[slot_id][rank] = + paddle::experimental::full(t.shape(), 1, t.dtype(), t.inner_place()); } } } diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 1e501a0c9e024..b8202fe8c51fd 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -52,6 +52,7 @@ std::map> op_ins_map = { {"fake_quantize_dequantize_moving_average_abs_max", {"X", "InScale", "InAccum", "InState"}}, {"nll_loss", {"X", "Label", "Weight"}}, + {"smooth_l1_loss", {"X", "Y", "InsideWeight", "OutsideWeight"}}, {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}}, {"gather", {"X", "Index", "Axis"}}, {"repeat_interleave", {"X", "RepeatsTensor"}}, diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index dc1f82d235e31..20c441f364145 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -106,14 +106,35 @@ # to make sure in most case, we find new dygraph mode first with only one if statement. +def _update_monkey_methods(is_eager): + """ + Update monkey methods of VarBase or eager.Tensor while + switching eager mode and legacy mode. + """ + from paddle import _C_ops + from .dygraph.varbase_patch_methods import monkey_patch_varbase + from .dygraph import monkey_patch_math_varbase + + assert isinstance(is_eager, bool) + if is_eager: + _C_ops.switch_to_eager_ops() + else: + _C_ops.switch_to_core_ops() + + monkey_patch_varbase() + monkey_patch_math_varbase() + + def _enable_legacy_dygraph(): global _in_eager_mode_ _in_eager_mode_ = False + _update_monkey_methods(is_eager=False) def _disable_legacy_dygraph(): global _in_eager_mode_ _in_eager_mode_ = True + _update_monkey_methods(is_eager=True) def _in_eager_without_dygraph_check(): diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index be883d243f795..60064340b198a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1937,6 +1937,9 @@ def check_grad_with_place(self, "Gradient Check On %s" % str(place)) if check_dygraph: + # ensure switch into legacy dygraph + g_enable_legacy_dygraph() + dygraph_grad = self._get_dygraph_grad( inputs_to_check, place, output_names, user_defined_grad_outputs, no_grad_set, False) @@ -1950,6 +1953,8 @@ def check_grad_with_place(self, self._assert_is_close(numeric_grads, dygraph_grad, inputs_to_check, max_relative_error, "Gradient Check On %s" % str(place)) + # ensure switch back eager dygraph + g_disable_legacy_dygraph() if check_eager: with fluid.dygraph.base.guard(place): @@ -2087,7 +2092,6 @@ def _get_dygraph_grad(self, inputs={"X": loss_sum}, outputs={"Out": loss}, attrs={'scale': 1.0 / float(len(avg_sum))}) - loss.backward() fetch_list_grad = [] diff --git a/python/paddle/fluid/tests/unittests/test_bicubic_interp_op.py b/python/paddle/fluid/tests/unittests/test_bicubic_interp_op.py index f3f3431c9fb3e..8d7dd0d81180e 100644 --- a/python/paddle/fluid/tests/unittests/test_bicubic_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bicubic_interp_op.py @@ -127,6 +127,9 @@ def setUp(self): self.data_layout = 'NCHW' self.init_test_case() self.op_type = "bicubic_interp" + # NOTE(dev): some AsDispensible input is not used under imperative mode. + # Skip check_eager while found them in Inputs. + self.check_eager = True input_np = np.random.random(self.input_shape).astype("float64") if self.data_layout == "NCHW": @@ -149,8 +152,10 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape + self.check_eager = False self.attrs = { 'out_h': self.out_h, @@ -163,10 +168,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'bicubic' @@ -442,4 +448,5 @@ def test_outshape_and_scale(): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_bicubic_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_bicubic_interp_v2_op.py index 43e418addf2bf..d5c3aee2f4372 100644 --- a/python/paddle/fluid/tests/unittests/test_bicubic_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_bicubic_interp_v2_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, _in_eager_without_dygraph_check import paddle.fluid.core as core import paddle.fluid as fluid import paddle @@ -135,6 +135,10 @@ def setUp(self): self.data_layout = 'NCHW' self.init_test_case() self.op_type = "bicubic_interp_v2" + # NOTE(dev): some AsDispensible input is not used under imperative mode. + # Skip check_eager while found them in Inputs. + # TODO(dev): add self.python_api + self.check_eager = False input_np = np.random.random(self.input_shape).astype("float64") scale_h = 0 scale_w = 0 @@ -166,8 +170,10 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape + self.check_eager = False self.attrs = { 'out_h': self.out_h, @@ -186,10 +192,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'bicubic' @@ -543,4 +550,5 @@ def test_input_shape_1(): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index 083b671c283a0..1817ef160c70a 100755 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -102,6 +102,9 @@ def setUp(self): self.data_layout = 'NCHW' self.init_test_case() self.op_type = "bilinear_interp" + # NOTE(dev): some AsDispensible input is not used under imperative mode. + # Skip check_eager while found them in Inputs. + self.check_eager = True input_np = np.random.random(self.input_shape).astype("float64") if self.data_layout == "NCHW": @@ -124,8 +127,10 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape + self.check_eager = False self.attrs = { 'out_h': self.out_h, @@ -139,10 +144,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'bilinear' @@ -266,6 +272,7 @@ def setUp(self): self.actual_shape = None self.init_test_case() self.op_type = "bilinear_interp" + self.check_eager = True input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") @@ -282,6 +289,7 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False self.attrs = { 'out_h': self.out_h, @@ -294,7 +302,8 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) + self.check_output_with_place( + place=core.CPUPlace(), atol=1, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'bilinear' @@ -397,6 +406,7 @@ def setUp(self): self.actual_shape = None self.init_test_case() self.op_type = "bilinear_interp" + self.check_eager = True self.shape_by_1Dtensor = False self.scale_by_1Dtensor = False self.attrs = { @@ -419,12 +429,14 @@ def setUp(self): if self.shape_by_1Dtensor: self.inputs['OutSize'] = self.out_size + self.check_eager = False elif self.out_size is not None: size_tensor = [] for index, ele in enumerate(self.out_size): size_tensor.append(("x" + str(index), np.ones( (1)).astype('int32') * ele)) self.inputs['SizeTensor'] = size_tensor + self.check_eager = False self.attrs['out_h'] = self.out_h self.attrs['out_w'] = self.out_w @@ -433,10 +445,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'bilinear' diff --git a/python/paddle/fluid/tests/unittests/test_crop_op.py b/python/paddle/fluid/tests/unittests/test_crop_op.py index b08648b99f123..acb652ad6f9e8 100644 --- a/python/paddle/fluid/tests/unittests/test_crop_op.py +++ b/python/paddle/fluid/tests/unittests/test_crop_op.py @@ -71,10 +71,10 @@ def initTestCase(self): self.offsets = [1, 2] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestCase1(TestCropOp): @@ -125,4 +125,6 @@ def initTestCase(self): if __name__ == '__main__': + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py b/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py index 0808f99ff1a94..a4552c8f5ddbb 100644 --- a/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py @@ -77,10 +77,10 @@ def initTestCase(self): self.offsets = [1, 2] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestCase1(TestCropTensorOp): @@ -175,10 +175,10 @@ def initTestCase(self): self.shape_attr = [0, 0] def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(["X"], "Out") + self.check_grad(["X"], "Out", check_eager=True) class TestCropTensorOpTensorAttrCase1(TestCropTensorOpTensorAttr): @@ -262,4 +262,6 @@ def input_dtype(): if __name__ == '__main__': + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py b/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py index a664a1529f4de..e2f6d17cc96a8 100644 --- a/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py +++ b/python/paddle/fluid/tests/unittests/test_decayed_adagrad_op.py @@ -48,7 +48,7 @@ def setUp(self): self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestDecayedAdagradOp2(OpTest): @@ -80,8 +80,10 @@ def setUp(self): self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) if __name__ == "__main__": + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dpsgd_op.py b/python/paddle/fluid/tests/unittests/test_dpsgd_op.py index 48bf786e139dd..35a922b78205f 100644 --- a/python/paddle/fluid/tests/unittests/test_dpsgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_dpsgd_op.py @@ -45,7 +45,7 @@ def setUp(self): self.outputs = {'ParamOut': param_out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def dpsgd_step(inputs, attributes): @@ -70,4 +70,6 @@ def dpsgd_step(inputs, attributes): if __name__ == "__main__": + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_ftrl_op.py b/python/paddle/fluid/tests/unittests/test_ftrl_op.py index f58672a7a1e89..1826fdc3c0604 100644 --- a/python/paddle/fluid/tests/unittests/test_ftrl_op.py +++ b/python/paddle/fluid/tests/unittests/test_ftrl_op.py @@ -101,7 +101,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestSparseFTRLOp(unittest.TestCase): @@ -201,4 +201,6 @@ def init_kernel(self): if __name__ == "__main__": + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_mean_iou.py b/python/paddle/fluid/tests/unittests/test_mean_iou.py index 4e89a9034a341..b392a328494b3 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_iou.py +++ b/python/paddle/fluid/tests/unittests/test_mean_iou.py @@ -15,6 +15,7 @@ from __future__ import print_function from __future__ import division + import unittest import numpy as np from op_test import OpTest @@ -113,6 +114,11 @@ def config(self): self.in_correct_num = 2 self.in_mean_iou_num = 2 + # NOTE(dev): Skip check_dygraph becuase Python API doesn't expose + # in_wrong_num/in_correct_num/in_mean_iou_num argument + def test_check_output(self): + self.check_output(check_dygraph=False, check_eager=False) + class TestMeanIOUOpError(unittest.TestCase): def test_errors(self): @@ -130,5 +136,4 @@ def test_errors(self): if __name__ == '__main__': - paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py index eda530f30df26..5df085d4febac 100755 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py @@ -79,6 +79,7 @@ def setUp(self): self.data_layout = 'NCHW' self.init_test_case() self.op_type = "nearest_interp" + self.check_eager = True input_np = np.random.random(self.input_shape).astype("float64") if self.data_layout == "NCHW": @@ -101,8 +102,10 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape + self.check_eager = False self.attrs = { 'out_h': self.out_h, 'out_w': self.out_w, @@ -114,10 +117,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'nearest' @@ -231,6 +235,7 @@ def setUp(self): self.actual_shape = None self.init_test_case() self.op_type = "nearest_interp" + self.check_eager = True input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") @@ -247,6 +252,7 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False self.attrs = { 'out_h': self.out_h, 'out_w': self.out_w, @@ -257,7 +263,8 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) + self.check_output_with_place( + place=core.CPUPlace(), atol=1, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'nearest' @@ -339,6 +346,9 @@ def setUp(self): 'interp_method': self.interp_method, 'align_corners': self.align_corners, } + # NOTE(dev): some AsDispensible input is not used under imperative mode. + # Skip check_eager while found them in Inputs. + self.check_eager = True input_np = np.random.random(self.input_shape).astype("float64") self.inputs = {'X': input_np} @@ -355,12 +365,14 @@ def setUp(self): if self.shape_by_1Dtensor: self.inputs['OutSize'] = self.out_size + self.check_eager = False elif self.out_size is not None: size_tensor = [] for index, ele in enumerate(self.out_size): size_tensor.append(("x" + str(index), np.ones( (1)).astype('int32') * ele)) self.inputs['SizeTensor'] = size_tensor + self.check_eager = False self.attrs['out_h'] = self.out_h self.attrs['out_w'] = self.out_w @@ -370,10 +382,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'nearest' @@ -495,4 +508,6 @@ def attr_scale_value(): if __name__ == "__main__": + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py index efb5e05bdebca..8e5ba7c3363a1 100644 --- a/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py @@ -80,14 +80,14 @@ def setUp(self): self.set_data() def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_backward(self): places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) for place in places: - self.check_grad_with_place(place, ['X'], 'Out') + self.check_grad_with_place(place, ['X'], 'Out', check_eager=True) def run_net(self, place): with program_guard(Program(), Program()): @@ -197,14 +197,14 @@ def setUp(self): self.set_data() def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_backward(self): places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) for place in places: - self.check_grad_with_place(place, ['X'], 'Out') + self.check_grad_with_place(place, ['X'], 'Out', check_eager=True) def run_net(self, place): with program_guard(Program(), Program()): @@ -280,4 +280,6 @@ def test_bad_y(): if __name__ == '__main__': + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py b/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py index 3c825c08e8c3f..63e8568048d13 100644 --- a/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_smooth_l1_loss_op.py @@ -48,18 +48,27 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=0.02, check_eager=True) def test_check_grad_ingore_x(self): self.check_grad( - ['Y'], 'Out', max_relative_error=0.03, no_grad_set=set("X")) + ['Y'], + 'Out', + max_relative_error=0.03, + no_grad_set=set("X"), + check_eager=True) def test_check_grad_ingore_y(self): self.check_grad( - ['X'], 'Out', max_relative_error=0.03, no_grad_set=set('Y')) + ['X'], + 'Out', + max_relative_error=0.03, + no_grad_set=set('Y'), + check_eager=True) class TestSmoothL1LossOp2(OpTest): @@ -86,24 +95,27 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.03) + self.check_grad( + ['X', 'Y'], 'Out', max_relative_error=0.03, check_eager=True) def test_check_grad_ingore_x(self): self.check_grad( ['Y'], 'Out', max_relative_error=0.03, - no_grad_set=set(['X', 'InsideWeight', 'OutsideWeight'])) + no_grad_set=set(['X', 'InsideWeight', 'OutsideWeight']), + check_eager=True) def test_check_grad_ingore_y(self): self.check_grad( ['X'], 'Out', max_relative_error=0.03, - no_grad_set=set(['Y', 'InsideWeight', 'OutsideWeight'])) + no_grad_set=set(['Y', 'InsideWeight', 'OutsideWeight']), + check_eager=True) class TestSmoothL1LossOpError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_sparse_momentum_op.py b/python/paddle/fluid/tests/unittests/test_sparse_momentum_op.py index e36cb72efc725..033dbd250ed61 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_momentum_op.py @@ -163,7 +163,8 @@ def init_use_nesterov(self): pass def test_check_output(self): - self.check_output(atol=5e-3 if self.multi_precision else 1e-5) + self.check_output( + atol=5e-3 if self.multi_precision else 1e-5, check_eager=True) class TestSparseMomentumOpDtype1(TestSparseMomentumOp): @@ -240,3 +241,7 @@ def init_multi_precision(self): def init_use_nesterov(self): self.use_nesterov = False + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_stft_op.py b/python/paddle/fluid/tests/unittests/test_stft_op.py index f228c148d6e17..41e950606b3db 100644 --- a/python/paddle/fluid/tests/unittests/test_stft_op.py +++ b/python/paddle/fluid/tests/unittests/test_stft_op.py @@ -77,12 +77,12 @@ def initTestCase(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) paddle.disable_static() def test_check_grad_normal(self): paddle.enable_static() - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py index 2778fa0c6ace4..49699b8fafd03 100755 --- a/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py @@ -131,6 +131,9 @@ def setUp(self): self.data_layout = 'NCDHW' self.init_test_case() self.op_type = "trilinear_interp" + # NOTE(dev): some AsDispensible input is not used under imperative mode. + # Skip check_eager while found them in Inputs. + self.check_eager = True input_np = np.random.random(self.input_shape).astype("float32") if self.data_layout == "NCDHW": @@ -157,8 +160,10 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape + self.check_eager = False # c++ end treat NCDHW the same way as NCHW if self.data_layout == 'NCDHW': data_layout = 'NCHW' @@ -177,10 +182,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'trilinear' @@ -326,6 +332,7 @@ def setUp(self): self.actual_shape = None self.init_test_case() self.op_type = "trilinear_interp" + self.check_eager = True input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") @@ -344,6 +351,7 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False self.attrs = { 'out_d': self.out_d, @@ -357,7 +365,8 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) + self.check_output_with_place( + place=core.CPUPlace(), atol=1, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'trilinear' @@ -467,6 +476,7 @@ def setUp(self): self.actual_shape = None self.init_test_case() self.op_type = "trilinear_interp" + self.check_eager = True self.shape_by_1Dtensor = False self.scale_by_1Dtensor = False self.attrs = { @@ -492,12 +502,14 @@ def setUp(self): if self.shape_by_1Dtensor: self.inputs['OutSize'] = self.out_size + self.check_eager = False elif self.out_size is not None: size_tensor = [] for index, ele in enumerate(self.out_size): size_tensor.append(("x" + str(index), np.ones( (1)).astype('int32') * ele)) self.inputs['SizeTensor'] = size_tensor + self.check_eager = False self.attrs['out_d'] = self.out_d self.attrs['out_h'] = self.out_h @@ -508,10 +520,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'trilinear' diff --git a/python/paddle/fluid/tests/unittests/test_trilinear_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_trilinear_interp_v2_op.py index 9f46b539a04b6..6d072e3c377fe 100755 --- a/python/paddle/fluid/tests/unittests/test_trilinear_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_trilinear_interp_v2_op.py @@ -145,6 +145,10 @@ def setUp(self): self.data_layout = 'NCDHW' self.init_test_case() self.op_type = "trilinear_interp_v2" + # NOTE(dev): some AsDispensible input is not used under imperative mode. + # Skip check_eager while found them in Inputs. + # TODO(dev): add self.python_api + self.check_eager = False input_np = np.random.random(self.input_shape).astype("float32") scale_w = 0 @@ -183,8 +187,10 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape + self.check_eager = False # c++ end treat NCDHW the same way as NCHW if self.data_layout == 'NCDHW': data_layout = 'NCHW' @@ -208,10 +214,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'trilinear' @@ -357,6 +364,8 @@ def setUp(self): self.actual_shape = None self.init_test_case() self.op_type = "trilinear_interp_v2" + # TODO(dev): add self.python_api + self.check_eager = False input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") @@ -383,6 +392,7 @@ def setUp(self): self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size + self.check_eager = False self.attrs = { 'out_d': self.out_d, @@ -401,7 +411,8 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output_with_place(place=core.CPUPlace(), atol=1) + self.check_output_with_place( + place=core.CPUPlace(), atol=1, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'trilinear' @@ -511,6 +522,8 @@ def setUp(self): self.actual_shape = None self.init_test_case() self.op_type = "trilinear_interp_v2" + # TODO(dev): add self.python_api + self.check_eager = False self.shape_by_1Dtensor = False self.scale_by_1Dtensor = False self.attrs = { @@ -543,12 +556,14 @@ def setUp(self): if self.shape_by_1Dtensor: self.inputs['OutSize'] = self.out_size + self.check_eager = False elif self.out_size is not None: size_tensor = [] for index, ele in enumerate(self.out_size): size_tensor.append(("x" + str(index), np.ones( (1)).astype('int32') * ele)) self.inputs['SizeTensor'] = size_tensor + self.check_eager = False self.attrs['out_d'] = self.out_d self.attrs['out_h'] = self.out_h @@ -565,10 +580,11 @@ def setUp(self): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output() + self.check_output(check_eager=self.check_eager) def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad( + ['X'], 'Out', in_place=True, check_eager=self.check_eager) def init_test_case(self): self.interp_method = 'trilinear' From af8d248215a0e6f725179c772bb97252cf84a545 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Sun, 3 Apr 2022 13:12:55 +0800 Subject: [PATCH 42/93] add maximum limit for grid of index_select (#41127) * limit grid dim for index select * mv LimitGridDim into gpu_launch_config.h * fix conflicts * fix conflicts * fix code style * set block to 256 * fix grid setting * set dtype of block_dim to unsigned int --- .../platform/device/gpu/gpu_launch_config.h | 8 ++++ .../phi/kernels/funcs/elementwise_grad_base.h | 44 ++++++++----------- paddle/phi/kernels/funcs/reduce_function.h | 16 ++----- .../kernels/gpu/index_sample_grad_kernel.cu | 9 +--- paddle/phi/kernels/gpu/index_sample_kernel.cu | 9 +--- .../kernels/gpu/index_select_grad_kernel.cu | 23 +++++----- paddle/phi/kernels/gpu/index_select_kernel.cu | 35 +++++++-------- 7 files changed, 58 insertions(+), 86 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h index 4e8b790fa63d1..4a550e61d42da 100644 --- a/paddle/fluid/platform/device/gpu/gpu_launch_config.h +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -170,6 +170,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D( return config; } +template +void LimitGridDim(const Context& ctx, dim3* grid_dim) { + auto max_grid_dim = reinterpret_cast(ctx) + .GetCUDAMaxGridDimSize(); + grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; + grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; + grid_dim->z = grid_dim->z < max_grid_dim[2] ? grid_dim->z : max_grid_dim[2]; +} } // namespace platform } // namespace paddle diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index 23b8388c74589..1021b510b26cd 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -24,6 +24,7 @@ limitations under the License. */ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" #endif @@ -49,14 +50,6 @@ namespace phi { namespace funcs { using DDim = phi::DDim; -template -void LimitGridDim(const GPUContext &ctx, T *grid_dim) { - auto max_grid_dim = ctx.GetCUDAMaxGridDimSize()[0]; - if (*grid_dim > max_grid_dim) { - *grid_dim = max_grid_dim; - } -} - template void CommonGradBroadcastCPU(const DenseTensor &x, const DenseTensor &y, @@ -978,17 +971,17 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, constexpr int half_walf = 16; if (w < half_walf || h < half_walf) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int gird_size = w; - ElemwiseGradBroadcast1CUDAKernel<<>>( + int grid_size = w; + ElemwiseGradBroadcast1CUDAKernel<<>>( x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); } else { // suppose perfoemance improves with h increased. dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X); auto gplace = phi::GPUPlace(); auto *ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get(gplace)); - LimitGridDim(*ctx, &grid_size); + paddle::platform::LimitGridDim(*ctx, &grid_size); FastElemwiseGradBroadcast1CUDAKernel<<>>( x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); } @@ -1009,13 +1002,12 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, T *dx, T *dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); - int gird_size = n; - int grid_size = n; + dim3 grid_size = dim3(n); auto gplace = phi::GPUPlace(); auto *ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get(gplace)); - LimitGridDim(*ctx, &grid_size); - ElemwiseGradBroadcast2CUDAKernel<<>>( + paddle::platform::LimitGridDim(*ctx, &grid_size); + ElemwiseGradBroadcast2CUDAKernel<<>>( x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy); } @@ -1216,8 +1208,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, is_y); } else { dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - LimitGridDim(ctx, &grid_size); + dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X); + paddle::platform::LimitGridDim(ctx, &grid_size); FastCommonGradBroadcastCUDAKernelHeight<<>>( x_data, @@ -1392,8 +1384,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, 1, std::multiplies()); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); - int grid_size = pre * post; - LimitGridDim(ctx, &grid_size); + dim3 grid_size = dim3(pre * post); + paddle::platform::LimitGridDim(ctx, &grid_size); // we need to calc y offset with blockid, so do x_pre/y_pre to get left // size. if (k_pre != pre) k_pre = pre / k_pre; @@ -1423,8 +1415,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, 1, std::multiplies()); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); - int grid_size = pre * post; - LimitGridDim(ctx, &grid_size); + dim3 grid_size = dim3(pre * post); + paddle::platform::LimitGridDim(ctx, &grid_size); if (k_pre != pre) k_pre = pre / k_pre; FastCommonGradBroadcastOneCUDAKernel<<( - paddle::platform::DeviceContextPool::Instance().Get(place)); - std::array max_grid_dim = ctx->GetCUDAMaxGridDimSize(); - grid.x = grid.x < max_grid_dim[0] ? grid.x : max_grid_dim[0]; - grid.y = grid.y < max_grid_dim[1] ? grid.y : max_grid_dim[1]; - grid.z = grid.z < max_grid_dim[2] ? grid.z : max_grid_dim[2]; - } - public: std::vector reduce_dims_origin; std::vector reduce_dim; @@ -1072,7 +1064,7 @@ void ReduceKernel(const KPDevice& dev_ctx, auto x_dim = phi::vectorize(x.dims()); auto config = ReduceConfig(origin_reduce_dims, x_dim); - config.Run(x.place()); + config.Run(dev_ctx); int numel = x.numel(); // after config.run() // SetOutputData for ReduceHigherDim when should_reduce_again is true, diff --git a/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu index 669ae11543950..c8c025c7fc18f 100644 --- a/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu @@ -26,13 +26,6 @@ namespace phi { namespace { -template -void LimitGridDim(const Context& ctx, dim3* grid_dim) { - auto max_grid_dim = - reinterpret_cast(ctx).GetCUDAMaxGridDimSize(); - grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; - grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; -} #define PREDEFINED_BLOCK_SIZE_X 512 #define PREDEFINED_BLOCK_SIZE 1024 #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -107,7 +100,7 @@ void IndexSampleGradKernel(const Context& ctx, dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); - LimitGridDim(ctx, &grid_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); phi::funcs::SetConstant set_zero; set_zero(ctx, x_grad, static_cast(0)); diff --git a/paddle/phi/kernels/gpu/index_sample_kernel.cu b/paddle/phi/kernels/gpu/index_sample_kernel.cu index 68573d5596646..0eca473a565a8 100644 --- a/paddle/phi/kernels/gpu/index_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/index_sample_kernel.cu @@ -25,13 +25,6 @@ namespace phi { namespace { -template -void LimitGridDim(const Context& ctx, dim3* grid_dim) { - auto max_grid_dim = - reinterpret_cast(ctx).GetCUDAMaxGridDimSize(); - grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; - grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; -} #define PREDEFINED_BLOCK_SIZE_X 512 #define PREDEFINED_BLOCK_SIZE 1024 #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -95,7 +88,7 @@ void IndexSampleKernel(const Context& ctx, dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); - LimitGridDim(ctx, &grid_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); if (index_type == DataType::INT64) { const int64_t* index_data = index.data(); diff --git a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu index b3bd307e2aad6..209ce1ccf5c80 100644 --- a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/index_select_grad_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/kernel_registry.h" @@ -89,25 +90,23 @@ void IndexSelectGradKernel(const Context& ctx, auto stream = ctx.stream(); - index_select_grad_init< - T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(in_grad_data, numel); + unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; + dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); - int blocks = - (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; - int threads = PADDLE_CUDA_NUM_THREADS; + index_select_grad_init<<>>(in_grad_data, + numel); if (FLAGS_cudnn_deterministic) { VLOG(2) << "Run grad kernel of index_select with single thread."; - blocks = 1; - threads = 1; + block_dim = 1; + grid_dim.x = 1; } if (index_type == phi::DataType::INT64) { const int64_t* index_data = index.data(); - index_select_grad_cuda_kernel<<>>( + index_select_grad_cuda_kernel<<>>( output_grad_data, in_grad_data, index_data, @@ -118,7 +117,7 @@ void IndexSelectGradKernel(const Context& ctx, delta); } else { const int* index_data = index.data(); - index_select_grad_cuda_kernel<<>>( + index_select_grad_cuda_kernel<<>>( output_grad_data, in_grad_data, index_data, diff --git a/paddle/phi/kernels/gpu/index_select_kernel.cu b/paddle/phi/kernels/gpu/index_select_kernel.cu index e82976d46e68b..57a13a9aefc2c 100644 --- a/paddle/phi/kernels/gpu/index_select_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/kernel_registry.h" @@ -31,16 +32,14 @@ __global__ void index_select_cuda_kernel(const T* input, int64_t stride, int64_t size, int64_t delta) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; + CUDA_KERNEL_LOOP(idx, N) { + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = + idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; } - - int64_t pre_idx = idx / (stride * size); - int64_t dim_idx = idx % (stride * size) / stride; - IndexT src_dim_idx = index[dim_idx]; - int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; - output[idx] = input[input_idx]; } template @@ -75,21 +74,17 @@ void IndexSelectKernel(const Context& ctx, int64_t numel = output->numel(); auto stream = ctx.stream(); + unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; + dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); + if (index_type == phi::DataType::INT64) { const int64_t* index_data = index.data(); - index_select_cuda_kernel<<< - (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(in_data, out_data, index_data, numel, stride, size, delta); + index_select_cuda_kernel<<>>( + in_data, out_data, index_data, numel, stride, size, delta); } else { const int* index_data = index.data(); - index_select_cuda_kernel< - T, - int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>( + index_select_cuda_kernel<<>>( in_data, out_data, index_data, numel, stride, size, delta); } } From 2bc72a06be7a2df79b2324bc97ea6eb5f3c847b3 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sun, 3 Apr 2022 13:27:41 +0800 Subject: [PATCH 43/93] fix eager gen grad multi out error (#41358) --- .../auto_code_generator/final_state_generator/eager_gen.py | 4 ++-- python/paddle/utils/code_gen/api_base.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 0d1d3ab722522..88688672b18b5 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -88,9 +88,9 @@ def ParseArguments(): CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \ """ - for (auto tw: {}) { + for (auto& tw : {}) {{ tw.clear(); - }; + }} """ SET_ATTR_METHOD_TEMPLATE = \ diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index d3c3177827b28..14f22fced9230 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -193,7 +193,7 @@ def parse_output_item(output_item): f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \ but now is {out_type}." - return out_type, result.group('name') + return output_type_map[out_type], result.group('name') else: if output_item.strip() in output_type_map: From 868a3203eba4745d43be8dec1adad32994cb80c4 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Sun, 3 Apr 2022 14:54:15 +0800 Subject: [PATCH 44/93] Add infer meta (#41054) * add some infer meta * fix bug * fix bugs; * fix bug and add set data type * revert infer shape of lookup table * recover test --- paddle/fluid/operators/meshgrid_op.cc | 33 ++--- .../fluid/operators/optimizers/adagrad_op.cc | 42 ++---- .../fluid/operators/optimizers/rmsprop_op.cc | 88 ++----------- paddle/fluid/operators/optimizers/sgd_op.cc | 48 +------ paddle/fluid/operators/temporal_shift_op.cc | 52 ++------ paddle/phi/infermeta/binary.cc | 26 ++++ paddle/phi/infermeta/binary.h | 5 + paddle/phi/infermeta/multiary.cc | 124 ++++++++++++++++++ paddle/phi/infermeta/multiary.h | 34 +++++ paddle/phi/infermeta/unary.cc | 46 +++++++ paddle/phi/infermeta/unary.h | 7 + 11 files changed, 281 insertions(+), 224 deletions(-) diff --git a/paddle/fluid/operators/meshgrid_op.cc b/paddle/fluid/operators/meshgrid_op.cc index 103169fedb90e..5a6862f380da1 100644 --- a/paddle/fluid/operators/meshgrid_op.cc +++ b/paddle/fluid/operators/meshgrid_op.cc @@ -19,6 +19,10 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -28,30 +32,6 @@ class MeshgridOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_GE( - ctx->Inputs("X").size(), 1UL, - platform::errors::InvalidArgument("Input(X) should not be empty.")); - PADDLE_ENFORCE_GE( - ctx->Outputs("Out").size(), 1UL, - platform::errors::InvalidArgument("Output(Out) should not be empty.")); - - auto inputs_dims = ctx->GetInputsDim("X"); - const size_t inputs_num = inputs_dims.size(); - auto outs_names = ctx->Outputs("Out"); - const size_t outputs_num = outs_names.size(); - - auto out_shape = std::vector(inputs_num); - - for (size_t i = 0; i < inputs_num; i++) { - out_shape[i] = inputs_dims[i][0]; - } - auto out_dims = phi::make_ddim(std::vector(out_shape)); - std::vector outs_dims(outputs_num, out_dims); - ctx->SetOutputsDim("Out", outs_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -142,7 +122,10 @@ class MeshgridGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(meshgrid, MeshgridInferShapeFunctor, + PD_INFER_META(phi::MeshgridInferMeta)); REGISTER_OPERATOR(meshgrid, ops::MeshgridOp, ops::MeshgridOpMaker, ops::MeshgridGradOpMaker, - ops::MeshgridGradOpMaker); + ops::MeshgridGradOpMaker, + MeshgridInferShapeFunctor); REGISTER_OPERATOR(meshgrid_grad, ops::MeshgridGradOp); diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 33c4cf94cf25a..91bad1430615f 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -19,6 +19,10 @@ limitations under the License. */ #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -27,39 +31,6 @@ class AdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "Adagrad"); - OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "Adagrad"); - OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment", "Adagrad"); - OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate", - "Adagrad"); - OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", "Adagrad"); - OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut", - "Adagrad"); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE(phi::product(lr_dims), 0, - platform::errors::InvalidArgument( - "Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.")); - PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1, - platform::errors::InvalidArgument( - "LearningRate should have one element")); - auto param_dims = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument("Param and Grad input of AdagradOp " - "should have the same dimension.")); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Moment"), - platform::errors::InvalidArgument("Param and Moment input of AdagradOp " - "should have the same dimension.")); - - ctx->SetOutputDim("ParamOut", param_dims); - ctx->SetOutputDim("MomentOut", param_dims); - } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -105,4 +76,7 @@ for numerical stability to avoid the division by zero error. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker); +DECLARE_INFER_SHAPE_FUNCTOR(adagrad, AdagradInferShapeFunctor, + PD_INFER_META(phi::AdagradInferMeta)); +REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker, + AdagradInferShapeFunctor); diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cc b/paddle/fluid/operators/optimizers/rmsprop_op.cc index cd6fdcf34e95f..b3458724482e9 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cc @@ -14,91 +14,16 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { class RmspropOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, - platform::errors::NotFound( - "Input(Param) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("MeanSquare"), true, - platform::errors::NotFound( - "Input(MeanSquare) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("LearningRate"), true, - platform::errors::NotFound( - "Input(LearningRate) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true, - platform::errors::NotFound( - "Input(Grad) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Moment"), true, - platform::errors::NotFound( - "Input(Moment) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(), - framework::proto::VarType::LOD_TENSOR, - platform::errors::InvalidArgument( - "The input var's type in RmspropOp should be " - "LoDTensor, but the received is %s", - ctx->GetInputsVarType("Param").front())); - - PADDLE_ENFORCE_EQ( - ctx->HasOutput("ParamOut"), true, - platform::errors::NotFound( - "Output(param_out) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("MomentOut"), true, - platform::errors::NotFound( - "Output(MomentOut) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("MeanSquareOut"), true, - platform::errors::NotFound( - "Output(MeanSquareOut) of RmspropOp should not be null.")); - if (ctx->Attrs().Get("centered")) { - PADDLE_ENFORCE_EQ( - ctx->HasOutput("MeanGradOut"), true, - platform::errors::NotFound( - "Output(MeanGradOut) of RmspropOp should not be null.")); - } - - auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "Param and grad input of RmspropOp should have the same dimension. " - "But received Param's dim [%s] and Grad's dim [%s].", - param_dim, ctx->GetInputDim("Grad"))); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"), - platform::errors::InvalidArgument( - "Param and Momentum input of RmspropOp " - "should have the same dimension. But received " - "Param's dim [%s] and Moment [%s]", - param_dim, ctx->GetInputDim("Moment"))); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"), - platform::errors::InvalidArgument( - "Param and Momentum input of RmspropOp " - "should have the same dimension. But received " - "Param's dim [%s] and MeanSquare [%s]", - param_dim, ctx->GetInputDim("MeanSquare"))); - - auto lr_dim = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_EQ(phi::product(lr_dim), 1, - platform::errors::InvalidArgument( - "Learning Rate of RmspropOp should be a scalar. But " - "received LearningRate's dim [%s]", - phi::product(lr_dim))); - - ctx->SetOutputDim("ParamOut", param_dim); - ctx->SetOutputDim("MomentOut", param_dim); - ctx->SetOutputDim("MeanSquareOut", param_dim); - if (ctx->Attrs().Get("centered")) { - ctx->SetOutputDim("MeanGradOut", param_dim); - } - } }; class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { @@ -169,4 +94,7 @@ The original slides that proposed Rmsprop: Slide 29 of } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); +DECLARE_INFER_SHAPE_FUNCTOR(rmsprop, RmspropInferShapeFunctor, + PD_INFER_META(phi::RmspropInferMeta)); +REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker, + RmspropInferShapeFunctor); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index 0e3f895d276af..f51d776d7195c 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -19,6 +19,10 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -26,46 +30,6 @@ class SGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, - platform::errors::NotFound( - "Input(Param) of SGDOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Grad"), true, - platform::errors::NotFound("Input(Grad) of SGDOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true, - platform::errors::NotFound( - "Input(LearningRate) of SGDOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true, - platform::errors::NotFound( - "Output(ParamOut) of SGDOp should not be null.")); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE(phi::product(lr_dims), 0, - platform::errors::NotFound( - "Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.")); - PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1, - platform::errors::InvalidArgument( - "Learning rate should have 1 element. But received " - "LearningRate dims [%s]", - phi::product(lr_dims))); - auto param_dim = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "SGD Operator's input Param and Grad dimensions do not match. " - "The Param %s shape is [%s], but the Grad %s shape is [%s].", - ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0], - ctx->GetInputDim("Grad"))); - } - ctx->SetOutputDim("ParamOut", param_dim); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -161,8 +125,10 @@ This operator implements one step of the stochastic gradient descent algorithm. } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(sgd, SGDInferShapeFunctor, + PD_INFER_META(phi::SGDInferMeta)); REGISTER_OPERATOR( sgd, ops::SGDOp, ops::SGDOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - ops::SGDOpInferVarType); + ops::SGDOpInferVarType, SGDInferShapeFunctor); diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index acf99d09ffb90..3bdb9cb972fc6 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -15,6 +15,10 @@ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { @@ -24,49 +28,6 @@ class TemporalShiftOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); - - auto dim_x = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(dim_x.size(), 4, - platform::errors::InvalidArgument( - "Input(X) rank should be 4 in shape of [N*T, C, H, " - "W], but received X rank(%d)", - dim_x.size())); - - int seg_num = ctx->Attrs().Get("seg_num"); - float shift_ratio = ctx->Attrs().Get("shift_ratio"); - PADDLE_ENFORCE_GT( - seg_num, 0, - platform::errors::InvalidArgument( - "Attr(seg_num) should be greater than 0, but received %d", - seg_num)); - PADDLE_ENFORCE_GT( - shift_ratio, 0., - platform::errors::InvalidArgument( - "Attr(shift_ratio) should be greater than 0, but received %d", - shift_ratio)); - PADDLE_ENFORCE_LT( - shift_ratio, 0.5, - platform::errors::InvalidArgument( - "Attr(shift_ratio) should be less than 0.5, but received %d", - shift_ratio)); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0, - platform::errors::InvalidArgument( - "Input(X) dimension[0] should be divided exactly " - "by Attr(seg_num), but received X dimension[0](%d) " - "mod seg_num(%d) != 0", - dim_x[0], seg_num)); - } - - ctx->SetOutputDim("Out", dim_x); - ctx->ShareLoD("X", "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -186,10 +147,13 @@ class TemporalShiftGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(temporal_shift, TemporalShiftInferShapeFunctor, + PD_INFER_META(phi::TemporalShiftInferMeta)); REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker, ops::TemporalShiftGradOpMaker, - ops::TemporalShiftGradOpMaker); + ops::TemporalShiftGradOpMaker, + TemporalShiftInferShapeFunctor); REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad); REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel, ops::TemporalShiftKernel); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 44ae53a00d18e..ab13df081aa28 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -75,6 +75,32 @@ void AllValueCompareInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } +void EmbeddingInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t padding_idx, + MetaTensor* out) { + auto table_dims = weight.dims(); + auto ids_dims = input.dims(); + int ids_rank = ids_dims.size(); + VLOG(5) << "ids rank is " << ids_rank << std::endl; + PADDLE_ENFORCE_EQ( + table_dims.size(), + 2, + phi::errors::InvalidArgument( + "ShapeError: The dimensions of the 'lookup table' must be 2. " + "But received lookup table's dimensions = %d, " + "lookup table's shape = [%s].", + table_dims.size(), + table_dims)); + + auto output_dims = phi::vectorize(ids_dims); + output_dims.push_back(table_dims[1]); + + out->set_dims(phi::make_ddim(output_dims)); + out->set_dtype(weight.dtype()); + out->share_lod(input); +} + void KLDivInferMeta(const MetaTensor& x, const MetaTensor& label, const std::string& reduction, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 751422a4def48..3fcbf69c35e25 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -37,6 +37,11 @@ void AllValueCompareInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void EmbeddingInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t padding_idx, + MetaTensor* out); + void KLDivInferMeta(const MetaTensor& x, const MetaTensor& label, const std::string& reduction, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 8e4f0b1fbb5c9..4fbd264f10f9f 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -66,6 +66,32 @@ void AdadeltaInferMeta(const MetaTensor& param, avg_squared_update_out->set_dtype(avg_squared_update.dtype()); } +void AdagradInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out) { + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_EQ( + phi::product(lr_dims), + 1, + phi::errors::InvalidArgument("LearningRate should have one element")); + auto param_dims = param.dims(); + + PADDLE_ENFORCE_EQ( + param_dims, + moment.dims(), + phi::errors::InvalidArgument("Param and Moment input of AdagradOp " + "should have the same dimension.")); + + param_out->set_dims(param_dims); + param_out->set_dtype(param.dtype()); + moment_out->set_dims(param_dims); + moment_out->set_dtype(moment.dtype()); +} + void AdamInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -1390,6 +1416,22 @@ void InterpolateInferMeta( } } +void MeshgridInferMeta(const std::vector& inputs, + std::vector outputs) { + const size_t inputs_num = inputs.size(); + + auto out_shape = std::vector(inputs_num); + + for (size_t i = 0; i < inputs.size(); i++) { + out_shape[i] = inputs[i]->dims()[0]; + } + auto out_dims = phi::make_ddim(std::vector(out_shape)); + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i]->set_dims(out_dims); + outputs[i]->set_dtype(inputs[0]->dtype()); + } +} + void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { auto inputs_dims = GetMetaTensorsDim(x); @@ -1582,6 +1624,65 @@ void PsroiPoolInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RmspropInferMeta(const MetaTensor& param, + const MetaTensor& mean_square, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + paddle::optional mean_grad, + float epsilon, + float decay, + float momentum, + bool centered, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* mean_square_out, + MetaTensor* mean_grad_out) { + if (centered) { + PADDLE_ENFORCE_NOT_NULL( + mean_grad_out, + phi::errors::InvalidArgument( + "Output(MeanGradOut) of RmspropOp should not be null.")); + } + + auto param_dim = param.dims(); + PADDLE_ENFORCE_EQ(param_dim, + moment.dims(), + phi::errors::InvalidArgument( + "Param and Momentum input of RmspropOp " + "should have the same dimension. But received " + "Param's dim [%s] and Moment [%s]", + param_dim, + moment.dims())); + PADDLE_ENFORCE_EQ(param_dim, + mean_square.dims(), + phi::errors::InvalidArgument( + "Param and Momentum input of RmspropOp " + "should have the same dimension. But received " + "Param's dim [%s] and MeanSquare [%s]", + param_dim, + mean_square.dims())); + + auto lr_dim = learning_rate.dims(); + PADDLE_ENFORCE_EQ(phi::product(lr_dim), + 1, + phi::errors::InvalidArgument( + "Learning Rate of RmspropOp should be a scalar. But " + "received LearningRate's dim [%s]", + phi::product(lr_dim))); + + param_out->set_dims(param_dim); + param_out->set_dtype(param.dtype()); + moment_out->set_dims(param_dim); + moment_out->set_dtype(moment.dtype()); + mean_square_out->set_dims(param_dim); + mean_square_out->set_dtype(mean_square.dtype()); + if (centered) { + mean_grad_out->set_dims(param_dim); + mean_grad_out->set_dtype(mean_grad.get_ptr()->dtype()); + } +} + void RnnInferMeta(const MetaTensor& x, const std::vector& pre_state, const std::vector& weight_list, @@ -1667,6 +1768,29 @@ void RnnInferMeta(const MetaTensor& x, } } +void SGDInferMeta(const MetaTensor& param, + const MetaTensor& learning_rate, + const MetaTensor& grad, + paddle::optional master_param, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* master_param_out) { + PADDLE_ENFORCE_NOT_NULL(param_out, + phi::errors::InvalidArgument( + "Output(ParamOut) of SGDOp should not be null.")); + + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_EQ(phi::product(lr_dims), + 1, + phi::errors::InvalidArgument( + "Learning rate should have 1 element. But received " + "LearningRate dims [%s]", + phi::product(lr_dims))); + + param_out->set_dims(param.dims()); + param_out->set_dtype(param.dtype()); +} + void StackInferMeta(const std::vector& x, int axis, MetaTensor* out) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 72c64e8500ad2..64a11ed0b2621 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -47,6 +47,14 @@ void AdadeltaInferMeta(const MetaTensor& param, MetaTensor* avg_squared_grad_out, MetaTensor* avg_squared_update_out); +void AdagradInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out); + void AdamaxInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -215,6 +223,9 @@ void InterpolateInferMeta( MetaTensor* output, MetaConfig config = MetaConfig()); +void MeshgridInferMeta(const std::vector& inputs, + std::vector outputs); + void MultiDotInferMeta(const std::vector& x, MetaTensor* out); void MultiplexInferMeta(const std::vector& ins, @@ -230,6 +241,21 @@ void PsroiPoolInferMeta(const MetaTensor& x, float spatial_scale, MetaTensor* out); +void RmspropInferMeta(const MetaTensor& param, + const MetaTensor& mean_square, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + paddle::optional mean_grad, + float epsilon, + float decay, + float momentum, + bool centered, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* mean_square_out, + MetaTensor* mean_grad_out); + void RnnInferMeta(const MetaTensor& x, const std::vector& pre_state, const std::vector& weight_list, @@ -247,6 +273,14 @@ void RnnInferMeta(const MetaTensor& x, std::vector state, MetaTensor* reserve); +void SGDInferMeta(const MetaTensor& param, + const MetaTensor& learning_rate, + const MetaTensor& grad, + paddle::optional master_param, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* master_param_out); + void StackInferMeta(const std::vector& x, int axis, MetaTensor* out); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 6bf7a36b06534..36c192cbf2748 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2102,6 +2102,52 @@ void SumRawInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void TemporalShiftInferMeta(const MetaTensor& x, + int seg_num, + float shift_ratio, + const std::string& data_format, + MetaTensor* out, + MetaConfig config) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_EQ(dim_x.size(), + 4, + phi::errors::InvalidArgument( + "Input(X) rank should be 4 in shape of [N*T, C, H, " + "W], but received X rank(%d)", + dim_x.size())); + + PADDLE_ENFORCE_GT( + seg_num, + 0, + phi::errors::InvalidArgument( + "Attr(seg_num) should be greater than 0, but received %d", seg_num)); + PADDLE_ENFORCE_GT( + shift_ratio, + 0., + phi::errors::InvalidArgument( + "Attr(shift_ratio) should be greater than 0, but received %d", + shift_ratio)); + PADDLE_ENFORCE_LT( + shift_ratio, + 0.5, + phi::errors::InvalidArgument( + "Attr(shift_ratio) should be less than 0.5, but received %d", + shift_ratio)); + + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, + 0, + phi::errors::InvalidArgument( + "Input(X) dimension[0] should be divided exactly " + "by Attr(seg_num), but received X dimension[0](%d) " + "mod seg_num(%d) != 0", + dim_x[0], + seg_num)); + } + + out->share_meta(x); +} + void TileInferMeta(const MetaTensor& x, const IntArray& repeat_times, MetaTensor* out, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 54f70d8d55405..bda9c83fce1f2 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -315,6 +315,13 @@ void SumRawInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); +void TemporalShiftInferMeta(const MetaTensor& x, + int seg_num, + float shift_ratio, + const std::string& data_format, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void TileInferMeta(const MetaTensor& x, const IntArray& repeat_times, MetaTensor* out, From 4da467370f3be2e6336d51760fba9debb0304318 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Sun, 3 Apr 2022 15:39:41 +0800 Subject: [PATCH 45/93] [Eager] do not mutabledata when init (#41331) * do not mutabledata when init, test=develop * refine, test=develop * fix copy_, test=develop * refine, test=develop --- paddle/fluid/pybind/eager.cc | 7 ++--- paddle/fluid/pybind/eager_method.cc | 11 ++++++-- .../test_cuda_max_memory_allocated.py | 28 +++++++++++++++---- .../unittests/test_cuda_memory_reserved.py | 28 +++++++++++++++---- 4 files changed, 54 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 5278f371dd4e7..657c79e7bd3aa 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -77,9 +77,6 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name, phi::make_intrusive(place), phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), ddims)); - if (phi::product(ddims) > 0) { - dense_tensor->mutable_data(place); - } self->tensor.set_impl(dense_tensor); } @@ -92,6 +89,7 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name, } void InitTensorWithNumpyValue(TensorObject* self, const py::object& array, + const paddle::platform::Place& place, bool zero_copy = false) { PADDLE_ENFORCE_EQ( self->tensor.defined(), true, @@ -102,7 +100,6 @@ void InitTensorWithNumpyValue(TensorObject* self, const py::object& array, "eager tensor before init it with NumPy.")); phi::DenseTensor* impl_ptr = static_cast(self->tensor.impl().get()); - paddle::platform::Place place = impl_ptr->place(); if (platform::is_cpu_place(place)) { SetTensorFromPyArray(impl_ptr, array, place, zero_copy); } else if (platform::is_xpu_place(place)) { @@ -289,7 +286,7 @@ void AutoInitTensorByPyArray(TensorObject* py_tensor_ptr, EmptyTensorInitializer(py_tensor_ptr, act_name, place, persistable, stop_gradient); - InitTensorWithNumpyValue(py_tensor_ptr, numpy_value, zero_copy); + InitTensorWithNumpyValue(py_tensor_ptr, numpy_value, place, zero_copy); } // initialize Tensor by Tensor or framework::Tensor (mix args and diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index d9face124bd82..814243e0a5774 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -330,17 +330,22 @@ static PyObject* tensor_method_copy_(TensorObject* self, PyObject* args, bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1); VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to " << self->tensor.name(); - if (!self->tensor.defined()) { + if (!self->tensor.initialized()) { egr::EagerUtils::autograd_meta(&(self->tensor)) ->SetStopGradient( egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient()); egr::EagerUtils::autograd_meta(&(self->tensor)) ->SetPersistable( egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable()); + if (src_tensor.initialized()) { + self->tensor.copy_(src_tensor, src_tensor.inner_place(), blocking); + } + } else { + if (src_tensor.initialized()) { + self->tensor.copy_(src_tensor, self->tensor.inner_place(), blocking); + } } - self->tensor.copy_(src_tensor, self->tensor.inner_place(), blocking); - VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to " << self->tensor.name(); Py_INCREF(Py_None); diff --git a/python/paddle/fluid/tests/unittests/test_cuda_max_memory_allocated.py b/python/paddle/fluid/tests/unittests/test_cuda_max_memory_allocated.py index 51c9ba182ab72..ae8bdeed1ef7a 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_max_memory_allocated.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_max_memory_allocated.py @@ -16,10 +16,11 @@ import unittest from paddle.fluid import core from paddle.device.cuda import device_count, memory_allocated, max_memory_allocated +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class TestMaxMemoryAllocated(unittest.TestCase): - def test_max_memory_allocated(self, device=None): + def func_test_max_memory_allocated(self, device=None): if core.is_compiled_with_cuda(): alloc_time = 100 max_alloc_size = 10000 @@ -35,16 +36,26 @@ def test_max_memory_allocated(self, device=None): self.assertEqual(peak_memory_allocated_size, max_memory_allocated(device)) - def test_max_memory_allocated_for_all_places(self): + def test_max_memory_allocated(self): + with _test_eager_guard(): + self.func_test_max_memory_allocated() + self.func_test_max_memory_allocated() + + def func_test_max_memory_allocated_for_all_places(self): if core.is_compiled_with_cuda(): gpu_num = device_count() for i in range(gpu_num): paddle.device.set_device("gpu:" + str(i)) - self.test_max_memory_allocated(core.CUDAPlace(i)) - self.test_max_memory_allocated(i) - self.test_max_memory_allocated("gpu:" + str(i)) + self.func_test_max_memory_allocated(core.CUDAPlace(i)) + self.func_test_max_memory_allocated(i) + self.func_test_max_memory_allocated("gpu:" + str(i)) - def test_max_memory_allocated_exception(self): + def test_max_memory_allocated_for_all_places(self): + with _test_eager_guard(): + self.func_test_max_memory_allocated_for_all_places() + self.func_test_max_memory_allocated_for_all_places() + + def func_test_max_memory_allocated_exception(self): if core.is_compiled_with_cuda(): wrong_device = [ core.CPUPlace(), device_count() + 1, -2, 0.5, "gpu1", "npu" @@ -56,6 +67,11 @@ def test_max_memory_allocated_exception(self): with self.assertRaises(BaseException): max_memory_allocated() + def test_max_memory_allocated_exception(self): + with _test_eager_guard(): + self.func_test_max_memory_allocated_exception() + self.func_test_max_memory_allocated_exception() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_cuda_memory_reserved.py b/python/paddle/fluid/tests/unittests/test_cuda_memory_reserved.py index 149760de8b231..ca551ab4a3f28 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_memory_reserved.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_memory_reserved.py @@ -17,26 +17,37 @@ import numpy as np from paddle.fluid import core from paddle.device.cuda import device_count, memory_reserved +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class TestMemoryreserved(unittest.TestCase): - def test_memory_reserved(self, device=None): + def func_test_memory_reserved(self, device=None): if core.is_compiled_with_cuda(): tensor = paddle.zeros(shape=[256]) alloc_size = 4 * 256 # 256 float32 data, with 4 bytes for each one memory_reserved_size = memory_reserved(device) self.assertEqual(memory_reserved_size, alloc_size) - def test_memory_reserved_for_all_places(self): + def test_memory_reserved(self): + with _test_eager_guard(): + self.func_test_memory_reserved() + self.func_test_memory_reserved() + + def func_test_memory_reserved_for_all_places(self): if core.is_compiled_with_cuda(): gpu_num = device_count() for i in range(gpu_num): paddle.device.set_device("gpu:" + str(i)) - self.test_memory_reserved(core.CUDAPlace(i)) - self.test_memory_reserved(i) - self.test_memory_reserved("gpu:" + str(i)) + self.func_test_memory_reserved(core.CUDAPlace(i)) + self.func_test_memory_reserved(i) + self.func_test_memory_reserved("gpu:" + str(i)) - def test_memory_reserved_exception(self): + def test_memory_reserved_for_all_places(self): + with _test_eager_guard(): + self.func_test_memory_reserved_for_all_places() + self.func_test_memory_reserved_for_all_places() + + def func_test_memory_reserved_exception(self): if core.is_compiled_with_cuda(): wrong_device = [ core.CPUPlace(), device_count() + 1, -2, 0.5, "gpu1", "npu" @@ -48,6 +59,11 @@ def test_memory_reserved_exception(self): with self.assertRaises(BaseException): memory_reserved() + def test_memory_reserved_exception(self): + with _test_eager_guard(): + self.func_test_memory_reserved_exception() + self.func_test_memory_reserved_exception() + if __name__ == "__main__": unittest.main() From 3f57ef7a1fedd598d9d171261df66c50b0fa5222 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Sun, 3 Apr 2022 17:22:56 +0800 Subject: [PATCH 46/93] [Phi]Concat grad (#41112) * add concat_grad kernel * fix error * remove comment code * fix outs nullptr error * change to phi header * add concat_grad declare for standalone_executor_test --- .../new_executor/standalone_executor_test.cc | 3 +- paddle/fluid/operators/concat_op.cc | 15 ---- paddle/fluid/operators/concat_op.cu.cc | 36 ---------- paddle/fluid/operators/concat_op.h | 56 --------------- paddle/phi/kernels/concat_grad_kernel.h | 30 ++++++++ paddle/phi/kernels/cpu/concat_grad_kernel.cc | 35 ++++++++++ paddle/phi/kernels/gpu/concat_grad_kernel.cu | 37 ++++++++++ .../kernels/impl/concat_grad_kernel_impl.h | 69 +++++++++++++++++++ paddle/phi/ops/compat/concat_sig.cc | 14 ++++ 9 files changed, 187 insertions(+), 108 deletions(-) delete mode 100644 paddle/fluid/operators/concat_op.cu.cc create mode 100644 paddle/phi/kernels/concat_grad_kernel.h create mode 100644 paddle/phi/kernels/cpu/concat_grad_kernel.cc create mode 100644 paddle/phi/kernels/gpu/concat_grad_kernel.cu create mode 100644 paddle/phi/kernels/impl/concat_grad_kernel_impl.h diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index b5670565e2a64..fbcbb2ca23bcb 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -46,7 +46,7 @@ USE_OP_ITSELF(elementwise_add_grad); USE_OP_ITSELF(matmul_grad); USE_OP_ITSELF(square); USE_OP_ITSELF(transpose2_grad); -USE_OP(concat_grad); +USE_OP_ITSELF(concat_grad); USE_OP_ITSELF(elementwise_mul_grad); USE_OP_ITSELF(sigmoid_grad); USE_OP_ITSELF(tanh_grad); @@ -67,6 +67,7 @@ PD_DECLARE_KERNEL(transpose, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(reshape, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(split, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT); diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 059fafa3e7f4d..a467f2dbee7c9 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -216,18 +216,3 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, ops::ConcatDoubleGradOpMaker, ops::ConcatDoubleGradOpMaker, ops::ConcatOpGradNoNeedBufferVarInferer); - -REGISTER_OP_CPU_KERNEL( - concat_grad, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel>, - ops::ConcatGradKernel>); diff --git a/paddle/fluid/operators/concat_op.cu.cc b/paddle/fluid/operators/concat_op.cu.cc deleted file mode 100644 index f7b64f16e2d8b..0000000000000 --- a/paddle/fluid/operators/concat_op.cu.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/concat_op.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - concat_grad, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel>, - ops::ConcatGradKernel>); diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index ec43e2ad374db..50aca54c12dec 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -39,62 +39,6 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { } return axis > 0 ? axis : 0; } -template -class ConcatGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); - auto outs = - ctx.MultiOutput(framework::GradVarName("X")); - - { - auto dx = outs; - auto x = ins; - for (size_t i = 0; i < dx.size(); ++i) { - if (dx[i] != nullptr) { - dx[i]->set_lod(x[i]->lod()); - } - } - } - PADDLE_ENFORCE_NOT_NULL(ins[0], - platform::errors::NotFound( - "The first input tensor is not initalized.")); - - auto axis = ctx.Attr("axis"); - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - axis = GetDataFromTensor(axis_tensor)[0]; - } - axis = ComputeAxis(static_cast(axis), - static_cast(ins[0]->dims().size())); - // get output tensor that the name is not kEmptyVarName - std::vector outputs; - for (size_t j = 0; j < outs.size(); ++j) { - if (out_var_names[j] != framework::kEmptyVarName && - outs[j]->numel() != 0UL) { - outs[j]->mutable_data(ctx.GetPlace()); - outputs.push_back(outs[j]); - } else { - outputs.push_back(nullptr); - } - } - auto& dev_ctx = ctx.template device_context(); - - // Sometimes direct copies will be faster, this maybe need deeply analysis. - if (axis == 0 && outs.size() < 10) { - std::vector ref_shape; - ref_shape.insert(ref_shape.begin(), ins.begin(), ins.end()); - StridedMemcpyWithAxis0(dev_ctx, *out_grad, ref_shape, &outputs); - } else { - math::SplitFunctor split_functor; - split_functor(dev_ctx, *out_grad, ctx.MultiInput("X"), - static_cast(axis), &outputs); - } - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/phi/kernels/concat_grad_kernel.h b/paddle/phi/kernels/concat_grad_kernel.h new file mode 100644 index 0000000000000..e407d73bb49ee --- /dev/null +++ b/paddle/phi/kernels/concat_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/kernels/empty_kernel.h" +namespace phi { + +template +void ConcatGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const Scalar& axis_scalar, + std::vector x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/concat_grad_kernel.cc b/paddle/phi/kernels/cpu/concat_grad_kernel.cc new file mode 100644 index 0000000000000..56ed95769fef4 --- /dev/null +++ b/paddle/phi/kernels/cpu/concat_grad_kernel.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/concat_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(concat_grad, + CPU, + ALL_LAYOUT, + phi::ConcatGradKernel, + double, + float, + bool, + int64_t, + int, + uint8_t, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/concat_grad_kernel.cu b/paddle/phi/kernels/gpu/concat_grad_kernel.cu new file mode 100644 index 0000000000000..2445978daca46 --- /dev/null +++ b/paddle/phi/kernels/gpu/concat_grad_kernel.cu @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/concat_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(concat_grad, + GPU, + ALL_LAYOUT, + phi::ConcatGradKernel, + float, + double, + bool, + int64_t, + int, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/concat_grad_kernel_impl.h b/paddle/phi/kernels/impl/concat_grad_kernel_impl.h new file mode 100644 index 0000000000000..e89920340ff18 --- /dev/null +++ b/paddle/phi/kernels/impl/concat_grad_kernel_impl.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/concat_funcs.h" + +namespace phi { + +template +void ConcatGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const Scalar& axis_scalar, + std::vector x_grad) { + auto outs = x_grad; + { + auto dx = x_grad; + for (size_t i = 0; i < dx.size(); ++i) { + if (dx[i] != nullptr) { + dx[i]->set_lod(x[i]->lod()); + } + } + } + PADDLE_ENFORCE_NOT_NULL( + x[0], phi::errors::NotFound("The first input tensor is not initalized.")); + + auto axis = axis_scalar.to(); + axis = funcs::ComputeAxis(static_cast(axis), + static_cast(x[0]->dims().size())); + // get output tensor that the name is not kEmptyVarName + std::vector outputs; + for (size_t j = 0; j < outs.size(); ++j) { + if (outs[j] && outs[j]->numel() != 0UL) { + dev_ctx.template Alloc(outs[j]); + + outputs.push_back(outs[j]); + } else { + outputs.push_back(nullptr); + } + } + + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + std::vector ref_shape; + ref_shape.insert(ref_shape.begin(), x.begin(), x.end()); + paddle::operators::StridedMemcpyWithAxis0( + dev_ctx, out_grad, ref_shape, &outputs); + } else { + phi::funcs::SplitFunctor split_functor; + split_functor(dev_ctx, out_grad, x, static_cast(axis), &outputs); + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/concat_sig.cc b/paddle/phi/ops/compat/concat_sig.cc index 21e653ccfe90f..d443f521c6146 100644 --- a/paddle/phi/ops/compat/concat_sig.cc +++ b/paddle/phi/ops/compat/concat_sig.cc @@ -23,6 +23,20 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); } +KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("concat_grad", + {"X", {GradVarName("Out")}}, + {"AxisTensor"}, + {{GradVarName("X")}}); + } + return KernelSignature("concat_grad", + {"X", {GradVarName("Out")}}, + {"axis"}, + {{GradVarName("X")}}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(concat, phi::ConcatOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(concat_grad, phi::ConcatGradOpArgumentMapping); From ea4b56f2d0eb5cc146b83abb574f2796de03c0d4 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Sun, 3 Apr 2022 21:15:04 +0800 Subject: [PATCH 47/93] Switch dy2st UT to eager mode by cmake (#41317) * Switch dy2st UT to eager mode by cmake * Rename ENVS * Remove invalid UT * Remove error UT * Remove test_bert --- .../unittests/dygraph_to_static/CMakeLists.txt | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt index 7ee5e83e76d6e..eeb377ff3b4a2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt @@ -1,12 +1,17 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) +set(DY2ST_EAGER_TEST_ENVS ${GC_ENVS} FLAGS_enable_eager_mode=1) +set(TEST_EAGER_OPS test_bmn test_break_continue test_ifelse test_loop test_mnist_amp + test_mnist_pure_fp16 test_mobile_net test_program_translator test_ptb_lm test_reinforcement_learning + test_resnet test_resnet_amp test_resnet_pure_fp16 test_se_resnet test_sentiment test_seq2seq + test_tsm test_word2vec test_yolov3) list(REMOVE_ITEM TEST_OPS test_lac) # NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope will # be removed and will cause some random failed in multi-thread. if(NOT ON_INFER) - py_test_modules(test_lac MODULES test_lac) + py_test_modules(test_lac MODULES test_lac ENVS FLAGS_enable_eager_mode=1) set_tests_properties(test_lac PROPERTIES TIMEOUT 120) endif() @@ -15,7 +20,12 @@ if(WIN32 AND NOT WITH_GPU) endif() foreach(TEST_OP ${TEST_OPS}) - py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) + list(FIND TEST_EAGER_OPS ${TEST_OP} WAS_FOUND) + if (NOT WAS_FOUND EQUAL -1) + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${DY2ST_EAGER_TEST_ENVS}) + else() + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) + endif() endforeach(TEST_OP) set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900) From 1ae0730f3800c5975e9c6287c0a7c3fd6521d187 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 3 Apr 2022 23:04:31 +0800 Subject: [PATCH 48/93] fix bug caused by arange (#41372) --- python/paddle/fluid/layers/tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 81a60bf517522..b47ddd0dc9fc3 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1433,10 +1433,6 @@ def range(start, end, step, dtype, name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dygraph_mode(): - return _C_ops.final_state_arange(start, end, step, dtype, - _current_expected_place()) - if not isinstance(start, Variable): with device_guard("cpu"): start = fill_constant([1], dtype, start, force_cpu=True) @@ -1455,6 +1451,10 @@ def range(start, end, step, dtype, name=None): elif step.dtype != dtype: step = cast(step, dtype) + if in_dygraph_mode(): + return _C_ops.final_state_arange(start, end, step, dtype, + _current_expected_place()) + if _in_legacy_dygraph(): out = _C_ops.range(start, end, step) out.stop_gradient = True From fd591ecb457d0f7f76ef6ddaa6c2ef02248bdb5f Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 4 Apr 2022 06:04:33 +0800 Subject: [PATCH 49/93] [Eager]Polish enable/disable_legacy_dygraph logic (#41364) * [Eager]Polish enable/disable_legacy_dygraph logic * merge yunfei PR * merge other pr --- python/paddle/fluid/framework.py | 53 ++++++++++++++-------------- python/paddle/tensor/manipulation.py | 35 ++++++++---------- 2 files changed, 40 insertions(+), 48 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 20c441f364145..a329610eeae83 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -115,14 +115,38 @@ def _update_monkey_methods(is_eager): from .dygraph.varbase_patch_methods import monkey_patch_varbase from .dygraph import monkey_patch_math_varbase + global _already_patch_eager_tensor + global _already_patch_varbase + assert isinstance(is_eager, bool) + # switch into eager mode if is_eager: _C_ops.switch_to_eager_ops() + if not _already_patch_eager_tensor: + monkey_patch_varbase() + monkey_patch_math_varbase() + + _already_patch_eager_tensor = True + # switch back into legacy mode else: _C_ops.switch_to_core_ops() + if not _already_patch_varbase: + monkey_patch_varbase() + monkey_patch_math_varbase() + + _already_patch_varbase = True - monkey_patch_varbase() - monkey_patch_math_varbase() + # switch Paddle.Tensor bind type + _switch_tensor_bind_type(is_eager) + + +def _switch_tensor_bind_type(is_eager): + import paddle + if is_eager: + paddle.Tensor = core.eager.Tensor + else: + paddle.Tensor = core.VarBase + paddle.Tensor.__qualname__ = 'Tensor' def _enable_legacy_dygraph(): @@ -183,35 +207,10 @@ def _non_static_mode(): @signature_safe_contextmanager def _test_eager_guard(place=None): _disable_legacy_dygraph() - from paddle import _C_ops - _C_ops.switch_to_eager_ops() - global _already_patch_eager_tensor - global _already_patch_varbase - from .dygraph.varbase_patch_methods import monkey_patch_varbase - from .dygraph import monkey_patch_math_varbase - if not _already_patch_eager_tensor: - monkey_patch_varbase() - monkey_patch_math_varbase() - - # Ugly setting - from paddle.tensor.manipulation import fill_, zero_, fill_diagonal_, fill_diagonal_tensor_, tolist - setattr(core.eager.Tensor, 'fill_', fill_) - setattr(core.eager.Tensor, 'zero_', zero_) - setattr(core.eager.Tensor, 'fill_diagonal_', fill_diagonal_) - setattr(core.eager.Tensor, 'fill_diagonal_tensor_', - fill_diagonal_tensor_) - setattr(core.eager.Tensor, 'tolist', tolist) - - _already_patch_eager_tensor = True try: yield finally: _enable_legacy_dygraph() - if not _already_patch_varbase: - monkey_patch_varbase() - monkey_patch_math_varbase() - _already_patch_varbase = True - _C_ops.switch_to_core_ops() global_ipu_index = None diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ca807c286a05b..f6bbadf98726f 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -76,9 +76,6 @@ def fill_(x, value): float(value), "value_int", int(value)) -setattr(core.VarBase, 'fill_', fill_) - - @dygraph_only def zero_(x): """ @@ -107,9 +104,6 @@ def zero_(x): return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0)) -setattr(core.VarBase, 'zero_', zero_) - - @dygraph_only def fill_diagonal_(x, value, offset=0, wrap=False, name=None): """ @@ -156,9 +150,6 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None): True) -setattr(core.VarBase, 'fill_diagonal_', fill_diagonal_) - - def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False): inshape = x.shape assert dim1 < len(inshape) and dim1 >= -len(inshape), ( @@ -226,9 +217,6 @@ def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None): x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=True) -setattr(core.VarBase, 'fill_diagonal_tensor_', fill_diagonal_tensor_) - - def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None): """ This function fill the source Tensor y into the x Tensor's diagonal. @@ -262,12 +250,6 @@ def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None): x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=False) -setattr(core.VarBase, 'fill_diagonal_tensor', fill_diagonal_tensor) - -if _in_eager_without_dygraph_check(): - setattr(core.eager.Tensor, 'fill_diagonal_tensor', fill_diagonal_tensor) - - @dygraph_only def tolist(x): """ @@ -301,9 +283,6 @@ def tolist(x): return x.numpy().tolist() -setattr(core.VarBase, 'tolist', tolist) - - def concat(x, axis=0, name=None): """ @@ -2961,3 +2940,17 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'): values = paddle.broadcast_to(values, indices.shape) return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce", reduce) + + +# TODO(dev): We need avoid implementing it by this way. +__METHODS = { + 'fill_': fill_, + 'zero_': zero_, + 'fill_diagonal_': fill_diagonal_, + 'fill_diagonal_tensor_': fill_diagonal_tensor_, + "fill_diagonal_tensor": fill_diagonal_tensor, + 'tolist': tolist +} +for name, func in __METHODS.items(): + setattr(core.VarBase, name, func) + setattr(core.eager.Tensor, name, func) From 3152f3fb2fae568adc7f8443102b432453278b71 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 4 Apr 2022 06:06:13 +0800 Subject: [PATCH 50/93] [Yaml] add yaml for gather op and elementwise_mod op . (#41348) * gather op * add mod --- .../tests/unittests/test_activation_op.py | 5 +- .../unittests/test_elementwise_mod_op.py | 11 +++- .../fluid/tests/unittests/test_gather_op.py | 12 ++--- python/paddle/tensor/manipulation.py | 6 +-- python/paddle/tensor/math.py | 50 +++++++++---------- python/paddle/utils/code_gen/api.yaml | 20 ++++++++ python/paddle/utils/code_gen/backward.yaml | 23 ++++++++- 7 files changed, 88 insertions(+), 39 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 5573ecf33687b..04e37a9b0379a 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -2326,7 +2326,7 @@ class TestPow(TestActivation): def setUp(self): self.op_type = "pow" self.python_api = paddle.pow - self.check_eager = False + self.check_eager = True self.init_dtype() np.random.seed(1024) @@ -2337,6 +2337,9 @@ def setUp(self): self.attrs = {'factor': 3.0} self.outputs = {'Out': out} + def test_check_output(self): + self.check_output(check_eager=self.check_eager) + def test_check_grad(self): if self.dtype == np.float16: return diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py index 2a8ca51693ecf..c6973255f2644 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py @@ -29,6 +29,7 @@ def init_kernel_type(self): def setUp(self): self.op_type = "elementwise_mod" + self.python_api = paddle.remainder self.axis = -1 self.init_dtype() self.init_input_output() @@ -43,7 +44,10 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + if self.attrs['axis'] == -1: + self.check_output(check_eager=True) + else: + self.check_output(check_eager=False) def init_input_output(self): self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype) @@ -76,7 +80,10 @@ def init_input_output(self): self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y) def test_check_output(self): - self.check_output() + if self.attrs['axis'] == -1: + self.check_output(check_eager=True) + else: + self.check_output(check_eager=False) class TestElementwiseModOpDouble(TestElementwiseModOpFloat): diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 9ec2d1acdb5f3..3d7dc2da052f3 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -43,10 +43,10 @@ def setUp(self): self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=False) + self.check_grad(['X'], 'Out', check_eager=True) def config(self): """ @@ -136,10 +136,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=0.5, check_eager=False) + self.check_grad(['X'], 'Out', numeric_grad_delta=0.5, check_eager=True) def config(self): """ @@ -165,10 +165,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=False) + self.check_grad(['X'], 'Out', check_eager=True) def config(self): """ diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f6bbadf98726f..30e559151ed9e 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1391,9 +1391,9 @@ def gather(x, index, axis=None, name=None): if axis is None: axis = 0 - #if in_dygraph_mode(): - #return _C_ops.final_state_gather(x, index, axis) - if _non_static_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_gather(x, index, axis) + if _in_legacy_dygraph(): axis = axis.item() if isinstance(axis, paddle.Tensor) else axis return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ccd5efbd580af..adca732dfdaa0 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -150,41 +150,38 @@ def pow(x, y, name=None): """ # in dynamic graph mode - #if in_dygraph_mode(): - #if isinstance(y, (int, float)): - #return _C_ops.final_state_pow(x, y) - #elif isinstance(y, (paddle.Tensor, Variable)): - #return _elementwise_op_in_dygraph( - #x, y, axis=-1, act=None, op_name='elementwise_pow') - #else: - #raise TypeError('y must be scalar or tensor type, but received: %s '% (y.dtype)) - - #if _in_legacy_dygraph(): - if _non_static_mode(): + if in_dygraph_mode(): if isinstance(y, (int, float)): - return _C_ops.pow(x, 'factor', y) + return _C_ops.final_state_pow(x, y) elif isinstance(y, (paddle.Tensor, Variable)): return _elementwise_op_in_dygraph( x, y, axis=-1, act=None, op_name='elementwise_pow') else: raise TypeError('y must be scalar or tensor type, but received: %s '% (y.dtype)) - # in static graph mode - else: + if _in_legacy_dygraph(): if isinstance(y, (int, float)): - helper = LayerHelper('pow', **locals()) - inputs = {'X': x} - attrs = {'factor': y} - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op( - type='pow', inputs=inputs, outputs={'Out': out}, attrs=attrs) - return out + return _C_ops.pow(x, 'factor', y) elif isinstance(y, (paddle.Tensor, Variable)): - # TODO A potential speed improvement is supporting different types in C++ and removing the cast ops here - helper = LayerHelper('elementwise_pow', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - return _elementwise_op(LayerHelper('elementwise_pow', **locals())) + return _elementwise_op_in_dygraph( + x, y, axis=-1, act=None, op_name='elementwise_pow') else: - raise TypeError('y must be scalar or tensor type, but received: %s '% (type(y))) + raise TypeError('y must be scalar or tensor type, but received: %s '% (y.dtype)) + # in static graph mode + if isinstance(y, (int, float)): + helper = LayerHelper('pow', **locals()) + inputs = {'X': x} + attrs = {'factor': y} + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='pow', inputs=inputs, outputs={'Out': out}, attrs=attrs) + return out + elif isinstance(y, (paddle.Tensor, Variable)): + # TODO A potential speed improvement is supporting different types in C++ and removing the cast ops here + helper = LayerHelper('elementwise_pow', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + return _elementwise_op(LayerHelper('elementwise_pow', **locals())) + else: + raise TypeError('y must be scalar or tensor type, but received: %s '% (type(y))) OP_NAMEMAPPING = { @@ -192,6 +189,7 @@ def pow(x, y, name=None): 'elementwise_min': 'final_state_minimum', 'elementwise_pow': 'final_state_elementwise_pow', 'elementwise_floordiv': 'final_state_floor_divide', + 'elementwise_mod': 'final_state_modulo', } @dygraph_only diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 0b855b0f967ba..139eb3556b058 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -632,6 +632,16 @@ data_type : dtype > x backend : place > x +- api : gather + args : (Tensor x, Tensor index, Scalar axis=0) + output : Tensor(out) + infer_meta : + func : GatherInferMeta + kernel : + func : gather + data_type: x + backward : gather_grad + - api : gather_nd args : (Tensor x, Tensor index) output : Tensor @@ -1220,6 +1230,16 @@ func : pool3d backward : pool3d_grad +- api : pow + args : (Tensor x, Scalar s) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : pow + backward : pow_grad + - api : prelu args : (Tensor x, Tensor alpha, str data_format, str mode) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index d3d589d00f7f2..6ce0ae1b78a85 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -178,7 +178,7 @@ output : Tensor(x_grad), Tensor(filter_grad) infer_meta : func : ConvTransposeGradInferMeta - kernel : + kernel : func : conv2d_transpose_grad - backward_api : conv3d_transpose_grad @@ -389,6 +389,17 @@ kernel : func : frobenius_norm_grad +- backward_api : gather_grad + forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) + args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0, bool overwrite=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + data_type: x + func : gather_grad + - backward_api : gather_nd_grad forward : gather_nd (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) @@ -803,6 +814,16 @@ kernel : func : pool3d_grad +- backward_api : pow_grad + forward : pow(Tensor x, Scalar s) -> Tensor(out) + args : (Tensor x, Tensor out_grad, Scalar s=-1) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : pow_grad + - backward_api : prelu_grad forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out) args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode) From c5285cc5834406d87f6763d6fa77bfc1ca5c8c26 Mon Sep 17 00:00:00 2001 From: From00 Date: Mon, 4 Apr 2022 07:07:19 +0800 Subject: [PATCH 51/93] Add yaml for flatten_contiguous_range OP (#41345) * Add yaml for flatten_contiguous_range OP * update * Fix typos Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- .../{phi_test.mlir => disabled_phi_test.mlir} | 0 ...50_ops.mlir => disabled_resnet50_ops.mlir} | 0 paddle/phi/kernels/flatten_grad_kernel.cc | 2 +- paddle/phi/kernels/flatten_grad_kernel.h | 2 +- paddle/phi/ops/compat/flatten_sig.cc | 2 +- paddle/phi/tests/api/CMakeLists.txt | 1 - paddle/phi/tests/api/test_flatten_api.cc | 75 ------------------- .../test_flatten_contiguous_range_op.py | 6 +- python/paddle/tensor/manipulation.py | 6 +- python/paddle/utils/code_gen/api.yaml | 10 ++- python/paddle/utils/code_gen/backward.yaml | 13 ++++ tools/infrt/skipped_phi_api.json | 2 +- 12 files changed, 33 insertions(+), 86 deletions(-) rename paddle/infrt/tests/dialect/phi/{phi_test.mlir => disabled_phi_test.mlir} (100%) rename paddle/infrt/tests/dialect/phi/kernels/{resnet50_ops.mlir => disabled_resnet50_ops.mlir} (100%) delete mode 100644 paddle/phi/tests/api/test_flatten_api.cc diff --git a/paddle/infrt/tests/dialect/phi/phi_test.mlir b/paddle/infrt/tests/dialect/phi/disabled_phi_test.mlir similarity index 100% rename from paddle/infrt/tests/dialect/phi/phi_test.mlir rename to paddle/infrt/tests/dialect/phi/disabled_phi_test.mlir diff --git a/paddle/infrt/tests/dialect/phi/kernels/resnet50_ops.mlir b/paddle/infrt/tests/dialect/phi/kernels/disabled_resnet50_ops.mlir similarity index 100% rename from paddle/infrt/tests/dialect/phi/kernels/resnet50_ops.mlir rename to paddle/infrt/tests/dialect/phi/kernels/disabled_resnet50_ops.mlir diff --git a/paddle/phi/kernels/flatten_grad_kernel.cc b/paddle/phi/kernels/flatten_grad_kernel.cc index b7b45e46cf414..83f96c1f9f521 100644 --- a/paddle/phi/kernels/flatten_grad_kernel.cc +++ b/paddle/phi/kernels/flatten_grad_kernel.cc @@ -21,8 +21,8 @@ namespace phi { template void FlattenGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& xshape, + const DenseTensor& out_grad, DenseTensor* x_grad) { auto xshape_dims = xshape.dims(); dev_ctx.Alloc(x_grad, out_grad.dtype()); diff --git a/paddle/phi/kernels/flatten_grad_kernel.h b/paddle/phi/kernels/flatten_grad_kernel.h index 3ad27b430eb72..abd120e69b2e9 100644 --- a/paddle/phi/kernels/flatten_grad_kernel.h +++ b/paddle/phi/kernels/flatten_grad_kernel.h @@ -20,8 +20,8 @@ namespace phi { template void FlattenGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& xshape, + const DenseTensor& out_grad, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/ops/compat/flatten_sig.cc b/paddle/phi/ops/compat/flatten_sig.cc index b72ad05ea09d8..3e8119c38cf51 100644 --- a/paddle/phi/ops/compat/flatten_sig.cc +++ b/paddle/phi/ops/compat/flatten_sig.cc @@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature FlattenGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")}); + "flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")}); } } // namespace phi diff --git a/paddle/phi/tests/api/CMakeLists.txt b/paddle/phi/tests/api/CMakeLists.txt index cc05c0194804a..94378aceff58c 100644 --- a/paddle/phi/tests/api/CMakeLists.txt +++ b/paddle/phi/tests/api/CMakeLists.txt @@ -12,7 +12,6 @@ cc_test(test_dot_api SRCS test_dot_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS}) -cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS}) diff --git a/paddle/phi/tests/api/test_flatten_api.cc b/paddle/phi/tests/api/test_flatten_api.cc deleted file mode 100644 index f1c8935e26640..0000000000000 --- a/paddle/phi/tests/api/test_flatten_api.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/phi/api/include/api.h" - -#include "paddle/phi/api/lib/utils/allocator.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" - -PD_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT); - -namespace paddle { -namespace tests { - -namespace framework = paddle::framework; -using DDim = phi::DDim; - -// TODO(chenweihang): Remove this test after the API is used in the dygraph -TEST(API, flatten) { - // 1. create tensor - const auto alloc = std::make_unique( - paddle::platform::CPUPlace()); - auto dense_x = std::make_shared( - alloc.get(), - phi::DenseTensorMeta(phi::DataType::FLOAT32, - phi::make_ddim({3, 2, 2, 3}), - phi::DataLayout::NCHW)); - auto* dense_x_data = - dense_x->mutable_data(paddle::platform::CPUPlace()); - - for (int i = 0; i < dense_x->numel(); i++) { - dense_x_data[i] = i; - } - - paddle::experimental::Tensor x(dense_x); - int start_axis = 1, stop_axis = 2; - // 2. test API - auto out = paddle::experimental::flatten(x, start_axis, stop_axis); - - // 3. check result - std::vector expect_shape = {3, 4, 3}; - ASSERT_EQ(out.dims()[0], expect_shape[0]); - ASSERT_EQ(out.dims()[1], expect_shape[1]); - ASSERT_EQ(out.dims()[2], expect_shape[2]); - ASSERT_EQ(out.numel(), 36); - ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), phi::DataType::FLOAT32); - ASSERT_EQ(out.layout(), phi::DataLayout::NCHW); - ASSERT_EQ(out.initialized(), true); - bool value_equal = true; - auto dense_out = std::dynamic_pointer_cast(out.impl()); - auto* dense_out_data = dense_out->data(); - for (int i = 0; i < dense_x->numel(); i++) { - if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f) - value_equal = false; - } - ASSERT_EQ(value_equal, true); -} - -} // namespace tests -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index 9093050d6d5c6..ac352fcdf87ea 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -23,6 +23,8 @@ class TestFlattenOp(OpTest): def setUp(self): + self.python_api = paddle.flatten + self.python_out_sig = ["Out"] self.op_type = "flatten_contiguous_range" self.start_axis = 0 self.stop_axis = -1 @@ -35,10 +37,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=["XShape"]) + self.check_output(no_check_set=["XShape"], check_eager=True) def test_check_grad(self): - self.check_grad(["X"], "Out") + self.check_grad(["X"], "Out", check_eager=True) def init_test_case(self): self.in_shape = (3, 2, 5, 4) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 30e559151ed9e..b055abcf845f9 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -676,7 +676,11 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): if start_axis > stop_axis: raise ValueError("The stop_axis should be larger than stat_axis") - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + dy_out, _ = _C_ops.final_state_flatten(x, start_axis, stop_axis) + return dy_out + + if _in_legacy_dygraph(): dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis, 'stop_axis', stop_axis) return dy_out diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 139eb3556b058..2a0026fb50933 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -547,11 +547,15 @@ - api : flatten args : (Tensor x, int start_axis, int stop_axis) - output : Tensor + output : Tensor(out), Tensor(xshape) infer_meta : - func : FlattenInferMeta + func : FlattenWithXShapeInferMeta kernel : - func : flatten + func : flatten_with_xshape + backend : x + inplace : (x -> out) + view : (x -> out) + backward : flatten_grad # flip - api : flip diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 6ce0ae1b78a85..80ec2d9b84e54 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -349,6 +349,19 @@ kernel : func : expm1_grad +- backward_api : flatten_grad + forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape) + args : (Tensor xshape, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : KernelWithXShapeInferMeta + param : [xshape] + kernel : + func : flatten_grad + data_type: out_grad + backend: out_grad + layout: out_grad + - backward_api : floor_grad forward : floor(Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 74650846921b6..eef57a2d6b7bc 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "nll_loss"], +"phi_apis":["conj", "nll_loss", "flatten"], "phi_kernels":["equal_all"] } From bcb663ccbe2c19fb0cbaba9fbe25fc9cfcdb3be0 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 4 Apr 2022 07:50:23 +0800 Subject: [PATCH 52/93] [Phi] Support scale dygraph final state (#41321) * support scale final state * fix inplace error * pass arg directly * pass arg directly for inplace api * fix type --- .../final_state_generator/python_c_gen.py | 2 +- python/paddle/fluid/layers/nn.py | 3 +++ .../fluid/tests/unittests/test_scale_op.py | 19 +++++++++++-------- python/paddle/tensor/math.py | 11 +++++++---- python/paddle/utils/code_gen/api.yaml | 1 + python/paddle/utils/code_gen/backward.yaml | 2 +- 6 files changed, 24 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 463c50658cd32..8075b65b1945b 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -23,7 +23,7 @@ ########################### ## Global Configurations ## ########################### -skipped_forward_api_names = set(["scale"]) +skipped_forward_api_names = set([]) def SkipAPIGeneration(forward_api_name): diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0dcc8ee517fb1..d7ec3276d8b79 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11779,6 +11779,9 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): """ + if in_dygraph_mode(): + out = _C_ops.final_state_scale(x, scale, float(bias), bias_after_scale) + return dygraph_utils._append_activation_in_dygraph(out) if _non_static_mode(): _scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale out = _C_ops.scale(x, 'scale', diff --git a/python/paddle/fluid/tests/unittests/test_scale_op.py b/python/paddle/fluid/tests/unittests/test_scale_op.py index d432b8057f624..04ddb5a788d6f 100644 --- a/python/paddle/fluid/tests/unittests/test_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_scale_op.py @@ -27,6 +27,7 @@ class TestScaleOp(OpTest): def setUp(self): self.op_type = "scale" + self.python_api = paddle.scale self.dtype = np.float64 self.init_dtype_type() self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} @@ -39,15 +40,16 @@ def init_dtype_type(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestScaleOpScaleVariable(OpTest): def setUp(self): self.op_type = "scale" + self.python_api = paddle.scale self.dtype = np.float64 self.init_dtype_type() self.scale = -2.3 @@ -62,10 +64,10 @@ def init_dtype_type(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestScaleOpSelectedRows(unittest.TestCase): @@ -144,18 +146,19 @@ def init_dtype_type(self): def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=0.002) + self.check_output_with_place(place, atol=0.002, check_eager=True) def test_check_grad(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_grad_with_place( - place, ["X"], "Out", max_relative_error=0.05) + place, ["X"], "Out", max_relative_error=0.05, check_eager=True) class TestScaleBF16Op(OpTest): def setUp(self): self.op_type = "scale" + self.python_api = paddle.scale self.dtype = np.uint16 self.attrs = {'scale': -2.3} x = np.random.random((10, 10)).astype(np.float32) @@ -164,10 +167,10 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(out)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=0.8) + self.check_grad(['X'], 'Out', numeric_grad_delta=0.8, check_eager=True) @unittest.skipIf(not core.is_compiled_with_cuda(), diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index adca732dfdaa0..c552fb4c09ca5 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -98,10 +98,13 @@ def scale_(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): Inplace version of ``scale`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_tensor_scale`. """ - _scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale - return _C_ops.scale_(x, 'scale', - float(_scale), 'bias', - float(bias), 'bias_after_scale', bias_after_scale) + if in_dygraph_mode(): + return _C_ops.final_state_scale_(x, scale, float(bias), bias_after_scale) + if _in_legacy_dygraph(): + _scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale + return _C_ops.scale_(x, 'scale', + float(_scale), 'bias', + float(bias), 'bias_after_scale', bias_after_scale) def pow(x, y, name=None): diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 2a0026fb50933..507f8b3f36097 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1345,6 +1345,7 @@ kernel : func : scale, scale_sr inplace : (x -> out) + backward : scale_grad - api : scatter args : (Tensor x, Tensor index, Tensor updates, bool overwrite) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 80ec2d9b84e54..cb72040aa4ea5 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -947,7 +947,7 @@ - backward_api : scale_grad forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) - args : (Tensor out_grad, Scalar scale, float bias=0.0, bool bias_after_scale=true) + args : (Tensor out_grad, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true) output : Tensor(x_grad) invoke : scale(out_grad, scale, bias, bias_after_scale) From 0f165f0b34d9278620489c8323d57a23bfe58021 Mon Sep 17 00:00:00 2001 From: From00 Date: Mon, 4 Apr 2022 08:46:04 +0800 Subject: [PATCH 53/93] Add yaml for randint OP (#41375) --- paddle/phi/infermeta/nullary.cc | 28 +++++++++++ paddle/phi/infermeta/nullary.h | 3 ++ .../fluid/tests/unittests/test_randint_op.py | 47 +++++++++++++++++-- python/paddle/tensor/random.py | 11 +++-- python/paddle/utils/code_gen/api.yaml | 14 +++++- 5 files changed, 93 insertions(+), 10 deletions(-) diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 6a05e1b4d7f30..f76e7910d77b5 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -63,6 +63,34 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) { out->set_dtype(dtype); } +void RandintInferMeta( + int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out) { + PADDLE_ENFORCE_NOT_NULL( + out, errors::InvalidArgument("Output(Out) of RandintOp is null.")); + PADDLE_ENFORCE_LT( + low, + high, + errors::InvalidArgument("randint's low must less then high, " + "but received: low = %d, high = %d.", + low, + high)); + + auto& shape_vector = shape.GetData(); + PADDLE_ENFORCE_EQ( + shape_vector.empty(), + false, + errors::InvalidArgument("The shape information should not be empty, it " + "must be set by Attr(shape).")); + + std::vector tensor_shape; + tensor_shape.reserve(shape_vector.size()); + for (auto dim : shape_vector) { + tensor_shape.push_back(static_cast(dim)); + } + out->set_dims(make_ddim(tensor_shape)); + out->set_dtype(dtype); +} + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index ada44658a2c25..f84ac01d002d3 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -55,6 +55,9 @@ void GaussianRandomInferMeta(const IntArray& shape, void RandpermInferMeta(int n, DataType dtype, MetaTensor* out); +void RandintInferMeta( + int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out); + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/python/paddle/fluid/tests/unittests/test_randint_op.py b/python/paddle/fluid/tests/unittests/test_randint_op.py index 5f58054d7efc9..1eb99e08bb8e1 100644 --- a/python/paddle/fluid/tests/unittests/test_randint_op.py +++ b/python/paddle/fluid/tests/unittests/test_randint_op.py @@ -14,13 +14,14 @@ from __future__ import print_function +import os +import paddle import unittest import numpy as np from op_test import OpTest -import paddle from paddle.fluid import core +from paddle.fluid.framework import _test_eager_guard from paddle.static import program_guard, Program -import os paddle.enable_static() @@ -53,6 +54,10 @@ def verify_output(self, outs): np.allclose( hist, prob, rtol=0, atol=0.001), "hist: " + str(hist)) + def test_check_output_eager(self): + with _test_eager_guard(): + self.test_check_output() + class TestRandintOpError(unittest.TestCase): def test_errors(self): @@ -67,6 +72,10 @@ def test_errors(self): self.assertRaises( TypeError, paddle.randint, 5, shape=[shape_tensor]) + def test_errors_eager(self): + with _test_eager_guard(): + self.test_errors() + class TestRandintOp_attr_tensorlist(OpTest): def setUp(self): @@ -93,6 +102,10 @@ def verify_output(self, outs): np.allclose( hist, prob, rtol=0, atol=0.001), "hist: " + str(hist)) + def test_check_output_eager(self): + with _test_eager_guard(): + self.test_check_output() + class TestRandint_attr_tensor(OpTest): def setUp(self): @@ -114,6 +127,10 @@ def verify_output(self, outs): np.allclose( hist, prob, rtol=0, atol=0.001), "hist: " + str(hist)) + def test_check_output_eager(self): + with _test_eager_guard(): + self.test_check_output() + # Test python API class TestRandintAPI(unittest.TestCase): @@ -145,18 +162,30 @@ def test_api(self): feed={'var_shape': np.array([100, 100]).astype('int64')}, fetch_list=[out1, out2, out3, out4, out5]) + def test_api_eager(self): + with _test_eager_guard(): + self.test_api() + class TestRandintImperative(unittest.TestCase): def test_api(self): - n = 10 paddle.disable_static() + + self.run_test_case() + + with _test_eager_guard(): + self.run_test_case() + + paddle.enable_static() + + def run_test_case(self): + n = 10 x1 = paddle.randint(n, shape=[10], dtype="int32") x2 = paddle.tensor.randint(n) x3 = paddle.tensor.random.randint(n) for i in [x1, x2, x3]: for j in i.numpy().tolist(): self.assertTrue((j >= 0 and j < n)) - paddle.enable_static() class TestRandomValue(unittest.TestCase): @@ -174,6 +203,15 @@ def test_fixed_random_number(self): print("Test Fixed Random number on GPU------>") paddle.disable_static() + + self.run_test_case() + + with _test_eager_guard(): + self.run_test_case() + + paddle.enable_static() + + def run_test_case(self): paddle.set_device('gpu') paddle.seed(100) @@ -198,7 +236,6 @@ def test_fixed_random_number(self): self.assertTrue(np.array_equal(x[20, 1, 600, 600:605], expect)) expect = [3581, 3420, -8027, -5237, -2436] self.assertTrue(np.array_equal(x[30, 2, 1000, 1000:1005], expect)) - paddle.enable_static() if __name__ == "__main__": diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 20f4e73b2718a..d2e4363443720 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -22,7 +22,7 @@ import paddle from paddle import _C_ops from paddle.static import Variable -from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph, _current_expected_place __all__ = [] @@ -687,7 +687,11 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + shape = utils.convert_shape_to_list(shape) + place = _current_expected_place() + return _C_ops.final_state_randint(low, high, shape, dtype, place) + if _in_legacy_dygraph(): shape = utils.convert_shape_to_list(shape) return _C_ops.randint('shape', shape, 'low', low, 'high', high, 'seed', 0, 'dtype', dtype) @@ -920,8 +924,7 @@ def randperm(n, dtype="int64", name=None): dtype = convert_np_dtype_to_dtype_(dtype) if in_dygraph_mode(): - return _C_ops.final_state_randperm( - n, dtype, paddle.fluid.framework._current_expected_place()) + return _C_ops.final_state_randperm(n, dtype, _current_expected_place()) if _in_legacy_dygraph(): return _C_ops.randperm('n', n, 'seed', 0, 'dtype', dtype) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 507f8b3f36097..fb0c6e294a0f0 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1265,6 +1265,18 @@ data_type : x backward : put_along_axis_grad +- api : randint + args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={}) + output : Tensor(out) + infer_meta : + func : RandintInferMeta + param : [low, high, shape, dtype] + kernel : + func : randint + param : [low, high, shape, dtype] + data_type : dtype + backend : place + - api : randperm args : (int n, DataType dtype, Place place={}) output : Tensor @@ -1276,7 +1288,7 @@ param : [n, dtype] data_type : dtype backend : place - + - api : reciprocal args : (Tensor x) output : Tensor From 84b63a26bcb109e56cbe7223aa98dd308fb19136 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 4 Apr 2022 10:01:38 +0800 Subject: [PATCH 54/93] [Phi] Add add_n(sum) infermeta and yaml (#41362) * add add_n infermeta * forward run success * add add_n grad yaml --- paddle/phi/api/lib/api_custom_impl.cc | 46 ++++++++++++ paddle/phi/api/lib/api_custom_impl.h | 3 + paddle/phi/infermeta/multiary.cc | 72 +++++++++++++++++++ paddle/phi/infermeta/multiary.h | 4 ++ .../fluid/tests/unittests/test_sum_op.py | 22 ++++++ python/paddle/tensor/math.py | 6 +- python/paddle/utils/code_gen/api.yaml | 9 +++ python/paddle/utils/code_gen/backward.yaml | 7 ++ 8 files changed, 168 insertions(+), 1 deletion(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 152873fe41072..3818572db0c20 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -31,6 +31,52 @@ limitations under the License. */ namespace paddle { namespace experimental { +// TODO(chenweihang): the original sum grad op can support higher-level +// differentiation, +// but if we use this impl, it will not support. We need to be able to reuse +// the autograd API here, which is not yet implemented +// TODO(chenweihang): we should support call generated api in custom api impl +std::vector add_n_grad_impl(const std::vector& x, + const Tensor& out_grad) { + auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + + Backend kernel_backend = kernel_key.backend(); + DataLayout kernel_layout = kernel_key.layout(); + DataType kernel_data_type = kernel_key.dtype(); + + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "scale", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "add_n_grad API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "add_n_grad API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {}); + + size_t out_number = x.size(); + std::vector x_grad; + auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::Scalar&, + float, + bool, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + + for (auto* dense_x_grad_t : dense_x_grad) { + phi::MetaTensor meta_out(dense_x_grad_t); + phi::UnchangedInferMeta(MakeMetaTensor(*dense_out_grad), &meta_out); + (*kernel_fn)( + *dev_ctx, *dense_out_grad, phi::Scalar(1.0), 0.0, true, dense_x_grad_t); + } + + return x_grad; +} + Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { auto kernel_key_set = ParseKernelKeyByInputArgs(x); kernel_key_set.backend_set = diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index b2f5a074d9288..f9a11b4bd9683 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -22,6 +22,9 @@ limitations under the License. */ namespace paddle { namespace experimental { +std::vector add_n_grad_impl(const std::vector& x, + const Tensor& out_grad); + Tensor copy_to_impl(const Tensor& x, Place place, bool blocking); std::vector split_impl(const Tensor& x, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 4fbd264f10f9f..42041af2dfe9e 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -279,6 +279,78 @@ void AdamwInferMeta(const MetaTensor& param, master_param_outs); } +void AddNInferMeta(const std::vector& x, + MetaTensor* out, + MetaConfig config) { + auto N = x.size(); + PADDLE_ENFORCE_GT( + N, + 0, + phi::errors::InvalidArgument( + "The input tensor X's dimensions of SumOp " + "should be larger than 0. But received X's dimensions %d.", + N)); + if (N == 1) { + VLOG(3) << "Warning: SumOp have only one input, may waste memory"; + } + + phi::DDim in_dim({0}); + for (size_t i = 0; i < x.size(); ++i) { + auto x_dim = x[i]->dims(); + if (phi::product(x_dim) == 0) { + continue; + } + if (phi::product(in_dim) == 0) { + in_dim = x_dim; + } else { + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(in_dim, + x_dim, + phi::errors::InvalidArgument( + "The input tensor X of SumOp must" + " have same shape. But received X[0]'s shape = " + "[%s], X[%d]'s shape = [%s].", + in_dim, + i, + x_dim)); + } else { + PADDLE_ENFORCE_EQ( + in_dim.size(), + x_dim.size(), + phi::errors::InvalidArgument( + "The input tensor X of SumOp must have same " + "dimensions. But received X[0]'s dimensions = %d, X[0]'s " + "shape = " + "[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].", + in_dim.size(), + in_dim, + i, + x_dim.size(), + i, + x_dim)); + // if in_dim or x_dim has -1, not check equal + for (int j = 0; j < x_dim.size(); ++j) { + if (x_dim[j] == -1 || in_dim[j] == -1) { + continue; + } + PADDLE_ENFORCE_EQ( + in_dim[j], + x_dim[j], + phi::errors::InvalidArgument( + "The input tensor X of SumOp must have same shape " + "if not -1." + "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].", + in_dim, + i, + x_dim)); + } + } + } + } + out->set_dims(in_dim); + out->share_lod(*x[0]); +} + void AucInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& stat_pos, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 64a11ed0b2621..0b1ccfcb90541 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -117,6 +117,10 @@ void AdamwInferMeta(const MetaTensor& param, MetaTensor* beta2_pow_out, MetaTensor* master_param_outs); +void AddNInferMeta(const std::vector& x, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void AucInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& stat_pos, diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 7040145a76833..6f625c097979b 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -25,6 +25,7 @@ from paddle.fluid.tests.unittests.op_test import ( OpTest, convert_float_to_uint16, convert_uint16_to_float) from paddle import _C_ops +from paddle.fluid.framework import _test_eager_guard class TestSumOp(OpTest): @@ -347,6 +348,27 @@ def test_api(self): self.assertEqual((sum_value.numpy() == expected_result).all(), True) + def test_dygraph_final_state_api(self): + with fluid.dygraph.guard(): + with _test_eager_guard(): + input0 = paddle.ones(shape=[2, 3], dtype='float32') + input1 = paddle.ones(shape=[2, 3], dtype='float32') + input0.stop_gradient = False + input1.stop_gradient = False + expected_result = np.empty((2, 3)) + expected_result.fill(2) + sum_value = paddle.add_n([input0, input1]) + self.assertEqual((sum_value.numpy() == expected_result).all(), + True) + + expected_grad_result = np.empty((2, 3)) + expected_grad_result.fill(1) + sum_value.backward() + self.assertEqual( + (input0.grad.numpy() == expected_grad_result).all(), True) + self.assertEqual( + (input1.grad.numpy() == expected_grad_result).all(), True) + class TestRaiseSumError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c552fb4c09ca5..3408dd7ce9384 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1068,7 +1068,11 @@ def add_n(inputs, name=None): # [[8., 10., 12.], # [14., 16., 18.]] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if isinstance(inputs, Variable): + inputs = [inputs] + return _C_ops.final_state_add_n(inputs) + if _in_legacy_dygraph(): if isinstance(inputs, Variable): inputs = [inputs] return _C_ops.sum(inputs, 'use_mkldnn', False) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index fb0c6e294a0f0..f38a9bc619eba 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -63,6 +63,15 @@ backward : add_grad # no_need_buffer : x, y +- api : add_n + args : (Tensor[] x) + output : Tensor + infer_meta : + func : AddNInferMeta + kernel : + func : add_n + backward : add_n_grad + - api : addmm args : (Tensor input, Tensor x, Tensor y, float alpha, float beta) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index cb72040aa4ea5..7b6c383286601 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -41,6 +41,13 @@ func : add_grad no_need_buffer : x, y +- backward_api : add_n_grad + forward : add_n (Tensor[] x) -> Tensor(out) + args : (Tensor[] x, Tensor out_grad) + output : Tensor[](x_grad) + invoke : add_n_grad_impl(x, out_grad) + no_need_buffer : x + - backward_api : addmm_grad forward : scatter (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out) args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha, float beta) From 42075ddcadcca03d2ad39414f8b636bdde443b28 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 4 Apr 2022 10:06:33 +0800 Subject: [PATCH 55/93] [Eager] support tensor uva, test=windows_ci (#41310) * [Eager] support tensor uva, test=windows_ci * Add headers to fix CI, test=windows_ci * Expose _uva python interface, Fix windows ci issue --- paddle/fluid/pybind/eager_functions.cc | 49 +++++++++++++++ paddle/fluid/pybind/eager_method.cc | 25 ++++++++ paddle/fluid/pybind/imperative.cc | 33 +--------- paddle/fluid/pybind/tensor_py.h | 39 ++++++++++-- paddle/fluid/pybind/uva_utils.h | 60 +++++++++++++++++++ .../fluid/dygraph/varbase_patch_methods.py | 5 ++ .../fluid/tests/unittests/test_tensor_uva.py | 21 ++++++- 7 files changed, 194 insertions(+), 38 deletions(-) create mode 100644 paddle/fluid/pybind/uva_utils.h diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 0c6707748ef5a..fb115455357dd 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -772,6 +772,53 @@ static PyObject* eager_api_async_write(PyObject* self, PyObject* args, return Py_None; EAGER_CATCH_AND_THROW_RETURN_NULL } + +static PyObject* eager_api_to_uva_tensor(PyObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + VLOG(4) << "Running in eager_api_to_uva_tensor."; + auto new_tensor = std::shared_ptr( + new paddle::experimental::Tensor( + egr::Controller::Instance().GenerateUniqueName())); + PyObject* obj = PyTuple_GET_ITEM(args, 0); + auto array = py::cast(py::handle(obj)); + + int device_id = 0; + PyObject* Py_device_id = PyTuple_GET_ITEM(args, 1); + if (Py_device_id) { + device_id = CastPyArg2AttrLong(Py_device_id, 1); + } + + if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, + device_id); + } else if (py::isinstance>(array)) { + SetUVATensorFromPyArray(new_tensor, array, device_id); + } else { + // obj may be any type, obj.cast() may be failed, + // then the array.dtype will be string of unknown meaning. + PADDLE_THROW(platform::errors::InvalidArgument( + "Input object type error or incompatible array data type. " + "tensor.set() supports array with bool, float16, float32, " + "float64, int8, int16, int32, int64," + "please check your input or input array data type.")); + } + + return ToPyObject(*(new_tensor.get())); + EAGER_CATCH_AND_THROW_RETURN_NULL +} #endif PyMethodDef variable_functions[] = { @@ -803,6 +850,8 @@ PyMethodDef variable_functions[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"async_write", (PyCFunction)(void (*)(void))eager_api_async_write, METH_VARARGS | METH_KEYWORDS, NULL}, + {"to_uva_tensor", (PyCFunction)(void (*)(void))eager_api_to_uva_tensor, + METH_VARARGS | METH_KEYWORDS, NULL}, #endif {NULL, NULL, 0, NULL}}; diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 814243e0a5774..66fba92f67b83 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -32,6 +32,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/slice_utils.h" +#include "paddle/fluid/pybind/uva_utils.h" #include "paddle/phi/api/include/api.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" @@ -1343,6 +1344,26 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +#if defined(PADDLE_WITH_CUDA) +static PyObject* tensor_method__uva(TensorObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + VLOG(4) << "Running in tensor_method__uva."; + PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.inner_place()), true, + platform::errors::InvalidArgument( + "Unified virtual addressing only support " + "CPU Tensor currently.")); + int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0); + auto* self_tensor = + static_cast(self->tensor.impl().get()); + tensor_uva(self_tensor, device_id); + + Py_INCREF(Py_None); + return Py_None; + EAGER_CATCH_AND_THROW_RETURN_NULL +} +#endif + PyMethodDef variable_methods[] = { {"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy, METH_VARARGS | METH_KEYWORDS, NULL}, @@ -1447,6 +1468,10 @@ PyMethodDef variable_methods[] = { {"_reset_grad_inplace_version", (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version, METH_VARARGS | METH_KEYWORDS, NULL}, +#if defined(PADDLE_WITH_CUDA) + {"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva, + METH_VARARGS | METH_KEYWORDS, NULL}, +#endif {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 0286560ec9982..7df6d8f7f791c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -57,6 +57,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/slice_utils.h" #include "paddle/fluid/pybind/tensor_py.h" +#include "paddle/fluid/pybind/uva_utils.h" #include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/type_defs.h" @@ -1629,39 +1630,9 @@ void BindImperative(py::module *m_ptr) { platform::errors::InvalidArgument( "Unified virtual addressing only support " "CPU Tensor currently.")); - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto *dev_ctx = pool.Get(platform::CUDAPlace(device_id)); - VLOG(4) << "Init the DeviceContext, and the place is " - << dev_ctx->GetPlace(); auto *self_tensor = self->MutableVar()->GetMutable(); - // Register the cpu memory as the cuda host memory - const auto &data_numel = self_tensor->numel(); - const size_t &need_allocate_size = - data_numel * - framework::SizeOfType( - framework::TransToProtoVarType(self_tensor->dtype())); - void *data_ptr = self_tensor->data(); - auto result = cudaHostRegister(data_ptr, need_allocate_size, - cudaHostRegisterDefault); - if (cudaSuccess != result) { - VLOG(4) << "UVA(unified virtual addressing) failed allocate:" - << need_allocate_size << ", the error code:" << result; - } - - // Get device pointer from the function of cudaHostGetDevicePointer - void *cuda_device_pointer = nullptr; - cudaHostGetDevicePointer( - reinterpret_cast(&cuda_device_pointer), - reinterpret_cast(data_ptr), 0); - - // Reset the memory with device pointer - std::shared_ptr holder = - std::make_shared( - cuda_device_pointer, need_allocate_size, - platform::CUDAPlace(device_id)); - self_tensor->ResetHolderWithType(holder, self_tensor->dtype()); + tensor_uva(self_tensor, device_id); }, py::arg("device_id") = 0, py::return_value_policy::reference, R"DOC( Returns self tensor with the UVA(unified virtual addressing). diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index bf459bd468421..3f7ce8b63f968 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -529,11 +529,10 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, } template -void SetUVATensorFromPyArray( - const std::shared_ptr &self, - const py::array_t &array, int device_id) { +void SetUVATensorFromPyArrayImpl(framework::LoDTensor *self_tensor, + const py::array_t &array, int device_id) { #if defined(PADDLE_WITH_CUDA) - auto *self_tensor = self->MutableVar()->GetMutable(); + VLOG(4) << "Running in SetUVATensorFromPyArrayImpl."; std::vector dims; dims.reserve(array.ndim()); int64_t numel = 1; @@ -562,6 +561,38 @@ void SetUVATensorFromPyArray( #endif } +template +void SetUVATensorFromPyArray( + const std::shared_ptr &self, + const py::array_t &array, int device_id) { +#if defined(PADDLE_WITH_CUDA) + VLOG(4) << "Running in SetUVATensorFromPyArray for VarBase."; + auto *self_tensor = self->MutableVar()->GetMutable(); + SetUVATensorFromPyArrayImpl(self_tensor, array, device_id); +#endif +} + +template +void SetUVATensorFromPyArray( + const std::shared_ptr &self, + const py::array_t &array, int device_id) { +#if defined(PADDLE_WITH_CUDA) + VLOG(4) << "Running in SetUVATensorFromPyArray for Phi::Tensor."; + phi::DenseTensorMeta meta = + phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim({1, 1})); + std::shared_ptr tmp_t = std::make_shared( + std::make_unique( + paddle::platform::CPUPlace()) + .get(), + meta); + self.get()->set_impl(tmp_t); + auto *self_tensor = + static_cast(self.get()->impl().get()); + + SetUVATensorFromPyArrayImpl(self_tensor, array, device_id); +#endif +} + template void _sliceCompute(const framework::Tensor *in, framework::Tensor *out, const platform::CPUDeviceContext &ctx, diff --git a/paddle/fluid/pybind/uva_utils.h b/paddle/fluid/pybind/uva_utils.h new file mode 100644 index 0000000000000..94f55769b7356 --- /dev/null +++ b/paddle/fluid/pybind/uva_utils.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/operators/utils.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace paddle { +namespace pybind { + +static void tensor_uva(paddle::framework::LoDTensor *self_tensor, + int device_id) { + VLOG(4) << "Running in _uva interface."; +#if defined(PADDLE_WITH_CUDA) + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *dev_ctx = pool.Get(platform::CUDAPlace(device_id)); + VLOG(4) << "Init the DeviceContext, and the place is " << dev_ctx->GetPlace(); + // Register the cpu memory as the cuda host memory + const auto &data_numel = self_tensor->numel(); + const size_t &need_allocate_size = + data_numel * framework::SizeOfType( + framework::TransToProtoVarType(self_tensor->dtype())); + void *data_ptr = self_tensor->data(); + auto result = + cudaHostRegister(data_ptr, need_allocate_size, cudaHostRegisterDefault); + if (cudaSuccess != result) { + VLOG(4) << "UVA(unified virtual addressing) failed allocate:" + << need_allocate_size << ", the error code:" << result; + } + // Get device pointer from the function of cudaHostGetDevicePointer + void *cuda_device_pointer = nullptr; + cudaHostGetDevicePointer(reinterpret_cast(&cuda_device_pointer), + reinterpret_cast(data_ptr), 0); + + // Reset the memory with device pointer + std::shared_ptr holder = + std::make_shared( + cuda_device_pointer, need_allocate_size, + platform::CUDAPlace(device_id)); + self_tensor->ResetHolderWithType(holder, self_tensor->dtype()); +#endif +} + +} // namespace pybind +} // namespace paddle diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index f4871ba64e571..c97471d25f19c 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -816,6 +816,10 @@ def _slice(self, begin_idx, end_idx): def _numel(self): return self.get_tensor()._numel() + @framework.dygraph_only + def _uva(self, device_id=0): + self._tensor_uva(device_id) + @framework.dygraph_only def cpu(self): if self.place.is_cpu_place(): @@ -874,6 +878,7 @@ def pin_memory(self): setattr(core.eager.Tensor, "pin_memory", pin_memory) setattr(core.eager.Tensor, "_slice", _slice) setattr(core.eager.Tensor, "_numel", _numel) + setattr(core.eager.Tensor, "_uva", _uva) else: setattr(core.VarBase, "__name__", "Tensor") setattr(core.VarBase, "grad", grad) diff --git a/python/paddle/fluid/tests/unittests/test_tensor_uva.py b/python/paddle/fluid/tests/unittests/test_tensor_uva.py index c60d4d98d7154..4af04b8f6d41e 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_uva.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_uva.py @@ -15,10 +15,12 @@ import paddle import unittest import numpy as np +from paddle.fluid import core +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class TestTensorCopyFrom(unittest.TestCase): - def test_main(self): + def func_main(self): if paddle.fluid.core.is_compiled_with_cuda(): place = paddle.CPUPlace() np_value = np.random.random(size=[10, 30]).astype('float32') @@ -26,9 +28,14 @@ def test_main(self): tensor._uva() self.assertTrue(tensor.place.is_gpu_place()) + def test_main(self): + with _test_eager_guard(): + self.func_main() + self.func_main() + class TestUVATensorFromNumpy(unittest.TestCase): - def test_uva_tensor_creation(self): + def func_uva_tensor_creation(self): if paddle.fluid.core.is_compiled_with_cuda(): dtype_list = [ "int32", "int64", "float32", "float64", "float16", "int8", @@ -36,10 +43,18 @@ def test_uva_tensor_creation(self): ] for dtype in dtype_list: data = np.random.randint(10, size=[4, 5]).astype(dtype) - tensor = paddle.fluid.core.to_uva_tensor(data, 0) + if _in_legacy_dygraph(): + tensor = paddle.fluid.core.to_uva_tensor(data, 0) + else: + tensor = core.eager.to_uva_tensor(data, 0) self.assertTrue(tensor.place.is_gpu_place()) self.assertTrue(np.allclose(tensor.numpy(), data)) + def test_uva_tensor_creation(self): + with _test_eager_guard(): + self.func_uva_tensor_creation() + self.func_uva_tensor_creation() + if __name__ == "__main__": unittest.main() From 49e4e2f9afbba1cc46de1f6e17ff931930ca5b14 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 4 Apr 2022 10:08:17 +0800 Subject: [PATCH 56/93] [Eager] Support rnn_decode switch to eager mode (#41333) --- paddle/fluid/pybind/op_function_generator.h | 1 + python/paddle/fluid/layers/sequence_lod.py | 18 +++++++++++++++++- .../tests/unittests/test_rnn_decode_api.py | 19 ++++++++++++++++--- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index b8202fe8c51fd..ba4abc8d13536 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -106,6 +106,7 @@ std::map> op_ins_map = { {"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}}, {"crf_decoding", {"Emission", "Transition", "Label", "Length"}}, {"chunk_eval", {"Inference", "Label", "SeqLength"}}, + {"sequence_mask", {"X", "MaxLenTensor"}}, {"graph_reindex", {"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"}}, {"graph_sample_neighbors", {"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"}}, diff --git a/python/paddle/fluid/layers/sequence_lod.py b/python/paddle/fluid/layers/sequence_lod.py index 1aa3e357c4fd7..1758123f0e608 100644 --- a/python/paddle/fluid/layers/sequence_lod.py +++ b/python/paddle/fluid/layers/sequence_lod.py @@ -15,10 +15,11 @@ from __future__ import print_function from .layer_function_generator import templatedoc -from ..framework import Variable, _non_static_mode +from ..framework import core, Variable, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, convert_np_dtype_to_dtype_ from ..layer_helper import LayerHelper from ..data_feeder import check_variable_and_dtype, check_type, check_dtype from ..core import VarDesc +from paddle import _C_ops __all__ = [ 'sequence_conv', @@ -1380,6 +1381,21 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): # [1 1 1 1 1 1 1 1 0 0]] """ + + if _non_static_mode(): + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + if in_dygraph_mode(): + if maxlen is not None: + if isinstance(maxlen, core.eager.Tensor): + attrs = ('out_dtype', dtype) + out = _C_ops.sequence_mask(x, maxlen, *attrs) + else: + attrs = ('out_dtype', dtype, 'maxlen', maxlen) + out = _C_ops.sequence_mask(x, None, *attrs) + out.stop_gradient = True + return out + helper = LayerHelper('sequence_mask', **locals()) out = helper.create_variable_for_type_inference(dtype=dtype) diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index a0009a71b3ef7..bf848357e3195 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -31,7 +31,7 @@ from paddle.fluid.executor import Executor from paddle.fluid import framework - +from paddle.fluid.framework import _test_eager_guard paddle.enable_static() @@ -554,7 +554,7 @@ def test_beam_search_infer(self): }, fetch_list=[output])[0] - def test_dynamic_basic_decoder(self): + def func_dynamic_basic_decoder(self): paddle.disable_static() src = paddle.to_tensor(np.random.randint(8, size=(8, 4))) src_length = paddle.to_tensor(np.random.randint(8, size=(8))) @@ -562,6 +562,11 @@ def test_dynamic_basic_decoder(self): probs, samples, sample_length = model(src, src_length) paddle.enable_static() + def test_dynamic_basic_decoder(self): + with _test_eager_guard(): + self.func_dynamic_basic_decoder() + self.func_dynamic_basic_decoder() + class ModuleApiTest(unittest.TestCase): @classmethod @@ -708,9 +713,17 @@ def make_inputs(self): ] return inputs - def test_check_output(self): + def func_check_output(self): + self.setUp() + self.make_inputs() + self.make_inputs() self.check_output() + def test_check_output(self): + with _test_eager_guard(): + self.func_check_output() + self.func_check_output() + if __name__ == '__main__': unittest.main() From 0bcfc4747410a52e138e63cd5b1edb4062f3fa4b Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 4 Apr 2022 10:49:20 +0800 Subject: [PATCH 57/93] fix eager gen opti bug (#41302) * fix eager gen opti bug * polish code * fix some bug * fix some bugs; --- .../final_state_generator/eager_gen.py | 19 ++++++++++++++++--- paddle/fluid/eager/utils.cc | 16 ---------------- paddle/fluid/eager/utils.h | 3 --- paddle/fluid/pybind/eager_utils.cc | 2 +- paddle/phi/api/include/tensor.h | 2 +- paddle/phi/api/lib/api_gen_utils.cc | 16 ---------------- paddle/phi/api/lib/api_gen_utils.h | 6 ------ python/paddle/utils/code_gen/api_base.py | 14 +++++++++----- 8 files changed, 27 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 88688672b18b5..3a7e5fbcc0f86 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -359,6 +359,12 @@ class {} : public egr::GradNodeBase {{ if({}.initialized()) {}_optional = paddle::make_optional({}); """ +CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = \ +""" + paddle::optional {}_optional = paddle::none; + if( {}.impl() ) {}_optional = paddle::make_optional({}); +""" + ####################### ## Generator Helpers ## @@ -1248,11 +1254,18 @@ def GenerateNodeDefinition(self, grad_node_creation_str): name) is_optional = (name in self.optional_inputs) + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" if is_optional: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" + tensor_wrapper_recover_str += "\n" + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format( + transformed_tensor_name, transformed_tensor_name, + transformed_tensor_name, transformed_tensor_name) + + grad_api_args[ + grad_api_position] = transformed_tensor_name + "_optional" + else: - tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" - grad_api_args[grad_api_position] = transformed_tensor_name + grad_api_args[grad_api_position] = transformed_tensor_name + get_grad_in_args_list.append(tensor_wrapper_recover_str) # Grad Ins from grads diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index dfbc96a9db836..bcf4a4627bb76 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -364,22 +364,6 @@ paddle::experimental::Tensor EagerUtils::RecoverTensorWrapper( return tw->recover(grad_node); } -paddle::optional -EagerUtils::RecoverOptionalTensorWrapper( - TensorWrapper* tw, const std::shared_ptr& grad_node) { - PADDLE_ENFORCE_NOT_NULL( - tw, phi::errors::InvalidArgument("TensorWrapper in " - "RecoverOptionalTensorWrapper function " - "should not be null")); - auto tmp = tw->recover(grad_node); - - paddle::optional res{paddle::none}; - if (tmp.initialized()) { - res = tmp; - } - return res; -} - std::vector EagerUtils::RecoverTensorWrapper( std::vector* tw, const std::shared_ptr& grad_node) { diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index beb46d876c4a1..be534d4440561 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -179,9 +179,6 @@ class EagerUtils { static std::vector RecoverTensorWrapper( std::vector* tw, const std::shared_ptr& grad_node); - static paddle::optional - RecoverOptionalTensorWrapper(TensorWrapper* tw, - const std::shared_ptr& grad_node); // Intermidate needed remove this once we don't need legacy // Inner Method diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index e245362c50be5..bdc96e85e44ae 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -971,7 +971,7 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, std::vector value = CastPyArg2Ints(obj, op_type, arg_pos); return paddle::experimental::IntArray(value); - } else if (type_name == "paddle.Tensor") { + } else if (type_name == "paddle.Tensor" || type_name == "Tensor") { paddle::experimental::Tensor& value = GetTensorFromPyObject( op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/); return paddle::experimental::IntArray(value); diff --git a/paddle/phi/api/include/tensor.h b/paddle/phi/api/include/tensor.h index 0a2e815be8411..3c5c1531c4a2d 100644 --- a/paddle/phi/api/include/tensor.h +++ b/paddle/phi/api/include/tensor.h @@ -567,7 +567,7 @@ class PADDLE_API Tensor final { * heterogeneous Tensor implementation, so that the API level can be unified * to one `Tensor`. */ - std::shared_ptr impl_; + std::shared_ptr impl_{nullptr}; /** * [ Why need abstract AbstractAutogradMeta here? ] diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 7cbb4344e81d7..732ecacde94d7 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -66,14 +66,6 @@ phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) { return phi::MetaTensor(tensor); } -paddle::optional MakeMetaTensor( - const paddle::optional& tensor) { - if (tensor) { - return {phi::MetaTensor(*tensor)}; - } - return {paddle::none}; -} - std::vector MakeMetaTensor( const std::vector& tensors) { std::vector meta_tensors; @@ -88,14 +80,6 @@ phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) { return phi::MetaTensor(tensor); } -paddle::optional MakeMetaTensor( - const paddle::optional& tensor) { - if (tensor) { - return {phi::MetaTensor(*tensor)}; - } - return {paddle::none}; -} - phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor) { return phi::MetaTensor(tensor); } diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 2a4c8417b5e6d..d7ecef61c5be3 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -50,17 +50,11 @@ std::shared_ptr TensorToStringTensor(const Tensor& tensor); phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor); -paddle::optional MakeMetaTensor( - const paddle::optional& tensor); - std::vector MakeMetaTensor( const std::vector& tensors); phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor); -paddle::optional MakeMetaTensor( - const paddle::optional& tensor); - phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor); /* ------------------ for output ----------------------- */ diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 14f22fced9230..c1a987d06ba39 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -480,11 +480,15 @@ def gene_infer_meta(self, kernel_output_names, code_indent) -> str: param_code = param_code + param + "_metas, " elif param in self.optional_vars: meta_tensor_code = meta_tensor_code + f""" -{code_indent} paddle::optional {PREFIX_TENSOR_NAME}meta_ref_{param}(paddle::none); -{code_indent} auto {PREFIX_TENSOR_NAME}meta_{param} = MakeMetaTensor({PREFIX_TENSOR_NAME}{param}); -{code_indent} if ({PREFIX_TENSOR_NAME}meta_{param}) {{ -{code_indent} {PREFIX_TENSOR_NAME}meta_ref_{param} = paddle::make_optional(*{PREFIX_TENSOR_NAME}meta_{param}); -{code_indent} }}""" +{code_indent} paddle::optional {PREFIX_TENSOR_NAME}meta_ref_{param} = paddle::none; +{code_indent} phi::DenseTensor dt; +{code_indent} phi::MetaTensor {PREFIX_TENSOR_NAME}meta_tmp_{param}(dt); +{code_indent} if ({PREFIX_TENSOR_NAME}{param}_ptr) {{ +{code_indent} {PREFIX_TENSOR_NAME}meta_tmp_{param}.set_dtype( {PREFIX_TENSOR_NAME}{param}_ptr->dtype() ); +{code_indent} {PREFIX_TENSOR_NAME}meta_tmp_{param}.set_dims( {PREFIX_TENSOR_NAME}{param}_ptr->dims() ); +{code_indent} {PREFIX_TENSOR_NAME}meta_tmp_{param}.set_layout( {PREFIX_TENSOR_NAME}{param}_ptr->layout() ); +{code_indent} {PREFIX_TENSOR_NAME}meta_ref_{param} = {PREFIX_TENSOR_NAME}meta_tmp_{param}; +{code_indent} }}\n""" param_code = param_code + f"{PREFIX_TENSOR_NAME}meta_ref_{param}, " else: From 119816f98b339e013ef16ea044aafb90517f2bfe Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Mon, 4 Apr 2022 11:33:28 +0800 Subject: [PATCH 58/93] [Yaml]Add concat grad yaml (#41365) * add concat_grad kernel * fix error * remove comment code * fix outs nullptr error * change to phi header * add concat_grad declare for standalone_executor_test * add concat_grad yaml * add concat api * fix test concat op error * fix test concat op error --- paddle/phi/api/lib/CMakeLists.txt | 2 +- paddle/phi/api/lib/api_custom_impl.cc | 66 +++++++++++++++++++ paddle/phi/api/lib/api_custom_impl.h | 4 ++ paddle/phi/infermeta/multiary.cc | 7 ++ paddle/phi/infermeta/multiary.h | 3 + python/paddle/fluid/layers/tensor.py | 10 ++- .../fluid/tests/unittests/test_concat_op.py | 26 +++++--- python/paddle/utils/code_gen/api.yaml | 1 + python/paddle/utils/code_gen/backward.yaml | 6 ++ 9 files changed, 115 insertions(+), 10 deletions(-) diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index af2533019156c..d4d8a0fa8a304 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -165,7 +165,7 @@ cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place) cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool) cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor) cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform) -cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) +cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform backward_infermeta) cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 3818572db0c20..ce49680586caa 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/nullary.h" @@ -166,5 +167,70 @@ std::vector split_impl(const Tensor& x, return out; } +std::vector concat_grad_impl(const std::vector& x, + const Tensor& out_grad, + const Scalar& axis) { + auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + + Backend kernel_backend = kernel_key.backend(); + DataLayout kernel_layout = kernel_key.layout(); + DataType kernel_data_type = kernel_key.dtype(); + + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "concat_grad", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "concat_grad API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "concat_grad API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + // std::unique_ptr> + auto dense_x = PrepareData(x, kernel.InputAt(0), {}); + auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(1), {}); + + // Calculate the number of out tensors + size_t out_number = x.size(); + std::vector x_grad; + auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad); + + std::vector meta_x; + meta_x.reserve(x.size()); + std::vector meta_x_ptrs; + meta_x_ptrs.reserve(x.size()); + for (const auto& t : *dense_x) { + meta_x.push_back(t); + meta_x_ptrs.push_back(&meta_x.back()); + } + + std::vector meta_x_grad; + meta_x_grad.reserve(x.size()); + std::vector meta_x_grad_ptrs; + meta_x_grad_ptrs.reserve(x.size()); + for (size_t i = 0; i < out_number; ++i) { + meta_x_grad.push_back(*dense_x_grad[i]); + meta_x_grad_ptrs.push_back(&meta_x_grad.back()); + } + + phi::UnchangedMultiInferMeta(meta_x_ptrs, meta_x_grad_ptrs); + + std::vector dense_x_ptr; + dense_x_ptr.reserve(x.size()); + for (const auto& t : *dense_x) { + dense_x_ptr.push_back(&t); + } + + using kernel_signature = void (*)(const platform::DeviceContext&, + const std::vector&, + const phi::DenseTensor&, + const phi::Scalar&, + std::vector); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)( + *dev_ctx, dense_x_ptr, *dense_out_grad, phi::Scalar(axis), dense_x_grad); + + return x_grad; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index f9a11b4bd9683..1f84eab10353d 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -31,5 +31,9 @@ std::vector split_impl(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis); +std::vector concat_grad_impl(const std::vector& x, + const Tensor& out_grad, + const Scalar& axis); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 42041af2dfe9e..76951669c66f2 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1909,6 +1909,13 @@ void StackInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } +void UnchangedMultiInferMeta(const std::vector& x, + std::vector out) { + for (size_t i = 0; i < x.size(); ++i) { + out[i]->share_meta(*x[i]); + } +} + void WarpctcInferMeta(const MetaTensor& logits, const MetaTensor& label, const paddle::optional logits_length, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 0b1ccfcb90541..c63960c7b9b79 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -289,6 +289,9 @@ void StackInferMeta(const std::vector& x, int axis, MetaTensor* out); +void UnchangedMultiInferMeta(const std::vector& x, + std::vector out); + void WarpctcInferMeta(const MetaTensor& logits, const MetaTensor& label, const paddle::optional logits_length, diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index b47ddd0dc9fc3..a49b4b79fbf0c 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -323,7 +323,15 @@ def concat(input, axis=0, name=None): # [14 15 16]] """ - if _non_static_mode(): + if in_dygraph_mode(): + if isinstance(axis, Variable): + axis = axis.numpy() + axis = axis.item(0) + if not isinstance(input, Variable): + input = [t for t in input if t.shape.count(0) == 0] + return _C_ops.final_state_concat(input, axis) + + if _in_legacy_dygraph(): if isinstance(axis, Variable): axis = axis.numpy() axis = axis.item(0) diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 4feca1b92505b..629ddb31d7b62 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -19,6 +19,7 @@ from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard, core +from paddle.fluid.framework import _test_eager_guard import paddle @@ -49,7 +50,7 @@ def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place(place) else: - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): if self.dtype == np.uint16: @@ -58,9 +59,9 @@ def test_check_grad(self): self.check_grad_with_place(place, ['x1'], 'Out') self.check_grad_with_place(place, ['x2'], 'Out') else: - self.check_grad(['x0'], 'Out') - self.check_grad(['x1'], 'Out') - self.check_grad(['x2'], 'Out') + self.check_grad(['x0'], 'Out', check_eager=True) + self.check_grad(['x1'], 'Out', check_eager=True) + self.check_grad(['x2'], 'Out', check_eager=True) def init_test_data(self): if self.dtype == np.uint16: @@ -124,6 +125,7 @@ class TestConcatOp6(TestConcatOp): def setUp(self): self.op_type = "concat" self.dtype = self.get_dtype() + self.python_api = paddle.concat self.init_test_data() self.lod = [[20, 80]] self.out_lod = [[20, 80, 20, 80, 20, 80]] @@ -141,12 +143,12 @@ def setUp(self): self.outputs = {'Out': (out, self.out_lod)} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['x0'], 'Out', check_dygraph=False) - self.check_grad(['x1'], 'Out', check_dygraph=False) - self.check_grad(['x2'], 'Out', check_dygraph=False) + self.check_grad(['x0'], 'Out', check_eager=True) + self.check_grad(['x1'], 'Out', check_eager=True) + self.check_grad(['x2'], 'Out', check_eager=True) def init_test_data(self): self.x0 = np.random.random([100]).astype(self.dtype) @@ -159,6 +161,7 @@ def create_test_AxisTensor(parent): class TestConcatAxisTensor(parent): def setUp(self): self.op_type = "concat" + self.python_api = paddle.concat self.dtype = self.get_dtype() self.init_test_data() @@ -334,6 +337,12 @@ def test_imperative(self): self.assertEqual((out1.numpy() == np_out1).all(), True) self.assertEqual((out2.numpy() == np_out2).all(), True) + def test_eager(self): + with _test_eager_guard(): + self.test_api() + self.test_fluid_api() + self.test_imperative() + def test_errors(self): with program_guard(Program(), Program()): # The item in input must be Variable. @@ -370,6 +379,7 @@ class TestConcatAPIWithLoDTensorArray(unittest.TestCase): def setUp(self): self.axis = 1 + self.python = paddle.concat self.iter_num = 3 self.input_shape = [2, 3] self.x = np.random.random(self.input_shape).astype("float32") diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index f38a9bc619eba..4f05f107bc2fc 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -320,6 +320,7 @@ param : [x, axis] kernel : func : concat + backward : concat_grad - api : conj args : (Tensor x) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 7b6c383286601..db1fe6cdf5220 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -179,6 +179,12 @@ kernel : func : cholesky_solve_grad +- backward_api : concat_grad + forward : concat (Tensor[] x, Scalar axis) -> Tensor(out) + args : (Tensor[] x, Tensor out_grad, Scalar axis = 0) + output : Tensor[](x_grad) + invoke : concat_grad_impl(x, out_grad, axis) + - backward_api : conv2d_transpose_grad forward : conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) From 1c7001e731099061370447b1e1f0e1d0ba164742 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 4 Apr 2022 12:14:43 +0800 Subject: [PATCH 59/93] Add dropout yaml (#41355) * add dropout slice yaml * remove useless code * fix infer shape error * skip infrt compile for dropout --- paddle/fluid/framework/op_desc.cc | 11 +++- paddle/fluid/framework/op_desc.h | 2 +- paddle/fluid/operators/dropout_op.cc | 2 +- paddle/phi/infermeta/binary.cc | 20 +++++++ paddle/phi/infermeta/binary.h | 10 ++++ paddle/phi/infermeta/unary.cc | 57 +++++++++++++++---- paddle/phi/infermeta/unary.h | 11 +++- python/paddle/fluid/backward.py | 1 + python/paddle/fluid/layers/nn.py | 17 ++---- .../fluid/tests/unittests/test_dropout_op.py | 31 ++++++++++ .../fluid/tests/unittests/test_slice_op.py | 26 +++++++++ python/paddle/nn/functional/common.py | 10 +++- python/paddle/utils/code_gen/api.yaml | 19 +++++++ python/paddle/utils/code_gen/backward.yaml | 21 +++++++ tools/infrt/skipped_phi_api.json | 2 +- 15 files changed, 209 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index a02466c04e913..f31fefcfade89 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -777,10 +777,17 @@ void OpDesc::CheckAttrs() { checker->Check(&attrs_); } -void OpDesc::InferShape(const BlockDesc &block) const { +void OpDesc::InferShape(const BlockDesc &block) { try { VLOG(3) << "CompileTime infer shape on " << Type(); - auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; + auto &op_info = OpInfoMap::Instance().Get(this->Type()); + auto *checker = op_info.Checker(); + if (checker != nullptr) { + // set dafault value here + VLOG(10) << "begin to check attribute of " << Type(); + checker->Check(&attrs_); + } + auto &infer_shape = op_info.infer_shape_; PADDLE_ENFORCE_EQ( static_cast(infer_shape), true, platform::errors::NotFound( diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 82e15d40bee78..0afe6796dad7a 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -142,7 +142,7 @@ class OpDesc { void CheckAttrs(); - void InferShape(const BlockDesc &block) const; + void InferShape(const BlockDesc &block); void InferVarType(BlockDesc *block) const; diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 3d9950902acfe..8d033ea3194b9 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ab13df081aa28..60db5d342b8b3 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -776,6 +776,26 @@ void DistInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void DropoutInferMeta(const MetaTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + MetaTensor* out, + MetaTensor* mask) { + auto x_dims = x.dims(); + out->set_dims(x_dims); + out->share_lod(x); + out->set_dtype(x.dtype()); + + if (mask != nullptr) { + mask->set_dims(x_dims); + mask->set_dtype(DataType::UINT8); + } +} + void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { auto x_dims = x.dims(); auto x_rank = static_cast(x_dims.size()); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 3fcbf69c35e25..296c05756f291 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -124,6 +124,16 @@ void DistInferMeta(const MetaTensor& x, void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); +void DropoutInferMeta(const MetaTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + MetaTensor* out, + MetaTensor* mask); + void ElementwiseInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 36c192cbf2748..e0ea637074c20 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" #include "paddle/phi/kernels/funcs/strided_slice.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" @@ -360,17 +361,6 @@ void DiagonalInferMeta(const MetaTensor& input, out->set_dims(phi::make_ddim(out_dims)); } -void DropoutInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* mask) { - auto x_dims = x.dims(); - out->set_dims(x_dims); - out->share_lod(x); - out->set_dtype(x.dtype()); - - if (mask != nullptr) { - mask->set_dims(x_dims); - } -} - void EighInferMeta(const MetaTensor& x, const std::string& uplo, MetaTensor* out_w, @@ -1738,6 +1728,51 @@ void SizeInferMeta(const MetaTensor& input, MetaTensor* out) { out->set_dims({1}); } +void SliceRawInferMeta(const MetaTensor& input, + const std::vector& axes, + const IntArray& starts_arr, + const IntArray& ends_arr, + const std::vector& infer_flags_t, + const std::vector& decrease_axis, + MetaTensor* out, + MetaConfig config) { + auto in_dims = input.dims(); + PADDLE_ENFORCE_LT( + in_dims.size(), + 7, + phi::errors::InvalidArgument("The rank of input should be less than 7.")); + DDim out_dims(in_dims); + + std::vector infer_flags = infer_flags_t; + if (infer_flags.empty()) { + // Initialize infer_flags with 1. + // To be compatible with other op tests in which infer_flags is not set. + infer_flags = std::vector(axes.size(), 1); + } + + // 2.1 Check attrs. + std::vector starts = starts_arr.GetData(); + std::vector ends = ends_arr.GetData(); + + phi::funcs::CheckAndUpdateSliceAttrs( + in_dims, axes, &starts, &ends, nullptr, &infer_flags); + + auto slice_dims = phi::funcs::GetSliceDims( + in_dims, axes, starts, ends, nullptr, &infer_flags); + if (config.is_runtime) { + out_dims = phi::funcs::GetDecreasedDims( + slice_dims, decrease_axis, &infer_flags); + } else { + out_dims = phi::funcs::GetDecreasedDims( + slice_dims, decrease_axis, nullptr); + } + + out->set_dims(out_dims); + if (axes.size() > 0 && axes[0] != 0) { + out->share_lod(input); + } +} + void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) { auto dim_x = x.dims(); auto rank_x = dim_x.size(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index bda9c83fce1f2..5106c6f448733 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -80,8 +80,6 @@ void DiagInferMeta(const MetaTensor& x, void DiagonalInferMeta( const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out); -void DropoutInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* mask); - void EighInferMeta(const MetaTensor& x, const std::string& uplo, MetaTensor* out_w, @@ -271,6 +269,15 @@ void ShardIndexInferMeta(const MetaTensor& in, void SizeInferMeta(const MetaTensor& input, MetaTensor* out); +void SliceRawInferMeta(const MetaTensor& input, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out); void SplitInferMeta(const MetaTensor& x_meta, diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 0988f6709552b..ba7692b442f82 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1337,6 +1337,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): continue grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) # infer_shape and infer_type + op_desc.check_attrs() op_desc.infer_var_type(block.desc) op_desc.infer_shape(block.desc) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d7ec3276d8b79..9f971faed3435 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5141,7 +5141,6 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): # [-0.33972208 -0.43014923 0.31772556 0.76617881 -0.10761525]] """ - if len(x.shape) == 1: axis = 0 if _non_static_mode(): @@ -11199,18 +11198,15 @@ def slice(input, axes, starts, ends): infer_flags = list(1 for i in range(len(axes))) tmp_tensor_type = core.eager.Tensor - if isinstance(starts, (list, tuple)): starts = [ item.numpy().item(0) if isinstance(item, tmp_tensor_type) else item for item in starts ] - attrs += ('starts', starts) elif isinstance(starts, tmp_tensor_type): - starts_tensor = starts - starts.stop_gradient = True - infer_flags = list(-1 for i in range(len(axes))) + tensor_t = starts.numpy() + starts = [ele for ele in tensor_t] if isinstance(ends, (list, tuple)): ends = [ @@ -11219,12 +11215,11 @@ def slice(input, axes, starts, ends): ] attrs += ('ends', ends) elif isinstance(ends, tmp_tensor_type): - ends_tensor = ends - ends_tensor.stop_gradient = True - infer_flags = list(-1 for i in range(len(axes))) + tensor_t = ends.numpy() + ends = [ele for ele in tensor_t] - return _C_ops.slice(input, starts_tensor, ends_tensor, None, None, - 'axes', axes, 'infer_flags', infer_flags, *attrs) + return _C_ops.final_state_slice(input, axes, starts, ends, infer_flags, + []) else: if _in_legacy_dygraph(): attrs = () diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 09712005d4125..d8a4eb8f45f7d 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -22,8 +22,11 @@ import paddle.static as static import paddle.fluid as fluid from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard import os +from paddle import _C_ops + class TestDropoutOp(OpTest): def setUp(self): @@ -960,6 +963,19 @@ def test_backward_downscale_in_infer(self): np.array_equal(input.gradient( ), self.cal_grad_downscale_in_infer(mask.numpy()))) + def test_backward_downscale_in_infer_eager(self): + for place in self.places: + with fluid.dygraph.guard(place): + with _test_eager_guard(): + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = _C_ops.final_state_dropout( + input, None, 0.5, False, "downgrade_in_infer", 0, False) + out.backward() + self.assertTrue( + np.array_equal(input.gradient( + ), self.cal_grad_downscale_in_infer(mask.numpy()))) + def test_backward_upscale_train(self): for place in self.places: with fluid.dygraph.guard(place): @@ -976,6 +992,21 @@ def test_backward_upscale_train(self): np.allclose(input.gradient( ), self.cal_grad_upscale_train(mask.numpy(), prob))) + def test_backward_upscale_train_eager(self): + for place in self.places: + with fluid.dygraph.guard(place): + with _test_eager_guard(): + prob = 0.5 + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = _C_ops.final_state_dropout( + input, None, 0.5, False, "upscale_in_train", 0, False) + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + def test_backward_upscale_train_2(self): for place in self.places: with fluid.dygraph.guard(place): diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 71869b96aedf0..a565bba304184 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -21,6 +21,7 @@ import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle +from paddle.fluid.framework import _test_eager_guard paddle.enable_static() @@ -599,6 +600,31 @@ def test_bool_tensor(self): self.assertTrue(np.array_equal(y_paddle.numpy(), y_np)) +class TestSliceApiEager(unittest.TestCase): + def test_slice_api(self): + with paddle.fluid.dygraph.guard(): + with _test_eager_guard(): + a = paddle.rand(shape=[4, 5, 6], dtype='float32') + a.stop_gradient = False + axes = [0, 1, 2] + starts = [-3, 0, 2] + ends = [3, 2, 4] + a_1 = paddle.slice(a, axes=axes, starts=starts, ends=ends) + + a_2 = paddle.slice( + a, + axes=axes, + starts=paddle.to_tensor(starts), + ends=paddle.to_tensor(ends)) + + a_1.backward() + grad_truth = paddle.zeros_like(a) + grad_truth[-3:3, 0:2, 2:4] = 1 + self.assertTrue(np.array_equal(grad_truth, a.gradient())) + + self.assertTrue(np.allclose(a_1.numpy(), a[-3:3, 0:2, 2:4])) + + class TestSliceApiWithLoDTensorArray(unittest.TestCase): def setUp(self): self.shape = (3, 4) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 131d31aa02405..74df8f6ed5c34 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -28,7 +28,7 @@ from ...tensor import sum from ...tensor import sqrt from ...fluid.data_feeder import check_variable_and_dtype, check_dtype -from ...fluid.framework import _varbase_creator, _in_legacy_dygraph, in_dygraph_mode +from ...fluid.framework import _varbase_creator, _in_legacy_dygraph, in_dygraph_mode, _non_static_mode from ...fluid import dygraph_utils from ...fluid import layers @@ -895,9 +895,15 @@ def dropout(x, seed = None mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer - if in_dynamic_mode(): + if _non_static_mode(): if default_main_program().random_seed != 0: seed = default_main_program().random_seed + + if in_dygraph_mode(): + out, mask = _C_ops.final_state_dropout( x, None, p, not training, mode, \ + seed if seed is not None else 0, seed is not None) + + return out out, mask = _C_ops.dropout( x, 'dropout_prob', p, 'is_test', not training, 'fix_seed', seed is not None, 'seed', seed diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 4f05f107bc2fc..2b0c562dbf9bd 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -463,6 +463,16 @@ kernel : func : dot +- api : dropout + args : (Tensor x, Tensor seed_tensor, float p, bool is_test, str mode, int seed, bool fix_seed) + output : Tensor(out), Tensor(mask) + infer_meta : + func : DropoutInferMeta + kernel : + func : dropout + optional : seed_tensor + backward : dropout_grad + # eigh - api : eigh args : (Tensor x, str uplo) @@ -1504,6 +1514,15 @@ kernel : func : size +- api : slice + args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) + output : Tensor + infer_meta : + func : SliceRawInferMeta + kernel : + func : slice + backward : slice_grad + # soft_shrink - api : soft_shrink args : (Tensor x, float lambda) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index db1fe6cdf5220..cbcfc02ea0992 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -301,6 +301,17 @@ kernel : func : divide_grad +- backward_api : dropout_grad + forward : dropout (Tensor x, Tensor seed_tensor, float p, bool is_test, str mode, int seed, bool fix_seed) -> Tensor(out), Tensor(mask) + args : (Tensor mask, Tensor out_grad, float p, bool is_test, str mode) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : dropout_grad + optional : seed_tensor + - backward_api : eigh_grad forward : eigh (Tensor x, str uplo) -> Tensor(out_w), Tensor(out_v) args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad) @@ -1054,6 +1065,16 @@ kernel : func : sinh_grad +- backward_api : slice_grad + forward : slice (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(out) + args : (Tensor input, Tensor out_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) + output : Tensor(input_grad) + infer_meta : + func : UnchangedInferMeta + param : [input] + kernel : + func : slice_grad + - backward_api : soft_shrink_grad forward : soft_shrink (Tensor x, float lambda) -> Tensor(out) args : (Tensor x, Tensor out_grad, float lambda) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index eef57a2d6b7bc..74cb6fb0e5356 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "nll_loss", "flatten"], +"phi_apis":["conj", "nll_loss", "dropout", "flatten"], "phi_kernels":["equal_all"] } From c02eeb969c15a7276e2e6ea1b651d4dff3e41973 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 4 Apr 2022 15:49:41 +0800 Subject: [PATCH 60/93] Updated uva related code (#41391) --- paddle/fluid/pybind/eager_method.cc | 4 ++++ .../paddle/fluid/dygraph/varbase_patch_methods.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 66fba92f67b83..1a7eb629a0eaa 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1349,6 +1349,10 @@ static PyObject* tensor_method__uva(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY VLOG(4) << "Running in tensor_method__uva."; + PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(), true, + platform::errors::InvalidArgument( + "Unified virtual addressing only support " + "DenseTensor currently.")); PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.inner_place()), true, platform::errors::InvalidArgument( "Unified virtual addressing only support " diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index c97471d25f19c..bd1ca1aa26dda 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -818,6 +818,21 @@ def _numel(self): @framework.dygraph_only def _uva(self, device_id=0): + ''' + Returns self tensor with the UVA(unified virtual addressing). + + Args: + device_id(int, optional): The destination GPU device id. Default: None, means current device. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + x = paddle.to_tensor([1, 2, 3], place=paddle.CPUPlace()) + x._uva() + print(x) + ''' self._tensor_uva(device_id) @framework.dygraph_only From 5b8c5b7bc0fbf0a0e8a70442eefd7432011dfbf5 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 4 Apr 2022 15:51:11 +0800 Subject: [PATCH 61/93] Fix some PaddleTest UT (#41373) * Fix some PaddleTest UT * refine code * set default value --- python/paddle/tensor/logic.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 3896fa535ff22..a4ff87246631a 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -182,7 +182,8 @@ def equal(x, y, name=None): y = full(shape=[1], dtype=x.dtype, fill_value=y) if in_dygraph_mode(): - return _C_ops.final_state_equal(x, y) + axis = -1 + return _C_ops.final_state_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.equal(x, y) @@ -231,7 +232,8 @@ def greater_equal(x, y, name=None): print(result1) # result1 = [True False True] """ if in_dygraph_mode(): - return _C_ops.final_state_greater_equal(x, y) + axis = -1 + return _C_ops.final_state_greater_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.greater_equal(x, y) @@ -331,7 +333,8 @@ def less_equal(x, y, name=None): print(result1) # result1 = [True True False] """ if in_dygraph_mode(): - return _C_ops.final_state_less_equal(x, y) + axis = -1 + return _C_ops.final_state_less_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.less_equal(x, y) @@ -381,7 +384,8 @@ def less_than(x, y, name=None): print(result1) # result1 = [False True False] """ if in_dygraph_mode(): - return _C_ops.final_state_less_than(x, y) + axis = -1 + return _C_ops.final_state_less_than(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.less_than(x, y) @@ -431,7 +435,8 @@ def not_equal(x, y, name=None): print(result1) # result1 = [False True True] """ if in_dygraph_mode(): - return _C_ops.final_state_not_equal(x, y) + axis = -1 + return _C_ops.final_state_not_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.not_equal(x, y) @@ -538,7 +543,7 @@ def bitwise_and(x, y, out=None, name=None): res = paddle.bitwise_and(x, y) print(res) # [0, 2, 1] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_and(x, y) return _bitwise_op( op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True) @@ -566,7 +571,7 @@ def bitwise_or(x, y, out=None, name=None): res = paddle.bitwise_or(x, y) print(res) # [-1, -1, -3] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_or(x, y) return _bitwise_op( @@ -595,7 +600,7 @@ def bitwise_xor(x, y, out=None, name=None): res = paddle.bitwise_xor(x, y) print(res) # [-1, -3, -4] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_xor(x, y) return _bitwise_op( op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True) @@ -621,7 +626,7 @@ def bitwise_not(x, out=None, name=None): res = paddle.bitwise_not(x) print(res) # [4, 0, -2] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_not(x) return _bitwise_op( From 75a17cdb29b1e3c5f307369d81cbf0ccf8e04a3d Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 4 Apr 2022 15:56:12 +0800 Subject: [PATCH 62/93] Skip DoubleGrad-related unit tests under eager mode (#41380) --- .../test_autograd_functional_dynamic.py | 205 +++++++++++++----- ...perative_star_gan_with_gradient_penalty.py | 7 +- .../unittests/test_imperative_triple_grad.py | 16 +- 3 files changed, 168 insertions(+), 60 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py index e46c532eb05db..8c725fe24e59c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py @@ -21,6 +21,7 @@ import paddle.compat as cpt import paddle.nn.functional as F from paddle.autograd.functional import _as_tensors +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check import config import utils @@ -145,7 +146,7 @@ def check_results(self, ref, res): class TestVJP(TestAutogradFunctional): - def test_vjp_i1o1(self): + def func_vjp_i1o1(self): test_cases = [ [reduce, 'A'], # noqa [reduce_dim, 'A'], # noqa @@ -155,7 +156,7 @@ def test_vjp_i1o1(self): vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) - def test_vjp_i2o1(self): + def func_vjp_i2o1(self): test_cases = [ [matmul, ['A', 'B']], # noqa [mul, ['b', 'c']], # noqa @@ -165,7 +166,7 @@ def test_vjp_i2o1(self): vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) - def test_vjp_i2o2(self): + def func_vjp_i2o2(self): test_cases = [ [o2, ['A', 'A']], # noqa ] # noqa @@ -176,7 +177,7 @@ def test_vjp_i2o2(self): vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) - def test_vjp_i2o2_omitting_v(self): + def func_vjp_i2o2_omitting_v(self): test_cases = [ [o2, ['A', 'A']], # noqa ] # noqa @@ -186,7 +187,7 @@ def test_vjp_i2o2_omitting_v(self): vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) - def test_vjp_nested(self): + def func_vjp_nested(self): x = self.gen_input('a') test_cases = [ [nested(x), 'a'], # noqa @@ -196,13 +197,22 @@ def test_vjp_nested(self): vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result) - def test_vjp_aliased_input(self): + def func_vjp_aliased_input(self): x = self.gen_input('a') ref = self.gen_test_pairs(nested(x), 'a')[0] aliased = self.gen_test_pairs(nested(x), x)[0] ref_result, aliased_result = ref(), aliased() self.check_results(ref_result, aliased_result) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_vjp_i1o1() + self.func_vjp_i2o1() + self.func_vjp_i2o2() + self.func_vjp_i2o2_omitting_v() + self.func_vjp_nested() + self.func_vjp_aliased_input() + @utils.place(config.DEVICES) @utils.parameterize( @@ -210,12 +220,16 @@ def test_vjp_aliased_input(self): ('v_shape_not_equal_ys', utils.square, np.random.rand(3), np.random.rand(1), RuntimeError), )) class TestVJPException(unittest.TestCase): - def test_vjp(self): + def func_vjp(self): with self.assertRaises(self.expected_exception): paddle.autograd.vjp(self.fun, paddle.to_tensor(self.xs), paddle.to_tensor(self.v)) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_vjp() + def jac(grad_fn, f, inputs): assert grad_fn in [paddle.autograd.vjp, paddle.autograd.jvp] @@ -246,7 +260,7 @@ def jac(grad_fn, f, inputs): class TestJVP(TestAutogradFunctional): - def test_jvp_i1o1(self): + def func_jvp_i1o1(self): test_cases = [ [reduce, 'A'], # noqa [reduce_dim, 'A'], # noqa @@ -257,7 +271,7 @@ def test_jvp_i1o1(self): reverse_jac = jac(paddle.autograd.vjp, f, inputs) self.check_results(forward_jac, reverse_jac) - def test_jvp_i2o1(self): + def func_jvp_i2o1(self): test_cases = [ # noqa [matmul, ['A', 'B']], # noqa ] # noqa @@ -267,7 +281,7 @@ def test_jvp_i2o1(self): reverse_jac = jac(paddle.autograd.vjp, f, inputs) self.check_results(forward_jac, reverse_jac) - def test_jvp_i2o2(self): + def func_jvp_i2o2(self): test_cases = [ # noqa [o2, ['A', 'A']], # noqa ] # noqa @@ -277,7 +291,7 @@ def test_jvp_i2o2(self): reverse_jac = jac(paddle.autograd.vjp, f, inputs) self.check_results(forward_jac, reverse_jac) - def test_jvp_i2o2_omitting_v(self): + def func_jvp_i2o2_omitting_v(self): test_cases = [ # noqa [o2, ['A', 'A']], # noqa ] # noqa @@ -288,6 +302,13 @@ def test_jvp_i2o2_omitting_v(self): results_with_v = paddle.autograd.jvp(f, inputs, v) self.check_results(results_omitting_v, results_with_v) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_jvp_i1o1() + self.func_jvp_i2o1() + self.func_jvp_i2o2() + self.func_jvp_i2o2_omitting_v() + @utils.place(config.DEVICES) @utils.parameterize((utils.TEST_CASE_NAME, 'func', 'xs'), ( @@ -312,7 +333,7 @@ def setUp(self): self._actual = paddle.autograd.Jacobian(self.func, self.xs, False) self._expected = self._expected() - def test_jacobian(self): + def func_jacobian(self): Index = collections.namedtuple('Index', ('type', 'value')) indexes = (Index('all', (slice(0, None, None), slice(0, None, None))), Index('row', (0, slice(0, None, None))), @@ -333,6 +354,10 @@ def _expected(self): self._dtype) return utils._np_concat_matrix_sequence(jac, utils.MatrixFormat.NM) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_jacobian() + @utils.place(config.DEVICES) @utils.parameterize((utils.TEST_CASE_NAME, 'func', 'xs'), ( @@ -355,7 +380,7 @@ def setUp(self): self._actual = paddle.autograd.Jacobian(self.func, self.xs, True) self._expected = self._expected() - def test_jacobian(self): + def func_jacobian(self): Index = collections.namedtuple('Index', ('type', 'value')) indexes = ( Index('all', (slice(0, None, None), slice(0, None, None), @@ -384,6 +409,10 @@ def _expected(self): return utils._np_transpose_matrix_format(jac, utils.MatrixFormat.NBM, utils.MatrixFormat.BNM) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_jacobian() + class TestHessianClassNoBatch(unittest.TestCase): @classmethod @@ -400,7 +429,7 @@ def setUpClass(self): self.x = paddle.rand(shape=self.shape, dtype=self.dtype) self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - def test_single_input(self): + def func_single_input(self): def func(x): return paddle.sum(paddle.matmul(x, x)) @@ -413,7 +442,7 @@ def func(x): np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian, self.rtol, self.atol) - def test_multi_input(self): + def func_multi_input(self): def func(x, y): return paddle.sum(paddle.matmul(x, y)) @@ -429,7 +458,7 @@ def func(x, y): rtol=self.rtol, atol=self.atol) - def test_allow_unused_true(self): + def func_allow_unused_true(self): def func(x, y): return paddle.sum(paddle.matmul(x, x)) @@ -442,7 +471,7 @@ def func(x, y): np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian, self.rtol, self.atol) - def test_create_graph_true(self): + def func_create_graph_true(self): def func(x): return paddle.sum(F.sigmoid(x)) @@ -455,13 +484,21 @@ def func(x): np.testing.assert_allclose(hessian[:].numpy(), numerical_hessian, self.rtol, self.atol) - def test_out_not_single(self): + def func_out_not_single(self): def func(x): return x * x with self.assertRaises(RuntimeError): paddle.autograd.Hessian(func, paddle.ones([3])) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_single_input() + self.func_multi_input() + self.func_allow_unused_true() + self.func_create_graph_true() + self.func_out_not_single() + class TestHessianClassBatchFirst(unittest.TestCase): @classmethod @@ -482,7 +519,7 @@ def setUpClass(self): self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) - def test_single_input(self): + def func_single_input(self): def func(x): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -496,7 +533,7 @@ def func(x): np.testing.assert_allclose(actual, expected, self.rtol, self.atol) - def test_multi_input(self): + def func_multi_input(self): def func(x, y): return paddle.matmul(x * x * y * y, self.weight)[:, 0:1] @@ -517,7 +554,7 @@ def func(x, y): np.testing.assert_allclose(actual, expected, self.rtol, self.atol) - def test_allow_unused(self): + def func_allow_unused(self): def func(x, y): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -538,7 +575,7 @@ def func(x, y): np.testing.assert_allclose( actual, expected, rtol=self.rtol, atol=self.atol) - def test_stop_gradient(self): + def func_stop_gradient(self): def func(x): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -554,13 +591,21 @@ def func(x): np.testing.assert_allclose(actual, expected, self.rtol, self.atol) - def test_out_not_single(self): + def func_out_not_single(self): def func(x): return (x * x) with self.assertRaises(RuntimeError): paddle.autograd.Hessian(func, paddle.ones((3, 3)), is_batched=True) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_single_input() + self.func_multi_input() + self.func_allow_unused() + self.func_stop_gradient() + self.func_out_not_single() + class TestHessian(unittest.TestCase): @classmethod @@ -577,7 +622,7 @@ def setUpClass(self): self.x = paddle.rand(shape=self.shape, dtype=self.dtype) self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - def test_single_input(self): + def func_single_input(self): def func(x): return paddle.sum(paddle.matmul(x, x)) @@ -589,7 +634,7 @@ def func(x): np.testing.assert_allclose(hessian.numpy(), numerical_hessian[0][0], self.rtol, self.atol) - def test_multi_input(self): + def func_multi_input(self): def func(x, y): return paddle.sum(paddle.matmul(x, y)) @@ -605,7 +650,7 @@ def func(x, y): numerical_hessian[i][j], self.rtol, self.atol) - def test_allow_unused_false(self): + def func_allow_unused_false(self): def func(x, y): return paddle.sum(paddle.matmul(x, x)) @@ -617,7 +662,7 @@ def func(x, y): error_msg = cpt.get_exception_message(e) assert error_msg.find("allow_unused") > 0 - def test_allow_unused_true(self): + def func_allow_unused_true(self): def func(x, y): return paddle.sum(paddle.matmul(x, x)) @@ -636,7 +681,7 @@ def func(x, y): else: assert hessian[i][j] is None - def test_create_graph_false(self): + def func_create_graph_false(self): def func(x): return paddle.sum(paddle.matmul(x, x)) @@ -653,7 +698,7 @@ def func(x): error_msg = cpt.get_exception_message(e) assert error_msg.find("has no gradient") > 0 - def test_create_graph_true(self): + def func_create_graph_true(self): def func(x): return paddle.sum(F.sigmoid(x)) @@ -667,6 +712,15 @@ def func(x): triple_grad = paddle.grad(hessian, self.x) assert triple_grad is not None + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_single_input() + self.func_multi_input() + self.func_allow_unused_false() + self.func_allow_unused_true() + self.func_create_graph_false() + self.func_create_graph_true() + class TestHessianFloat64(TestHessian): @classmethod @@ -702,7 +756,7 @@ def setUpClass(self): self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) - def test_single_input(self): + def func_single_input(self): def func(x): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -713,7 +767,7 @@ def func(x): np.testing.assert_allclose(hessian, numerical_hessian, self.rtol, self.atol) - def test_multi_input(self): + def func_multi_input(self): def func(x, y): return paddle.matmul(x * x * y * y, self.weight)[:, 0:1] @@ -729,7 +783,7 @@ def func(x, y): np.testing.assert_allclose(hessian_reshape, numerical_hessian, self.rtol, self.atol) - def test_allow_unused_false(self): + def func_allow_unused_false(self): def func(x, y): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -741,7 +795,7 @@ def func(x, y): error_msg = cpt.get_exception_message(e) assert error_msg.find("allow_unused") > 0 - def test_allow_unused_true(self): + def func_allow_unused_true(self): def func(x, y): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -763,7 +817,7 @@ def func(x, y): else: assert hessian[i][j] is None - def test_create_graph_false(self): + def func_create_graph_false(self): def func(x): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -780,7 +834,7 @@ def func(x): error_msg = cpt.get_exception_message(e) assert error_msg.find("has no gradient") > 0 - def test_create_graph_true(self): + def func_create_graph_true(self): def func(x): return paddle.matmul(x * x, self.weight)[:, 0:1] @@ -794,6 +848,15 @@ def func(x): triple_grad = paddle.grad(hessian, self.x) assert triple_grad is not None + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_single_input() + self.func_multi_input() + self.func_allow_unused_false() + self.func_allow_unused_true() + self.func_create_graph_false() + self.func_create_graph_true() + class TestBatchHessianFloat64(TestBatchHessian): @classmethod @@ -831,7 +894,7 @@ def setUpClass(self): self.vx = paddle.rand(shape=self.shape, dtype=self.dtype) self.vy = paddle.rand(shape=self.shape, dtype=self.dtype) - def test_single_input(self): + def func_single_input(self): def func(x): return paddle.sum(paddle.matmul(x, x)) @@ -846,7 +909,7 @@ def func(x): np.testing.assert_allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, self.atol) - def test_multi_input(self): + def func_multi_input(self): def func(x, y): return paddle.sum(paddle.matmul(x, y)) @@ -865,7 +928,7 @@ def func(x, y): np.testing.assert_allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol, self.atol) - def test_v_default(self): + def func_v_default(self): def func(x, y): return paddle.sum(paddle.matmul(x, y)) @@ -885,7 +948,7 @@ def func(x, y): np.testing.assert_allclose(vhp[i].numpy(), numerical_vhp[i], self.rtol, self.atol) - def test_allow_unused_true(self): + def func_allow_unused_true(self): def func(x, y): return paddle.sum(paddle.matmul(x, x)) @@ -903,7 +966,7 @@ def func(x, y): np.testing.assert_allclose(vhp[0].numpy(), numerical_vhp[0], self.rtol, self.atol) - def test_create_graph_true(self): + def func_create_graph_true(self): def func(x): return paddle.sum(F.sigmoid(x)) @@ -921,6 +984,14 @@ def func(x): triple_grad = paddle.grad(vhp, self.x) assert triple_grad is not None + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_v_default() + self.func_multi_input() + self.func_single_input() + self.func_allow_unused_true() + self.func_create_graph_true() + class TestJacobian(unittest.TestCase): @classmethod @@ -934,7 +1005,7 @@ def setUpClass(self): self.x = paddle.rand(shape=self.shape, dtype=self.dtype) self.y = paddle.rand(shape=self.shape, dtype=self.dtype) - def test_single_input_and_single_output(self): + def func_single_input_and_single_output(self): def func(x): return paddle.matmul(x, x) @@ -945,7 +1016,7 @@ def func(x): np.testing.assert_allclose(jacobian.numpy(), numerical_jacobian[0][0], self.rtol, self.atol) - def test_single_input_and_multi_output(self): + def func_single_input_and_multi_output(self): def func(x): return paddle.matmul(x, x), x * x @@ -958,7 +1029,7 @@ def func(x): numerical_jacobian[i][0], self.rtol, self.atol) - def test_multi_input_and_single_output(self): + def func_multi_input_and_single_output(self): def func(x, y): return paddle.matmul(x, y) @@ -972,7 +1043,7 @@ def func(x, y): numerical_jacobian[0][j], self.rtol, self.atol) - def test_multi_input_and_multi_output(self): + def func_multi_input_and_multi_output(self): def func(x, y): return paddle.matmul(x, y), x * y @@ -987,7 +1058,7 @@ def func(x, y): numerical_jacobian[i][j], self.rtol, self.atol) - def test_allow_unused_false(self): + def func_allow_unused_false(self): def func(x, y): return paddle.matmul(x, x) @@ -999,7 +1070,7 @@ def func(x, y): error_msg = cpt.get_exception_message(e) assert error_msg.find("allow_unused") > 0 - def test_allow_unused_true(self): + def func_allow_unused_true(self): def func(x, y): return paddle.matmul(x, x) @@ -1013,7 +1084,7 @@ def func(x, y): jacobian[0].numpy(), numerical_jacobian[0][0], self.rtol, self.atol) assert jacobian[1] is None - def test_create_graph_false(self): + def func_create_graph_false(self): def func(x, y): return paddle.matmul(x, y) @@ -1033,7 +1104,7 @@ def func(x, y): error_msg = cpt.get_exception_message(e) assert error_msg.find("has no gradient") > 0 - def test_create_graph_true(self): + def func_create_graph_true(self): def func(x, y): return paddle.matmul(x, y) @@ -1051,6 +1122,17 @@ def func(x, y): double_grad = paddle.grad(jacobian[0], [self.x, self.y]) assert double_grad is not None + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_multi_input_and_multi_output() + self.func_multi_input_and_single_output() + self.func_single_input_and_multi_output() + self.func_single_input_and_single_output() + self.func_allow_unused_false() + self.func_allow_unused_true() + self.func_create_graph_false() + self.func_create_graph_true() + class TestJacobianFloat64(TestJacobian): @classmethod @@ -1080,7 +1162,7 @@ def setUpClass(self): self.weight = paddle.rand(shape=self.weight_shape, dtype=self.dtype) self.y = paddle.rand(shape=self.y_shape, dtype=self.dtype) - def test_batch_single_input_and_batch_single_output(self): + def func_batch_single_input_and_batch_single_output(self): def func(x): return paddle.matmul(paddle.matmul(x, self.weight), self.y) @@ -1096,7 +1178,7 @@ def func(x): np.allclose(batch_jacobian.numpy().all(), numerical_jacobian[0][0] .all())) - def test_batch_single_input_and_batch_multi_output(self): + def func_batch_single_input_and_batch_multi_output(self): def func(x): return paddle.matmul(paddle.matmul(x, self.weight), self.y), x * x @@ -1113,7 +1195,7 @@ def func(x): numerical_jacobian[i][0], self.rtol, self.atol) - def test_batch_multi_input_and_batch_single_output(self): + def func_batch_multi_input_and_batch_single_output(self): def func(x, y): return x * y @@ -1129,7 +1211,7 @@ def func(x, y): numerical_jacobian[0][j], self.rtol, self.atol) - def test_batch_multi_input_and_batch_multi_output(self): + def func_batch_multi_input_and_batch_multi_output(self): def func(x, y): return x * y, x * y @@ -1144,7 +1226,7 @@ def func(x, y): np.testing.assert_allclose(batch_jacobian[i], numerical_jacobian[i], self.rtol, self.atol) - def test_allow_unused_false(self): + def func_allow_unused_false(self): def func(x, y): return x * x @@ -1156,7 +1238,7 @@ def func(x, y): error_msg = cpt.get_exception_message(e) assert error_msg.find("allow_unused") > 0 - def test_allow_unused_true(self): + def func_allow_unused_true(self): def func(x, y): return x * x @@ -1171,7 +1253,7 @@ def func(x, y): jacobian[0].numpy(), numerical_jacobian[0][0], self.rtol, self.atol) assert jacobian[1] is None - def test_create_graph_false(self): + def func_create_graph_false(self): def func(x, y): return x * y @@ -1191,7 +1273,7 @@ def func(x, y): error_msg = cpt.get_exception_message(e) assert error_msg.find("has no gradient") > 0 - def test_create_graph_true(self): + def func_create_graph_true(self): def func(x, y): return x * y @@ -1209,6 +1291,17 @@ def func(x, y): double_grad = paddle.grad(jacobian[0], [self.x, self.y]) assert double_grad is not None + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_batch_single_input_and_batch_single_output() + self.func_batch_single_input_and_batch_multi_output() + self.func_batch_multi_input_and_batch_single_output() + self.func_batch_multi_input_and_batch_multi_output() + self.func_allow_unused_false() + self.func_allow_unused_true() + self.func_create_graph_false() + self.func_create_graph_true() + class TestJacobianBatchFloat64(TestJacobianBatch): @classmethod diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py index 2b8e10d779256..be81c15677a3a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -17,6 +17,7 @@ import numpy as np import unittest from paddle import _C_ops +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check if fluid.is_compiled_with_cuda(): fluid.core.globals()['FLAGS_cudnn_deterministic'] = True @@ -583,7 +584,7 @@ def run(self, image_real, label_org, label_trg): class TestStarGANWithGradientPenalty(unittest.TestCase): - def test_main(self): + def func_main(self): self.place_test(fluid.CPUPlace()) if fluid.is_compiled_with_cuda(): @@ -615,6 +616,10 @@ def place_test(self, place): self.assertEqual(g_loss_s, g_loss_d) self.assertEqual(d_loss_s, d_loss_d) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_main() + if __name__ == '__main__': paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py index 3644eead6bc65..027c0002c7103 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py @@ -19,6 +19,7 @@ import unittest from unittest import TestCase import numpy as np +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check def _dygraph_guard_(func): @@ -65,7 +66,7 @@ def grad(self, allow_unused=allow_unused) @dygraph_guard - def test_exception(self): + def func_exception(self): with self.assertRaises(AssertionError): self.grad(None, None) @@ -95,7 +96,7 @@ def test_exception(self): self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1) @dygraph_guard - def test_example_with_gradient_and_create_graph(self): + def func_example_with_gradient_and_create_graph(self): x = random_var(self.shape) x_np = x.numpy() x.stop_gradient = False @@ -145,6 +146,11 @@ def test_example_with_gradient_and_create_graph(self): dddx_grad_actual = x.gradient() self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected)) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_exception() + self.func_example_with_gradient_and_create_graph() + class TestDygraphTripleGradBradcastCase(TestCase): def setUp(self): @@ -172,7 +178,7 @@ def grad(self, allow_unused=allow_unused) @dygraph_guard - def test_example_with_gradient_and_create_graph(self): + def func_example_with_gradient_and_create_graph(self): x = random_var(self.x_shape) x_np = x.numpy() x.stop_gradient = False @@ -227,6 +233,10 @@ def test_example_with_gradient_and_create_graph(self): dddx_grad_actual = x.gradient() self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected)) + def test_all_cases(self): + if _in_legacy_dygraph(): + self.func_example_with_gradient_and_create_graph() + if __name__ == '__main__': unittest.main() From e5e0b726e5c2c561d6afd4765bbb75d30e0ff417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Mon, 4 Apr 2022 10:01:42 +0200 Subject: [PATCH 63/93] conv + elementwise_add refactor (#41286) * DRY * change nodes names * add const prefix * change asX to as_x in all files --- .../framework/ir/graph_pattern_detector.cc | 23 +++ .../framework/ir/graph_pattern_detector.h | 16 ++ paddle/fluid/framework/ir/graph_traits.cc | 48 +++++ paddle/fluid/framework/ir/graph_traits.h | 3 + .../conv_elementwise_add_mkldnn_fuse_pass.cc | 166 ++---------------- .../conv_elementwise_add_mkldnn_fuse_pass.h | 16 +- ...t_mkldnn_conv_elementwise_add_fuse_pass.py | 136 +------------- 7 files changed, 113 insertions(+), 295 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 03da1289205e4..8eb1b64a2763a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2069,6 +2069,29 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, return out_var; } +PDNode *patterns::ResidualElementwise::operator()( + PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, + bool as_x) { + auto elementwise_op = + pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); + + if (as_x) { + op_var->AsInput()->assert_is_op_input(elementwise_type, "X"); + residual_var->AsInput()->assert_is_op_input(elementwise_type, "Y"); + } else { + op_var->AsInput()->assert_is_op_input(elementwise_type, "Y"); + residual_var->AsInput()->assert_is_op_input(elementwise_type, "X"); + } + auto out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output(elementwise_type, "Out"); + + elementwise_op->LinksFrom({op_var, residual_var}); + elementwise_op->LinksTo({out_var}); + + return out_var; +} + PDNode *patterns::Concat::operator()() { auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1f253c6b91043..434ede6cf7a3b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1032,6 +1032,22 @@ struct Elementwise : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; +// Residual Elementwise ops +// This pattern allows operator output to be X or Y +// and residual data Y or X, based on as_x flag +struct ResidualElementwise : public PatternBase { + ResidualElementwise(PDPattern* pattern, const std::string& name_scope, + bool as_x) + : PatternBase(pattern, name_scope, "residual_elementwise") {} + PDNode* operator()(PDNode* op_var, PDNode* residual_var, + const std::string elementwise_type, bool as_x); + + PATTERN_DECL_NODE(operator_output); + PATTERN_DECL_NODE(residual_data); + PATTERN_DECL_NODE(elementwise_op); + PATTERN_DECL_NODE(elementwise_out); +}; + // Transpose op // Forward pass for transpose. // transpose_out is a result of the operator. diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc index 262a523bd8e0e..b06314563025a 100644 --- a/paddle/fluid/framework/ir/graph_traits.cc +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "paddle/fluid/framework/ir/graph_traits.h" namespace paddle { @@ -23,6 +26,51 @@ namespace ir { // class Node; +bool IsReachable(ir::Graph *graph, Node *from, Node *to) { + if (from == to) { + return true; + } + + std::map visited; + + for (auto &node : GraphTraits::DFS(*graph)) { + visited[&node] = false; + } + + visited[from] = true; + + std::list queue; + queue.push_back(from); + + while (!queue.empty()) { + auto cur = FindNode(graph, queue.front()); + queue.pop_front(); + + if (!cur) return false; + + for (const auto &n : cur->outputs) { + if (n == to) { + return true; + } + + if (!visited[n]) { + visited[n] = true; + queue.push_back(n); + } + } + } + return false; +} + +Node *FindNode(ir::Graph *graph, const Node *node) { + for (const auto &n : graph->Nodes()) { + if (n == node) { + return n; + } + } + return nullptr; +} + NodesDFSIterator::NodesDFSIterator(const std::vector &source) { for (auto *x : source) stack_.push(x); } diff --git a/paddle/fluid/framework/ir/graph_traits.h b/paddle/fluid/framework/ir/graph_traits.h index a54cc61a63fde..7e313e17f422e 100644 --- a/paddle/fluid/framework/ir/graph_traits.h +++ b/paddle/fluid/framework/ir/graph_traits.h @@ -29,6 +29,9 @@ namespace ir { class Graph; class Node; +bool IsReachable(ir::Graph *graph, Node *from, Node *to); +Node *FindNode(ir::Graph *graph, const Node *node); + template class iterator_range { IteratorT begin_, end_; diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index fc2758c273450..16c4f251e0bde 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -14,12 +14,6 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" -#include -#include -#include -#include -#include - #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/string/pretty_log.h" @@ -28,60 +22,6 @@ namespace paddle { namespace framework { namespace ir { -bool IsReachable(ir::Graph* graph, Node* from, Node* to) { - auto find_node = [](ir::Graph* graph, const Node* node) -> Node* { - for (auto n : graph->Nodes()) { - if (n == node) { - return n; - } - } - - return nullptr; - }; - - if (from == to) { - return true; - } - - std::map visited; - - for (auto& node : GraphTraits::DFS(*graph)) { - visited[&node] = false; - } - - visited[from] = true; - - std::list queue; - queue.push_back(from); - - while (!queue.empty()) { - auto cur = find_node(graph, queue.front()); - queue.pop_front(); - - if (!cur) return false; - - for (auto n : cur->outputs) { - if (n == to) { - return true; - } - - if (!visited[n]) { - visited[n] = true; - queue.push_back(n); - } - } - } - return false; -} - -template -paddle::optional HasAttribute(const Node& op, const std::string& attr) { - if (op.Op()->HasAttr(attr)) - return BOOST_GET_CONST(T, op.Op()->GetAttr(attr)); - else - return paddle::none; -} - ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { AddOpCompat(OpCompat("conv2d")) .AddInput("Input") @@ -136,89 +76,22 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { .End(); } -GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( - const std::string& name_scope, - const GraphWithStats& graph_with_stats) const { - GraphPatternDetector gpd; - auto pattern = gpd.mutable_pattern(); - - patterns::Conv conv_pattern{pattern, name_scope}; - auto conv_output = conv_pattern(); - - patterns::Elementwise elementwise_pattern{pattern, name_scope}; - elementwise_pattern( - conv_output, pattern->NewNode(elementwise_pattern.elementwise_y_repr()), - "elementwise_add"); - conv_output->AsIntermediate(); - - int found_conv_as_x_count = 0; - - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_identity, elementwise_y, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, - elementwise_pattern); - - if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return; - - if (!IsReachable(g, elementwise_identity, conv_output)) return; - - if (HasFusedActivation(conv_op)) return; - - if (!IsCompat(subgraph, g)) { - LOG(WARNING) - << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; - return; - } - - conv_op->Op()->SetInput("ResidualData", {elementwise_identity->Name()}); - conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); - conv_op->Op()->SetAttr("fuse_residual_connection", true); - - GraphSafeRemoveNodes(g, {conv_output, elementwise_op}); - - IR_NODE_LINK_TO(elementwise_identity, conv_op); - IR_NODE_LINK_TO(conv_op, elementwise_out); - - found_conv_as_x_count++; - }; - - gpd(graph_with_stats.first, handler); - if (!Has("disable_logs") || !Get("disable_logs")) { - std::stringstream msg_ss; - msg_ss << "--- Fused " << found_conv_as_x_count - << " conv (as x) + elementwise_add patterns"; - paddle::string::PrettyLogDetail(msg_ss.str().c_str()); - } - - return std::make_pair(graph_with_stats.first, - found_conv_as_x_count + graph_with_stats.second); -} - -GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( - const std::string& name_scope, - const GraphWithStats& graph_with_stats) const { +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( + const std::string& name_scope, const GraphWithStats& graph_with_stats, + bool as_x) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::Conv conv_pattern{pattern, name_scope}; auto conv_output = conv_pattern(); - patterns::Elementwise elementwise_pattern{pattern, name_scope}; + patterns::ResidualElementwise elementwise_pattern{pattern, name_scope, as_x}; elementwise_pattern( - pattern->NewNode(elementwise_pattern.elementwise_x_repr()), conv_output, - "elementwise_add"); + conv_output, pattern->NewNode(elementwise_pattern.residual_data_repr()), + "elementwise_add", as_x); conv_output->AsIntermediate(); - int found_conv_as_y_count = 0; + int found_conv_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -229,15 +102,13 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op, elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, + GET_IR_NODE_FROM_SUBGRAPH(residual_data, residual_data, elementwise_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, elementwise_pattern); if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return; - - if (!IsReachable(g, elementwise_x, conv_output)) return; - + if (!IsReachable(g, residual_data, conv_output)) return; if (HasFusedActivation(conv_op)) return; if (!IsCompat(subgraph, g)) { @@ -246,28 +117,29 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( return; } - conv_op->Op()->SetInput("ResidualData", {elementwise_x->Name()}); + conv_op->Op()->SetInput("ResidualData", {residual_data->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); conv_op->Op()->SetAttr("fuse_residual_connection", true); GraphSafeRemoveNodes(g, {conv_output, elementwise_op}); - IR_NODE_LINK_TO(elementwise_x, conv_op); + IR_NODE_LINK_TO(residual_data, conv_op); IR_NODE_LINK_TO(conv_op, elementwise_out); - found_conv_as_y_count++; + found_conv_count++; }; gpd(graph_with_stats.first, handler); if (!Has("disable_logs") || !Get("disable_logs")) { std::stringstream msg_ss; - msg_ss << "--- Fused " << found_conv_as_y_count - << " conv (as y) + elementwise_add patterns"; + std::string fusionMode = as_x ? "x" : "y"; + msg_ss << "--- Fused " << found_conv_count << " conv (as " << fusionMode + << ") + elementwise_add patterns"; paddle::string::PrettyLogDetail(msg_ss.str().c_str()); } return std::make_pair(graph_with_stats.first, - found_conv_as_y_count + graph_with_stats.second); + found_conv_count + graph_with_stats.second); } GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( @@ -308,7 +180,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( if (!IsCompat(subgraph, g)) { LOG(WARNING) - << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + << "op compat for conv_elementwise_add_mkldnn_fuse_pass failed."; return; } @@ -361,8 +233,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init(name_scope_, graph); auto graph_with_stats = FuseProjectionConv(name_scope_, std::make_pair(graph, 0)); - graph_with_stats = FuseConvAsX(name_scope_, graph_with_stats); - graph_with_stats = FuseConvAsY(name_scope_, graph_with_stats); + graph_with_stats = FuseConv(name_scope_, graph_with_stats, true); + graph_with_stats = FuseConv(name_scope_, graph_with_stats, false); AddStatis(graph_with_stats.second); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index c4351b382187d..7c6e9927163c7 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -14,30 +14,20 @@ #pragma once -#include -#include -#include -#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include - namespace paddle { namespace framework { namespace ir { using GraphWithStats = std::pair; -bool IsReachable(ir::Graph* graph, Node* from, Node* to); - class ResidualConnectionMKLDNNFusePass : public FusePassBase { private: - GraphWithStats FuseConvAsX(const std::string& name_scope, - const GraphWithStats& graph_with_stats) const; - GraphWithStats FuseConvAsY(const std::string& name_scope, - const GraphWithStats& graph_with_stats) const; + GraphWithStats FuseConv(const std::string& name_scope, + const GraphWithStats& graph_with_stats, + bool as_x) const; GraphWithStats FuseProjectionConv( const std::string& name_scope, const GraphWithStats& graph_with_stats) const; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py index 2e84607e2f5c2..58d09a880619c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py @@ -26,7 +26,7 @@ # the two inputs of elementwise_add are tensor -class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest): +class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: attrs = [ program_config.ops[i].attrs @@ -125,139 +125,5 @@ def test(self): quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"]) -''' -class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): - def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - if "elementwise_weight" in program_config.weights: - if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[1]: - if attrs[2]['axis'] != 1: - return False - if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[3]: - if attrs[2]['axis'] != -1: - return False - return True - - def sample_program_config(self, draw): - data_format = draw(st.sampled_from(["NCHW", "NHWC"])) - dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) - groups = draw(st.sampled_from([1, 2, 4])) - paddings = draw(st.sampled_from([[0, 3], [1, 1], [1, 2, 3, 4]])) - strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - axis = draw(st.sampled_from([-1, 0, 1])) - batch_size = draw(st.integers(min_value=1, max_value=4)) - - def generate_input1(): - if data_format == "NCHW": - return np.random.random( - [batch_size, 48, 64, 64]).astype(np.float32) - else: - return np.random.random( - [batch_size, 64, 64, 48]).astype(np.float32) - - def generate_weight1(): - return np.random.random( - [48, int(48 / groups), 3, 3]).astype(np.float32) - - def compute_out_shape(padding_alg): - import paddle - import paddle.nn as nn - - x_var = paddle.uniform( - (batch_size, 48, 64, 64), dtype='float32', min=-1., max=1.) - if padding_alg == "EXPLICIT": - conv = nn.Conv2D(48, 48, (3, 3), strides, paddings, dilations, - 1) - else: - conv = nn.Conv2D(48, 48, (3, 3), strides, padding_alg, - dilations, 1) - y_var = conv(x_var) - return y_var.shape - - def generate_weight2(): - return np.random.random([48]).astype(np.float32) - - if compute_out_shape(padding_algorithm) != (batch_size, 48, 64, 64): - axis = 1 - - relu_op = OpConfig( - type="relu", - inputs={"X": ["input_data1"]}, - outputs={"Out": ["sigmoid_out"]}, - attrs={}) - - conv2d_op = OpConfig( - type="conv2d", - inputs={"Input": ["sigmoid_out"], - "Filter": ["conv_weight"]}, - outputs={"Output": ["conv_output"]}, - attrs={ - "data_format": data_format, - "dilations": dilations, - "padding_algorithm": padding_algorithm, - "groups": groups, - "paddings": paddings, - "strides": strides - }) - - if axis == 0: - elt_op = OpConfig( - type="elementwise_add", - inputs={"X": ["input_data1"], - "Y": ["conv_output"]}, - outputs={"Out": ["elementwise_output"]}, - attrs={'axis': axis}) - else: - elt_op = OpConfig( - type="elementwise_add", - inputs={"X": ["conv_output"], - "Y": ["elementwise_weight"]}, - outputs={"Out": ["elementwise_output"]}, - attrs={'axis': axis}) - - model_net = [relu_op, conv2d_op, elt_op] - - if axis == 0: - program_config = ProgramConfig( - ops=model_net, - weights={ - "conv_weight": - TensorConfig(data_gen=partial(generate_weight1)) - }, - inputs={ - "input_data1": - TensorConfig(data_gen=partial(generate_input1)) - }, - outputs=["elementwise_output"]) - else: - program_config = ProgramConfig( - ops=model_net, - weights={ - "conv_weight": - TensorConfig(data_gen=partial(generate_weight1)), - "elementwise_weight": - TensorConfig(data_gen=partial(generate_weight2)) - }, - inputs={ - "input_data1": - TensorConfig(data_gen=partial(generate_input1)) - }, - outputs=["elementwise_output"]) - - return program_config - - def sample_predictor_configs(self, program_config): - config = self.create_inference_config(use_mkldnn=True) - yield config, ["relu", "conv2d"], (1e-5, 1e-5) - - def test(self): - self.run_and_statis( - quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"]) -''' - if __name__ == "__main__": unittest.main() From 08811d9b873948d2d5b1bf2f9b9811fc7a2d6e60 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 4 Apr 2022 16:38:10 +0800 Subject: [PATCH 64/93] Update sequence_mask related code (#41393) --- python/paddle/fluid/layers/sequence_lod.py | 21 +++++++++---------- .../tests/unittests/test_rnn_decode_api.py | 1 - 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/layers/sequence_lod.py b/python/paddle/fluid/layers/sequence_lod.py index 1758123f0e608..80dc990af4556 100644 --- a/python/paddle/fluid/layers/sequence_lod.py +++ b/python/paddle/fluid/layers/sequence_lod.py @@ -1382,19 +1382,18 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): """ - if _non_static_mode(): + if in_dygraph_mode(): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dygraph_mode(): - if maxlen is not None: - if isinstance(maxlen, core.eager.Tensor): - attrs = ('out_dtype', dtype) - out = _C_ops.sequence_mask(x, maxlen, *attrs) - else: - attrs = ('out_dtype', dtype, 'maxlen', maxlen) - out = _C_ops.sequence_mask(x, None, *attrs) - out.stop_gradient = True - return out + if maxlen is not None: + if isinstance(maxlen, core.eager.Tensor): + attrs = ('out_dtype', dtype) + out = _C_ops.sequence_mask(x, maxlen, *attrs) + else: + attrs = ('out_dtype', dtype, 'maxlen', maxlen) + out = _C_ops.sequence_mask(x, None, *attrs) + out.stop_gradient = True + return out helper = LayerHelper('sequence_mask', **locals()) out = helper.create_variable_for_type_inference(dtype=dtype) diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index bf848357e3195..dacb7a5b59957 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -716,7 +716,6 @@ def make_inputs(self): def func_check_output(self): self.setUp() self.make_inputs() - self.make_inputs() self.check_output() def test_check_output(self): From 780c7a1dadd741243f32feb30c665a16fe07526d Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 4 Apr 2022 16:38:48 +0800 Subject: [PATCH 65/93] [Eager] Support test_var_base bf16 case (#41377) * [Eager]Polish enable/disable_legacy_dygraph logic * fix test_var_base print_tensor * fix bug caused by arange * Updated bf16 cast case * BF16 astype to float32 Co-authored-by: Aurelius84 Co-authored-by: pangyoki Co-authored-by: zyfncg --- .../fluid/tests/unittests/test_var_base.py | 56 ++++++++++++++++--- python/paddle/tensor/to_string.py | 9 ++- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index b648caf750e96..b426b0d810ac5 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -996,7 +996,7 @@ def _assert_to_static(self, var_base, static_var, is_param=False): self.assertListEqual(list(var_base.shape), list(static_var.shape)) - def test_tensor_str(self): + def func_test_tensor_str(self): paddle.enable_static() paddle.disable_static(paddle.CPUPlace()) paddle.seed(10) @@ -1016,7 +1016,12 @@ def test_tensor_str(self): self.assertEqual(a_str, expected) paddle.enable_static() - def test_tensor_str2(self): + def test_tensor_str(self): + with _test_eager_guard(): + self.func_test_tensor_str() + self.func_test_tensor_str() + + def func_test_tensor_str2(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.to_tensor([[1.5111111, 1.0], [0, 0]]) a_str = str(a) @@ -1028,7 +1033,12 @@ def test_tensor_str2(self): self.assertEqual(a_str, expected) paddle.enable_static() - def test_tensor_str3(self): + def test_tensor_str2(self): + with _test_eager_guard(): + self.func_test_tensor_str2() + self.func_test_tensor_str2() + + def func_test_tensor_str3(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.to_tensor([[-1.5111111, 1.0], [0, -0.5]]) a_str = str(a) @@ -1040,7 +1050,12 @@ def test_tensor_str3(self): self.assertEqual(a_str, expected) paddle.enable_static() - def test_tensor_str_scaler(self): + def test_tensor_str3(self): + with _test_eager_guard(): + self.func_test_tensor_str3() + self.func_test_tensor_str3() + + def func_test_tensor_str_scaler(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.to_tensor(np.array(False)) a_str = str(a) @@ -1051,7 +1066,12 @@ def test_tensor_str_scaler(self): self.assertEqual(a_str, expected) paddle.enable_static() - def test_tensor_str_shape_with_zero(self): + def test_tensor_str_scaler(self): + with _test_eager_guard(): + self.func_test_tensor_str_scaler() + self.func_test_tensor_str_scaler() + + def func_test_tensor_str_shape_with_zero(self): paddle.disable_static(paddle.CPUPlace()) x = paddle.ones((10, 10)) y = paddle.fluid.layers.where(x == 0) @@ -1063,7 +1083,12 @@ def test_tensor_str_shape_with_zero(self): self.assertEqual(a_str, expected) paddle.enable_static() - def test_tensor_str_linewidth(self): + def test_tensor_str_shape_with_zero(self): + with _test_eager_guard(): + self.func_test_tensor_str_shape_with_zero() + self.func_test_tensor_str_shape_with_zero() + + def func_test_tensor_str_linewidth(self): paddle.disable_static(paddle.CPUPlace()) paddle.seed(2021) x = paddle.rand([128]) @@ -1091,7 +1116,12 @@ def test_tensor_str_linewidth(self): self.assertEqual(a_str, expected) paddle.enable_static() - def test_tensor_str_linewidth2(self): + def test_tensor_str_linewidth(self): + with _test_eager_guard(): + self.func_test_tensor_str_linewidth() + self.func_test_tensor_str_linewidth() + + def func_test_tensor_str_linewidth2(self): paddle.disable_static(paddle.CPUPlace()) paddle.seed(2021) x = paddle.rand([128]) @@ -1114,7 +1144,12 @@ def test_tensor_str_linewidth2(self): self.assertEqual(a_str, expected) paddle.enable_static() - def test_tensor_str_bf16(self): + def test_tensor_str_linewidth2(self): + with _test_eager_guard(): + self.func_test_tensor_str_linewidth2() + self.func_test_tensor_str_linewidth2() + + def func_tensor_str_bf16(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.to_tensor([[1.5, 1.0], [0, 0]]) a = paddle.cast(a, dtype=core.VarDesc.VarType.BF16) @@ -1128,6 +1163,11 @@ def test_tensor_str_bf16(self): self.assertEqual(a_str, expected) paddle.enable_static() + def test_tensor_str_bf16(self): + with _test_eager_guard(): + self.func_tensor_str_bf16() + self.func_tensor_str_bf16() + def test_print_tensor_dtype(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.rand([1]) diff --git a/python/paddle/tensor/to_string.py b/python/paddle/tensor/to_string.py index 6caa792adb159..a65257b7ee798 100644 --- a/python/paddle/tensor/to_string.py +++ b/python/paddle/tensor/to_string.py @@ -264,6 +264,9 @@ def to_string(var, prefix='Tensor'): def _format_dense_tensor(tensor, indent): + if tensor.dtype == core.VarDesc.VarType.BF16: + tensor = tensor.astype('float32') + np_tensor = tensor.numpy() if len(tensor.shape) == 0: @@ -330,6 +333,10 @@ def sparse_tensor_to_string(tensor, prefix='Tensor'): def tensor_to_string(tensor, prefix='Tensor'): indent = len(prefix) + 1 + dtype = convert_dtype(tensor.dtype) + if tensor.dtype == core.VarDesc.VarType.BF16: + dtype = 'bfloat16' + _template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})" if tensor.is_sparse(): @@ -342,7 +349,7 @@ def tensor_to_string(tensor, prefix='Tensor'): return _template.format( prefix=prefix, shape=tensor.shape, - dtype=tensor.dtype, + dtype=dtype, place=tensor._place_str, stop_gradient=tensor.stop_gradient, indent=' ' * indent, From 50f8e974589e87e7785301c34a211bba3eb454d1 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 4 Apr 2022 16:39:24 +0800 Subject: [PATCH 66/93] [Eager] Support test_var_base _offset in eager mode (#41369) * [Eager]Polish enable/disable_legacy_dygraph logic * Support _offset in eager mode * Update framework.py * Update framework.py Co-authored-by: Aurelius84 --- paddle/fluid/pybind/eager_method.cc | 15 +++++++++++++++ paddle/fluid/pybind/eager_utils.cc | 2 ++ paddle/fluid/pybind/eager_utils.h | 1 + .../paddle/fluid/tests/unittests/test_var_base.py | 7 ++++++- 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 1a7eb629a0eaa..dfe2fab9fc468 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1344,6 +1344,19 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__offset(TensorObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + auto t = std::dynamic_pointer_cast(self->tensor.impl()); + PADDLE_ENFORCE_EQ( + t->IsInitialized(), true, + platform::errors::InvalidArgument("Tensor %s has not been initialized!", + self->tensor.name())); + + return ToPyObject(t->offset()); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + #if defined(PADDLE_WITH_CUDA) static PyObject* tensor_method__uva(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -1472,6 +1485,8 @@ PyMethodDef variable_methods[] = { {"_reset_grad_inplace_version", (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_offset", (PyCFunction)(void (*)(void))tensor__offset, + METH_VARARGS | METH_KEYWORDS, NULL}, #if defined(PADDLE_WITH_CUDA) {"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index bdc96e85e44ae..a6047f36ad98f 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -426,6 +426,8 @@ PyObject* ToPyObject(int value) { return PyLong_FromLong(value); } PyObject* ToPyObject(uint32_t value) { return PyLong_FromUnsignedLong(value); } +PyObject* ToPyObject(size_t value) { return PyLong_FromLong(value); } + PyObject* ToPyObject(int64_t value) { return PyLong_FromLongLong(value); } PyObject* ToPyObject(float value) { return PyLong_FromDouble(value); } diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index bd78342e21f4b..2fe73c24ee3a0 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -55,6 +55,7 @@ framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, PyObject* ToPyObject(int value); PyObject* ToPyObject(uint32_t value); +PyObject* ToPyObject(size_t value); PyObject* ToPyObject(bool value); PyObject* ToPyObject(int64_t value); PyObject* ToPyObject(float value); diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index b426b0d810ac5..11d77ecc6226b 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -1396,7 +1396,7 @@ def test_clear(self): class TestVarBaseOffset(unittest.TestCase): - def test_offset(self): + def func_offset(self): paddle.disable_static() np_x = np.random.random((3, 8, 8)) x = paddle.to_tensor(np_x, dtype="float64") @@ -1405,6 +1405,11 @@ def test_offset(self): actual_x = paddle.to_tensor(actual_x) self.assertEqual(actual_x._offset(), expected_offset) + def test_offset(self): + with _test_eager_guard(): + self.func_offset() + self.func_offset() + class TestVarBaseShareBufferTo(unittest.TestCase): def test_share_buffer_To(self): From 5936fa6e560d3c5fc235d11552760fd0460662be Mon Sep 17 00:00:00 2001 From: From00 Date: Mon, 4 Apr 2022 17:15:25 +0800 Subject: [PATCH 67/93] Add yaml for reduce_sum OP (#41295) * Add yaml for reduce_sum OP * Fix CI errors * Fix CI errors * Fix CI errors * Fix CI errors --- .../fluid/tests/unittests/CMakeLists.txt | 2 +- .../fluid/tests/unittests/test_reduce_op.py | 43 +++++++++++++------ python/paddle/tensor/math.py | 13 +++++- python/paddle/utils/code_gen/api.yaml | 5 ++- python/paddle/utils/code_gen/backward.yaml | 10 +++++ 5 files changed, 55 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 272ca806747ed..4a771990d91e1 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1077,7 +1077,7 @@ set_tests_properties(test_generator_dataloader PROPERTIES TIMEOUT 120) set_tests_properties(test_partial_concat_op PROPERTIES TIMEOUT 120) set_tests_properties(test_fuse_optimizer_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_softmax_with_cross_entropy_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_reduce_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_reduce_op PROPERTIES TIMEOUT 500) set_tests_properties(test_adam_optimizer_fp32_fp64 PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_nn_grad PROPERTIES TIMEOUT 120) set_tests_properties(test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 69693f57bb2f3..01d386724d161 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -26,19 +26,22 @@ class TestSumOp(OpTest): def setUp(self): + self.python_api = paddle.sum self.op_type = "reduce_sum" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + self.attrs = {'dim': [0]} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSumOp_fp16(OpTest): def setUp(self): + self.python_api = paddle.sum self.op_type = "reduce_sum" self.inputs = { 'X': np.random.uniform(0, 0.1, (5, 6, 10)).astype("float16") @@ -50,7 +53,7 @@ def setUp(self): self.gradient = self.calc_gradient() def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def calc_gradient(self): x = self.inputs["X"] @@ -58,7 +61,8 @@ def calc_gradient(self): return grad, def test_check_grad(self): - self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + self.check_grad( + ['X'], 'Out', user_defined_grads=self.gradient, check_eager=True) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -66,6 +70,7 @@ def test_check_grad(self): class TestSumOp_bf16(OpTest): def setUp(self): np.random.seed(100) + self.python_api = paddle.sum self.op_type = "reduce_sum" self.dtype = np.uint16 self.x = np.random.uniform(0, 0.1, (2, 5, 10)).astype(np.float32) @@ -79,12 +84,15 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_eager=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', user_defined_grads=self.gradient) + place, ['X'], + 'Out', + user_defined_grads=self.gradient, + check_eager=True) def calc_gradient(self): x = self.x @@ -94,6 +102,7 @@ def calc_gradient(self): class TestSumOp_fp16_withInt(OpTest): def setUp(self): + self.python_api = paddle.sum self.op_type = "reduce_sum" self.inputs = { # ref to https://en.wikipedia.org/wiki/Half-precision_floating-point_format @@ -107,7 +116,7 @@ def setUp(self): self.gradient = self.calc_gradient() def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def calc_gradient(self): x = self.inputs["X"] @@ -115,41 +124,47 @@ def calc_gradient(self): return grad, def test_check_grad(self): - self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + self.check_grad( + ['X'], 'Out', user_defined_grads=self.gradient, check_eager=True) class TestSumOp5D(OpTest): def setUp(self): + self.python_api = paddle.sum self.op_type = "reduce_sum" self.inputs = { 'X': np.random.random((1, 2, 5, 6, 10)).astype("float64") } + self.attrs = {'dim': [0]} self.outputs = {'Out': self.inputs['X'].sum(axis=0)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSumOp6D(OpTest): def setUp(self): + self.python_api = paddle.sum self.op_type = "reduce_sum" self.inputs = { 'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float64") } + self.attrs = {'dim': [0]} self.outputs = {'Out': self.inputs['X'].sum(axis=0)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSumOp8D(OpTest): def setUp(self): + self.python_api = paddle.sum self.op_type = "reduce_sum" self.inputs = { 'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float64") @@ -158,10 +173,10 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 3))} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) @skip_check_grad_ci( diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 3408dd7ce9384..d2ed985fb8651 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -904,7 +904,18 @@ def get_dtype(x, dtype): return (False, src_type) dtype_flag, dtype = get_dtype(x, dtype) - if paddle.in_dynamic_mode(): + + if in_dygraph_mode(): + if reduce_all_flag: + axis = range(len(x.shape)) + else: + axis = axis if axis != None and axis != [] else [0] + + out_dtype = convert_np_dtype_to_dtype_(dtype) + out = _C_ops.final_state_sum(x, axis, out_dtype, keepdim) + return out + + if _in_legacy_dygraph(): axis = axis if axis != None and axis != [] else [0] if dtype_flag: return _C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim, diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 2b0c562dbf9bd..b137399b71c88 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1596,13 +1596,14 @@ # no_need_buffer : x, y - api : sum - args : (Tensor x, int64_t[] axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) - output : Tensor + args : (Tensor x, int64_t[] dims={}, DataType out_dtype=paddle::experimental::DataType::UNDEFINED, bool keep_dim=false) + output : Tensor(out) infer_meta : func : SumInferMeta kernel : func : sum data_type : x + backward : sum_grad # take_along_axis - api : take_along_axis diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index cbcfc02ea0992..c6951fa8fc1d4 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1152,6 +1152,16 @@ kernel : func : subtract_grad +- backward_api : sum_grad + forward : sum (Tensor x, int64_t[] dims={}, DataType out_dtype=paddle::experimental::DataType::UNDEFINED, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int64_t[] dims, bool keep_dim, bool reduce_all=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : sum_grad + - backward_api : take_along_axis_grad forward : take_along_axis (Tensor x, Tensor index, int axis) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, int axis) From a2b80145eb4ea69fe5853a966a973576e7301e8c Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 4 Apr 2022 17:16:48 +0800 Subject: [PATCH 68/93] [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run (#41306) * [Refactor] refactored eager_gen.py PR #2 * [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes * Fixed minor issue * Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition * Fixed issues * Supported higher-order grad node generation * [DoubleGrad PR #4] Supported higher-order GradNode generation * [DoubleGrad #4] Bug Fixes to Double Grad Node Generation * Fixed yaml typo * Fixed yaml typo * fixed minor issues * [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad() * Fixed minor issue * Fixed CI-Inference issue * Fixed CI-inference issues * [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run * Fixed minor issues * Fixed issue with backward graph construction logic * Fixed implementation issues with backward graph reconstruction * Fixed unittest issue * Fixed issues --- .../eager/accumulation/accumulation_node.h | 15 ++- .../eager_generated/backwards/scale_node.h | 11 +- .../auto_code_generator/eager_generator.cc | 18 +-- .../final_state_generator/eager_gen.py | 20 +-- paddle/fluid/eager/backward.cc | 116 +++++++++++++++++- .../custom_operator/custom_operator_node.h | 17 +-- paddle/fluid/eager/grad_node_info.cc | 4 + paddle/fluid/eager/grad_node_info.h | 22 +++- paddle/fluid/eager/pylayer/py_layer_node.h | 11 +- .../data_structure_tests/grad_node_test.h | 12 +- .../eager/to_static/run_program_op_node.h | 10 +- .../unittests/test_imperative_double_grad.py | 35 ++++++ 12 files changed, 237 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/eager/accumulation/accumulation_node.h b/paddle/fluid/eager/accumulation/accumulation_node.h index 2e38d7e9e91e2..38d5533c3d606 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.h +++ b/paddle/fluid/eager/accumulation/accumulation_node.h @@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase { // Constructor: configure fwd input tensors to grad node explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) { VLOG(6) << "Construct GradNodeAccumulation"; - weak_grad_ = meta->WeakGrad(); + if (meta) { + weak_grad_ = meta->WeakGrad(); + } + SetDefaultGradInOutMeta(); } @@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase { void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } - std::string name() { return "GradNodeAccumulation"; } /** @@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase { inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; } void ApplyReduceHooks(); + std::shared_ptr Copy() const override { + return std::shared_ptr( + new GradNodeAccumulation(nullptr)); + } + private: std::weak_ptr weak_grad_; diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h index 0b942d2a06707..dd61ddc486eef 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h @@ -44,11 +44,6 @@ class GradNodeScale : public GradNodeBase { void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } - void SetTensorWrappers_X( const std::vector& tensors); @@ -56,6 +51,12 @@ class GradNodeScale : public GradNodeBase { std::string name() override { return ""; } // Members: define fwd input tensors // For Scale there is no fwd input tensor needed + + std::shared_ptr Copy() const override { + auto copied_node = std::make_shared(*this); + return copied_node; + } + private: float scale_{1.0}; }; diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index f5bdbcd968452..b1be15ac86ade 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -2479,22 +2479,23 @@ static std::string GenerateGradNodeHeaderContents( "\n" " void ClearTensorWrappers() override { \n" "%s\n" - " is_tensor_wrappers_cleared = true;\n" + " SetIsTensorWrappersCleared(true);\n" " }\n" " std::string name() override { return \" GradNode%s \"; } \n " "\n" + "std::shared_ptr Copy() const override {{\n " + " auto copied_node = std::shared_ptr(new " + "GradNode%s(*this));\n " + " return copied_node;\n " + "}}\n " + "\n" " // SetX, SetY, ...\n" "%s\n" " // SetAttrMap\n" "%s\n" - " bool IsTensorWrappersCleared() override { \n" - " return is_tensor_wrappers_cleared;\n" - " }\n" " private:\n" " // TensorWrappers\n" "%s\n" - " bool is_tensor_wrappers_cleared = false;\n" - "\n" " // Attribute Map\n" "%s\n" "};"; @@ -2601,8 +2602,9 @@ static std::string GenerateGradNodeHeaderContents( std::string grad_node_str = paddle::string::Sprintf( GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type, - op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str, - set_attr_map_str, tensor_wrapper_members_str, attr_members_str); + op_type, clear_tensor_wrappers_str, op_type, op_type, op_type, + set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str, + attr_members_str); return grad_node_str; } diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 3a7e5fbcc0f86..12738b7206276 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -125,7 +125,13 @@ class {} : public egr::GradNodeBase {{ void ClearTensorWrappers() override {{ {} - is_tensor_wrappers_cleared = true; + SetIsTensorWrappersCleared(true); + }} + + std::shared_ptr Copy() const override {{ + auto copied_node = std::shared_ptr<{}>(new {}(*this)); + + return copied_node; }} // SetTensorWrapperX, SetTensorWrapperY, ... @@ -133,15 +139,10 @@ class {} : public egr::GradNodeBase {{ // SetAttributes {} - bool IsTensorWrappersCleared() override {{ - return is_tensor_wrappers_cleared; - }} private: // TensorWrappers {} - bool is_tensor_wrappers_cleared = false; - // Attributes {} }}; @@ -1218,9 +1219,10 @@ def GenerateNodeDeclaration(self): grad_node_name = GetGradNodeName(forward_op_name) self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( grad_node_name, grad_node_name, grad_node_name, grad_node_name, - grad_node_name, clear_tensor_wrapper_str, - set_tensor_wrapper_methods_str, set_attribute_methods_str, - tensor_wrapper_members_str, attribute_members_str) + grad_node_name, clear_tensor_wrapper_str, grad_node_name, + grad_node_name, set_tensor_wrapper_methods_str, + set_attribute_methods_str, tensor_wrapper_members_str, + attribute_members_str) logging.info(f"Generated Node Declaration: {self.node_declaration_str}") diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index ed286dd5fd960..3e86ad6f59b53 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -50,7 +50,16 @@ class GeneralGrad { for (size_t i = 0; i < num_inputs; i++) { AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(inputs[i]); - auto target_node = auto_grad_meta->GetMutableGradNode().get(); + auto* target_node = auto_grad_meta->GetMutableGradNode().get(); + + if (orig_to_copied_node_mapping_.count(target_node)) { + target_node = orig_to_copied_node_mapping_[target_node]; + } else { + VLOG(6) << "Unable to find target node in " + "orig_to_copied_node_mapping_, likely indicating an " + "unused input"; + } + PADDLE_ENFORCE_NOT_NULL(target_node, paddle::platform::errors::Fatal( "There is no grad op for %s:[%d] or it's" @@ -249,7 +258,15 @@ class GeneralGrad { for (size_t i = 0; i < inputs.size(); ++i) { auto& input = inputs[i]; AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input); - auto target_node = auto_grad_meta->GetMutableGradNode().get(); + + auto* target_node = auto_grad_meta->GetMutableGradNode().get(); + if (orig_to_copied_node_mapping_.count(target_node)) { + target_node = orig_to_copied_node_mapping_[target_node]; + } else { + VLOG(6) << "Unable to find target node in " + "orig_to_copied_node_mapping_, likely indicating an unused " + "input"; + } auto iter = results_map.find(target_node); if (iter != results_map.end()) { @@ -326,6 +343,78 @@ class GeneralGrad { potential_stop_nodes.clear(); depending_nodes.clear(); results_map.clear(); + copied_grad_nodes_.clear(); + orig_to_copied_node_mapping_.clear(); + } + + GradNodeBase* CopyGradNode(const std::shared_ptr& orig_node) { + if (orig_to_copied_node_mapping_.count(orig_node.get())) { + return orig_to_copied_node_mapping_[orig_node.get()]; + } + std::shared_ptr copied_node = orig_node->Copy(); + + // Save node and update mapping + orig_to_copied_node_mapping_[orig_node.get()] = copied_node.get(); + copied_grad_nodes_.push_back(copied_node); + + return copied_node.get(); + } + + void ReconstructBackwardGraph( + const std::queue& orig_init_queue) { + std::queue queue = orig_init_queue; + std::unordered_set visited; + + // BFS and recursively copy the grad nodes + while (!queue.empty()) { + GradNodeBase* orig_node = queue.front(); + queue.pop(); + if (visited.count(orig_node)) { + continue; + } + visited.insert(orig_node); + + PADDLE_ENFORCE( + orig_to_copied_node_mapping_.count(orig_node), + paddle::platform::errors::Fatal( + "Cannot reconstruct backward graph," + "unable to find copied target for certain grad node.")); + GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node]; + + const std::vector>& orig_edges = orig_node->GetEdges(); + std::vector>& copied_edges = + copied_node->GetMutableEdges(); + for (size_t i = 0; i < orig_edges.size(); i++) { + for (size_t j = 0; j < orig_edges[i].size(); j++) { + const Edge& orig_edge = orig_edges[i][j]; + Edge& copied_edge = copied_edges[i][j]; + + std::shared_ptr orig_next_node = + orig_edge.GetMutableGradNode(); + if (!orig_next_node) continue; + + // Copy Next Node + std::shared_ptr copied_next_node; + if (orig_to_copied_node_mapping_.count(orig_next_node.get())) { + copied_next_node = + orig_to_copied_node_mapping_[orig_next_node.get()] + ->shared_from_this(); + + } else { + copied_next_node = orig_next_node->Copy(); + orig_to_copied_node_mapping_[orig_next_node.get()] = + copied_next_node.get(); + copied_grad_nodes_.push_back(copied_next_node); + } + + // Update Edge's Grad Node + copied_edge.SetGradNode(copied_next_node); + + // Update BFS queue + queue.push(orig_next_node.get()); + } + } + } } private: @@ -345,6 +434,10 @@ class GeneralGrad { std::unordered_set /* pre nodes */> depending_nodes; std::unordered_map results_map; + + std::vector> copied_grad_nodes_; + std::unordered_map orig_to_copied_node_mapping_; + DISABLE_COPY_AND_ASSIGN(GeneralGrad); }; @@ -444,6 +537,7 @@ std::vector RunBackward( // 1. Init queue with starting nodes // 2. Prepare initial input buffers std::queue queue; + std::queue orig_queue; std::unordered_map> node_input_buffers_dict; for (size_t i = 0; i < tensors.size(); i++) { @@ -468,6 +562,16 @@ std::vector RunBackward( // TODO(zhanlve): Copy and Modify GradNode if is_general_grad GradNodeBase* grad_node = shared_grad_node.get(); + if (is_general_grad) { + // Save orig grad node + orig_queue.push(grad_node); + + // Replace grad_node with copied grad_node + grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node); + + // Record potential startup grad node + GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node); + } // Prepare GradTensorHolder if (!node_input_buffers_dict.count(grad_node)) { @@ -504,9 +608,11 @@ std::vector RunBackward( // Prepare queue, potential startup_nodes queue.push(grad_node); - if (is_general_grad) { - GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node); - } + } + + if (is_general_grad) { + // Copy Backward Graph + GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue); } VLOG(6) << "Update In degree Map for backward"; diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.h b/paddle/fluid/eager/custom_operator/custom_operator_node.h index 33b56fc8c863a..c483dc0ebd177 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.h +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.h @@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase { } // Functor: perform backward computations - virtual std::vector> operator()( - std::vector>& grads, - bool create_graph = false) // NOLINT + virtual std::vector> + operator()( // NOLINT + std::vector>& grads, // NOLINT + bool create_graph = false) // NOLINT override; std::string name() { @@ -64,13 +65,15 @@ class RunCustomOpNode : public GradNodeBase { } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } void SetAttrs(const std::vector& attr) { attrs_ = attr; } + std::shared_ptr Copy() const override { + auto copied_node = + std::shared_ptr(new RunCustomOpNode(*this)); + return copied_node; + } + public: std::unordered_map> fwd_outs; std::unordered_map> fwd_ins; diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 22266ff386293..23c7ea7c5e9b4 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -326,6 +326,10 @@ const std::vector>& GradNodeBase::GetEdges() const { return adj_edges_; } +std::vector>& GradNodeBase::GetMutableEdges() { + return adj_edges_; +} + std::vector> GradNodeBase::ApplyGradientHooks( const std::vector>& tensors) { diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 70fc4afa0ac71..6a70a16a2416f 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -113,7 +113,11 @@ class GradNodeBase : public std::enable_shared_from_this { virtual void ClearTensorWrappers() = 0; - virtual bool IsTensorWrappersCleared() = 0; + /** + * Self-Copy interface designed for use in DoubleGrad + * **/ + virtual std::shared_ptr Copy() const = 0; + /** * AddEdges is designed to set input tensors' backward Node as current * node's Edges. @@ -191,6 +195,16 @@ class GradNodeBase : public std::enable_shared_from_this { /** * GetEdges is designed to get all edges of current node**/ const std::vector>& GetEdges() const; + std::vector>& GetMutableEdges(); + + /** + * The following interfaces are designed for no_need_buffer + * **/ + bool IsTensorWrappersCleared() { return is_tensor_wrappers_cleared_; } + + void SetIsTensorWrappersCleared(bool is_tensor_wrappers_cleared) { + is_tensor_wrappers_cleared_ = is_tensor_wrappers_cleared; + } private: // TODO(zhanlve): Merge adj_edges_ into GradOutMeta @@ -218,6 +232,7 @@ class GradNodeBase : public std::enable_shared_from_this { // We handle complex to real conversion only if any complex GradIn is involved bool need_complex_to_real_ = false; int64_t next_hook_id_{0}; + bool is_tensor_wrappers_cleared_ = false; }; class Edge { @@ -246,6 +261,11 @@ class Edge { return grad_node_; } + void SetGradNode(const std::shared_ptr& node) { + VLOG(6) << "Reseting Edge's Grad Node"; + grad_node_ = node; + } + std::pair GetEdgeRankInfo() const { return std::make_pair(in_slot_id_, in_rank_); } diff --git a/paddle/fluid/eager/pylayer/py_layer_node.h b/paddle/fluid/eager/pylayer/py_layer_node.h index cd0a517afbf0f..f2e50494467c7 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.h +++ b/paddle/fluid/eager/pylayer/py_layer_node.h @@ -40,11 +40,6 @@ class GradNodePyLayer : public GradNodeBase { void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } - std::string name() { return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name); } @@ -72,6 +67,12 @@ class GradNodePyLayer : public GradNodeBase { } } + std::shared_ptr Copy() const override { + auto copied_node = + std::shared_ptr(new GradNodePyLayer(*this)); + return copied_node; + } + private: PyObject* ctx_{nullptr}; PyObject* outputs_{nullptr}; diff --git a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h index dff12fdfc34a1..8500ec79ef9ba 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h +++ b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h @@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase { GradTestNode() : GradNodeBase() { val_ = 1.0; } std::string name() override { return "GradTestNode"; } std::vector> operator()( - std::vector>& grads, + std::vector>& grads, // NOLINT bool create_graph = false) override { val_ = std::dynamic_pointer_cast(grads[0][0].impl()) ->data()[0]; @@ -50,10 +50,14 @@ class GradTestNode : public egr::GradNodeBase { return res; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; + + std::shared_ptr Copy() const override { + { + auto copied_node = std::shared_ptr(new GradTestNode(*this)); + return copied_node; + } } + float val_; }; } // namespace eager_test diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 79703ce06dc9b..46f48778a9656 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -407,10 +407,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } - bool IsTensorWrappersCleared() override { - VLOG(6) << "Do nothing here now"; - return false; - } // SetAttrMap void SetAttrMap(const paddle::framework::AttributeMap &attrs) { @@ -468,6 +464,12 @@ class GradNodeRunProgram : public egr::GradNodeBase { } } + std::shared_ptr Copy() const override { + auto copied_node = + std::shared_ptr(new GradNodeRunProgram(*this)); + return copied_node; + } + private: // TensorWrappers std::vector x_; diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index 9977756f406d5..c9e41fe93ebe1 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -639,5 +639,40 @@ def test_resnet_resnet101(self): self.assertTrue(np.array_equal(egr_g_numpy, g_numpy)) +class TestDoubleGradBasics(TestCase): + def test_matmul(self): + input_numpy = np.ones([3, 3]) * 2 + with _test_eager_guard(): + x = paddle.to_tensor( + input_numpy, stop_gradient=False, dtype='float32') + y = paddle.to_tensor( + input_numpy, stop_gradient=False, dtype='float32') + grad_out = paddle.to_tensor( + np.ones([3, 3]), stop_gradient=False, dtype='float32') + + out = paddle.matmul(x, y, False, False) + new_x_g, new_y_g = paddle.grad( + [out], [x, y], [grad_out], retain_graph=True, create_graph=True) + new_x_g.backward() + + out_ref = np.ones([3, 3]) * 12.0 + self.assertTrue(np.array_equal(out.numpy(), out_ref)) + + new_x_g_ref = np.ones([3, 3]) * 6.0 + new_y_g_ref = np.ones([3, 3]) * 6.0 + self.assertTrue(np.array_equal(new_x_g.numpy(), new_x_g_ref)) + self.assertTrue(np.array_equal(new_y_g.numpy(), new_y_g_ref)) + + x_grad_ref = np.ones([3, 3]) * 0.0 + self.assertTrue(np.array_equal(x.grad.numpy(), x_grad_ref)) + + y_grad_ref = np.ones([3, 3]) * 3.0 + self.assertTrue(np.array_equal(y.grad.numpy(), y_grad_ref)) + + grad_out_grad_ref = np.ones([3, 3]) * 6.0 + self.assertTrue( + np.array_equal(grad_out.grad.numpy(), grad_out_grad_ref)) + + if __name__ == '__main__': unittest.main() From a6b6bcbf52b31012f615aaaa76925dc3b808cebd Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 4 Apr 2022 17:22:43 +0800 Subject: [PATCH 69/93] [Phi] Add softmax with cross entropy infershape & yaml (#41351) * add infershape and forward yaml * add final_state call * add base unittests * add backward yaml and test * fix without softmax test error * add cross_entropy test --- paddle/phi/infermeta/backward.cc | 65 ++++++++ paddle/phi/infermeta/backward.h | 11 ++ paddle/phi/infermeta/binary.cc | 77 +++++++++ paddle/phi/infermeta/binary.h | 11 ++ python/paddle/fluid/layers/loss.py | 15 +- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/test_cross_entropy_loss.py | 38 +++++ .../test_softmax_with_cross_entropy_op.py | 146 +++++++++++++++++- python/paddle/nn/functional/loss.py | 16 +- python/paddle/utils/code_gen/api.yaml | 11 ++ python/paddle/utils/code_gen/backward.yaml | 10 ++ 11 files changed, 390 insertions(+), 11 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e7682d78a14a1..7282c0695086a 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + namespace phi { void BilinearTensorProductGradInferMeta(const MetaTensor& x, @@ -103,6 +105,69 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x, } } +void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, + const MetaTensor& softmax, + const MetaTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* logits_grad, + MetaConfig config) { + auto softmax_dims = softmax.dims(); + auto labels_dims = label.dims(); + auto softmax_rank = softmax_dims.size(); + PADDLE_ENFORCE_GE(axis, + -softmax_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + PADDLE_ENFORCE_LT(axis, + softmax_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + + axis = phi::funcs::CanonicalAxis(axis, softmax_rank); + for (int i = 0; i < softmax_rank; i++) { + if (i != axis) { + if (config.is_runtime || (softmax_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ( + softmax_dims[i], + labels_dims[i], + phi::errors::InvalidArgument( + "Input(Logits) and Input(Label) should in same shape in " + "dimensions except axis.")); + } + } + } + + if (soft_label) { + if (config.is_runtime || + (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(softmax_dims[axis], + labels_dims[axis], + phi::errors::InvalidArgument( + "If Attr(soft_label) == true, " + "the axis dimension of " + "Input(X) and Input(Label) should be equal.")); + } + } else { + if (config.is_runtime || labels_dims[axis] > 0) { + PADDLE_ENFORCE_EQ( + labels_dims[axis], + 1UL, + phi::errors::InvalidArgument("If Attr(soft_label) == false, " + "the axis dimension of " + "Input(Label) should be 1.")); + } + } + + logits_grad->set_dims(softmax.dims()); + logits_grad->set_dtype(softmax.dtype()); +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 4cdc048b24964..92266811de057 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -68,6 +68,17 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x, MetaTensor* dfilter, MetaTensor* ddout); +void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, + const MetaTensor& softmax, + const MetaTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* logits_grad, + MetaConfig config = MetaConfig()); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 60db5d342b8b3..298ad14f9e04b 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { @@ -753,6 +754,82 @@ void CrossInferMeta(const MetaTensor& x, out->share_lod(x); } +void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* softmax, + MetaTensor* loss, + MetaConfig config) { + auto logits_dims = logits.dims(); + auto labels_dims = label.dims(); + auto logits_rank = logits_dims.size(); + PADDLE_ENFORCE_GE(axis, + -logits_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + PADDLE_ENFORCE_LT(axis, + logits_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + + axis = phi::funcs::CanonicalAxis(axis, logits_rank); + for (int i = 0; i < logits_rank; i++) { + if (i != axis) { + if (config.is_runtime || (logits_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[i], + labels_dims[i], + phi::errors::InvalidArgument( + "Input(Logits) and Input(Label) should in " + "same shape in dimensions except axis.")); + } + } + } + + if (axis != logits_rank - 1) { + PADDLE_ENFORCE_EQ( + numeric_stable_mode, + true, + phi::errors::InvalidArgument("Attr(axis) can only be -1 " + "when not in numeric_stable_mode.")); + } + + if (soft_label) { + if (config.is_runtime || (logits_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[axis], + labels_dims[axis], + phi::errors::InvalidArgument( + "If Attr(soft_label) == true, " + "the axis dimension of " + "Input(X) and Input(Label) should be equal.")); + } + } else { + if (config.is_runtime || labels_dims[axis] > 0) { + PADDLE_ENFORCE_EQ( + labels_dims[axis], + 1UL, + phi::errors::InvalidArgument("If Attr(soft_label) == false, " + "the axis dimension of " + "Input(Label) should be 1.")); + } + } + + softmax->set_dims(logits_dims); + softmax->set_dtype(logits.dtype()); + + logits_dims[axis] = 1; + loss->set_dims(logits_dims); + loss->set_dtype(logits.dtype()); + + softmax->share_lod(logits); + loss->share_lod(logits); +} + void DistInferMeta(const MetaTensor& x, const MetaTensor& y, float p, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 296c05756f291..70c3c9dfe849d 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -117,6 +117,17 @@ void CrossInferMeta(const MetaTensor& x, int axis, MetaTensor* out); +void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* softmax, + MetaTensor* loss, + MetaConfig config = MetaConfig()); + void DistInferMeta(const MetaTensor& x, const MetaTensor& y, float p, diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index a1cebc2f369bd..1efcbe4ee8871 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -21,7 +21,7 @@ from . import nn from .layer_function_generator import templatedoc from ..layer_helper import LayerHelper -from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph +from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph, in_dygraph_mode from .. import core from ..data_feeder import check_variable_and_dtype, check_type from ..param_attr import ParamAttr @@ -1267,10 +1267,15 @@ def softmax_with_cross_entropy(logits, ignore_index, 'numeric_stable_mode', numeric_stable_mode, 'axis', axis) else: - softmax, loss = _C_ops.softmax_with_cross_entropy( - logits, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', numeric_stable_mode, - 'axis', axis) + if in_dygraph_mode(): + softmax, loss = _C_ops.final_state_cross_entropy_with_softmax( + logits, label, soft_label, True, numeric_stable_mode, + ignore_index, axis) + if _in_legacy_dygraph(): + softmax, loss = _C_ops.softmax_with_cross_entropy( + logits, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', numeric_stable_mode, + 'axis', axis) if not return_softmax: return loss else: diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4a771990d91e1..81849606370d6 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -969,6 +969,7 @@ set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_profiler PROPERTIES TIMEOUT 120) set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120) set_tests_properties(test_cross_entropy2_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 150) set_tests_properties(test_fetch_unmerged PROPERTIES TIMEOUT 120) set_tests_properties(test_gru_unit_op PROPERTIES TIMEOUT 120) set_tests_properties(test_activation_nn_grad PROPERTIES TIMEOUT 200) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index d3ed76e34a614..4402d875a41f6 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -21,6 +21,7 @@ from test_softmax_op import stable_softmax from test_softmax_with_cross_entropy_op import cross_entropy from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard def log_softmax(x, axis=-1): @@ -1447,6 +1448,43 @@ def test_cross_entropy_loss_2d_sum(self): self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_soft_1d_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_cross_entropy_loss_soft_1d() + self.test_cross_entropy_loss_soft_1d_weight() + self.test_cross_entropy_loss_soft_1d_mean() + self.test_cross_entropy_loss_soft_1d_weight_mean() + + # put all testcases in one test will be failed + def test_soft_2d_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_cross_entropy_loss_soft_2d() + self.test_cross_entropy_loss_soft_2d_weight_mean() + + def test_other_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_cross_entropy_loss_1d_with_mean_ignore() + self.test_cross_entropy_loss_1d_with_mean_ignore_negative() + self.test_cross_entropy_loss_1d_with_weight_mean_ignore() + self.test_cross_entropy_loss_1d_with_weight_mean_ignore_exceedlabel( + ) + self.test_cross_entropy_loss_1d_with_weight_mean() + self.test_cross_entropy_loss_1d_with_weight_sum() + self.test_cross_entropy_loss_1d_with_weight_none() + self.test_cross_entropy_loss_1d_with_weight_none_func() + self.test_cross_entropy_loss_1d_mean() + self.test_cross_entropy_loss_1d_sum() + self.test_cross_entropy_loss_1d_none() + self.test_cross_entropy_loss_2d_with_weight_none() + self.test_cross_entropy_loss_2d_with_weight_axis_change_mean() + self.test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel( + ) + self.test_cross_entropy_loss_2d_with_weight_mean() + self.test_cross_entropy_loss_2d_with_weight_sum() + self.test_cross_entropy_loss_2d_none() + self.test_cross_entropy_loss_2d_mean() + self.test_cross_entropy_loss_2d_sum() + class TestCrossEntropyFAPIError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index 69f6a87dd9ed1..75d09e3df0c30 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -26,7 +26,6 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): if soft_label: return (-label * np.log(softmax)).sum(axis=axis, keepdims=True) - shape = softmax.shape axis %= len(shape) n = int(np.prod(shape[:axis])) @@ -43,6 +42,41 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): return result.reshape(label.shape) +def python_api(logits, + label, + soft_label=False, + use_softmax=True, + numeric_stable_mode=True, + ignore_index=-100, + axis=-1): + # here only can test paddle.nn.functional.softmax_with_cross_entropy, + # the paddle.nn.functional.cross_entropy contains other math ops + return paddle.nn.functional.softmax_with_cross_entropy( + logits, + label, + soft_label=soft_label, + ignore_index=ignore_index, + numeric_stable_mode=numeric_stable_mode, + return_softmax=use_softmax, + axis=axis) + + +def python_core_api_without_softmax(logits, + label, + soft_label=False, + use_softmax=False, + numeric_stable_mode=True, + ignore_index=-100, + axis=-1): + # the API paddle.nn.functional.softmax_with_cross_entropy cannot + # set use_softmax=False, so add a core api manually + assert use_softmax is False + _, loss = paddle._C_ops.final_state_cross_entropy_with_softmax( + logits, label, soft_label, use_softmax, numeric_stable_mode, + ignore_index, axis) + return loss + + class TestSoftmaxWithCrossEntropyOp(OpTest): """ Test softmax with cross entropy operator with discreate one-hot labels. @@ -50,6 +84,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = False self.soft_label = False # explicilty use float32 for ROCm, as MIOpen does not yet support float64 @@ -102,13 +138,27 @@ def setUp(self): self.attrs['axis'] = self.axis def test_check_output(self): + if self.python_api is not None: + self.check_output(check_eager=True) self.check_output() def test_check_grad(self): if core.is_compiled_with_rocm(): + if self.python_api is not None: + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=5e-1, + check_eager=True) # HIP will have accuracy fail when using float32 in CPU place self.check_grad(["Logits"], "Loss", max_relative_error=5e-1) else: + if self.python_api is not None: + self.check_grad( + ["Logits"], + "Loss", + numeric_grad_delta=0.001, + check_eager=True) self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001) @@ -136,6 +186,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.shape = [13, 8] @@ -149,6 +201,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.shape = [13, 8] @@ -165,6 +219,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -178,6 +234,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -191,6 +249,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -204,6 +264,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -226,6 +288,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -239,6 +303,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -252,6 +318,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -265,6 +333,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -287,6 +357,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = False self.soft_label = False self.shape = [13, 8] @@ -300,6 +372,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = False self.soft_label = False self.shape = [13, 8] @@ -313,6 +387,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -326,6 +402,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -343,6 +421,8 @@ def initParams(self): class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -357,6 +437,8 @@ def initParams(self): class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = False self.soft_label = False self.shape = [3, 5, 7, 11] @@ -394,9 +476,14 @@ def setUp(self): self.attrs['axis'] = self.axis def test_check_output(self): + if self.python_api is not None: + self.check_output(atol=1e-2, check_eager=True) self.check_output(atol=1e-2) def test_check_grad(self): + if self.python_api is not None: + self.check_grad( + ["Logits"], "Loss", max_relative_error=0.1, check_eager=True) self.check_grad(["Logits"], "Loss", max_relative_error=0.1) @@ -404,6 +491,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16( TestSoftmaxWithCrossEntropyOpFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -412,6 +501,9 @@ def initParams(self): self.dtype = np.float16 def test_check_grad(self): + if self.python_api is not None: + self.check_grad( + ["Logits"], "Loss", max_relative_error=0.1, check_eager=True) self.check_grad(["Logits"], "Loss", max_relative_error=0.1) @@ -422,6 +514,8 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -431,13 +525,23 @@ def initParams(self): self.use_softmax = True def test_check_output(self): + if self.python_api is not None: + self.check_output(check_eager=True) self.check_output() def test_check_grad(self): if core.is_compiled_with_rocm(): # HIP will have accuracy fail when using float32 in CPU place + if self.python_api is not None: + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=0.1, + check_eager=True) self.check_grad(["Logits"], "Loss", max_relative_error=0.1) else: + if self.python_api is not None: + self.check_grad(["Logits"], "Loss", check_eager=True) self.check_grad(["Logits"], "Loss") @@ -448,6 +552,8 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = False self.soft_label = False self.shape = [41, 37] @@ -460,6 +566,8 @@ def initParams(self): class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -477,6 +585,8 @@ class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -494,6 +604,8 @@ class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -511,6 +623,8 @@ class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -528,6 +642,8 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -546,6 +662,8 @@ class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -559,6 +677,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -572,6 +692,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -585,6 +707,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -598,6 +722,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -611,6 +737,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -624,6 +752,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -637,6 +767,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -650,6 +782,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -663,6 +797,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -676,6 +812,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -689,6 +827,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -706,6 +846,8 @@ class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -724,6 +866,8 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 3748a5904ba96..8a2b5cbb8b334 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1700,7 +1700,8 @@ def cross_entropy(input, (got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) - if in_dynamic_mode(): + + if _non_static_mode(): if soft_label == False: valid_label = paddle.cast( label != ignore_index, dtype=label.dtype) * label @@ -1718,10 +1719,15 @@ def cross_entropy(input, ignore_index, 'numeric_stable_mode', True, 'axis', axis, 'use_softmax', use_softmax) else: - _, out = _C_ops.softmax_with_cross_entropy( - input, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'use_softmax', use_softmax) + if in_dygraph_mode(): + _, out = _C_ops.final_state_cross_entropy_with_softmax( + input, label, soft_label, use_softmax, True, ignore_index, + axis) + if _in_legacy_dygraph(): + _, out = _C_ops.softmax_with_cross_entropy( + input, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', True, 'axis', axis, + 'use_softmax', use_softmax) if weight is not None: diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index b137399b71c88..af4e7a5b3bb32 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -382,6 +382,17 @@ func : cross backward : cross_grad +# Part of python API paddle.nn.functional.cross_entropy +- api : cross_entropy_with_softmax + args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) + output : Tensor(softmax), Tensor(loss) + infer_meta : + func : CrossEntropyWithSoftmaxInferMeta + kernel : + func : cross_entropy_with_softmax + data_type : input + backward : cross_entropy_with_softmax_grad + - api : cumprod args : (Tensor x, int dim) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index c6951fa8fc1d4..f94d0a9e50523 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -223,6 +223,16 @@ kernel : func : cosh_grad +- backward_api : cross_entropy_with_softmax_grad + forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss) + args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) + output : Tensor(input_grad) + infer_meta : + func : CrossEntropyWithSoftmaxGradInferMeta + kernel : + func : cross_entropy_with_softmax_grad + data_type : softmax + - backward_api : cross_grad forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis) From 625dd72276d9673f16ebcc889f145340a73fe679 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 4 Apr 2022 18:41:53 +0800 Subject: [PATCH 70/93] fix recompute (#41396) --- .../distributed/fleet/utils/recompute.py | 147 +++++++++++++++++- .../tests/unittests/test_dygraph_recompute.py | 111 ++++++------- 2 files changed, 191 insertions(+), 67 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 4ccb48ef72e71..c767be77d8384 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -14,9 +14,11 @@ import paddle from paddle.fluid import core -from paddle.autograd import PyLayer +from paddle.autograd import PyLayer, EagerPyLayer + from paddle.fluid import framework import contextlib +from paddle.fluid.framework import in_dygraph_mode import logging logger = logging.getLogger(__name__) @@ -32,7 +34,7 @@ def detach_variable(inputs): out = [] for inp in inputs: - if not isinstance(inp, core.VarBase): + if not isinstance(inp, (core.eager.Tensor, core.VarBase)): out.append(inp) continue @@ -44,7 +46,7 @@ def detach_variable(inputs): def check_recompute_necessary(inputs): if not any(input_.stop_gradient == False for input_ in inputs - if isinstance(input_, paddle.Tensor)): + if isinstance(input_, (core.eager.Tensor, paddle.Tensor))): logger.warn( "[Recompute]: None of the inputs to current recompute block need grad, " "therefore there is NO need to recompute this block in backward !") @@ -60,6 +62,140 @@ def swith_rng_state(rng_state): paddle.set_cuda_rng_state(orig_cuda_rng_state) +class EagerRecomputeFunction(EagerPyLayer): + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + + # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input + # the order of tensors in backward()'s output should be the same as tensors in forward()'s input + # None tensor inputs will be filtered in backward inputs. + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + ctx.save_for_backward(*tensor_inputs) + + # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu. + # one process with multiple gpu and mix-gpu-cpu senarios are not support + if ctx.preserve_rng_state: + cur_device = paddle.get_device() + if 'gpu:' not in cur_device: + raise RuntimeError( + "Recompute with RNG perserve is not support current device: {}.". + format(cur_device)) + ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() + + # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' + else: + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) + + if tracer._amp_dtype == 'float16': + ctx.amp_dtype = 'float16' + elif tracer._amp_dtype in ('bfloat16', 'float32'): + ctx.amp_dtype = 'bfloat16' + else: + raise ValueError("unsupported amp dtype: {}".format( + tracer._amp_dtype)) + + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() + + with paddle.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # TODO need to check the recompute calling is vaild or not + + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensor() + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # paddle.enable_grad() + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # NOTE support AMP + # need restore auto_cast state as well as w/b list + if ctx.preserve_rng_state: + with swith_rng_state(ctx.fw_cuda_rng_state): + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + else: + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, core.eager.Tensor): + outputs = (outputs, ) + assert len(outputs) == len(args) + + # run backward() with only tensor that requires grad + forward_outputs_with_grad = [] + # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output, + # pylayer will force the stop_gradient of attention mask to be False, which will make the number of + # tensor that need grad does not match. + # the following backward_inputs_with_grad is used to avoid this case. + backward_inputs_with_grad = [] + for i in range(len(outputs)): + if isinstance( + outputs[i], + core.eager.Tensor) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + backward_inputs_with_grad.append(args[i]) + + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True, this recompute() is not necessary" + ) + + # actually backward + with paddle.amp.auto_cast(enable=False): + paddle.autograd.backward(forward_outputs_with_grad, + backward_inputs_with_grad) + + grads = tuple( + inp.grad for inp in detached_inputs + if isinstance(inp, core.eager.Tensor)) + return grads + + class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): @@ -315,4 +451,7 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): raise ValueError("Unexpected keyword arguments: " + ",".join( arg for arg in kwargs)) - return RecomputeFunction.apply(function, preserve, *args) + if in_dygraph_mode(): + return EagerRecomputeFunction.apply(function, preserve, *args) + else: + return RecomputeFunction.apply(function, preserve, *args) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py index 4a4bcd2b8163c..fa9ea5d086c03 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -23,6 +23,7 @@ import random import paddle.fluid.layers as layers +from paddle.fluid.framework import _test_eager_guard def get_fc_block(block_idx, input_size, is_last=False): @@ -141,96 +142,75 @@ def run_model(recompute_block=[], class TestPyLayer(unittest.TestCase): - def test_fc_net_with_dropout(self): + def test_base_case(self, enable_autocast=False, pure_fp16=False): def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): self.assertEqual(loss_ref, loss) self.assertEqual(param_ref, param) self.assertEqual(grad_ref, grad) # without recompute - loss_ref, param_ref, grad_ref = run_model(recompute_block=[]) - - # recompute second block - loss, param, grad = run_model(recompute_block=[1]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute fourth block - loss, param, grad = run_model(recompute_block=[3]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute second to fourth block - loss, param, grad = run_model(recompute_block=[1, 2, 3]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute second & fourth block - loss, param, grad = run_model(recompute_block=[1, 3]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - def test_fc_net_without_restore_rng(self): loss_ref, param_ref, grad_ref = run_model( - recompute_block=[2], - recompute_kwargs={"preserve_rng_state": False}, - enable_autocast=True) - - def test_fc_net_with_amp(self): - def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): - self.assertEqual(loss_ref, loss) - self.assertEqual(param_ref, param) - self.assertEqual(grad_ref, grad) - - # without recompute - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[], enable_autocast=True) + recompute_block=[], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) # recompute second block - loss, param, grad = run_model(recompute_block=[1], enable_autocast=True) + loss, param, grad = run_model( + recompute_block=[1], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute fourth block - loss, param, grad = run_model(recompute_block=[3], enable_autocast=True) + loss, param, grad = run_model( + recompute_block=[3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute second to fourth block loss, param, grad = run_model( - recompute_block=[1, 2, 3], enable_autocast=True) + recompute_block=[1, 2, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute second & fourth block loss, param, grad = run_model( - recompute_block=[1, 3], enable_autocast=True) + recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - def test_fc_net_with_fp16(self): - def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): - self.assertEqual(loss_ref, loss) - self.assertEqual(param_ref, param) - self.assertEqual(grad_ref, grad) - - # without recompute - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[], enable_autocast=True, pure_fp16=True) - - # recompute second block - loss, param, grad = run_model( - recompute_block=[1], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): + with _test_eager_guard(): + self.test_base_case() + self.test_base_case() - # recompute fourth block - loss, param, grad = run_model( - recompute_block=[3], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_without_restore_rng(self): + with _test_eager_guard(): + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], + recompute_kwargs={"preserve_rng_state": False}, + enable_autocast=True) - # recompute second to fourth block - loss, param, grad = run_model( - recompute_block=[1, 2, 3], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_amp(self): + with _test_eager_guard(): + self.test_base_case(enable_autocast=True) + self.test_base_case(enable_autocast=True) - # recompute second & fourth block - loss, param, grad = run_model( - recompute_block=[1, 3], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_fp16(self): + with _test_eager_guard(): + self.test_base_case(enable_autocast=True, pure_fp16=True) + self.test_base_case(enable_autocast=True, pure_fp16=True) def test_recompute_kwargs(self): + with _test_eager_guard(): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(ValueError): + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], recompute_kwargs=kwargs) paddle.set_device("gpu") kwargs = {"is_test": False} with self.assertRaises(ValueError): @@ -238,6 +218,11 @@ def test_recompute_kwargs(self): recompute_block=[2], recompute_kwargs=kwargs) def test_recompute_cpu_rng(self): + with _test_eager_guard(): + paddle.set_device("cpu") + with self.assertRaises(RuntimeError): + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) + paddle.set_device("cpu") with self.assertRaises(RuntimeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) From 1b031987c563038dc33370182e978ffe32b54abe Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 4 Apr 2022 18:48:21 +0800 Subject: [PATCH 71/93] [Dygraph] Support sparse tensor in refactored reducer (#40836) * [Dygraph] Support sparse tensor in refactored reducer * add uts * refactor * update * fix bugs --- .../fluid/distributed/collective/reducer.cc | 233 +++++++++++++++--- paddle/fluid/distributed/collective/reducer.h | 3 + .../fluid/tests/unittests/CMakeLists.txt | 7 +- .../parallel_dygraph_sparse_embedding.py | 5 +- .../parallel_dygraph_sparse_embedding_fp64.py | 1 - .../parallel_dygraph_unused_variables.py | 1 - .../test_parallel_dygraph_sparse_embedding.py | 42 ++++ ..._parallel_dygraph_sparse_embedding_gloo.py | 30 +++ ...el_dygraph_sparse_embedding_over_height.py | 27 ++ ...graph_sparse_embedding_over_height_gloo.py | 15 ++ .../test_parallel_dygraph_sync_batch_norm.py | 16 ++ .../test_parallel_dygraph_transformer.py | 16 ++ .../test_parallel_dygraph_transformer_gloo.py | 15 ++ .../test_parallel_dygraph_unused_variables.py | 66 +++++ 14 files changed, 440 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index ec02406efc818..71741515c90d5 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -360,6 +360,7 @@ void EagerReducer::InitializeGroups( is_sparse_gradient_[tensor_indices_.front()]) { // process the sparse gradient. one sparse, one group group.dtype_ = first_var.dtype(); + group.is_sparse_ = true; } else { // process the dense gradient. InitializeDenseGroups(tensor_indices_, &group); @@ -391,6 +392,12 @@ void EagerReducer::InitializeDenseGroups( auto &tensor = tensors_[tensor_index]; auto &tensor_name = tensor.name(); + PADDLE_ENFORCE_EQ(is_sparse_gradient_[tensor_index], false, + platform::errors::PreconditionNotMet( + "Tensor %s's GRAD must be Tensor, but received " + "GRAD is SelectedRows", + tensor_name)); + PADDLE_ENFORCE_EQ(tensor.is_initialized(), true, platform::errors::PreconditionNotMet( "Tensor %s is not initialized.", tensor_name)); @@ -480,6 +487,7 @@ void EagerReducer::PrepareForBackward(const std::vector &outputs) { next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) { group.pending_ = group.tensor_indices_.size(); + group.sparse_contents_ = Tensor(); }); // reinitialize vars_marked_ready_ for next iteration @@ -544,9 +552,6 @@ void EagerReducer::AddDistHook(size_t var_index) { return; } - auto &tensor = tensors_[var_index]; - const auto &grad_node = GetGradNodeFromTensor(&tensor); - VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name() << "@Grad] arrived and triggered disthook"; @@ -608,33 +613,69 @@ void EagerReducer::MarkVarReady(const size_t var_index, auto &group_tensor = group.dense_tensors_[inside_group_index]; const auto length = group.length_[inside_group_index]; - if (is_used_var) { - auto *autograd_meta = tensors_[var_index].get_autograd_meta(); - auto &grad_tensor = static_cast(autograd_meta)->Grad(); - group_tensor - .ShareDataWith( - *(std::dynamic_pointer_cast(grad_tensor.impl()))) - .Resize({grad_tensor.numel()}); - } else { - // TODO(shenliang03): maybe save the memory by avoiding tensor construction - if (!group_tensor.initialized()) { - group_tensor.Resize({static_cast(length)}); - group_tensor.mutable_data(inner_place_, group.dtype_); - } - if (HasGrad(var_index)) { - VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad"; - auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]); + if (!group.is_sparse_) { + if (is_used_var) { + auto *autograd_meta = tensors_[var_index].get_autograd_meta(); + auto &grad_tensor = + static_cast(autograd_meta)->Grad(); group_tensor .ShareDataWith(*( - std::dynamic_pointer_cast(grad_tensor->impl()))) - .Resize({length}); + std::dynamic_pointer_cast(grad_tensor.impl()))) + .Resize({grad_tensor.numel()}); } else { - VLOG(3) << "Tensor[" << tensors_[var_index].name() - << "] doesn't have grad"; - auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_); - group_tensor.Resize({static_cast(length)}); - phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0); + // TODO(shenliang03): maybe save the memory by avoiding tensor + // construction + if (!group_tensor.initialized()) { + group_tensor.Resize({static_cast(length)}); + group_tensor.mutable_data(inner_place_, group.dtype_); + } + if (HasGrad(var_index)) { + VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad"; + auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]); + group_tensor + .ShareDataWith(*(std::dynamic_pointer_cast( + grad_tensor->impl()))) + .Resize({length}); + } else { + VLOG(3) << "Tensor[" << tensors_[var_index].name() + << "] doesn't have grad"; + auto *dev_ctx = + platform::DeviceContextPool::Instance().Get(inner_place_); + group_tensor.Resize({static_cast(length)}); + phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0); + } } + } else { + auto *autograd_meta = tensors_[var_index].get_autograd_meta(); + auto &grad_tensor = static_cast(autograd_meta)->Grad(); + + // process sparse group + PADDLE_ENFORCE_EQ( + HasGrad(var_index), true, + platform::errors::PreconditionNotMet( + "The sparse parameter[%d][%s] should have gradient. " + "Currently, DataParallel does not support sparse " + "parameters without generating gradients during training. " + "For example, if is_sparese=True is used in Embedding, " + "the current step of this parameter cannot generate gradient " + "because of stop_gradient/detatch, where error will occur.", + var_index, tensors_[var_index].name())); + + // need to check tensor type + PADDLE_ENFORCE_EQ( + grad_tensor.is_selected_rows(), true, + platform::errors::PreconditionNotMet( + "The sparse parameter[%d][%s] must have a selectedrows gradient. " + "Before forward pass, the parameter type is inferred to be " + "SelectedRows, but after backward pass, its actual type becomes " + "LodTensor. It is currently not supported by DataParallel. " + "For example, if sparse embedding is used, and the weight of " + "embedding is shared with subsequent dense parameters, then " + "the parameter gradient of the embedding will be converted " + "to dense parameters.", + var_index, tensors_[var_index].name())); + + group.sparse_contents_.set_impl(grad_tensor.impl()); } if (--group.pending_ == 0) { @@ -666,7 +707,11 @@ void EagerReducer::MarkGroupReady(size_t group_index) { for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; ++next_group_) { UNUSED auto &group = groups_[next_group_]; - FusedAllReduceSchedule(&group, next_group_); + if (group.is_sparse_) { + AllReduceSparse(&group, next_group_); + } else { + FusedAllReduceSchedule(&group, next_group_); + } } } @@ -725,6 +770,11 @@ void EagerReducer::ProcessUnusedDenseVars() { const auto inside_group_index = var_locator.inside_group_index; auto &src_tensor = group.dense_tensors_[inside_group_index]; + // sparse no need to check and no support find_unused_parameters + if (group.is_sparse_) { + continue; + } + Tensor grad_value(std::make_shared(src_tensor)); auto dest_var_base = tensors_[var_index]; @@ -739,11 +789,15 @@ void EagerReducer::FinalizeBackward() { groups_need_finalize_ = false; grad_need_hooks_ = false; for (auto &group : groups_) { - group.task->Synchronize(); + if (!group.is_sparse_) { + group.task->Synchronize(); + } } for (auto &group : groups_) { - group.SplitTensors(inner_place_); + if (!group.is_sparse_) { + group.SplitTensors(inner_place_); + } } if (find_unused_vars_each_step_) { @@ -778,6 +832,127 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, // split in FinalizeBackward() } +void EagerReducer::AllReduceSparse(EagerGroup *group, + const int curr_group_index) { + // div nranks + Tensor sparse_tensor(group->sparse_contents_); + paddle::experimental::scale_(sparse_tensor, 1.0 / nranks_, 0.0, false); + + VLOG(3) << "sparse_group [" << curr_group_index << "] start allreduce."; + + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_); + if (platform::is_gpu_place(inner_place_)) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(inner_place_)); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat grad tensors since it's not compiled with NCCL," + "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_cpu_place(inner_place_)) { + dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(inner_place_)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Split grad tensor not supported on place (%s)", inner_place_)); + } + + auto src = std::dynamic_pointer_cast( + group->sparse_contents_.impl()); + const auto &src_rows = src->rows(); + + const auto &rank_ = process_group_->GetRank(); + const auto &size_ = process_group_->GetSize(); + + framework::Vector rows_num_vector(size_); + rows_num_vector[rank_] = static_cast(src_rows.size()); + + Tensor rows_num_tensor = paddle::experimental::empty( + IntArray({static_cast(size_)}), DataType::INT64, inner_place_); + auto *rows_num_dense_tensor = + std::dynamic_pointer_cast(rows_num_tensor.impl()).get(); + framework::TensorFromVector(rows_num_vector, *dev_ctx, + rows_num_dense_tensor); + + distributed::AllreduceOptions opts; + opts.reduce_op = ReduceOp::SUM; + std::vector reduce_tensors = {rows_num_tensor}; + process_group_->AllReduce(reduce_tensors, opts)->Synchronize(); + + framework::TensorToVector(*rows_num_dense_tensor, *dev_ctx, + &rows_num_vector); + dev_ctx->Wait(); + + const auto *cpu_rows_num_ptr = rows_num_vector.data(); + auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + size_, + static_cast(0)); + + VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',') + << ", total rows number: " << rows_num + << ", height: " << src->height(); + + dev_ctx->Wait(); + + if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_, + [&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) { + // During sparse communication, the number of each card is same. + // allgather is used to speed up the allreduce by replacing broadcast. + + VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce"; + + Tensor dst_rows_tensor = + paddle::experimental::empty(IntArray({static_cast(rows_num)}), + DataType::INT64, inner_place_); + Tensor src_rows_tensor = paddle::experimental::empty( + IntArray({static_cast((*src).rows().size())}), DataType::INT64, + inner_place_); + auto *src_rows_dense_tensor = + std::dynamic_pointer_cast(src_rows_tensor.impl()) + .get(); + framework::TensorFromVector((*src).rows(), *dev_ctx, + src_rows_dense_tensor); + + std::vector src_rows_tensors = {src_rows_tensor}; + std::vector dst_rows_tensors = {dst_rows_tensor}; + process_group_->AllGather(src_rows_tensors, dst_rows_tensors) + ->Synchronize(); + + framework::Vector dst_rows_vector(rows_num, 0); + auto *dst_rows_dense_tensor = + std::dynamic_pointer_cast(dst_rows_tensor.impl()) + .get(); + framework::TensorToVector(*dst_rows_dense_tensor, *dev_ctx, + &dst_rows_vector); + dev_ctx->Wait(); + + Tensor src_value_tensor(std::make_shared(src->value())); + std::vector dst_shape = src_value_tensor.shape(); + dst_shape[dst_shape.size() - 2] = rows_num; + auto dst_dense_tensor = std::dynamic_pointer_cast( + paddle::experimental::full(IntArray(dst_shape), 0, + src_value_tensor.dtype(), inner_place_) + .impl()); + + auto dst = + std::make_shared(dst_rows_vector, (*src).height()); + *(dst->mutable_value()) = *dst_dense_tensor; + Tensor dst_value_tensor(std::make_shared(dst->value())); + + std::vector src_value_tensors = {src_value_tensor}; + std::vector dst_value_tensors = {dst_value_tensor}; + process_group_->AllGather(src_value_tensors, dst_value_tensors) + ->Synchronize(); + + src->set_rows(dst_rows_vector); + *(src->mutable_value()) = + *(std::dynamic_pointer_cast(dst_value_tensor.impl())); + } else { + PADDLE_THROW( + platform::errors::Unimplemented("This case is not supported.")); + } +} + std::ostream &operator<<(std::ostream &out, const EagerGroup &group) { const auto &tensors_ = group.tensor_indices_; out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size() diff --git a/paddle/fluid/distributed/collective/reducer.h b/paddle/fluid/distributed/collective/reducer.h index 848277f5fad4e..12c02509884e9 100644 --- a/paddle/fluid/distributed/collective/reducer.h +++ b/paddle/fluid/distributed/collective/reducer.h @@ -47,6 +47,8 @@ std::vector> Eager_AssignGroupBySize( class EagerGroup { public: Tensor dense_contents_; + Tensor sparse_contents_; + bool is_sparse_ = false; // for concat kernel std::vector dense_tensors_; @@ -104,6 +106,7 @@ class EagerReducer { void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkGroupReady(const size_t group_index); void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index); + void AllReduceSparse(EagerGroup *group, const int curr_group_index); void FinalizeBackward(); void TraverseBackwardGraph(const std::vector &outputs); void ProcessUnusedDenseVars(); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 81849606370d6..663dd9b9e1257 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1128,7 +1128,7 @@ set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120) - set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 150) + set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 300) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150) @@ -1153,8 +1153,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 300) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) - set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) - set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200) + set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height PROPERTIES TIMEOUT 150) + set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 150) endif() endif() diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py index 226f1293ef688..33ae0acf43d12 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding.py @@ -42,7 +42,6 @@ def __init__(self, dtype=dtype, is_sparse=is_sparse, param_attr=fluid.ParamAttr( - name='embedding_param', initializer=fluid.initializer.UniformInitializer( low=-init_scale, high=init_scale))) self.softmax_weight = self.create_parameter( @@ -103,8 +102,8 @@ def get_model(self): train_reader = paddle.batch( fake_sample_reader(), batch_size=batch_size, drop_last=True) - optimizer = fluid.optimizer.SGD(learning_rate=0.001, - parameter_list=model.parameters()) + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) return model, train_reader, optimizer diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py index a15b263a29508..b341a227285b1 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py @@ -40,7 +40,6 @@ def __init__(self, self.hidden_size, sparse=True, weight_attr=paddle.ParamAttr( - name='embedding_param', initializer=paddle.nn.initializer.Uniform( low=-init_scale, high=init_scale))) self.softmax_weight = self.create_parameter( diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py index 9f877381101e9..b4dd03aecfaf3 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py @@ -39,7 +39,6 @@ def __init__(self, self.hidden_size, sparse=is_sparse, weight_attr=paddle.ParamAttr( - name='embedding_param', initializer=paddle.nn.initializer.Uniform( low=-init_scale, high=init_scale))) self.softmax_weight = self.create_parameter( diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py index 43907da609803..30349270b9ead 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding.py @@ -64,5 +64,47 @@ def test_sparse_embedding_with_spawn(self): test_class=TestSparseEmbedding, delta=1e-5) +class TestParallelDygraphSparseEmdeddingEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._eager_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sparse_embedding.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingFP64Eager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_sparse_embedding_fp64(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sparse_embedding_fp64.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingSpawnEager(TestDistSpawnRunner): + def _args_config(self, args): + args.eager_mode = True + + def test_sparse_embedding_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbedding, delta=1e-5) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py index 56fcf806c4717..e461bf2a26f41 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py @@ -55,5 +55,35 @@ def test_sparse_embedding_fp64(self): log_name=flag_name) +class TestParallelDygraphSparseEmdeddingEager_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingEagerFP64_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding_fp64(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding_fp64.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py index 9aca448f16121..fb4c992d35fe9 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height.py @@ -48,5 +48,32 @@ def test_sparse_embedding_with_spawn(self): test_class=TestSparseEmbeddingOverHeight, delta=1e-5) +class TestParallelDygraphSparseEmdeddingOverHeightEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sparse_embedding_over_height.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingOverHeightSpawnEager( + TestDistSpawnRunner): + def _args_config(self, args): + args.eager_mode = True + + def test_sparse_embedding_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbeddingOverHeight, delta=1e-5) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py index ba43e26e23a4e..0acec54ca62b3 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py @@ -40,5 +40,20 @@ def test_sparse_embedding(self): log_name=flag_name) +class TestParallelDygraphSparseEmdeddingOverHeightEager_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding_over_height.py", + delta=1e-7, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py index 7cf1e9711b74b..3a7a32c2ec9dc 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py @@ -36,5 +36,21 @@ def test_mnist(self): log_name=flag_name) +class TestParallelDygraphMnistEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sync_batch_norm.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py index e0aab8541a542..2141cceb790fe 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py @@ -65,5 +65,21 @@ def test_transformer(self): log_name=flag_name) +class TestParallelDygraphTransformerEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_transformer(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_transformer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py index d3619cc1b9a00..6d4dd6433ae03 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py @@ -57,5 +57,20 @@ def test_transformer(self): log_name=flag_name) +class TestParallelDygraphTransformerEager_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._gloo_mode = True + self._dygraph = True + + def test_transformer(self): + self.check_with_place( + "parallel_dygraph_transformer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py index 75fa6f7c71d0a..f2225111d1ee7 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py @@ -86,5 +86,71 @@ def test_mnist(self): log_name=flag_name) +class TestParallelDygraphUnusedVarEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_unused_variables.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestDygraphUnusedVarEager(TestParallelDygraphUnusedVar): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + +class TestSparseEmbeddingUnusedVarsSpawnEager(TestDistSpawnRunner): + def _args_config(self, args): + args.eager_mode = True + + def test_mnist_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestSparseEmbeddingUnusedVars, delta=1e-5) + + +class TestParallelDygraphNoVarEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_none_var.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSharedUnusedVariablesEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_shared_unused_var.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + if __name__ == "__main__": unittest.main() From fa250aa13246e456b405973484acff06e6313804 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 4 Apr 2022 18:49:33 +0800 Subject: [PATCH 72/93] Add expand as sigmoid api (#41311) * update epxand and sigmoid with cross entropy * skip expand as infrt check * fix sigmoid cross entropy bug * remove no grad set white list * remove no grad set * fix bug * fix sigmoid error * fix bug --- python/paddle/fluid/layers/loss.py | 4 + .../paddle/fluid/tests/unittests/op_test.py | 2 +- .../unittests/test_bce_with_logits_loss.py | 44 +++-- .../tests/unittests/test_expand_as_v2_op.py | 42 ++--- ...st_sigmoid_cross_entropy_with_logits_op.py | 153 ++++++++++-------- .../unittests/test_sigmoid_focal_loss.py | 6 + python/paddle/nn/functional/loss.py | 16 +- python/paddle/tensor/manipulation.py | 3 + python/paddle/utils/code_gen/api.yaml | 11 ++ python/paddle/utils/code_gen/backward.yaml | 10 ++ tools/infrt/skipped_phi_api.json | 2 +- 11 files changed, 173 insertions(+), 120 deletions(-) diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index 1efcbe4ee8871..f3ebfb9de10cf 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -1463,6 +1463,10 @@ def sigmoid_cross_entropy_with_logits(x, ignore_index=-1, normalize=True) print(loss) """ + + if in_dygraph_mode(): + return _C_ops.final_state_sigmoid_cross_entropy_with_logits( + x, label, normalize, int(ignore_index)) check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], 'sigmoid_cross_entropy_with_logits') diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 60064340b198a..cfe0d4e32ef7a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -2106,7 +2106,7 @@ def _get_dygraph_grad(self, grad_outputs = [] for grad_out_value in user_defined_grad_outputs: grad_outputs.append(paddle.to_tensor(grad_out_value)) - # delete the inputs which no need to calculate grad + # delete the inputs which no need to calculate grad for no_grad_val in no_grad_set: del (inputs[no_grad_val]) diff --git a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py index 153b8fd3e7f6b..ea6d82d15ce0c 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py @@ -17,6 +17,7 @@ import numpy as np import unittest from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard def call_bce_layer(logit, label, weight=None, reduction='mean', @@ -81,23 +82,22 @@ def test_dygraph(place, reduction='mean', pos_weight_np=None, functional=False): - paddle.disable_static() - logit = paddle.to_tensor(logit_np) - label = paddle.to_tensor(label_np) - weight = None - pos_weight = None - if weight_np is not None: - weight = paddle.to_tensor(weight_np) - if pos_weight_np is not None: - pos_weight = paddle.to_tensor(pos_weight_np) - if functional: - dy_res = call_bce_functional(logit, label, weight, reduction, - pos_weight) - else: - dy_res = call_bce_layer(logit, label, weight, reduction, pos_weight) - dy_result = dy_res.numpy() - paddle.enable_static() - return dy_result + with paddle.fluid.dygraph.base.guard(): + logit = paddle.to_tensor(logit_np) + label = paddle.to_tensor(label_np) + weight = None + pos_weight = None + if weight_np is not None: + weight = paddle.to_tensor(weight_np) + if pos_weight_np is not None: + pos_weight = paddle.to_tensor(pos_weight_np) + if functional: + dy_res = call_bce_functional(logit, label, weight, reduction, + pos_weight) + else: + dy_res = call_bce_layer(logit, label, weight, reduction, pos_weight) + dy_result = dy_res.numpy() + return dy_result def calc_bce_with_logits_loss(logit_np, @@ -154,9 +154,19 @@ def test_BCEWithLogitsLoss(self): label_np, reduction=reduction, functional=True) + + with _test_eager_guard(): + eager_functional = test_dygraph( + place, + logit_np, + label_np, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) self.assertTrue(np.allclose(static_functional, dy_functional)) self.assertTrue(np.allclose(dy_functional, expected)) + self.assertTrue(np.allclose(eager_functional, expected)) def test_BCEWithLogitsLoss_weight(self): logit_np = np.random.uniform( diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py index 416a60b8ba200..3bf6868fed9c9 100755 --- a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py @@ -21,78 +21,63 @@ import paddle.fluid as fluid -class TestExpandAsOpRank1(OpTest): +class TestExpandAsBasic(OpTest): def setUp(self): self.op_type = "expand_as_v2" self.python_api = paddle.expand_as x = np.random.rand(100).astype("float64") target_tensor = np.random.rand(2, 100).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [2, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) -class TestExpandAsOpRank2(OpTest): +class TestExpandAsOpRank2(TestExpandAsBasic): def setUp(self): self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as x = np.random.rand(10, 12).astype("float64") target_tensor = np.random.rand(10, 12).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [1, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - -class TestExpandAsOpRank3(OpTest): +class TestExpandAsOpRank3(TestExpandAsBasic): def setUp(self): self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as x = np.random.rand(2, 3, 20).astype("float64") target_tensor = np.random.rand(2, 3, 20).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [1, 1, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} - def test_check_output(self): - self.check_output() - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestExpandAsOpRank4(OpTest): +class TestExpandAsOpRank4(TestExpandAsBasic): def setUp(self): self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as x = np.random.rand(1, 1, 7, 16).astype("float64") target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [4, 6, 1, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - class TestExpandAsV2Error(unittest.TestCase): def test_errors(self): @@ -130,4 +115,5 @@ def test_api(self): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py index 51751588f7b94..e5406f4d0c224 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py @@ -22,6 +22,12 @@ import unittest from paddle.fluid import compiler, Program, program_guard import paddle.fluid as fluid +import paddle + + +def test_fluid_sigmoid(x, label, normalize=False, ignore_index=-100): + return paddle.fluid.layers.sigmoid_cross_entropy_with_logits( + x, label, int(ignore_index), normalize=normalize) class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): @@ -30,6 +36,7 @@ class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 self.inputs = { @@ -49,10 +56,10 @@ def setUp(self): self.outputs = {'Out': -term1 - term2} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): @@ -61,6 +68,7 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 ignore_index = -1 @@ -83,10 +91,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithLogitsOp3(OpTest): @@ -95,6 +103,7 @@ class TestSigmoidCrossEntropyWithLogitsOp3(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 self.inputs = { @@ -114,15 +123,16 @@ def setUp(self): self.outputs = {'Out': -term1 - term2} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithNorm(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 ignore_index = -1 @@ -145,10 +155,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithLogitsOp5(OpTest): @@ -157,6 +167,7 @@ class TestSigmoidCrossEntropyWithLogitsOp5(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = [10, 10] num_classes = 20 self.inputs = { @@ -176,15 +187,16 @@ def setUp(self): self.outputs = {'Out': -term1 - term2} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithNorm2(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = [10, 10] num_classes = 20 ignore_index = -1 @@ -207,68 +219,71 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestSigmoidCrossEntropyWithLogitsOp6(OpTest): - """Test sigmoid_cross_entropy_with_logit_op with binary label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - batch_size = [10, 10] - num_classes = 20 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, tuple(batch_size + [num_classes])) - .astype("float64")), - 'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes])) - .astype("float64") - } - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - self.outputs = {'Out': -term1 - term2} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestSigmoidCrossEntropyWithLogitsOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - - def test_Variable(): - # the input of sigmoid_cross_entropy_with_logits must be Variable. - x1 = fluid.create_lod_tensor( - np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) - lab1 = fluid.create_lod_tensor( - np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) - fluid.layers.sigmoid_cross_entropy_with_logits(x1, lab1) - - self.assertRaises(TypeError, test_Variable) - - def test_dtype(): - # the input dtype of sigmoid_cross_entropy_with_logits must be float16 or float32 or float64 - # float16 only can be set on GPU place - x2 = fluid.layers.data( - name='x2', shape=[3, 4, 5, 6], dtype="int32") - lab2 = fluid.layers.data( - name='lab2', shape=[3, 4, 5, 6], dtype="int32") - fluid.layers.sigmoid_cross_entropy_with_logits(x2, lab2) - - self.assertRaises(TypeError, test_dtype) + self.check_grad(['X'], 'Out', check_eager=True) + + class TestSigmoidCrossEntropyWithLogitsOp6(OpTest): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype("float64")), + 'Label': + np.random.randint(0, 2, tuple(batch_size + [num_classes])) + .astype("float64") + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X'], 'Out', check_eager=True) + + class TestSigmoidCrossEntropyWithLogitsOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of sigmoid_cross_entropy_with_logits must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], + fluid.CPUPlace()) + lab1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], + fluid.CPUPlace()) + fluid.layers.sigmoid_cross_entropy_with_logits(x1, lab1) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of sigmoid_cross_entropy_with_logits must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data( + name='x2', shape=[3, 4, 5, 6], dtype="int32") + lab2 = fluid.layers.data( + name='lab2', shape=[3, 4, 5, 6], dtype="int32") + fluid.layers.sigmoid_cross_entropy_with_logits(x2, lab2) + + self.assertRaises(TypeError, test_dtype) if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py index 2ef04d9cbfa73..15a4827cecba3 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py @@ -18,6 +18,7 @@ import unittest from op_test import OpTest from test_sigmoid_focal_loss_op import sigmoid_focal_loss_forward +from paddle.fluid.framework import _test_eager_guard def call_sfl_functional(logit, @@ -140,6 +141,10 @@ def test_SigmoidFocalLoss(self): dy_result = test_dygraph(place, logit_np, label_np, normalizer_np, alpha, gamma, reduction) + with _test_eager_guard(): + eager_result = test_dygraph( + place, logit_np, label_np, normalizer_np, + alpha, gamma, reduction) expected = calc_sigmoid_focal_loss( logit_np, label_np, normalizer_np, alpha, gamma, reduction) @@ -148,6 +153,7 @@ def test_SigmoidFocalLoss(self): self.assertTrue( np.allclose(static_result, dy_result)) self.assertTrue(np.allclose(dy_result, expected)) + self.assertTrue(np.allclose(eager_result, expected)) def test_SigmoidFocalLoss_error(self): paddle.disable_static() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 8a2b5cbb8b334..593cea2d2cf64 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -259,12 +259,16 @@ def binary_cross_entropy_with_logits(logit, "should be 'sum', 'mean' or 'none', but received %s, which is not allowed." % reduction) - if in_dynamic_mode(): + if _non_static_mode(): one = _varbase_creator(dtype=logit.dtype) _C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False, 'dtype', one.dtype, 'str_value', '1.0', 'shape', [1]) - out = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) + if in_dygraph_mode(): + out = _C_ops.final_state_sigmoid_cross_entropy_with_logits( + logit, label, False, -100) + else: + out = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) if pos_weight is not None: log_weight = _C_ops.elementwise_add( _C_ops.elementwise_mul(label, @@ -2024,12 +2028,16 @@ def sigmoid_focal_loss(logit, "Expected one dimension of normalizer in sigmoid_focal_loss but got {}.". format(normalizer_dims)) - if in_dynamic_mode(): + if _non_static_mode(): one = _varbase_creator(dtype=logit.dtype) _C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False, 'dtype', one.dtype, 'str_value', '1.0', 'shape', logit.shape) - loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) + if in_dygraph_mode(): + loss = _C_ops.final_state_sigmoid_cross_entropy_with_logits( + logit, label, False, -100) + else: + loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) pred = _C_ops.sigmoid(logit) p_t = _C_ops.elementwise_add( _C_ops.elementwise_mul(pred, label), diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index b055abcf845f9..92fec23c6c769 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1837,6 +1837,9 @@ def expand_as(x, y, name=None): np_out = out.numpy() # [[1, 2, 3], [1, 2, 3]] """ + if in_dygraph_mode(): + return _C_ops.final_state_expand_as(x, None, y.shape) + if _non_static_mode(): return _C_ops.expand_as_v2(x, 'target_shape', y.shape) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index af4e7a5b3bb32..4c17644792fbd 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -566,6 +566,17 @@ func : erfinv backward : erfinv_grad +# expand_as +- api : expand_as + args : (Tensor x, Tensor y, int[] target_shape) + output : Tensor + infer_meta : + func : ExpandAsInferMeta + kernel : + func : expand_as + optional : y + backward : expand_as_grad + - api : expm1 args : (Tensor x) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index f94d0a9e50523..da60dae431695 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -373,6 +373,16 @@ kernel : func : erfinv_grad +- backward_api : expand_as_grad + forward : expand_as (Tensor x, Tensor y, int[] target_shape) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int[] target_shape) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : expand_as_grad + - backward_api : expm1_grad forward : expm1 (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 74cb6fb0e5356..5638cf506c84d 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "nll_loss", "dropout", "flatten"], +"phi_apis":["conj", "nll_loss", "flatten", "expand_as", "dropout"], "phi_kernels":["equal_all"] } From 489b8a88a1cd10a4d09ec29ffa23b0834d9b3faf Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Mon, 4 Apr 2022 19:30:16 +0800 Subject: [PATCH 73/93] [Yaml]add clip yaml (#41337) * add clip yaml * import _test_eager_guad * add default value to scalar * add clip_grad default value * fix test failed --- .../fluid/tests/unittests/test_clip_op.py | 10 ++++++++-- python/paddle/tensor/math.py | 18 ++++++++++++++++-- python/paddle/utils/code_gen/api.yaml | 11 +++++++++++ python/paddle/utils/code_gen/backward.yaml | 10 ++++++++++ 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_clip_op.py b/python/paddle/fluid/tests/unittests/test_clip_op.py index 74c5f693a37f1..f4423ccd0294c 100644 --- a/python/paddle/fluid/tests/unittests/test_clip_op.py +++ b/python/paddle/fluid/tests/unittests/test_clip_op.py @@ -20,11 +20,13 @@ import paddle.fluid as fluid from paddle.fluid import Program, program_guard from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard class TestClipOp(OpTest): def setUp(self): self.max_relative_error = 0.006 + self.python_api = paddle.clip self.inputs = {} self.initTestCase() @@ -51,12 +53,12 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) paddle.disable_static() def test_check_grad_normal(self): paddle.enable_static() - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) paddle.disable_static() def initTestCase(self): @@ -228,6 +230,10 @@ def test_clip_dygraph(self): self.assertTrue( np.allclose(out_5.numpy(), (data * 10).astype(np.int64).clip(2, 8))) + def test_eager(self): + with _test_eager_guard(): + self.test_clip_dygraph() + def test_errors(self): paddle.enable_static() x1 = fluid.data(name='x1', shape=[1], dtype="int16") diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index d2ed985fb8651..e4faa573ffb26 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2290,7 +2290,16 @@ def clip(x, min=None, max=None, name=None): min_ = float(np.finfo(np.float32).min) max_ = float(np.finfo(np.float32).max) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if isinstance(min, Variable): + min = min.numpy().item(0) + if isinstance(max, Variable): + max = max.numpy().item(0) + min = min_ if min is None else min + max = max_ if max is None else max + return _C_ops.final_state_clip(x, min, max) + + if _in_legacy_dygraph(): if isinstance(min, Variable): min = min.numpy().item(0) if isinstance(max, Variable): @@ -2350,7 +2359,12 @@ def clip_(x, min=None, max=None, name=None): max = max.numpy().item(0) min = fmin if min is None else min max = fmax if max is None else max - return _C_ops.clip_(x, "min", min, "max", max) + + if in_dygraph_mode(): + return _C_ops.final_state_clip_(x, min, max) + + if _in_legacy_dygraph(): + return _C_ops.clip_(x, "min", min, "max", max) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 4c17644792fbd..08cf04f692806 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -312,6 +312,17 @@ func : cholesky_solve backward : cholesky_solve_grad +- api : clip + args : (Tensor x, Scalar(float) min, Scalar(float) max) + output : Tensor(out) + inplace : (x -> out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : clip + backward : clip_grad + - api : concat args : (Tensor[] x, Scalar(int64_t) axis) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index da60dae431695..570e64dcd5e12 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -179,6 +179,16 @@ kernel : func : cholesky_solve_grad +- backward_api : clip_grad + forward : clip (Tensor x, Scalar min, Scalar max) -> Tensor(out) + args : (Tensor x, Tensor out_grad, Scalar min = 0., Scalar max = 0.) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : clip_grad + - backward_api : concat_grad forward : concat (Tensor[] x, Scalar axis) -> Tensor(out) args : (Tensor[] x, Tensor out_grad, Scalar axis = 0) From ac4a422d5a741093703e0c510a287f7ef8c5c274 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 4 Apr 2022 20:54:16 +0800 Subject: [PATCH 74/93] [Eager]Fix tile API final_state and Backward bug (#41385) * [Eager]Fix tile API final_state bug * fix backward bug --- paddle/fluid/eager/backward.cc | 6 +++--- paddle/fluid/pybind/eager_utils.cc | 6 ++++++ python/paddle/tensor/manipulation.py | 5 +++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 3e86ad6f59b53..d5397e20e7d68 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -580,8 +580,9 @@ std::vector RunBackward( node_input_buffers_dict[grad_node] = std::make_unique(grad_node->InputMeta()); } - - if (grad_tensors.size() > 0) { + bool copy_from_grad_t = + grad_tensors.size() > 0 && grad_tensors[i].initialized(); + if (copy_from_grad_t) { PADDLE_ENFORCE( grad_tensors.size() == tensors.size(), paddle::platform::errors::Fatal( @@ -594,7 +595,6 @@ std::vector RunBackward( // Deep copy node_input_buffers_dict[grad_node]->CopyValueFromTensor( input_info.first, input_info.second, grad_tensors[i]); - } else { VLOG(6) << "Fill grad input tensor " << i << " with 1.0"; // Initialize tensor with 1.0 diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index a6047f36ad98f..ef1359ac04772 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -213,6 +213,9 @@ std::vector CastPyArg2VectorOfTensor( if (PyObject_IsInstance(item, reinterpret_cast(p_tensor_type))) { result.emplace_back(reinterpret_cast(item)->tensor); + } else if (item == Py_None) { + // emplace empty Tensor for None + result.emplace_back(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " @@ -229,6 +232,9 @@ std::vector CastPyArg2VectorOfTensor( if (PyObject_IsInstance(item, reinterpret_cast(p_tensor_type))) { result.emplace_back(reinterpret_cast(item)->tensor); + } else if (item == Py_None) { + // emplace empty Tensor for None + result.emplace_back(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 92fec23c6c769..f1e2938b205c7 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1751,6 +1751,11 @@ def tile(x, repeat_times, name=None): # [[1, 2, 3, 1, 2, 3]] """ if in_dygraph_mode(): + if isinstance(repeat_times, core.eager.Tensor): + assert (repeat_times.ndim == 1, + "Only support ndim == 1 while repeat_times is a Tensor.") + repeat_times = repeat_times.numpy().tolist() + return _C_ops.final_state_tile(x, repeat_times) if _in_legacy_dygraph(): From 1071bafc45d18feb99637bbd130b12fd2d786ee2 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Mon, 4 Apr 2022 21:14:23 +0800 Subject: [PATCH 75/93] quick fix package. (#41339) --- python/setup.py.in | 1 + 1 file changed, 1 insertion(+) diff --git a/python/setup.py.in b/python/setup.py.in index 7f311feb4ee34..a1beab8c665ec 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -278,6 +278,7 @@ packages=['paddle', 'paddle.incubate.checkpoint', 'paddle.incubate.operators', 'paddle.incubate.tensor', + 'paddle.incubate.multiprocessing', 'paddle.incubate.nn', 'paddle.incubate.passes', 'paddle.distribution', From 19cb0d189f53e41e12829da360cd8e605d5c4758 Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Mon, 4 Apr 2022 21:29:18 +0800 Subject: [PATCH 76/93] Table refine: Pull/Push(TableContext) (#41320) * update name * update name * fix test * fix fleet bind * update name * update name * fix test * fix gpups wrapper * remove Push/Pull/Load/Save with context in client and wrapper base class * fix * fix * remove some interface * fix * remove * code style * recover * fix * remove code unused * fix * recover * fix Co-authored-by: esythan --- .../distributed/ps/service/brpc_ps_server.cc | 36 +++++- .../distributed/ps/service/ps_local_client.cc | 60 +++++++++- .../ps/table/common_dense_table.cc | 7 +- .../distributed/ps/table/common_dense_table.h | 22 ++-- .../distributed/ps/table/common_graph_table.h | 28 +++-- .../ps/table/common_sparse_table.h | 14 ++- .../fluid/distributed/ps/table/common_table.h | 57 --------- .../ps/table/memory_sparse_geo_table.cc | 24 ++++ .../ps/table/memory_sparse_geo_table.h | 32 ++--- .../ps/table/memory_sparse_table.cc | 18 ++- .../ps/table/memory_sparse_table.h | 58 ++++----- paddle/fluid/distributed/ps/table/table.h | 53 +++------ .../fluid/distributed/ps/table/tensor_table.h | 89 ++++---------- .../test/brpc_service_sparse_sgd_test.cc | 110 ++++++++++-------- .../distributed/test/dense_table_test.cc | 47 +++++++- .../distributed/test/memory_geo_table_test.cc | 37 +++++- .../test/memory_sparse_table_test.cc | 25 +++- python/paddle/distributed/ps/the_one_ps.py | 2 +- 18 files changed, 406 insertions(+), 313 deletions(-) diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index a1690cbb9353b..d22cca91f7816 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -244,7 +244,14 @@ int32_t BrpcPsService::PushDenseParam(Table *table, uint32_t num = *(const uint32_t *)data; const float *values = (const float *)(data + sizeof(uint32_t)); - if (table->PushDenseParam(values, num) != 0) { + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = values; + table_context.push_context.is_param = true; + table_context.num = num; + + // if (table->PushDenseParam(values, num) != 0) { + if (table->Push(table_context) != 0) { set_response_code(response, -1, "PushDenseParam failed"); } return 0; @@ -330,7 +337,15 @@ int32_t BrpcPsService::PushSparseParam(Table *table, const uint64_t *keys = (const uint64_t *)push_data.data(); const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * num); - if (table->PushSparseParam(keys, values, num) != 0) { + + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = keys; + table_context.push_context.values = values; + table_context.push_context.is_param = true; + table_context.num = num; + // if (table->PushSparseParam(keys, values, num) != 0) { + if (table->Push(table_context) != 0) { set_response_code(response, -1, "PushSparseParam error"); } return 0; @@ -349,7 +364,14 @@ int32_t BrpcPsService::PullGeoParam(Table *table, std::vector values; std::vector ids; - table->PullGeoParam(trainer_id, &values, &ids); + + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.geo_pull_keys = &ids; + table_context.pull_context.geo_pull_values = &values; + table_context.trainer_id = trainer_id; + table->Pull(table_context); + // table->PullGeoParam(trainer_id, &values, &ids); uint32_t num = ids.size(); cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); @@ -625,7 +647,13 @@ int32_t BrpcPsService::PushGlobalStep(Table *table, const int64_t *values = (const int64_t *)(request.data().data() + sizeof(uint32_t)); auto trainer_id = request.client_id(); - if (table->PushDense(values, trainer_id) != 0) { + + TableContext context; + context.trainer_id = trainer_id; + context.push_context.push_steps = values; + + // if (table->PushDense(values, trainer_id) != 0) { + if (table->Push(context) != 0) { set_response_code(response, -1, "run_program failed"); } diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index 3e93f861d4e0e..bc024ed3175bc 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -104,7 +104,13 @@ ::std::future PsLocalClient::PullDense(Region* regions, std::vector region_buffer; region_buffer.resize(num_per_shard); - table_ptr->PullDense(region_buffer.data(), region_buffer.size()); + + TableContext table_context; + table_context.value_type = Dense; + table_context.pull_context.values = region_buffer.data(); + table_context.num = region_buffer.size(); + table_ptr->Pull(table_context); + // table_ptr->PullDense(region_buffer.data(), region_buffer.size()); size_t region_idx = 0; size_t region_data_idx = 0; @@ -154,6 +160,13 @@ ::std::future PsLocalClient::PushDenseParam(const Region* regions, offset += data_num; } + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = region_buffer.data(); + table_context.push_context.is_param = true; + table_context.num = region_buffer.size(); + + table_ptr->Push(table_context); // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size()); return done(); @@ -168,7 +181,13 @@ ::std::future PsLocalClient::PushDenseRawGradient( auto* table_ptr = GetTable(table_id); - table_ptr->PushDense(total_send_data, total_send_data_size); + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = total_send_data; + table_context.num = total_send_data_size; + // table_ptr->PushDense(total_send_data, total_send_data_size); + table_ptr->Push(table_context); + delete closure; return done(); } @@ -194,7 +213,12 @@ ::std::future PsLocalClient::PushDense(const Region* regions, offset += data_num; } - table_ptr->PushDense(region_buffer.data(), region_buffer.size()); + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = region_buffer.data(); + table_context.num = region_buffer.size(); + // table_ptr->PushDense(total_send_data, total_send_data_size); + table_ptr->Push(table_context); return done(); } @@ -241,7 +265,15 @@ ::std::future PsLocalClient::PullSparsePtr(char** select_values, //将key拆分到各shard请求,并记录原始对应value指针 auto* table_ptr = GetTable(table_id); - table_ptr->PullSparsePtr(select_values, keys, num); + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.keys = keys; + table_context.pull_context.ptr_values = select_values; + table_context.use_ptr = true; + table_context.num = num; + + // table_ptr->PullSparsePtr(select_values, keys, num); + table_ptr->Pull(table_context); return done(); } @@ -253,7 +285,15 @@ ::std::future PsLocalClient::PushSparseRawGradient( auto* accessor = GetTableAccessor(table_id); auto* table_ptr = GetTable(table_id); - table_ptr->PushSparse(keys, update_values, num); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = keys; + table_context.push_context.ptr_values = update_values; + table_context.num = num; + table_context.use_ptr = true; + + // table_ptr->PushSparse(keys, update_values, num); + table_ptr->Push(table_context); delete closure; return done(); } @@ -265,7 +305,15 @@ ::std::future PsLocalClient::PushSparse(size_t table_id, auto* accessor = GetTableAccessor(table_id); auto* table_ptr = GetTable(table_id); - table_ptr->PushSparse(keys, update_values, num); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = keys; + table_context.push_context.ptr_values = update_values; + table_context.num = num; + table_context.use_ptr = true; + + // table_ptr->PushSparse(keys, update_values, num); + table_ptr->Push(table_context); return done(); } } diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.cc b/paddle/fluid/distributed/ps/table/common_dense_table.cc index 4242b65dea023..45208670f9d4c 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/common_dense_table.cc @@ -139,8 +139,11 @@ int32_t CommonDenseTable::Pull(TableContext& context) { int32_t CommonDenseTable::Push(TableContext& context) { CHECK(context.value_type == Dense); if (context.push_context.values != nullptr) { - const float* values = context.push_context.values; - return PushDense(values, context.num); + if (!context.push_context.is_param) { + return PushDense(context.push_context.values, context.num); + } else { + return PushDenseParam(context.push_context.values, context.num); + } } return 0; } diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.h b/paddle/fluid/distributed/ps/table/common_dense_table.h index 8e4ff1ecaf487..acda009d02402 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.h +++ b/paddle/fluid/distributed/ps/table/common_dense_table.h @@ -30,21 +30,22 @@ namespace distributed { class DenseOptimizer; -class CommonDenseTable : public DenseTable { +class CommonDenseTable : public Table { public: CommonDenseTable() {} virtual ~CommonDenseTable() {} int32_t Initialize() override; int32_t InitializeShard() override { return 0; } - virtual void CreateInitializer(const std::string& attr, - const std::string& name); - virtual int32_t InitializeValue(); - virtual int32_t InitializeOptimizer(); - virtual int32_t Pull(TableContext& context); - virtual int32_t Push(TableContext& context); - int32_t PullDense(float* pull_values, size_t num) override; - int32_t PushDenseParam(const float* values, size_t num) override; - int32_t PushDense(const float* values, size_t num) override; + void CreateInitializer(const std::string& attr, const std::string& name); + int32_t InitializeValue(); + int32_t InitializeOptimizer(); + + int32_t Pull(TableContext& context) override; + int32_t Push(TableContext& context) override; + + int32_t PullDense(float* pull_values, size_t num); + int32_t PushDenseParam(const float* values, size_t num); + int32_t PushDense(const float* values, size_t num); int32_t Pour() override; int32_t SetGlobalLR(float* lr) override; @@ -54,6 +55,7 @@ class CommonDenseTable : public DenseTable { int32_t Flush() override { return 0; } int32_t Shrink(const std::string& param) override { return 0; } void Clear() override { return; } + void* GetShard(size_t shard_idx) override { return 0; } protected: int32_t _PushDense(const float* values, size_t num); diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 035a3de3eba63..acc484e6098d4 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -404,7 +404,7 @@ class GraphSampler { }; #endif -class GraphTable : public SparseTable { +class GraphTable : public Table { public: GraphTable() { use_cache = false; @@ -415,6 +415,23 @@ class GraphTable : public SparseTable { rw_lock.reset(new pthread_rwlock_t()); } virtual ~GraphTable(); + + virtual void *GetShard(size_t shard_idx) { return 0; } + + static int32_t sparse_local_shard_num(uint32_t shard_num, + uint32_t server_num) { + if (shard_num % server_num == 0) { + return shard_num / server_num; + } + size_t local_shard_num = shard_num / server_num + 1; + return local_shard_num; + } + + static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + return (key % shard_num) / sparse_local_shard_num(shard_num, server_num); + } + virtual int32_t pull_graph_list(int start, int size, std::unique_ptr &buffer, int &actual_size, bool need_feature, @@ -452,15 +469,6 @@ class GraphTable : public SparseTable { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) { - return 0; - } - - virtual int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) { - return 0; - } - virtual int32_t clear_nodes(); virtual void Clear() {} virtual int32_t Flush() { return 0; } diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.h b/paddle/fluid/distributed/ps/table/common_sparse_table.h index f6deaf0a82b13..2673e8dfae3c6 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.h @@ -108,15 +108,16 @@ struct Meta { } }; -class CommonSparseTable : public SparseTable { +class CommonSparseTable : public Table { public: CommonSparseTable() { rwlock_.reset(new phi::RWLock); } virtual ~CommonSparseTable() {} // unused method begin - virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } - virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } - virtual int32_t PushDense(const float* values, size_t num) { return 0; } + // virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } + // virtual int32_t PushDenseParam(const float* values, size_t num) { return + // 0; } + // virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end virtual int32_t Pull(TableContext& context); @@ -163,14 +164,15 @@ class CommonSparseTable : public SparseTable { // only for sparse geo table virtual int32_t PushSparseParam(const uint64_t* keys, const float* values, size_t num); - - virtual int32_t SetGlobalLR(float* lr) override; + virtual int32_t SetGlobalLR(float* lr); virtual int32_t Pour(); virtual int32_t Flush(); virtual int32_t Shrink(const std::string& param); virtual void Clear(); + virtual void* GetShard(size_t shard_idx) { return 0; } + protected: virtual int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num); diff --git a/paddle/fluid/distributed/ps/table/common_table.h b/paddle/fluid/distributed/ps/table/common_table.h index f5e263e8e7189..f69d9ccbf1453 100644 --- a/paddle/fluid/distributed/ps/table/common_table.h +++ b/paddle/fluid/distributed/ps/table/common_table.h @@ -66,50 +66,6 @@ struct ReservoirValue { } }; -class SparseTable : public Table { - public: - SparseTable() {} - virtual ~SparseTable() {} - - virtual void *GetShard(size_t shard_idx) { return 0; } - - int32_t PullDense(float *values, size_t num) override { return 0; } - - int32_t PushDense(const float *values, size_t num) override { return 0; } - - static int32_t sparse_local_shard_num(uint32_t shard_num, - uint32_t server_num) { - if (shard_num % server_num == 0) { - return shard_num / server_num; - } - size_t local_shard_num = shard_num / server_num + 1; - return local_shard_num; - } - - static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, - uint64_t key) { - return (key % shard_num) / sparse_local_shard_num(shard_num, server_num); - } -}; - -class DenseTable : public Table { - public: - DenseTable() {} - virtual ~DenseTable() {} - - virtual void *GetShard(size_t shard_idx) { return 0; } - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } - int32_t PushDenseParam(const float *values, size_t num) override { return 0; } - int32_t Shrink(const std::string ¶m) override { return 0; } -}; - class BarrierTable : public Table { public: BarrierTable() {} @@ -120,19 +76,6 @@ class BarrierTable : public Table { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - int32_t PullDense(float *values, size_t num) override { return 0; } - - int32_t PushDense(const float *values, size_t num) override { return 0; } - - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } - int32_t PushDenseParam(const float *values, size_t num) override { return 0; } int32_t Shrink(const std::string ¶m) override { return 0; } virtual void Clear() {} virtual int32_t Flush() { return 0; } diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc index 979e1c482547c..1567d31d0f3ee 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc @@ -17,6 +17,29 @@ namespace paddle { namespace distributed { +int32_t MemorySparseGeoTable::Pull(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.pull_context.geo_pull_keys != nullptr) { + return PullGeoParam(context.trainer_id, + context.pull_context.geo_pull_values, + context.pull_context.geo_pull_keys); + } else { + return PullSparse(context.pull_context.values, + context.pull_context.pull_value); + } +} + +int32_t MemorySparseGeoTable::Push(TableContext& context) { + CHECK(context.value_type == Sparse); + if (!context.push_context.is_param) { + return PushSparse(context.push_context.keys, context.push_context.values, + context.num); + } else { + return PushSparseParam(context.push_context.keys, + context.push_context.values, context.num); + } +} + int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys, const float* values, size_t num) { VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin " @@ -117,6 +140,7 @@ int32_t MemorySparseGeoTable::Initialize() { return 0; } +// hash different from MemorySparseTable int32_t MemorySparseGeoTable::PullSparse(float* pull_values, const PullSparseValue& pull_value) { auto shard_num = _task_pool_size; diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h index 1a74df32db8e7..60ba5d9602e44 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h @@ -34,40 +34,44 @@ namespace distributed { class GeoRecorder; -class MemorySparseGeoTable : public SparseTable { +class MemorySparseGeoTable : public Table { public: typedef SparseTableShard shard_type; MemorySparseGeoTable() { _geo_recorder = nullptr; } virtual ~MemorySparseGeoTable() {} - virtual int32_t Initialize(); - virtual int32_t InitializeShard() { return 0; } - virtual int32_t Load(const std::string& path, const std::string& param) { + int32_t Initialize() override; + int32_t InitializeShard() override { return 0; } + int32_t Load(const std::string& path, const std::string& param) override { return 0; } - virtual int32_t Save(const std::string& path, const std::string& param) { + int32_t Save(const std::string& path, const std::string& param) override { return 0; } - virtual int32_t Pull(TableContext& context) { return 0; } - virtual int32_t Push(TableContext& context) { return 0; } - virtual int32_t Flush() { return 0; } - virtual int32_t Shrink(const std::string& param) { return 0; } - virtual void Clear() { return; } - virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); + int32_t Pull(TableContext& context) override; + int32_t Push(TableContext& context) override; + int32_t Flush() override { return 0; } + int32_t Shrink(const std::string& param) override { return 0; } + void Clear() override { return; } + + int32_t PullSparse(float* values, const PullSparseValue& pull_value); int32_t PushSparseParam(const uint64_t* keys, const float* values, size_t num); - // TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse + int32_t PullGeoParam(const uint32_t trainer_id, std::vector* values, std::vector* keys); - int32_t PushSparse(const uint64_t* keys, const float* values, - size_t num) override; + int32_t PushSparse(const uint64_t* keys, const float* values, size_t num); int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num); // int32_t _pull_sparse(float* pull_values, const PullSparseValue& // pull_value); + void* GetShard(size_t shard_idx) override { + return &_local_shards[shard_idx]; + } + private: std::shared_ptr _geo_recorder; const int _task_pool_size = 10; diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index b4b2263ed77bf..e6c52e0b9b0c8 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -47,7 +47,7 @@ int32_t MemorySparseTable::Initialize() { int32_t MemorySparseTable::InitializeValue() { _sparse_table_shard_num = static_cast(_config.shard_num()); _avg_local_shard_num = - SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num); + sparse_local_shard_num(_sparse_table_shard_num, _shard_num); _real_local_shard_num = _avg_local_shard_num; if (_real_local_shard_num * (_shard_idx + 1) > _sparse_table_shard_num) { _real_local_shard_num = @@ -405,9 +405,13 @@ int32_t MemorySparseTable::Pull(TableContext& context) { int32_t MemorySparseTable::Push(TableContext& context) { CHECK(context.value_type == Sparse); - - const uint64_t* keys = context.push_context.keys; - return PushSparse(keys, context.push_context.values, context.num); + if (!context.use_ptr) { + return PushSparse(context.push_context.keys, context.push_context.values, + context.num); + } else { + return PushSparse(context.push_context.keys, + context.push_context.ptr_values, context.num); + } } int32_t MemorySparseTable::PullSparse(float* pull_values, @@ -610,12 +614,6 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float** values, size_t num) { - _PushSparse(keys, values, num); - return 0; -} - -int32_t MemorySparseTable::_PushSparse(const uint64_t* keys, - const float** values, size_t num) { std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( _real_local_shard_num); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index a4af4caa472d7..87a73bd22fa2f 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -34,28 +34,37 @@ namespace paddle { namespace distributed { -class MemorySparseTable : public SparseTable { +class MemorySparseTable : public Table { public: typedef SparseTableShard shard_type; MemorySparseTable() {} virtual ~MemorySparseTable() {} - // unused method begin - virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } - virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } - virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end + static int32_t sparse_local_shard_num(uint32_t shard_num, + uint32_t server_num) { + if (shard_num % server_num == 0) { + return shard_num / server_num; + } + size_t local_shard_num = shard_num / server_num + 1; + return local_shard_num; + } - virtual int32_t Pull(TableContext& context); - virtual int32_t Push(TableContext& context); + static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + return (key % shard_num) / sparse_local_shard_num(shard_num, server_num); + } - virtual int32_t Initialize(); - virtual int32_t InitializeShard() { return 0; } - virtual int32_t InitializeValue(); + int32_t Pull(TableContext& context) override; + int32_t Push(TableContext& context) override; - virtual int32_t Load(const std::string& path, const std::string& param); + int32_t Initialize() override; + int32_t InitializeShard() override { return 0; } + int32_t InitializeValue(); - virtual int32_t Save(const std::string& path, const std::string& param); + int32_t Load(const std::string& path, const std::string& param) override; + + int32_t Save(const std::string& path, const std::string& param) override; int32_t LoadLocalFS(const std::string& path, const std::string& param); int32_t SaveLocalFS(const std::string& path, const std::string& param, @@ -64,25 +73,22 @@ class MemorySparseTable : public SparseTable { int64_t LocalSize(); int64_t LocalMFSize(); - virtual std::pair PrintTableStat(); - virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); + std::pair PrintTableStat() override; + int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, - size_t num); + int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, size_t num); - virtual int32_t PushSparse(const uint64_t* keys, const float* values, - size_t num); + int32_t PushSparse(const uint64_t* keys, const float* values, size_t num); - virtual int32_t PushSparse(const uint64_t* keys, const float** values, - size_t num); + int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); - virtual int32_t Flush(); - virtual int32_t Shrink(const std::string& param); - virtual void Clear(); + int32_t Flush() override; + int32_t Shrink(const std::string& param) override; + void Clear() override; - protected: - virtual int32_t _PushSparse(const uint64_t* keys, const float** values, - size_t num); + void* GetShard(size_t shard_idx) override { + return &_local_shards[shard_idx]; + } protected: const int _task_pool_size = 24; diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index f55c30b774059..c515e03e3fa48 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -35,25 +35,30 @@ namespace distributed { enum ValueType { Sparse = 0, Dense = 1 }; -struct PullContext { - const uint64_t *keys; +struct TablePullContext { + const uint64_t *keys = nullptr; PullSparseValue pull_value; - float *values; - char **ptr_values; + float *values = nullptr; + char **ptr_values = nullptr; + std::vector *geo_pull_keys = nullptr; // for GEO + std::vector *geo_pull_values = nullptr; // for GEO }; struct TablePushContext { - const uint64_t *keys; - const float *values; - const float **ptr_values; + const uint64_t *keys = nullptr; + const float *values = nullptr; + const float **ptr_values = nullptr; + const int64_t *push_steps = nullptr; // for global step + bool is_param = false; // true: push param, false: push gradient }; struct TableContext { ValueType value_type; - PullContext pull_context; + TablePullContext pull_context; TablePushContext push_context; size_t num; bool use_ptr = false; + uint32_t trainer_id; // for GEO and global step }; class Table { @@ -65,38 +70,6 @@ class Table { virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0; - virtual int32_t PullDense(float *values, size_t num) = 0; - virtual int32_t PushDense(const float *values, size_t num) = 0; - // for push global_step - virtual int32_t PushDense(const int64_t *values, const int32_t trainer_id) { - return 0; - } - virtual int32_t PushDenseParam(const float *values, size_t num) { return 0; } - - virtual int32_t PullSparsePtr(char **pull_values, const uint64_t *keys, - size_t num) { - VLOG(0) << "NOT IMPLEMENT"; - return 0; - } - virtual int32_t PullSparse(float *values, - const PullSparseValue &pull_value) = 0; - virtual int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) = 0; - virtual int32_t PushSparse(const uint64_t *keys, const float **values, - size_t num) { - return 0; - } - virtual int32_t PushSparseParam(const uint64_t *keys, const float *values, - size_t num) { - return 0; - } - - // only for sparse geo table - virtual int32_t PullGeoParam(const uint32_t trainer_id, - std::vector *values, - std::vector *keys) { - return 0; - } // only for barrier virtual int32_t Barrier(const uint32_t trainer_id, diff --git a/paddle/fluid/distributed/ps/table/tensor_table.h b/paddle/fluid/distributed/ps/table/tensor_table.h index 175aa194fb80f..7bb236d02c985 100644 --- a/paddle/fluid/distributed/ps/table/tensor_table.h +++ b/paddle/fluid/distributed/ps/table/tensor_table.h @@ -50,43 +50,28 @@ class TensorTable : public Table { TensorTable() {} virtual ~TensorTable() {} - virtual int32_t Pull(TableContext &context) { return 0; } - virtual int32_t Push(TableContext &context) { return 0; } - int32_t PullDense(float *values, size_t num) override { return 0; } + int32_t Pull(TableContext &context) override { return 0; } + int32_t Push(TableContext &context) override { return 0; } - int32_t PushDense(const float *values, size_t num) override { return 0; } - - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *GetShard(size_t shard_idx) { return 0; } + void *GetShard(size_t shard_idx) override { return 0; } - virtual int32_t InitializeShard() { return 0; } + int32_t InitializeShard() override { return 0; } - virtual int32_t Flush() { return 0; } + int32_t Flush() override { return 0; } - virtual int32_t Load(const std::string &path, const std::string ¶m) { + int32_t Load(const std::string &path, const std::string ¶m) override { return 0; } - virtual int32_t Save(const std::string &path, const std::string ¶m) { + int32_t Save(const std::string &path, const std::string ¶m) override { return 0; } - virtual void Clear() {} + void Clear() override {} int32_t Initialize() override { return 0; } - int32_t PushDense(const int64_t *values, const int32_t trainer_id) override { - return 0; - } - int32_t SetProgramEnv( framework::Scope *scope, platform::Place place, const std::vector *sub_program) override { @@ -111,45 +96,28 @@ class DenseTensorTable : public TensorTable { DenseTensorTable() {} virtual ~DenseTensorTable() {} - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *GetShard(size_t shard_idx) { return 0; } + void *GetShard(size_t shard_idx) override { return 0; } - virtual int32_t InitializeShard() { return 0; } + int32_t InitializeShard() override { return 0; } - virtual int32_t Flush() { return 0; } + int32_t Flush() override { return 0; } - virtual void Clear() {} + void Clear() override {} // Todo: Support program Load & Save - virtual int32_t Load(const std::string &path, const std::string ¶m) { + int32_t Load(const std::string &path, const std::string ¶m) override { return 0; } - virtual int32_t Save(const std::string &path, const std::string ¶m) { + int32_t Save(const std::string &path, const std::string ¶m) override { return 0; } - // Todo: Support pull dense - int32_t PullDense(float *values, size_t num) override { return 0; } - /*----------------------------------------------------------------------*/ int32_t Initialize() override { return 0; } - int32_t PushDense(const float *values, size_t num) override { return 0; } - - int32_t PushDense(const int64_t *values, const int32_t trainer_id) { - return 0; - } - protected: virtual int32_t _RunProgram(const float *values, size_t num, const uint32_t trainer_id) { @@ -167,33 +135,23 @@ class GlobalStepTable : public DenseTensorTable { GlobalStepTable() {} virtual ~GlobalStepTable() {} - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *GetShard(size_t shard_idx) { return 0; } + void *GetShard(size_t shard_idx) override { return 0; } - virtual int32_t InitializeShard() { return 0; } + int32_t InitializeShard() override { return 0; } - virtual int32_t Flush() { return 0; } + int32_t Flush() override { return 0; } - virtual void Clear() {} + void Clear() override {} - virtual int32_t Load(const std::string &path, const std::string ¶m) { + int32_t Load(const std::string &path, const std::string ¶m) override { return 0; } - virtual int32_t Save(const std::string &path, const std::string ¶m) { + int32_t Save(const std::string &path, const std::string ¶m) override { return 0; } - int32_t PullDense(float *values, size_t num) override { return 0; } - /*----------------------------------------------------------------------*/ int32_t Initialize() override { @@ -235,12 +193,13 @@ class GlobalStepTable : public DenseTensorTable { decay_counters_[i] = 0; } } + return 0; } - int32_t PushDense(const float *values, size_t num) override { return 0; } + // int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t PushDense(const int64_t *values, const int32_t trainer_id) { - return _RunProgram(values, trainer_id); + virtual int32_t Push(TableContext context) { + return _RunProgram(context.push_context.push_steps, context.trainer_id); } int32_t SetTableMap(std::unordered_map> diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc index f7d287af84472..29195d9985728 100644 --- a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -49,6 +49,8 @@ namespace distributed = paddle::distributed; void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { auto x_var = scope->Var("x"); x_var->GetMutable(); + auto x_g_var = scope->Var("x@GRAD"); + x_g_var->GetMutable(); } void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, @@ -59,34 +61,49 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, float* x_ptr = x_var->mutable_data(framework::DDim({1, rows_numel}), *place); for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; + + auto g_size = rows_numel + + 30; // hard code here: key_num * (fea_dim + 3), show/clk/slot + auto x_g_var = scope->Var("x@GRAD")->GetMutable(); + float* x_g_ptr = + x_g_var->mutable_data(framework::DDim({1, g_size}), *place); + for (int64_t i = 0; i < g_size; ++i) x_g_ptr[i] = 1.0; } void GetDownpourSparseTableProto( ::paddle::distributed::TableParameter* sparse_table_proto) { sparse_table_proto->set_table_id(0); - sparse_table_proto->set_table_class("CommonSparseTable"); - sparse_table_proto->set_shard_num(256); - sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); - ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->set_table_class("MemorySparseTable"); + sparse_table_proto->set_shard_num(10); + ::paddle::distributed::TableAccessorParameter* accessor_config = sparse_table_proto->mutable_accessor(); - ::paddle::distributed::CommonAccessorParameter* common_proto = - sparse_table_proto->mutable_common(); - - accessor_proto->set_accessor_class("CommMergeAccessor"); - accessor_proto->set_fea_dim(0); - accessor_proto->set_embedx_dim(10); - - common_proto->set_name("sgd"); - common_proto->set_table_name("MergedDense"); - common_proto->set_trainer_num(1); - common_proto->set_sync(false); - common_proto->set_entry("none"); - common_proto->add_params("Param"); - common_proto->add_dims(10); - common_proto->add_initializers("uniform_random&0&-1.0&1.0"); - common_proto->add_params("LearningRate"); - common_proto->add_dims(1); - common_proto->add_initializers("fill_constant&1.0"); + + accessor_config->set_accessor_class("SparseAccessor"); + accessor_config->set_fea_dim(10); + accessor_config->set_embedx_dim(9); + accessor_config->set_embedx_threshold(0); + accessor_config->mutable_ctr_accessor_param()->set_nonclk_coeff(0.2); + accessor_config->mutable_ctr_accessor_param()->set_click_coeff(1); + accessor_config->mutable_ctr_accessor_param()->set_base_threshold(0.5); + accessor_config->mutable_ctr_accessor_param()->set_delta_threshold(0.2); + accessor_config->mutable_ctr_accessor_param()->set_delta_keep_days(16); + accessor_config->mutable_ctr_accessor_param()->set_show_click_decay_rate( + 0.99); + + accessor_config->mutable_embed_sgd_param()->set_name("SparseNaiveSGDRule"); + auto* naive_param = + accessor_config->mutable_embed_sgd_param()->mutable_naive(); + naive_param->set_learning_rate(1.0); + naive_param->set_initial_range(0.3); + naive_param->add_weight_bounds(-10.0); + naive_param->add_weight_bounds(10.0); + + accessor_config->mutable_embedx_sgd_param()->set_name("SparseNaiveSGDRule"); + naive_param = accessor_config->mutable_embedx_sgd_param()->mutable_naive(); + naive_param->set_learning_rate(1.0); + naive_param->set_initial_range(0.3); + naive_param->add_weight_bounds(-10.0); + naive_param->add_weight_bounds(10.0); } ::paddle::distributed::PSParameter GetServerProto() { @@ -217,42 +234,42 @@ void RunBrpcPushSparse() { auto pull_status = worker_ptr_->PullSparse( fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_status.wait(); - for (size_t idx = 0; idx < tensor->numel(); ++idx) { - fea_values.data()[idx] *= 2.0; - } - - /*-----------------------Test Push Param----------------------------------*/ - LOG(INFO) << "Run push_sparse_param"; - paddle::distributed::DownpourBrpcClosure* closure_push_param = + /*-----------------------Test Push Grad----------------------------------*/ + // first to expand embedx, init + paddle::distributed::DownpourBrpcClosure* closure_push_grad = new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { int ret = 0; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; for (size_t i = 0; i < 1; ++i) { if (closure->check_response( - i, paddle::distributed::PS_PUSH_SPARSE_PARAM) != 0) { + i, paddle::distributed::PS_PUSH_SPARSE_TABLE) != 0) { ret = -1; break; } } closure->set_promise_value(ret); }); - auto push_status = worker_ptr_->PushSparseParam( - 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), - closure_push_param); - push_status.wait(); - auto pull_param_status = worker_ptr_->PullSparse( - fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); - pull_param_status.wait(); + framework::Variable* g_var = client_scope.FindVar("x@GRAD"); + framework::LoDTensor* g_tensor = g_var->GetMutable(); - for (size_t idx = 0; idx < tensor->numel(); ++idx) { - EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx]); + LOG(INFO) << "Run push_sparse_grad"; + std::vector push_g_vec; + for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { + push_g_vec.push_back(g_tensor->data() + i * 13); } + auto push_grad_status = worker_ptr_->PushSparseRawGradient( + 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), + closure_push_grad); + push_grad_status.wait(); - /*-----------------------Test Push Grad----------------------------------*/ + // pull + pull_status = worker_ptr_->PullSparse(fea_value_ptr.data(), 0, + fea_keys.data(), fea_keys.size(), true); + pull_status.wait(); - paddle::distributed::DownpourBrpcClosure* closure_push_grad = + paddle::distributed::DownpourBrpcClosure* closure_push_grad1 = new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { int ret = 0; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; @@ -266,16 +283,13 @@ void RunBrpcPushSparse() { closure->set_promise_value(ret); }); - LOG(INFO) << "Run pull_sparse_grad"; - std::vector push_g_vec; - for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { - push_g_vec.push_back(tensor->data() + i * 10); - } - auto push_grad_status = worker_ptr_->PushSparseRawGradient( + // push again, embedx update this time + push_grad_status = worker_ptr_->PushSparseRawGradient( 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), - closure_push_grad); + closure_push_grad1); push_grad_status.wait(); + // pull update auto pull_update_status = worker_ptr_->PullSparse( fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_update_status.wait(); diff --git a/paddle/fluid/distributed/test/dense_table_test.cc b/paddle/fluid/distributed/test/dense_table_test.cc index 49346c2898fc6..40992b1b53b89 100644 --- a/paddle/fluid/distributed/test/dense_table_test.cc +++ b/paddle/fluid/distributed/test/dense_table_test.cc @@ -69,7 +69,13 @@ TEST(CommonDenseTable, Adam) { // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->PullDense(init_values.data(), fea_dim); + + TableContext table_context1; + table_context1.value_type = Dense; + table_context1.pull_context.values = init_values.data(); + table_context1.num = fea_dim; + table->Pull(table_context1); + // table->PullDense(init_values.data(), fea_dim); // push gradient std::vector> trainer_gradient_values; @@ -85,12 +91,24 @@ TEST(CommonDenseTable, Adam) { // for adam for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; - table->PushDense(push_values.data(), push_values.size()); + + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = push_values.data(); + table_context.num = push_values.size(); + table->Push(table_context); + // table->PushDense(push_values.data(), push_values.size()); } std::vector pull_values; pull_values.resize(fea_dim); - table->PullDense(pull_values.data(), fea_dim); + + TableContext table_context; + table_context.value_type = Dense; + table_context.pull_context.values = pull_values.data(); + table_context.num = fea_dim; + table->Pull(table_context); + // table->PullDense(pull_values.data(), fea_dim); float mom_rate = 0.99; float decay_rate = 0.9999; @@ -150,7 +168,13 @@ TEST(CommonDenseTable, SGD) { // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->PullDense(init_values.data(), fea_dim); + + TableContext table_context1; + table_context1.value_type = Dense; + table_context1.pull_context.values = init_values.data(); + table_context1.num = fea_dim; + table->Pull(table_context1); + // table->PullDense(init_values.data(), fea_dim); std::vector total_gradients; total_gradients.resize(fea_dim); @@ -173,7 +197,12 @@ TEST(CommonDenseTable, SGD) { for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_values] { - table->PushDense(push_values.data(), push_values.size()); + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = push_values.data(); + table_context.num = push_values.size(); + table->Push(table_context); + // table->PushDense(push_values.data(), push_values.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -183,7 +212,13 @@ TEST(CommonDenseTable, SGD) { std::vector pull_values; pull_values.resize(fea_dim); - table->PullDense(pull_values.data(), fea_dim); + + TableContext table_context; + table_context.value_type = Dense; + table_context.pull_context.values = pull_values.data(); + table_context.num = fea_dim; + table->Pull(table_context); + // table->PullDense(pull_values.data(), fea_dim); for (int j = 0; j < fea_dim; j++) { auto update_val = init_values[j] - 1.0 * total_gradients[j]; ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5); diff --git a/paddle/fluid/distributed/test/memory_geo_table_test.cc b/paddle/fluid/distributed/test/memory_geo_table_test.cc index 965f67992d000..ca3b51fade177 100644 --- a/paddle/fluid/distributed/test/memory_geo_table_test.cc +++ b/paddle/fluid/distributed/test/memory_geo_table_test.cc @@ -58,12 +58,26 @@ TEST(MemorySparseGeoTable, SSUM) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { init_values.push_back(0.0); } - table->PushSparseParam(init_keys.data(), init_values.data(), - init_keys.size()); + + TableContext table_context1; + table_context1.value_type = Sparse; + table_context1.push_context.keys = init_keys.data(); + table_context1.push_context.values = init_values.data(); + table_context1.push_context.is_param = true; + table_context1.num = init_keys.size(); + + table->Push(table_context1); + // table->PushSparseParam(init_keys.data(), init_values.data(), + // init_keys.size()); std::vector pull_values(init_values.size()); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->PullSparse(pull_values.data(), value); + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.pull_value = value; + table_context.pull_context.values = pull_values.data(); + table->Pull(table_context); + // table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5); @@ -93,7 +107,14 @@ TEST(MemorySparseGeoTable, SSUM) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_values[i]; auto task = [table, &push_keys, &push_values] { - table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = push_keys.data(); + table_context.push_context.values = push_values.data(); + table_context.num = push_keys.size(); + table->Push(table_context); + // table->PushSparse(push_keys.data(), push_values.data(), + // push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -106,7 +127,13 @@ TEST(MemorySparseGeoTable, SSUM) { geo_pull_ids.resize(trainers); geo_pull_values.resize(trainers); for (int i = 0; i < trainers; i++) { - table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]); + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.geo_pull_keys = &geo_pull_ids[i]; + table_context.pull_context.geo_pull_values = &geo_pull_values[i]; + table_context.trainer_id = i; + table->Pull(table_context); + // table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]); ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { auto id = geo_pull_ids[i][j]; diff --git a/paddle/fluid/distributed/test/memory_sparse_table_test.cc b/paddle/fluid/distributed/test/memory_sparse_table_test.cc index 73fa7272280b2..68bc50373ffad 100644 --- a/paddle/fluid/distributed/test/memory_sparse_table_test.cc +++ b/paddle/fluid/distributed/test/memory_sparse_table_test.cc @@ -76,7 +76,13 @@ TEST(MemorySparseTable, SGD) { std::vector init_values; init_values.resize(init_keys.size() * (emb_dim + 3)); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->PullSparse(init_values.data(), value); + + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.pull_value = value; + table_context.pull_context.values = init_values.data(); + table->Pull(table_context); + // table->PullSparse(init_values.data(), value); // for check std::vector total_gradients; @@ -109,7 +115,14 @@ TEST(MemorySparseTable, SGD) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_keys, &push_values] { - table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = push_keys.data(); + table_context.push_context.values = push_values.data(); + table_context.num = push_keys.size(); + table->Push(table_context); + // table->PushSparse(push_keys.data(), push_values.data(), + // push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -119,7 +132,13 @@ TEST(MemorySparseTable, SGD) { std::vector pull_values; pull_values.resize(init_keys.size() * (emb_dim + 3)); - table->PullSparse(pull_values.data(), value); + + TableContext table_context1; + table_context1.value_type = Sparse; + table_context1.pull_context.pull_value = value; + table_context1.pull_context.values = pull_values.data(); + table->Pull(table_context1); + // table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size(); ++i) { for (size_t j = 2; j < emb_dim + 3; ++j) { diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 007aaeb4fed67..1fd435cca1107 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -621,7 +621,7 @@ def _set(self, table_proto): class GeoSparseTable(SparseTable): def __init__(self, context, send_ctx): super(GeoSparseTable, self).__init__(context, send_ctx) - self.table_class = "SparseGeoTable" + self.table_class = "MemorySparseGeoTable" if self.context['ps_mode'] != DistributedMode.GEO: raise ValueError("not geo sparse table!") From 77cf305f0e08ce3057d7c4c74416743fa9b7104c Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 4 Apr 2022 21:46:06 +0800 Subject: [PATCH 77/93] Add batch norm yaml (#41386) * update * fix bug --- paddle/fluid/operators/inplace_abn_op.cc | 4 +- paddle/fluid/operators/inplace_abn_op.cu | 8 +- paddle/phi/api/lib/api_custom_impl.cc | 129 ++++++++++++++++++ paddle/phi/api/lib/api_custom_impl.h | 14 ++ paddle/phi/kernels/batch_norm_grad_kernel.h | 12 +- .../phi/kernels/cpu/batch_norm_grad_kernel.cc | 26 ++-- .../phi/kernels/gpu/batch_norm_grad_kernel.cu | 18 +-- paddle/phi/ops/compat/batch_norm_sig.cc | 20 +-- python/paddle/fluid/dygraph/nn.py | 25 ++-- .../tests/unittests/test_batch_norm_op_v2.py | 34 +++++ python/paddle/nn/functional/norm.py | 11 +- python/paddle/utils/code_gen/api.yaml | 7 + python/paddle/utils/code_gen/backward.yaml | 12 ++ 13 files changed, 269 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/operators/inplace_abn_op.cc b/paddle/fluid/operators/inplace_abn_op.cc index 77951ff394e74..89459d00ae813 100644 --- a/paddle/fluid/operators/inplace_abn_op.cc +++ b/paddle/fluid/operators/inplace_abn_op.cc @@ -312,8 +312,8 @@ class InplaceABNGradKernel : public framework::OpKernel { phi::BatchNormGradRawKernel( static_cast::TYPE&>(dev_ctx), - *d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt, - mean_opt, variance_opt, momentum, epsilon, data_layout, is_test, + *y, *scale, *bias, mean_opt, variance_opt, *saved_mean, *saved_variance, + space_opt, *d_y, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics, fuse_with_relu, true, d_x, scale_grad, bias_grad); } diff --git a/paddle/fluid/operators/inplace_abn_op.cu b/paddle/fluid/operators/inplace_abn_op.cu index db8f8c72d13f8..6c16210ced022 100644 --- a/paddle/fluid/operators/inplace_abn_op.cu +++ b/paddle/fluid/operators/inplace_abn_op.cu @@ -140,10 +140,10 @@ class InplaceABNGradKernel phi::BatchNormGradRawKernel( static_cast::TYPE&>(dev_ctx), - *d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt, - mean_opt, variance_opt, momentum, epsilon, data_layout, is_test, - use_global_stats, trainable_statistics, fuse_with_relu, true, d_x, - scale_grad, bias_grad); + *y, *scale, *bias, mean_opt, variance_opt, *saved_mean, + *saved_variance, space_opt, *d_y, momentum, epsilon, data_layout, + is_test, use_global_stats, trainable_statistics, fuse_with_relu, true, + d_x, scale_grad, bias_grad); } } }; diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index ce49680586caa..6325322b63c6f 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -167,6 +167,135 @@ std::vector split_impl(const Tensor& x, return out; } +std::tuple batch_norm_impl( + const Tensor& x, + const Tensor& scale, + const Tensor& bias, + const Tensor& mean, + const Tensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu) { + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; + + kernel_data_type = ParseDataType(x); + + if (kernel_backend == Backend::UNDEFINED || + kernel_layout == DataLayout::UNDEFINED || + kernel_data_type == DataType::UNDEFINED) { + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) { + kernel_backend = kernel_key.backend(); + } + if (kernel_layout == DataLayout::UNDEFINED) { + kernel_layout = kernel_key.layout(); + } + if (kernel_data_type == DataType::UNDEFINED) { + kernel_data_type = kernel_key.dtype(); + } + } + + const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "batch_norm", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "batch_norm API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "batch_norm API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto input_x = PrepareData(x, kernel.InputAt(0), {}); + auto input_scale = PrepareData(scale, kernel.InputAt(1), {}); + auto input_bias = PrepareData(bias, kernel.InputAt(2), {}); + auto input_mean = PrepareData(mean, kernel.InputAt(3), {}); + auto input_variance = PrepareData(variance, kernel.InputAt(4), {}); + + std::tuple api_output; + auto kernel_out_0 = SetKernelOutput(kernel_backend, &std::get<0>(api_output)); + std::get<1>(api_output).set_impl(mean.impl()); + std::get<2>(api_output).set_impl(variance.impl()); + auto kernel_out_1 = SetKernelOutput(kernel_backend, &std::get<1>(api_output)); + auto kernel_out_2 = SetKernelOutput(kernel_backend, &std::get<2>(api_output)); + auto kernel_out_3 = SetKernelOutput(kernel_backend, &std::get<3>(api_output)); + auto kernel_out_4 = SetKernelOutput(kernel_backend, &std::get<4>(api_output)); + auto kernel_out_5 = SetKernelOutput(kernel_backend, &std::get<5>(api_output)); + phi::MetaTensor meta_out_0(kernel_out_0); + phi::MetaTensor meta_out_1(kernel_out_1); + phi::MetaTensor meta_out_2(kernel_out_2); + phi::MetaTensor meta_out_3(kernel_out_3); + phi::MetaTensor meta_out_4(kernel_out_4); + phi::MetaTensor meta_out_5(kernel_out_5); + + phi::BatchNormInferMeta(MakeMetaTensor(*input_x), + MakeMetaTensor(*input_scale), + MakeMetaTensor(*input_bias), + MakeMetaTensor(*input_mean), + MakeMetaTensor(*input_variance), + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + fuse_with_relu, + &meta_out_0, + &meta_out_1, + &meta_out_2, + &meta_out_3, + &meta_out_4, + &meta_out_5); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + float, + float, + const std::string&, + bool, + bool, + bool, + bool, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + { + (*kernel_fn)(*dev_ctx, + *input_x, + *input_scale, + *input_bias, + *input_mean, + *input_variance, + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + fuse_with_relu, + kernel_out_0, + kernel_out_1, + kernel_out_2, + kernel_out_3, + kernel_out_4, + kernel_out_5); + } + + return api_output; +} + std::vector concat_grad_impl(const std::vector& x, const Tensor& out_grad, const Scalar& axis) { diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 1f84eab10353d..e8893cc2476a0 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -31,6 +31,20 @@ std::vector split_impl(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis); +std::tuple batch_norm_impl( + const Tensor& x, + const Tensor& scale, + const Tensor& bias, + const Tensor& mean, + const Tensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu); + std::vector concat_grad_impl(const std::vector& x, const Tensor& out_grad, const Scalar& axis); diff --git a/paddle/phi/kernels/batch_norm_grad_kernel.h b/paddle/phi/kernels/batch_norm_grad_kernel.h index c15dbd2f63f58..73752f015ca3a 100644 --- a/paddle/phi/kernels/batch_norm_grad_kernel.h +++ b/paddle/phi/kernels/batch_norm_grad_kernel.h @@ -21,15 +21,15 @@ namespace phi { template void BatchNormGradRawKernel(const Context& dev_ctx, - const DenseTensor& y_grad, const DenseTensor& x, const DenseTensor& scale, const DenseTensor& bias, + paddle::optional mean, + paddle::optional variance, const DenseTensor& saved_mean, const DenseTensor& saved_variance, paddle::optional reserve_space, - paddle::optional mean, - paddle::optional variance, + const DenseTensor& y_grad, float momentum, float epsilon, const std::string& data_layout, @@ -44,15 +44,15 @@ void BatchNormGradRawKernel(const Context& dev_ctx, template void BatchNormGradKernel(const Context& dev_ctx, - const DenseTensor& y_grad, const DenseTensor& x, const DenseTensor& scale, const DenseTensor& bias, + paddle::optional mean, + paddle::optional variance, const DenseTensor& saved_mean, const DenseTensor& saved_variance, paddle::optional reserve_space, - paddle::optional mean, - paddle::optional variance, + const DenseTensor& y_grad, float momentum, float epsilon, const std::string& data_layout, diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index de2343a384a5b..ae87886b89bff 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -37,15 +37,16 @@ using ConstEigenVectorArrayMap = template void BatchNormGradRawKernel(const Context& ctx, - const DenseTensor& y_grad, + const DenseTensor& x, const DenseTensor& scale, const DenseTensor& bias, + paddle::optional mean, + paddle::optional variance, const DenseTensor& saved_mean, const DenseTensor& saved_variance, paddle::optional reserve_space, - paddle::optional mean, - paddle::optional variance, + const DenseTensor& y_grad, float momentum, float epsilon, const std::string& data_layout_str, @@ -122,8 +123,8 @@ void BatchNormGradRawKernel(const Context& ctx, ctx.template Alloc(d_x); } - const T* mean_data = saved_mean.data(); - const T* inv_var_data = saved_variance.data(); + const T* mean_data = nullptr; + const T* inv_var_data = nullptr; DenseTensor inv_var_tensor; if (use_global_stats) { const auto* running_mean = mean.get_ptr(); @@ -136,6 +137,9 @@ void BatchNormGradRawKernel(const Context& ctx, inv_var_tmp = (var_arr + epsilon).sqrt().inverse(); inv_var_data = running_inv_var_data; + } else { + mean_data = saved_mean.data(); + inv_var_data = saved_variance.data(); } ConstEigenVectorArrayMap scale_arr(scale.data(), C); @@ -293,15 +297,15 @@ void BatchNormGradRawKernel(const Context& ctx, template void BatchNormGradKernel(const Context& dev_ctx, - const DenseTensor& y_grad, const DenseTensor& x, const DenseTensor& scale, const DenseTensor& bias, + paddle::optional mean, + paddle::optional variance, const DenseTensor& saved_mean, const DenseTensor& saved_variance, paddle::optional reserve_space, - paddle::optional mean, - paddle::optional variance, + const DenseTensor& y_grad, float momentum, float epsilon, const std::string& data_layout, @@ -313,15 +317,15 @@ void BatchNormGradKernel(const Context& dev_ctx, DenseTensor* scale_grad, DenseTensor* bias_grad) { BatchNormGradRawKernel(dev_ctx, - y_grad, x, scale, bias, + mean, + variance, saved_mean, saved_variance, reserve_space, - mean, - variance, + y_grad, momentum, epsilon, data_layout, diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 339c3536d7a7f..09bce3c9895b3 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -306,15 +306,15 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( template void BatchNormGradRawKernel(const Context &ctx, - const DenseTensor &y_grad, const DenseTensor &x, const DenseTensor &scale, const DenseTensor &bias, + paddle::optional mean, + paddle::optional variance, const DenseTensor &saved_mean, const DenseTensor &saved_variance, paddle::optional reserve_space, - paddle::optional mean, - paddle::optional variance, + const DenseTensor &y_grad, float momentum, float epsilon_f, const std::string &data_layout_str, @@ -863,15 +863,15 @@ void BatchNormGradRawKernel(const Context &ctx, template void BatchNormGradKernel(const Context &dev_ctx, - const DenseTensor &y_grad, const DenseTensor &x, const DenseTensor &scale, const DenseTensor &bias, + paddle::optional mean, + paddle::optional variance, const DenseTensor &saved_mean, const DenseTensor &saved_variance, paddle::optional reserve_space, - paddle::optional mean, - paddle::optional variance, + const DenseTensor &y_grad, float momentum, float epsilon, const std::string &data_layout, @@ -883,15 +883,15 @@ void BatchNormGradKernel(const Context &dev_ctx, DenseTensor *scale_grad, DenseTensor *bias_grad) { BatchNormGradRawKernel(dev_ctx, - y_grad, x, scale, bias, + mean, + variance, saved_mean, saved_variance, reserve_space, - mean, - variance, + y_grad, momentum, epsilon, data_layout, diff --git a/paddle/phi/ops/compat/batch_norm_sig.cc b/paddle/phi/ops/compat/batch_norm_sig.cc index 803bb50b438a5..cfd9f4def933a 100644 --- a/paddle/phi/ops/compat/batch_norm_sig.cc +++ b/paddle/phi/ops/compat/batch_norm_sig.cc @@ -59,15 +59,17 @@ KernelSignature BatchNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( "batch_norm_grad", - {GradVarName("Y"), - "X", - "Scale", - "Bias", - "SavedMean", - "SavedVariance", - "ReserveSpace", - "Mean", - "Variance"}, + { + "X", + "Scale", + "Bias", + "Mean", + "Variance", + "SavedMean", + "SavedVariance", + "ReserveSpace", + GradVarName("Y"), + }, {"momentum", "epsilon", "data_layout", diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 531adc9e456b8..0ae3cf6ba2fdb 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1339,15 +1339,22 @@ def forward(self, input): variance_out = self._variance if _non_static_mode(): - attrs = ("momentum", self._momentum, "epsilon", self._epsilon, - "is_test", not self.training, "data_layout", - self._data_layout, "use_mkldnn", self._use_mkldnn, - "fuse_with_relu", self._fuse_with_relu, "use_global_stats", - self._use_global_stats, 'trainable_statistics', - self._trainable_statistics) - batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( - input, self.weight, self.bias, self._mean, self._variance, - mean_out, variance_out, *attrs) + if in_dygraph_mode(): + batch_norm_out, t1, t2, t3, t4, _ = _C_ops.final_state_batch_norm( + input, self.weight, self.bias, self._mean, self._variance, + self._momentum, self._epsilon, self._data_layout, + not self.training, self._use_global_stats, + self._trainable_statistics, False) + else: + attrs = ("momentum", self._momentum, "epsilon", self._epsilon, + "is_test", not self.training, "data_layout", + self._data_layout, "use_mkldnn", self._use_mkldnn, + "fuse_with_relu", self._fuse_with_relu, + "use_global_stats", self._use_global_stats, + 'trainable_statistics', self._trainable_statistics) + batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( + input, self.weight, self.bias, self._mean, self._variance, + mean_out, variance_out, *attrs) return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index dda10fdd84fff..ac09d9f5fdfd0 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -81,6 +81,40 @@ def error3d(): self.assertRaises(ValueError, error2d_dataformat) self.assertRaises(ValueError, error3d_dataformat) + def test_eager_api(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [4, 10, 4, 4] + + def compute_v1(x): + with fluid.dygraph.guard(p): + bn = fluid.dygraph.BatchNorm(shape[1]) + #bn = paddle.nn.BatchNorm2D(shape[1]) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v2(x): + with fluid.dygraph.guard(p): + with _test_eager_guard(): + print("v2") + bn = paddle.nn.BatchNorm2D(shape[1]) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + x = np.random.randn(*shape).astype("float32") + y1, g1 = compute_v1(x) + y2, g2 = compute_v2(x) + self.assertTrue(np.allclose(g1, g2)) + self.assertTrue(np.allclose(y1, y2)) + def test_dygraph(self): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 3f7e819f442c1..38a6d7a09d208 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -186,15 +186,24 @@ def batch_norm(x, else: trainable_statistics = not use_global_stats - if in_dynamic_mode(): + if in_dygraph_mode(): + batch_norm_out, _, _, _, _, _ = _C_ops.final_state_batch_norm( + x, weight, bias, running_mean, running_var, momentum, epsilon, + data_format, not training, use_global_stats, trainable_statistics, + False) + return batch_norm_out + if _in_legacy_dygraph(): + # for dygraph need tuple attrs = ("momentum", momentum, "epsilon", epsilon, "is_test", not training, "data_layout", data_format, "use_mkldnn", False, "fuse_with_relu", False, "use_global_stats", use_global_stats, "trainable_statistics", trainable_statistics) + batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( x, weight, bias, running_mean, running_var, mean_out, variance_out, *attrs) + return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=None) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 08cf04f692806..b41ccf8ddb545 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -207,6 +207,13 @@ kernel : func : auc +# batch_norm +- api : batch_norm + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) + output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + invoke : batch_norm_impl(x, scale, bias, mean, variance, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics, fuse_with_relu) + backward : batch_norm_grad + - api : bce_loss args : (Tensor input, Tensor label) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 570e64dcd5e12..814c56d7d222c 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -118,6 +118,18 @@ kernel : func : atanh_grad +- backward_api : batch_norm_grad + forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [x, scale, bias] + kernel : + func : batch_norm_grad + data_type : out_grad + optional : mean_out, variance_out, reserve_space + - backward_api : bce_loss_grad forward : bce_loss (Tensor input, Tensor label) -> Tensor(out) args : (Tensor input, Tensor label, Tensor out_grad) From 1888d874b2cc62e10adc0d22b60cdce48f90fd65 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 4 Apr 2022 21:53:36 +0800 Subject: [PATCH 78/93] add cudnn flag in yaml (#41368) --- paddle/phi/core/kernel_factory.cc | 20 ++++++++++++++++++- paddle/phi/core/kernel_factory.h | 3 ++- python/paddle/utils/code_gen/api_base.py | 11 ++++++++-- python/paddle/utils/code_gen/api_gen.py | 2 ++ .../paddle/utils/code_gen/backward_api_gen.py | 2 ++ 5 files changed, 34 insertions(+), 4 deletions(-) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 81c43764fee9e..a1ce90c2c78ae 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -75,13 +75,31 @@ bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name, } const Kernel& KernelFactory::SelectKernelOrThrowError( - const std::string& kernel_name, const KernelKey& kernel_key) const { + const std::string& kernel_name, + const KernelKey& kernel_key, + bool use_cudnn) const { auto iter = kernels_.find(kernel_name); PADDLE_ENFORCE_NE( iter, kernels_.end(), phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (use_cudnn && kernel_key.backend() == Backend::GPU) { + auto kernel_iter = iter->second.find( + {Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()}); + if (kernel_iter == iter->second.end() && + kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { + kernel_iter = iter->second.find( + {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()}); + } + if (kernel_iter != iter->second.end()) { + return kernel_iter->second; + } + LOG(WARNING) << "The cudnn kernel for [" << kernel_name + << "] is not registered."; + } +#endif auto kernel_iter = iter->second.find(kernel_key); // TODO(chenweihang): polish refind impl here if (kernel_iter == iter->second.end() && diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 6c098c75a0eda..8fd25b691bdeb 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -238,7 +238,8 @@ class KernelFactory { } const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, - const KernelKey& kernel_key) const; + const KernelKey& kernel_key, + bool use_cudnn = false) const; const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, Backend backend, diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index c1a987d06ba39..c51e2b0acd268 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -238,7 +238,8 @@ def parse_kernel(self, kernel_config): 'param': None, 'backend': None, 'layout': None, - 'data_type': None + 'data_type': None, + 'use_cudnn': 'false' } if 'backend' in kernel_config and len(kernel_config['backend']) > 0: kernel['backend'] = kernel_config['backend'] @@ -248,6 +249,10 @@ def parse_kernel(self, kernel_config): kernel['data_type'] = kernel_config['data_type'] if 'param' in kernel_config: kernel['param'] = kernel_config['param'] + if 'use_cudnn' in kernel_config: + kernel['use_cudnn'] = kernel_config['use_cudnn'] + if isinstance(kernel['use_cudnn'], bool): + kernel['use_cudnn'] = str(kernel['use_cudnn']).lower() kernel['func'] = [ kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',') ] @@ -713,10 +718,12 @@ def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False): outputs_args, kernel_output_names, output_create = self.gene_output( self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag) api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') + cudnn_args = '' if self.kernel[ + 'use_cudnn'] == 'false' else ', ' + self.kernel['use_cudnn'] return f""" {code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; {code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( -{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); +{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args}); {code_indent} VLOG(6) << "{self.api} API kernel: " << kernel; {code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index f95edf6c591ab..4087b55b51324 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -163,6 +163,8 @@ def source_include(header_file_path): #include "paddle/phi/infermeta/ternary.h" #include "paddle/fluid/platform/profiler/event_tracing.h" + +DECLARE_bool(conv2d_disable_cudnn); """ diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index e26f65387878c..970ac022473d1 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -179,6 +179,8 @@ def source_include(header_file_path): #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" + +DECLARE_bool(conv2d_disable_cudnn); """ From 3e9ad093c67492288c03ee61cfe6edf93438488a Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Mon, 4 Apr 2022 21:54:33 +0800 Subject: [PATCH 79/93] fix index_select kernel configuration error where input numel is 0 (#41383) --- paddle/phi/kernels/gpu/index_select_grad_kernel.cu | 3 +++ paddle/phi/kernels/gpu/index_select_kernel.cu | 3 +++ 2 files changed, 6 insertions(+) diff --git a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu index 209ce1ccf5c80..75ae1bbcd0a08 100644 --- a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu @@ -85,6 +85,9 @@ void IndexSelectGradKernel(const Context& ctx, phi::DataType::INT64)); int64_t numel = x_grad->numel(); + if (numel == 0) { + return; + } int64_t index_nums = index.numel(); int64_t out_nums = out_grad.numel(); diff --git a/paddle/phi/kernels/gpu/index_select_kernel.cu b/paddle/phi/kernels/gpu/index_select_kernel.cu index 57a13a9aefc2c..38a6582d790f8 100644 --- a/paddle/phi/kernels/gpu/index_select_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_kernel.cu @@ -72,6 +72,9 @@ void IndexSelectKernel(const Context& ctx, T* out_data = ctx.template Alloc(output); int64_t numel = output->numel(); + if (numel == 0) { + return; + } auto stream = ctx.stream(); unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; From eb6d7da947a9ec9151503d069d6329750e5a764c Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Mon, 4 Apr 2022 21:54:50 +0800 Subject: [PATCH 80/93] support getitem when index is a all-false bool tensor (#41297) * support getitem when index is a all-false bool tensor * use cond to replace if * add static_graph geitem unit test when index is a bool tensor --- .../fluid/tests/unittests/test_var_base.py | 11 ++-- .../fluid/tests/unittests/test_variable.py | 55 +++++++++++++++++++ python/paddle/fluid/variable_index.py | 49 +++++++++++------ 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 11d77ecc6226b..ef57ba1530299 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -795,17 +795,17 @@ def _test_bool_index(self): np_value = np.random.random(shape).astype('float32') var_tensor = paddle.to_tensor(np_value) index = [[True, True, True, True], [True, False, True, True], - [True, False, False, True], [False, 0, 1, True, True]] + [True, False, False, True], [False, 0, 1, True, True], + [False, False, False, False]] index2d = np.array([[True, True], [False, False], [True, False], [True, True]]) tensor_index = paddle.to_tensor(index2d) var = [ - var_tensor[index[0]].numpy(), - var_tensor[index[1]].numpy(), - var_tensor[index[2]].numpy(), - var_tensor[index[3]].numpy(), + var_tensor[index[0]].numpy(), var_tensor[index[1]].numpy(), + var_tensor[index[2]].numpy(), var_tensor[index[3]].numpy(), var_tensor[paddle.to_tensor(index[0])].numpy(), var_tensor[tensor_index].numpy(), + var_tensor[paddle.to_tensor(index[4])].numpy() ] self.assertTrue(np.array_equal(var[0], np_value[index[0]])) self.assertTrue(np.array_equal(var[1], np_value[index[1]])) @@ -813,6 +813,7 @@ def _test_bool_index(self): self.assertTrue(np.array_equal(var[3], np_value[index[3]])) self.assertTrue(np.array_equal(var[4], np_value[index[0]])) self.assertTrue(np.array_equal(var[5], np_value[index2d])) + self.assertTrue(np.array_equal(var[6], np_value[index[4]])) self.assertTrue( np.array_equal(var_tensor[var_tensor > 0.67], np_value[np_value > 0.67])) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index b218739ff9527..3a924669b0020 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -690,6 +690,61 @@ def test_dygraph_list_index_muti_dim(self): y = x[index_t1, index_t2] self.assertTrue(np.array_equal(y.numpy(), y_np)) + def run_getitem_list_index(self, array, index): + x = paddle.static.data(name='x', shape=array.shape, dtype='float32') + + y = x[index] + place = paddle.fluid.CPUPlace() + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + + exe.run(paddle.static.default_startup_program()) + fetch_list = [y.name] + array2 = array.copy() + + try: + value_np = array2[index] + except: + with self.assertRaises(ValueError): + getitem_pp = exe.run(prog, + feed={x.name: array}, + fetch_list=fetch_list) + return + getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list) + + print(getitem_pp) + self.assertTrue( + np.array_equal(value_np, getitem_pp[0]), + msg='\n numpy:{},\n paddle:{}'.format(value_np, getitem_pp[0])) + + def test_static_graph_getitem_bool_index(self): + paddle.enable_static() + + # case 1: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([True, False, False, False]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_getitem_list_index(array, index) + + # case 2: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([False, True, False, False]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_getitem_list_index(array, index) + + # case 3: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([True, True, True, True]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_getitem_list_index(array, index) + def run_setitem_list_index(self, array, index, value_np): x = paddle.static.data(name='x', shape=array.shape, dtype='float32') diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index e6990e25a08af..257ddc96d9c87 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -279,6 +279,37 @@ def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): attrs[attr_name] = attr +# the item is a tensor of bool +def get_value_for_bool_tensor(var, item): + if len(item.shape) > len(var.shape): + raise IndexError("The dims of bool index doesn't match indexed array, " + "the dims of bool index except to be equal or less " + "than {}, but received {}.".format( + len(var.shape), len(item.shape))) + for i, dim_len in enumerate(item.shape): + if dim_len != var.shape[i]: + raise IndexError( + "The dimension of bool index doesn't match indexed array along "\ + "dimension {}, the target dimension is {}, but received {}.". + format(i, var.shape[i], dim_len)) + + def idx_not_empty(var, item): + from .layers.nn import where + from ..tensor import gather_nd + + bool_2_idx = where(item == True) + return gather_nd(var, bool_2_idx) + + def idx_empty(var): + var_shape = list(var.shape) + var_shape[0] = 0 + return paddle.empty(var_shape, dtype=var.dtype) + + from .layers.control_flow import cond + return cond(item.any(), lambda: idx_not_empty(var, item), + lambda: idx_empty(var)) + + def _getitem_impl_(var, item): """ Slice the variable. @@ -393,24 +424,10 @@ def _getitem_impl_(var, item): elif isinstance(slice_item, (Variable, core.eager.Tensor)): if len(item) == 1: - from ..tensor import index_select, gather_nd - from .layers.nn import where + from ..tensor import index_select if slice_item.dtype == paddle.bool: - if len(slice_item.shape) > len(var.shape): - raise IndexError( - "The dims of bool index doesn't match indexed array, " - "the dims of bool index except to be equal or less " - "than {}, but received {}.".format( - len(var.shape), len(slice_item.shape))) - for i, dim_len in enumerate(slice_item.shape): - if dim_len != var.shape[i]: - raise IndexError( - "The dimension of bool index doesn't match indexed array along "\ - "dimension {}, the target dimension is {}, but received {}.". - format(i, var.shape[i], dim_len)) - bool_2_idx = where(slice_item == True) - return gather_nd(var, bool_2_idx) + return get_value_for_bool_tensor(var, slice_item) else: if len(slice_item.shape) == 1: return index_select(var, index=slice_item, axis=0) From afb56e8ca6d552b51b6c9da556209094f139a4d4 Mon Sep 17 00:00:00 2001 From: Sing_chan <51314274+betterpig@users.noreply.github.com> Date: Mon, 4 Apr 2022 22:08:26 +0800 Subject: [PATCH 81/93] cut off relation between xk and initial_position's graph (#41371) * cut off relation between xk and initial_position's graph * fix_bug * add detach to cut off with original graph --- python/paddle/incubate/optimizer/functional/bfgs.py | 3 ++- python/paddle/incubate/optimizer/functional/lbfgs.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/optimizer/functional/bfgs.py b/python/paddle/incubate/optimizer/functional/bfgs.py index 9147444f5a6bb..abdab457fda00 100644 --- a/python/paddle/incubate/optimizer/functional/bfgs.py +++ b/python/paddle/incubate/optimizer/functional/bfgs.py @@ -126,7 +126,8 @@ def func(x): check_initial_inverse_hessian_estimate(initial_inverse_hessian_estimate) Hk = paddle.assign(initial_inverse_hessian_estimate) - xk = initial_position + # use detach and assign to create new tensor rather than =, or xk will share memory and grad with initial_position + xk = paddle.assign(initial_position.detach()) value, g1 = _value_and_gradient(objective_func, xk) num_func_calls = paddle.full(shape=[1], fill_value=1, dtype='int64') diff --git a/python/paddle/incubate/optimizer/functional/lbfgs.py b/python/paddle/incubate/optimizer/functional/lbfgs.py index 1fbae18a4c65a..d4bf511f85a99 100644 --- a/python/paddle/incubate/optimizer/functional/lbfgs.py +++ b/python/paddle/incubate/optimizer/functional/lbfgs.py @@ -113,7 +113,8 @@ def func(x): check_initial_inverse_hessian_estimate(initial_inverse_hessian_estimate) H0 = initial_inverse_hessian_estimate - xk = initial_position + # use detach and assign to create new tensor rather than =, or xk will share memory and grad with initial_position + xk = paddle.assign(initial_position.detach()) value, g1 = _value_and_gradient(objective_func, xk) k = paddle.full(shape=[1], fill_value=0, dtype='int64') From 5d6d14bc7e6021e2e36b8c6a9b359fc9754fb550 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 4 Apr 2022 22:28:31 +0800 Subject: [PATCH 82/93] [Eager] fix test_var_base (#41397) * eager test var base * refine, test=develop --- paddle/fluid/pybind/eager.cc | 4 + paddle/fluid/pybind/eager_method.cc | 50 +++ paddle/fluid/pybind/eager_properties.cc | 15 +- paddle/fluid/pybind/eager_utils.cc | 6 + paddle/fluid/pybind/eager_utils.h | 1 + paddle/phi/api/lib/tensor.cc | 6 +- .../fluid/dygraph/varbase_patch_methods.py | 6 +- .../fluid/tests/unittests/test_var_base.py | 339 ++++++++++++++---- 8 files changed, 360 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 657c79e7bd3aa..e39a9199b1cb9 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -78,6 +78,10 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name, phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), ddims)); self->tensor.set_impl(dense_tensor); + } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) { + std::shared_ptr tensor = + std::make_shared(); + self->tensor.set_impl(tensor); } if (!autograd_meta->GetMutableGradNode()) { diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index dfe2fab9fc468..74b866355f070 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -465,6 +465,9 @@ static PyObject* tensor__share_buffer_to(TensorObject* self, PyObject* args, self->tensor.name())); auto* src_tensor = static_cast(self->tensor.impl().get()); + if (!dst_ptr->defined()) { + dst_ptr->set_impl(std::make_shared()); + } auto dst_tensor = static_cast(dst_ptr->impl().get()); dst_tensor->ShareDataWith(*src_tensor); @@ -565,6 +568,10 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY + if (!self->tensor.defined()) { + Py_IncRef(Py_None); + return Py_None; + } if (self->tensor.is_dense_tensor()) { auto* tensor = static_cast(self->tensor.impl().get()); @@ -577,6 +584,25 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + if (!self->tensor.defined()) { + Py_IncRef(Py_None); + return Py_None; + } + if (self->tensor.is_selected_rows()) { + auto* selected_rows = + static_cast(self->tensor.impl().get()); + return ToPyObject(selected_rows); + } else { + Py_IncRef(Py_None); + return Py_None; + } + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -1214,6 +1240,9 @@ static PyObject* tensor_method_get_non_zero_cols(TensorObject* self, static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY + if (!self->tensor.defined()) { + return ToPyObject(false); + } return ToPyObject(self->tensor.is_sparse_coo_tensor() || self->tensor.is_sparse_csr_tensor()); EAGER_CATCH_AND_THROW_RETURN_NULL @@ -1222,6 +1251,9 @@ static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args, static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY + if (!self->tensor.defined()) { + return ToPyObject(false); + } return ToPyObject(self->tensor.is_sparse_coo_tensor()); EAGER_CATCH_AND_THROW_RETURN_NULL } @@ -1229,6 +1261,9 @@ static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args, static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY + if (!self->tensor.defined()) { + return ToPyObject(false); + } return ToPyObject(self->tensor.is_sparse_csr_tensor()); EAGER_CATCH_AND_THROW_RETURN_NULL } @@ -1307,6 +1342,9 @@ static PyObject* tensor_method_is_selected_rows(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY + if (!self->tensor.defined()) { + return ToPyObject(false); + } return ToPyObject(self->tensor.is_selected_rows()); EAGER_CATCH_AND_THROW_RETURN_NULL } @@ -1323,6 +1361,13 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor_methon_element_size(TensorObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + return ToPyObject(paddle::experimental::SizeOf(self->tensor.dtype())); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor__reset_grad_inplace_version(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -1420,6 +1465,9 @@ PyMethodDef variable_methods[] = { {"get_tensor", (PyCFunction)(void (*)(void))tensor_method_get_underline_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, + {"get_selected_rows", + (PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows, + METH_VARARGS | METH_KEYWORDS, NULL}, {"_getitem_index_not_tensor", (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, @@ -1482,6 +1530,8 @@ PyMethodDef variable_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows, METH_VARARGS | METH_KEYWORDS, NULL}, + {"element_size", (PyCFunction)(void (*)(void))tensor_methon_element_size, + METH_VARARGS | METH_KEYWORDS, NULL}, {"_reset_grad_inplace_version", (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index d8c297b1a94c7..4c11fcc7c98c1 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -43,8 +43,14 @@ PyObject* tensor_properties_get_name(TensorObject* self, void* closure) { PyObject* tensor_properties_get_type(TensorObject* self, void* closure) { EAGER_TRY + if (!self->tensor.defined()) { + // be same to old dygraph + return ToPyObject(paddle::framework::proto::VarType::LOD_TENSOR); + } if (self->tensor.is_dense_tensor()) { return ToPyObject(paddle::framework::proto::VarType::LOD_TENSOR); + } else if (self->tensor.is_selected_rows()) { + return ToPyObject(paddle::framework::proto::VarType::SELECTED_ROWS); } else { Py_INCREF(Py_None); return Py_None; @@ -137,8 +143,11 @@ int tensor_properties_set_persistable(TensorObject* self, PyObject* value, PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) { EAGER_TRY - auto ddim = self->tensor.shape(); std::vector value; + if (!self->tensor.defined()) { + return ToPyObject(value); + } + auto ddim = self->tensor.shape(); size_t rank = static_cast(ddim.size()); value.resize(rank); for (size_t i = 0; i < rank; i++) { @@ -165,6 +174,10 @@ PyObject* tensor_properties_get_place_str(TensorObject* self, void* closure) { PyObject* tensor_properties_get_dtype(TensorObject* self, void* closure) { EAGER_TRY + if (!self->tensor.defined()) { + // be same to old dygraph + return ToPyObject(framework::proto::VarType::FP32); + } return ToPyObject( paddle::framework::TransToProtoVarType(self->tensor.type())); EAGER_CATCH_AND_THROW_RETURN_NULL diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index ef1359ac04772..427f21dc1a4b9 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -577,6 +577,12 @@ PyObject* ToPyObject(const paddle::framework::LoDTensor* value) { return obj.ptr(); } +PyObject* ToPyObject(const phi::SelectedRows* value) { + auto obj = ::pybind11::cast(value, py::return_value_policy::reference); + obj.inc_ref(); + return obj.ptr(); +} + PyObject* ToPyObject(const void* value) { if (value == nullptr) { Py_INCREF(Py_None); diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 2fe73c24ee3a0..49075fb44486c 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -75,6 +75,7 @@ PyObject* ToPyObject(const std::vector& value, bool return_py_none_if_not_initialize = false); PyObject* ToPyObject(const platform::Place& value); PyObject* ToPyObject(const framework::LoDTensor* value); +PyObject* ToPyObject(const phi::SelectedRows* value); PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype); PyObject* ToPyObject(const paddle::framework::proto::VarType& type); PyObject* ToPyObject(const void* value); diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index 3790384c8af16..ffc754feaed98 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -101,7 +101,11 @@ int64_t Tensor::size() const { return impl_->numel(); } phi::DDim Tensor::dims() const { return impl_->dims(); } std::vector Tensor::shape() const { - return phi::vectorize(impl_->dims()); + auto dims = impl_->dims(); + if (dims.size() == 1 && dims.at(0) == 0) { + return {}; + } + return phi::vectorize(dims); } void Tensor::reshape(const std::vector &shape) { diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index bd1ca1aa26dda..a62a260969c68 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -846,7 +846,11 @@ def cpu(self): return res @framework.dygraph_only - def cuda(self, device_id, blocking): + def cuda(self, device_id=0, blocking=True): + if device_id is None: + device_id = 0 + if not isinstance(device_id, int): + raise ValueError("\'device_id\' must be a positive integer") if self.place.is_gpu_place(): return self else: diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index ef57ba1530299..724a71ebe3dda 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -31,7 +31,7 @@ def setUp(self): self.dtype = np.float32 self.array = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) - def test_to_tensor(self): + def func_test_to_tensor(self): def _test_place(place): with fluid.dygraph.guard(): paddle.set_default_dtype('float32') @@ -262,7 +262,12 @@ def _test_place(place): _test_place(core.NPUPlace(0)) _test_place("npu:0") - def test_to_tensor_not_change_input_stop_gradient(self): + def test_to_tensor(self): + with _test_eager_guard(): + self.func_test_to_tensor() + self.func_test_to_tensor() + + def func_test_to_tensor_not_change_input_stop_gradient(self): with paddle.fluid.dygraph.guard(core.CPUPlace()): a = paddle.zeros([1024]) a.stop_gradient = False @@ -270,7 +275,12 @@ def test_to_tensor_not_change_input_stop_gradient(self): self.assertEqual(a.stop_gradient, False) self.assertEqual(b.stop_gradient, True) - def test_to_tensor_change_place(self): + def test_to_tensor_not_change_input_stop_gradient(self): + with _test_eager_guard(): + self.func_test_to_tensor_not_change_input_stop_gradient() + self.func_test_to_tensor_not_change_input_stop_gradient() + + def func_test_to_tensor_change_place(self): if core.is_compiled_with_cuda(): a_np = np.random.rand(1024, 1024) with paddle.fluid.dygraph.guard(core.CPUPlace()): @@ -288,7 +298,12 @@ def test_to_tensor_change_place(self): a = paddle.to_tensor(a, place=paddle.CUDAPinnedPlace()) self.assertEqual(a.place.__repr__(), "Place(gpu_pinned)") - def test_to_tensor_with_lodtensor(self): + def test_to_tensor_change_place(self): + with _test_eager_guard(): + self.func_test_to_tensor_change_place() + self.func_test_to_tensor_change_place() + + def func_test_to_tensor_with_lodtensor(self): if core.is_compiled_with_cuda(): a_np = np.random.rand(1024, 1024) with paddle.fluid.dygraph.guard(core.CPUPlace()): @@ -304,7 +319,12 @@ def test_to_tensor_with_lodtensor(self): self.assertTrue(np.array_equal(a_np, a.numpy())) self.assertTrue(a.place.__repr__(), "Place(cpu)") - def test_to_variable(self): + def test_to_tensor_with_lodtensor(self): + with _test_eager_guard(): + self.func_test_to_tensor_with_lodtensor() + self.func_test_to_tensor_with_lodtensor() + + def func_test_to_variable(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array, name="abc") self.assertTrue(np.array_equal(var.numpy(), self.array)) @@ -323,7 +343,12 @@ def test_to_variable(self): linear = fluid.dygraph.Linear(32, 64) var = linear._helper.to_variable("test", name="abc") - def test_list_to_variable(self): + def test_to_variable(self): + with _test_eager_guard(): + self.func_test_to_variable() + self.func_test_to_variable() + + def func_test_list_to_variable(self): with fluid.dygraph.guard(): array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]] var = fluid.dygraph.to_variable(array, dtype='int32') @@ -332,7 +357,12 @@ def test_list_to_variable(self): self.assertEqual(var.dtype, core.VarDesc.VarType.INT32) self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR) - def test_tuple_to_variable(self): + def test_list_to_variable(self): + with _test_eager_guard(): + self.func_test_list_to_variable() + self.func_test_list_to_variable() + + def func_test_tuple_to_variable(self): with fluid.dygraph.guard(): array = (((1, 2), (1, 2), (1, 2)), ((1, 2), (1, 2), (1, 2))) var = fluid.dygraph.to_variable(array, dtype='float32') @@ -341,14 +371,24 @@ def test_tuple_to_variable(self): self.assertEqual(var.dtype, core.VarDesc.VarType.FP32) self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR) - def test_tensor_to_variable(self): + def test_tuple_to_variable(self): + with _test_eager_guard(): + self.func_test_tuple_to_variable() + self.func_test_tuple_to_variable() + + def func_test_tensor_to_variable(self): with fluid.dygraph.guard(): t = fluid.Tensor() t.set(np.random.random((1024, 1024)), fluid.CPUPlace()) var = fluid.dygraph.to_variable(t) self.assertTrue(np.array_equal(t, var.numpy())) - def test_leaf_tensor(self): + def test_tensor_to_variable(self): + with _test_eager_guard(): + self.func_test_tensor_to_variable() + self.func_test_tensor_to_variable() + + def func_test_leaf_tensor(self): with fluid.dygraph.guard(): x = paddle.to_tensor(np.random.uniform(-1, 1, size=[10, 10])) self.assertTrue(x.is_leaf) @@ -374,7 +414,12 @@ def test_leaf_tensor(self): self.assertTrue(linear.bias.is_leaf) self.assertFalse(out.is_leaf) - def test_detach(self): + def test_leaf_tensor(self): + with _test_eager_guard(): + self.func_test_leaf_tensor() + self.func_test_leaf_tensor() + + def func_test_detach(self): with fluid.dygraph.guard(): x = paddle.to_tensor(1.0, dtype="float64", stop_gradient=False) detach_x = x.detach() @@ -407,7 +452,12 @@ def test_detach(self): detach_x[:] = 5.0 y.backward() - def test_write_property(self): + def test_detach(self): + with _test_eager_guard(): + self.func_test_detach() + self.func_test_detach() + + def func_test_write_property(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) @@ -423,9 +473,17 @@ def test_write_property(self): var.stop_gradient = False self.assertEqual(var.stop_gradient, False) - def test_deep_copy(self): + def test_write_property(self): + with _test_eager_guard(): + self.func_test_write_property() + self.func_test_write_property() + + def func_test_deep_copy(self): with fluid.dygraph.guard(): - empty_var = core.VarBase() + if _in_legacy_dygraph(): + empty_var = core.VarBase() + else: + empty_var = core.eager.Tensor() empty_var_copy = copy.deepcopy(empty_var) self.assertEqual(empty_var.stop_gradient, empty_var_copy.stop_gradient) @@ -462,9 +520,15 @@ def test_deep_copy(self): self.assertEqual(id(y_copy), id(y_copy2)) # test copy selected rows - x = core.VarBase(core.VarDesc.VarType.FP32, [3, 100], - "selected_rows", - core.VarDesc.VarType.SELECTED_ROWS, True) + if _in_legacy_dygraph(): + x = core.VarBase(core.VarDesc.VarType.FP32, [3, 100], + "selected_rows", + core.VarDesc.VarType.SELECTED_ROWS, True) + else: + x = core.eager.Tensor(core.VarDesc.VarType.FP32, [3, 100], + "selected_rows", + core.VarDesc.VarType.SELECTED_ROWS, True) + selected_rows = x.value().get_selected_rows() selected_rows.get_tensor().set( np.random.rand(3, 100), core.CPUPlace()) @@ -486,8 +550,13 @@ def test_deep_copy(self): np.array(copy_selected_rows.get_tensor()), np.array(selected_rows.get_tensor()))) + def test_deep_copy(self): + with _test_eager_guard(): + self.func_test_deep_copy() + self.func_test_deep_copy() + # test some patched methods - def test_set_value(self): + def func_test_set_value(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) tmp1 = np.random.uniform(0.1, 1, [2, 2, 3]).astype(self.dtype) @@ -497,12 +566,22 @@ def test_set_value(self): var.set_value(tmp2) self.assertTrue(np.array_equal(var.numpy(), tmp2)) - def test_to_string(self): + def test_set_value(self): + with _test_eager_guard(): + self.func_test_set_value() + self.func_test_set_value() + + def func_test_to_string(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) self.assertTrue(isinstance(str(var), str)) - def test_element_size(self): + def test_to_string(self): + with _test_eager_guard(): + self.func_test_to_string() + self.func_test_to_string() + + def func_test_element_size(self): with fluid.dygraph.guard(): x = paddle.to_tensor(1, dtype='bool') self.assertEqual(x.element_size(), 1) @@ -537,7 +616,12 @@ def test_element_size(self): x = paddle.to_tensor(1, dtype='complex128') self.assertEqual(x.element_size(), 16) - def test_backward(self): + def test_element_size(self): + with _test_eager_guard(): + self.func_test_element_size() + self.func_test_element_size() + + def func_test_backward(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) var.stop_gradient = False @@ -546,7 +630,12 @@ def test_backward(self): grad_var = var._grad_ivar() self.assertEqual(grad_var.shape, self.shape) - def test_gradient(self): + def test_backward(self): + with _test_eager_guard(): + self.func_test_backward() + self.func_test_backward() + + def func_test_gradient(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) var.stop_gradient = False @@ -555,12 +644,22 @@ def test_gradient(self): grad_var = var.gradient() self.assertEqual(grad_var.shape, self.array.shape) - def test_block(self): + def test_gradient(self): + with _test_eager_guard(): + self.func_test_gradient() + self.func_test_gradient() + + def func_test_block(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) self.assertEqual(var.block, fluid.default_main_program().global_block()) + def test_block(self): + with _test_eager_guard(): + self.func_test_block() + self.func_test_block() + def _test_slice(self): w = fluid.dygraph.to_variable( np.random.random((784, 100, 100)).astype('float64')) @@ -916,14 +1015,19 @@ def test_slice(self): self.func_test_slice() self.func_test_slice() - def test_var_base_to_np(self): + def func_test_var_base_to_np(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) self.assertTrue( np.array_equal(var.numpy(), fluid.framework._var_base_to_np(var))) - def test_var_base_as_np(self): + def test_var_base_to_np(self): + with _test_eager_guard(): + self.func_test_var_base_to_np() + self.func_test_var_base_to_np() + + def func_test_var_base_as_np(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) self.assertTrue(np.array_equal(var.numpy(), np.array(var))) @@ -932,7 +1036,12 @@ def test_var_base_as_np(self): var.numpy(), np.array( var, dtype=np.float32))) - def test_if(self): + def test_var_base_as_np(self): + with _test_eager_guard(): + self.func_test_var_base_as_np() + self.func_test_var_base_as_np() + + def func_test_if(self): with fluid.dygraph.guard(): var1 = fluid.dygraph.to_variable(np.array([[[0]]])) var2 = fluid.dygraph.to_variable(np.array([[[1]]])) @@ -951,7 +1060,12 @@ def test_if(self): assert bool(var1) == False, "bool(var1) is False" assert bool(var2) == True, "bool(var2) is True" - def test_to_static_var(self): + def test_if(self): + with _test_eager_guard(): + self.func_test_if() + self.func_test_if() + + def func_test_to_static_var(self): with fluid.dygraph.guard(): # Convert VarBase into Variable or Parameter var_base = fluid.dygraph.to_variable(self.array, name="var_base_1") @@ -974,6 +1088,11 @@ def test_to_static_var(self): static_param = weight._to_static_var() self._assert_to_static(weight, static_param, True) + def test_to_static_var(self): + with _test_eager_guard(): + self.func_test_to_static_var() + self.func_test_to_static_var() + def _assert_to_static(self, var_base, static_var, is_param=False): if is_param: self.assertTrue(isinstance(static_var, fluid.framework.Parameter)) @@ -1015,7 +1134,6 @@ def func_test_tensor_str(self): [0.2665, 0.8483, 0.5389, ..., 0.4956, 0.6862, 0.9178]])''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str(self): with _test_eager_guard(): @@ -1032,7 +1150,6 @@ def func_test_tensor_str2(self): [0. , 0. ]])''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str2(self): with _test_eager_guard(): @@ -1049,7 +1166,6 @@ def func_test_tensor_str3(self): [ 0. , -0.5000]])''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str3(self): with _test_eager_guard(): @@ -1065,7 +1181,6 @@ def func_test_tensor_str_scaler(self): False)''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str_scaler(self): with _test_eager_guard(): @@ -1082,7 +1197,6 @@ def func_test_tensor_str_shape_with_zero(self): [])''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str_shape_with_zero(self): with _test_eager_guard(): @@ -1115,7 +1229,6 @@ def func_test_tensor_str_linewidth(self): 0.4678, 0.5047])''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str_linewidth(self): with _test_eager_guard(): @@ -1143,7 +1256,6 @@ def func_test_tensor_str_linewidth2(self): 8.9448e-01, 7.0981e-01, 8.0783e-01, 4.7065e-01, 5.7154e-01, 7.2319e-01, 4.6777e-01, 5.0465e-01])''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str_linewidth2(self): with _test_eager_guard(): @@ -1162,14 +1274,18 @@ def func_tensor_str_bf16(self): [0. , 0. ]])''' self.assertEqual(a_str, expected) - paddle.enable_static() def test_tensor_str_bf16(self): with _test_eager_guard(): self.func_tensor_str_bf16() self.func_tensor_str_bf16() - def test_print_tensor_dtype(self): + def test_tensor_str_bf16(self): + with _test_eager_guard(): + self.func_tensor_str_bf16() + self.func_tensor_str_bf16() + + def func_test_print_tensor_dtype(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.rand([1]) a_str = str(a.dtype) @@ -1177,11 +1293,15 @@ def test_print_tensor_dtype(self): expected = 'paddle.float32' self.assertEqual(a_str, expected) - paddle.enable_static() + + def test_print_tensor_dtype(self): + with _test_eager_guard(): + self.func_test_print_tensor_dtype() + self.func_test_print_tensor_dtype() class TestVarBaseSetitem(unittest.TestCase): - def setUp(self): + def func_setUp(self): self.set_dtype() self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype)) self.np_value = np.random.random((2, 3)).astype(self.dtype) @@ -1225,9 +1345,9 @@ def func_test_value_tensor(self): def test_value_tensor(self): with _test_eager_guard(): - self.setUp() + self.func_setUp() self.func_test_value_tensor() - self.setUp() + self.func_setUp() self.func_test_value_tensor() def func_test_value_numpy(self): @@ -1235,9 +1355,9 @@ def func_test_value_numpy(self): def test_value_numpy(self): with _test_eager_guard(): - self.setUp() + self.func_setUp() self.func_test_value_numpy() - self.setUp() + self.func_setUp() self.func_test_value_numpy() def func_test_value_int(self): @@ -1245,9 +1365,9 @@ def func_test_value_int(self): def test_value_int(self): with _test_eager_guard(): - self.setUp() + self.func_setUp() self.func_test_value_int() - self.setUp() + self.func_setUp() self.func_test_value_int() @@ -1260,10 +1380,17 @@ class TestVarBaseSetitemFp32(TestVarBaseSetitem): def set_dtype(self): self.dtype = "float32" - def test_value_float(self): + def func_test_value_float(self): paddle.disable_static() self._test(3.3) + def test_value_float(self): + with _test_eager_guard(): + self.func_setUp() + self.func_test_value_float() + self.func_setUp() + self.func_test_value_float() + class TestVarBaseSetitemFp64(TestVarBaseSetitem): def set_dtype(self): @@ -1271,7 +1398,7 @@ def set_dtype(self): class TestVarBaseSetitemBoolIndex(unittest.TestCase): - def setUp(self): + def func_setUp(self): paddle.disable_static() self.set_dtype() self.set_input() @@ -1314,18 +1441,39 @@ def _test(self, value): self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result)) self.assertEqual(id_origin, id(self.tensor_x)) - def test_value_tensor(self): + def func_test_value_tensor(self): paddle.disable_static() self._test(self.tensor_value) - def test_value_numpy(self): + def test_value_tensor(self): + with _test_eager_guard(): + self.func_setUp() + self.func_test_value_tensor() + self.func_setUp() + self.func_test_value_tensor() + + def func_test_value_numpy(self): paddle.disable_static() self._test(self.np_value) - def test_value_int(self): + def test_value_numpy(self): + with _test_eager_guard(): + self.func_setUp() + self.func_test_value_numpy() + self.func_setUp() + self.func_test_value_numpy() + + def func_test_value_int(self): paddle.disable_static() self._test(10) + def test_value_int(self): + with _test_eager_guard(): + self.func_setUp() + self.func_test_value_int() + self.func_setUp() + self.func_test_value_int() + class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase): def set_input(self): @@ -1353,7 +1501,7 @@ def _test(self, value): class TestVarBaseInplaceVersion(unittest.TestCase): - def test_setitem(self): + def func_test_setitem(self): paddle.disable_static() var = paddle.ones(shape=[4, 2, 3], dtype="float32") @@ -1365,7 +1513,12 @@ def test_setitem(self): var[1:2] = 1 self.assertEqual(var.inplace_version, 2) - def test_bump_inplace_version(self): + def test_setitem(self): + with _test_eager_guard(): + self.func_test_setitem() + self.func_test_setitem() + + def func_test_bump_inplace_version(self): paddle.disable_static() var = paddle.ones(shape=[4, 2, 3], dtype="float32") self.assertEqual(var.inplace_version, 0) @@ -1376,9 +1529,14 @@ def test_bump_inplace_version(self): var._bump_inplace_version() self.assertEqual(var.inplace_version, 2) + def test_bump_inplace_version(self): + with _test_eager_guard(): + self.func_test_bump_inplace_version() + self.func_test_bump_inplace_version() + class TestVarBaseSlice(unittest.TestCase): - def test_slice(self): + def func_test_slice(self): paddle.disable_static() np_x = np.random.random((3, 8, 8)) x = paddle.to_tensor(np_x, dtype="float64") @@ -1386,15 +1544,25 @@ def test_slice(self): actual_x = paddle.to_tensor(actual_x) self.assertEqual(actual_x.numpy().all(), np_x[0:1].all()) + def test_slice(self): + with _test_eager_guard(): + self.func_test_slice() + self.func_test_slice() + class TestVarBaseClear(unittest.TestCase): - def test_clear(self): + def func_test_clear(self): paddle.disable_static() np_x = np.random.random((3, 8, 8)) x = paddle.to_tensor(np_x, dtype="float64") x._clear() self.assertEqual(str(x), "Tensor(Not initialized)") + def test_clear(self): + with _test_eager_guard(): + self.func_test_clear() + self.func_test_clear() + class TestVarBaseOffset(unittest.TestCase): def func_offset(self): @@ -1413,23 +1581,31 @@ def test_offset(self): class TestVarBaseShareBufferTo(unittest.TestCase): - def test_share_buffer_To(self): + def func_test_share_buffer_To(self): paddle.disable_static() np_src = np.random.random((3, 8, 8)) src = paddle.to_tensor(np_src, dtype="float64") # empty_var - dst = core.VarBase() + if _in_legacy_dygraph(): + dst = core.VarBase() + else: + dst = core.eager.Tensor() src._share_buffer_to(dst) self.assertEqual(src._is_shared_buffer_with(dst), True) + def test_share_buffer_To(self): + with _test_eager_guard(): + self.func_test_share_buffer_To() + self.func_test_share_buffer_To() + class TestVarBaseTo(unittest.TestCase): - def setUp(self): + def func_setUp(self): paddle.disable_static() self.np_x = np.random.random((3, 8, 8)) self.x = paddle.to_tensor(self.np_x, dtype="float32") - def test_to_api(self): + def func_test_to_api(self): x_double = self.x._to(dtype='double') self.assertEqual(x_double.dtype, paddle.fluid.core.VarDesc.VarType.FP64) self.assertTrue(np.allclose(self.np_x, x_double)) @@ -1476,9 +1652,16 @@ def test_to_api(self): self.assertRaises(ValueError, self.x._to, device=1) self.assertRaises(AssertionError, self.x._to, blocking=1) + def test_to_api(self): + with _test_eager_guard(): + self.func_setUp() + self.func_test_to_api() + self.func_setUp() + self.func_test_to_api() + class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase): - def test_varbase_init(self): + def func_test_varbase_init(self): paddle.disable_static() t = fluid.Tensor() np_x = np.random.random((3, 8, 8)) @@ -1486,17 +1669,28 @@ def test_varbase_init(self): if paddle.fluid.is_compiled_with_cuda(): device = paddle.CUDAPlace(0) - tmp = fluid.core.VarBase(t, device) + if _in_legacy_dygraph(): + tmp = fluid.core.VarBase(t, device) + else: + tmp = fluid.core.eager.Tensor(t, device) self.assertTrue(tmp.place.is_gpu_place()) self.assertEqual(tmp.numpy().all(), np_x.all()) device = paddle.CPUPlace() - tmp = fluid.core.VarBase(t, device) + if _in_legacy_dygraph(): + tmp = fluid.core.VarBase(t, device) + else: + tmp = fluid.core.eager.Tensor(t, device) self.assertEqual(tmp.numpy().all(), np_x.all()) + def test_varbase_init(self): + with _test_eager_guard(): + self.func_test_varbase_init() + self.func_test_varbase_init() + class TestVarBaseNumel(unittest.TestCase): - def test_numel_normal(self): + def func_test_numel_normal(self): paddle.disable_static() np_x = np.random.random((3, 8, 8)) x = paddle.to_tensor(np_x, dtype="float64") @@ -1504,15 +1698,28 @@ def test_numel_normal(self): x_expected_numel = np.product((3, 8, 8)) self.assertEqual(x_actual_numel, x_expected_numel) - def test_numel_without_holder(self): + def test_numel_normal(self): + with _test_eager_guard(): + self.func_test_numel_normal() + self.func_test_numel_normal() + + def func_test_numel_without_holder(self): paddle.disable_static() - x_without_holder = core.VarBase() + if _in_legacy_dygraph(): + x_without_holder = core.VarBase() + else: + x_without_holder = core.eager.Tensor() x_actual_numel = x_without_holder._numel() self.assertEqual(x_actual_numel, 0) + def ttest_numel_without_holder(self): + with _test_eager_guard(): + self.func_test_numel_without_holder() + self.func_test_numel_without_holder() + class TestVarBaseCopyGradientFrom(unittest.TestCase): - def test_copy_gradient_from(self): + def func_test_copy_gradient_from(self): paddle.disable_static() np_x = np.random.random((2, 2)) np_y = np.random.random((2, 2)) @@ -1523,7 +1730,11 @@ def test_copy_gradient_from(self): x._copy_gradient_from(y) self.assertEqual(x.grad.numpy().all(), np_y.all()) + def test_copy_gradient_from(self): + with _test_eager_guard(): + self.func_test_copy_gradient_from() + self.func_test_copy_gradient_from() + if __name__ == '__main__': - paddle.enable_static() unittest.main() From f8b3e576146fd70e6037088ee564f9ced0914678 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 4 Apr 2022 23:07:08 +0800 Subject: [PATCH 83/93] Fix Warpctc error when using muti-gpu (#41389) --- paddle/phi/kernels/impl/warpctc_kernel_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/warpctc_kernel_impl.h b/paddle/phi/kernels/impl/warpctc_kernel_impl.h index 8a18f2500a512..ef6be7a9dfa88 100644 --- a/paddle/phi/kernels/impl/warpctc_kernel_impl.h +++ b/paddle/phi/kernels/impl/warpctc_kernel_impl.h @@ -203,7 +203,7 @@ class WarpCTCFunctor { void init(const Context& dev_ctx, const size_t blank) { warpctc_version_ = phi::dynload::get_warpctc_version(); - if (dev_ctx.GetPlace() == phi::GPUPlace()) { + if (paddle::platform::is_gpu_place(dev_ctx.GetPlace())) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) options_.loc = CTC_GPU; options_.stream = From e90f93675c89e2b63c15c93fe653d26a1eb0627c Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 4 Apr 2022 23:31:41 +0800 Subject: [PATCH 84/93] add no need buffer; (#41367) --- .../final_state_generator/eager_gen.py | 3 +- .../unittests/test_elementwise_add_op.py | 4 +-- python/paddle/utils/code_gen/api.yaml | 29 ++++++------------ python/paddle/utils/code_gen/backward.yaml | 30 ++++++++++++------- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 12738b7206276..b2db256f6026a 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -724,10 +724,11 @@ def GenerateNodeCreationCodes(self): is_optional = (name in optional_inputs) if is_fwd_input: + need_input_data = "false" if name in self.no_need_buffers else "true" if is_optional: set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" else: - set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);" + set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, {need_input_data});" else: if num_fwd_outputs > 1: # Aligned with forward output position diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 4ddfe9d1559de..22787a23feadf 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -28,6 +28,7 @@ def init_kernel_type(self): def setUp(self): self.op_type = "elementwise_add" + self.python_api = paddle.add self.init_dtype() self.init_input_output() self.init_kernel_type() @@ -41,8 +42,7 @@ def setUp(self): self.outputs = {'Out': self.out} def check_eager(self): - return False - #return (self.use_mkldnn == False and self.axis == -1) + return (self.use_mkldnn == False and self.axis == -1) def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index b41ccf8ddb545..050cb058f7df7 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -61,7 +61,6 @@ kernel : func : add backward : add_grad - # no_need_buffer : x, y - api : add_n args : (Tensor[] x) @@ -147,7 +146,6 @@ kernel : func : argsort backward : argsort_grad - # no_need_buffer : x # asin - api : asin @@ -455,7 +453,6 @@ kernel : func : diagonal backward : diagonal_grad - # no_need_buffer : x - api : digamma args : (Tensor x) @@ -666,9 +663,9 @@ - api : frobenius_norm args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) output : Tensor(out) - infer_meta : + infer_meta : func : ReduceInferMetaBase - kernel : + kernel : func : frobenius_norm backward : frobenius_norm_grad @@ -817,14 +814,13 @@ func : index_sample data_type : x backward : index_sample_grad - # no_need_buffer : x - api : index_select args : (Tensor x, Tensor index, int dim) output : Tensor(out) - infer_meta : + infer_meta : func : IndexSelectInferMeta - kernel : + kernel : func : index_select data_type : x backward : index_select_grad @@ -1283,7 +1279,7 @@ func : PoolInferMeta kernel : func : pool2d - backward : pool2d_grad + backward : pool2d_grad - api : pool3d args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) @@ -1393,9 +1389,9 @@ - api : roll args : (Tensor x, IntArray shifts, int64_t[] axis) output : Tensor(out) - infer_meta : + infer_meta : func : RollInferMeta - kernel : + kernel : func : roll backward : roll_grad @@ -1428,7 +1424,6 @@ kernel : func : scatter backward : scatter_grad - # no_need_buffer : updates - api : scatter_nd_add args : (Tensor x, Tensor index, Tensor updates) @@ -1439,7 +1434,6 @@ kernel : func : scatter_nd_add backward : scatter_nd_add_grad - # no_need_buffer : updates - api : searchsorted args : (Tensor sorted_sequence, Tensor value, bool out_int32, bool right) @@ -1633,7 +1627,6 @@ kernel : func : subtract backward : subtract_grad - # no_need_buffer : x, y - api : sum args : (Tensor x, int64_t[] dims={}, DataType out_dtype=paddle::experimental::DataType::UNDEFINED, bool keep_dim=false) @@ -1707,7 +1700,6 @@ kernel : func : tile backward : tile_grad - # no_need_buffer : x - api : top_k args : (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true) @@ -1726,7 +1718,6 @@ kernel : func : trace backward : trace_grad - no_need_buffer : x - api : transpose args : (Tensor x, int[] axis) @@ -1749,9 +1740,9 @@ - api : tril_triu args : (Tensor x, int diagonal, bool lower) output : Tensor(out) - infer_meta : + infer_meta : func : TrilTriuInferMeta - kernel : + kernel : func : tril_triu backward : tril_triu_grad @@ -1773,7 +1764,6 @@ kernel : func : unfold backward : unfold_grad - # no_need_buffer : x - api : unsqueeze args : (Tensor x, IntArray axes) @@ -1812,7 +1802,6 @@ func : WhereIndexInferMeta kernel : func : where_index - # no_need_buffer : x, y # yolo_box - api : yolo_box diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 814c56d7d222c..a45220843b230 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1,3 +1,13 @@ +# - backward_api : gumbel_softmax_grad +# forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out) +# args : (Tensor out, Tensor out_grad, int axis) +# output : Tensor(x_grad) +# infer_meta : +# func : GumbelSoftmaxGradInferMeta +# param : [out, out_grad, axis] +# kernel : +# func : gumbel_softmax_grad + - backward_api : abs_grad forward : abs (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -49,7 +59,7 @@ no_need_buffer : x - backward_api : addmm_grad - forward : scatter (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out) + forward : addmm (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out) args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha, float beta) output : Tensor(input_grad), Tensor(x_grad), Tensor(y_grad) infer_meta : @@ -67,6 +77,7 @@ param : [x] kernel : func : argsort_grad + no_need_buffer : x - backward_api : asin_grad forward : asin (Tensor x) -> Tensor(out) @@ -274,15 +285,6 @@ param: [x] kernel : func : cumprod_grad -# - backward_api : gumbel_softmax_grad -# forward : gumbel_softmax (Tensor x, float temperature, bool hard, int axis) -> Tensor(out) -# args : (Tensor out, Tensor out_grad, int axis) -# output : Tensor(x_grad) -# infer_meta : -# func : GumbelSoftmaxGradInferMeta -# param : [out, out_grad, axis] -# kernel : -# func : gumbel_softmax_grad - backward_api : depthwise_conv2d_transpose_grad forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) @@ -302,6 +304,7 @@ param : [x] kernel : func : diagonal_grad + no_need_buffer : x - backward_api : digamma_grad forward : digamma (Tensor x) -> Tensor(out) @@ -529,6 +532,7 @@ kernel : func : index_sample_grad data_type : out_grad + no_need_buffer : x - backward_api : index_select_grad forward : index_select(Tensor x, Tensor index, int dim) -> Tensor(out) @@ -1026,6 +1030,7 @@ param : [index, updates, out_grad, overwrite] kernel : func : scatter_grad + no_need_buffer : updates - backward_api : scatter_nd_add_grad forward : scatter (Tensor x, Tensor index, Tensor updates) -> Tensor(out) @@ -1036,6 +1041,7 @@ param : [index, updates, out_grad] kernel : func : scatter_nd_grad + no_need_buffer : updates - backward_api : segment_pool_grad forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype) -> Tensor(out), Tensor(summed_ids) @@ -1193,6 +1199,7 @@ param : [x, y] kernel : func : subtract_grad + no_need_buffer : x, y - backward_api : sum_grad forward : sum (Tensor x, int64_t[] dims={}, DataType out_dtype=paddle::experimental::DataType::UNDEFINED, bool keep_dim=false) -> Tensor(out) @@ -1263,6 +1270,7 @@ param : [x] kernel : func : tile_grad + no_need_buffer : x - backward_api : top_k_grad forward : top_k (Tensor x, Scalar k, int axis = -1, bool largest = true, bool sorted = true) -> Tensor(out), Tensor(indices) @@ -1283,6 +1291,7 @@ param : [x] kernel : func : trace_grad + no_need_buffer : x - backward_api : transpose_grad forward : transpose (Tensor x, int[] axis) -> Tensor(out) @@ -1323,6 +1332,7 @@ param : [x] kernel : func : unfold_grad + no_need_buffer : x - backward_api : unsqueeze_grad forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(xshape), Tensor(out) From 69b79e6f09a954b4cd6bc3b0d16f03534db24134 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 5 Apr 2022 08:23:18 +0800 Subject: [PATCH 85/93] ignore no_need_buffer tensor_wrapper in inplace checking (#41350) * support inplace no_need_buffer * fix * use padle.add --- paddle/fluid/eager/tensor_wrapper.h | 2 +- python/paddle/fluid/tests/unittests/test_inplace.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index dc4cf379390f1..3d5d3139de14c 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -51,6 +51,7 @@ class TensorWrapper { * to avoid recursive depends on GradNodeBase * **/ full_reserved_ = full_reserved; + no_need_buffer_ = no_need_buffer; if (full_reserved_) { VLOG(6) << "Fully reserved tensor: " << tensor.name(); intermidiate_tensor_ = tensor; @@ -58,7 +59,6 @@ class TensorWrapper { } // shallow copy tensor_impl here - no_need_buffer_ = no_need_buffer; if (no_need_buffer) { if (phi::DenseTensor::classof(tensor.impl().get())) { // Only Copy Meta diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index b4f1dc22f4ee4..ee0d5bcdde6f2 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -103,7 +103,9 @@ def func_test_backward_success_2(self): var_b[1:2] = 3 # var_b is modified inplace before using it - var_c = var_b + var_b # Here, the grad op of sum doesn't use the value of var_b + var_c = paddle.add( + var_b, + var_b) # Here, the grad op of sum doesn't use the value of var_b loss = var_c.sum() var_b[1:2] = 3 # var_b is modified inplace after using it @@ -111,9 +113,8 @@ def func_test_backward_success_2(self): loss.backward() def test_backward_success_2(self): - # TODO: need to process no_need_buffer in eager mode - # with _test_eager_guard(): - # self.func_test_backward_success_2() + with _test_eager_guard(): + self.func_test_backward_success_2() self.func_test_backward_success_2() From cce176bfbad78c1960e10b558f1f315470db8de7 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 5 Apr 2022 08:41:40 +0800 Subject: [PATCH 86/93] [Phi] add stack yaml and adapt eager mode (#41334) * add stack yaml * add stack yaml * add stack yaml * add no_need_buffer * refine no_need_buffer declare * remove original grad infershape * revert stack op --- paddle/phi/api/lib/api_custom_impl.cc | 139 ++++++++++++------ paddle/phi/api/lib/api_custom_impl.h | 15 +- paddle/phi/infermeta/backward.cc | 41 ++++++ paddle/phi/infermeta/backward.h | 4 + python/paddle/fluid/layers/nn.py | 5 +- .../fluid/tests/unittests/test_stack_op.py | 14 +- python/paddle/utils/code_gen/api.yaml | 9 ++ python/paddle/utils/code_gen/backward.yaml | 7 + 8 files changed, 180 insertions(+), 54 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 6325322b63c6f..40f5b8b297508 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -32,51 +32,7 @@ limitations under the License. */ namespace paddle { namespace experimental { -// TODO(chenweihang): the original sum grad op can support higher-level -// differentiation, -// but if we use this impl, it will not support. We need to be able to reuse -// the autograd API here, which is not yet implemented -// TODO(chenweihang): we should support call generated api in custom api impl -std::vector add_n_grad_impl(const std::vector& x, - const Tensor& out_grad) { - auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - Backend kernel_backend = kernel_key.backend(); - DataLayout kernel_layout = kernel_key.layout(); - DataType kernel_data_type = kernel_key.dtype(); - - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "scale", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "add_n_grad API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << "add_n_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {}); - - size_t out_number = x.size(); - std::vector x_grad; - auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - const phi::Scalar&, - float, - bool, - phi::DenseTensor*); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - - for (auto* dense_x_grad_t : dense_x_grad) { - phi::MetaTensor meta_out(dense_x_grad_t); - phi::UnchangedInferMeta(MakeMetaTensor(*dense_out_grad), &meta_out); - (*kernel_fn)( - *dev_ctx, *dense_out_grad, phi::Scalar(1.0), 0.0, true, dense_x_grad_t); - } - - return x_grad; -} +////////////////// Forward api impls ////////////////////// Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { auto kernel_key_set = ParseKernelKeyByInputArgs(x); @@ -167,6 +123,54 @@ std::vector split_impl(const Tensor& x, return out; } +////////////////// Backward(grad) api impls ////////////////////// + +// TODO(chenweihang): the original sum grad op can support higher-level +// differentiation, +// but if we use this impl, it will not support. We need to be able to reuse +// the autograd API here, which is not yet implemented +// TODO(chenweihang): we should support call generated api in custom api impl +std::vector add_n_grad_impl(const std::vector& x, + const Tensor& out_grad) { + auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + + Backend kernel_backend = kernel_key.backend(); + DataLayout kernel_layout = kernel_key.layout(); + DataType kernel_data_type = kernel_key.dtype(); + + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "scale", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "add_n_grad API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "add_n_grad API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {}); + + size_t out_number = x.size(); + std::vector x_grad; + auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::Scalar&, + float, + bool, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + + for (auto* dense_x_grad_t : dense_x_grad) { + phi::MetaTensor meta_out(dense_x_grad_t); + phi::UnchangedInferMeta(MakeMetaTensor(*dense_out_grad), &meta_out); + (*kernel_fn)( + *dev_ctx, *dense_out_grad, phi::Scalar(1.0), 0.0, true, dense_x_grad_t); + } + + return x_grad; +} + std::tuple batch_norm_impl( const Tensor& x, const Tensor& scale, @@ -361,5 +365,50 @@ std::vector concat_grad_impl(const std::vector& x, return x_grad; } +std::vector stack_grad_impl(const std::vector& x, + const Tensor& out_grad, + int axis) { + auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + + Backend kernel_backend = kernel_key.backend(); + DataLayout kernel_layout = kernel_key.layout(); + DataType kernel_data_type = kernel_key.dtype(); + + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "stack_grad", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "stack_grad API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "stack_grad API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {}); + + size_t out_number = x.size(); + std::vector x_grad; + auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad); + std::vector meta_x_grad; + meta_x_grad.reserve(out_number); + std::vector meta_x_grad_ptrs; + meta_x_grad_ptrs.reserve(out_number); + for (size_t i = 0; i < out_number; ++i) { + meta_x_grad.push_back(dense_x_grad[i]); + meta_x_grad_ptrs.push_back(&meta_x_grad.back()); + } + + phi::StackGradInferMeta( + MakeMetaTensor(*dense_out_grad), axis, meta_x_grad_ptrs); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + int axis, + std::vector); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, *dense_out_grad, axis, dense_x_grad); + + return x_grad; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index e8893cc2476a0..25d70d6477de1 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -22,8 +22,10 @@ limitations under the License. */ namespace paddle { namespace experimental { -std::vector add_n_grad_impl(const std::vector& x, - const Tensor& out_grad); +// NOTE: Separate forward and backward(grad) api impl +// NOTE: The api_impl in this file are arranged in alphabetic order. + +////////////////// Forward api impls ////////////////////// Tensor copy_to_impl(const Tensor& x, Place place, bool blocking); @@ -31,6 +33,11 @@ std::vector split_impl(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis); +////////////////// Backward(grad) api impls ////////////////////// + +std::vector add_n_grad_impl(const std::vector& x, + const Tensor& out_grad); + std::tuple batch_norm_impl( const Tensor& x, const Tensor& scale, @@ -49,5 +56,9 @@ std::vector concat_grad_impl(const std::vector& x, const Tensor& out_grad, const Scalar& axis); +std::vector stack_grad_impl(const std::vector& x, + const Tensor& out_grad, + int axis); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 7282c0695086a..9ee472c5c8843 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -375,4 +375,45 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, } } +void StackGradInferMeta(const MetaTensor& out_grad, + int axis, + std::vector x_grad) { + auto dy_dim = out_grad.dims(); + int rank = dy_dim.size(); + PADDLE_ENFORCE_GE( + axis, + -rank, + phi::errors::InvalidArgument( + "Attr(axis) must be inside [-rank, rank), where rank = %d, " + "but received axis is:%d.", + rank, + axis)); + PADDLE_ENFORCE_LT( + axis, + rank, + phi::errors::InvalidArgument( + "Attr(axis) must be inside [-rank, rank), where rank = %d, " + "but received axis is:%d.", + rank, + axis)); + + if (axis < 0) axis += rank; + PADDLE_ENFORCE_LE( + x_grad.size(), + static_cast(dy_dim[axis]), + phi::errors::InvalidArgument( + "Number of Outputs(X@Grad) should be less than or equal to dy dim " + "at axis, but received outputs size is:%d, dy dims is:%d.", + x_grad.size(), + static_cast(dy_dim[axis]))); + + auto vec = phi::vectorize(dy_dim); + vec.erase(vec.begin() + axis); + + for (size_t i = 0; i < x_grad.size(); ++i) { + x_grad[i]->set_dims(phi::make_ddim(vec)); + x_grad[i]->set_dtype(out_grad.dtype()); + } +} + } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 92266811de057..fb13b4281ae6e 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -163,4 +163,8 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, MetaTensor* x_grad, MetaTensor* updates_grad); +void StackGradInferMeta(const MetaTensor& out_grad, + int axis, + std::vector x_grad); + } // namespace phi diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9f971faed3435..c489b362ccf9e 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10309,7 +10309,10 @@ def stack(x, axis=0, name=None): """ axis = 0 if axis is None else axis - if _non_static_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_stack(x, axis) + + if _in_legacy_dygraph(): return _C_ops.stack(x, 'axis', axis) if not isinstance(x, list) and not isinstance(x, tuple): diff --git a/python/paddle/fluid/tests/unittests/test_stack_op.py b/python/paddle/fluid/tests/unittests/test_stack_op.py index 76f9cf1128ac4..faabcea13aec7 100644 --- a/python/paddle/fluid/tests/unittests/test_stack_op.py +++ b/python/paddle/fluid/tests/unittests/test_stack_op.py @@ -40,6 +40,7 @@ def setUp(self): self.initDefaultParameters() self.initParameters() self.op_type = 'stack' + self.python_api = paddle.stack self.x = [] for i in range(self.num_inputs): self.x.append( @@ -55,20 +56,20 @@ def setUp(self): self.attrs = {'axis': self.axis} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(self.get_x_names(), 'Y') + self.check_grad(self.get_x_names(), 'Y', check_eager=True) class TestStackOp1(TestStackOpBase): def initParameters(self): - self.num_inputs = 16 + self.num_inputs = 8 class TestStackOp2(TestStackOpBase): def initParameters(self): - self.num_inputs = 20 + self.num_inputs = 10 class TestStackOp3(TestStackOpBase): @@ -111,6 +112,7 @@ def setUp(self): self.initDefaultParameters() self.initParameters() self.op_type = 'stack' + self.python_api = paddle.stack self.x = [] for i in range(self.num_inputs): self.x.append( @@ -128,10 +130,10 @@ def setUp(self): self.attrs = {'axis': self.axis} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(self.get_x_names(), 'Y') + self.check_grad(self.get_x_names(), 'Y', check_eager=True) class TestStackAPIWithLoDTensorArray(unittest.TestCase): diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 050cb058f7df7..615bcb01f5690 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1610,6 +1610,15 @@ view: (x -> out) backward : squeeze_grad +- api : stack + args : (Tensor[] x, int axis) + output : Tensor + infer_meta : + func : StackInferMeta + kernel : + func : stack + backward : stack_grad + - api : strided_slice args : (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index a45220843b230..317610679854f 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1180,6 +1180,13 @@ kernel : func : squeeze_grad +- backward_api : stack_grad + forward : stack (Tensor[] x, int axis) -> Tensor(out) + args : (Tensor[] x, Tensor out_grad, int axis) + output : Tensor[](x_grad) + invoke : stack_grad_impl(x, out_grad, axis) + no_need_buffer : x + - backward_api : strided_slice_grad forward : strided_slice (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides) -> Tensor(out) args : (Tensor x, Tensor out_grad, int[] axes, IntArray starts, IntArray ends, IntArray strides) From feaa97984592e08af313acc9d09c7e07e2fc0499 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 5 Apr 2022 09:39:47 +0800 Subject: [PATCH 87/93] add test time, test=document_fix (#41405) --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 663dd9b9e1257..ac3c708cc001e 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -969,7 +969,7 @@ set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_profiler PROPERTIES TIMEOUT 120) set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120) set_tests_properties(test_cross_entropy2_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 150) +set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 180) set_tests_properties(test_fetch_unmerged PROPERTIES TIMEOUT 120) set_tests_properties(test_gru_unit_op PROPERTIES TIMEOUT 120) set_tests_properties(test_activation_nn_grad PROPERTIES TIMEOUT 200) From 93ea1297f753419f73dc365ab3b5d3b0f5562641 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 5 Apr 2022 09:45:20 +0800 Subject: [PATCH 88/93] [new-exec] enable the new standalone executor by default (#41179) * enable new executor by default * enable stream safe allocator * test=document_fix;test=coverage * do not use scope in op kernel * fit empty program for new executor * fix communication depend * fix test_sync_batch_norm * skip unsupported place * refine datatransfer * fit for dirtributed program * fix dependencpy * fix some ut --- .../framework/new_executor/data_transfer.cc | 17 +- .../framework/new_executor/interpretercore.cc | 15 +- .../new_executor/interpretercore_util.cc | 151 +++++++++++++++++- .../memory/allocation/allocator_facade.cc | 2 +- python/paddle/fluid/executor.py | 17 +- .../fluid/tests/unittests/CMakeLists.txt | 6 +- .../unittests/collective_reducescatter.py | 1 + .../distributed_passes/dist_pass_test_base.py | 3 +- .../unittests/ir/inference/CMakeLists.txt | 2 +- .../fluid/tests/unittests/test_nn_grad.py | 1 + .../unittests/test_sync_batch_norm_op.py | 2 + 11 files changed, 187 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 1d0727b80baf7..d0e5565139c54 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -319,6 +319,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, } } + bool transfered = false; DataTranferHelper data_transfer_helper(place, var_scope); for (auto& var_name_item : *ins_map_temp) { bool should_skip_input = @@ -334,6 +335,9 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, if (var->IsType() || var->IsType()) { tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); } else if (var->IsType()) { + if (var->Get().size() == 0) { + continue; + } tensor_in = static_cast(&(var->Get()[0])); } else { @@ -389,6 +393,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, } if (is_transferred) { + transfered = true; // update RuntimeContext.inputs and original op_func_node inputs op_func_node->input_index[var_name_item.first][i] = var_scope->VarId(new_var_name); @@ -426,11 +431,13 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, } } - // NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent - // with instruction. (hot fix, it is not good design here) - op_func_node->operator_base_ = - std::shared_ptr(framework::OpRegistry::CreateOp( - op_base->Type(), new_ins, new_outs, op_base->Attrs())); + if (transfered) { + // NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent + // with instruction. (hot fix, it is not good design here) + op_func_node->operator_base_ = + std::shared_ptr(framework::OpRegistry::CreateOp( + op_base->Type(), new_ins, new_outs, op_base->Attrs())); + } op_func_node->no_data_transform_index = std::move(no_data_transform_index); } diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index cf0b64cbc3a70..29aa7b13a270e 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -300,8 +300,16 @@ void InterpreterCore::Convert( gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(), platform::GenerateDeviceEventFlag()); } + bool inplaced = false; + for (auto inst : vec_instruction_) { + if (inst.OpBase()->Type() == "share_buffer" || + inst.OpBase()->Type() == "share_data") { + VLOG(4) << "Already inplaced, skip inplace now."; + inplaced = true; + } + } - if (FLAGS_new_executor_use_inplace) { + if (FLAGS_new_executor_use_inplace && !inplaced) { BuildInplace(); } @@ -565,12 +573,11 @@ void InterpreterCore::RunNextInstructions( const Instruction& instr, std::queue* reserved_next_ops, std::vector>* atomic_deps, std::vector>* atomic_var_ref) { - VLOG(4) << "atomic 1:" << atomic_deps; auto& next_instr = instr.NextInstructions(); auto IsReady = [atomic_deps](size_t next_id) { - VLOG(4) << "atomic:" << atomic_deps << " " << &(*atomic_deps)[next_id] - << " " << next_id; + VLOG(4) << "atomic:" << atomic_deps << " op_id: " << next_id + << ", remain deps: " << (*atomic_deps)[next_id]; return (*atomic_deps)[next_id].fetch_sub(1, std::memory_order_relaxed) == 1; }; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 360e0222a516c..a704411f3bb71 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -428,19 +428,19 @@ void build_op_func_list(const platform::Place& place, op_func_node.dev_ctx_ = dev_ctx; VLOG(3) << op_with_kernel->Type() << " : expected_kernel_key : " << expected_kernel_key; - auto exec_ctx = - ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); // see OperatorWithKernel::RunImpl in operator.cc for why if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && op->Attr(kAllKernelsMustComputeRuntimeShape))) { InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); // TODO(Aurelius84): In case of control flow ops, they are NOT - // inheritted - // from OperatorWithKernel. + // inheritted from OperatorWithKernel. op_with_kernel->Info().infer_shape_(&infer_shape_ctx); } + auto exec_ctx = + ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); + auto run_phi_kernel = false; if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( op_with_kernel->Type())) { @@ -476,7 +476,6 @@ void build_op_func_list(const platform::Place& place, op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx, &pt_kernel_context); op_func_node.pt_kernel_ = op_with_kernel->PhiKernel(); - (*op_func_node.pt_kernel_)(&pt_kernel_context); } else { auto kernels_iter = all_op_kernels.find(op->Type()); @@ -711,6 +710,7 @@ std::map> build_op_downstream_map( const std::set random_op_set = { "bernoulli", "poisson", "multinomial", "gaussian_random", "uniform_random", "randint", "randperm", "exponential"}; + int dependence_op_idx = -1; for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) { @@ -721,6 +721,147 @@ std::map> build_op_downstream_map( } } + // add dependency for communication op + const std::string communication_op_prefix = "c_"; + dependence_op_idx = -1; + for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { + if (vec_instruction[op_idx].OpBase()->Type().find( + communication_op_prefix) != std::string::npos) { + if (dependence_op_idx != -1) { + op2dependences[op_idx].insert(dependence_op_idx); + } + dependence_op_idx = op_idx; + } + } + + // TODO(zhiqiu): there still some cases not handled + // add dependency for c_sync_comm_stream + + // in program, we can add only one c_sync_comm_stream to sync all + // communication ops. + // c_allreduce_sum(a) + // c_allreduce_sum(b) + // c_allreduce_sum(c) + // c_sync_comm_stream(a) + const std::string kSyncComm = "c_sync_comm_stream"; + dependence_op_idx = -1; + for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { + if (vec_instruction[op_idx].OpBase()->Type() == kSyncComm) { + dependence_op_idx = op_idx; + } else { + if (dependence_op_idx != -1) { + VLOG(4) << "Add depend from " + << vec_instruction[dependence_op_idx].OpBase()->Type() << " to " + << vec_instruction[op_idx].OpBase()->Type(); + op2dependences[op_idx].insert(dependence_op_idx); + } + } + } + + // add dependency for coalesce_tensor + const std::string kCoalesceTensor = "coalesce_tensor"; + for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { + if (vec_instruction[op_idx].OpBase()->Type() == kCoalesceTensor) { + VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx; + auto fused_out = vec_instruction[op_idx].Outputs().at("FusedOutput")[0]; + auto outputs = vec_instruction[op_idx].Outputs().at("Output"); + + auto is_read = [](const Instruction& inst, int var_id) -> bool { + for (auto pair : inst.Inputs()) { + for (auto item : pair.second) { + if (item == var_id) { + return true; + } + } + } + return false; + }; + + auto is_write = [](const Instruction& inst, int var_id) -> bool { + for (auto pair : inst.Outputs()) { + for (auto item : pair.second) { + if (item == var_id) { + return true; + } + } + } + return false; + }; + + // find first op that reads fused_out + auto first_read_fused_out_op = -1; + for (auto j = op_idx + 1; j < vec_instruction.size(); ++j) { + if (is_read(vec_instruction[j], fused_out)) { + first_read_fused_out_op = j; + break; + } + } + + if (UNLIKELY(first_read_fused_out_op == -1)) { + VLOG(4) << "No op read FusedOutput"; + continue; + } + + // find ops that write 'outputs' between (op_index, + // first_read_fused_out_op) + // add depend: them->first_read_fused_out_op + for (auto j = op_idx + 1; + j < static_cast(first_read_fused_out_op); ++j) { + for (auto var_id : outputs) { + if (is_write(vec_instruction[j], var_id)) { + op2dependences[first_read_fused_out_op].insert(j); + VLOG(4) << j << " -> " << first_read_fused_out_op; + VLOG(4) + << "Add depend from " << vec_instruction[j].OpBase()->Type() + << " to " + << vec_instruction[first_read_fused_out_op].OpBase()->Type(); + } + } + } + + // find first op read 'outputs' between (first_read_fused_out_op, end) + // add depned: first_read_fused_out_op -> first op that reads 'outputs' + + // special case for consecutive communication ops, for example, + // FusedOutput = c_sync_calc_stream(FusedOutput) + // FusedOutput= c_allreduce_sum(FusedOutput) + // FusedOutput = c_sync_comm_stream(FusedOutput) + // we should take the last one to add depned instead of + // 'first_read_fused_out_op' + size_t target = first_read_fused_out_op; + for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size(); + ++j) { + if (j == target + 1 && + vec_instruction[target].OpBase()->Type().find( + communication_op_prefix) != std::string::npos && + vec_instruction[j].OpBase()->Type().find(communication_op_prefix) != + std::string::npos) { + VLOG(4) << "Found consecutive communication ops, " + << vec_instruction[target].OpBase()->Type() << " -> " + << vec_instruction[j].OpBase()->Type(); + target = j; + continue; + } + + for (auto var_id : outputs) { + if (is_read(vec_instruction[j], var_id)) { + op2dependences[j].insert(target); + VLOG(4) << target << " -> " << j; + VLOG(4) << "Add depend from " + << vec_instruction[target].OpBase()->Type() << " to " + << vec_instruction[j].OpBase()->Type(); + } + } + } + } + } + for (auto pair : op2dependences) { + VLOG(10) << pair.first << " Depends on " << pair.second.size(); + std::ostringstream oss; + std::copy(pair.second.begin(), pair.second.end(), + std::ostream_iterator(oss, " ")); + VLOG(10) << oss.str(); + } return std::move(get_downstream_map(op2dependences)); } diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index f4dfb76884f17..e2730a1b825e9 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -85,7 +85,7 @@ PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, // NOTE(Ruibiao): This FLAGS is just to be compatibled with // the old single-stream CUDA allocator. It will be removed // after StreamSafeCudaAllocator has been fully tested. -PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, false, +PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, true, "Enable StreamSafeCUDAAllocator"); PADDLE_DEFINE_EXPORTED_bool(use_cuda_managed_memory, false, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index eb833428afa42..935f7b53eba57 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -394,19 +394,10 @@ def _is_enable_standalone_executor(): Whether to use experimental executor `StandaloneExecutor`. """ flag = False - # NOTE(zhiqiu): enable STANDALONE_EXECUTOR on windows platform by default - # It should be enabled on all platform in the future. - - import platform - sysstr = platform.system().lower() - if sysstr == 'windows': - env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', 1) - else: - env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None) + env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', '1') if env_val in [1, '1', True, 'True', 'true']: flag = True - warnings.warn("STANDALONE_EXECUTOR is enabled.") return flag @@ -1386,6 +1377,10 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name, program = pruned_program def _can_use_interpreter_core(program, place): + if core.is_compiled_with_npu() or core.is_compiled_with_xpu( + ) or core.is_compiled_with_mlu() or core.is_compiled_with_ipu(): + return False + compiled = isinstance(program, compiler.CompiledProgram) # NOTE(zhiqiu): do not support compiled program now if compiled: @@ -1396,6 +1391,8 @@ def _can_use_interpreter_core(program, place): # else: # return False else: + if isinstance(program._graph, compiler.CompiledProgram): + return False assert isinstance(program, Program) return True diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ac3c708cc001e..8b84a9c524adf 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -951,7 +951,7 @@ endif() if (WITH_DISTRIBUTE AND NOT APPLE) if(WITH_GPU OR WITH_ROCM) set_tests_properties(test_c_comm_init_op PROPERTIES TIMEOUT 120) - set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 160) + set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 360) endif() endif() @@ -1033,7 +1033,7 @@ set_tests_properties(test_parallel_executor_seresnext_with_reduce_gpu PROPERTIES set_tests_properties(test_dropout_op PROPERTIES TIMEOUT 120) set_tests_properties(test_argsort_op PROPERTIES TIMEOUT 120) set_tests_properties(test_gather_nd_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120) +set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0) set_tests_properties(test_elementwise_sub_op PROPERTIES TIMEOUT 120) set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120) @@ -1072,7 +1072,7 @@ set_tests_properties(test_space_to_depth_op PROPERTIES TIMEOUT 200) set_tests_properties(test_dyn_rnn PROPERTIES TIMEOUT 120) set_tests_properties(test_sgd_op PROPERTIES TIMEOUT 250) set_tests_properties(test_parallel_executor_seresnext_base_gpu PROPERTIES TIMEOUT 120) -set_tests_properties(test_norm_nn_grad PROPERTIES TIMEOUT 120) +set_tests_properties(test_norm_nn_grad PROPERTIES TIMEOUT 120 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0) set_tests_properties(test_matrix_nms_op PROPERTIES TIMEOUT 120) set_tests_properties(test_generator_dataloader PROPERTIES TIMEOUT 120) set_tests_properties(test_partial_concat_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/collective_reducescatter.py b/python/paddle/fluid/tests/unittests/collective_reducescatter.py index 8b989c73d4deb..00d4a1c4cf6bd 100644 --- a/python/paddle/fluid/tests/unittests/collective_reducescatter.py +++ b/python/paddle/fluid/tests/unittests/collective_reducescatter.py @@ -48,6 +48,7 @@ def get_model(self, main_prog, startup_program): tindata = layers.data( name="tindata", shape=[10, 1000], dtype='float32') toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks) + toutdata = fluid.layers.collective._c_sync_comm_stream(toutdata, 0) return toutdata diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py index 488e7c809fc39..f0ed2cdc04950 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py @@ -32,7 +32,7 @@ def prepare_python_path_and_return_module(path): assert filename.endswith(py_suffix), filename env_name = 'PYTHONPATH' - python_path = env_name + python_path = os.environ.get(env_name, '') if python_path: paths = [p for p in python_path.split(":") if p] if dirname not in paths: @@ -41,6 +41,7 @@ def prepare_python_path_and_return_module(path): else: python_path = path os.environ[env_name] = python_path + print('GLOG_v=', os.environ.get('GLOG_v', None), flush=1) return filename[:-len(py_suffix)] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 5be531258edac..808821f06cbae 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -91,7 +91,7 @@ set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) -set_tests_properties(test_trt_multiclass_nms3_op PROPERTIES TIMEOUT 60) +set_tests_properties(test_trt_multiclass_nms3_op PROPERTIES TIMEOUT 60 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0) if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 722926b0d77f7..55f87540c1b8a 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -27,6 +27,7 @@ class TestSliceOpDoubleGradCheck(unittest.TestCase): + @prog_scope() def func(self, place): self.config() diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 47a6d2b811552..6bf811be2ad0d 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -30,6 +30,7 @@ from paddle.fluid import Program, program_guard from op_test import OpTest, _set_use_system_allocator +from decorator_helper import prog_scope _set_use_system_allocator(True) @@ -105,6 +106,7 @@ def _build_program(self, sgd_opt.backward(out) return main, startup, [out, conv, bn] + @prog_scope() def _compare(self, place, layout, only_forward): """Compare results.""" seed = 10 From ceb3382bc31c3748bd5077274bde976c1ed11210 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 5 Apr 2022 10:03:58 +0800 Subject: [PATCH 89/93] [Eager] Fix empty tensor Initializer bug with shape=[] (#41374) * [Eager] Fix empty tensor Initializer bug with shape=[] * [Eager] Fix empty tensor Initializer bug with shape=[] * ignore two unittest * fix unittest --- paddle/fluid/pybind/eager.cc | 19 ++++++++++++++----- paddle/fluid/pybind/eager_method.cc | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index e39a9199b1cb9..1f72af8d79d17 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -72,11 +72,20 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name, } if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) { // TODO(jiabin): Maybe support LOD later - std::shared_ptr dense_tensor = - std::make_shared( - phi::make_intrusive(place), - phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), - ddims)); + std::shared_ptr dense_tensor = nullptr; + if (dims.empty()) { + std::shared_ptr allocation_ptr = nullptr; + dense_tensor = std::make_shared( + allocation_ptr, + phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), + ddims)); + } else { + // TODO(dev): we need enhance check for ddims. + dense_tensor = std::make_shared( + phi::make_intrusive(place), + phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), + ddims)); + } self->tensor.set_impl(dense_tensor); } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) { std::shared_ptr tensor = diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 74b866355f070..9f75b5c70b24d 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -125,6 +125,7 @@ class PyTensorVoidHook : public egr::TensorVoidHook { extern void InitTensorWithNumpyValue(TensorObject* self, const pybind11::object& array, + const paddle::platform::Place& place, bool zero_copy); extern PyTypeObject* p_tensor_type; From 3b0e911c7c10cb97c7366d6a00c66fa579073330 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 5 Apr 2022 12:40:48 +0800 Subject: [PATCH 90/93] [Eager] dataloader2 (#41338) * eager math op, test=develop * eager support lookahead, test=develop * refine,test=develop * refine doc, test=develop * refine,test =develop * refie, test=develop * refie, test=develop * refie, test=develop * test_paddle_multiprocessing * refine, test=develop * refine, test=develop * fix bug, test=develop * refine, test=develop * dataloader, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * test_datasets timeout, test=develop * refine, test=develop --- .../auto_code_generator/eager_generator.cc | 3 +- paddle/fluid/pybind/eager_method.cc | 38 +++ paddle/fluid/pybind/eager_utils.cc | 80 ++++++- paddle/fluid/pybind/eager_utils.h | 8 +- paddle/fluid/pybind/op_function_generator.h | 2 + python/paddle/fluid/dataloader/collate.py | 4 +- .../fluid/dataloader/dataloader_iter.py | 5 +- python/paddle/fluid/dataloader/flat.py | 6 +- python/paddle/fluid/initializer.py | 12 +- .../unittests/test_dataloader_dataset.py | 22 +- .../fluid/tests/unittests/test_lookahead.py | 8 +- .../unittests/test_math_op_patch_var_base.py | 218 +++++++++++++++--- .../test_multiprocess_dataloader_dataset.py | 99 ++++++-- .../unittests/test_paddle_multiprocessing.py | 37 ++- python/paddle/nn/initializer/dirac.py | 14 +- python/paddle/tensor/linalg.py | 4 +- python/paddle/tensor/logic.py | 3 +- python/paddle/tensor/math.py | 14 +- python/paddle/tests/CMakeLists.txt | 2 +- python/paddle/tests/test_datasets.py | 85 ++++++- 20 files changed, 563 insertions(+), 101 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index b1be15ac86ade..de44a833f6e73 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -2653,7 +2653,8 @@ static void GenerateForwardDygraphFile(const std::string& forward_cc_path, "#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n" "#include \"paddle/fluid/eager/amp_utils.h\"\n" "#include \"paddle/fluid/eager/amp_auto_cast.h\"\n" - "#include \"paddle/fluid/platform/profiler/event_tracing.h\"\n\n"; + "#include \"paddle/fluid/platform/profiler/event_tracing.h\"\n" + "#pragma GCC diagnostic ignored \"-Wunused-variable\"\n\n"; std::string forward_cc_include_str = paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE); std::ofstream forward_cc_stream(forward_cc_path, std::ios::out); diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 9f75b5c70b24d..4e18d4bbfbccb 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -42,7 +42,9 @@ limitations under the License. */ #include "pybind11/detail/internals.h" #pragma GCC diagnostic ignored "-Wmissing-field-initializers" #include "paddle/fluid/framework/python_headers.h" +#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/pybind/tensor_py.h" +#include "paddle/phi/core/ddim.h" namespace paddle { namespace pybind { @@ -1390,6 +1392,40 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor_method__share_memory(TensorObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY +#ifndef _WIN32 + PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.inner_place()), true, + platform::errors::InvalidArgument( + "Sharing memory only support CPU Tensor currently")); + // 1. get LoDTensor + auto* t = + std::dynamic_pointer_cast(self->tensor.impl()).get(); + // 2. allocate shared memory + void* data_ptr = t->data(); + size_t data_size = + t->numel() * + framework::SizeOfType(framework::TransToProtoVarType(t->dtype())); + auto shared_writer_holder = + memory::allocation::AllocateMemoryMapWriterAllocation(data_size); + // 3. maintain mmap fd set & backup ipc_name + const std::string& ipc_name = shared_writer_holder->ipc_name(); + memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name); + // 4. copy data & reset holder + memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(), + platform::CPUPlace(), data_ptr, data_size); + t->ResetHolder(shared_writer_holder); + return ToPyObject(t); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Sharing memory in Windows OS is not supported currently")); + Py_INCREF(Py_None); + return Py_None; +#endif + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor__offset(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY @@ -1536,6 +1572,8 @@ PyMethodDef variable_methods[] = { {"_reset_grad_inplace_version", (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_share_memory", (PyCFunction)(void (*)(void))tensor_method__share_memory, + METH_VARARGS | METH_KEYWORDS, NULL}, {"_offset", (PyCFunction)(void (*)(void))tensor__offset, METH_VARARGS | METH_KEYWORDS, NULL}, #if defined(PADDLE_WITH_CUDA) diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 427f21dc1a4b9..8baea3d0dbfe1 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -156,6 +156,17 @@ int64_t CastPyArg2AttrLong(PyObject* obj, ssize_t arg_pos) { } } +size_t CastPyArg2AttrSize_t(PyObject* obj, ssize_t arg_pos) { + if (PyObject_CheckLongOrConvertToLong(&obj)) { + return PyLong_AsSize_t(obj); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "long, but got %s", + arg_pos + 1, (reinterpret_cast(obj->ob_type))->tp_name)); + } +} + float CastPyArg2AttrFloat(PyObject* obj, ssize_t arg_pos) { if (PyObject_CheckFloatOrConvertToFloat(&obj)) { return static_cast(PyFloat_AsDouble(obj)); @@ -297,6 +308,51 @@ std::vector CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos) { return result; } +std::vector CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos) { + std::vector result; + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + PyObject* item = nullptr; + for (Py_ssize_t i = 0; i < len; i++) { + item = PyList_GetItem(obj, i); + if (PyObject_CheckLongOrConvertToLong(&item)) { + result.emplace_back(PyLong_AsSize_t(item)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "list of int, but got %s at pos %d", + arg_pos + 1, + reinterpret_cast(item->ob_type)->tp_name, i)); + } + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "list, but got %s", + arg_pos + 1, reinterpret_cast(obj->ob_type)->tp_name)); + } + return result; +} + +std::vector> CastPyArg2VectorOfVectorOfSize_t( + PyObject* obj, size_t arg_pos) { + std::vector> result; + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + PyObject* item = nullptr; + for (Py_ssize_t i = 0; i < len; i++) { + item = PyList_GetItem(obj, i); + result.emplace_back(CastPyArg2VectorOfSize_t(item, arg_pos)); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "list but got %s", + arg_pos + 1, reinterpret_cast(obj->ob_type)->tp_name)); + } + return result; +} + platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) { platform::Place place; if (PyObject_IsInstance(obj, reinterpret_cast(g_place_pytype))) { @@ -432,10 +488,10 @@ PyObject* ToPyObject(int value) { return PyLong_FromLong(value); } PyObject* ToPyObject(uint32_t value) { return PyLong_FromUnsignedLong(value); } -PyObject* ToPyObject(size_t value) { return PyLong_FromLong(value); } - PyObject* ToPyObject(int64_t value) { return PyLong_FromLongLong(value); } +PyObject* ToPyObject(size_t value) { return PyLong_FromSize_t(value); } + PyObject* ToPyObject(float value) { return PyLong_FromDouble(value); } PyObject* ToPyObject(double value) { return PyLong_FromDouble(value); } @@ -508,6 +564,16 @@ PyObject* ToPyObject(const std::vector& value) { return result; } +PyObject* ToPyObject(const std::vector& value) { + PyObject* result = PyList_New((Py_ssize_t)value.size()); + + for (size_t i = 0; i < value.size(); i++) { + PyList_SET_ITEM(result, (Py_ssize_t)i, ToPyObject(value[i])); + } + + return result; +} + PyObject* ToPyObject(const std::vector& value) { PyObject* result = PyList_New((Py_ssize_t)value.size()); @@ -528,6 +594,16 @@ PyObject* ToPyObject(const std::vector& value) { return result; } +PyObject* ToPyObject(const std::vector>& value) { + PyObject* result = PyList_New((Py_ssize_t)value.size()); + + for (size_t i = 0; i < value.size(); i++) { + PyList_SET_ITEM(result, static_cast(i), ToPyObject(value[i])); + } + + return result; +} + PyObject* ToPyObject(const std::vector& value, bool return_py_none_if_not_initialize) { PyObject* result = PyList_New((Py_ssize_t)value.size()); diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 49075fb44486c..90c4d727923d0 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -36,6 +36,7 @@ bool PyObject_CheckStr(PyObject* obj); bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos); int CastPyArg2AttrInt(PyObject* obj, ssize_t arg_pos); int64_t CastPyArg2AttrLong(PyObject* obj, ssize_t arg_pos); +size_t CastPyArg2AttrSize_t(PyObject* obj, ssize_t arg_pos); float CastPyArg2AttrFloat(PyObject* obj, ssize_t arg_pos); std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos); paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj, @@ -50,14 +51,17 @@ framework::Tensor CastPyArg2FrameworkTensor(PyObject* obj, ssize_t arg_pos); std::vector CastPyArg2VectorOfTensorBase(PyObject* obj, ssize_t arg_pos); std::vector CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos); +std::vector CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos); +std::vector> CastPyArg2VectorOfVectorOfSize_t( + PyObject* obj, size_t arg_pos); framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ssize_t arg_pos); PyObject* ToPyObject(int value); PyObject* ToPyObject(uint32_t value); -PyObject* ToPyObject(size_t value); PyObject* ToPyObject(bool value); PyObject* ToPyObject(int64_t value); +PyObject* ToPyObject(size_t value); PyObject* ToPyObject(float value); PyObject* ToPyObject(double value); PyObject* ToPyObject(const char* value); @@ -69,8 +73,10 @@ PyObject* ToPyObject(const paddle::experimental::Tensor& value, PyObject* ToPyObject(const std::vector& value); PyObject* ToPyObject(const std::vector& value); PyObject* ToPyObject(const std::vector& value); +PyObject* ToPyObject(const std::vector& value); PyObject* ToPyObject(const std::vector& value); PyObject* ToPyObject(const std::vector& value); +PyObject* ToPyObject(const std::vector>& value); PyObject* ToPyObject(const std::vector& value, bool return_py_none_if_not_initialize = false); PyObject* ToPyObject(const platform::Place& value); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index ba4abc8d13536..d9aab3dbb04ce 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -241,6 +241,8 @@ std::map> op_passing_outs_map = { {"run_program", {"Out", "DOut", "OutScope"}}, {"clear_float_status", {"FloatStatusOut"}}, {"get_float_status", {"FloatStatusOut"}}, + {"assign", {"Out"}}, + {"assign_value", {"Out"}}, }; // NOTE(pangyoki): Tensor View Strategy. diff --git a/python/paddle/fluid/dataloader/collate.py b/python/paddle/fluid/dataloader/collate.py index 2086827258128..0bf041007eb38 100644 --- a/python/paddle/fluid/dataloader/collate.py +++ b/python/paddle/fluid/dataloader/collate.py @@ -57,7 +57,7 @@ def default_collate_fn(batch): if isinstance(sample, np.ndarray): batch = np.stack(batch, axis=0) return batch - elif isinstance(sample, paddle.Tensor): + elif isinstance(sample, (paddle.Tensor, core.eager.Tensor)): return layers.stack(batch, axis=0) elif isinstance(sample, numbers.Number): batch = np.array(batch) @@ -99,7 +99,7 @@ def default_convert_fn(batch): Batched data: batched each number, numpy array and paddle.Tensor in input data. """ - if isinstance(batch, (paddle.Tensor, np.ndarray)): + if isinstance(batch, (paddle.Tensor, np.ndarray, core.eager.Tensor)): return batch elif isinstance(batch, (str, bytes)): return batch diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 0dc733440fada..bbf2a4377c767 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -229,7 +229,7 @@ def _thread_loop(self, legacy_expected_place): # pack as LoDTensorArray array = core.LoDTensorArray() for slot in batch: - if isinstance(slot, paddle.Tensor): + if isinstance(slot, (paddle.Tensor, core.eager.Tensor)): slot = slot.value().get_tensor() elif not isinstance(slot, core.LoDTensor): tmp = core.LoDTensor() @@ -543,7 +543,8 @@ def _thread_loop(self, legacy_expected_place): # LoDTensor not in shared memory is not # serializable, cannot be create in workers for slot in batch: - if isinstance(slot, paddle.Tensor): + if isinstance(slot, (paddle.Tensor, + core.eager.Tensor)): slot = slot.value().get_tensor() elif not isinstance(slot, core.LoDTensor): tmp = core.LoDTensor() diff --git a/python/paddle/fluid/dataloader/flat.py b/python/paddle/fluid/dataloader/flat.py index 32c8ef02dd915..5baf4cc853e27 100644 --- a/python/paddle/fluid/dataloader/flat.py +++ b/python/paddle/fluid/dataloader/flat.py @@ -36,7 +36,8 @@ def _flatten_batch(batch): def _flatten(batch, flat_batch, structure, field_idx): if isinstance(batch, Sequence): for field in batch: - if isinstance(field, (np.ndarray, paddle.Tensor)): + if isinstance(field, (np.ndarray, paddle.Tensor, + paddle.fluid.core.eager.Tensor)): structure.append('{}{}'.format(FIELD_PREFIX, field_idx)) flat_batch.append(field) field_idx += 1 @@ -54,7 +55,8 @@ def _flatten(batch, flat_batch, structure, field_idx): structure.append(field) elif isinstance(batch, Mapping): for k, field in batch.items(): - if isinstance(field, (np.ndarray, paddle.Tensor)): + if isinstance(field, (np.ndarray, paddle.Tensor, + paddle.fluid.core.eager.Tensor)): structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx) flat_batch.append(field) field_idx += 1 diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index a416d139a9111..bdc97eca0d84f 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -876,9 +876,9 @@ def __call__(self, var, block=None): raise ValueError("The size of input is too big. ") if framework._non_static_mode(): - out_var = _C_ops.assign_value('shape', - list(shape), 'dtype', out_dtype, - value_name, values) + _C_ops.assign_value(out_var, 'shape', + list(shape), 'dtype', out_dtype, value_name, + values) if var.dtype in [ VarDesc.VarType.FP16, VarDesc.VarType.BF16, VarDesc.VarType.FP64 @@ -985,9 +985,9 @@ def __call__(self, var, block=None): "saving it to file and 'load_op' to load it") if framework._non_static_mode(): - out_var = _C_ops.assign_value('shape', - list(self._value.shape), 'dtype', - out_dtype, value_name, values) + _C_ops.assign_value(out_var, 'shape', + list(self._value.shape), 'dtype', out_dtype, + value_name, values) if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, 'out_dtype', var.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py index c54a1406e39bf..786d04272e3eb 100644 --- a/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py @@ -22,10 +22,11 @@ import paddle.vision.transforms as transforms import paddle.fluid as fluid from paddle.io import * +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class TestDatasetAbstract(unittest.TestCase): - def test_main(self): + def func_test_main(self): dataset = Dataset() try: d = dataset[0] @@ -39,6 +40,11 @@ def test_main(self): except NotImplementedError: pass + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestDatasetWithDiffOutputPlace(unittest.TestCase): def get_dataloader(self, num_workers): @@ -60,7 +66,7 @@ def run_check_on_cpu(self): self.assertTrue(label.place.is_cpu_place()) break - def test_single_process(self): + def func_test_single_process(self): self.run_check_on_cpu() if paddle.is_compiled_with_cuda(): # Get (image, label) tuple from MNIST dataset @@ -72,7 +78,12 @@ def test_single_process(self): self.assertTrue(label.place.is_cuda_pinned_place()) break - def test_multi_process(self): + def test_single_process(self): + with _test_eager_guard(): + self.func_test_single_process() + self.func_test_single_process() + + def func_test_multi_process(self): # DataLoader with multi-process mode is not supported on MacOs and Windows currently if sys.platform != 'darwin' and sys.platform != 'win32': self.run_check_on_cpu() @@ -86,6 +97,11 @@ def test_multi_process(self): self.assertTrue(label.place.is_cuda_pinned_place()) break + def test_multi_process(self): + with _test_eager_guard(): + self.func_test_multi_process() + self.func_test_multi_process() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lookahead.py b/python/paddle/fluid/tests/unittests/test_lookahead.py index a4b5e6d0d9576..263310043a5f7 100644 --- a/python/paddle/fluid/tests/unittests/test_lookahead.py +++ b/python/paddle/fluid/tests/unittests/test_lookahead.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid import paddle import paddle.nn as nn +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph LOOKAHEAD_K = 5 LOOKAHEAD_ALPHA = 0.2 @@ -68,7 +69,7 @@ def test_lookahead_static(self): slow_param.all(), latest_b.all(), delta=5e-3) fast_param = latest_b - SGD_LR * b_grad - def test_look_ahead_dygraph(self): + def func_test_look_ahead_dygraph(self): BATCH_SIZE = 16 BATCH_NUM = 4 EPOCH_NUM = 4 @@ -142,6 +143,11 @@ def train(layer, loader, loss_fn, opt): train(layer, loader, loss_fn, lookahead) + def test_look_ahead_dygraph(self): + with _test_eager_guard(): + self.func_test_look_ahead_dygraph() + self.func_test_look_ahead_dygraph() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 44876c9bd5773..48aa530ff87f9 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -19,6 +19,7 @@ import paddle.fluid as fluid import numpy as np import inspect +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class TestMathOpPatchesVarBase(unittest.TestCase): @@ -26,7 +27,7 @@ def setUp(self): self.shape = [10, 1024] self.dtype = np.float32 - def test_add(self): + def func_test_add(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -35,7 +36,12 @@ def test_add(self): res = a + b self.assertTrue(np.array_equal(res.numpy(), a_np + b_np)) - def test_sub(self): + def test_add(self): + with _test_eager_guard(): + self.func_test_add() + self.func_test_add() + + def func_test_sub(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -44,7 +50,12 @@ def test_sub(self): res = a - b self.assertTrue(np.array_equal(res.numpy(), a_np - b_np)) - def test_mul(self): + def test_sub(self): + with _test_eager_guard(): + self.func_test_sub() + self.func_test_sub() + + def func_test_mul(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -53,7 +64,12 @@ def test_mul(self): res = a * b self.assertTrue(np.array_equal(res.numpy(), a_np * b_np)) - def test_div(self): + def test_mul(self): + with _test_eager_guard(): + self.func_test_mul() + self.func_test_mul() + + def func_test_div(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -63,7 +79,12 @@ def test_div(self): #NOTE: Not sure why array_equal fails on windows, allclose is acceptable self.assertTrue(np.allclose(res.numpy(), a_np / b_np)) - def test_add_scalar(self): + def test_div(self): + with _test_eager_guard(): + self.func_test_div() + self.func_test_div() + + def func_test_add_scalar(self): a_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) @@ -71,7 +92,12 @@ def test_add_scalar(self): res = a + b self.assertTrue(np.array_equal(res.numpy(), a_np + b)) - def test_add_scalar_reverse(self): + def test_add_scalar(self): + with _test_eager_guard(): + self.func_test_add_scalar() + self.func_test_add_scalar() + + def func_test_add_scalar_reverse(self): a_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) @@ -79,7 +105,12 @@ def test_add_scalar_reverse(self): res = b + a self.assertTrue(np.array_equal(res.numpy(), b + a_np)) - def test_sub_scalar(self): + def test_add_scalar_reverse(self): + with _test_eager_guard(): + self.func_test_add_scalar_reverse() + self.func_test_add_scalar_reverse() + + def func_test_sub_scalar(self): a_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) @@ -87,7 +118,12 @@ def test_sub_scalar(self): res = a - b self.assertTrue(np.array_equal(res.numpy(), a_np - b)) - def test_sub_scalar_reverse(self): + def test_sub_scalar(self): + with _test_eager_guard(): + self.func_test_sub_scalar() + self.func_test_sub_scalar() + + def func_test_sub_scalar_reverse(self): a_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) @@ -95,7 +131,12 @@ def test_sub_scalar_reverse(self): res = b - a self.assertTrue(np.array_equal(res.numpy(), b - a_np)) - def test_mul_scalar(self): + def test_sub_scalar_reverse(self): + with _test_eager_guard(): + self.func_test_sub_scalar_reverse() + self.func_test_sub_scalar_reverse() + + def func_test_mul_scalar(self): a_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) @@ -103,8 +144,13 @@ def test_mul_scalar(self): res = a * b self.assertTrue(np.array_equal(res.numpy(), a_np * b)) + def test_mul_scalar(self): + with _test_eager_guard(): + self.func_test_mul_scalar() + self.func_test_mul_scalar() + # div_scalar, not equal - def test_div_scalar(self): + def func_test_div_scalar(self): a_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) @@ -112,8 +158,13 @@ def test_div_scalar(self): res = a / b self.assertTrue(np.allclose(res.numpy(), a_np / b)) + def test_div_scalar(self): + with _test_eager_guard(): + self.func_test_div_scalar() + self.func_test_div_scalar() + # pow of float type, not equal - def test_pow(self): + def func_test_pow(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -122,7 +173,12 @@ def test_pow(self): res = a**b self.assertTrue(np.allclose(res.numpy(), a_np**b_np)) - def test_floor_div(self): + def test_pow(self): + with _test_eager_guard(): + self.func_test_pow() + self.func_test_pow() + + def func_test_floor_div(self): a_np = np.random.randint(1, 100, size=self.shape) b_np = np.random.randint(1, 100, size=self.shape) with fluid.dygraph.guard(): @@ -131,7 +187,12 @@ def test_floor_div(self): res = a // b self.assertTrue(np.array_equal(res.numpy(), a_np // b_np)) - def test_mod(self): + def test_floor_div(self): + with _test_eager_guard(): + self.func_test_floor_div() + self.func_test_floor_div() + + def func_test_mod(self): a_np = np.random.randint(1, 100, size=self.shape) b_np = np.random.randint(1, 100, size=self.shape) with fluid.dygraph.guard(): @@ -140,8 +201,13 @@ def test_mod(self): res = a % b self.assertTrue(np.array_equal(res.numpy(), a_np % b_np)) + def test_mod(self): + with _test_eager_guard(): + self.func_test_mod() + self.func_test_mod() + # for bitwise and/or/xor/not - def test_bitwise(self): + def func_test_bitwise(self): paddle.disable_static() x_np = np.random.randint(-100, 100, [2, 3, 5]) @@ -165,8 +231,13 @@ def test_bitwise(self): out = ~x self.assertTrue(np.array_equal(out.numpy(), out_np)) + def test_bitwise(self): + with _test_eager_guard(): + self.func_test_bitwise() + self.func_test_bitwise() + # for logical compare - def test_equal(self): + def func_test_equal(self): a_np = np.asarray([1, 2, 3, 4, 5]) b_np = np.asarray([1, 2, 3, 4, 5]) c_np = np.asarray([1, 2, 2, 4, 5]) @@ -179,7 +250,12 @@ def test_equal(self): self.assertTrue(np.array_equal(res1.numpy(), a_np == b_np)) self.assertTrue(np.array_equal(res2.numpy(), a_np == c_np)) - def test_not_equal(self): + def test_equal(self): + with _test_eager_guard(): + self.func_test_equal() + self.func_test_equal() + + def func_test_not_equal(self): a_np = np.asarray([1, 2, 3, 4, 5]) b_np = np.asarray([1, 2, 3, 4, 5]) c_np = np.asarray([1, 2, 2, 4, 5]) @@ -192,7 +268,12 @@ def test_not_equal(self): self.assertTrue(np.array_equal(res1.numpy(), a_np != b_np)) self.assertTrue(np.array_equal(res2.numpy(), a_np != c_np)) - def test_less_than(self): + def test_not_equal(self): + with _test_eager_guard(): + self.func_test_not_equal() + self.func_test_not_equal() + + def func_test_less_than(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -201,7 +282,12 @@ def test_less_than(self): res = (a < b) self.assertTrue(np.array_equal(res.numpy(), a_np < b_np)) - def test_less_equal(self): + def test_less_than(self): + with _test_eager_guard(): + self.func_test_less_than() + self.func_test_less_than() + + def func_test_less_equal(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -210,7 +296,12 @@ def test_less_equal(self): res = (a <= b) self.assertTrue(np.array_equal(res.numpy(), a_np <= b_np)) - def test_greater_than(self): + def test_less_equal(self): + with _test_eager_guard(): + self.func_test_less_equal() + self.func_test_less_equal() + + def func_test_greater_than(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -219,7 +310,12 @@ def test_greater_than(self): res = (a > b) self.assertTrue(np.array_equal(res.numpy(), a_np > b_np)) - def test_greater_equal(self): + def test_greater_than(self): + with _test_eager_guard(): + self.func_test_greater_than() + self.func_test_greater_than() + + def func_test_greater_equal(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): @@ -228,27 +324,47 @@ def test_greater_equal(self): res = (a >= b) self.assertTrue(np.array_equal(res.numpy(), a_np >= b_np)) - def test_neg(self): + def test_greater_equal(self): + with _test_eager_guard(): + self.func_test_greater_equal() + self.func_test_greater_equal() + + def func_test_neg(self): a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) res = -a self.assertTrue(np.array_equal(res.numpy(), -a_np)) - def test_float_int_long(self): + def test_neg(self): + with _test_eager_guard(): + self.func_test_neg() + self.func_test_neg() + + def func_test_float_int_long(self): with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(np.array([100.1])) self.assertTrue(float(a) == 100.1) self.assertTrue(int(a) == 100) self.assertTrue(int(a) == 100) - def test_len(self): + def test_float_int_long(self): + with _test_eager_guard(): + self.func_test_float_int_long() + self.func_test_float_int_long() + + def func_test_len(self): a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) self.assertTrue(len(a) == 10) - def test_index(self): + def test_len(self): + with _test_eager_guard(): + self.func_test_len() + self.func_test_len() + + def func_test_index(self): with fluid.dygraph.guard(): var1 = fluid.dygraph.to_variable(np.array([2])) i_tmp = 0 @@ -260,7 +376,12 @@ def test_index(self): str1 = "just test" self.assertTrue(str1[var1] == 's') - def test_np_left_mul(self): + def test_index(self): + with _test_eager_guard(): + self.func_test_index() + self.func_test_index() + + def func_test_np_left_mul(self): with fluid.dygraph.guard(): t = np.sqrt(2.0 * np.pi) x = fluid.layers.ones((2, 2), dtype="float32") @@ -274,7 +395,12 @@ def test_np_left_mul(self): rtol=1e-05, atol=0.0)) - def test_add_different_dtype(self): + def test_np_left_mul(self): + with _test_eager_guard(): + self.func_test_np_left_mul() + self.func_test_np_left_mul() + + def func_test_add_different_dtype(self): a_np = np.random.random(self.shape).astype(np.float32) b_np = np.random.random(self.shape).astype(np.float16) with fluid.dygraph.guard(): @@ -283,7 +409,12 @@ def test_add_different_dtype(self): res = a + b self.assertTrue(np.array_equal(res.numpy(), a_np + b_np)) - def test_floordiv_different_dtype(self): + def test_add_different_dtype(self): + with _test_eager_guard(): + self.func_test_add_different_dtype() + self.func_test_add_different_dtype() + + def func_test_floordiv_different_dtype(self): a_np = np.full(self.shape, 10, np.int64) b_np = np.full(self.shape, 2, np.int32) with fluid.dygraph.guard(): @@ -292,7 +423,12 @@ def test_floordiv_different_dtype(self): res = a // b self.assertTrue(np.array_equal(res.numpy(), a_np // b_np)) - def test_astype(self): + def test_floordiv_different_dtype(self): + with _test_eager_guard(): + self.func_test_floordiv_different_dtype() + self.func_test_floordiv_different_dtype() + + def func_test_astype(self): a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) @@ -306,7 +442,12 @@ def test_astype(self): self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) self.assertTrue(np.array_equal(res1.numpy(), res3.numpy())) - def test_conpare_op_broadcast(self): + def test_astype(self): + with _test_eager_guard(): + self.func_test_astype() + self.func_test_astype() + + def func_test_conpare_op_broadcast(self): a_np = np.random.uniform(-1, 1, [10, 1, 10]).astype(self.dtype) b_np = np.random.uniform(-1, 1, [1, 1, 10]).astype(self.dtype) with fluid.dygraph.guard(): @@ -316,7 +457,12 @@ def test_conpare_op_broadcast(self): self.assertEqual((a != b).dtype, fluid.core.VarDesc.VarType.BOOL) self.assertTrue(np.array_equal((a != b).numpy(), a_np != b_np)) - def test_tensor_patch_method(self): + def test_conpare_op_broadcast(self): + with _test_eager_guard(): + self.func_test_conpare_op_broadcast() + self.func_test_conpare_op_broadcast() + + def func_test_tensor_patch_method(self): paddle.disable_static() x_np = np.random.uniform(-1, 1, [2, 3]).astype(self.dtype) y_np = np.random.uniform(-1, 1, [2, 3]).astype(self.dtype) @@ -590,13 +736,23 @@ def test_tensor_patch_method(self): self.assertTrue(inspect.ismethod(a.std)) self.assertTrue(inspect.ismethod(a.numel)) - def test_complex_scalar(self): + def test_tensor_patch_method(self): + with _test_eager_guard(): + self.func_test_tensor_patch_method() + self.func_test_tensor_patch_method() + + def func_test_complex_scalar(self): a_np = np.random.random(self.shape).astype(self.dtype) with fluid.dygraph.guard(): a = fluid.dygraph.to_variable(a_np) res = 1J * a self.assertTrue(np.array_equal(res.numpy(), 1J * a_np)) + def test_complex_scalar(self): + with _test_eager_guard(): + self.func_test_complex_scalar() + self.func_test_complex_scalar() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py index 8f1febcdeddf7..e23905005df56 100755 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py @@ -21,6 +21,7 @@ import paddle.fluid as fluid from paddle.io import Dataset, IterableDataset, TensorDataset, \ ComposeDataset, ChainDataset, DataLoader, random_split, Subset +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph IMAGE_SIZE = 32 @@ -76,21 +77,28 @@ def run_main(self, num_workers, places): assert len(label) == 1 assert input.shape == [1, 3, 4] assert label.shape == [1, 1] - assert isinstance(input, paddle.Tensor) - assert isinstance(label, paddle.Tensor) + assert isinstance(input, + (fluid.core.VarBase, fluid.core.eager.Tensor)) + assert isinstance(label, + (fluid.core.VarBase, fluid.core.eager.Tensor)) assert np.allclose(input.numpy(), input_np[i]) assert np.allclose(label.numpy(), label_np[i]) - def test_main(self): + def func_test_main(self): places = [paddle.CPUPlace()] if paddle.is_compiled_with_cuda(): places.append(paddle.CUDAPlace(0)) for p in places: self.run_main(num_workers=0, places=p) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestComposeDataset(unittest.TestCase): - def test_main(self): + def func_test_main(self): paddle.static.default_startup_program().random_seed = 1 paddle.static.default_main_program().random_seed = 1 @@ -108,9 +116,14 @@ def test_main(self): assert np.allclose(input2, input2_t) assert np.allclose(label2, label2_t) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestRandomSplitApi(unittest.TestCase): - def test_main(self): + def func_test_main(self): paddle.static.default_startup_program().random_seed = 1 paddle.static.default_main_program().random_seed = 1 @@ -129,9 +142,14 @@ def test_main(self): self.assertTrue(len(elements_list) == 0) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestRandomSplitError(unittest.TestCase): - def test_errors(self): + def func_test_errors(self): paddle.static.default_startup_program().random_seed = 1 paddle.static.default_main_program().random_seed = 1 @@ -139,6 +157,11 @@ def test_errors(self): self.assertRaises(ValueError, paddle.io.random_split, range(5), [8]) self.assertRaises(ValueError, paddle.io.random_split, range(5), []) + def test_errors(self): + with _test_eager_guard(): + self.func_test_errors() + self.func_test_errors() + class TestSubsetDataset(unittest.TestCase): def run_main(self, num_workers, places): @@ -173,8 +196,10 @@ def assert_basic(input, label): assert len(label) == 1 assert input.shape == [1, 3, 4] assert label.shape == [1, 1] - assert isinstance(input, paddle.Tensor) - assert isinstance(label, paddle.Tensor) + assert isinstance(input, + (fluid.core.VarBase, fluid.core.eager.Tensor)) + assert isinstance(label, + (fluid.core.VarBase, fluid.core.eager.Tensor)) elements_list = list() for _, (input, label) in enumerate(dataloader()): @@ -192,7 +217,7 @@ def assert_basic(input, label): self.assertEqual(odd_list, elements_list) - def test_main(self): + def func_test_main(self): paddle.static.default_startup_program().random_seed = 1 paddle.static.default_main_program().random_seed = 1 @@ -202,6 +227,11 @@ def test_main(self): for p in places: self.run_main(num_workers=0, places=p) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestChainDataset(unittest.TestCase): def run_main(self, num_workers, places): @@ -227,13 +257,18 @@ def run_main(self, num_workers, places): assert np.allclose(label, samples[idx][1]) idx += 1 - def test_main(self): + def func_test_main(self): places = [paddle.CPUPlace()] if paddle.is_compiled_with_cuda(): places.append(paddle.CUDAPlace(0)) for p in places: self.run_main(num_workers=0, places=p) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class NumpyMixTensorDataset(Dataset): def __init__(self, sample_num): @@ -269,8 +304,10 @@ def run_main(self, num_workers, places): assert len(label) == 1 assert input.shape == [1, IMAGE_SIZE] assert label.shape == [1, 1] - assert isinstance(input, paddle.Tensor) - assert isinstance(label, paddle.Tensor) + assert isinstance(input, + (fluid.core.VarBase, fluid.core.eager.Tensor)) + assert isinstance(label, + (fluid.core.VarBase, fluid.core.eager.Tensor)) class ComplextDataset(Dataset): @@ -325,10 +362,15 @@ def run_main(self, num_workers): assert data[4]['a'].shape == [2] assert data[4]['b'].shape == [2, 2] - def test_main(self): + def func_test_main(self): for num_workers in [0, 2]: self.run_main(num_workers) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class SingleFieldDataset(Dataset): def __init__(self, sample_num): @@ -360,13 +402,19 @@ def run_main(self, num_workers): drop_last=True) for i, data in enumerate(dataloader()): - assert isinstance(data, paddle.Tensor) + assert isinstance(data, + (fluid.core.VarBase, fluid.core.eager.Tensor)) assert data.shape == [2, 2, 3] - def test_main(self): + def func_test_main(self): for num_workers in [0, 2]: self.run_main(num_workers) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class SingleFieldIterableDataset(IterableDataset): def __init__(self, sample_num): @@ -390,12 +438,17 @@ def setUp(self): [2834126987, 2358157858, 1860244682, 1437227251], [457190280, 2660306227, 859341110, 354512857]] - def test_main(self): + def func_test_main(self): from paddle.fluid.dataloader.worker import _generate_states for inp, outp in zip(self.inputs, self.outputs): out = _generate_states(*inp) assert out == outp + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestDatasetWithDropLast(unittest.TestCase): def run_main(self, dataset, num_samples, batch_size): @@ -413,14 +466,24 @@ def run_main(self, dataset, num_samples, batch_size): datas.append(data) assert len(datas) == steps - def test_map_dataset(self): + def func_test_map_dataset(self): dataset = RandomDataset(10) self.run_main(dataset, 10, 3) - def test_iterable_dataset(self): + def test_map_dataset(self): + with _test_eager_guard(): + self.func_test_map_dataset() + self.func_test_map_dataset() + + def func_test_iterable_dataset(self): dataset = RandomIterableDataset(10) self.run_main(dataset, 10, 3) + def test_iterable_dataset(self): + with _test_eager_guard(): + self.func_test_iterable_dataset() + self.func_test_iterable_dataset() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_paddle_multiprocessing.py b/python/paddle/fluid/tests/unittests/test_paddle_multiprocessing.py index 1e31356a6bc81..7825b13001f28 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_multiprocessing.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_multiprocessing.py @@ -19,6 +19,7 @@ import time import paddle import paddle.incubate.multiprocessing as mp +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, in_dygraph_mode REPEAT = 20 HAS_SHM_FILES = os.path.isdir('/dev/shm') @@ -174,26 +175,54 @@ def test_receive(): class TestMultiprocessingCpu(TestMultiprocessingBase): - def test_pass_tensor(self): + def func_test_pass_tensor(self): + if in_dygraph_mode(): + return paddle.set_device("cpu") self._test_sharing(repeat=REPEAT) - def test_pass_parambase(self): + def test_pass_tensor(self): + with _test_eager_guard(): + self.func_test_pass_tensor() + self.func_test_pass_tensor() + + def func_test_pass_parambase(self): + if in_dygraph_mode(): + return paddle.set_device("cpu") self._test_sharing(repeat=1, param=True) - def test_pass_empty(self): + def test_pass_parambase(self): + with _test_eager_guard(): + self.func_test_pass_parambase() + self.func_test_pass_parambase() + + def func_test_pass_empty(self): + if in_dygraph_mode(): + return paddle.set_device("cpu") self._test_empty() + def test_pass_empty(self): + with _test_eager_guard(): + self.func_test_pass_empty() + self.func_test_pass_empty() + class TestMultiprocessingGpu(TestMultiprocessingBase): @unittest.skipIf(not paddle.fluid.core.is_compiled_with_cuda(), "core is not compiled with CUDA") - def test_pass_tensor(self): + def func_test_pass_tensor(self): + if in_dygraph_mode(): + return paddle.set_device("gpu") self._test_sharing(mp.get_context("spawn"), "gpu") + def test_pass_tensor(self): + with _test_eager_guard(): + self.func_test_pass_tensor() + self.func_test_pass_tensor() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index 46f47fbc7b639..c7cb1052d2f78 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -185,9 +185,10 @@ def __call__(self, var, block=None): if framework.in_dygraph_mode(): with fluid.dygraph.no_grad(): - tmp_tensor = _C_ops.assign_value('shape', [len(idx_list)], - 'dtype', VarDesc.VarType.INT64, - 'int64_values', idx_list) + tmp_tensor = framework._varbase_creator() + _C_ops.assign_value(tmp_tensor, 'shape', [len(idx_list)], + 'dtype', VarDesc.VarType.INT64, + 'int64_values', idx_list) tmp_tensor._share_underline_tensor_to(index_tensor) else: block.append_op( @@ -207,9 +208,10 @@ def __call__(self, var, block=None): if framework.in_dygraph_mode(): with fluid.dygraph.no_grad(): - tmp_tensor = _C_ops.assign_value('shape', [len(value_list)], - 'dtype', VarDesc.VarType.FP32, - 'fp32_values', value_list) + tmp_tensor = framework._varbase_creator() + _C_ops.assign_value(tmp_tensor, 'shape', [len(value_list)], + 'dtype', VarDesc.VarType.FP32, + 'fp32_values', value_list) tmp_tensor._share_underline_tensor_to(value_tensor) else: block.append_op( diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 81c99c5a41e03..c4814bd2b2f9c 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1126,7 +1126,7 @@ def t(input, name=None): return out -def cross(x, y, axis=None, name=None): +def cross(x, y, axis=9, name=None): """ Computes the cross product between two tensors along an axis. @@ -1136,7 +1136,7 @@ def cross(x, y, axis=None, name=None): Args: x (Tensor): The first input tensor. y (Tensor): The second input tensor. - axis (int, optional): The axis along which to compute the cross product. It defaults to the first axis found with the length 3. + axis (int, optional): The axis along which to compute the cross product. It defaults to be 9 which indicates using the first axis found with the length 3. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index a4ff87246631a..f11e21e65da0b 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -282,8 +282,7 @@ def greater_than(x, y, name=None): print(result1) # result1 = [False False True] """ if in_dygraph_mode(): - axis = -1 # default value - return _C_ops.final_state_greater_than(x, y, axis) + return _C_ops.final_state_greater_than(x, y, -1) else: if _in_legacy_dygraph(): return _C_ops.greater_than(x, y) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e4faa573ffb26..5376d393ea432 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -205,13 +205,17 @@ def _elementwise_op_in_dygraph(x, def is_inplace(op_name): return op_name[-1] == "_" - if in_dygraph_mode(): - op = getattr(_C_ops, OP_NAMEMAPPING[op_name] if not is_inplace(op_name) else op_name) - out = op(x, y) - - if _in_legacy_dygraph(): + if op_name not in OP_NAMEMAPPING.keys(): op = getattr(_C_ops, op_name) out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) + else: + if in_dygraph_mode(): + op = getattr(_C_ops, OP_NAMEMAPPING[op_name] if not is_inplace(op_name) else op_name) + out = op(x, y) + + if _in_legacy_dygraph(): + op = getattr(_C_ops, op_name) + out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) return dygraph_utils._append_activation_in_dygraph( out, act, use_mkldnn=use_mkldnn) diff --git a/python/paddle/tests/CMakeLists.txt b/python/paddle/tests/CMakeLists.txt index 0babdee3a0884..bc9f402ed9686 100644 --- a/python/paddle/tests/CMakeLists.txt +++ b/python/paddle/tests/CMakeLists.txt @@ -47,7 +47,7 @@ set_tests_properties(test_dataset_cifar PROPERTIES TIMEOUT 120) set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 120) set_tests_properties(test_model PROPERTIES TIMEOUT 300) set_tests_properties(test_dataset_movielens PROPERTIES TIMEOUT 120) -set_tests_properties(test_datasets PROPERTIES TIMEOUT 150) +set_tests_properties(test_datasets PROPERTIES TIMEOUT 300) set_tests_properties(test_dataset_wmt PROPERTIES TIMEOUT 120) set_tests_properties(test_vision_models PROPERTIES TIMEOUT 120) set_tests_properties(test_dataset_uci_housing PROPERTIES TIMEOUT 120) diff --git a/python/paddle/tests/test_datasets.py b/python/paddle/tests/test_datasets.py index c93bac3ac27e8..be26dff6c0426 100644 --- a/python/paddle/tests/test_datasets.py +++ b/python/paddle/tests/test_datasets.py @@ -22,6 +22,7 @@ import paddle.vision.transforms as T from paddle.vision.datasets import DatasetFolder, ImageFolder, MNIST, FashionMNIST, Flowers from paddle.dataset.common import _check_exists_and_download +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class TestFolderDatasets(unittest.TestCase): @@ -39,7 +40,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.data_dir) - def test_dataset(self): + def func_test_dataset(self): dataset_folder = DatasetFolder(self.data_dir) for _ in dataset_folder: @@ -52,7 +53,12 @@ def test_dataset(self): for _ in dataset_folder: pass - def test_folder(self): + def test_dataset(self): + with _test_eager_guard(): + self.func_test_dataset() + self.func_test_dataset() + + def func_test_folder(self): loader = ImageFolder(self.data_dir) for _ in loader: @@ -64,7 +70,12 @@ def test_folder(self): assert len(loader) == 4 - def test_transform(self): + def test_folder(self): + with _test_eager_guard(): + self.func_test_folder() + self.func_test_folder() + + def func_test_transform(self): def fake_transform(img): return img @@ -78,7 +89,12 @@ def fake_transform(img): for _ in loader: pass - def test_errors(self): + def test_transform(self): + with _test_eager_guard(): + self.func_test_transform() + self.func_test_transform() + + def func_test_errors(self): with self.assertRaises(RuntimeError): ImageFolder(self.empty_dir) with self.assertRaises(RuntimeError): @@ -87,9 +103,14 @@ def test_errors(self): with self.assertRaises(ValueError): _check_exists_and_download('temp_paddle', None, None, None, False) + def test_errors(self): + with _test_eager_guard(): + self.func_test_errors() + self.func_test_errors() + class TestMNISTTest(unittest.TestCase): - def test_main(self): + def func_test_main(self): transform = T.Transpose() mnist = MNIST(mode='test', transform=transform) self.assertTrue(len(mnist) == 10000) @@ -102,9 +123,14 @@ def test_main(self): self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestMNISTTrain(unittest.TestCase): - def test_main(self): + def func_test_main(self): transform = T.Transpose() mnist = MNIST(mode='train', transform=transform) self.assertTrue(len(mnist) == 60000) @@ -133,9 +159,14 @@ def test_main(self): with self.assertRaises(ValueError): mnist = MNIST(mode='train', transform=transform, backend=1) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestFASHIONMNISTTest(unittest.TestCase): - def test_main(self): + def func_test_main(self): transform = T.Transpose() mnist = FashionMNIST(mode='test', transform=transform) self.assertTrue(len(mnist) == 10000) @@ -148,9 +179,14 @@ def test_main(self): self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestFASHIONMNISTTrain(unittest.TestCase): - def test_main(self): + def func_test_main(self): transform = T.Transpose() mnist = FashionMNIST(mode='train', transform=transform) self.assertTrue(len(mnist) == 60000) @@ -179,16 +215,26 @@ def test_main(self): with self.assertRaises(ValueError): mnist = FashionMNIST(mode='train', transform=transform, backend=1) - def test_dataset_value(self): + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + + def func_test_dataset_value(self): fmnist = FashionMNIST(mode='train') value = np.mean([np.array(x[0]) for x in fmnist]) # 72.94035223214286 was getted from competitive products np.testing.assert_allclose(value, 72.94035223214286) + def test_dataset_value(self): + with _test_eager_guard(): + self.func_test_dataset_value() + self.func_test_dataset_value() + class TestFlowersTrain(unittest.TestCase): - def test_main(self): + def func_test_main(self): flowers = Flowers(mode='train') self.assertTrue(len(flowers) == 6149) @@ -201,9 +247,14 @@ def test_main(self): self.assertTrue(image.shape[2] == 3) self.assertTrue(label.shape[0] == 1) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestFlowersValid(unittest.TestCase): - def test_main(self): + def func_test_main(self): flowers = Flowers(mode='valid') self.assertTrue(len(flowers) == 1020) @@ -216,9 +267,14 @@ def test_main(self): self.assertTrue(image.shape[2] == 3) self.assertTrue(label.shape[0] == 1) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + class TestFlowersTest(unittest.TestCase): - def test_main(self): + def func_test_main(self): flowers = Flowers(mode='test') self.assertTrue(len(flowers) == 1020) @@ -247,6 +303,11 @@ def test_main(self): with self.assertRaises(ValueError): flowers = Flowers(mode='test', backend=1) + def test_main(self): + with _test_eager_guard(): + self.func_test_main() + self.func_test_main() + if __name__ == '__main__': unittest.main() From b0f8000e141c61dcefc3fe2d0587826f9b515363 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 5 Apr 2022 13:18:23 +0800 Subject: [PATCH 91/93] Implement AutoTuneStatus class for Kernel Auto Tune (#41218) * switch autotune * implement AutoTuneCache * implement AutoTuneCache class * add pybind api * add dygraph test * support static mode and eager mode and improve unittests * rename the SwitchAutoTune Class and improve tests * improve AutoTuneStatus and reduce the cost of tests --- paddle/fluid/imperative/basic_engine.cc | 3 + paddle/fluid/pybind/pybind.cc | 30 ++++ paddle/phi/kernels/autotune/CMakeLists.txt | 4 +- paddle/phi/kernels/autotune/cache.cc | 36 +++++ paddle/phi/kernels/autotune/cache.h | 72 +++++++-- paddle/phi/kernels/autotune/cache_test.cc | 9 +- paddle/phi/kernels/autotune/switch_autotune.h | 130 ++++++++++++++++ python/paddle/fluid/executor.py | 4 +- .../tests/unittests/test_switch_autotune.py | 147 ++++++++++++++++++ 9 files changed, 415 insertions(+), 20 deletions(-) create mode 100644 paddle/phi/kernels/autotune/cache.cc create mode 100644 paddle/phi/kernels/autotune/switch_autotune.h create mode 100644 python/paddle/fluid/tests/unittests/test_switch_autotune.py diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index d7478b18dba06..ce3c5dd2fe562 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -30,6 +30,7 @@ #include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/phi/kernels/autotune/switch_autotune.h" #include "paddle/phi/kernels/funcs/math_function.h" DECLARE_bool(sort_sum_gradient); @@ -645,6 +646,8 @@ void BasicEngine::Execute() { Clear(); VLOG(1) << "Backward op number: " << op_num; + + phi::autotune::AutoTuneStatus::Instance().Update(); } void BasicEngine::Clear() { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 982aa52913d63..96d86ee1a3100 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -168,6 +168,8 @@ limitations under the License. */ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/autotune/switch_autotune.h" #include "pybind11/stl.h" DECLARE_bool(use_mkldnn); @@ -4419,6 +4421,34 @@ All parameter, weight, gradient are variables in Paddle. .def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled); #endif + m.def("enable_autotune", [] { + return phi::autotune::AutoTuneStatus::Instance().EnableAutoTune(); + }); + + m.def("disable_autotune", [] { + return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune(); + }); + + m.def("autotune_range", [](int64_t start, int64_t stop) { + return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start, + stop); + }); + + m.def("update_autotune_status", + [] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); + + m.def("autotune_status", [] { + phi::autotune::AutoTuneCache::Instance().UpdateStatus(); + py::dict res; + res["use_autotune"] = + phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID(); + res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size(); + res["cache_hit_rate"] = + phi::autotune::AutoTuneCache::Instance().CacheHitRate(); + return res; + }); + BindFleetWrapper(&m); BindIO(&m); diff --git a/paddle/phi/kernels/autotune/CMakeLists.txt b/paddle/phi/kernels/autotune/CMakeLists.txt index a3fb9a06fe671..db094d85bf3fd 100644 --- a/paddle/phi/kernels/autotune/CMakeLists.txt +++ b/paddle/phi/kernels/autotune/CMakeLists.txt @@ -6,4 +6,6 @@ elseif (WITH_ROCM) hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) endif() -cc_test(cache_test SRCS cache_test.cc DEPS gtest) +cc_library(cache SRCS cache.cc DEPS) + +cc_test(cache_test SRCS cache_test.cc DEPS gtest cache) diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc new file mode 100644 index 0000000000000..bf68e2010151b --- /dev/null +++ b/paddle/phi/kernels/autotune/cache.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/autotune/cache.h" + +namespace phi { +namespace autotune { + +// Define the cache key of operator +size_t ConvKey(const std::vector& x_dims, + const std::vector& w_dims, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + phi::DataType dtype) { + return GetKey(x_dims, + w_dims, + strides, + paddings, + dilations, + static_cast(dtype)); +} + +} // namespace autotune +} // namespace phi diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 990843e58f7f2..d492e7c151f91 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -64,14 +64,7 @@ size_t ConvKey(const std::vector& x_dims, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, - phi::DataType dtype) { - return GetKey(x_dims, - w_dims, - strides, - paddings, - dilations, - static_cast(dtype)); -} + phi::DataType dtype); template class AlgorithmsCache { @@ -104,14 +97,21 @@ class AlgorithmsCache { hash_[key] = algo; } + int64_t CacheMisses() const { return cache_misses_; } + + int64_t CacheHits() const { return cache_hits_; } + float CacheHitRate() const { int64_t num_accesses = cache_hits_ + cache_misses_; - float cache_hit_rate = - static_cast(cache_hits_) / static_cast(num_accesses); + float cache_hit_rate = 0.; + if (num_accesses != 0) { + cache_hit_rate = + static_cast(cache_hits_) / static_cast(num_accesses); + } return cache_hit_rate; } - int64_t Size() { return hash_.size(); } + int64_t Size() const { return hash_.size(); } private: std::unordered_map hash_; @@ -142,20 +142,58 @@ class AutoTuneCache { return auto_tune_map_[algo_type]; } - // The number of total config cached - int64_t Size() { - int64_t total = 0; + void Clean(float miss_rate) { + std::lock_guard lock(*autotune_cache_mutex_); + // Set a small tolerance to avoid performance degradation + // due to large cache size under dynamic shape. + if (miss_rate > 0.01) { + auto_tune_map_.clear(); + } + } + + void UpdateStatus() { + int64_t size = 0; + int64_t cache_hits = 0; + int64_t cache_misses = 0; for (auto& v : auto_tune_map_) { - VLOG(3) << v.first << " " << v.second.Size(); - total += v.second.Size(); + VLOG(4) << "AlgoType: " << v.first << " Cache Size: " << v.second.Size() + << " Hits: " << v.second.CacheHits() + << " Misses: " << v.second.CacheMisses() + << " Hit Rate: " << v.second.CacheHitRate(); + size += v.second.Size(); + cache_hits += v.second.CacheHits(); + cache_misses += v.second.CacheMisses(); } - return total; + total_size_ = size; + total_cache_hits_ = cache_hits; + total_cache_misses_ = cache_misses; + } + + // The number of total config cached + int64_t Size() const { return total_size_; } + + int64_t CacheHits() const { return total_cache_hits_; } + + int64_t CacheMisses() const { return total_cache_misses_; } + + float CacheHitRate() const { + float total_cache_hit_rate = 0.; + int64_t total_num_accesses = total_cache_hits_ + total_cache_misses_; + if (total_num_accesses != 0) { + total_cache_hit_rate = static_cast(total_cache_hits_) / + static_cast(total_num_accesses); + } + + return total_cache_hit_rate; } private: AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {} AlgorithmsTypeMap auto_tune_map_; std::shared_ptr autotune_cache_mutex_; + int64_t total_cache_hits_ = 0; + int64_t total_cache_misses_ = 0; + int64_t total_size_ = 0; }; } // namespace autotune diff --git a/paddle/phi/kernels/autotune/cache_test.cc b/paddle/phi/kernels/autotune/cache_test.cc index 9fcd9b796d0ae..92ba411624fc0 100644 --- a/paddle/phi/kernels/autotune/cache_test.cc +++ b/paddle/phi/kernels/autotune/cache_test.cc @@ -46,8 +46,15 @@ TEST(AlgosCache, AlgosCache) { EXPECT_EQ(cache.Find(key), false); cache.Set(key, ConvAlgos::CuDNNKernel_1); EXPECT_EQ(cache.Size(), 2); - EXPECT_EQ(autotune_cache.Size(), 2); + EXPECT_EQ(cache.CacheHits(), 1); + EXPECT_EQ(cache.CacheMisses(), 2); float cache_hit_rate = static_cast(1) / static_cast(3); EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5); + + autotune_cache.UpdateStatus(); + EXPECT_EQ(autotune_cache.Size(), 2); + EXPECT_EQ(autotune_cache.CacheHits(), 1); + EXPECT_EQ(autotune_cache.CacheMisses(), 2); + EXPECT_LT(std::abs(cache_hit_rate - autotune_cache.CacheHitRate()), 1e-5); } diff --git a/paddle/phi/kernels/autotune/switch_autotune.h b/paddle/phi/kernels/autotune/switch_autotune.h new file mode 100644 index 0000000000000..2f9621ed2079e --- /dev/null +++ b/paddle/phi/kernels/autotune/switch_autotune.h @@ -0,0 +1,130 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/kernels/autotune/cache.h" + +namespace phi { +namespace autotune { + +class AutoTuneStatus { + public: + static AutoTuneStatus& Instance() { + static AutoTuneStatus switch_autotune; + return switch_autotune; + } + + bool UseAutoTune() { return use_autotune_; } + + // EnableAutoTune and DisableAutoTune Should be used for debug only. + void EnableAutoTune() { + use_autotune_ = true; + Init(); + } + + void DisableAutoTune() { + use_autotune_ = false; + Init(); + } + + void Update() { + current_steps_id_ += 1; + + if (!use_autotune_ && !update_use_autotune_) { + return; + } + + if (current_steps_id_ < start_step_id_) { + use_autotune_ = false; + } else if (current_steps_id_ >= start_step_id_ && + current_steps_id_ < stop_step_id_) { + use_autotune_ = true; + AutoTuneCache::Instance().UpdateStatus(); + step_hit_rates_.push_back(StepHitRate()); + VLOG(3) << "Step ID " << current_steps_id_ + << ", Accumulative Cache Hit Rate: " + << AutoTuneCache::Instance().CacheHitRate() + << ", Cache Size: " << AutoTuneCache::Instance().Size() + << ", Current Step Hit Rate: " << StepHitRate(); + } else if (current_steps_id_ == stop_step_id_) { + use_autotune_ = false; + update_use_autotune_ = false; + // clean cache according miss rate + float miss_rate = static_cast(1) - RecentHitRate(); + AutoTuneCache::Instance().Clean(miss_rate); + VLOG(3) << "Recent Miss Rate: " << miss_rate; + } + } + + int64_t StepID() { return current_steps_id_; } + + float RecentHitRate() { + int recent_step_nums = std::ceil(step_hit_rates_.size() * 0.3); + float sum = std::accumulate(step_hit_rates_.rbegin(), + step_hit_rates_.rbegin() + recent_step_nums, + 0.0); + float mean = sum / recent_step_nums; + return mean; + } + + // Hit Rate of Current Step + float StepHitRate() { + int64_t current_hits = AutoTuneCache::Instance().CacheHits(); + int64_t current_misses = AutoTuneCache::Instance().CacheMisses(); + int64_t step_hits_ = current_hits - previous_hits_; + int64_t step_misses_ = current_misses - previous_misses_; + float step_hit_rate = 0.; + int64_t step_num_accesses = step_hits_ + step_misses_; + if (step_num_accesses != 0) { + step_hit_rate = static_cast(step_hits_) / + static_cast(step_num_accesses); + } + previous_hits_ = current_hits; + previous_misses_ = current_misses; + return step_hit_rate; + } + + void SetAutoTuneRange(int64_t start, int64_t stop) { + start_step_id_ = start; + stop_step_id_ = stop; + } + + private: + AutoTuneStatus() = default; + + void Init() { + update_use_autotune_ = use_autotune_; + current_steps_id_ = -1; + previous_hits_ = 0; + previous_misses_ = 0; + step_hit_rates_.clear(); + AutoTuneCache::Instance().Clean(1.0); + } + + int64_t start_step_id_ = 0; + int64_t stop_step_id_ = 10; + int64_t current_steps_id_ = -1; + bool use_autotune_ = false; + bool update_use_autotune_ = false; + int64_t previous_hits_ = 0; + int64_t previous_misses_ = 0; + std::vector step_hit_rates_; +}; + +} // namespace autotune +} // namespace phi diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 935f7b53eba57..2232c34e63bd0 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1276,7 +1276,7 @@ def run(self, """ try: - return self._run_impl( + res = self._run_impl( program=program, feed=feed, fetch_list=fetch_list, @@ -1287,6 +1287,8 @@ def run(self, use_program_cache=use_program_cache, use_prune=use_prune, return_merged=return_merged) + core.update_autotune_status() + return res except Exception as e: six.reraise(*sys.exc_info()) diff --git a/python/paddle/fluid/tests/unittests/test_switch_autotune.py b/python/paddle/fluid/tests/unittests/test_switch_autotune.py new file mode 100644 index 0000000000000..08cf120a0366e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_switch_autotune.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy + + +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.conv = paddle.nn.Conv2D(1, 2, (3, 3)) + + def forward(self, image, label=None): + return self.conv(image) + + +def train_dygraph(net, data): + out = net(data) + loss = paddle.mean(out) + adam = paddle.optimizer.Adam(parameters=net.parameters()) + out.backward() + adam.step() + adam.clear_grad() + + +def static_program(net, data): + out = net(data) + loss = paddle.mean(out) + adam = paddle.optimizer.Adam() + adam.minimize(loss) + return loss + + +class TestAutoTune(unittest.TestCase): + def test_autotune(self): + paddle.fluid.core.disable_autotune() + status = paddle.fluid.core.autotune_status() + self.assertEqual(status["use_autotune"], False) + + paddle.fluid.core.enable_autotune() + status = paddle.fluid.core.autotune_status() + self.assertEqual(status["use_autotune"], True) + + def check_status(self, expected_res): + status = paddle.fluid.core.autotune_status() + for key in status.keys(): + self.assertEqual(status[key], expected_res[key]) + + +class TestDygraphAutoTuneStatus(TestAutoTune): + def run_program(self, enable_autotune): + if enable_autotune: + paddle.fluid.core.enable_autotune() + else: + paddle.fluid.core.disable_autotune() + paddle.fluid.core.autotune_range(1, 2) + x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.) + net = SimpleNet() + for i in range(3): + train_dygraph(net, x_var) + if i >= 1 and i < 2: + expected_res = { + "step_id": i, + "use_autotune": enable_autotune, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + else: + expected_res = { + "step_id": i, + "use_autotune": False, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + + def test_enable_autotune(self): + self.run_program(enable_autotune=True) + + def test_disable_autotune(self): + self.run_program(enable_autotune=False) + + +class TestStaticAutoTuneStatus(TestAutoTune): + def run_program(self, enable_autotune): + paddle.enable_static() + if enable_autotune: + paddle.fluid.core.enable_autotune() + else: + paddle.fluid.core.disable_autotune() + paddle.fluid.core.autotune_range(1, 2) + + data_shape = [1, 1, 8, 8] + data = paddle.static.data(name='X', shape=data_shape, dtype='float32') + net = SimpleNet() + loss = static_program(net, data) + place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + x = numpy.random.random(size=data_shape).astype('float32') + + for i in range(3): + exe.run(feed={'X': x}, fetch_list=[loss]) + status = paddle.fluid.core.autotune_status() + # In static mode, the startup_program will run at first. + # The expected step_id will be increased by 1. + if i >= 0 and i < 1: + expected_res = { + "step_id": i + 1, + "use_autotune": enable_autotune, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + else: + expected_res = { + "step_id": i + 1, + "use_autotune": False, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + paddle.disable_static() + + def test_enable_autotune(self): + self.run_program(enable_autotune=True) + + def test_disable_autotune(self): + self.run_program(enable_autotune=False) + + +if __name__ == '__main__': + unittest.main() From 510347f95b4d4970d36589665e66c522dd2956b8 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Tue, 5 Apr 2022 14:00:22 +0800 Subject: [PATCH 92/93] Fix divide_grad yaml args error (#41406) --- python/paddle/utils/code_gen/backward.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 317610679854f..f073529fcd280 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -328,7 +328,7 @@ - backward_api : divide_grad forward : divide (Tensor x, Tensor y) -> Tensor(out) - args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1) + args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1) output : Tensor(x_grad), Tensor(y_grad) infer_meta : func : GeneralBinaryGradInferMeta From 7554f428f59d630283b59dd8cf604062b57cff6a Mon Sep 17 00:00:00 2001 From: RichardWooSJTU <37864677+RichardWooSJTU@users.noreply.github.com> Date: Tue, 5 Apr 2022 14:22:06 +0800 Subject: [PATCH 93/93] Add nms op and batched_nms api (#40962) * add nms op and batched_nms api --- .../fluid/operators/detection/CMakeLists.txt | 1 + paddle/fluid/operators/detection/nms_op.cc | 147 ++++++++++++++ paddle/fluid/operators/detection/nms_op.cu | 108 ++++++++++ paddle/fluid/operators/detection/nms_op.h | 51 +++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_nms_op.py | 92 +++++++++ .../fluid/tests/unittests/test_ops_nms.py | 190 ++++++++++++++++++ python/paddle/vision/ops.py | 149 ++++++++++++++ tools/static_mode_white_list.py | 1 + 9 files changed, 740 insertions(+) create mode 100644 paddle/fluid/operators/detection/nms_op.cc create mode 100644 paddle/fluid/operators/detection/nms_op.cu create mode 100644 paddle/fluid/operators/detection/nms_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_nms_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_ops_nms.py diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 568c7982cfc7c..f10c801919999 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -66,6 +66,7 @@ detection_library(yolo_box_op SRCS yolo_box_op.cc) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) +detection_library(nms_op SRCS nms_op.cc nms_op.cu) if(WITH_GPU OR WITH_ROCM) set(TMPDEPS memory) diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc new file mode 100644 index 0000000000000..f6dc44eb5fc2d --- /dev/null +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/nms_op.h" +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class NMSOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Boxes", + "(Tensor) " + "Boxes is a Tensor with shape [N, 4] " + "N is the number of boxes " + "in last dimension in format [x1, x2, y1, y2] " + "the relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``."); + + AddOutput("KeepBoxesIdxs", + "(Tensor) " + "KeepBoxesIdxs is a Tensor with shape [N] "); + AddAttr( + "iou_threshold", + "iou_threshold is a threshold value used to compress similar boxes " + "boxes with IoU > iou_threshold will be considered as overlapping " + "and just one of them can be kept.") + .SetDefault(1.0f) + .AddCustomChecker([](const float& iou_threshold) { + PADDLE_ENFORCE_LE(iou_threshold, 1.0f, + platform::errors::InvalidArgument( + "iou_threshold should less equal than 1.0 " + "but got %f", + iou_threshold)); + PADDLE_ENFORCE_GE(iou_threshold, 0.0f, + platform::errors::InvalidArgument( + "iou_threshold should greater equal than 0.0 " + "but got %f", + iou_threshold)); + }); + AddComment(R"DOC( + NMS Operator. + This Operator is used to perform Non-Maximum Compress for input boxes. + Indices of boxes kept by NMS will be sorted by scores and output. + )DOC"); + } +}; + +class NMSOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS"); + OP_INOUT_CHECK(ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs", + "NMS"); + + auto boxes_dim = ctx->GetInputDim("Boxes"); + PADDLE_ENFORCE_EQ(boxes_dim.size(), 2, + platform::errors::InvalidArgument( + "The Input Boxes must be 2-dimention " + "whose shape must be [N, 4] " + "N is the number of boxes " + "in last dimension in format [x1, x2, y1, y2]. ")); + auto num_boxes = boxes_dim[0]; + + ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace()); + } +}; + +template +static void NMS(const T* boxes_data, int64_t* output_data, float threshold, + int64_t num_boxes) { + auto num_masks = CeilDivide(num_boxes, 64); + std::vector masks(num_masks, 0); + + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + T box_1[4]; + for (int k = 0; k < 4; ++k) { + box_1[k] = boxes_data[i * 4 + k]; + } + for (int64_t j = i + 1; j < num_boxes; ++j) { + if (masks[j / 64] & 1ULL << (j % 64)) continue; + T box_2[4]; + for (int k = 0; k < 4; ++k) { + box_2[k] = boxes_data[j * 4 + k]; + } + bool is_overlap = CalculateIoU(box_1, box_2, threshold); + if (is_overlap) { + masks[j / 64] |= 1ULL << (j % 64); + } + } + } + + int64_t output_data_idx = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + output_data[output_data_idx++] = i; + } + + for (; output_data_idx < num_boxes; ++output_data_idx) { + output_data[output_data_idx] = 0; + } +} + +template +class NMSKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* boxes = context.Input("Boxes"); + Tensor* output = context.Output("KeepBoxesIdxs"); + int64_t* output_data = output->mutable_data(context.GetPlace()); + auto threshold = context.template Attr("iou_threshold"); + NMS(boxes->data(), output_data, threshold, boxes->dims()[0]); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + nms, ops::NMSOp, ops::NMSOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel, ops::NMSKernel); diff --git a/paddle/fluid/operators/detection/nms_op.cu b/paddle/fluid/operators/detection/nms_op.cu new file mode 100644 index 0000000000000..b6027e67d6ced --- /dev/null +++ b/paddle/fluid/operators/detection/nms_op.cu @@ -0,0 +1,108 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/detection/nms_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +static const int64_t threadsPerBlock = sizeof(int64_t) * 8; + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +static __global__ void NMS(const T* boxes_data, float threshold, + int64_t num_boxes, uint64_t* masks) { + auto raw_start = blockIdx.y; + auto col_start = blockIdx.x; + if (raw_start > col_start) return; + + const int raw_last_storage = + min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock); + const int col_last_storage = + min(num_boxes - col_start * threadsPerBlock, threadsPerBlock); + + if (threadIdx.x < raw_last_storage) { + uint64_t mask = 0; + auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x; + const T* current_box = boxes_data + current_box_idx * 4; + for (int i = 0; i < col_last_storage; ++i) { + const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4; + if (CalculateIoU(current_box, target_box, threshold)) { + mask |= 1ULL << i; + } + } + const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); + masks[current_box_idx * blocks_per_line + col_start] = mask; + } +} + +template +class NMSCudaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* boxes = context.Input("Boxes"); + Tensor* output = context.Output("KeepBoxesIdxs"); + auto* output_data = output->mutable_data(context.GetPlace()); + + auto threshold = context.template Attr("iou_threshold"); + const int64_t num_boxes = boxes->dims()[0]; + const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); + + dim3 block(threadsPerBlock); + dim3 grid(blocks_per_line, blocks_per_line); + + auto mask_data = + memory::Alloc(context.cuda_device_context(), + num_boxes * blocks_per_line * sizeof(uint64_t)); + uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); + NMS<<>>( + boxes->data(), threshold, num_boxes, mask_dev); + + std::vector mask_host(num_boxes * blocks_per_line); + memory::Copy(platform::CPUPlace(), mask_host.data(), context.GetPlace(), + mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t), + context.cuda_device_context().stream()); + + std::vector remv(blocks_per_line); + + std::vector keep_boxes_idxs(num_boxes); + int64_t* output_host = keep_boxes_idxs.data(); + + int64_t last_box_num = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + auto remv_element_id = i / threadsPerBlock; + auto remv_bit_id = i % threadsPerBlock; + if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) { + output_host[last_box_num++] = i; + uint64_t* current_mask = mask_host.data() + i * blocks_per_line; + for (auto j = remv_element_id; j < blocks_per_line; ++j) { + remv[j] |= current_mask[j]; + } + } + } + memory::Copy(context.GetPlace(), output_data, platform::CPUPlace(), + output_host, sizeof(int64_t) * num_boxes, + context.cuda_device_context().stream()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(nms, ops::NMSCudaKernel, + ops::NMSCudaKernel); diff --git a/paddle/fluid/operators/detection/nms_op.h b/paddle/fluid/operators/detection/nms_op.h new file mode 100644 index 0000000000000..dce8f47f0174e --- /dev/null +++ b/paddle/fluid/operators/detection/nms_op.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) { + return (n + m - 1) / m; +} + +template +HOSTDEVICE inline bool CalculateIoU(const T* const box_1, const T* const box_2, + const float threshold) { + auto box_1_x0 = box_1[0], box_1_y0 = box_1[1]; + auto box_1_x1 = box_1[2], box_1_y1 = box_1[3]; + auto box_2_x0 = box_2[0], box_2_y0 = box_2[1]; + auto box_2_x1 = box_2[2], box_2_y1 = box_2[3]; + + auto inter_box_x0 = box_1_x0 > box_2_x0 ? box_1_x0 : box_2_x0; + auto inter_box_y0 = box_1_y0 > box_2_y0 ? box_1_y0 : box_2_y0; + auto inter_box_x1 = box_1_x1 < box_2_x1 ? box_1_x1 : box_2_x1; + auto inter_box_y1 = box_1_y1 < box_2_y1 ? box_1_y1 : box_2_y1; + + auto inter_width = + inter_box_x1 - inter_box_x0 > 0 ? inter_box_x1 - inter_box_x0 : 0; + auto inter_height = + inter_box_y1 - inter_box_y0 > 0 ? inter_box_y1 - inter_box_y0 : 0; + auto inter_area = inter_width * inter_height; + auto union_area = (box_1_x1 - box_1_x0) * (box_1_y1 - box_1_y0) + + (box_2_x1 - box_2_x0) * (box_2_y1 - box_2_y0) - inter_area; + return inter_area / union_area > threshold; +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8b84a9c524adf..b4d6f9b941d4f 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -234,6 +234,7 @@ endif() if(WIN32) LIST(REMOVE_ITEM TEST_OPS test_complex_matmul) + LIST(REMOVE_ITEM TEST_OPS test_ops_nms) endif() LIST(REMOVE_ITEM TEST_OPS test_fleet_checkpoint) diff --git a/python/paddle/fluid/tests/unittests/test_nms_op.py b/python/paddle/fluid/tests/unittests/test_nms_op.py new file mode 100644 index 0000000000000..1b5ac1f1337d0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nms_op.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def iou(box_a, box_b): + """Apply intersection-over-union overlap between box_a and box_b + """ + xmin_a = min(box_a[0], box_a[2]) + ymin_a = min(box_a[1], box_a[3]) + xmax_a = max(box_a[0], box_a[2]) + ymax_a = max(box_a[1], box_a[3]) + + xmin_b = min(box_b[0], box_b[2]) + ymin_b = min(box_b[1], box_b[3]) + xmax_b = max(box_b[0], box_b[2]) + ymax_b = max(box_b[1], box_b[3]) + + area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a) + area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b) + if area_a <= 0 and area_b <= 0: + return 0.0 + + xa = max(xmin_a, xmin_b) + ya = max(ymin_a, ymin_b) + xb = min(xmax_a, xmax_b) + yb = min(ymax_a, ymax_b) + + inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0) + + iou_ratio = inter_area / (area_a + area_b - inter_area) + return iou_ratio + + +def nms(boxes, nms_threshold): + selected_indices = np.zeros(boxes.shape[0], dtype=np.int64) + keep = np.ones(boxes.shape[0], dtype=int) + io_ratio = np.ones((boxes.shape[0], boxes.shape[0]), dtype=np.float64) + cnt = 0 + for i in range(boxes.shape[0]): + if keep[i] == 0: + continue + selected_indices[cnt] = i + cnt += 1 + for j in range(i + 1, boxes.shape[0]): + io_ratio[i][j] = iou(boxes[i], boxes[j]) + if keep[j]: + overlap = iou(boxes[i], boxes[j]) + keep[j] = 1 if overlap <= nms_threshold else 0 + else: + continue + + return selected_indices + + +class TestNMSOp(OpTest): + def setUp(self): + self.op_type = 'nms' + self.dtype = np.float64 + self.init_dtype_type() + boxes = np.random.rand(32, 4).astype(self.dtype) + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + + self.inputs = {'Boxes': boxes} + self.attrs = {'iou_threshold': 0.5} + out_py = nms(boxes, self.attrs['iou_threshold']) + self.outputs = {'KeepBoxesIdxs': out_py} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_ops_nms.py b/python/paddle/fluid/tests/unittests/test_ops_nms.py new file mode 100644 index 0000000000000..c0bbe82d3581a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ops_nms.py @@ -0,0 +1,190 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from test_nms_op import nms + + +def _find(condition): + """ + Find the indices of elements saticfied the condition. + + Args: + condition(Tensor[N] or np.ndarray([N,])): Element should be bool type. + + Returns: + Tensor: Indices of True element. + """ + res = [] + for i in range(condition.shape[0]): + if condition[i]: + res.append(i) + return np.array(res) + + +def multiclass_nms(boxes, scores, category_idxs, iou_threshold, top_k): + mask = np.zeros_like(scores) + + for category_id in np.unique(category_idxs): + cur_category_boxes_idxs = _find(category_idxs == category_id) + cur_category_boxes = boxes[cur_category_boxes_idxs] + cur_category_scores = scores[cur_category_boxes_idxs] + cur_category_sorted_indices = np.argsort(-cur_category_scores) + cur_category_sorted_boxes = cur_category_boxes[ + cur_category_sorted_indices] + + cur_category_keep_boxes_sub_idxs = cur_category_sorted_indices[nms( + cur_category_sorted_boxes, iou_threshold)] + + mask[cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs]] = True + + keep_boxes_idxs = _find(mask == True) + topK_sub_indices = np.argsort(-scores[keep_boxes_idxs])[:top_k] + return keep_boxes_idxs[topK_sub_indices] + + +def gen_args(num_boxes, dtype): + boxes = np.random.rand(num_boxes, 4).astype(dtype) + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + + scores = np.random.rand(num_boxes).astype(dtype) + + categories = [0, 1, 2, 3] + category_idxs = np.random.choice(categories, num_boxes) + + return boxes, scores, category_idxs, categories + + +class TestOpsNMS(unittest.TestCase): + def setUp(self): + self.num_boxes = 64 + self.threshold = 0.5 + self.topk = 20 + self.dtypes = ['float32'] + self.devices = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.devices.append('gpu') + + def test_nms(self): + for device in self.devices: + for dtype in self.dtypes: + boxes, scores, category_idxs, categories = gen_args( + self.num_boxes, dtype) + paddle.set_device(device) + out = paddle.vision.ops.nms( + paddle.to_tensor(boxes), self.threshold, + paddle.to_tensor(scores)) + out = paddle.vision.ops.nms( + paddle.to_tensor(boxes), self.threshold) + out_py = nms(boxes, self.threshold) + + self.assertTrue( + np.array_equal(out.numpy(), out_py), + "paddle out: {}\n py out: {}\n".format(out, out_py)) + + def test_multiclass_nms_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + boxes, scores, category_idxs, categories = gen_args( + self.num_boxes, dtype) + paddle.set_device(device) + out = paddle.vision.ops.nms( + paddle.to_tensor(boxes), self.threshold, + paddle.to_tensor(scores), + paddle.to_tensor(category_idxs), categories, self.topk) + out_py = multiclass_nms(boxes, scores, category_idxs, + self.threshold, self.topk) + + self.assertTrue( + np.array_equal(out.numpy(), out_py), + "paddle out: {}\n py out: {}\n".format(out, out_py)) + + def test_multiclass_nms_static(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.enable_static() + boxes, scores, category_idxs, categories = gen_args( + self.num_boxes, dtype) + boxes_static = paddle.static.data( + shape=boxes.shape, dtype=boxes.dtype, name="boxes") + scores_static = paddle.static.data( + shape=scores.shape, dtype=scores.dtype, name="scores") + category_idxs_static = paddle.static.data( + shape=category_idxs.shape, + dtype=category_idxs.dtype, + name="category_idxs") + out = paddle.vision.ops.nms(boxes_static, self.threshold, + scores_static, category_idxs_static, + categories, self.topk) + place = paddle.CPUPlace() + if device == 'gpu': + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + out = exe.run(paddle.static.default_main_program(), + feed={ + 'boxes': boxes, + 'scores': scores, + 'category_idxs': category_idxs + }, + fetch_list=[out]) + paddle.disable_static() + out_py = multiclass_nms(boxes, scores, category_idxs, + self.threshold, self.topk) + out = np.array(out) + out = np.squeeze(out) + self.assertTrue( + np.array_equal(out, out_py), + "paddle out: {}\n py out: {}\n".format(out, out_py)) + + def test_multiclass_nms_dynamic_to_static(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.set_device(device) + + def fun(x): + scores = np.arange(0, 64).astype('float32') + categories = np.array([0, 1, 2, 3]) + category_idxs = categories.repeat(16) + out = paddle.vision.ops.nms(x, 0.1, + paddle.to_tensor(scores), + paddle.to_tensor(category_idxs), + categories, 10) + return out + + path = "./net" + boxes = np.random.rand(64, 4).astype('float32') + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + + origin = fun(paddle.to_tensor(boxes)) + paddle.jit.save( + fun, + path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 4], dtype='float32', name='x') + ], ) + load_func = paddle.jit.load(path) + res = load_func(paddle.to_tensor(boxes)) + self.assertTrue( + np.array_equal(origin, res), + "origin out: {}\n inference model out: {}\n".format(origin, + res)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index b510b7c8bdfe8..7797909e3b52c 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -36,6 +36,7 @@ 'PSRoIPool', 'roi_align', 'RoIAlign', + 'nms', ] @@ -1357,3 +1358,151 @@ def __init__(self, if activation_layer is not None: layers.append(activation_layer()) super().__init__(*layers) + + +def nms(boxes, + iou_threshold=0.3, + scores=None, + category_idxs=None, + categories=None, + top_k=None): + r""" + This operator implements non-maximum suppression. Non-maximum suppression (NMS) + is used to select one bounding box out of many overlapping bounding boxes in object detection. + Boxes with IoU > iou_threshold will be considered as overlapping boxes, + just one with highest score can be kept. Here IoU is Intersection Over Union, + which can be computed by: + + .. math:: + + IoU = \frac{intersection\_area(box1, box2)}{union\_area(box1, box2)} + + If scores are provided, input boxes will be sorted by their scores firstly. + If category_idxs and categories are provided, NMS will be performed with a batched style, + which means NMS will be applied to each category respectively and results of each category + will be concated and sorted by scores. + If K is provided, only the first k elements will be returned. Otherwise, all box indices sorted by scores will be returned. + + Args: + boxes(Tensor): The input boxes data to be computed, it's a 2D-Tensor with + the shape of [num_boxes, 4] and boxes should be sorted by their + confidence scores. The data type is float32 or float64. + Given as [[x1, y1, x2, y2], …], (x1, y1) is the top left coordinates, + and (x2, y2) is the bottom right coordinates. + Their relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``. + iou_threshold(float32): IoU threshold for determine overlapping boxes. Default value: 0.3. + scores(Tensor, optional): Scores corresponding to boxes, it's a 1D-Tensor with + shape of [num_boxes]. The data type is float32 or float64. + category_idxs(Tensor, optional): Category indices corresponding to boxes. + it's a 1D-Tensor with shape of [num_boxes]. The data type is int64. + categories(List, optional): A list of unique id of all categories. The data type is int64. + top_k(int64, optional): The top K boxes who has higher score and kept by NMS preds to + consider. top_k should be smaller equal than num_boxes. + + Returns: + Tensor: 1D-Tensor with the shape of [num_boxes]. Indices of boxes kept by NMS. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + boxes = np.random.rand(4, 4).astype('float32') + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + # [[0.06287421 0.5809351 0.3443958 0.8713329 ] + # [0.0749094 0.9713205 0.99241287 1.2799143 ] + # [0.46246734 0.6753201 1.346266 1.3821303 ] + # [0.8984796 0.5619834 1.1254641 1.0201943 ]] + + out = paddle.vision.ops.nms(paddle.to_tensor(boxes), 0.1) + # [0, 1, 3, 0] + + scores = np.random.rand(4).astype('float32') + # [0.98015213 0.3156527 0.8199343 0.874901 ] + + categories = [0, 1, 2, 3] + category_idxs = np.random.choice(categories, 4) + # [2 0 0 3] + + out = paddle.vision.ops.nms(paddle.to_tensor(boxes), + 0.1, + paddle.to_tensor(scores), + paddle.to_tensor(category_idxs), + categories, + 4) + # [0, 3, 2] + """ + + def _nms(boxes, iou_threshold): + if _non_static_mode(): + return _C_ops.nms(boxes, 'iou_threshold', iou_threshold) + + helper = LayerHelper('nms', **locals()) + out = helper.create_variable_for_type_inference('int64') + helper.append_op( + type='nms', + inputs={'Boxes': boxes}, + outputs={'KeepBoxesIdxs': out}, + attrs={'iou_threshold': iou_threshold}) + return out + + if scores is None: + return _nms(boxes, iou_threshold) + + import paddle + if category_idxs is None: + sorted_global_indices = paddle.argsort(scores, descending=True) + return _nms(boxes[sorted_global_indices], iou_threshold) + + if top_k is not None: + assert top_k <= scores.shape[ + 0], "top_k should be smaller equal than the number of boxes" + assert categories is not None, "if category_idxs is given, categories which is a list of unique id of all categories is necessary" + + mask = paddle.zeros_like(scores, dtype=paddle.int32) + + for category_id in categories: + cur_category_boxes_idxs = paddle.where(category_idxs == category_id)[0] + shape = cur_category_boxes_idxs.shape[0] + cur_category_boxes_idxs = paddle.reshape(cur_category_boxes_idxs, + [shape]) + if shape == 0: + continue + elif shape == 1: + mask[cur_category_boxes_idxs] = 1 + continue + cur_category_boxes = boxes[cur_category_boxes_idxs] + cur_category_scores = scores[cur_category_boxes_idxs] + cur_category_sorted_indices = paddle.argsort( + cur_category_scores, descending=True) + cur_category_sorted_boxes = cur_category_boxes[ + cur_category_sorted_indices] + + cur_category_keep_boxes_sub_idxs = cur_category_sorted_indices[_nms( + cur_category_sorted_boxes, iou_threshold)] + + updates = paddle.ones_like( + cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs], + dtype=paddle.int32) + mask = paddle.scatter( + mask, + cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs], + updates, + overwrite=True) + keep_boxes_idxs = paddle.where(mask)[0] + shape = keep_boxes_idxs.shape[0] + keep_boxes_idxs = paddle.reshape(keep_boxes_idxs, [shape]) + sorted_sub_indices = paddle.argsort( + scores[keep_boxes_idxs], descending=True) + + if top_k is None: + return keep_boxes_idxs[sorted_sub_indices] + + if _non_static_mode(): + top_k = shape if shape < top_k else top_k + _, topk_sub_indices = paddle.topk(scores[keep_boxes_idxs], top_k) + return keep_boxes_idxs[topk_sub_indices] + + return keep_boxes_idxs[sorted_sub_indices][:top_k] diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 365047f7e8382..f907d51e4d038 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -349,6 +349,7 @@ 'test_nearest_interp_v2_op', 'test_network_with_dtype', 'test_nll_loss', + 'test_nms_op', 'test_nn_functional_embedding_static', 'test_nn_functional_hot_op', 'test_nonzero_api',