Skip to content

Commit 77aa35e

Browse files
authored
Replace image classification loss functions to self.loss_function (#40764)
1 parent 797859c commit 77aa35e

37 files changed

+50
-762
lines changed

src/transformers/models/beit/modeling_beit.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
import torch.utils.checkpoint
2525
from torch import Tensor, nn
26-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26+
from torch.nn import CrossEntropyLoss
2727

2828
from ...activations import ACT2FN
2929
from ...modeling_layers import GradientCheckpointingLayer
@@ -1020,26 +1020,8 @@ def forward(
10201020

10211021
loss = None
10221022
if labels is not None:
1023-
if self.config.problem_type is None:
1024-
if self.num_labels == 1:
1025-
self.config.problem_type = "regression"
1026-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1027-
self.config.problem_type = "single_label_classification"
1028-
else:
1029-
self.config.problem_type = "multi_label_classification"
1030-
1031-
if self.config.problem_type == "regression":
1032-
loss_fct = MSELoss()
1033-
if self.num_labels == 1:
1034-
loss = loss_fct(logits.squeeze(), labels.squeeze())
1035-
else:
1036-
loss = loss_fct(logits, labels)
1037-
elif self.config.problem_type == "single_label_classification":
1038-
loss_fct = CrossEntropyLoss()
1039-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1040-
elif self.config.problem_type == "multi_label_classification":
1041-
loss_fct = BCEWithLogitsLoss()
1042-
loss = loss_fct(logits, labels)
1023+
loss = self.loss_function(labels, logits, self.config)
1024+
10431025
if not return_dict:
10441026
output = (logits,) + outputs[2:]
10451027
return ((loss,) + output) if loss is not None else output

src/transformers/models/bit/modeling_bit.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch
2323
import torch.utils.checkpoint
2424
from torch import Tensor, nn
25-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2625

2726
from ...activations import ACT2FN
2827
from ...modeling_outputs import (
@@ -744,25 +743,7 @@ def forward(
744743
loss = None
745744

746745
if labels is not None:
747-
if self.config.problem_type is None:
748-
if self.num_labels == 1:
749-
self.config.problem_type = "regression"
750-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
751-
self.config.problem_type = "single_label_classification"
752-
else:
753-
self.config.problem_type = "multi_label_classification"
754-
if self.config.problem_type == "regression":
755-
loss_fct = MSELoss()
756-
if self.num_labels == 1:
757-
loss = loss_fct(logits.squeeze(), labels.squeeze())
758-
else:
759-
loss = loss_fct(logits, labels)
760-
elif self.config.problem_type == "single_label_classification":
761-
loss_fct = CrossEntropyLoss()
762-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
763-
elif self.config.problem_type == "multi_label_classification":
764-
loss_fct = BCEWithLogitsLoss()
765-
loss = loss_fct(logits, labels)
746+
loss = self.loss_function(labels, logits, self.config)
766747

767748
if not return_dict:
768749
output = (logits,) + outputs[2:]

src/transformers/models/clip/modeling_clip.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import torch
2121
from torch import nn
22-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2322

2423
from ...activations import ACT2FN
2524
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
@@ -1220,28 +1219,7 @@ def forward(
12201219

12211220
loss = None
12221221
if labels is not None:
1223-
# move labels to correct device to enable model parallelism
1224-
labels = labels.to(logits.device)
1225-
if self.config.problem_type is None:
1226-
if self.num_labels == 1:
1227-
self.config.problem_type = "regression"
1228-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1229-
self.config.problem_type = "single_label_classification"
1230-
else:
1231-
self.config.problem_type = "multi_label_classification"
1232-
1233-
if self.config.problem_type == "regression":
1234-
loss_fct = MSELoss()
1235-
if self.num_labels == 1:
1236-
loss = loss_fct(logits.squeeze(), labels.squeeze())
1237-
else:
1238-
loss = loss_fct(logits, labels)
1239-
elif self.config.problem_type == "single_label_classification":
1240-
loss_fct = CrossEntropyLoss()
1241-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1242-
elif self.config.problem_type == "multi_label_classification":
1243-
loss_fct = BCEWithLogitsLoss()
1244-
loss = loss_fct(logits, labels)
1222+
loss = self.loss_function(labels, logits, self.config)
12451223

12461224
return ImageClassifierOutput(
12471225
loss=loss,

src/transformers/models/data2vec/modeling_data2vec_vision.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
import torch.utils.checkpoint
2525
from torch import nn
26-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26+
from torch.nn import CrossEntropyLoss
2727

2828
from ...activations import ACT2FN
2929
from ...modeling_layers import GradientCheckpointingLayer
@@ -935,26 +935,8 @@ def forward(
935935

936936
loss = None
937937
if labels is not None:
938-
if self.config.problem_type is None:
939-
if self.num_labels == 1:
940-
self.config.problem_type = "regression"
941-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
942-
self.config.problem_type = "single_label_classification"
943-
else:
944-
self.config.problem_type = "multi_label_classification"
945-
946-
if self.config.problem_type == "regression":
947-
loss_fct = MSELoss()
948-
if self.num_labels == 1:
949-
loss = loss_fct(logits.squeeze(), labels.squeeze())
950-
else:
951-
loss = loss_fct(logits, labels)
952-
elif self.config.problem_type == "single_label_classification":
953-
loss_fct = CrossEntropyLoss()
954-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
955-
elif self.config.problem_type == "multi_label_classification":
956-
loss_fct = BCEWithLogitsLoss()
957-
loss = loss_fct(logits, labels)
938+
loss = self.loss_function(labels, logits, self.config)
939+
958940
if not return_dict:
959941
output = (logits,) + outputs[2:]
960942
return ((loss,) + output) if loss is not None else output

src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.utils.checkpoint
2323
from torch import nn
24-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2524

2625
from ....activations import ACT2FN
2726
from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
@@ -660,26 +659,7 @@ def forward(
660659

661660
loss = None
662661
if labels is not None:
663-
if self.config.problem_type is None:
664-
if self.num_labels == 1:
665-
self.config.problem_type = "regression"
666-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
667-
self.config.problem_type = "single_label_classification"
668-
else:
669-
self.config.problem_type = "multi_label_classification"
670-
671-
if self.config.problem_type == "regression":
672-
loss_fct = MSELoss()
673-
if self.num_labels == 1:
674-
loss = loss_fct(logits.squeeze(), labels.squeeze())
675-
else:
676-
loss = loss_fct(logits, labels)
677-
elif self.config.problem_type == "single_label_classification":
678-
loss_fct = CrossEntropyLoss()
679-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
680-
elif self.config.problem_type == "multi_label_classification":
681-
loss_fct = BCEWithLogitsLoss()
682-
loss = loss_fct(logits, labels)
662+
loss = self.loss_function(labels, logits, self.config)
683663

684664
if not return_dict:
685665
output = (logits,) + outputs[1:]

src/transformers/models/deprecated/nat/modeling_nat.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.utils.checkpoint
2323
from torch import nn
24-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2524

2625
from ....activations import ACT2FN
2726
from ....modeling_outputs import BackboneOutput
@@ -810,26 +809,7 @@ def forward(
810809

811810
loss = None
812811
if labels is not None:
813-
if self.config.problem_type is None:
814-
if self.num_labels == 1:
815-
self.config.problem_type = "regression"
816-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
817-
self.config.problem_type = "single_label_classification"
818-
else:
819-
self.config.problem_type = "multi_label_classification"
820-
821-
if self.config.problem_type == "regression":
822-
loss_fct = MSELoss()
823-
if self.num_labels == 1:
824-
loss = loss_fct(logits.squeeze(), labels.squeeze())
825-
else:
826-
loss = loss_fct(logits, labels)
827-
elif self.config.problem_type == "single_label_classification":
828-
loss_fct = CrossEntropyLoss()
829-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
830-
elif self.config.problem_type == "multi_label_classification":
831-
loss_fct = BCEWithLogitsLoss()
832-
loss = loss_fct(logits, labels)
812+
loss = self.loss_function(labels, logits, self.config)
833813

834814
if not return_dict:
835815
output = (logits,) + outputs[2:]

src/transformers/models/deprecated/van/modeling_van.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.utils.checkpoint
2323
from torch import nn
24-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2524

2625
from ....activations import ACT2FN
2726
from ....modeling_outputs import (
@@ -510,26 +509,7 @@ def forward(
510509

511510
loss = None
512511
if labels is not None:
513-
if self.config.problem_type is None:
514-
if self.config.num_labels == 1:
515-
self.config.problem_type = "regression"
516-
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
517-
self.config.problem_type = "single_label_classification"
518-
else:
519-
self.config.problem_type = "multi_label_classification"
520-
521-
if self.config.problem_type == "regression":
522-
loss_fct = MSELoss()
523-
if self.config.num_labels == 1:
524-
loss = loss_fct(logits.squeeze(), labels.squeeze())
525-
else:
526-
loss = loss_fct(logits, labels)
527-
elif self.config.problem_type == "single_label_classification":
528-
loss_fct = CrossEntropyLoss()
529-
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
530-
elif self.config.problem_type == "multi_label_classification":
531-
loss_fct = BCEWithLogitsLoss()
532-
loss = loss_fct(logits, labels)
512+
loss = self.loss_function(labels, logits, self.config)
533513

534514
if not return_dict:
535515
output = (logits,) + outputs[2:]

src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.utils.checkpoint
2323
from torch import nn
24-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2524

2625
from ....activations import ACT2FN
2726
from ....modeling_layers import GradientCheckpointingLayer
@@ -725,28 +724,7 @@ def forward(
725724

726725
loss = None
727726
if labels is not None:
728-
# move labels to correct device to enable model parallelism
729-
labels = labels.to(logits.device)
730-
if self.config.problem_type is None:
731-
if self.num_labels == 1:
732-
self.config.problem_type = "regression"
733-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
734-
self.config.problem_type = "single_label_classification"
735-
else:
736-
self.config.problem_type = "multi_label_classification"
737-
738-
if self.config.problem_type == "regression":
739-
loss_fct = MSELoss()
740-
if self.num_labels == 1:
741-
loss = loss_fct(logits.squeeze(), labels.squeeze())
742-
else:
743-
loss = loss_fct(logits, labels)
744-
elif self.config.problem_type == "single_label_classification":
745-
loss_fct = CrossEntropyLoss()
746-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
747-
elif self.config.problem_type == "multi_label_classification":
748-
loss_fct = BCEWithLogitsLoss()
749-
loss = loss_fct(logits, labels)
727+
loss = self.loss_function(labels, logits, self.config)
750728

751729
if not return_dict:
752730
output = (logits,) + outputs[1:]

src/transformers/models/dinat/modeling_dinat.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.utils.checkpoint
2323
from torch import nn
24-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2524

2625
from ...activations import ACT2FN
2726
from ...modeling_outputs import BackboneOutput
@@ -736,26 +735,7 @@ def forward(
736735

737736
loss = None
738737
if labels is not None:
739-
if self.config.problem_type is None:
740-
if self.num_labels == 1:
741-
self.config.problem_type = "regression"
742-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
743-
self.config.problem_type = "single_label_classification"
744-
else:
745-
self.config.problem_type = "multi_label_classification"
746-
747-
if self.config.problem_type == "regression":
748-
loss_fct = MSELoss()
749-
if self.num_labels == 1:
750-
loss = loss_fct(logits.squeeze(), labels.squeeze())
751-
else:
752-
loss = loss_fct(logits, labels)
753-
elif self.config.problem_type == "single_label_classification":
754-
loss_fct = CrossEntropyLoss()
755-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
756-
elif self.config.problem_type == "multi_label_classification":
757-
loss_fct = BCEWithLogitsLoss()
758-
loss = loss_fct(logits, labels)
738+
loss = self.loss_function(labels, logits, self.config)
759739

760740
if not return_dict:
761741
output = (logits,) + outputs[2:]

src/transformers/models/donut/modeling_donut_swin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def forward(
10141014

10151015
loss = None
10161016
if labels is not None:
1017-
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
1017+
loss = self.loss_function(labels, logits, self.config)
10181018

10191019
if not return_dict:
10201020
output = (logits,) + outputs[2:]

0 commit comments

Comments
 (0)