-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfigure_pretraining.py
134 lines (110 loc) · 4.7 KB
/
configure_pretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# coding=utf-8
# Adapt from electra
"""Config controlling hyperparameters for pre-training ELECTRA."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
class PretrainingConfig(object):
"""Defines pre-training hyperparameters."""
def __init__(self, model_name, data_dir, **kwargs):
self.model_name = model_name
self.debug = False # debug mode for quickly running things
self.do_train = True # pre-train ELECTRA
self.do_eval = False # evaluate generator/discriminator on unlabeled data
self.mask_prob = 0.15 # percent of input tokens to mask out / replace
# optimization
self.learning_rate = 5e-4
self.lr_decay_power = 1.0 # linear weight decay by default
self.weight_decay_rate = 0.01
self.num_warmup_steps = 10000
# training settings
self.iterations_per_loop = 200
self.save_checkpoints_steps = 1000
self.num_train_steps = 1000000
self.num_eval_steps = 100
# model settings
self.model_size = "small" # one of "small", "base", or "large"
# override the default transformer hparams for the provided model size; see
# modeling.BertConfig for the possible hparams and util.training_utils for
# the defaults
self.model_hparam_overrides = (
kwargs["model_hparam_overrides"]
if "model_hparam_overrides" in kwargs else {})
self.embedding_size = None # bert hidden size by default
self.vocab_size = 30522 # number of tokens in the vocabulary
self.do_lower_case = False # lowercase the input?
# generator settings
#self.uniform_generator = False # generator is uniform at random
#self.untied_generator_embeddings = False # tie generator/discriminator
# token embeddings?
self.untied_generator = True # tie all generator/discriminator weights?
#self.generator_layers = 1.0 # frac of discriminator layers for generator
#self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen
self.disallow_correct = False # force the generator to sample incorrect
# tokens (so 15% of tokens are always
# fake)
self.temperature = 1.0 # temperature for sampling from generator
# batch sizes
self.max_seq_length = 128
self.train_batch_size = 128
self.eval_batch_size = 128
# TPU settings
self.use_tpu = False
self.num_tpu_cores = 1
self.tpu_job_name = None
self.tpu_name = None # cloud TPU to use for training
self.tpu_zone = None # GCE zone where the Cloud TPU is located in
self.gcp_project = None # project name for the Cloud TPU-enabled project
# default locations of data files
self.pretrain_tfrecords = os.path.join(
data_dir, "pretrain_tfrecords/pretrain_data.tfrecord*")
self.vocab_file = os.path.join(data_dir, "vocab.txt")
self.model_dir = os.path.join(data_dir, "models", model_name)
results_dir = os.path.join(self.model_dir, "results")
self.results_txt = os.path.join(results_dir, "unsup_results.txt")
self.results_pkl = os.path.join(results_dir, "unsup_results.pkl")
# Masking strategy
self.masking_strategy = "random"
self.init_checkpoint = None
# config for adversarial strategy
self.teacher_update_rate = 0.5
self.teacher_rate_update_step = 1000
self.teacher_rate_decay = 0.963
self.teacher_learning_rate = 5e-5
self.teacher_size = "small"
self.ratio_file=""
# update defaults with passed-in hyperparameters
self.update(kwargs)
self.max_predictions_per_seq = int((self.mask_prob + 0.005) *
self.max_seq_length)
# debug-mode settings
if self.debug:
self.train_batch_size = 8
self.num_train_steps = 20
self.eval_batch_size = 4
self.iterations_per_loop = 1
self.num_eval_steps = 2
# defaults for different-sized model
if self.model_size == "small":
self.embedding_size = 128
# Here are the hyperparameters we used for larger models; see Table 6 in the
# paper for the full hyperparameters
# else:
# self.max_seq_length = 512
# self.learning_rate = 2e-4
# if self.model_size == "base":
# self.embedding_size = 768
# self.generator_hidden_size = 0.33333
# self.train_batch_size = 256
# else:
# self.embedding_size = 1024
# self.mask_prob = 0.25
# self.train_batch_size = 2048
# passed-in-arguments override (for example) debug-mode defaults
self.update(kwargs)
def update(self, kwargs):
for k, v in kwargs.items():
if k not in self.__dict__:
raise ValueError("Unknown hparam " + k)
self.__dict__[k] = v