-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtransformer.jsonnet
95 lines (91 loc) · 3.58 KB
/
transformer.jsonnet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
function (
num_epochs = 5,
batch_size = 8,
lr = 2e-5,
adam_epsilon = 1e-8,
weight_decay = 0.0,
warmup_steps = 0,
model_name = "bert-base-uncased",
do_lower_case = true,
max_length = 128,
train_data_path = "../ds/tacred/data/json/train.json",
validation_data_path = "../ds/tacred/data/json/dev.json",
negative_label = "no_relation",
entity_handling = "mark_entity_append_ner",
fp16 = false,
cuda_device = 0,
max_instances = null,
) {
local tokenizer_kwargs = {
"do_lower_case": do_lower_case,
"additional_special_tokens": if do_lower_case then [
'[head=organization]', '[head=person]', '[head_end]', '[head_start]', '[tail=cause_of_death]', '[tail=city]', '[tail=country]', '[tail=criminal_charge]', '[tail=date]', '[tail=duration]', '[tail=ideology]', '[tail=location]', '[tail=misc]', '[tail=nationality]', '[tail=number]', '[tail=organization]', '[tail=person]', '[tail=religion]', '[tail=state_or_province]', '[tail=title]', '[tail=url]', '[tail_end]', '[tail_start]'
] else [
'[HEAD=ORGANIZATION]', '[HEAD=PERSON]', '[HEAD_END]', '[HEAD_START]', '[TAIL=CAUSE_OF_DEATH]', '[TAIL=CITY]', '[TAIL=COUNTRY]', '[TAIL=CRIMINAL_CHARGE]', '[TAIL=DATE]', '[TAIL=DURATION]', '[TAIL=IDEOLOGY]', '[TAIL=LOCATION]', '[TAIL=MISC]', '[TAIL=NATIONALITY]', '[TAIL=NUMBER]', '[TAIL=ORGANIZATION]', '[TAIL=PERSON]', '[TAIL=RELIGION]', '[TAIL=STATE_OR_PROVINCE]', '[TAIL=TITLE]', '[TAIL=URL]', '[TAIL_END]', '[TAIL_START]'
],
},
local parameter_groups = if weight_decay > 0 then [
[["(?<!LayerNorm\\.)weight",], {"weight_decay": weight_decay}],
[["bias", "LayerNorm.weight"], {"weight_decay": 0.0}],
] else null,
"dataset_reader": {
"type": "sherlock",
"dataset_reader_name": "tacred",
"feature_converter_name": "binary_rc",
"tokenizer": {
"type": "pretrained_transformer",
"model_name": model_name,
"max_length": max_length,
"add_special_tokens": false,
"tokenizer_kwargs": tokenizer_kwargs,
},
"token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name,
"max_length": max_length,
"tokenizer_kwargs": tokenizer_kwargs,
},
},
"max_tokens": max_length,
"log_num_input_features": 3,
"dataset_reader_kwargs": {
"negative_label_re": negative_label,
},
"feature_converter_kwargs": {
"entity_handling": entity_handling,
},
"max_instances": max_instances,
},
"train_data_path": train_data_path,
"validation_data_path": validation_data_path,
"model": {
"type": "transformer_relation_classifier",
"model_name": model_name,
"max_length": max_length,
"ignore_label": negative_label,
"f1_average": "micro",
"tokenizer_kwargs": tokenizer_kwargs,
},
"data_loader": {
"type": "simple",
"batch_size": batch_size,
"shuffle": true,
},
"trainer": {
"num_epochs": num_epochs,
"optimizer": {
"type": "huggingface_adamw",
"parameter_groups": parameter_groups,
"lr": lr,
"eps": adam_epsilon,
},
"learning_rate_scheduler": {
"type": "linear_with_warmup",
"warmup_steps": warmup_steps,
},
"validation_metric": "+fscore",
"cuda_device": cuda_device,
"use_amp": fp16,
},
}