From 0b534e62d03db5ef74f77b61837e0561a1fc129a Mon Sep 17 00:00:00 2001 From: nikola-jovisic Date: Tue, 22 Aug 2023 13:05:13 +0200 Subject: [PATCH] fix #4 fix binary classification for tensorflow segformer fix binary classification for tf segformer #2 --- .../models/segformer/modeling_tf_segformer.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/segformer/modeling_tf_segformer.py b/src/transformers/models/segformer/modeling_tf_segformer.py index 632382f95ed0a7..9464180c2bbfd0 100644 --- a/src/transformers/models/segformer/modeling_tf_segformer.py +++ b/src/transformers/models/segformer/modeling_tf_segformer.py @@ -759,7 +759,24 @@ def hf_compute_loss(self, logits, labels): upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") # compute weighted loss - loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") + + if self.config.num_labels > 1: + loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") + loss = loss_fct(upsampled_logits, labels) + elif self.config.num_labels == 1: + valid_mask = tf.cast( + (labels >= 0) & (labels != self.config.semantic_loss_ignore_index), + dtype=tf.float32 + ) + loss_fct = tf.keras.losses.BinaryCrossentropy( + from_logits=True, + reduction=tf.keras.losses.Reduction.NONE + ) + loss = loss_fct(labels, upsampled_logits[:, 0]) # Assuming channel dimension is last + loss = tf.reduce_mean(loss * valid_mask) + else: + raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}") + def masked_loss(real, pred): unmasked_loss = loss_fct(real, pred) @@ -829,10 +846,7 @@ def call( loss = None if labels is not None: - if not self.config.num_labels > 1: - raise ValueError("The number of labels should be greater than one") - else: - loss = self.hf_compute_loss(logits=logits, labels=labels) + loss = self.hf_compute_loss(logits=logits, labels=labels) # make logits of shape (batch_size, num_labels, height, width) to # keep them consistent across APIs