diff --git a/paddleseg/models/losses/cross_entropy_loss.py b/paddleseg/models/losses/cross_entropy_loss.py index 6958a35fea..3e04e4769d 100644 --- a/paddleseg/models/losses/cross_entropy_loss.py +++ b/paddleseg/models/losses/cross_entropy_loss.py @@ -18,6 +18,8 @@ from paddleseg.cvlibs import manager +_IS_NPU = "npu" in paddle.get_device() + @manager.LOSSES.add_component class CrossEntropyLoss(nn.Layer): @@ -81,11 +83,20 @@ def forward(self, logit, label, semantic_weights=None): logit = paddle.transpose(logit, [0, 2, 3, 1]) label = label.astype('int64') - loss = F.cross_entropy(logit, - label, - ignore_index=self.ignore_index, - reduction='none', - weight=self.weight) + if _IS_NPU: + logit = logit.transpose([0, 3, 1, 2]) + logit = F.log_softmax(logit, axis=1) + loss = F.nll_loss(logit, + label, + weight=self.weight, + ignore_index=self.ignore_index, + reduction='none') + else: + loss = F.cross_entropy(logit, + label, + ignore_index=self.ignore_index, + reduction='none', + weight=self.weight) return self._post_process_loss(logit, label, semantic_weights, loss)