Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is last token dropped in loglikelihood computation? Gives different result than when calculating loss. #942

Closed
sorenmulli opened this issue Oct 23, 2023 · 4 comments

Comments

@sorenmulli
Copy link

Question

In

(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],

(and refactor:
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
),
the last input token is dropped before the model call.

This is motivated by this diagram:

# how this all works:
#          CTX      CONT
# inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
# gpt2    \               \
# logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
# cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

I must admit that I do not understand why this is: Does anyone have som pointers as to why removing this yields correct probabilities (surely the value of the last token matters for the overall likelihood?).

Minimal Example

The below computation shows that I can reproduce the result of _loglikelihood_tokens only if I remove the [:-1], otherwise there is a difference from the last token:

import torch
import torch.nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_eval.models.huggingface import AutoCausalLM


lm_key = "sshleifer/tiny-gpt2"
context = "we are the"
cont = " koala bears of the world"

model = AutoModelForCausalLM.from_pretrained(lm_key)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(lm_key)


encodings = tokenizer(context, text_target=cont, return_tensors="pt")
input_ids = torch.cat((encodings.input_ids, encodings.labels), dim=1)
target_ids = input_ids.clone()
# Makes context ignored by loss function
target_ids[:, : encodings.input_ids.size(1)] = -100

with torch.no_grad():
    logits = model(input_ids).logits

# Move vocab dimension last as we do classification over these
logits = logits.permute(0, 2, 1)
losses = torch.nn.CrossEntropyLoss(reduction="none")(logits, target_ids)
print(-losses.sum().item())
# Result: -64.88402557373047

lm_eval_model = AutoCausalLM(lm_key, device="cpu")
print(lm_eval_model.loglikelihood([(context, cont)])[0][0])
# Result: -65.07632446289062
# If I remove the `[:-1]` in _loglikelihood_tokens:
# -64.88401794433594

Similar issues

A similar question was asked in #337 where @jon-tow, who asked the question, closed with the message

Update: I confused position indexing (next-token distribution)

@haileyschoelkopf
Copy link
Collaborator

haileyschoelkopf commented Oct 24, 2023

Hi! I'll add a further note on this to the comment and the documentation, as this is a frequent question.

The reason behind chopping off the last completion token is that for autoregressive LLMs, they take in tokens up to position N and return a logit distribution for position N+1. Therefore, the logit the model assigns to token N is obtained by feeding in 0 1 2 3.... (N - 1) and then taking the last logit position--this is the logit for the Nth token.

When we're feeding in

0 1 2 3 | 4 

what we want is the logits predicting 4. To get the logits for 4 conditioned on 0 1 2 3 we must feed 0 1 2 3 in without passing in 4. Then, the final logits index is the predicted distribution over tokens at the 4 position, which is what we wanted! The same applies for multi-token continuations.

Leaving open until I update the documentation. If this doesn't make sense happy to clarify further!

@sasaadi
Copy link

sasaadi commented Oct 24, 2023

For multi-token continuations, do we only drop the last token? if the input is 0 1 2 3 and the continuation is 4 5 6, do we condition on 0 1 2 3 4 5?
Thanks

@sorenmulli
Copy link
Author

Thank you very much, @haileyschoelkopf for a swift reply!
And for a good explanation of the indexing.

IDK why I thought that calling cross entropy loss on the logits would magically handle this for me, this shifting is of course also implemented in decoder model losses. For completeness, I have updated my little code snippet such that it gives the same result

import torch
import torch.nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_eval.models.huggingface import AutoCausalLM


lm_key = "sshleifer/tiny-gpt2"
context = "we are the"
cont = " koala bears of the world"

model = AutoModelForCausalLM.from_pretrained(lm_key)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(lm_key)


encodings = tokenizer(context, text_target=cont, return_tensors="pt")
input_ids = torch.cat((encodings.input_ids, encodings.labels), dim=1)
target_ids = input_ids.clone()
# Makes context ignored by loss function
target_ids[:, : encodings.input_ids.size(1)] = -100


with torch.no_grad():
    logits = model(input_ids).logits

# Move vocab dimension last as we do classification over these
logits = logits.permute(0, 2, 1)
# Task: Next-token-prediction => shift tokens
target_ids = target_ids[:, 1:]
logits = logits[:, :, :-1]
losses = torch.nn.CrossEntropyLoss(reduction="none")(logits, target_ids)
print(-losses.sum().item())
# Result: -65.07633972167969
lm_eval_model = AutoCausalLM(lm_key, device="cpu")
print(lm_eval_model.loglikelihood([(context, cont)])[0][0])
# Result: -65.07632446289062

# Same results - yay!

@haileyschoelkopf
Copy link
Collaborator

Glad this is helpful!!

For multi-token continuations, do we only drop the last token? if the input is 0 1 2 3 and the continuation is 4 5 6, do we condition on 0 1 2 3 4 5?
Thanks

@sasaadi yes, we would feed 0 1 2 3 4 5 into the model, which will then give us out logits of shape (seqlen, vocabsize) = (6, vocabsize). The last sequence position of these logits is the logit for the model to predict the 6 position conditioned on up to 5, and the second-to-last sequence position would give the prediction for 5 conditioned on 0 1 2 3 4, and so on.

So if the continuation is 4 5 6, we want:

  • last logit position (predicting 6, conditional on 0 1 2 3 4 5)
  • second-to-last (predicting 5, conditional on 0 1 2 3 4)
  • third-to-last (predicting 4, conditional on 0 1 2 3)
    and we don't care about how likely the model would be to generate the input/context.

And the loglikelihood of the completion is the loglikelihood of producing all 3 completion tokens in turn, starting from 0 1 2 3. so to get the probability we'd multiply the probs of producing each completion token, or add the log-probabilities of producing each completion token assuming we got the previous ones right.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants