diff --git a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py index 9ff34635406997..c4784bc64d8d7d 100644 --- a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py @@ -894,3 +894,5 @@ def ParseArguments(): python_c_def_h_file, python_c_def_cc_file, ) + +# diff --git a/paddle/phi/ops/yaml/python_api_info.yaml b/paddle/phi/ops/yaml/python_api_info.yaml index 0ded669db2248e..430d9a804bdae5 100644 --- a/paddle/phi/ops/yaml/python_api_info.yaml +++ b/paddle/phi/ops/yaml/python_api_info.yaml @@ -8,10 +8,10 @@ args_alias : use_default_mapping : True -- op : matmul - name : [paddle.matmul,paddle.Tensor.matmul] - args_alias : - use_default_mapping : True +# - op : matmul +# name : [paddle.matmul,paddle.Tensor.matmul] +# args_alias : +# use_default_mapping : True - op : multiply name : [paddle.multiply,paddle.Tensor.multiply] args_alias : diff --git a/python/paddle/_paddle_docs.py b/python/paddle/_paddle_docs.py index abb99cb9e03e90..f8bdb36a2998b6 100644 --- a/python/paddle/_paddle_docs.py +++ b/python/paddle/_paddle_docs.py @@ -522,111 +522,111 @@ def argmin( """, ) -add_doc_and_signature( - "matmul", - """ - Applies matrix multiplication to two tensors. `matmul` follows - the complete broadcast rules, - and its behavior is consistent with `np.matmul`. - - Currently, the input tensors' number of dimensions can be any, `matmul` can be used to - achieve the `dot`, `matmul` and `batchmatmul`. - - The actual behavior depends on the shapes of :math:`x`, :math:`y` and the - flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: - - - If a transpose flag is specified, the last two dimensions of the tensor - are transposed. If the tensor is ndim-1 of shape, the transpose is invalid. If the tensor - is ndim-1 of shape :math:`[D]`, then for :math:`x` it is treated as :math:`[1, D]`, whereas - for :math:`y` it is the opposite: It is treated as :math:`[D, 1]`. - - The multiplication behavior depends on the dimensions of `x` and `y`. Specifically: - - - If both tensors are 1-dimensional, the dot product result is obtained. - - - If both tensors are 2-dimensional, the matrix-matrix product is obtained. - - - If the `x` is 1-dimensional and the `y` is 2-dimensional, - a `1` is prepended to its dimension in order to conduct the matrix multiply. - After the matrix multiply, the prepended dimension is removed. - - - If the `x` is 2-dimensional and `y` is 1-dimensional, - the matrix-vector product is obtained. - - - If both arguments are at least 1-dimensional and at least one argument - is N-dimensional (where N > 2), then a batched matrix multiply is obtained. - If the first argument is 1-dimensional, a 1 is prepended to its dimension - in order to conduct the batched matrix multiply and removed after. - If the second argument is 1-dimensional, a 1 is appended to its - dimension for the purpose of the batched matrix multiple and removed after. - The non-matrix (exclude the last two dimensions) dimensions are - broadcasted according the broadcast rule. - For example, if input is a (j, 1, n, m) tensor and the other is a (k, m, p) tensor, - out will be a (j, k, n, p) tensor. - - Args: - x (Tensor): The input tensor which is a Tensor. - y (Tensor): The input tensor which is a Tensor. - transpose_x (bool, optional): Whether to transpose :math:`x` before multiplication. Default is False. - transpose_y (bool, optional): Whether to transpose :math:`y` before multiplication. Default is False. - name (str|None, optional): If set None, the layer will be named automatically. For more information, please refer to :ref:`api_guide_Name`. Default is None. - out (Tensor, optional): The output tensor. If set, the result will be stored in this tensor. Default is None. - - Returns: - Tensor: The output Tensor. - - Examples: - - .. code-block:: python - - >>> import paddle - - >>> # vector * vector - >>> x = paddle.rand([10]) - >>> y = paddle.rand([10]) - >>> z = paddle.matmul(x, y) - >>> print(z.shape) - [] - - >>> # matrix * vector - >>> x = paddle.rand([10, 5]) - >>> y = paddle.rand([5]) - >>> z = paddle.matmul(x, y) - >>> print(z.shape) - [10] - - >>> # batched matrix * broadcasted vector - >>> x = paddle.rand([10, 5, 2]) - >>> y = paddle.rand([2]) - >>> z = paddle.matmul(x, y) - >>> print(z.shape) - [10, 5] - - >>> # batched matrix * batched matrix - >>> x = paddle.rand([10, 5, 2]) - >>> y = paddle.rand([10, 2, 5]) - >>> z = paddle.matmul(x, y) - >>> print(z.shape) - [10, 5, 5] - - >>> # batched matrix * broadcasted matrix - >>> x = paddle.rand([10, 1, 5, 2]) - >>> y = paddle.rand([1, 3, 2, 5]) - >>> z = paddle.matmul(x, y) - >>> print(z.shape) - [10, 3, 5, 5] - - """, - """ def matmul( - x: Tensor, - y: Tensor, - transpose_x: bool = False, - transpose_y: bool = False, - name: str | None = None, - *, - out: Tensor | None = None, -) -> Tensor""", -) +# add_doc_and_signature( +# "matmul", +# """ +# Applies matrix multiplication to two tensors. `matmul` follows +# the complete broadcast rules, +# and its behavior is consistent with `np.matmul`. + +# Currently, the input tensors' number of dimensions can be any, `matmul` can be used to +# achieve the `dot`, `matmul` and `batchmatmul`. + +# The actual behavior depends on the shapes of :math:`x`, :math:`y` and the +# flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: + +# - If a transpose flag is specified, the last two dimensions of the tensor +# are transposed. If the tensor is ndim-1 of shape, the transpose is invalid. If the tensor +# is ndim-1 of shape :math:`[D]`, then for :math:`x` it is treated as :math:`[1, D]`, whereas +# for :math:`y` it is the opposite: It is treated as :math:`[D, 1]`. + +# The multiplication behavior depends on the dimensions of `x` and `y`. Specifically: + +# - If both tensors are 1-dimensional, the dot product result is obtained. + +# - If both tensors are 2-dimensional, the matrix-matrix product is obtained. + +# - If the `x` is 1-dimensional and the `y` is 2-dimensional, +# a `1` is prepended to its dimension in order to conduct the matrix multiply. +# After the matrix multiply, the prepended dimension is removed. + +# - If the `x` is 2-dimensional and `y` is 1-dimensional, +# the matrix-vector product is obtained. + +# - If both arguments are at least 1-dimensional and at least one argument +# is N-dimensional (where N > 2), then a batched matrix multiply is obtained. +# If the first argument is 1-dimensional, a 1 is prepended to its dimension +# in order to conduct the batched matrix multiply and removed after. +# If the second argument is 1-dimensional, a 1 is appended to its +# dimension for the purpose of the batched matrix multiple and removed after. +# The non-matrix (exclude the last two dimensions) dimensions are +# broadcasted according the broadcast rule. +# For example, if input is a (j, 1, n, m) tensor and the other is a (k, m, p) tensor, +# out will be a (j, k, n, p) tensor. + +# Args: +# x (Tensor): The input tensor which is a Tensor. +# y (Tensor): The input tensor which is a Tensor. +# transpose_x (bool, optional): Whether to transpose :math:`x` before multiplication. Default is False. +# transpose_y (bool, optional): Whether to transpose :math:`y` before multiplication. Default is False. +# name (str|None, optional): If set None, the layer will be named automatically. For more information, please refer to :ref:`api_guide_Name`. Default is None. +# out (Tensor, optional): The output tensor. If set, the result will be stored in this tensor. Default is None. + +# Returns: +# Tensor: The output Tensor. + +# Examples: + +# .. code-block:: python + +# >>> import paddle + +# >>> # vector * vector +# >>> x = paddle.rand([10]) +# >>> y = paddle.rand([10]) +# >>> z = paddle.matmul(x, y) +# >>> print(z.shape) +# [] + +# >>> # matrix * vector +# >>> x = paddle.rand([10, 5]) +# >>> y = paddle.rand([5]) +# >>> z = paddle.matmul(x, y) +# >>> print(z.shape) +# [10] + +# >>> # batched matrix * broadcasted vector +# >>> x = paddle.rand([10, 5, 2]) +# >>> y = paddle.rand([2]) +# >>> z = paddle.matmul(x, y) +# >>> print(z.shape) +# [10, 5] + +# >>> # batched matrix * batched matrix +# >>> x = paddle.rand([10, 5, 2]) +# >>> y = paddle.rand([10, 2, 5]) +# >>> z = paddle.matmul(x, y) +# >>> print(z.shape) +# [10, 5, 5] + +# >>> # batched matrix * broadcasted matrix +# >>> x = paddle.rand([10, 1, 5, 2]) +# >>> y = paddle.rand([1, 3, 2, 5]) +# >>> z = paddle.matmul(x, y) +# >>> print(z.shape) +# [10, 3, 5, 5] + +# """, +# """ def matmul( +# x: Tensor, +# y: Tensor, +# transpose_x: bool = False, +# transpose_y: bool = False, +# name: str | None = None, +# *, +# out: Tensor | None = None, +# ) -> Tensor""", +# ) add_doc_and_signature( "multiply", """ diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4f6969262833f6..801846c96a0fe0 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -21,13 +21,14 @@ import paddle from paddle import _C_ops -from paddle._C_ops import bmm, matmul # noqa: F401 +from paddle._C_ops import bmm # noqa: F401 from paddle.base.libpaddle import DataType from paddle.common_ops_import import VarDesc from paddle.tensor.math import broadcast_shape from paddle.utils.decorator_utils import ( ParamAliasDecorator, VariableArgsDecorator, + param_two_alias, transpose_decorator, ) from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only @@ -261,6 +262,148 @@ def matrix_transpose( return x.mT +@param_two_alias(["x", "input"], ["y", "other"]) +def matmul( + x: Tensor, + y: Tensor, + transpose_x: bool = False, + transpose_y: bool = False, + name: str | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + """ + Applies matrix multiplication to two tensors. `matmul` follows + the complete broadcast rules, + and its behavior is consistent with `np.matmul`. + + Currently, the input tensors' number of dimensions can be any, `matmul` can be used to + achieve the `dot`, `matmul` and `batchmatmul`. + + The actual behavior depends on the shapes of :math:`x`, :math:`y` and the + flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: + + - If a transpose flag is specified, the last two dimensions of the tensor + are transposed. If the tensor is ndim-1 of shape, the transpose is invalid. If the tensor + is ndim-1 of shape :math:`[D]`, then for :math:`x` it is treated as :math:`[1, D]`, whereas + for :math:`y` it is the opposite: It is treated as :math:`[D, 1]`. + + The multiplication behavior depends on the dimensions of `x` and `y`. Specifically: + + - If both tensors are 1-dimensional, the dot product result is obtained. + + - If both tensors are 2-dimensional, the matrix-matrix product is obtained. + + - If the `x` is 1-dimensional and the `y` is 2-dimensional, + a `1` is prepended to its dimension in order to conduct the matrix multiply. + After the matrix multiply, the prepended dimension is removed. + + - If the `x` is 2-dimensional and `y` is 1-dimensional, + the matrix-vector product is obtained. + + - If both arguments are at least 1-dimensional and at least one argument + is N-dimensional (where N > 2), then a batched matrix multiply is obtained. + If the first argument is 1-dimensional, a 1 is prepended to its dimension + in order to conduct the batched matrix multiply and removed after. + If the second argument is 1-dimensional, a 1 is appended to its + dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (exclude the last two dimensions) dimensions are + broadcasted according the broadcast rule. + For example, if input is a (j, 1, n, m) tensor and the other is a (k, m, p) tensor, + out will be a (j, k, n, p) tensor. + + Args: + x (Tensor): The input tensor which is a Tensor. + y (Tensor): The input tensor which is a Tensor. + transpose_x (bool, optional): Whether to transpose :math:`x` before multiplication. Default is False. + transpose_y (bool, optional): Whether to transpose :math:`y` before multiplication. Default is False. + name (str|None, optional): If set None, the layer will be named automatically. For more information, please refer to :ref:`api_guide_Name`. Default is None. + out (Tensor, optional): The output tensor. If set, the result will be stored in this tensor. Default is None. + + Returns: + Tensor: The output Tensor. + + Examples: + + .. code-block:: python + + >>> import paddle + + >>> # vector * vector + >>> x = paddle.rand([10]) + >>> y = paddle.rand([10]) + >>> z = paddle.matmul(x, y) + >>> print(z.shape) + [] + + >>> # matrix * vector + >>> x = paddle.rand([10, 5]) + >>> y = paddle.rand([5]) + >>> z = paddle.matmul(x, y) + >>> print(z.shape) + [10] + + >>> # batched matrix * broadcasted vector + >>> x = paddle.rand([10, 5, 2]) + >>> y = paddle.rand([2]) + >>> z = paddle.matmul(x, y) + >>> print(z.shape) + [10, 5] + + >>> # batched matrix * batched matrix + >>> x = paddle.rand([10, 5, 2]) + >>> y = paddle.rand([10, 2, 5]) + >>> z = paddle.matmul(x, y) + >>> print(z.shape) + [10, 5, 5] + + >>> # batched matrix * broadcasted matrix + >>> x = paddle.rand([10, 1, 5, 2]) + >>> y = paddle.rand([1, 3, 2, 5]) + >>> z = paddle.matmul(x, y) + >>> print(z.shape) + [10, 3, 5, 5] + + """ + if in_dynamic_or_pir_mode(): + return _C_ops.matmul(x, y, transpose_x, transpose_y, out=out) + else: + attrs = { + 'trans_x': transpose_x, + 'trans_y': transpose_y, + } + + def __check_input(x, y): + var_names = {'x': x, 'y': y} + for name, val in var_names.items(): + check_variable_and_dtype( + val, + name, + [ + 'int8', + 'uint16', + 'float16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], + 'matmul', + ) + + __check_input(x, y) + + helper = LayerHelper('matmul_v2', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='matmul_v2', + inputs={'X': x, 'Y': y}, + outputs={'Out': out}, + attrs=attrs, + ) + return out + + def fp8_fp8_half_gemm_fused( x, y, diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py index 4d62182992a087..8ebedb93e509f3 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_global_input.py @@ -82,7 +82,7 @@ def forward(self, x): else: global_input1 = global_input x = x + global_input1 - y = paddle.matmul(x, self.w0) + y = x @ self.w0 # forward on mesh1 if self.run_single_process is False: y = dist.reshard(y, mesh1, [dist.Shard(0), dist.Shard(2)]) @@ -93,7 +93,7 @@ def forward(self, x): global_input2 = global_input y = y + global_input2 - z = paddle.matmul(y, self.w1) + z = y @ self.w1 return z diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py index b544a89f867175..c577c6fbdc44ec 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py @@ -61,11 +61,11 @@ def forward(self, input1, input2): x = input1 + input2 # x: [bs, seq_len, hidden] # forward on mesh0 - y = paddle.matmul(x, self.w0) + y = x @ self.w0 # forward on mesh1 if self.run_single_process is False: y = dist.reshard(y, mesh1, [dist.Shard(0), dist.Shard(2)]) - z = paddle.matmul(y, self.w1) + z = y @ self.w1 return z diff --git a/test/legacy_test/test_imperative_hook_for_layer.py b/test/legacy_test/test_imperative_hook_for_layer.py index f7b289caa843d1..3538c81eed275d 100644 --- a/test/legacy_test/test_imperative_hook_for_layer.py +++ b/test/legacy_test/test_imperative_hook_for_layer.py @@ -18,14 +18,68 @@ import numpy as np sys.path.append("../deprecated/legacy_test") -# from test_imperative_lod_tensor_to_selected_rows_deprecated import SimpleNet +from op_test import get_places import paddle +from paddle import base call_forward_post_hook = False call_forward_pre_hook = False +class SimpleNet(paddle.nn.Layer): + def __init__( + self, + hidden_size, + vocab_size, + num_steps=20, + init_scale=0.1, + is_sparse=False, + dtype='float32', + ): + super().__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.init_scale = init_scale + self.num_steps = num_steps + paddle.set_default_dtype(dtype) + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + sparse=is_sparse, + weight_attr=base.ParamAttr( + name='embedding_para', + initializer=paddle.nn.initializer.Uniform( + low=-init_scale, high=init_scale + ), + ), + ) + self.softmax_bias = self.create_parameter( + attr=base.ParamAttr(), + shape=[self.vocab_size], + dtype=dtype, + default_initializer=paddle.nn.initializer.Uniform( + low=-self.init_scale, high=self.init_scale + ), + ) + + def forward(self, input, label): + x_emb = self.embedding(input) + projection = paddle.matmul( + x_emb, paddle.transpose(self.embedding.weight, perm=[1, 0]) + ) + projection = paddle.add(projection, self.softmax_bias) + projection = paddle.reshape(projection, shape=[-1, self.vocab_size]) + loss = paddle.nn.functional.softmax_with_cross_entropy( + logits=projection, label=label, soft_label=False + ) + loss = paddle.reshape(loss, shape=[-1, self.num_steps]) + loss = paddle.mean(loss, axis=[0]) + loss = paddle.sum(loss) + + return loss + + def forward_post_hook(layer, input, output): global call_forward_post_hook call_forward_post_hook = True @@ -45,160 +99,160 @@ def forward_pre_hook1(layer, input): return input_return -# class Test_Forward_Hook(unittest.TestCase): -# # test forward_pre_hook and forward_post_hook that have return value -# def test_forward_hook_return_value(self): -# seed = 90 - -# for place in get_places(): -# with base.dygraph.guard(place): -# paddle.seed(seed) -# base.set_flags({'FLAGS_sort_sum_gradient': True}) - -# input_word = ( -# np.array( -# [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8] -# ) -# .reshape(6, 3) -# .astype('int64') -# ) -# input_word1 = input_word * 2 -# input_word = input_word.reshape((-1, 3, 1)) -# input_word1 = input_word1.reshape((-1, 3, 1)) -# y_data = ( -# np.array( -# [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9] -# ) -# .reshape(6, 3) -# .astype('int64') -# ) -# y_data = y_data.reshape((-1, 1)) - -# input = paddle.to_tensor(input_word) -# input1 = paddle.to_tensor(input_word1) -# y = paddle.to_tensor(y_data) - -# simplenet = SimpleNet( -# hidden_size=20, -# vocab_size=32, -# num_steps=3, -# init_scale=0.1, -# is_sparse=False, -# dtype="float32", -# ) - -# # origin, don't register any hook -# outs_origin = simplenet(input, y) -# outs_origin1 = simplenet(input1, y) - -# # register forward_pre_hook -# forward_pre_hook_handle1 = simplenet.register_forward_pre_hook( -# forward_pre_hook1 -# ) -# outs_pre_hook = simplenet(input, y) -# np.testing.assert_array_equal( -# outs_pre_hook.numpy(), outs_origin1.numpy() -# ) - -# # remove forward_pre_hook -# forward_pre_hook_handle1.remove() -# outs_pre_hook = simplenet(input, y) -# np.testing.assert_array_equal( -# outs_pre_hook.numpy(), outs_origin.numpy() -# ) - -# # register forward_posst_hook -# forward_post_hook_handle1 = ( -# simplenet.register_forward_post_hook(forward_post_hook1) -# ) -# outs_forward_hook = simplenet(input, y) -# np.testing.assert_array_equal( -# outs_forward_hook.numpy(), outs_origin.numpy() * 2 -# ) - -# # remove forward_post_hook -# forward_post_hook_handle1.remove() -# outs_forward_hook = simplenet(input, y) -# np.testing.assert_array_equal( -# outs_forward_hook.numpy(), outs_origin.numpy() -# ) - -# # test forward_pre_hook and forward_post_hook that don't have return value -# def test_forward_hook(self): -# seed = 90 - -# for place in get_places(): -# with base.dygraph.guard(place): -# paddle.seed(seed) -# base.set_flags({'FLAGS_sort_sum_gradient': True}) - -# global call_forward_post_hook -# global call_forward_pre_hook - -# input_word = ( -# np.array( -# [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8] -# ) -# .reshape(6, 3) -# .astype('int64') -# ) -# input_word = input_word.reshape((-1, 3, 1)) -# y_data = ( -# np.array( -# [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9] -# ) -# .reshape(6, 3) -# .astype('int64') -# ) -# y_data = y_data.reshape((-1, 1)) - -# input = paddle.to_tensor(input_word) -# y = paddle.to_tensor(y_data) - -# simplenet = SimpleNet( -# hidden_size=20, -# vocab_size=32, -# num_steps=3, -# init_scale=0.1, -# is_sparse=False, -# dtype="float32", -# ) - -# # origin, don't register any hook -# outs_origin = simplenet(input, y) -# self.assertFalse(call_forward_post_hook) -# self.assertFalse(call_forward_pre_hook) - -# # register forward_post_hook and forward_pre_hook -# forward_post_hook_handle = simplenet.register_forward_post_hook( -# forward_post_hook -# ) -# forward_pre_hook_handle = simplenet.register_forward_pre_hook( -# forward_pre_hook -# ) -# outs_hook = simplenet(input, y) -# self.assertTrue(call_forward_post_hook) -# self.assertTrue(call_forward_pre_hook) - -# outs_hook = simplenet(input, y) -# self.assertTrue(call_forward_post_hook) -# self.assertTrue(call_forward_pre_hook) - -# # remove forward_post_hook -# forward_post_hook_handle.remove() -# call_forward_post_hook = False -# call_forward_pre_hook = False -# outs_remove_forward_hook = simplenet(input, y) -# self.assertFalse(call_forward_post_hook) -# self.assertTrue(call_forward_pre_hook) - -# # remove forward_pre_hook -# forward_pre_hook_handle.remove() -# call_forward_post_hook = False -# call_forward_pre_hook = False -# outs_remove_hook = simplenet(input, y) -# self.assertFalse(call_forward_post_hook) -# self.assertFalse(call_forward_pre_hook) +class Test_Forward_Hook(unittest.TestCase): + # test forward_pre_hook and forward_post_hook that have return value + def test_forward_hook_return_value(self): + seed = 90 + + for place in get_places(): + with base.dygraph.guard(place): + paddle.seed(seed) + base.set_flags({'FLAGS_sort_sum_gradient': True}) + + input_word = ( + np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8] + ) + .reshape(6, 3) + .astype('int64') + ) + input_word1 = input_word * 2 + input_word = input_word.reshape((-1, 3, 1)) + input_word1 = input_word1.reshape((-1, 3, 1)) + y_data = ( + np.array( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + .reshape(6, 3) + .astype('int64') + ) + y_data = y_data.reshape((-1, 1)) + + input = paddle.to_tensor(input_word) + input1 = paddle.to_tensor(input_word1) + y = paddle.to_tensor(y_data) + + simplenet = SimpleNet( + hidden_size=20, + vocab_size=32, + num_steps=3, + init_scale=0.1, + is_sparse=False, + dtype="float32", + ) + + # origin, don't register any hook + outs_origin = simplenet(input, y) + outs_origin1 = simplenet(input1, y) + + # register forward_pre_hook + forward_pre_hook_handle1 = simplenet.register_forward_pre_hook( + forward_pre_hook1 + ) + outs_pre_hook = simplenet(input, y) + np.testing.assert_array_equal( + outs_pre_hook.numpy(), outs_origin1.numpy() + ) + + # remove forward_pre_hook + forward_pre_hook_handle1.remove() + outs_pre_hook = simplenet(input, y) + np.testing.assert_array_equal( + outs_pre_hook.numpy(), outs_origin.numpy() + ) + + # register forward_posst_hook + forward_post_hook_handle1 = ( + simplenet.register_forward_post_hook(forward_post_hook1) + ) + outs_forward_hook = simplenet(input, y) + np.testing.assert_array_equal( + outs_forward_hook.numpy(), outs_origin.numpy() * 2 + ) + + # remove forward_post_hook + forward_post_hook_handle1.remove() + outs_forward_hook = simplenet(input, y) + np.testing.assert_array_equal( + outs_forward_hook.numpy(), outs_origin.numpy() + ) + + # test forward_pre_hook and forward_post_hook that don't have return value + def test_forward_hook(self): + seed = 90 + + for place in get_places(): + with base.dygraph.guard(place): + paddle.seed(seed) + base.set_flags({'FLAGS_sort_sum_gradient': True}) + + global call_forward_post_hook + global call_forward_pre_hook + + input_word = ( + np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8] + ) + .reshape(6, 3) + .astype('int64') + ) + input_word = input_word.reshape((-1, 3, 1)) + y_data = ( + np.array( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + .reshape(6, 3) + .astype('int64') + ) + y_data = y_data.reshape((-1, 1)) + + input = paddle.to_tensor(input_word) + y = paddle.to_tensor(y_data) + + simplenet = SimpleNet( + hidden_size=20, + vocab_size=32, + num_steps=3, + init_scale=0.1, + is_sparse=False, + dtype="float32", + ) + + # origin, don't register any hook + outs_origin = simplenet(input, y) + self.assertFalse(call_forward_post_hook) + self.assertFalse(call_forward_pre_hook) + + # register forward_post_hook and forward_pre_hook + forward_post_hook_handle = simplenet.register_forward_post_hook( + forward_post_hook + ) + forward_pre_hook_handle = simplenet.register_forward_pre_hook( + forward_pre_hook + ) + outs_hook = simplenet(input, y) + self.assertTrue(call_forward_post_hook) + self.assertTrue(call_forward_pre_hook) + + outs_hook = simplenet(input, y) + self.assertTrue(call_forward_post_hook) + self.assertTrue(call_forward_pre_hook) + + # remove forward_post_hook + forward_post_hook_handle.remove() + call_forward_post_hook = False + call_forward_pre_hook = False + outs_remove_forward_hook = simplenet(input, y) + self.assertFalse(call_forward_post_hook) + self.assertTrue(call_forward_pre_hook) + + # remove forward_pre_hook + forward_pre_hook_handle.remove() + call_forward_post_hook = False + call_forward_pre_hook = False + outs_remove_hook = simplenet(input, y) + self.assertFalse(call_forward_post_hook) + self.assertFalse(call_forward_pre_hook) def forward_pre_hook_with_kwargs(layer, args, kwargs): diff --git a/test/legacy_test/test_matmul_out.py b/test/legacy_test/test_matmul_out.py index 49138d510028a1..6341bca827f828 100644 --- a/test/legacy_test/test_matmul_out.py +++ b/test/legacy_test/test_matmul_out.py @@ -17,6 +17,32 @@ import numpy as np import paddle +from paddle import base + + +def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + if X.ndim == 1: + X = X.reshape((X.size,)) + elif X.ndim == 2: + X = X.T + else: + dim = list(range(len(X.shape))) + dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] + X = np.transpose(X, tuple(dim)) + if transpose_Y: + if Y.ndim == 1: + Y = Y.reshape((Y.size,)) + else: + dim = list(range(len(Y.shape))) + dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] + Y = np.transpose(Y, tuple(dim)) + + Out = np.matmul(X, Y) + return Out class TestMatmulOutAndParamDecorator(unittest.TestCase): @@ -77,5 +103,89 @@ def test_matmul_out(self): ) +class TestMatMulAPI_Compatibility(unittest.TestCase): + def setUp(self): + np.random.seed(123) + paddle.enable_static() + self.x_shape = [5, 6] + self.y_shape = [6, 4] + self.dtype = 'float32' + self.init_data() + + def init_data(self): + self.np_x_input = np.random.randint(0, 8, self.x_shape).astype( + self.dtype + ) + self.np_y_input = np.random.randint(3, 9, self.y_shape).astype( + self.dtype + ) + + def test_dygraph_Compatibility(self): + paddle.disable_static() + x = paddle.to_tensor(self.np_x_input) + y = paddle.to_tensor(self.np_y_input) + paddle_dygraph_out = [] + # Position args (args) + out1 = paddle.matmul(x, y) + paddle_dygraph_out.append(out1) + # Key words args (kwargs) for paddle + out2 = paddle.matmul(x=x, y=y) + paddle_dygraph_out.append(out2) + # Key words args for torch + out3 = paddle.matmul(input=x, other=y) + paddle_dygraph_out.append(out3) + # Combined args and kwargs + out4 = paddle.matmul(x, other=y) + paddle_dygraph_out.append(out4) + # Tensor method args + out5 = x.matmul(y) + paddle_dygraph_out.append(out5) + # Tensor method kwargs + out6 = x.matmul(other=y) + paddle_dygraph_out.append(out6) + # Test out + out7 = paddle.empty([]) + paddle.matmul(x, other=y, out=out7) + paddle_dygraph_out.append(out7) + # Numpy reference out + ref_out = reference_matmul(self.np_x_input, self.np_y_input) + # Check + for out in paddle_dygraph_out: + np.testing.assert_allclose(ref_out, out.numpy()) + paddle.enable_static() + + def test_static_Compatibility(self): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): + x = paddle.static.data( + name="x", shape=self.x_shape, dtype=self.dtype + ) + y = paddle.static.data( + name="y", shape=self.y_shape, dtype=self.dtype + ) + # Position args (args) + out1 = paddle.matmul(x, y) + # Key words args (kwargs) for paddle + out2 = paddle.matmul(x=x, y=y) + # Key words args for torch + out3 = paddle.matmul(input=x, other=y) + # Combined args and kwargs + out4 = paddle.matmul(x, other=y) + # Tensor method args + out5 = x.matmul(y) + # Tensor method kwargs + out6 = x.matmul(other=y) + exe = base.Executor(paddle.CPUPlace()) + fetches = exe.run( + main, + feed={"x": self.np_x_input, "y": self.np_y_input}, + fetch_list=[out1, out2, out3, out4, out5, out6], + ) + ref_out = reference_matmul(self.np_x_input, self.np_y_input) + for out in fetches: + np.testing.assert_allclose(out, ref_out) + + if __name__ == "__main__": unittest.main()