-
Notifications
You must be signed in to change notification settings - Fork 63
/
collator.py
86 lines (71 loc) · 3.07 KB
/
collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(
self, features: Sequence[Dict[str,
Any]]) -> Dict[str, 'torch.Tensor']:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
for key in ('chosen', 'rejected'):
for feature in features:
target_feature = {
'input_ids': feature['{}_input_ids'.format(key)],
'attention_mask': feature['{}_attention_mask'.format(key)],
'labels': feature['{}_labels'.format(key)],
}
if 'pixel_values' in feature:
target_feature['pixel_values'] = feature['pixel_values']
if '{}_token_type_ids'.format(key) in feature:
target_feature['token_type_ids'] = feature[
'{}_token_type_ids'.format(key)]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(
self, features: Sequence[Dict[str,
Any]]) -> Dict[str, 'torch.Tensor']:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
target_feature = {
'input_ids': feature['input_ids'],
'attention_mask': feature['attention_mask'],
'labels': feature['labels'],
}
kl_feature = {
'input_ids': feature['kl_input_ids'],
'attention_mask': feature['kl_attention_mask'],
'labels': feature['kl_labels'],
}
if 'pixel_values' in feature:
target_feature['pixel_values'] = feature['pixel_values']
if 'token_type_ids' in feature:
target_feature['token_type_ids'] = feature['token_type_ids']
kl_feature['token_type_ids'] = feature['kl_token_type_ids']
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature['kto_tags'])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch['kl_input_ids'] = kl_batch['input_ids']
batch['kl_attention_mask'] = kl_batch['attention_mask']
batch['kl_labels'] = kl_batch['labels']
if 'token_type_ids' in batch:
batch['kl_token_type_ids'] = kl_batch['token_type_ids']
batch['kto_tags'] = torch.tensor(kto_tags)
return batch