diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 7818c2398494d..a9d8312bb4ca0 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -86,6 +86,7 @@ from .layer.distance import PairwiseDistance from .layer.layers import Layer from .layer.loss import ( + AdaptiveLogSoftmaxWithLoss, BCELoss, BCEWithLogitsLoss, CosineEmbeddingLoss, @@ -295,6 +296,7 @@ 'TripletMarginLoss', 'SoftMarginLoss', 'GaussianNLLLoss', + 'AdaptiveLogSoftmaxWithLoss', 'Unflatten', 'FractionalMaxPool2D', 'FractionalMaxPool3D', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7722ffb437389..4543d5c8ca14d 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -94,6 +94,7 @@ ) from .input import embedding, one_hot from .loss import ( + adaptive_log_softmax_with_loss, binary_cross_entropy, binary_cross_entropy_with_logits, cosine_embedding_loss, @@ -276,6 +277,7 @@ 'rrelu', 'triplet_margin_with_distance_loss', 'triplet_margin_loss', + 'adaptive_log_softmax_with_loss', 'multi_margin_loss', 'soft_margin_loss', 'gaussian_nll_loss', diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index d15495993ce0e..47a2cd3116b67 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -4285,3 +4285,154 @@ def gaussian_nll_loss( return paddle.sum(loss, name=name) elif reduction == 'none': return loss + + +def adaptive_log_softmax_with_loss( + input, label, head_weight, tail_weights, cutoffs, head_bias=None, name=None +): + r"""Compute adaptive logsoftmax result and negative log likelihood between ``input`` and ``label``. + Parameter ``head``, ``tail_weights``, ``cutoffs`` are inner members of AdaptiveLogSoftmaxWithLoss + Please refer to :ref:`api_paddle_nn_AdaptiveLogSoftmaxWithLoss`. + + Args: + input (Tensor): Input tensor, the data type should be float32 or float64. + label (Tensor): Label tensor, the data type should be float32 or float64. + head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64, the shape should be ``[input.shape[1], shortlist_size + n_clusters]``, where ``shortlist_size`` is the first element in the cutoffs list, and ``n_clusters`` is the length of the cutoffs list minus 1. + tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are ``[input.shape[1], hsz]`` and ``[hsz, osz]``, where ``hsz`` is the number of input features in_features divided by div_value to the power ``(i + 1)``, where i is the cyclic variable, from ``0`` to ``n_clusters - 1``, and ``osz`` is the ``(i + 1)`` The difference between the cutoff and the ith cutoff. + cutoffs (Sequence): Cutoffs used to assign targets to their buckets. + head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. Default: ``None``. + name (str, optional): Name for the operation (optional, default is ``None``). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + - output (Tensor). The tensor sotring adaptive logsoftmax result, the shape of output is ``[N]`` + - loss (Tensor). The tensor variable storing the adaptive_log_softmax_loss of input and label. + + Examples: + .. code-block:: python + + >>> import paddle + >>> import paddle.nn.functional as F + + >>> paddle.seed(2024) + >>> input = paddle.randn([3, 5], dtype=paddle.float32) + >>> head_weight = paddle.randn([5, 3], dtype=paddle.float32) + >>> head_bias = paddle.randn([3], dtype=paddle.float32) + >>> tail_weights = [] + >>> tail_weights.append(paddle.randn([5, 2], dtype=paddle.float32)) + >>> tail_weights.append(paddle.randn([2, 1], dtype=paddle.float32)) + >>> out, loss = F.adaptive_log_softmax_with_loss(input, paddle.full((3,), 1, dtype='int64'), head_weight, tail_weights, cutoffs=[2], head_bias=head_bias) + >>> print(out) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.99842924, -2.27753878, -0.16740258]) + >>> print(loss) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 1.14779019) + """ + targt_dim = label.dim() + + if targt_dim == 1: + if input.shape[0] != label.shape[0]: + raise ValueError( + 'Input and label should have the same size ' + 'in the batch dimension.' + ) + if input.dim() != 2: + raise ValueError( + '1D label tensor expects 2D input tensors, ' + 'but found inputs with size', + input.shape, + ) + elif targt_dim == 0: + if input.dim() != 1: + raise ValueError( + '0D label tensor expects 1D input tensors, ' + 'but found inputs with size', + input.shape, + ) + else: + raise ValueError( + '0D or 1D label tensor expected, ' 'multi-label not supported' + ) + + is_batched = targt_dim > 0 + input = input if is_batched else input.unsqueeze(0) + label = label if is_batched else label.unsqueeze(0) + + used_rows = 0 + batch_size = label.shape[0] + + output = paddle.zeros([batch_size], dtype=input.dtype) + gather_inds = paddle.empty([batch_size], dtype=label.dtype) + + cutoff_values = [0] + cutoffs + for i in range(len(cutoff_values) - 1): + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + label_mask = (label >= low_idx) & (label < high_idx) + row_indices = label_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + label.masked_select(label_mask), + gather_inds.shape, + ) + gather_inds = scatter_output + else: + relative_label = label[label_mask] - low_idx + input_subset = input.index_select(row_indices, axis=0) + + cluster_output = paddle.nn.functional.linear( + x=input_subset, weight=tail_weights[i - 1][0] + ) + cluster_output = paddle.nn.functional.linear( + x=cluster_output, weight=tail_weights[i - 1][1] + ) + + cluster_index = cutoffs[0] + i - 1 + + gather_inds = paddle.index_fill( + gather_inds, row_indices, 0, cluster_index + ) + + cluster_logprob = paddle.nn.functional.log_softmax( + cluster_output, axis=1 + ) + + local_logprob = paddle.take_along_axis( + cluster_logprob, relative_label.unsqueeze(1), axis=1 + ) + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), local_logprob.squeeze(1), output.shape + ) + output = ( + output * (scatter_output == 0).astype('float32') + + scatter_output + ) + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise ValueError( + f"label values should be in [0, n_classes - 1], " + f"but values in range [{label.min().item()}, {label.max().item()}] " + "were found. " + ) + + head_output = paddle.nn.functional.linear( + x=input, weight=head_weight, bias=head_bias + ) + head_logprob = paddle.nn.functional.log_softmax(head_output, axis=1) + output += paddle.take_along_axis( + head_logprob, gather_inds.unsqueeze(1), axis=1 + ).squeeze() + loss = (-output).mean() + + if not is_batched: + output = output.squeeze(0) + + return output, loss diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 6516c85bdefff..27d5cd4ecefa4 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -60,6 +60,7 @@ from .distance import PairwiseDistance # noqa: F401 from .layers import Layer # noqa: F401 from .loss import ( # noqa: F401 + AdaptiveLogSoftmaxWithLoss, BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 087dc10a58e58..f93315c712d52 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -2226,3 +2226,258 @@ def forward(self, input, label, variance): self.name, ) return out + + +class AdaptiveLogSoftmaxWithLoss(Layer): + r"""Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when + the label distribution is highly imbalanced, for example in natural language modelling, where the word frequency + distribution approximately follows the `Zipf's law `_. + + Adaptive softmax partitions the labels into several clusters, according to their frequency. These clusters may contain + different number of targets each. Additionally, clusters containing less frequent labels assign lower dimensional + embeddings to those labels, which speeds up the computation. For each minibatch, only clusters for which at least + one target is present are evaluated. + + The idea is that the clusters which are accessed frequently (like the first one, containing most frequent labels), + should also be cheap to compute -- that is, contain a small number of assigned labels. We highly recommend taking + a look at the original paper for more details. + + For :attr:`cutoffs` should be an ordered Sequence of integers sorted in the increasing order. It controls number of + clusters and the partitioning of targets into clusters. For example setting ``cutoffs = [10, 100, 1000]`` means that + first ``10`` targets will be assigned to the 'head' of the adaptive softmax, targets ``11, 12, ..., 100`` will be assigned + to the first cluster, and targets ``101, 102, ..., 1000`` will be assigned to the second cluster, while targets + ``1001, 1002, ..., n_classes - 1`` will be assigned to the last, third cluster. + + For :attr:`div_value` is used to compute the size of each additional cluster, which is given as follow: + + .. math:: + \lfloor \frac{\text{in\_features}}{\text{div\_value}^{idx}} \rfloor + + where :math:`idx` is the cluster index (with clusters for less frequent words having larger indices, and indices starting from :math:`1`). + + For :attr:`head_bias` if set to True, adds a bias term to the 'head' of the adaptive softmax. See paper for details. Set to False in the official implementation. + + + Args: + in_features (int): Number of features in the input tensor. + n_classes (int): Number of classes in the dataset. + cutoffs (Sequence): Cutoffs used to assign targets to their buckets. + weight_attr (ParamAttr, optional): The attribute for the learnable + weight of this layer. The default value is None. If the Initializer of the + param_attr is not set, the parameter is initialized with Xavier. + For detailed information, please refer to :ref:`api_paddle_ParamAttr`. + bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias + of this layer. If it is set to False, no bias will be added to the output. + If it is set to None or one kind of ParamAttr, a bias parameter will + be created according to ParamAttr. For detailed information, please refer + to :ref:`api_paddle_ParamAttr`. The default value is None and the bias will be + initialized to zero. + div_value (float, optional): value used as an exponent to compute sizes of the clusters. Default: 4.0. + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the adaptive softmax. Default: ``False``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input (Tensor): The input tensor. The shapes is ``[N, in_features]``. N is batch size. + - label (Tensor): target. The shapes is ``[N]`` + - output1 (Tensor): The shape is ``[N]`` + - output2 (Scalar). + + Returns: + A callable object of AdaptiveLogSoftmaxWithLoss. + + Examples: + .. code-block:: python + + >>> import paddle + >>> import paddle.nn as nn + >>> paddle.seed(2024) + + >>> input = paddle.randn([3, 5], dtype="float32") + >>> target = paddle.full((3,), 1, dtype='int64') + >>> asfm = nn.AdaptiveLogSoftmaxWithLoss(in_features=5, n_classes=3, cutoffs=[ + 2], div_value=2.0, head_bias=False) + >>> out, loss = asfm(input, target) + >>> print(out) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=False, + [-1.04691017, -0.42341536, -1.16909981]) + >>> print(loss) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=False, + 0.87980843) + >>> out = asfm.log_prob(input) + >>> print(out) + Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=False, + [[-1.13710010, -1.04691017, -1.11403584], + [-1.51841831, -0.42341536, -2.07040048], + [-4.25405550, -1.16909981, -0.39282480]]) + >>> out = asfm.predict(input) + >>> print(out) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [1., 1., 2.]) + + Note: + Labels passed as inputs to this module should be sorted according to their frequency. This means that the most + frequent label should be represented by the index ``0``, and the least frequent label should be represented by + the index ``n_classes - 1``. To compute log-probabilities for all classes, the ``log_prob`` method can be used. + """ + + def __init__( + self, + in_features, + n_classes, + cutoffs, + weight_attr=None, + bias_attr=None, + div_value=4.0, + head_bias=False, + name=None, + ): + super().__init__() + self._dtype = self._helper.get_default_dtype() + cutoffs = list(cutoffs) + + if ( + (cutoffs != sorted(cutoffs)) + or (min(cutoffs) <= 0) + or (max(cutoffs) > (n_classes - 1)) + or (len(set(cutoffs)) != len(cutoffs)) + or any(int(c) != c for c in cutoffs) + ): + raise ValueError( + "cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1" + ) + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.is_head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head_weight = self.create_parameter( + shape=[self.in_features, self.head_size], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + if self.is_head_bias: + self.head_bias = self.create_parameter( + shape=[self.head_size], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True, + ) + else: + self.head_bias = None + + self.tail_weights = [] + + for i in range(self.n_clusters): + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + projection = [] + projection.append( + self.create_parameter( + shape=[self.in_features, hsz], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + ) + projection.append( + self.create_parameter( + shape=[hsz, osz], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + ) + self.tail_weights.append(projection) + + def forward(self, input, label): + return F.adaptive_log_softmax_with_loss( + input, + label, + self.head_weight, + self.tail_weights, + self.cutoffs, + self.head_bias, + ) + + def _get_full_log_prob(self, input, head_output): + out = paddle.empty((head_output.shape[0], self.n_classes)) + head_logprob = F.log_softmax(head_output, axis=1) + + if paddle.in_dynamic_mode(): + out[:, : self.shortlist_size] = head_logprob[ + :, : self.shortlist_size + ] + else: + paddle.static.setitem( + out, + ( + slice(None, None, None), + slice(None, self.shortlist_size, None), + ), + head_logprob, + ) + + for i, (start_idx, stop_idx) in enumerate( + zip(self.cutoffs, self.cutoffs[1:]) + ): + cluster_output = F.linear(x=input, weight=self.tail_weights[i][0]) + cluster_output = F.linear( + x=cluster_output, weight=self.tail_weights[i][1] + ) + cluster_logprob = F.log_softmax(cluster_output, axis=1) + output_logprob = cluster_logprob + head_logprob[ + :, self.shortlist_size + i + ].unsqueeze(1) + + if paddle.in_dynamic_mode(): + out[:, start_idx:stop_idx] = output_logprob + else: + paddle.static.setitem( + out, + (slice(None, None, None), slice(start_idx, stop_idx, None)), + output_logprob, + ) + + return out + + def log_prob(self, input): + head_output = F.linear( + x=input, weight=self.head_weight, bias=self.head_bias + ) + return self._get_full_log_prob(input, head_output) + + def predict(self, input): + head_output = F.linear( + x=input, weight=self.head_weight, bias=self.head_bias + ) + output = paddle.argmax(head_output, axis=1).cast('float32') + not_in_shortlist = output >= self.shortlist_size + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return paddle.argmax(log_prob, axis=1) + else: + log_prob = self._get_full_log_prob( + input[not_in_shortlist], head_output[not_in_shortlist] + ) + indices = paddle.masked_select( + paddle.arange(len(not_in_shortlist)), not_in_shortlist + ) + result = paddle.scatter( + output, indices, paddle.argmax(log_prob, axis=1).cast('float32') + ) + return result diff --git a/test/legacy_test/test_adaptive_log_softmax_with_loss.py b/test/legacy_test/test_adaptive_log_softmax_with_loss.py new file mode 100644 index 0000000000000..7886046a4b113 --- /dev/null +++ b/test/legacy_test/test_adaptive_log_softmax_with_loss.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.optimizer as optim +from paddle import nn +from paddle.base import Program +from paddle.nn import functional as F + + +class SimpleModel(nn.Layer): + def __init__( + self, + in_features, + n_classes, + cutoffs, + div_value=4.0, + head_bias=False, + ): + super().__init__() + self.fc = paddle.nn.Linear(in_features, in_features) + self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( + in_features, + n_classes, + cutoffs, + div_value=div_value, + head_bias=head_bias, + ) + + def forward(self, input, label=None): + x = self.fc(input) + if label is not None: + return self.adaptive_softmax(x, label) + else: + return self.adaptive_softmax.log_prob(x) + + def predict(self, input): + logprob = self.adaptive_softmax.predict(self.fc(input)) + return logprob + + +class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + self.log_np = np.random.randn(4, 8).astype('float32') + self.predict_np = np.abs(np.random.randn(64, 8).astype('float32')) + + def test_dygraph(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + x = paddle.randn((4, 8)) + self._test_log_probs_dygraph(x) + x = paddle.abs(paddle.randn((4, 8))) + self._test_correct_dygraph(x) + + def _test_log_probs_dygraph(self, x): + model = SimpleModel(8, 4, [2], div_value=2.0) + logprob_out = model(x) + np.testing.assert_array_almost_equal( + paddle.exp(logprob_out).sum(1), paddle.ones([4]) + ) + + for v in [0, 1, 2, 3]: + y = paddle.full((4,), v, dtype='int64') + out, loss = model(x, y) + np.testing.assert_array_almost_equal( + out, + logprob_out.gather(y.unsqueeze(1), 1) + .slice([1], [0], [1]) + .squeeze(), + ) + np.testing.assert_array_almost_equal( + loss, F.nll_loss(logprob_out, y) + ) + + def _test_correct_dygraph(self, x): + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + model.adaptive_softmax.head_weight.detach()[ + model.adaptive_softmax.shortlist_size :, : + ] *= 0.0 + + out = model.predict(x) + np.testing.assert_array_almost_equal(out, model(x).argmax(axis=1)) + + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + model.adaptive_softmax.head_weight.detach()[ + : model.adaptive_softmax.shortlist_size, : + ] *= 0.0 + + out = model.predict(x) + np.testing.assert_array_almost_equal(out, model(x).argmax(axis=1)) + + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + + x[:32, : model.adaptive_softmax.shortlist_size] *= 0.0 + x[32:, model.adaptive_softmax.shortlist_size :] *= 0.0 + model.adaptive_softmax.head_weight.detach()[ + : model.adaptive_softmax.shortlist_size, + model.adaptive_softmax.shortlist_size :, + ] *= 0.0 + model.adaptive_softmax.head_weight.detach()[ + model.adaptive_softmax.shortlist_size :, + : model.adaptive_softmax.shortlist_size, + ] *= 0.0 + + out = model.predict(x) + np.testing.assert_array_almost_equal(out, model(x).argmax(axis=1)) + + def _test_log_probs_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(Program()): + model = SimpleModel(8, 4, [2], div_value=2.0) + x = paddle.static.data( + name="log_input", shape=[4, 8], dtype='float32' + ) + out = model(x) + exe = paddle.static.Executor(place=place) + feed_list = {"log_input": self.log_np} + logprob_out = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + + np.testing.assert_array_almost_equal( + paddle.exp(logprob_out).sum(1), paddle.ones([4]) + ) + + for v in [0, 1, 2, 3]: + y = paddle.full((4,), v, dtype='int64') + out, loss = model(x, y) + f_out, f_loss = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out, loss], + ) + np.testing.assert_array_almost_equal( + f_out, + logprob_out.gather(y.unsqueeze(1), 1) + .slice([1], [0], [1]) + .squeeze(), + ) + np.testing.assert_array_almost_equal( + f_loss, F.nll_loss(logprob_out, y) + ) + + def _test_correct_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(Program()): + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + exe = paddle.static.Executor(place=place) + feed_list = {"predict_input": self.predict_np} + x = paddle.static.data( + name="predict_input", shape=[64, 8], dtype='float32' + ) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + paddle.static.setitem( + model.adaptive_softmax.head_weight.detach(), + ( + slice(model.adaptive_softmax.shortlist_size, None, None), + slice(None, None, None), + ), + 0.0, + ) + out = model.predict(x) + predict_out1 = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + np.testing.assert_array_almost_equal( + predict_out1, model(x).argmax(axis=1) + ) + + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + paddle.static.setitem( + model.adaptive_softmax.head_weight.detach(), + ( + slice(None, model.adaptive_softmax.shortlist_size, None), + slice(None, None, None), + ), + 0.0, + ) + out = model.predict(x) + predict_out2 = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + np.testing.assert_array_almost_equal( + predict_out2, model(x).argmax(axis=1) + ) + + model = SimpleModel(8, 10, [4, 8], div_value=2.0, head_bias=True) + model.adaptive_softmax.head_weight.detach().abs() + model.adaptive_softmax.head_bias.detach().abs() + paddle.static.setitem( + x, + ( + slice(None, 32, None), + slice(None, model.adaptive_softmax.shortlist_size, None), + ), + 0.0, + ) + paddle.static.setitem( + x, + ( + slice(32, None, None), + slice(model.adaptive_softmax.shortlist_size, None, None), + ), + 0.0, + ) + paddle.static.setitem( + model.adaptive_softmax.head_weight.detach(), + ( + slice( + None, model.adaptive_softmaxasfm.shortlist_size, None + ), + slice(model.adaptive_softmax.shortlist_size, None, None), + ), + 0.0, + ) + paddle.static.setitem( + model.adaptive_softmax.head_weight.detach(), + ( + slice(model.adaptive_softmax.shortlist_size, None, None), + slice(None, model.adaptive_softmax.shortlist_size, None), + ), + 0.0, + ) + out = model.predict(x) + predict_out3 = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + ) + np.testing.assert_array_almost_equal( + predict_out3, model(x).argmax(axis=1) + ) + + def test_shape(self): + with self.assertRaises(ValueError): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((2, 16)) + y = paddle.to_tensor([0, 5, 10]) + model(x, y) + + def test_forward(self): + n_classes = 4 + in_features = 8 + cutoffs = [2] + + x = paddle.to_tensor( + [ + [ + 0.99785769, + -1.14492130, + 0.62956816, + 0.77550924, + -1.97198308, + 0.50906199, + 0.76702958, + 1.31143034, + ], + [ + 0.17371807, + 2.68322444, + 1.90870595, + 0.58601201, + -0.78898108, + 0.42098731, + -0.74253917, + -0.37492049, + ], + [ + -0.77694625, + -0.11529812, + 0.38232428, + 0.70575434, + 0.73429769, + 0.81399834, + 0.14212975, + 0.12567955, + ], + [ + 0.44165909, + 0.23613696, + 0.81143701, + 0.60473150, + 0.77017546, + 0.27865678, + -0.03236491, + 0.31634274, + ], + [ + 0.15336825, + -0.66177142, + -0.01784009, + 0.08901446, + 0.85228783, + 1.49427640, + -1.66938102, + 0.86154014, + ], + [ + -0.60814697, + 1.26191938, + -0.21735200, + -0.88890392, + 0.49093658, + -1.28960681, + 1.06943762, + 0.15803306, + ], + [ + -0.12136814, + -0.16133699, + 0.15643604, + 0.79464215, + -1.02201688, + 0.26957786, + -0.31038952, + 0.93334937, + ], + [ + 0.66997373, + 0.95807010, + -0.66944563, + -0.89887059, + 1.00404060, + 0.69594669, + -0.82105070, + 1.15200853, + ], + ], + dtype='float32', + ) + labels = paddle.to_tensor([3, 3, 3, 2, 3, 0, 0, 0], dtype='int64') + model = SimpleModel(in_features, n_classes, cutoffs, div_value=2.0) + + optimizer = optim.Adam( + parameters=model.parameters(), learning_rate=0.001 + ) + for _ in range(2): + _, loss = model(x, labels) + + optimizer.clear_grad() + loss.backward() + optimizer.step() + + tail_weights_before_training = [ + proj[0].numpy().copy() + for proj in model.adaptive_softmax.tail_weights + ] + + with paddle.no_grad(): + output, loss = model(x, labels) + + tail_weights_after_training = [ + proj[0].numpy() for proj in model.adaptive_softmax.tail_weights + ] + + for before, after in zip( + tail_weights_before_training, tail_weights_after_training + ): + assert not np.any(before != after) + + def test_cluster(self): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[128]) + output, _ = model(x, y) + self.assertEqual(model.adaptive_softmax.head_weight.shape, [16, 5 + 3]) + self.assertEqual( + model.adaptive_softmax.tail_weights[0][1].shape, [8, 5] + ) + self.assertEqual( + model.adaptive_softmax.tail_weights[1][1].shape, [4, 5] + ) + self.assertEqual( + model.adaptive_softmax.tail_weights[2][1].shape, [2, 5] + ) + + self.assertEqual(output.shape, [128]) + + def test_error(self): + with self.assertRaises(ValueError): + _ = SimpleModel(16, 20, [5, 15, 15], div_value=2.0) + + with self.assertRaises(ValueError): + _ = SimpleModel(16, 20, [5, 15, 10], div_value=2.0) + + with self.assertRaises(ValueError): + _ = SimpleModel(16, 20, [5, 10, 25], div_value=2.0) + + with self.assertRaisesRegex( + ValueError, "cutoffs should be a sequence of unique," + ): + _ = SimpleModel(16, 20, [5, 10, 20], div_value=2.0) + + def test_dim_error(self): + with self.assertRaises(ValueError): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((129, 16)) + y = paddle.randint(low=0, high=20, shape=[128]) + _ = model(x, y) + + with self.assertRaises(ValueError): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[]) + _ = model(x, y) + + with self.assertRaises(ValueError): + model = SimpleModel(16, 20, [5, 10, 15], div_value=2.0) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[128, 1]) + _ = model(x, y) + + def test_gard(self): + n_classes = 4 + in_features = 8 + cutoffs = [2] + + x = paddle.to_tensor( + [ + [ + 0.99785769, + -1.14492130, + 0.62956816, + 0.77550924, + -1.97198308, + 0.50906199, + 0.76702958, + 1.31143034, + ], + [ + 0.17371807, + 2.68322444, + 1.90870595, + 0.58601201, + -0.78898108, + 0.42098731, + -0.74253917, + -0.37492049, + ], + [ + -0.77694625, + -0.11529812, + 0.38232428, + 0.70575434, + 0.73429769, + 0.81399834, + 0.14212975, + 0.12567955, + ], + [ + 0.44165909, + 0.23613696, + 0.81143701, + 0.60473150, + 0.77017546, + 0.27865678, + -0.03236491, + 0.31634274, + ], + [ + 0.15336825, + -0.66177142, + -0.01784009, + 0.08901446, + 0.85228783, + 1.49427640, + -1.66938102, + 0.86154014, + ], + [ + -0.60814697, + 1.26191938, + -0.21735200, + -0.88890392, + 0.49093658, + -1.28960681, + 1.06943762, + 0.15803306, + ], + [ + -0.12136814, + -0.16133699, + 0.15643604, + 0.79464215, + -1.02201688, + 0.26957786, + -0.31038952, + 0.93334937, + ], + [ + 0.66997373, + 0.95807010, + -0.66944563, + -0.89887059, + 1.00404060, + 0.69594669, + -0.82105070, + 1.15200853, + ], + ], + dtype='float32', + ) + labels = paddle.to_tensor([3, 3, 3, 2, 3, 0, 0, 0], dtype='int64') + model = SimpleModel(in_features, n_classes, cutoffs, div_value=2.0) + + _, loss = model(x, labels) + + weights = model.adaptive_softmax.head_weight + loss.backward() + analytic_grads = weights.grad.numpy() + + h = 1e-5 + weights_np = weights.numpy().copy() + grad_numerical = np.zeros_like(weights_np) + + it = np.nditer( + weights_np, flags=['multi_index'], op_flags=['readwrite'] + ) + while not it.finished: + ix = it.multi_index + oldval = weights_np[ix] + weights_np[ix] = oldval + h + model.adaptive_softmax.head_weight.set_value( + paddle.to_tensor(weights_np) + ) + _, y_pos = model(x, labels) + loss_pos = y_pos.mean() + + weights_np[ix] = oldval - h + model.adaptive_softmax.head_weight.set_value( + paddle.to_tensor(weights_np) + ) + _, y_neg = model(x, labels) + loss_neg = y_neg.mean() + + grad_numerical[ix] = (loss_pos - loss_neg) / (2 * h) + weights_np[ix] = oldval + it.iternext() + + np.allclose(analytic_grads, grad_numerical, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + unittest.main()