From 2e73a42ae53036f879a7236ad6a3a3d510f5e062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Sun, 7 Apr 2024 19:44:10 +0800 Subject: [PATCH 01/21] Add AdaptiveLogSoftmaxWithLoss API --- python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/loss.py | 141 ++++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 233 ++++++++++++++ .../test_adaptive_log_softmax_with_loss.py | 302 ++++++++++++++++++ 6 files changed, 681 insertions(+) create mode 100644 test/legacy_test/test_adaptive_log_softmax_with_loss.py diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 7818c2398494d2..a9d8312bb4ca0a 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -86,6 +86,7 @@ from .layer.distance import PairwiseDistance from .layer.layers import Layer from .layer.loss import ( + AdaptiveLogSoftmaxWithLoss, BCELoss, BCEWithLogitsLoss, CosineEmbeddingLoss, @@ -295,6 +296,7 @@ 'TripletMarginLoss', 'SoftMarginLoss', 'GaussianNLLLoss', + 'AdaptiveLogSoftmaxWithLoss', 'Unflatten', 'FractionalMaxPool2D', 'FractionalMaxPool3D', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 8f48a83575748e..295fdc9d82e328 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -91,6 +91,7 @@ ) from .input import embedding, one_hot from .loss import ( + adaptive_log_softmax_with_loss, binary_cross_entropy, binary_cross_entropy_with_logits, cosine_embedding_loss, @@ -273,6 +274,7 @@ 'rrelu', 'triplet_margin_with_distance_loss', 'triplet_margin_loss', + 'adaptive_log_softmax_with_loss', 'multi_margin_loss', 'soft_margin_loss', 'gaussian_nll_loss', diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 5741f0a643db0e..d447bc76e13d2d 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4272,3 +4272,144 @@ def gaussian_nll_loss( return paddle.sum(loss, name=name) elif reduction == 'none': return loss + + +def adaptive_log_softmax_with_loss( + input, label, head_weight, tail_weights, cutoffs, head_bias=None +): + r"""Compute adaptive logsoftmax result and negative log likelihood between `input` and `label`. + parameter `head_weight`, `tail_weights`, `cutoffs` and `head_bias` are inner members of AdaptiveLogSoftmaxWithLoss + Please refer to :ref:`_cn_api_paddle_nn_AdaptiveLogSoftmaxWithLoss`. + Args: + input (Tensor): Input tensor, the data type should be float32 or float64. + label (Tensor): Label tensor, the data type should be float32 or float64. + head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64. + tail_weights (Tensor): weight tensor for linear computation, the data type should be float32 or float64. + cutoffs (Sequence): Cutoffs used to assign targets to their buckets. + head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Returns: + output (Tensor): The tensor sotring adaptive logsoftmax result, the shape of output is [N] + loss (Tensor): The tensor variable storing the adaptive_log_softmax_loss of input and label. + Examples:: + .. code-block:: python + >>> import paddle + >>> import paddle.nn.functional as F + >>> input = paddle.randn([3, 5], dtype=paddle.float32) + >>> head_weight = paddle.randn([5, 3], dtype=paddle.float32) + >>> head_bias = paddle.randn([3], dtype=paddle.float32) + >>> tail_weights = [] + >>> tail_weights.append(paddle.randn([5, 1], dtype=paddle.float32)) + >>> tail_weights.append(paddle.randn([1, 2], dtype=paddle.float32)) + >>> out, loss = F.adaptive_log_softmax_with_loss(input, paddle.full((3,), 1, dtype='int64'), head_weight, head_bias, tail_weights, cutoffs=[2]) + >>> print(out) + Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [-4.26640177, -5.79977274, -0.00650562]) + >>> print(loss) + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + 3.35756016) + """ + targ_dim = label.dim() + if targ_dim == 1: + if input.shape[0] != label.shape[0]: + raise ValueError( + 'Input and label should have the same size ' + 'in the batch dimension.' + ) + if input.dim() != 2: + raise ValueError( + '1D label tensor expects 2D input tensors, ' + 'but found inputs with size', + input.size(), + ) + elif targ_dim == 0: + if input.dim() != 1: + raise ValueError( + '0D label tensor expects 1D input tensors, ' + 'but found inputs with size', + input.size(), + ) + else: + raise ValueError( + '0D or 1D label tensor expected, ' 'multi-label not supported' + ) + + is_batched = targ_dim > 0 + input = input if is_batched else input.unsqueeze(0) + label = label if is_batched else label.unsqueeze(0) + + used_rows = 0 + batch_size = label.shape[0] + + output = paddle.zeros([batch_size], dtype=input.dtype) + gather_inds = paddle.empty([batch_size], dtype=label.dtype) + + cutoff_values = [0] + cutoffs + for i in range(len(cutoff_values) - 1): + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + label_mask = (label >= low_idx) & (label < high_idx) + row_indices = label_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + label.masked_select(label_mask), + gather_inds.shape, + ) + gather_inds = scatter_output + + else: + relative_label = label[label_mask] - low_idx + input_subset = input.index_select(row_indices, axis=0) + + cluster_output = paddle.nn.functional.linear( + x=input_subset, weight=tail_weights[i - 1][0] + ) + cluster_output = paddle.nn.functional.linear( + x=cluster_output, weight=tail_weights[i - 1][1] + ) + cluster_index = cutoffs[0] + i - 1 + + gather_inds = paddle.index_fill( + gather_inds, row_indices, 0, cluster_index + ) + + cluster_logprob = paddle.nn.functional.log_softmax( + cluster_output, axis=1 + ) + + local_logprob = paddle.take_along_axis( + cluster_logprob, relative_label.unsqueeze(1), axis=1 + ) + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze( + 1), local_logprob.squeeze(1), output.shape + ) + output = output * (scatter_output == 0) + scatter_output + + used_rows += row_indices.numel() + if used_rows != batch_size: + raise ValueError( + f"label values should be in [0, n_classes - 1], " + f"but values in range [{label.min().item()}, {label.max().item()}] " + "were found. " + ) + + head_output = paddle.nn.functional.linear( + x=input, weight=head_weight, bias=head_bias + ) + head_logprob = paddle.nn.functional.log_softmax(head_output, axis=1) + output += paddle.take_along_axis( + head_logprob, gather_inds.unsqueeze(1), axis=1 + ).squeeze() + loss = (-output).mean() + + if not is_batched: + output = output.squeeze(0) + + return output, loss \ No newline at end of file diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 6516c85bdefffb..27d5cd4ecefa4b 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -60,6 +60,7 @@ from .distance import PairwiseDistance # noqa: F401 from .layers import Layer # noqa: F401 from .loss import ( # noqa: F401 + AdaptiveLogSoftmaxWithLoss, BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 1fd2501698c2f2..626732680a6387 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2208,3 +2208,236 @@ def forward(self, input, label, variance): self.name, ) return out + + +class AdaptiveLogSoftmaxWithLoss(Layer): + r"""Efficient softmax approximation as described in `Efficient softmax approximation for GPUs by Edouard Grave, + Armand Joulin, Moustapha Cissé, David Grangier, and Hervé Jégou `__. + Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when + the label distribution is highly imbalanced, for example in natural language modelling, where the word frequency + distribution approximately follows the `Zipf's law`_. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + Adaptive softmax partitions the labels into several clusters, according to their frequency. These clusters may contain + different number of targets each. Additionally, clusters containing less frequent labels assign lower dimensional + embeddings to those labels, which speeds up the computation. For each minibatch, only clusters for which at least + one target is present are evaluated. + The idea is that the clusters which are accessed frequently (like the first one, containing most frequent labels), + should also be cheap to compute -- that is, contain a small number of assigned labels. We highly recommend taking + a look at the original paper for more details. + For :attr:`cutoffs` should be an ordered Sequence of integers sorted in the increasing order. It controls number of + clusters and the partitioning of targets into clusters. For example setting ``cutoffs = [10, 100, 1000]`` means that + first `10` targets will be assigned to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be assigned + to the first cluster, and targets `101, 102, ..., 1000` will be assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned to the last, third cluster. + For :attr:`div_value` is used to compute the size of each additional cluster, which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:`1`). + For :attr:`head_bias` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. + Set to False in the official implementation. + Note: + Labels passed as inputs to this module should be sorted according to their frequency. This means that the most + frequent label should be represented by the index `0`, and the least frequent label should be represented by + the index `n_classes - 1`. To compute log-probabilities for all classes, the ``log_prob`` method can be used. + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset. + cutoffs (Sequence): Cutoffs used to assign targets to their buckets. + div_value (float, optional): value used as an exponent to compute sizes of the clusters. Default: 4.0. + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the adaptive softmax. Default: ``False``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Shape: + - input (Tensor): The input tensor. The shapes is [N, in_features]. N is batch size. + - label (Tensor): target. The shapes is `[N]` + - output1 (Tensor): The shape is `[N]` + - output2 (Scalar): + Returns: + A callable object of AdaptiveLogSoftmaxWithLoss. + Examples:: + .. code-block:: python + >>> import paddle + >>> import paddle.nn as nn + >>> paddle.seed(2023) + >>> input = paddle.randn([3, 5], dtype=paddle.float32) + >>> asfm = nn.AdaptiveLogSoftmaxWithLoss(in_features=5, n_classes=3, cutoffs=[2], div_value=2.0, head_bias=False) + >>> out, loss = asfm(input, paddle.full((3,), 1, dtype='int64')) + >>> print(out) + Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=False, + [-1.21106601, -0.88425100, -0.86460060]) + >>> print(loss) + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False, + 0.98663920) + >>> out = asfm.log_prob(input) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=False, + [[-1.50912428, -1.21106601, -0.73185283], + [-1.75451684, -0.88425100, -0.88192356], + [-2.56547689, -0.86460060, -0.68935889]]) + >>> out = asfm.predict(input) + >>> print(out) + Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [2, 2, 2]) + """ + + def __init__( + self, + in_features, + n_classes, + cutoffs, + div_value=4.0, + head_bias=False, + name=None, + ): + super().__init__() + self._dtype = self._helper.get_default_dtype() + cutoffs = list(cutoffs) + + if ( + (cutoffs != sorted(cutoffs)) + or (min(cutoffs) <= 0) + or (max(cutoffs) > (n_classes - 1)) + or (len(set(cutoffs)) != len(cutoffs)) + or any(int(c) != c for c in cutoffs) + ): + raise ValueError( + "cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1" + ) + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.is_head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head_weight = self.create_parameter( + shape=[self.in_features, self.head_size], + attr=None, + dtype=self._dtype, + is_bias=False, + ) + self.head_bias = None + if self.is_head_bias: + self.head_bias = self.create_parameter( + shape=[self.head_size], + attr=self.is_head_bias, + dtype=self._dtype, + is_bias=True, + ) + + self.tail_weights = [] + + for i in range(self.n_clusters): + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + projection = [] + projection.append( + self.create_parameter( + shape=[self.in_features, hsz], + attr=None, + dtype=self._dtype, + is_bias=False, + ) + ) + projection.append( + self.create_parameter( + shape=[hsz, osz], + attr=None, + dtype=self._dtype, + is_bias=False, + ) + ) + + self.tail_weights.append(projection) + + def forward(self, input, label): + return F.adaptive_log_softmax_with_loss( + input, + label, + self.head_weight, + self.tail_weights, + self.cutoffs, + self.head_bias, + ) + + def _get_full_log_prob(self, input, head_output): + out = paddle.empty((head_output.shape[0], self.n_classes)) + head_logprob = F.log_softmax(head_output, axis=1) + + if paddle.in_dynamic_mode(): + out[:, : self.shortlist_size] = head_logprob[ + :, : self.shortlist_size + ] + else: + paddle.static.setitem( + out, + ( + slice(None, None, None), + slice(None, self.shortlist_size, None), + ), + head_logprob, + ) + + for i, (start_idx, stop_idx) in enumerate( + zip(self.cutoffs, self.cutoffs[1:]) + ): + cluster_output = F.linear(x=input, weight=self.tail_weights[i][0]) + cluster_output = F.linear( + x=cluster_output, weight=self.tail_weights[i][1] + ) + cluster_logprob = F.log_softmax(cluster_output, axis=1) + output_logprob = cluster_logprob + head_logprob[ + :, self.shortlist_size + i + ].unsqueeze(1) + if paddle.in_dynamic_mode(): + out[:, start_idx:stop_idx] = output_logprob + else: + paddle.static.setitem( + out, + (slice(None, None, None), slice(start_idx, stop_idx, None)), + output_logprob, + ) + + return out + + def log_prob(self, input): + head_output = F.linear( + x=input, weight=self.head_weight, bias=self.head_bias + ) + return self._get_full_log_prob(input, head_output) + + def predict(self, input): + r"""This is equivalent to `self.log_pob(input).argmax(axis=1)`, but is more efficient in some cases. + Args: + input (Tensor): a minibatch of examples, The shape is [N, in_features] + Returns: + output (Tensor): a class with the highest probability for each example + Examples:: + Please refer to the example of AdaptiveLogSoftmaxWithLoss. + """ + + head_output = F.linear( + x=input, weight=self.head_weight, bias=self.head_bias + ) + output = paddle.argmax(head_output, axis=1).cast('float32') + not_in_shortlist = output >= self.shortlist_size + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return paddle.argmax(log_prob, axis=1) + + else: + log_prob = self._get_full_log_prob( + input[not_in_shortlist], head_output[not_in_shortlist] + ) + output[not_in_shortlist] = paddle.argmax(log_prob, axis=1).cast( + 'float32' + ) + return output diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py new file mode 100644 index 00000000000000..8fba453a5b227f --- /dev/null +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -0,0 +1,302 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import nn +from paddle.base import Program +from paddle.nn import functional as F + + +class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): + def setUp(self): + + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + self.log_np = np.random.randn(4, 8).astype('float32') + self.predict_np = np.abs(np.random.randn(64, 8).astype('float32')) + + def test_dygraph(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + x = paddle.randn((4, 8)) + self._test_log_probs_dygraph(x) + x = paddle.abs(paddle.randn((64, 8))) + self._test_correct_dygraph(x) + + def _test_log_probs_dygraph(self, x): + asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.0) + logprob_out = asfm.log_prob(x) + np.testing.assert_array_almost_equal( + paddle.exp(logprob_out).sum(1), paddle.ones([4]) + ) + + for v in [0, 1, 2, 3]: + y = paddle.full((4,), v, dtype='int64') + out, loss = asfm(x, y) + np.testing.assert_array_almost_equal( + out, + logprob_out.gather(y.unsqueeze(1), 1) + .slice([1], [0], [1]) + .squeeze(), + ) + np.testing.assert_array_almost_equal( + loss, F.nll_loss(logprob_out, y) + ) + + def _test_correct_dygraph(self, x): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2.0, head_bias=True + ) + asfm.head_weight.detach().abs() + asfm.head_bias.detach().abs() + asfm.head_weight.detach()[asfm.shortlist_size:, :] *= 0.0 + + out = asfm.predict(x) + np.testing.assert_array_almost_equal( + out, asfm.log_prob(x).argmax(axis=1) + ) + + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2.0, head_bias=True + ) + asfm.head_weight.detach().abs() + asfm.head_bias.detach().abs() + asfm.head_weight.detach()[: asfm.shortlist_size, :] *= 0.0 + + out = asfm.predict(x) + np.testing.assert_array_almost_equal( + out, asfm.log_prob(x).argmax(axis=1) + ) + + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2.0, head_bias=True + ) + asfm.head_weight.detach().abs() + asfm.head_bias.detach().abs() + + x[:32, : asfm.shortlist_size] *= 0.0 + x[32:, asfm.shortlist_size:] *= 0.0 + asfm.head_weight.detach()[ + : asfm.shortlist_size, asfm.shortlist_size: + ] *= 0.0 + asfm.head_weight.detach()[ + asfm.shortlist_size:, : asfm.shortlist_size + ] *= 0.0 + + out = asfm.predict(x) + np.testing.assert_array_almost_equal( + out, asfm.log_prob(x).argmax(axis=1) + ) + + # def test_static(self): + # paddle.enable_static() + # for place in self.place: + # self._test_log_probs_static(place) + # self._test_correct_static(place) + + def _test_log_probs_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(Program()): + asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.0) + x = paddle.static.data( + name="log_input", shape=[4, 8], dtype='float32' + ) + out = asfm.log_prob(x) + exe = paddle.static.Executor(place=place) + feed_list = {"log_input": self.log_np} + logprob_out = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + + np.testing.assert_array_almost_equal( + paddle.exp(logprob_out).sum(1), paddle.ones([4]) + ) + + for v in [0, 1, 2, 3]: + y = paddle.full((4,), v, dtype='int64') + out, loss = asfm(x, y) + f_out, f_loss = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out, loss], + ) + np.testing.assert_array_almost_equal( + f_out, + logprob_out.gather(y.unsqueeze(1), 1) + .slice([1], [0], [1]) + .squeeze(), + ) + np.testing.assert_array_almost_equal( + f_loss, F.nll_loss(logprob_out, y) + ) + + def _test_correct_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(Program()): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2.0, head_bias=True + ) + exe = paddle.static.Executor(place=place) + feed_list = {"predict_input": self.predict_np} + x = paddle.static.data( + name="predict_input", shape=[64, 8], dtype='float32' + ) + asfm.head_weight.detach().abs() + asfm.head_bias.detach().abs() + paddle.static.setitem( + asfm.head_weight.detach(), + ( + slice(asfm.shortlist_size, None, None), + slice(None, None, None), + ), + 0.0, + ) + out = asfm.predict(x) + predict_out1 = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + np.testing.assert_array_almost_equal( + predict_out1, asfm.log_prob(x).argmax(axis=1) + ) + + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2.0, head_bias=True + ) + asfm.head_weight.detach().abs() + asfm.head_bias.detach().abs() + paddle.static.setitem( + asfm.head_weight.detach(), + ( + slice(None, asfm.shortlist_size, None), + slice(None, None, None), + ), + 0.0, + ) + out = asfm.predict(x) + predict_out2 = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + np.testing.assert_array_almost_equal( + predict_out2, asfm.log_prob(x).argmax(axis=1) + ) + + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2.0, head_bias=True + ) + asfm.head_weight.detach().abs() + asfm.head_bias.detach().abs() + paddle.static.setitem( + x, + (slice(None, 32, None), slice(None, asfm.shortlist_size, None)), + 0.0, + ) + paddle.static.setitem( + x, + (slice(32, None, None), slice(asfm.shortlist_size, None, None)), + 0.0, + ) + paddle.static.setitem( + asfm.head_weight.detach(), + ( + slice(None, asfm.shortlist_size, None), + slice(asfm.shortlist_size, None, None), + ), + 0.0, + ) + paddle.static.setitem( + asfm.head_weight.detach(), + ( + slice(asfm.shortlist_size, None, None), + slice(None, asfm.shortlist_size, None), + ), + 0.0, + ) + out = asfm.predict(x) + predict_out3 = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + np.testing.assert_array_almost_equal( + predict_out3, asfm.log_prob(x).argmax(axis=1) + ) + + def test_shape(self): + with self.assertRaises(ValueError): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 15], div_value=2.0 + ) + x = paddle.randn((2, 16)) + y = paddle.to_tensor([0, 5, 10]) + asfm(x, y) + + with self.assertRaises(ValueError): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 15], div_value=2.0 + ) + x = paddle.randn((128, 16)) + y = paddle.randint(low=21, high=200, shape=[128]) + asfm(x, y) + + def test_cluster(self): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[128]) + output, loss = asfm(x, y) + self.assertEqual(asfm.head_weight.shape, [16, 5 + 3]) + self.assertEqual(asfm.tail_weights[0][1].shape, [8, 5]) + self.assertEqual(asfm.tail_weights[1][1].shape, [4, 5]) + self.assertEqual(asfm.tail_weights[2][1].shape, [2, 5]) + + self.assertEqual(output.shape, [128]) + + def test_error(self): + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 15, 15], div_value=2.0 + ) + + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 15, 10], div_value=2.0 + ) + + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 25], div_value=2.0 + ) + + with self.assertRaisesRegex( + ValueError, "cutoffs should be a sequence of unique," + ): + _ = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 20], div_value=2.0 + ) + + +if __name__ == "__main__": + unittest.main() From 127263786c7d786a000069d0b6163ecb5a727bae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Mon, 8 Apr 2024 14:00:44 +0800 Subject: [PATCH 02/21] update codestyle --- python/paddle/nn/functional/loss.py | 5 ++--- .../test_adaptive_log_softmax_with_loss.py | 12 +++++------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index d447bc76e13d2d..2b463cb2e95e95 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4387,8 +4387,7 @@ def adaptive_log_softmax_with_loss( cluster_logprob, relative_label.unsqueeze(1), axis=1 ) scatter_output = paddle.scatter_nd( - row_indices.unsqueeze( - 1), local_logprob.squeeze(1), output.shape + row_indices.unsqueeze(1), local_logprob.squeeze(1), output.shape ) output = output * (scatter_output == 0) + scatter_output @@ -4412,4 +4411,4 @@ def adaptive_log_softmax_with_loss( if not is_batched: output = output.squeeze(0) - return output, loss \ No newline at end of file + return output, loss diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index 8fba453a5b227f..6c0082b41765a4 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -24,7 +24,6 @@ class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): def setUp(self): - self.place = ['cpu'] if paddle.is_compiled_with_cuda(): self.place.append('gpu') @@ -66,7 +65,7 @@ def _test_correct_dygraph(self, x): ) asfm.head_weight.detach().abs() asfm.head_bias.detach().abs() - asfm.head_weight.detach()[asfm.shortlist_size:, :] *= 0.0 + asfm.head_weight.detach()[asfm.shortlist_size :, :] *= 0.0 out = asfm.predict(x) np.testing.assert_array_almost_equal( @@ -92,12 +91,12 @@ def _test_correct_dygraph(self, x): asfm.head_bias.detach().abs() x[:32, : asfm.shortlist_size] *= 0.0 - x[32:, asfm.shortlist_size:] *= 0.0 + x[32:, asfm.shortlist_size :] *= 0.0 asfm.head_weight.detach()[ - : asfm.shortlist_size, asfm.shortlist_size: + : asfm.shortlist_size, asfm.shortlist_size : ] *= 0.0 asfm.head_weight.detach()[ - asfm.shortlist_size:, : asfm.shortlist_size + asfm.shortlist_size :, : asfm.shortlist_size ] *= 0.0 out = asfm.predict(x) @@ -262,8 +261,7 @@ def test_shape(self): asfm(x, y) def test_cluster(self): - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 16, 20, [5, 10, 15], div_value=2.0) + asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.0) x = paddle.randn((128, 16)) y = paddle.randint(low=0, high=20, shape=[128]) output, loss = asfm(x, y) From c81de8628604d973ee41aef0a326b40f981ea51f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Thu, 11 Apr 2024 23:45:50 +0800 Subject: [PATCH 03/21] update loss --- python/paddle/nn/layer/loss.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 626732680a6387..f4208703795bae 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2220,23 +2220,31 @@ class AdaptiveLogSoftmaxWithLoss(Layer): different number of targets each. Additionally, clusters containing less frequent labels assign lower dimensional embeddings to those labels, which speeds up the computation. For each minibatch, only clusters for which at least one target is present are evaluated. + The idea is that the clusters which are accessed frequently (like the first one, containing most frequent labels), should also be cheap to compute -- that is, contain a small number of assigned labels. We highly recommend taking a look at the original paper for more details. + For :attr:`cutoffs` should be an ordered Sequence of integers sorted in the increasing order. It controls number of clusters and the partitioning of targets into clusters. For example setting ``cutoffs = [10, 100, 1000]`` means that first `10` targets will be assigned to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be assigned to the first cluster, and targets `101, 102, ..., 1000` will be assigned to the second cluster, while targets `1001, 1002, ..., n_classes - 1` will be assigned to the last, third cluster. + For :attr:`div_value` is used to compute the size of each additional cluster, which is given as - :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + + .. math:: + `\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:`1`). For :attr:`head_bias` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. Set to False in the official implementation. + Note: Labels passed as inputs to this module should be sorted according to their frequency. This means that the most frequent label should be represented by the index `0`, and the least frequent label should be represented by the index `n_classes - 1`. To compute log-probabilities for all classes, the ``log_prob`` method can be used. + Args: in_features (int): Number of features in the input tensor n_classes (int): Number of classes in the dataset. @@ -2244,15 +2252,19 @@ class AdaptiveLogSoftmaxWithLoss(Layer): div_value (float, optional): value used as an exponent to compute sizes of the clusters. Default: 4.0. head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the adaptive softmax. Default: ``False``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Shape: - input (Tensor): The input tensor. The shapes is [N, in_features]. N is batch size. - label (Tensor): target. The shapes is `[N]` - output1 (Tensor): The shape is `[N]` - output2 (Scalar): + Returns: A callable object of AdaptiveLogSoftmaxWithLoss. + Examples:: .. code-block:: python + >>> import paddle >>> import paddle.nn as nn >>> paddle.seed(2023) From 20368308140543bc5b117985ed300ec01b0fca98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Fri, 12 Apr 2024 15:13:36 +0800 Subject: [PATCH 04/21] test --- python/paddle/nn/functional/loss.py | 37 ++++--- python/paddle/nn/layer/loss.py | 96 +++++++++---------- .../test_adaptive_log_softmax_with_loss.py | 4 +- 3 files changed, 70 insertions(+), 67 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index cab307943e606e..2b7a4e55efa911 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4269,9 +4269,10 @@ def gaussian_nll_loss( def adaptive_log_softmax_with_loss( input, label, head_weight, tail_weights, cutoffs, head_bias=None ): - r"""Compute adaptive logsoftmax result and negative log likelihood between `input` and `label`. - parameter `head_weight`, `tail_weights`, `cutoffs` and `head_bias` are inner members of AdaptiveLogSoftmaxWithLoss + r"""Compute adaptive logsoftmax result and negative log likelihood between ``input`` and ``label``. + Parameter ``head``, ``tail_weights``, ``cutoffs`` are inner members of AdaptiveLogSoftmaxWithLoss Please refer to :ref:`_cn_api_paddle_nn_AdaptiveLogSoftmaxWithLoss`. + Args: input (Tensor): Input tensor, the data type should be float32 or float64. label (Tensor): Label tensor, the data type should be float32 or float64. @@ -4279,30 +4280,35 @@ def adaptive_log_softmax_with_loss( tail_weights (Tensor): weight tensor for linear computation, the data type should be float32 or float64. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Returns: output (Tensor): The tensor sotring adaptive logsoftmax result, the shape of output is [N] loss (Tensor): The tensor variable storing the adaptive_log_softmax_loss of input and label. + Examples:: .. code-block:: python + >>> import paddle >>> import paddle.nn.functional as F + + >>> paddle.seed(2024) >>> input = paddle.randn([3, 5], dtype=paddle.float32) >>> head_weight = paddle.randn([5, 3], dtype=paddle.float32) >>> head_bias = paddle.randn([3], dtype=paddle.float32) >>> tail_weights = [] - >>> tail_weights.append(paddle.randn([5, 1], dtype=paddle.float32)) - >>> tail_weights.append(paddle.randn([1, 2], dtype=paddle.float32)) - >>> out, loss = F.adaptive_log_softmax_with_loss(input, paddle.full((3,), 1, dtype='int64'), head_weight, head_bias, tail_weights, cutoffs=[2]) + >>> tail_weights.append(paddle.randn([5, 2], dtype=paddle.float32)) + >>> tail_weights.append(paddle.randn([2, 1], dtype=paddle.float32)) + >>> out, loss = F.adaptive_log_softmax_with_loss(input, paddle.full((3,), 1, dtype='int64'), head_weight, tail_weights, cutoffs=[2], head_bias=head_bias) >>> print(out) - Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - [-4.26640177, -5.79977274, -0.00650562]) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.99842924, -2.27753878, -0.16740258]) >>> print(loss) - Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, - 3.35756016) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 1.14779019) """ - targ_dim = label.dim() - if targ_dim == 1: + targt_dim = label.dim() + + if targt_dim == 1: if input.shape[0] != label.shape[0]: raise ValueError( 'Input and label should have the same size ' @@ -4314,7 +4320,7 @@ def adaptive_log_softmax_with_loss( 'but found inputs with size', input.size(), ) - elif targ_dim == 0: + elif targt_dim == 0: if input.dim() != 1: raise ValueError( '0D label tensor expects 1D input tensors, ' @@ -4326,7 +4332,7 @@ def adaptive_log_softmax_with_loss( '0D or 1D label tensor expected, ' 'multi-label not supported' ) - is_batched = targ_dim > 0 + is_batched = targt_dim > 0 input = input if is_batched else input.unsqueeze(0) label = label if is_batched else label.unsqueeze(0) @@ -4354,7 +4360,6 @@ def adaptive_log_softmax_with_loss( gather_inds.shape, ) gather_inds = scatter_output - else: relative_label = label[label_mask] - low_idx input_subset = input.index_select(row_indices, axis=0) @@ -4365,6 +4370,7 @@ def adaptive_log_softmax_with_loss( cluster_output = paddle.nn.functional.linear( x=cluster_output, weight=tail_weights[i - 1][1] ) + cluster_index = cutoffs[0] + i - 1 gather_inds = paddle.index_fill( @@ -4384,6 +4390,7 @@ def adaptive_log_softmax_with_loss( output = output * (scatter_output == 0) + scatter_output used_rows += row_indices.numel() + if used_rows != batch_size: raise ValueError( f"label values should be in [0, n_classes - 1], " diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index f4208703795bae..a58c5e1a131be8 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2211,11 +2211,10 @@ def forward(self, input, label, variance): class AdaptiveLogSoftmaxWithLoss(Layer): - r"""Efficient softmax approximation as described in `Efficient softmax approximation for GPUs by Edouard Grave, - Armand Joulin, Moustapha Cissé, David Grangier, and Hervé Jégou `__. - Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when + r"""Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when the label distribution is highly imbalanced, for example in natural language modelling, where the word frequency - distribution approximately follows the `Zipf's law`_. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + distribution approximately follows the ``Zipf's law``. + Adaptive softmax partitions the labels into several clusters, according to their frequency. These clusters may contain different number of targets each. Additionally, clusters containing less frequent labels assign lower dimensional embeddings to those labels, which speeds up the computation. For each minibatch, only clusters for which at least @@ -2225,28 +2224,24 @@ class AdaptiveLogSoftmaxWithLoss(Layer): should also be cheap to compute -- that is, contain a small number of assigned labels. We highly recommend taking a look at the original paper for more details. - For :attr:`cutoffs` should be an ordered Sequence of integers sorted in the increasing order. It controls number of + For :attr:``cutoffs`` should be an ordered Sequence of integers sorted in the increasing order. It controls number of clusters and the partitioning of targets into clusters. For example setting ``cutoffs = [10, 100, 1000]`` means that - first `10` targets will be assigned to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be assigned - to the first cluster, and targets `101, 102, ..., 1000` will be assigned to the second cluster, while targets - `1001, 1002, ..., n_classes - 1` will be assigned to the last, third cluster. + first ``10`` targets will be assigned to the 'head' of the adaptive softmax, targets ``11, 12, ..., 100`` will be assigned + to the first cluster, and targets ``101, 102, ..., 1000`` will be assigned to the second cluster, while targets + ``1001, 1002, ..., n_classes - 1`` will be assigned to the last, third cluster. - For :attr:`div_value` is used to compute the size of each additional cluster, which is given as + For :attr:``div_value`` is used to compute the size of each additional cluster, which is given as follow: .. math:: - `\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + \lfloor \frac{\text{in\_features}}{\text{div\_value}^{idx}} \rfloor - where :math:`idx` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:`1`). - For :attr:`head_bias` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. - Set to False in the official implementation. + where :math:``idx`` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:``1``). + + For :attr:``head_bias`` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. Set to False in the official implementation. - Note: - Labels passed as inputs to this module should be sorted according to their frequency. This means that the most - frequent label should be represented by the index `0`, and the least frequent label should be represented by - the index `n_classes - 1`. To compute log-probabilities for all classes, the ``log_prob`` method can be used. Args: - in_features (int): Number of features in the input tensor + in_features (int): Number of features in the input tensor. n_classes (int): Number of classes in the dataset. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. div_value (float, optional): value used as an exponent to compute sizes of the clusters. Default: 4.0. @@ -2262,31 +2257,39 @@ class AdaptiveLogSoftmaxWithLoss(Layer): Returns: A callable object of AdaptiveLogSoftmaxWithLoss. - Examples:: + Examples:: .. code-block:: python >>> import paddle >>> import paddle.nn as nn - >>> paddle.seed(2023) - >>> input = paddle.randn([3, 5], dtype=paddle.float32) - >>> asfm = nn.AdaptiveLogSoftmaxWithLoss(in_features=5, n_classes=3, cutoffs=[2], div_value=2.0, head_bias=False) - >>> out, loss = asfm(input, paddle.full((3,), 1, dtype='int64')) + >>> paddle.seed(2024) + + >>> input = paddle.randn([3, 5], dtype="float32") + >>> target = paddle.full((3,), 1, dtype='int64') + >>> asfm = nn.AdaptiveLogSoftmaxWithLoss(in_features=5, n_classes=3, cutoffs=[ + 2], div_value=2.0, head_bias=False) + >>> out, loss = asfm(input, target) >>> print(out) - Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=False, - [-1.21106601, -0.88425100, -0.86460060]) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=False, + [-1.04691017, -0.42341536, -1.16909981]) >>> print(loss) - Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False, - 0.98663920) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=False, + 0.87980843) >>> out = asfm.log_prob(input) >>> print(out) - Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=False, - [[-1.50912428, -1.21106601, -0.73185283], - [-1.75451684, -0.88425100, -0.88192356], - [-2.56547689, -0.86460060, -0.68935889]]) + Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=False, + [[-1.13710010, -1.04691017, -1.11403584], + [-1.51841831, -0.42341536, -2.07040048], + [-4.25405550, -1.16909981, -0.39282480]]) >>> out = asfm.predict(input) >>> print(out) - Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True, - [2, 2, 2]) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [1., 1., 2.]) + + Note: + Labels passed as inputs to this module should be sorted according to their frequency. This means that the most + frequent label should be represented by the index `0`, and the least frequent label should be represented by + the index `n_classes - 1`. To compute log-probabilities for all classes, the ``log_prob`` method can be used. """ def __init__( @@ -2331,14 +2334,15 @@ def __init__( dtype=self._dtype, is_bias=False, ) - self.head_bias = None if self.is_head_bias: self.head_bias = self.create_parameter( shape=[self.head_size], - attr=self.is_head_bias, + attr=None, dtype=self._dtype, is_bias=True, ) + else: + self.head_bias = None self.tail_weights = [] @@ -2362,7 +2366,6 @@ def __init__( is_bias=False, ) ) - self.tail_weights.append(projection) def forward(self, input, label): @@ -2404,6 +2407,7 @@ def _get_full_log_prob(self, input, head_output): output_logprob = cluster_logprob + head_logprob[ :, self.shortlist_size + i ].unsqueeze(1) + if paddle.in_dynamic_mode(): out[:, start_idx:stop_idx] = output_logprob else: @@ -2422,15 +2426,6 @@ def log_prob(self, input): return self._get_full_log_prob(input, head_output) def predict(self, input): - r"""This is equivalent to `self.log_pob(input).argmax(axis=1)`, but is more efficient in some cases. - Args: - input (Tensor): a minibatch of examples, The shape is [N, in_features] - Returns: - output (Tensor): a class with the highest probability for each example - Examples:: - Please refer to the example of AdaptiveLogSoftmaxWithLoss. - """ - head_output = F.linear( x=input, weight=self.head_weight, bias=self.head_bias ) @@ -2440,16 +2435,17 @@ def predict(self, input): if all_in_shortlist: return output - elif not_in_shortlist.all(): log_prob = self._get_full_log_prob(input, head_output) return paddle.argmax(log_prob, axis=1) - else: log_prob = self._get_full_log_prob( input[not_in_shortlist], head_output[not_in_shortlist] ) - output[not_in_shortlist] = paddle.argmax(log_prob, axis=1).cast( - 'float32' + indices = paddle.masked_select( + paddle.arange(len(not_in_shortlist)), not_in_shortlist ) - return output + result = paddle.scatter( + output, indices, paddle.argmax(log_prob, axis=1).cast('float32') + ) + return result diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index 6c0082b41765a4..4df557ca414a24 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ def test_dygraph(self): paddle.device.set_device(place) x = paddle.randn((4, 8)) self._test_log_probs_dygraph(x) - x = paddle.abs(paddle.randn((64, 8))) + x = paddle.abs(paddle.randn((4, 8))) self._test_correct_dygraph(x) def _test_log_probs_dygraph(self, x): From 9a489a95084d247f780a26ed526b4001ccf40ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Fri, 19 Apr 2024 16:47:38 +0800 Subject: [PATCH 05/21] update test --- .../test_adaptive_log_softmax_with_loss.py | 42 ++++++++++++++++--- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index 4df557ca414a24..dd4b646ffc3b6a 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -18,6 +18,7 @@ import paddle from paddle import nn +import paddle.optimizer as optim from paddle.base import Program from paddle.nn import functional as F @@ -104,12 +105,6 @@ def _test_correct_dygraph(self, x): out, asfm.log_prob(x).argmax(axis=1) ) - # def test_static(self): - # paddle.enable_static() - # for place in self.place: - # self._test_log_probs_static(place) - # self._test_correct_static(place) - def _test_log_probs_static(self, place): paddle.enable_static() with paddle.static.program_guard(Program()): @@ -260,6 +255,41 @@ def test_shape(self): y = paddle.randint(low=21, high=200, shape=[128]) asfm(x, y) + def test_output(self): + n_classes = 1000 + in_features = 128 + cutoffs = [200, 500, 900] + + x = paddle.randn([32, in_features]) + labels = paddle.randint(0, n_classes, [32]) + + model = nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs) + + optimizer = optim.Adam(parameters=model.parameters(), learning_rate=0.001) + + for epoch in range(10): + output, loss = model(x, labels) + + + optimizer.clear_grad() + loss.backward() + optimizer.step() + + with paddle.no_grad(): + log_probs = model.log_prob(x) + + predictions = model.predict(x) + + tail_weights_before_training = [proj[0].numpy().copy() for proj in model.tail_weights] + + with paddle.no_grad(): + output, loss = model(x, labels) + + tail_weights_after_training = [proj[0].numpy() for proj in model.tail_weights] + + for before, after in zip(tail_weights_before_training, tail_weights_after_training): + assert not np.any(before != after) + def test_cluster(self): asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.0) x = paddle.randn((128, 16)) From 5f989be0b57433f702fde4413185caf12b1162a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Fri, 19 Apr 2024 19:05:00 +0800 Subject: [PATCH 06/21] add weight_attr --- python/paddle/nn/layer/loss.py | 22 +- .../test_adaptive_log_softmax_with_loss.py | 261 ++++++++---------- 2 files changed, 138 insertions(+), 145 deletions(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index a58c5e1a131be8..8f5b9fdb1efbd4 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2244,6 +2244,16 @@ class AdaptiveLogSoftmaxWithLoss(Layer): in_features (int): Number of features in the input tensor. n_classes (int): Number of classes in the dataset. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. + weight_attr (ParamAttr, optional): The attribute for the learnable + weight of this layer. The default value is None. If the Initializer of the + param_attr is not set, the parameter is initialized with Xavier. + For detailed information, please refer to paddle.ParamAttr. + bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias + of this layer. If it is set to False, no bias will be added to the output. + If it is set to None or one kind of ParamAttr, a bias parameter will + be created according to ParamAttr. For detailed information, please refer + to paddle.ParamAttr. The default value is None and the bias will be + initialized to zero. div_value (float, optional): value used as an exponent to compute sizes of the clusters. Default: 4.0. head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the adaptive softmax. Default: ``False``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -2297,6 +2307,8 @@ def __init__( in_features, n_classes, cutoffs, + weight_attr=None, + bias_attr=None, div_value=4.0, head_bias=False, name=None, @@ -2322,6 +2334,8 @@ def __init__( self.n_classes = n_classes self.cutoffs = cutoffs + [n_classes] self.div_value = div_value + self._weight_attr = weight_attr + self._bias_attr = bias_attr self.is_head_bias = head_bias self.shortlist_size = self.cutoffs[0] @@ -2330,14 +2344,14 @@ def __init__( self.head_weight = self.create_parameter( shape=[self.in_features, self.head_size], - attr=None, + attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) if self.is_head_bias: self.head_bias = self.create_parameter( shape=[self.head_size], - attr=None, + attr=self._bias_attr, dtype=self._dtype, is_bias=True, ) @@ -2353,7 +2367,7 @@ def __init__( projection.append( self.create_parameter( shape=[self.in_features, hsz], - attr=None, + attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) @@ -2361,7 +2375,7 @@ def __init__( projection.append( self.create_parameter( shape=[hsz, osz], - attr=None, + attr=self._weight_attr, dtype=self._dtype, is_bias=False, ) diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index dd4b646ffc3b6a..b35e439437d0b5 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -18,11 +18,41 @@ import paddle from paddle import nn -import paddle.optimizer as optim from paddle.base import Program from paddle.nn import functional as F +class SimpleModel(nn.Layer): + def __init__( + self, + in_features, + n_classes, + cutoffs, + div_value=4.0, + head_bias=False, + ): + super().__init__() + self.fc = paddle.nn.Linear(in_features, in_features) + self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( + in_features, + n_classes, + cutoffs, + div_value=div_value, + head_bias=head_bias, + ) + + def forward(self, input, label=None): + x = self.fc(input) + if label is not None: + return self.adaptive_softmax(x, label) + else: + return self.adaptive_softmax.log_prob(x) + + def predict(self, input): + logprob = self.adaptive_softmax.log_prob(self.fc(input)) + return logprob.argmax(axis=1) + + class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): def setUp(self): self.place = ['cpu'] @@ -30,7 +60,7 @@ def setUp(self): self.place.append('gpu') self.log_np = np.random.randn(4, 8).astype('float32') self.predict_np = np.abs(np.random.randn(64, 8).astype('float32')) - + def test_dygraph(self): paddle.disable_static() for place in self.place: @@ -41,15 +71,15 @@ def test_dygraph(self): self._test_correct_dygraph(x) def _test_log_probs_dygraph(self, x): - asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.0) - logprob_out = asfm.log_prob(x) + model = SimpleModel(8, 4, [2], div_value=2.0) + logprob_out = model(x) np.testing.assert_array_almost_equal( paddle.exp(logprob_out).sum(1), paddle.ones([4]) ) for v in [0, 1, 2, 3]: y = paddle.full((4,), v, dtype='int64') - out, loss = asfm(x, y) + out, loss = model(x, y) np.testing.assert_array_almost_equal( out, logprob_out.gather(y.unsqueeze(1), 1) @@ -61,58 +91,52 @@ def _test_log_probs_dygraph(self, x): ) def _test_correct_dygraph(self, x): - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 8, 10, [4, 8], div_value=2.0, head_bias=True - ) - asfm.head_weight.detach().abs() - asfm.head_bias.detach().abs() - asfm.head_weight.detach()[asfm.shortlist_size :, :] *= 0.0 + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + model.adaptive_softmax.head_weight.detach()[ + model.adaptive_softmax.shortlist_size :, : + ] *= 0.0 - out = asfm.predict(x) - np.testing.assert_array_almost_equal( - out, asfm.log_prob(x).argmax(axis=1) - ) + out = model.predict(x) + np.testing.assert_array_almost_equal(out, model(x).argmax(axis=1)) - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 8, 10, [4, 8], div_value=2.0, head_bias=True - ) - asfm.head_weight.detach().abs() - asfm.head_bias.detach().abs() - asfm.head_weight.detach()[: asfm.shortlist_size, :] *= 0.0 + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + model.adaptive_softmax.head_weight.detach()[ + : model.adaptive_softmax.shortlist_size, : + ] *= 0.0 - out = asfm.predict(x) - np.testing.assert_array_almost_equal( - out, asfm.log_prob(x).argmax(axis=1) - ) + out = model.predict(x) + np.testing.assert_array_almost_equal(out, model(x).argmax(axis=1)) - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 8, 10, [4, 8], div_value=2.0, head_bias=True - ) - asfm.head_weight.detach().abs() - asfm.head_bias.detach().abs() + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() - x[:32, : asfm.shortlist_size] *= 0.0 - x[32:, asfm.shortlist_size :] *= 0.0 - asfm.head_weight.detach()[ - : asfm.shortlist_size, asfm.shortlist_size : + x[:32, : model.adaptive_softmax.shortlist_size] *= 0.0 + x[32:, model.adaptive_softmax.shortlist_size :] *= 0.0 + model.adaptive_softmax.head_weight.detach()[ + : model.adaptive_softmax.shortlist_size, + model.adaptive_softmax.shortlist_size :, ] *= 0.0 - asfm.head_weight.detach()[ - asfm.shortlist_size :, : asfm.shortlist_size + model.adaptive_softmax.head_weight.detach()[ + model.adaptive_softmax.shortlist_size :, + : model.adaptive_softmax.shortlist_size, ] *= 0.0 - out = asfm.predict(x) - np.testing.assert_array_almost_equal( - out, asfm.log_prob(x).argmax(axis=1) - ) + out = model.predict(x) + np.testing.assert_array_almost_equal(out, model(x).argmax(axis=1)) def _test_log_probs_static(self, place): paddle.enable_static() with paddle.static.program_guard(Program()): - asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.0) + model = SimpleModel(8, 4, [2], div_value=2.0) x = paddle.static.data( name="log_input", shape=[4, 8], dtype='float32' ) - out = asfm.log_prob(x) + out = model(x) exe = paddle.static.Executor(place=place) feed_list = {"log_input": self.log_np} logprob_out = exe.run( @@ -127,7 +151,7 @@ def _test_log_probs_static(self, place): for v in [0, 1, 2, 3]: y = paddle.full((4,), v, dtype='int64') - out, loss = asfm(x, y) + out, loss = model(x, y) f_out, f_loss = exe.run( paddle.static.default_main_program(), feed=feed_list, @@ -146,184 +170,139 @@ def _test_log_probs_static(self, place): def _test_correct_static(self, place): paddle.enable_static() with paddle.static.program_guard(Program()): - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 8, 10, [4, 8], div_value=2.0, head_bias=True - ) + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) exe = paddle.static.Executor(place=place) feed_list = {"predict_input": self.predict_np} x = paddle.static.data( name="predict_input", shape=[64, 8], dtype='float32' ) - asfm.head_weight.detach().abs() - asfm.head_bias.detach().abs() + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() paddle.static.setitem( - asfm.head_weight.detach(), + model.adaptive_softmax.head_weight.detach(), ( - slice(asfm.shortlist_size, None, None), + slice(model.adaptive_softmax.shortlist_size, None, None), slice(None, None, None), ), 0.0, ) - out = asfm.predict(x) + out = model.predict(x) predict_out1 = exe.run( paddle.static.default_main_program(), feed=feed_list, fetch_list=[out], ) np.testing.assert_array_almost_equal( - predict_out1, asfm.log_prob(x).argmax(axis=1) + predict_out1, model(x).argmax(axis=1) ) - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 8, 10, [4, 8], div_value=2.0, head_bias=True - ) - asfm.head_weight.detach().abs() - asfm.head_bias.detach().abs() + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() paddle.static.setitem( - asfm.head_weight.detach(), + model.adaptive_softmax.head_weight.detach(), ( - slice(None, asfm.shortlist_size, None), + slice(None, model.adaptive_softmax.shortlist_size, None), slice(None, None, None), ), 0.0, ) - out = asfm.predict(x) + out = model.predict(x) predict_out2 = exe.run( paddle.static.default_main_program(), feed=feed_list, fetch_list=[out], ) np.testing.assert_array_almost_equal( - predict_out2, asfm.log_prob(x).argmax(axis=1) + predict_out2, model(x).argmax(axis=1) ) - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 8, 10, [4, 8], div_value=2.0, head_bias=True - ) - asfm.head_weight.detach().abs() - asfm.head_bias.detach().abs() + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() paddle.static.setitem( x, - (slice(None, 32, None), slice(None, asfm.shortlist_size, None)), + ( + slice(None, 32, None), + slice(None, model.adaptive_softmax.shortlist_size, None), + ), 0.0, ) paddle.static.setitem( x, - (slice(32, None, None), slice(asfm.shortlist_size, None, None)), + ( + slice(32, None, None), + slice(model.adaptive_softmax.shortlist_size, None, None), + ), 0.0, ) paddle.static.setitem( - asfm.head_weight.detach(), + model.adaptive_softmax.head_weight.detach(), ( - slice(None, asfm.shortlist_size, None), - slice(asfm.shortlist_size, None, None), + slice( + None, model.adaptive_softmaxasfm.shortlist_size, None + ), + slice(model.adaptive_softmax.shortlist_size, None, None), ), 0.0, ) paddle.static.setitem( - asfm.head_weight.detach(), + model.adaptive_softmax.head_weight.detach(), ( - slice(asfm.shortlist_size, None, None), - slice(None, asfm.shortlist_size, None), + slice(model.adaptive_softmax.shortlist_size, None, None), + slice(None, model.adaptive_softmax.shortlist_size, None), ), 0.0, ) - out = asfm.predict(x) + out = model.predict(x) predict_out3 = exe.run( paddle.static.default_main_program(), feed=feed_list, fetch_list=[out], ) np.testing.assert_array_almost_equal( - predict_out3, asfm.log_prob(x).argmax(axis=1) + predict_out3, model(x).argmax(axis=1) ) def test_shape(self): with self.assertRaises(ValueError): - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 16, 20, [5, 10, 15], div_value=2.0 - ) + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) x = paddle.randn((2, 16)) y = paddle.to_tensor([0, 5, 10]) - asfm(x, y) - - with self.assertRaises(ValueError): - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 16, 20, [5, 10, 15], div_value=2.0 - ) - x = paddle.randn((128, 16)) - y = paddle.randint(low=21, high=200, shape=[128]) - asfm(x, y) - - def test_output(self): - n_classes = 1000 - in_features = 128 - cutoffs = [200, 500, 900] - - x = paddle.randn([32, in_features]) - labels = paddle.randint(0, n_classes, [32]) - - model = nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs) - - optimizer = optim.Adam(parameters=model.parameters(), learning_rate=0.001) - - for epoch in range(10): - output, loss = model(x, labels) - - - optimizer.clear_grad() - loss.backward() - optimizer.step() - - with paddle.no_grad(): - log_probs = model.log_prob(x) - - predictions = model.predict(x) - - tail_weights_before_training = [proj[0].numpy().copy() for proj in model.tail_weights] - - with paddle.no_grad(): - output, loss = model(x, labels) - - tail_weights_after_training = [proj[0].numpy() for proj in model.tail_weights] - - for before, after in zip(tail_weights_before_training, tail_weights_after_training): - assert not np.any(before != after) + model(x, y) def test_cluster(self): - asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.0) + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) x = paddle.randn((128, 16)) y = paddle.randint(low=0, high=20, shape=[128]) - output, loss = asfm(x, y) - self.assertEqual(asfm.head_weight.shape, [16, 5 + 3]) - self.assertEqual(asfm.tail_weights[0][1].shape, [8, 5]) - self.assertEqual(asfm.tail_weights[1][1].shape, [4, 5]) - self.assertEqual(asfm.tail_weights[2][1].shape, [2, 5]) + output, _ = model(x, y) + self.assertEqual(model.adaptive_softmax.head_weight.shape, [16, 5 + 3]) + self.assertEqual( + model.adaptive_softmax.tail_weights[0][1].shape, [8, 5] + ) + self.assertEqual( + model.adaptive_softmax.tail_weights[1][1].shape, [4, 5] + ) + self.assertEqual( + model.adaptive_softmax.tail_weights[2][1].shape, [2, 5] + ) self.assertEqual(output.shape, [128]) def test_error(self): with self.assertRaises(ValueError): - _ = nn.AdaptiveLogSoftmaxWithLoss( - 16, 20, [5, 15, 15], div_value=2.0 - ) + _ = SimpleModel(16, 20, [5, 15, 15], div_value=2.0) with self.assertRaises(ValueError): - _ = nn.AdaptiveLogSoftmaxWithLoss( - 16, 20, [5, 15, 10], div_value=2.0 - ) + _ = SimpleModel(16, 20, [5, 15, 10], div_value=2.0) with self.assertRaises(ValueError): - _ = nn.AdaptiveLogSoftmaxWithLoss( - 16, 20, [5, 10, 25], div_value=2.0 - ) + _ = SimpleModel(16, 20, [5, 10, 25], div_value=2.0) with self.assertRaisesRegex( ValueError, "cutoffs should be a sequence of unique," ): - _ = nn.AdaptiveLogSoftmaxWithLoss( - 16, 20, [5, 10, 20], div_value=2.0 - ) + _ = SimpleModel(16, 20, [5, 10, 20], div_value=2.0) if __name__ == "__main__": From 129095e57a241e6e8e22f4a4b8487a232ad04a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Mon, 22 Apr 2024 14:51:32 +0800 Subject: [PATCH 07/21] update forward --- test/legacy_test/test_adaptive_log_softmax_with_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index b35e439437d0b5..800a2a77c99f4a 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -60,7 +60,7 @@ def setUp(self): self.place.append('gpu') self.log_np = np.random.randn(4, 8).astype('float32') self.predict_np = np.abs(np.random.randn(64, 8).astype('float32')) - + def test_dygraph(self): paddle.disable_static() for place in self.place: From 65da77eb6ee19fd21f0bc90a81bd532c507b42f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Wed, 24 Apr 2024 13:36:31 +0800 Subject: [PATCH 08/21] update forward --- .../test_adaptive_log_softmax_with_loss.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index 800a2a77c99f4a..046fc9c2ef860c 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -18,6 +18,7 @@ import paddle from paddle import nn +import paddle.optimizer as optim from paddle.base import Program from paddle.nn import functional as F @@ -55,6 +56,7 @@ def predict(self, input): class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): def setUp(self): + paddle.seed(2024) self.place = ['cpu'] if paddle.is_compiled_with_cuda(): self.place.append('gpu') @@ -271,6 +273,44 @@ def test_shape(self): y = paddle.to_tensor([0, 5, 10]) model(x, y) + def test_forwadr(self): + n_classes = 4 + in_features = 8 + cutoffs = [2] + + x = paddle.randn([8, in_features]) + print(x) + labels = paddle.randint(0, n_classes, [8]) + print(labels) + model = SimpleModel(in_features, n_classes, cutoffs, div_value=2.0) + + optimizer = optim.Adam( + parameters=model.parameters(), learning_rate=0.001 + ) + for _ in range(2): + _, loss = model(x, labels) + + optimizer.clear_grad() + loss.backward() + optimizer.step() + + + tail_weights_before_training = [ + proj[0].numpy().copy() for proj in model.adaptive_softmax.tail_weights + ] + + with paddle.no_grad(): + output, loss = model(x, labels) + + tail_weights_after_training = [ + proj[0].numpy() for proj in model.adaptive_softmax.tail_weights + ] + + for before, after in zip( + tail_weights_before_training, tail_weights_after_training + ): + assert not np.any(before != after) + def test_cluster(self): model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) x = paddle.randn((128, 16)) From 12cb2ff66be90babb84a320277b11b8f669872e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Wed, 24 Apr 2024 15:53:59 +0800 Subject: [PATCH 09/21] update --- test/legacy_test/test_adaptive_log_softmax_with_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index 046fc9c2ef860c..dfb870d9ce79d9 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -17,8 +17,8 @@ import numpy as np import paddle -from paddle import nn import paddle.optimizer as optim +from paddle import nn from paddle.base import Program from paddle.nn import functional as F @@ -294,9 +294,9 @@ def test_forwadr(self): loss.backward() optimizer.step() - tail_weights_before_training = [ - proj[0].numpy().copy() for proj in model.adaptive_softmax.tail_weights + proj[0].numpy().copy() + for proj in model.adaptive_softmax.tail_weights ] with paddle.no_grad(): @@ -310,7 +310,7 @@ def test_forwadr(self): tail_weights_before_training, tail_weights_after_training ): assert not np.any(before != after) - + def test_cluster(self): model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) x = paddle.randn((128, 16)) From cca16367cf757fc1007da674f4673332b2f932c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Thu, 25 Apr 2024 15:04:23 +0800 Subject: [PATCH 10/21] update --- python/paddle/nn/functional/loss.py | 4 +- .../test_adaptive_log_softmax_with_loss.py | 113 +++++++++++++++++- 2 files changed, 109 insertions(+), 8 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2b7a4e55efa911..04aefbfc0ab577 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4318,14 +4318,14 @@ def adaptive_log_softmax_with_loss( raise ValueError( '1D label tensor expects 2D input tensors, ' 'but found inputs with size', - input.size(), + input.shape, ) elif targt_dim == 0: if input.dim() != 1: raise ValueError( '0D label tensor expects 1D input tensors, ' 'but found inputs with size', - input.size(), + input.shape, ) else: raise ValueError( diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index dfb870d9ce79d9..da55f5850323f5 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -50,8 +50,8 @@ def forward(self, input, label=None): return self.adaptive_softmax.log_prob(x) def predict(self, input): - logprob = self.adaptive_softmax.log_prob(self.fc(input)) - return logprob.argmax(axis=1) + logprob = self.adaptive_softmax.predict(self.fc(input)) + return logprob class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): @@ -278,10 +278,92 @@ def test_forwadr(self): in_features = 8 cutoffs = [2] - x = paddle.randn([8, in_features]) - print(x) - labels = paddle.randint(0, n_classes, [8]) - print(labels) + x = paddle.to_tensor( + [ + [ + 0.99785769, + -1.14492130, + 0.62956816, + 0.77550924, + -1.97198308, + 0.50906199, + 0.76702958, + 1.31143034, + ], + [ + 0.17371807, + 2.68322444, + 1.90870595, + 0.58601201, + -0.78898108, + 0.42098731, + -0.74253917, + -0.37492049, + ], + [ + -0.77694625, + -0.11529812, + 0.38232428, + 0.70575434, + 0.73429769, + 0.81399834, + 0.14212975, + 0.12567955, + ], + [ + 0.44165909, + 0.23613696, + 0.81143701, + 0.60473150, + 0.77017546, + 0.27865678, + -0.03236491, + 0.31634274, + ], + [ + 0.15336825, + -0.66177142, + -0.01784009, + 0.08901446, + 0.85228783, + 1.49427640, + -1.66938102, + 0.86154014, + ], + [ + -0.60814697, + 1.26191938, + -0.21735200, + -0.88890392, + 0.49093658, + -1.28960681, + 1.06943762, + 0.15803306, + ], + [ + -0.12136814, + -0.16133699, + 0.15643604, + 0.79464215, + -1.02201688, + 0.26957786, + -0.31038952, + 0.93334937, + ], + [ + 0.66997373, + 0.95807010, + -0.66944563, + -0.89887059, + 1.00404060, + 0.69594669, + -0.82105070, + 1.15200853, + ], + ], + dtype='float32', + ) + labels = paddle.to_tensor([3, 3, 3, 2, 3, 0, 0, 0], dtype='int64') model = SimpleModel(in_features, n_classes, cutoffs, div_value=2.0) optimizer = optim.Adam( @@ -344,6 +426,25 @@ def test_error(self): ): _ = SimpleModel(16, 20, [5, 10, 20], div_value=2.0) + def test_dim_error(self): + with self.assertRaises(ValueError): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((129, 16)) + y = paddle.randint(low=0, high=20, shape=[128]) + _ = model(x, y) + + with self.assertRaises(ValueError): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[]) + _ = model(x, y) + + with self.assertRaises(ValueError): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[128, 1]) + _ = model(x, y) + if __name__ == "__main__": unittest.main() From 6c637ec72f6283689e9626ede44a5e5e546cc324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Mon, 29 Apr 2024 10:27:05 +0800 Subject: [PATCH 11/21] update --- python/paddle/nn/functional/loss.py | 3 ++- test/legacy_test/test_adaptive_log_softmax_with_loss.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 04aefbfc0ab577..07c70187040c6e 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4267,7 +4267,7 @@ def gaussian_nll_loss( def adaptive_log_softmax_with_loss( - input, label, head_weight, tail_weights, cutoffs, head_bias=None + input, label, head_weight, tail_weights, cutoffs, head_bias=None, name=None ): r"""Compute adaptive logsoftmax result and negative log likelihood between ``input`` and ``label``. Parameter ``head``, ``tail_weights``, ``cutoffs`` are inner members of AdaptiveLogSoftmaxWithLoss @@ -4280,6 +4280,7 @@ def adaptive_log_softmax_with_loss( tail_weights (Tensor): weight tensor for linear computation, the data type should be float32 or float64. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: output (Tensor): The tensor sotring adaptive logsoftmax result, the shape of output is [N] diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index da55f5850323f5..1ea804ee7eb16e 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -273,7 +273,7 @@ def test_shape(self): y = paddle.to_tensor([0, 5, 10]) model(x, y) - def test_forwadr(self): + def test_forward(self): n_classes = 4 in_features = 8 cutoffs = [2] From d2190fcec862e1b447216054ceba88a0cf276b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Mon, 29 Apr 2024 11:14:16 +0800 Subject: [PATCH 12/21] update test_gard --- .../test_adaptive_log_softmax_with_loss.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index 1ea804ee7eb16e..4690b824004e61 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -446,5 +446,129 @@ def test_dim_error(self): _ = model(x, y) + def test_gard(self): + n_classes = 4 + in_features = 8 + cutoffs = [2] + + x = paddle.to_tensor( + [ + [ + 0.99785769, + -1.14492130, + 0.62956816, + 0.77550924, + -1.97198308, + 0.50906199, + 0.76702958, + 1.31143034, + ], + [ + 0.17371807, + 2.68322444, + 1.90870595, + 0.58601201, + -0.78898108, + 0.42098731, + -0.74253917, + -0.37492049, + ], + [ + -0.77694625, + -0.11529812, + 0.38232428, + 0.70575434, + 0.73429769, + 0.81399834, + 0.14212975, + 0.12567955, + ], + [ + 0.44165909, + 0.23613696, + 0.81143701, + 0.60473150, + 0.77017546, + 0.27865678, + -0.03236491, + 0.31634274, + ], + [ + 0.15336825, + -0.66177142, + -0.01784009, + 0.08901446, + 0.85228783, + 1.49427640, + -1.66938102, + 0.86154014, + ], + [ + -0.60814697, + 1.26191938, + -0.21735200, + -0.88890392, + 0.49093658, + -1.28960681, + 1.06943762, + 0.15803306, + ], + [ + -0.12136814, + -0.16133699, + 0.15643604, + 0.79464215, + -1.02201688, + 0.26957786, + -0.31038952, + 0.93334937, + ], + [ + 0.66997373, + 0.95807010, + -0.66944563, + -0.89887059, + 1.00404060, + 0.69594669, + -0.82105070, + 1.15200853, + ], + ], + dtype='float32', + ) + labels = paddle.to_tensor([3, 3, 3, 2, 3, 0, 0, 0], dtype='int64') + model = SimpleModel(in_features, n_classes, cutoffs, div_value=2.0) + + _, loss = model(x, labels) + + weights = model.adaptive_softmax.head_weight + loss.backward() + analytic_grads = weights.grad.numpy() + + h = 1e-5 + weights_np = weights.numpy().copy() + grad_numerical = np.zeros_like(weights_np) + + it = np.nditer(weights_np, flags=['multi_index'], op_flags=['readwrite']) + while not it.finished: + ix = it.multi_index + oldval = weights_np[ix] + weights_np[ix] = oldval + h + model.adaptive_softmax.head_weight.set_value(paddle.to_tensor(weights_np)) + _, y_pos = model(x, labels) + loss_pos = y_pos.mean() + + weights_np[ix] = oldval - h + model.adaptive_softmax.head_weight.set_value(paddle.to_tensor(weights_np)) + _, y_neg = model(x, labels) + loss_neg = y_neg.mean() + + grad_numerical[ix] = (loss_pos - loss_neg) / (2 * h) + weights_np[ix] = oldval + it.iternext() + + np.allclose(analytic_grads, grad_numerical, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": unittest.main() From f04cb40fa6890a61fca59984e10671ddb61be77d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Tue, 30 Apr 2024 09:53:37 +0800 Subject: [PATCH 13/21] update --- .../test_adaptive_log_softmax_with_loss.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py index 4690b824004e61..7886046a4b1137 100644 --- a/test/legacy_test/test_adaptive_log_softmax_with_loss.py +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -445,7 +445,6 @@ def test_dim_error(self): y = paddle.randint(low=0, high=20, shape=[128, 1]) _ = model(x, y) - def test_gard(self): n_classes = 4 in_features = 8 @@ -549,24 +548,30 @@ def test_gard(self): weights_np = weights.numpy().copy() grad_numerical = np.zeros_like(weights_np) - it = np.nditer(weights_np, flags=['multi_index'], op_flags=['readwrite']) + it = np.nditer( + weights_np, flags=['multi_index'], op_flags=['readwrite'] + ) while not it.finished: ix = it.multi_index oldval = weights_np[ix] weights_np[ix] = oldval + h - model.adaptive_softmax.head_weight.set_value(paddle.to_tensor(weights_np)) + model.adaptive_softmax.head_weight.set_value( + paddle.to_tensor(weights_np) + ) _, y_pos = model(x, labels) loss_pos = y_pos.mean() weights_np[ix] = oldval - h - model.adaptive_softmax.head_weight.set_value(paddle.to_tensor(weights_np)) + model.adaptive_softmax.head_weight.set_value( + paddle.to_tensor(weights_np) + ) _, y_neg = model(x, labels) loss_neg = y_neg.mean() grad_numerical[ix] = (loss_pos - loss_neg) / (2 * h) weights_np[ix] = oldval it.iternext() - + np.allclose(analytic_grads, grad_numerical, rtol=1e-5, atol=1e-5) From 9118438ea711c0b48f599cecb5c300ed6e7287d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Thu, 9 May 2024 20:30:17 +0800 Subject: [PATCH 14/21] update information --- python/paddle/nn/functional/loss.py | 4 ++-- python/paddle/nn/layer/loss.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index fa708873a3b017..a7b1781444ec17 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4292,8 +4292,8 @@ def adaptive_log_softmax_with_loss( Args: input (Tensor): Input tensor, the data type should be float32 or float64. label (Tensor): Label tensor, the data type should be float32 or float64. - head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64. - tail_weights (Tensor): weight tensor for linear computation, the data type should be float32 or float64. + head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64, the shape should be [input.shape[1], shortlist_size + n_clusters], where shortlist_size is the first element in the cutoffs list, and n_clusters is the length of the cutoffs list minus 1. + tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are [input.shape[1], hsz] and [hsz, osz], where hsz is the number of input features in_features divided by div_value to the power (i + 1), where i is the cyclic variable, from 0 to n_clusters - 1, and osz is the (i + 1) The difference between the cutoff and the ith cutoff. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 9b40234cb0ed9c..3c6b9ce4f1c28f 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +[# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From c5e1eb11e980bb445f693266c3b40769c3cb74fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Sat, 11 May 2024 23:26:07 +0800 Subject: [PATCH 15/21] update --- python/paddle/nn/layer/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3c6b9ce4f1c28f..9b40234cb0ed9c 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1,4 +1,4 @@ -[# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From a367a902e77d6b42e55712d783c67ef63f75b00d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Thu, 16 May 2024 17:08:19 +0800 Subject: [PATCH 16/21] update --- python/paddle/nn/functional/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 9b8629a8482edd..9065d7a27c95d9 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4409,7 +4409,7 @@ def adaptive_log_softmax_with_loss( scatter_output = paddle.scatter_nd( row_indices.unsqueeze(1), local_logprob.squeeze(1), output.shape ) - output = output * (scatter_output == 0) + scatter_output + output = output * (scatter_output == 0).astype('float32') + scatter_output used_rows += row_indices.numel() From 66adc443a8756afe49d75318b388cdebc756aa81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Thu, 16 May 2024 20:39:47 +0800 Subject: [PATCH 17/21] codestyle --- python/paddle/nn/functional/loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 9065d7a27c95d9..d44de01818dcdb 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4409,7 +4409,10 @@ def adaptive_log_softmax_with_loss( scatter_output = paddle.scatter_nd( row_indices.unsqueeze(1), local_logprob.squeeze(1), output.shape ) - output = output * (scatter_output == 0).astype('float32') + scatter_output + output = ( + output * (scatter_output == 0).astype('float32') + + scatter_output + ) used_rows += row_indices.numel() From 30ded8c0f7bc5f10001a1491fda1c8bc7a9da3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Tue, 21 May 2024 16:42:41 +0800 Subject: [PATCH 18/21] update --- python/paddle/nn/functional/loss.py | 12 ++++++------ python/paddle/nn/layer/loss.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index d44de01818dcdb..77167990064a02 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4292,22 +4292,22 @@ def adaptive_log_softmax_with_loss( ): r"""Compute adaptive logsoftmax result and negative log likelihood between ``input`` and ``label``. Parameter ``head``, ``tail_weights``, ``cutoffs`` are inner members of AdaptiveLogSoftmaxWithLoss - Please refer to :ref:`_cn_api_paddle_nn_AdaptiveLogSoftmaxWithLoss`. + Please refer to :ref:`api_paddle_nn_AdaptiveLogSoftmaxWithLoss`. Args: input (Tensor): Input tensor, the data type should be float32 or float64. label (Tensor): Label tensor, the data type should be float32 or float64. - head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64, the shape should be [input.shape[1], shortlist_size + n_clusters], where shortlist_size is the first element in the cutoffs list, and n_clusters is the length of the cutoffs list minus 1. - tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are [input.shape[1], hsz] and [hsz, osz], where hsz is the number of input features in_features divided by div_value to the power (i + 1), where i is the cyclic variable, from 0 to n_clusters - 1, and osz is the (i + 1) The difference between the cutoff and the ith cutoff. + head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64, the shape should be ``[input.shape[1], shortlist_size + n_clusters]``, where ``shortlist_size`` is the first element in the cutoffs list, and ``n_clusters`` is the length of the cutoffs list minus 1. + tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are ``[input.shape[1], hsz]`` and ``[hsz, osz]``, where ``hsz`` is the number of input features in_features divided by div_value to the power (i + 1), where i is the cyclic variable, from 0 to n_clusters - 1, and ``osz`` is the (i + 1) The difference between the cutoff and the ith cutoff. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - output (Tensor): The tensor sotring adaptive logsoftmax result, the shape of output is [N] - loss (Tensor): The tensor variable storing the adaptive_log_softmax_loss of input and label. + - output (Tensor). The tensor sotring adaptive logsoftmax result, the shape of output is [N] + - loss (Tensor). The tensor variable storing the adaptive_log_softmax_loss of input and label. - Examples:: + Examples: .. code-block:: python >>> import paddle diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 9b40234cb0ed9c..e8246c3916bb89 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2231,7 +2231,7 @@ def forward(self, input, label, variance): class AdaptiveLogSoftmaxWithLoss(Layer): r"""Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when the label distribution is highly imbalanced, for example in natural language modelling, where the word frequency - distribution approximately follows the ``Zipf's law``. + distribution approximately follows the `Zipf's law `_. Adaptive softmax partitions the labels into several clusters, according to their frequency. These clusters may contain different number of targets each. Additionally, clusters containing less frequent labels assign lower dimensional @@ -2265,12 +2265,12 @@ class AdaptiveLogSoftmaxWithLoss(Layer): weight_attr (ParamAttr, optional): The attribute for the learnable weight of this layer. The default value is None. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. - For detailed information, please refer to paddle.ParamAttr. + For detailed information, please refer to :ref:`api_paddle_ParamAttr`. bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias of this layer. If it is set to False, no bias will be added to the output. If it is set to None or one kind of ParamAttr, a bias parameter will be created according to ParamAttr. For detailed information, please refer - to paddle.ParamAttr. The default value is None and the bias will be + to :ref:`api_paddle_ParamAttr`. The default value is None and the bias will be initialized to zero. div_value (float, optional): value used as an exponent to compute sizes of the clusters. Default: 4.0. head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the adaptive softmax. Default: ``False``. @@ -2280,12 +2280,12 @@ class AdaptiveLogSoftmaxWithLoss(Layer): - input (Tensor): The input tensor. The shapes is [N, in_features]. N is batch size. - label (Tensor): target. The shapes is `[N]` - output1 (Tensor): The shape is `[N]` - - output2 (Scalar): + - output2 (Scalar). Returns: A callable object of AdaptiveLogSoftmaxWithLoss. - Examples:: + Examples: .. code-block:: python >>> import paddle From ad2d0c412d17efa1ea5b5b34e34903c8f6fead9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Wed, 22 May 2024 14:38:53 +0800 Subject: [PATCH 19/21] update --- python/paddle/nn/functional/loss.py | 2 +- python/paddle/nn/layer/loss.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 77167990064a02..560f7cd10e7b3c 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4304,7 +4304,7 @@ def adaptive_log_softmax_with_loss( name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - - output (Tensor). The tensor sotring adaptive logsoftmax result, the shape of output is [N] + - output (Tensor). The tensor sotring adaptive logsoftmax result, the shape of output is ``[N]`` - loss (Tensor). The tensor variable storing the adaptive_log_softmax_loss of input and label. Examples: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index e8246c3916bb89..4e6bb7b5df9329 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2242,20 +2242,20 @@ class AdaptiveLogSoftmaxWithLoss(Layer): should also be cheap to compute -- that is, contain a small number of assigned labels. We highly recommend taking a look at the original paper for more details. - For :attr:``cutoffs`` should be an ordered Sequence of integers sorted in the increasing order. It controls number of + For :attr:`cutoffs` should be an ordered Sequence of integers sorted in the increasing order. It controls number of clusters and the partitioning of targets into clusters. For example setting ``cutoffs = [10, 100, 1000]`` means that first ``10`` targets will be assigned to the 'head' of the adaptive softmax, targets ``11, 12, ..., 100`` will be assigned to the first cluster, and targets ``101, 102, ..., 1000`` will be assigned to the second cluster, while targets ``1001, 1002, ..., n_classes - 1`` will be assigned to the last, third cluster. - For :attr:``div_value`` is used to compute the size of each additional cluster, which is given as follow: + For :attr:`div_value` is used to compute the size of each additional cluster, which is given as follow: .. math:: \lfloor \frac{\text{in\_features}}{\text{div\_value}^{idx}} \rfloor where :math:``idx`` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:``1``). - For :attr:``head_bias`` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. Set to False in the official implementation. + For :attr:`head_bias` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. Set to False in the official implementation. Args: @@ -2278,14 +2278,14 @@ class AdaptiveLogSoftmaxWithLoss(Layer): Shape: - input (Tensor): The input tensor. The shapes is [N, in_features]. N is batch size. - - label (Tensor): target. The shapes is `[N]` - - output1 (Tensor): The shape is `[N]` + - label (Tensor): target. The shapes is ``[N]`` + - output1 (Tensor): The shape is ``[N]`` - output2 (Scalar). Returns: A callable object of AdaptiveLogSoftmaxWithLoss. - Examples: + Examples: .. code-block:: python >>> import paddle From b231a1d632c1c02e9e3d735e520a40cc9c3a9e62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Wed, 22 May 2024 14:45:10 +0800 Subject: [PATCH 20/21] update --- python/paddle/nn/functional/loss.py | 2 +- python/paddle/nn/layer/loss.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 560f7cd10e7b3c..8ebd38cf414886 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4298,7 +4298,7 @@ def adaptive_log_softmax_with_loss( input (Tensor): Input tensor, the data type should be float32 or float64. label (Tensor): Label tensor, the data type should be float32 or float64. head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64, the shape should be ``[input.shape[1], shortlist_size + n_clusters]``, where ``shortlist_size`` is the first element in the cutoffs list, and ``n_clusters`` is the length of the cutoffs list minus 1. - tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are ``[input.shape[1], hsz]`` and ``[hsz, osz]``, where ``hsz`` is the number of input features in_features divided by div_value to the power (i + 1), where i is the cyclic variable, from 0 to n_clusters - 1, and ``osz`` is the (i + 1) The difference between the cutoff and the ith cutoff. + tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are ``[input.shape[1], hsz]`` and ``[hsz, osz]``, where ``hsz`` is the number of input features in_features divided by div_value to the power ``(i + 1)``, where i is the cyclic variable, from ``0`` to ``n_clusters - 1``, and ``osz`` is the ``(i + 1)`` The difference between the cutoff and the ith cutoff. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 4e6bb7b5df9329..f93315c712d527 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2253,7 +2253,7 @@ class AdaptiveLogSoftmaxWithLoss(Layer): .. math:: \lfloor \frac{\text{in\_features}}{\text{div\_value}^{idx}} \rfloor - where :math:``idx`` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:``1``). + where :math:`idx` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:`1`). For :attr:`head_bias` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. Set to False in the official implementation. @@ -2277,7 +2277,7 @@ class AdaptiveLogSoftmaxWithLoss(Layer): name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Shape: - - input (Tensor): The input tensor. The shapes is [N, in_features]. N is batch size. + - input (Tensor): The input tensor. The shapes is ``[N, in_features]``. N is batch size. - label (Tensor): target. The shapes is ``[N]`` - output1 (Tensor): The shape is ``[N]`` - output2 (Scalar). @@ -2316,8 +2316,8 @@ class AdaptiveLogSoftmaxWithLoss(Layer): Note: Labels passed as inputs to this module should be sorted according to their frequency. This means that the most - frequent label should be represented by the index `0`, and the least frequent label should be represented by - the index `n_classes - 1`. To compute log-probabilities for all classes, the ``log_prob`` method can be used. + frequent label should be represented by the index ``0``, and the least frequent label should be represented by + the index ``n_classes - 1``. To compute log-probabilities for all classes, the ``log_prob`` method can be used. """ def __init__( From 45c860bbcd7705f60e0329dc01c73a302d01b1a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=A7?= <2085127827@qq.com> Date: Wed, 22 May 2024 15:24:34 +0800 Subject: [PATCH 21/21] update --- python/paddle/nn/functional/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 8ebd38cf414886..47a2cd3116b675 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4300,8 +4300,8 @@ def adaptive_log_softmax_with_loss( head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64, the shape should be ``[input.shape[1], shortlist_size + n_clusters]``, where ``shortlist_size`` is the first element in the cutoffs list, and ``n_clusters`` is the length of the cutoffs list minus 1. tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are ``[input.shape[1], hsz]`` and ``[hsz, osz]``, where ``hsz`` is the number of input features in_features divided by div_value to the power ``(i + 1)``, where i is the cyclic variable, from ``0`` to ``n_clusters - 1``, and ``osz`` is the ``(i + 1)`` The difference between the cutoff and the ith cutoff. cutoffs (Sequence): Cutoffs used to assign targets to their buckets. - head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. Default: ``None``. + name (str, optional): Name for the operation (optional, default is ``None``). For more information, please refer to :ref:`api_guide_Name`. Returns: - output (Tensor). The tensor sotring adaptive logsoftmax result, the shape of output is ``[N]``