Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix binary classification for tensorflow segformer

fix binary classification for tf segformer huggingface#2
  • Loading branch information
nikolaJovisic committed Aug 22, 2023
1 parent 3629190 commit 0b534e6
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,24 @@ def hf_compute_loss(self, logits, labels):

upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
# compute weighted loss
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")

if self.config.num_labels > 1:
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
valid_mask = tf.cast(
(labels >= 0) & (labels != self.config.semantic_loss_ignore_index),
dtype=tf.float32
)
loss_fct = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE
)
loss = loss_fct(labels, upsampled_logits[:, 0]) # Assuming channel dimension is last
loss = tf.reduce_mean(loss * valid_mask)
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")


def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
Expand Down Expand Up @@ -829,10 +846,7 @@ def call(

loss = None
if labels is not None:
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.hf_compute_loss(logits=logits, labels=labels)
loss = self.hf_compute_loss(logits=logits, labels=labels)

# make logits of shape (batch_size, num_labels, height, width) to
# keep them consistent across APIs
Expand Down

0 comments on commit 0b534e6

Please sign in to comment.