Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the accuracy issue of left padding and right padding #29419

Closed
2 tasks
hijkzzz opened this issue Mar 4, 2024 · 2 comments
Closed
2 tasks

the accuracy issue of left padding and right padding #29419

hijkzzz opened this issue Mar 4, 2024 · 2 comments

Comments

@hijkzzz
Copy link

hijkzzz commented Mar 4, 2024

System Info

transformers v4.38.2
docker container: nvcr.io/nvidia/pytorch:23.12-py3

Who can help?

@ArthurZucker
@younesbelkada

Information

The output results of the left and right pading are inconsistent

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# any llama2 model
modelname = "OpenLLMAI/Llama-2-7b-sft-model-ocra-500k"
model = AutoModelForCausalLM.from_pretrained(modelname).cuda()

# left pad
inputs={'input_ids': torch.tensor([[    1,  7251,   727, 29901, 29871],
        [    2,     2,     1, 29871, 29896]]).cuda(), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1]]).cuda()}
# right pad
inputs2={'input_ids': torch.tensor([[    1,  7251,   727, 29901, 29871],
        [    1, 29871, 29896,     2,     2]]).cuda(), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0]]).cuda()}

# baseline
output = model(**inputs)
output2 = model(**inputs2)

output2.logits[1][:3] - output.logits[1][-3:]
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0020, -0.0010, -0.0001,  ...,  0.0006,  0.0013,  0.0007],
        [ 0.0025,  0.0040, -0.0005,  ...,  0.0025,  0.0015,  0.0008]],
       device='cuda:0', grad_fn=<SubBackward0>)

# fixed positions
position_ids = inputs['attention_mask'].long().cumsum(-1) - 1
position_ids2 = inputs2['attention_mask'].long().cumsum(-1) - 1

output = model(**inputs, position_ids=position_ids)
output2 = model(**inputs2, position_ids=position_ids2)

output2.logits[1][:3] - output.logits[1][-3:]
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 9.3555e-04, -8.1062e-06, -8.5831e-05,  ...,  7.9441e-04,
          5.7936e-04,  4.6229e-04]], device='cuda:0', grad_fn=<SubBackward0>)

Expected behavior

no accuracy issue

@hijkzzz
Copy link
Author

hijkzzz commented Mar 4, 2024

related issue: OpenRLHF/OpenRLHF#217

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 4, 2024

A duplicate of #25921 and #25420 I am going to close this, feel free to read this great comment: #25420 (comment)
TLDR: it is expected when you pad inputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants