Skip to content

Commit

Permalink
Fix for batched tokenization.
Browse files Browse the repository at this point in the history
  • Loading branch information
Davidyz committed Jun 9, 2024
1 parent 6045e0b commit cab4141
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions llm_unlearn_ucl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ def create_symbolic_dataloader_from_dataset(
def preprocess(examples):
results = {"input_ids": [], "attention_mask": [], "start_locs": []}

prompt = examples["input"]
output = examples["output"]
text = f"### Question: {prompt} ### Answer: {output}"
for i in range(len(examples["input"])):
prompt = examples["input"][i]
output = examples["output"][i]
text = f"### Question: {prompt} ### Answer: {output}"

tokenized = tokenizer(text, truncation=True, padding="max_length")
results["input_ids"].append(tokenized["input_ids"])
results["attention_mask"].append(tokenized["attention_mask"])
# Calculate start idx for answer
test_text = f"### Question: {prompt} ### Answer: "
test_tokenized = tokenizer(test_text, truncation=True, padding="max_length")
results["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
tokenized = tokenizer(text, truncation=True, padding="max_length")
results["input_ids"].append(tokenized["input_ids"])
results["attention_mask"].append(tokenized["attention_mask"])
# Calculate start idx for answer
test_text = f"### Question: {prompt} ### Answer: "
test_tokenized = tokenizer(test_text, truncation=True, padding="max_length")
results["start_locs"].append(len(test_tokenized["input_ids"]) - 1)

return results

Expand Down

0 comments on commit cab4141

Please sign in to comment.