forked from lavis-nlp/spert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
110 lines (85 loc) · 5.95 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
def _add_common_args(arg_parser):
arg_parser.add_argument('--config', type=str)
# Input
arg_parser.add_argument('--types_path', type=str, help="Path to type specifications")
# Preprocessing
arg_parser.add_argument('--tokenizer_path', type=str, help="Path to tokenizer")
arg_parser.add_argument('--max_span_size', type=int, default=10, help="Maximum size of spans")
arg_parser.add_argument('--lowercase', action='store_true', default=False,
help="If true, input is lowercased during preprocessing")
arg_parser.add_argument('--sampling_processes', type=int, default=4,
help="Number of sampling processes. 0 = no multiprocessing for sampling")
# Model / Training / Evaluation
arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints")
arg_parser.add_argument('--model_type', type=str, default="spert", help="Type of model")
arg_parser.add_argument('--cpu', action='store_true', default=False,
help="If true, train/evaluate on CPU even if a CUDA device is available")
arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation/Prediction batch size")
arg_parser.add_argument('--max_pairs', type=int, default=1000,
help="Maximum entity pairs to process during training/evaluation")
arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations")
arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding")
arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT")
arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights")
arg_parser.add_argument('--no_overlapping', action='store_true', default=False,
help="If true, do not evaluate on overlapping entities "
"and relations with overlapping entities")
# Misc
arg_parser.add_argument('--seed', type=int, default=None, help="Seed")
arg_parser.add_argument('--cache_path', type=str, default=None,
help="Path to cache transformer models (for HuggingFace transformers library)")
arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off")
def _add_logging_args(arg_parser):
arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models")
arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored")
arg_parser.add_argument('--store_predictions', action='store_true', default=False,
help="If true, store predictions on disc (in log directory)")
arg_parser.add_argument('--store_examples', action='store_true', default=False,
help="If true, store evaluation examples on disc (in log directory)")
arg_parser.add_argument('--example_count', type=int, default=None,
help="Count of evaluation example to store (if store_examples == True)")
def train_argparser():
arg_parser = argparse.ArgumentParser()
# Input
arg_parser.add_argument('--train_path', type=str, help="Path to train dataset")
arg_parser.add_argument('--valid_path', type=str, help="Path to validation dataset")
# Logging
arg_parser.add_argument('--save_path', type=str, help="Path to directory where model checkpoints are stored")
arg_parser.add_argument('--init_eval', action='store_true', default=False,
help="If true, evaluate validation set before training")
arg_parser.add_argument('--save_optimizer', action='store_true', default=False,
help="Save optimizer alongside model")
arg_parser.add_argument('--train_log_iter', type=int, default=100, help="Log training process every x iterations")
arg_parser.add_argument('--final_eval', action='store_true', default=False,
help="Evaluate the model only after training, not at every epoch")
# Model / Training
arg_parser.add_argument('--train_batch_size', type=int, default=2, help="Training batch size")
arg_parser.add_argument('--epochs', type=int, default=20, help="Number of epochs")
arg_parser.add_argument('--neg_entity_count', type=int, default=100,
help="Number of negative entity samples per document (sentence)")
arg_parser.add_argument('--neg_relation_count', type=int, default=100,
help="Number of negative relation samples per document (sentence)")
arg_parser.add_argument('--lr', type=float, default=5e-5, help="Learning rate")
arg_parser.add_argument('--lr_warmup', type=float, default=0.1,
help="Proportion of total train iterations to warmup in linear increase/decrease schedule")
arg_parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay to apply")
arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm")
_add_common_args(arg_parser)
_add_logging_args(arg_parser)
return arg_parser
def eval_argparser():
arg_parser = argparse.ArgumentParser()
# Input
arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset")
_add_common_args(arg_parser)
_add_logging_args(arg_parser)
return arg_parser
def predict_argparser():
arg_parser = argparse.ArgumentParser()
# Input
arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset")
arg_parser.add_argument('--predictions_path', type=str, help="Path to store predictions")
arg_parser.add_argument('--spacy_model', type=str, help="Label of SpaCy model (used for tokenization)")
_add_common_args(arg_parser)
return arg_parser