diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 36c5d38710ef10..364bc15b905fcf 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -1260,6 +1260,8 @@ def forward( loss = None if labels is not None: + # move labels to correct device to enable PP + labels = labels.to(logits.device) raise NotImplementedError("Training is not yet supported.") if not return_dict: