You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
can anyone tell me why, in training a GPT-2 model on solving math problem in the GSM8k dataset, did this code make the loss calculable on the first part of the input sequence (which contains the question)? Why should we compute the GPT-2 loss on the question? Don't we need to compute the loss on only the generated answer? Thanks!
class GSMDataset(th.utils.data.Dataset):
def init(self, tokenizer, examples, loss_on_prefix=True):
self.examples = examples
self.qns = [ex["question"] for ex in self.examples]
self.ans = [ex["answer"] for ex in self.examples]
self.qns = tokenizer(self.qns, padding=False)
self.ans = tokenizer(self.ans, padding=False)
self.loss_on_prefix = loss_on_prefix
self.max_len = max(
[
len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i])
for i in range(len(self.examples))
]
)
print(f"Max tokens: {self.max_len}")
The text was updated successfully, but these errors were encountered:
Correct me if I'm wrong, but I think the code calculates the loss on the question because the idea is that the model can learn to predict each token of the question correctly, which may help the model understand the context better. Basically, it forces the model to learn the structure and context of the question more accurately, which can sometimes lead to better answer generation. If our goals was to generate a correct answer without any concern about how the question is interpreted, then focusing loss exclusively on the answer would be the best choice.
can anyone tell me why, in training a GPT-2 model on solving math problem in the GSM8k dataset, did this code make the loss calculable on the first part of the input sequence (which contains the question)? Why should we compute the GPT-2 loss on the question? Don't we need to compute the loss on only the generated answer? Thanks!
class GSMDataset(th.utils.data.Dataset):
def init(self, tokenizer, examples, loss_on_prefix=True):
self.examples = examples
self.qns = [ex["question"] for ex in self.examples]
self.ans = [ex["answer"] for ex in self.examples]
self.qns = tokenizer(self.qns, padding=False)
self.ans = tokenizer(self.ans, padding=False)
self.loss_on_prefix = loss_on_prefix
self.max_len = max(
[
len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i])
for i in range(len(self.examples))
]
)
print(f"Max tokens: {self.max_len}")
The text was updated successfully, but these errors were encountered: