-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
147 lines (135 loc) · 5.23 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse, json
from transformers import BertForTokenClassification
from rich import print
from ddaugner.predict import predict
from ddaugner.score import score_ner
from ddaugner.train import train_ner_model
from ddaugner.datas.conll import CoNLLDataset
from ddaugner.datas.aug import all_augmenters
from ddaugner.utils import flattened
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-en", "--epochs-nb", type=float, default=2.0)
parser.add_argument(
"-den",
"--dynamic-epochs-nb",
action="store_true",
help="If specified, the number of epochs (which is a float) will be modified according to the number of augmented examples added to the model, so that the model sees as much examples in total as he would have seen without augmentation.",
)
parser.add_argument("-bz", "--batch-size", type=int, default=4)
parser.add_argument("-bm", "--batch-mode", action="store_true")
parser.add_argument("-mp", "--model-path", type=str, default="./ner_model")
parser.add_argument(
"-cw",
"--custom-weights",
type=json.loads,
help="Custom class weights, as a json dictionay. Exemple : python train.py -cw '{\"B-PER\": 2}'",
default=None,
)
parser.add_argument("-cs", "--context-size", type=int, default=0)
parser.add_argument(
"-cp",
"--conll-percentage",
type=float,
default=1.0,
help="percentage of conll to use for training, between 0 and 1 (who would use 0 though ?)",
)
parser.add_argument(
"-koc",
"--keep-only-classes",
nargs="*",
default=None,
help="A list of classes to keep at training time, separated by spaces. Exemple : 'PER ORG'",
)
parser.add_argument("-tmp", "--test-metrics-path", type=str, default=None)
parser.add_argument(
"-das",
"--data-aug-strategies",
default="{}",
help=f"a json dictionary mapping a NER class to a list of replacement strategies (available strategies : {list(all_augmenters.keys())})",
)
parser.add_argument(
"-daf",
"--data-aug-frequencies",
default="{}",
help="a json dictionary mapping a NER class to a list of frequencies for the given replacement strategies, in order",
)
parser.add_argument(
"-dam",
"--data-aug-method",
default="standard",
type=str,
help="augmentation method. One of 'standard', 'replace'or 'balance_upsample'",
)
args = parser.parse_args()
print("running with config : ")
print(vars(args))
# augmentation parsing
data_aug_strategies = json.loads(args.data_aug_strategies)
augmenters = {}
for ner_class, strategies in data_aug_strategies.items():
assert all([strategy in all_augmenters.keys()] for strategy in strategies)
augmenters[ner_class] = [all_augmenters[strategy]() for strategy in strategies]
data_aug_frequencies = json.loads(args.data_aug_frequencies)
for ner_class, frequencies in data_aug_frequencies.items():
assert len(frequencies) == len(augmenters[ner_class])
# * dataset loading
# ** train dataset
train = CoNLLDataset.train_dataset(
augmenters,
data_aug_frequencies,
context_size=args.context_size,
usage_percentage=args.conll_percentage,
keep_only_classes=args.keep_only_classes,
aug_method=args.data_aug_method,
)
# *** dynamic epochs nb
if args.dynamic_epochs_nb:
args.epochs_nb = args.epochs_nb / (
train.augmented_sents_nb / train.original_sents_nb
)
# ** valid dataset
valid = CoNLLDataset.valid_dataset(
{}, {}, context_size=args.context_size, keep_only_classes=args.keep_only_classes
)
# ** test dataset
test = CoNLLDataset.test_dataset(
{}, {}, context_size=args.context_size, keep_only_classes=args.keep_only_classes
)
assert train.tags_nb == test.tags_nb
# model loading
model = BertForTokenClassification.from_pretrained(
"bert-base-cased",
num_labels=train.tags_nb,
label2id=train.tag_to_id,
id2label={v: k for k, v in train.tag_to_id.items()},
)
if args.custom_weights:
weights = [1.0 for _ in train.tags]
for tag, weight in args.custom_weights.items():
weights[train.tag_to_id[tag]] = weight
# training
model = train_ner_model(
model,
train,
valid,
epochs_nb=args.epochs_nb,
batch_size=args.batch_size,
quiet=args.batch_mode,
custom_weights=weights if args.custom_weights else None, # type: ignore
)
# test inference
predictions = predict(model, test, args.batch_size, quiet=args.batch_mode)
precision, recall, f1 = score_ner(
test.sents, predictions, ignored_classes={"MISC", "ORG", "LOC"}
)
metrics_dict = {"precision": precision, "recall": recall, "f1": f1}
print("test metrics : ")
print(metrics_dict)
if not args.test_metrics_path is None:
with open(args.test_metrics_path, "w") as f:
json.dump(metrics_dict, f, indent=4)
# save model
model.save_pretrained(args.model_path) # type: ignore