From d04ec99bec8a0b432fc03ed60cea9a1a20ebaf3c Mon Sep 17 00:00:00 2001 From: SUSHMANTH REDDY <73489688+sushmanthreddy@users.noreply.github.com> Date: Sat, 22 Apr 2023 05:31:25 +0530 Subject: [PATCH] vilt_model (#22930) --- src/transformers/models/vilt/modeling_vilt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 36c5d38710ef10..364bc15b905fcf 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -1260,6 +1260,8 @@ def forward( loss = None if labels is not None: + # move labels to correct device to enable PP + labels = labels.to(logits.device) raise NotImplementedError("Training is not yet supported.") if not return_dict: