forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
78 lines (68 loc) Β· 3.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import numpy as np
import paddle
from paddlenlp.utils.log import logger
def read_local_dataset(data_path, data_file=None, is_test=False):
"""
Load datasets with one example per line, formated as:
{"text_a": X, "text_b": X, "question": X, "choices": [A, B], "labels": [0, 1]}
"""
if data_file is not None:
file_paths = [os.path.join(data_path, fname) for fname in os.listdir(data_path) if fname.endswith(data_file)]
else:
file_paths = [data_path]
skip_count = 0
for file_path in file_paths:
with open(file_path, "r", encoding="utf-8") as fp:
for example in fp:
example = json.loads(example.strip())
if len(example["choices"]) < 2 or not isinstance(example["text_a"], str) or len(example["text_a"]) < 3:
skip_count += 1
continue
if "text_b" not in example:
example["text_b"] = ""
if not is_test or "labels" in example:
if not isinstance(example["labels"], list):
example["labels"] = [example["labels"]]
one_hots = np.zeros(len(example["choices"]), dtype="float32")
for x in example["labels"]:
one_hots[x] = 1
example["labels"] = one_hots.tolist()
if is_test:
yield example
continue
std_keys = ["text_a", "text_b", "question", "choices", "labels"]
std_example = {k: example[k] for k in std_keys if k in example}
yield std_example
logger.warning(f"Skip {skip_count} examples.")
class UTCLoss(object):
def __call__(self, logit, label):
return self.forward(logit, label)
def forward(self, logit, label):
logit = (1.0 - 2.0 * label) * logit
logit_neg = logit - label * 1e12
logit_pos = logit - (1.0 - label) * 1e12
zeros = paddle.zeros_like(logit[..., :1])
logit_neg = paddle.concat([logit_neg, zeros], axis=-1)
logit_pos = paddle.concat([logit_pos, zeros], axis=-1)
label = paddle.concat([label, zeros], axis=-1)
logit_neg[label == -100] = -1e12
logit_pos[label == -100] = -1e12
neg_loss = paddle.logsumexp(logit_neg, axis=-1)
pos_loss = paddle.logsumexp(logit_pos, axis=-1)
loss = (neg_loss + pos_loss).mean()
return loss