Skip to content

Commit

Permalink
Update data_collator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhui Dih Lee authored and Rhui Dih Lee committed Jun 26, 2024
1 parent 3ec5fe7 commit 7d45ef1
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,12 +1624,16 @@ def __init__(self, return_position_ids=True):
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = dict(input_ids=[], labels=[])
if self.return_position_ids:
ret.update(dict(position_ids=[]))
for idx in range(0,len(features)):
ret["input_ids"] += features[idx]["input_ids"]
ret["labels"] += [-100] + features[idx]["labels"][1:]
if is_labels_provided:
ret["labels"] += [-100] + features[idx]["labels"][1:]
else:
ret["labels"] += [-100] + features[idx]["input_ids"][1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
return default_data_collator([ret], return_tensors)

0 comments on commit 7d45ef1

Please sign in to comment.