Skip to content

Commit

Permalink
Fix the slow running speed of kl_div when option 'reduction' is set (#…
Browse files Browse the repository at this point in the history
…37283)

* Fix the slow running speed of kl_div when option reduction is set

* fix unittest coverage
  • Loading branch information
LielinJiang authored Nov 18, 2021
1 parent 9990952 commit a6e9ff8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
4 changes: 3 additions & 1 deletion python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def test_kl_loss_static_api(self):
input = paddle.fluid.data(name='input', shape=[5, 20])
label = paddle.fluid.data(name='label', shape=[5, 20])

pred_loss = paddle.nn.functional.kl_div(input, label)
paddle.nn.functional.kl_div(input, label)
paddle.nn.functional.kl_div(input, label, 'sum')
paddle.nn.functional.kl_div(input, label, 'batchmean')


class TestKLDivLossTypePromotion(unittest.TestCase):
Expand Down
20 changes: 18 additions & 2 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,15 @@ def kl_div(input, label, reduction='mean', name=None):
label = paddle.cast(label, 'float64')

if paddle.in_dynamic_mode():
out = _C_ops.kldiv_loss(input, label, 'reduction', reduction)
out = _C_ops.kldiv_loss(input, label, 'reduction', 'none')
if reduction == 'mean':
out = paddle.mean(out)
elif reduction == 'sum':
out = paddle.sum(out)
elif reduction == 'batchmean':
if len(input.shape) > 0:
batch_size = input.shape[0]
out = paddle.sum(out) / batch_size
return out

helper = LayerHelper('kl_div', **locals())
Expand All @@ -920,7 +928,15 @@ def kl_div(input, label, reduction='mean', name=None):
inputs={'X': input,
'Target': label},
outputs={'Loss': loss},
attrs={'reduction': reduction})
attrs={'reduction': 'none'})

if reduction == 'mean':
loss = paddle.mean(loss)
elif reduction == 'sum':
loss = paddle.sum(loss)
elif reduction == 'batchmean':
batch_size = paddle.shape(input)[0]
loss = paddle.sum(loss) / batch_size
return loss


Expand Down

0 comments on commit a6e9ff8

Please sign in to comment.