Skip to content

Commit

Permalink
Revert "fix huggingface#4"
Browse files Browse the repository at this point in the history
This reverts commit 0b534e6.
  • Loading branch information
nikolaJovisic committed Aug 22, 2023
1 parent 98b0663 commit 5b097ec
Showing 1 changed file with 5 additions and 19 deletions.
24 changes: 5 additions & 19 deletions src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,24 +759,7 @@ def hf_compute_loss(self, logits, labels):

upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
# compute weighted loss

if self.config.num_labels > 1:
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
valid_mask = tf.cast(
(labels >= 0) & (labels != self.config.semantic_loss_ignore_index),
dtype=tf.float32
)
loss_fct = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE
)
loss = loss_fct(labels, upsampled_logits[:, 0]) # Assuming channel dimension is last
loss = tf.reduce_mean(loss * valid_mask)
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")

loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")

def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
Expand Down Expand Up @@ -846,7 +829,10 @@ def call(

loss = None
if labels is not None:
loss = self.hf_compute_loss(logits=logits, labels=labels)
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.hf_compute_loss(logits=logits, labels=labels)

# make logits of shape (batch_size, num_labels, height, width) to
# keep them consistent across APIs
Expand Down

0 comments on commit 5b097ec

Please sign in to comment.