-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 6th No.1】Add AdaptiveLogSoftmaxWithLoss API to Paddle -part #63302
Changes from 26 commits
2e73a42
1272637
74e9bb5
c81de86
2036830
9a489a9
5f989be
129095e
65da77e
12cb2ff
4f5fc2b
cca1636
6c637ec
d2190fc
e54b9a4
f04cb40
4ae1238
9118438
422d801
c5e1eb1
85d5295
bd534ab
1586812
a367a90
596b88d
66adc44
30ded8c
ad2d0c4
b231a1d
45c860b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4285,3 +4285,154 @@ def gaussian_nll_loss( | |||||||||
return paddle.sum(loss, name=name) | ||||||||||
elif reduction == 'none': | ||||||||||
return loss | ||||||||||
|
||||||||||
|
||||||||||
def adaptive_log_softmax_with_loss( | ||||||||||
input, label, head_weight, tail_weights, cutoffs, head_bias=None, name=None | ||||||||||
): | ||||||||||
r"""Compute adaptive logsoftmax result and negative log likelihood between ``input`` and ``label``. | ||||||||||
Parameter ``head``, ``tail_weights``, ``cutoffs`` are inner members of AdaptiveLogSoftmaxWithLoss | ||||||||||
Please refer to :ref:`_cn_api_paddle_nn_AdaptiveLogSoftmaxWithLoss`. | ||||||||||
|
||||||||||
Args: | ||||||||||
input (Tensor): Input tensor, the data type should be float32 or float64. | ||||||||||
label (Tensor): Label tensor, the data type should be float32 or float64. | ||||||||||
head_weight (Tensor): weight tensor for linear computation, the data type should be float32 or float64, the shape should be [input.shape[1], shortlist_size + n_clusters], where shortlist_size is the first element in the cutoffs list, and n_clusters is the length of the cutoffs list minus 1. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
尽量在官网展示的美观一点吧,都揉在一起了 |
||||||||||
tail_weights (list[Tensor]): weight tensor list for linear computation, the data type should be float32 or float64. The number of elements in the tail_weights depends on the value of the n_clusters, and each element contains the weights of two linear layers, their dimensions are [input.shape[1], hsz] and [hsz, osz], where hsz is the number of input features in_features divided by div_value to the power (i + 1), where i is the cyclic variable, from 0 to n_clusters - 1, and osz is the (i + 1) The difference between the cutoff and the ith cutoff. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
cutoffs (Sequence): Cutoffs used to assign targets to their buckets. | ||||||||||
head_bias (Tensor, optional): bias tensor for linear computation, the data type should be float32 or float64. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加一下默认值 |
||||||||||
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||||||||||
|
||||||||||
Returns: | ||||||||||
output (Tensor): The tensor sotring adaptive logsoftmax result, the shape of output is [N] | ||||||||||
loss (Tensor): The tensor variable storing the adaptive_log_softmax_loss of input and label. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
Examples:: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
.. code-block:: python | ||||||||||
|
||||||||||
>>> import paddle | ||||||||||
>>> import paddle.nn.functional as F | ||||||||||
|
||||||||||
>>> paddle.seed(2024) | ||||||||||
>>> input = paddle.randn([3, 5], dtype=paddle.float32) | ||||||||||
>>> head_weight = paddle.randn([5, 3], dtype=paddle.float32) | ||||||||||
>>> head_bias = paddle.randn([3], dtype=paddle.float32) | ||||||||||
>>> tail_weights = [] | ||||||||||
>>> tail_weights.append(paddle.randn([5, 2], dtype=paddle.float32)) | ||||||||||
>>> tail_weights.append(paddle.randn([2, 1], dtype=paddle.float32)) | ||||||||||
>>> out, loss = F.adaptive_log_softmax_with_loss(input, paddle.full((3,), 1, dtype='int64'), head_weight, tail_weights, cutoffs=[2], head_bias=head_bias) | ||||||||||
>>> print(out) | ||||||||||
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, | ||||||||||
[-0.99842924, -2.27753878, -0.16740258]) | ||||||||||
>>> print(loss) | ||||||||||
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, | ||||||||||
1.14779019) | ||||||||||
""" | ||||||||||
targt_dim = label.dim() | ||||||||||
|
||||||||||
if targt_dim == 1: | ||||||||||
if input.shape[0] != label.shape[0]: | ||||||||||
raise ValueError( | ||||||||||
'Input and label should have the same size ' | ||||||||||
'in the batch dimension.' | ||||||||||
) | ||||||||||
if input.dim() != 2: | ||||||||||
raise ValueError( | ||||||||||
'1D label tensor expects 2D input tensors, ' | ||||||||||
'but found inputs with size', | ||||||||||
input.shape, | ||||||||||
) | ||||||||||
elif targt_dim == 0: | ||||||||||
if input.dim() != 1: | ||||||||||
raise ValueError( | ||||||||||
'0D label tensor expects 1D input tensors, ' | ||||||||||
'but found inputs with size', | ||||||||||
input.shape, | ||||||||||
) | ||||||||||
else: | ||||||||||
raise ValueError( | ||||||||||
'0D or 1D label tensor expected, ' 'multi-label not supported' | ||||||||||
) | ||||||||||
|
||||||||||
is_batched = targt_dim > 0 | ||||||||||
input = input if is_batched else input.unsqueeze(0) | ||||||||||
label = label if is_batched else label.unsqueeze(0) | ||||||||||
|
||||||||||
used_rows = 0 | ||||||||||
batch_size = label.shape[0] | ||||||||||
|
||||||||||
output = paddle.zeros([batch_size], dtype=input.dtype) | ||||||||||
gather_inds = paddle.empty([batch_size], dtype=label.dtype) | ||||||||||
|
||||||||||
cutoff_values = [0] + cutoffs | ||||||||||
for i in range(len(cutoff_values) - 1): | ||||||||||
low_idx = cutoff_values[i] | ||||||||||
high_idx = cutoff_values[i + 1] | ||||||||||
|
||||||||||
label_mask = (label >= low_idx) & (label < high_idx) | ||||||||||
row_indices = label_mask.nonzero().squeeze() | ||||||||||
|
||||||||||
if row_indices.numel() == 0: | ||||||||||
continue | ||||||||||
|
||||||||||
if i == 0: | ||||||||||
scatter_output = paddle.scatter_nd( | ||||||||||
row_indices.unsqueeze(1), | ||||||||||
label.masked_select(label_mask), | ||||||||||
gather_inds.shape, | ||||||||||
) | ||||||||||
gather_inds = scatter_output | ||||||||||
else: | ||||||||||
relative_label = label[label_mask] - low_idx | ||||||||||
input_subset = input.index_select(row_indices, axis=0) | ||||||||||
|
||||||||||
cluster_output = paddle.nn.functional.linear( | ||||||||||
x=input_subset, weight=tail_weights[i - 1][0] | ||||||||||
) | ||||||||||
cluster_output = paddle.nn.functional.linear( | ||||||||||
x=cluster_output, weight=tail_weights[i - 1][1] | ||||||||||
) | ||||||||||
|
||||||||||
cluster_index = cutoffs[0] + i - 1 | ||||||||||
|
||||||||||
gather_inds = paddle.index_fill( | ||||||||||
gather_inds, row_indices, 0, cluster_index | ||||||||||
) | ||||||||||
|
||||||||||
cluster_logprob = paddle.nn.functional.log_softmax( | ||||||||||
cluster_output, axis=1 | ||||||||||
) | ||||||||||
|
||||||||||
local_logprob = paddle.take_along_axis( | ||||||||||
cluster_logprob, relative_label.unsqueeze(1), axis=1 | ||||||||||
) | ||||||||||
scatter_output = paddle.scatter_nd( | ||||||||||
row_indices.unsqueeze(1), local_logprob.squeeze(1), output.shape | ||||||||||
) | ||||||||||
output = ( | ||||||||||
output * (scatter_output == 0).astype('float32') | ||||||||||
+ scatter_output | ||||||||||
) | ||||||||||
|
||||||||||
used_rows += row_indices.numel() | ||||||||||
|
||||||||||
if used_rows != batch_size: | ||||||||||
raise ValueError( | ||||||||||
f"label values should be in [0, n_classes - 1], " | ||||||||||
f"but values in range [{label.min().item()}, {label.max().item()}] " | ||||||||||
"were found. " | ||||||||||
) | ||||||||||
|
||||||||||
head_output = paddle.nn.functional.linear( | ||||||||||
x=input, weight=head_weight, bias=head_bias | ||||||||||
) | ||||||||||
head_logprob = paddle.nn.functional.log_softmax(head_output, axis=1) | ||||||||||
output += paddle.take_along_axis( | ||||||||||
head_logprob, gather_inds.unsqueeze(1), axis=1 | ||||||||||
).squeeze() | ||||||||||
loss = (-output).mean() | ||||||||||
|
||||||||||
if not is_batched: | ||||||||||
output = output.squeeze(0) | ||||||||||
|
||||||||||
return output, loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.