Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why do we make the loss computable on the first part of the input sequence (the question)? #27

Open
thuann2cats opened this issue Aug 4, 2024 · 1 comment

Comments

@thuann2cats
Copy link

can anyone tell me why, in training a GPT-2 model on solving math problem in the GSM8k dataset, did this code make the loss calculable on the first part of the input sequence (which contains the question)? Why should we compute the GPT-2 loss on the question? Don't we need to compute the loss on only the generated answer? Thanks!

class GSMDataset(th.utils.data.Dataset):
    def init(self, tokenizer, examples, loss_on_prefix=True):
        self.examples = examples
        self.qns = [ex["question"] for ex in self.examples]
        self.ans = [ex["answer"] for ex in self.examples]
        self.qns = tokenizer(self.qns, padding=False)
        self.ans = tokenizer(self.ans, padding=False)
        self.loss_on_prefix = loss_on_prefix
        self.max_len = max(
            [
                len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i])
                for i in range(len(self.examples))
            ]
        )
        print(f"Max tokens: {self.max_len}")

@lokashrinav
Copy link

Correct me if I'm wrong, but I think the code calculates the loss on the question because the idea is that the model can learn to predict each token of the question correctly, which may help the model understand the context better. Basically, it forces the model to learn the structure and context of the question more accurately, which can sometimes lead to better answer generation. If our goals was to generate a correct answer without any concern about how the question is interpreted, then focusing loss exclusively on the answer would be the best choice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants