Skip to content

Commit 1564189

Browse files
authored
feat(model parallelism): moving the labels to the same device as the logits for gpt2 and bart (#22591)
1 parent e577bd0 commit 1564189

File tree

9 files changed

+17
-0
lines changed

9 files changed

+17
-0
lines changed

src/transformers/models/bart/modeling_bart.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,7 @@ def forward(
13981398

13991399
masked_lm_loss = None
14001400
if labels is not None:
1401+
labels = labels.to(lm_logits.device)
14011402
loss_fct = CrossEntropyLoss()
14021403
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
14031404

@@ -1553,6 +1554,7 @@ def forward(
15531554

15541555
loss = None
15551556
if labels is not None:
1557+
labels = labels.to(logits.device)
15561558
if self.config.problem_type is None:
15571559
if self.config.num_labels == 1:
15581560
self.config.problem_type = "regression"
@@ -1896,6 +1898,7 @@ def forward(
18961898

18971899
loss = None
18981900
if labels is not None:
1901+
labels = labels.to(logits.device)
18991902
loss_fct = CrossEntropyLoss()
19001903
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
19011904

src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,6 +2581,7 @@ def forward(
25812581

25822582
masked_lm_loss = None
25832583
if labels is not None:
2584+
labels = labels.to(lm_logits.device)
25842585
loss_fct = CrossEntropyLoss()
25852586
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
25862587

@@ -2735,6 +2736,7 @@ def forward(
27352736

27362737
loss = None
27372738
if labels is not None:
2739+
labels = labels.to(logits.device)
27382740
if self.config.problem_type is None:
27392741
if self.config.num_labels == 1:
27402742
self.config.problem_type = "regression"

src/transformers/models/blenderbot/modeling_blenderbot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,7 @@ def forward(
15961596

15971597
loss = None
15981598
if labels is not None:
1599+
labels = labels.to(logits.device)
15991600
loss_fct = CrossEntropyLoss()
16001601
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
16011602

src/transformers/models/blenderbot_small/modeling_blenderbot_small.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,7 @@ def forward(
15631563

15641564
loss = None
15651565
if labels is not None:
1566+
labels = labels.to(logits.device)
15661567
loss_fct = CrossEntropyLoss()
15671568
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
15681569

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,8 @@ def forward(
10981098

10991099
loss = None
11001100
if labels is not None:
1101+
# move labels to correct device to enable model parallelism
1102+
labels = labels.to(lm_logits.device)
11011103
# Shift so that tokens < n predict n
11021104
shift_logits = lm_logits[..., :-1, :].contiguous()
11031105
shift_labels = labels[..., 1:].contiguous()
@@ -1318,6 +1320,7 @@ def forward(
13181320
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
13191321
lm_loss = None
13201322
if labels is not None:
1323+
labels = labels.to(lm_logits.device)
13211324
shift_logits = lm_logits[..., :-1, :].contiguous()
13221325
shift_labels = labels[..., 1:].contiguous()
13231326
loss_fct = CrossEntropyLoss()
@@ -1569,6 +1572,7 @@ def forward(
15691572

15701573
loss = None
15711574
if labels is not None:
1575+
labels = labels.to(logits.device)
15721576
loss_fct = CrossEntropyLoss()
15731577
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
15741578

src/transformers/models/marian/modeling_marian.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,6 +1715,7 @@ def forward(
17151715

17161716
loss = None
17171717
if labels is not None:
1718+
labels = labels.to(logits.device)
17181719
loss_fct = CrossEntropyLoss()
17191720
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
17201721

src/transformers/models/mbart/modeling_mbart.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,7 @@ def forward(
15281528

15291529
loss = None
15301530
if labels is not None:
1531+
labels = labels.to(logits.device)
15311532
if self.config.problem_type is None:
15321533
if self.config.num_labels == 1:
15331534
self.config.problem_type = "regression"
@@ -1866,6 +1867,7 @@ def forward(
18661867

18671868
loss = None
18681869
if labels is not None:
1870+
labels = labels.to(logits.device)
18691871
loss_fct = CrossEntropyLoss()
18701872
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
18711873

src/transformers/models/pegasus/modeling_pegasus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,7 @@ def forward(
16941694

16951695
loss = None
16961696
if labels is not None:
1697+
labels = labels.to(logits.device)
16971698
loss_fct = CrossEntropyLoss()
16981699
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
16991700

src/transformers/models/plbart/modeling_plbart.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,7 @@ def forward(
14991499

15001500
loss = None
15011501
if labels is not None:
1502+
labels = labels.to(logits.device)
15021503
if self.config.problem_type is None:
15031504
if self.config.num_labels == 1:
15041505
self.config.problem_type = "regression"
@@ -1713,6 +1714,7 @@ def forward(
17131714

17141715
loss = None
17151716
if labels is not None:
1717+
labels = labels.to(logits.device)
17161718
loss_fct = CrossEntropyLoss()
17171719
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
17181720

0 commit comments

Comments
 (0)