From 4b37151b0f4244bfb0b73609f9a8d67b33b7a7c0 Mon Sep 17 00:00:00 2001 From: Birdylx <29754889+Birdylx@users.noreply.github.com> Date: Fri, 23 Aug 2024 14:11:11 +0800 Subject: [PATCH] [NPU] replace ce loss with nll loss (#3759) (#3782) --- paddleseg/models/losses/cross_entropy_loss.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/paddleseg/models/losses/cross_entropy_loss.py b/paddleseg/models/losses/cross_entropy_loss.py index 6958a35fea..3e04e4769d 100644 --- a/paddleseg/models/losses/cross_entropy_loss.py +++ b/paddleseg/models/losses/cross_entropy_loss.py @@ -18,6 +18,8 @@ from paddleseg.cvlibs import manager +_IS_NPU = "npu" in paddle.get_device() + @manager.LOSSES.add_component class CrossEntropyLoss(nn.Layer): @@ -81,11 +83,20 @@ def forward(self, logit, label, semantic_weights=None): logit = paddle.transpose(logit, [0, 2, 3, 1]) label = label.astype('int64') - loss = F.cross_entropy(logit, - label, - ignore_index=self.ignore_index, - reduction='none', - weight=self.weight) + if _IS_NPU: + logit = logit.transpose([0, 3, 1, 2]) + logit = F.log_softmax(logit, axis=1) + loss = F.nll_loss(logit, + label, + weight=self.weight, + ignore_index=self.ignore_index, + reduction='none') + else: + loss = F.cross_entropy(logit, + label, + ignore_index=self.ignore_index, + reduction='none', + weight=self.weight) return self._post_process_loss(logit, label, semantic_weights, loss)