-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil.py
226 lines (214 loc) · 7.81 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import argparse
from collections import defaultdict
from typing import Optional, Dict
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch
from transformers import TrainingArguments, EvalPrediction
class HFDataset(torch.utils.data.Dataset):
"""Dataset for using HuggingFace Transformers."""
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def compute_metrics(
pred: EvalPrediction, idx_to_docid: Optional[Dict[int, int]] = None
):
labels = pred.label_ids
# Sometimes the output is a tuple, take first argument then.
if isinstance(pred.predictions, tuple):
pred = pred.predictions[0]
else:
pred = pred.predictions
preds = pred.argmax(-1)
if idx_to_docid is not None:
# Majority voting: take the most common prediction per document.
assert len(idx_to_docid) == len(preds), f"{len(idx_to_docid)} vs {len(preds)}"
docid_to_preds = defaultdict(list)
docid_to_label = dict()
for idx, (p, l) in enumerate(zip(preds, labels)):
docid = idx_to_docid[idx]
docid_to_preds[docid].append(p)
docid_to_label[docid] = l
preds_new = []
for docid, doc_preds in docid_to_preds.items():
# Take the majority prediction.
perc = sum(doc_preds) / len(doc_preds)
preds_new.append(1 if perc >= 0.5 else 0)
preds = np.array(preds_new)
labels = np.array(list(docid_to_label.values()))
precision, recall, f1, _ = precision_recall_fscore_support(
labels, preds, average="macro"
)
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
def get_training_arguments(args):
"""Load all training arguments here. There are a lot more not specified, check:
https://github.com/huggingface/transformers/blob/master/src/transformers/training_args.py#L72"""
return TrainingArguments(
output_dir=args.output_dir,
evaluation_strategy=args.strategy,
eval_steps=args.eval_steps,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size * 2,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
max_grad_norm=args.max_grad_norm,
gradient_accumulation_steps=args.gradient_accumulation_steps,
fp16=args.use_fp16,
num_train_epochs=args.num_epochs,
warmup_steps=args.warmup_steps,
logging_strategy=args.strategy,
logging_steps=args.logging_steps,
save_strategy=args.strategy,
save_steps=args.save_steps,
seed=args.seed,
load_best_model_at_end=True,
label_smoothing_factor=args.label_smoothing,
log_level="debug",
metric_for_best_model="accuracy",
save_total_limit=2,
)
def parse_args_hf():
"""
Parse CLI arguments for the script and return them.
:return: Namespace of parsed CLI arguments.
"""
parser = argparse.ArgumentParser(
description="Arguments for running the classifier."
)
parser.add_argument(
"--root_dir",
type=str,
default="./experiments/8",
help="Location of the root directory. By default, this is "
"the data from WMT08-19, without Translationese.",
)
parser.add_argument(
"--load_model",
type=str,
help="Initialize training from the model specified at this " "path location.",
)
parser.add_argument(
"--arch",
type=str,
help=("Huggingface transformer architecture to use, " "e.g. `bert-base-cased`"),
)
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-5)
parser.add_argument(
"-wd", "--weight_decay", default=0, type=float, help="Weight decay"
)
parser.add_argument(
"-mgn", "--max_grad_norm", default=1, type=float, help="Max grad norm"
)
# parser.add_argument("-wr", "--warmup_ratio", default=0.1, type=float,
# help="Ratio of total training steps used for a linear warmup "
# "from 0 to learning_rate.")
parser.add_argument(
"-wr",
"--warmup_steps",
default=200,
type=int,
help="Number of steps used for a linear warmup from 0 to " "learning_rate",
)
parser.add_argument(
"-ls",
"--label_smoothing",
default=0.0,
type=float,
help="Label smoothing percentage, 0-1",
)
parser.add_argument(
"-dr",
"--dropout",
default=0.1,
type=float,
help="Dropout applied to the classifier layer",
)
parser.add_argument(
"-str",
"--strategy",
type=str,
choices=["no", "steps", "epoch"],
default="steps",
help="Strategy for evaluating/saving/logging",
)
parser.add_argument(
"--eval_steps",
type=int,
default=200,
help="Number of update steps between two evaluations",
)
parser.add_argument(
"--logging_steps",
type=int,
default=200,
help="Number of update steps between two logs",
)
parser.add_argument(
"--save_steps",
type=int,
default=200,
help="Number of update steps before two checkpoints saves",
)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs")
parser.add_argument(
"--early_stopping_patience", type=int, default=3, help="Early stopping patience"
)
parser.add_argument(
"--use_fp16", action="store_true", help="Use mixed 16-bit precision"
)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--max_length", type=int, default=1e30)
parser.add_argument(
"--load_sentence_pairs",
action="store_true",
help="Set this flag to classify HT vs. MT for "
"source/translation pairs, rather than just "
"translations.",
)
parser.add_argument(
"--use_google_data",
action="store_true",
help="Use Google Translate data instead of DeepL data for train/dev/test.",
)
parser.add_argument(
"--use_normalized_data",
action="store_true",
help="Use translations that have been post-processed by applying "
"a Moses normalization script to them. Right now only works for "
"monolingual sentences",
)
parser.add_argument(
"--use_majority_classification",
action="store_true",
help="Make predictions by predicting each segment in a "
"document and taking the majority prediction. This is "
"only used for evaluating an already trained "
"sentence-level model on documents.",
)
parser.add_argument(
"--test",
type=str,
choices=["deepl", "google", "wmt1", "wmt2", "wmt3", "wmt4"],
help="Test a classifier on one of the test sets. For WMT "
"submissions there are 4 options, originating from the "
"WMT 19 test set. Along with their DA scores:"
"- wmt1: Facebook-FAIR (best, 81.6)"
"- wmt2: RWTH-Aachen (2nd best, 81.5)"
"- wmt3: PROMPT-NMT (2nd worst, 71.8)"
"- wmt4: online-X (worst, 69.7)",
)
parser.add_argument(
"--eval", action="store_true", help="Evaluate on dev set using a trained model"
)
parser.add_argument(
"--seed", type=int, default=1, help="Random number generator seed."
)
return parser.parse_args()