forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 13
/
arguments.py
398 lines (351 loc) · 20.6 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""argparser configuration"""
import argparse
import os
import torch
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group('model', 'model configuration')
group.add_argument('--pretrained-bert', action='store_true',
help='use a pretrained bert-large-uncased model instead'
'of initializing from scratch. See '
'--tokenizer-model-type to specify which pretrained '
'BERT model to use')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='dropout probability for attention weights')
group.add_argument('--num-attention-heads', type=int, default=16,
help='num of transformer attention heads')
group.add_argument('--hidden-size', type=int, default=1024,
help='tansformer hidden size')
group.add_argument('--intermediate-size', type=int, default=None,
help='transformer embedding dimension for FFN'
'set to 4*`--hidden-size` if it is None')
group.add_argument('--num-layers', type=int, default=24,
help='num decoder layers')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='layer norm epsilon')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='dropout probability for hidden state transformer')
group.add_argument('--max-position-embeddings', type=int, default=512,
help='maximum number of position embeddings to use')
group.add_argument('--vocab-size', type=int, default=30522,
help='vocab size to use for non-character-level '
'tokenization. This value will only be used when '
'creating a tokenizer')
group.add_argument('--deep-init', action='store_true',
help='initialize bert model similar to gpt2 model.'
'scales initialization of projection layers by a '
'factor of 1/sqrt(2N). Necessary to train bert '
'models larger than BERT-Large.')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
return parser
def add_fp16_config_args(parser):
"""Mixed precision arguments."""
group = parser.add_argument_group('fp16', 'fp16 configurations')
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode')
group.add_argument('--fp32-embedding', action='store_true',
help='embedding in fp32')
group.add_argument('--fp32-layernorm', action='store_true',
help='layer norm in fp32')
group.add_argument('--fp32-tokentypes', action='store_true',
help='embedding token types in fp32')
group.add_argument('--fp32-allreduce', action='store_true',
help='all-reduce in fp32')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.')
group.add_argument('--loss-scale-window', type=float, default=1000,
help='Window over which to raise/lower dynamic scale')
group.add_argument('--min-scale', type=float, default=1,
help='Minimum loss scale for dynamic loss scale')
return parser
def add_training_args(parser):
"""Training arguments."""
group = parser.add_argument_group('train', 'training configurations')
group.add_argument('--batch-size', type=int, default=4,
help='Data Loader batch size')
group.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay coefficient for L2 regularization')
group.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activation to allow for training '
'with larger models and sequences')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing')
group.add_argument('--clip-grad', type=float, default=1.0,
help='gradient clipping')
group.add_argument('--train-iters', type=int, default=1000000,
help='total number of iterations to train over all training runs')
group.add_argument('--log-interval', type=int, default=100,
help='report interval')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after this many new iterations.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory')
group.add_argument('--seed', type=int, default=1234,
help='random seed')
# Batch prodecuer arguments
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens')
# Learning rate.
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay LR over,'
' If None defaults to `--train-iters`*`--epochs`')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine', 'exponential'],
help='learning rate decay function')
group.add_argument('--lr', type=float, default=1.0e-4,
help='initial learning rate')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from input arguments and ignore values from '
'checkpoints. Notethat all the above values will be '
'reset.')
# model checkpointing
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=5000,
help='number of iterations between saves')
group.add_argument('--no-save-optim', action='store_true',
help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true',
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Path to a directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true',
help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--resume-dataloader', action='store_true',
help='Resume the dataloader when resuming training. '
'Does not apply to tfrecords dataloader, try resuming'
'with a different seed in this case.')
# distributed training args
group.add_argument('--distributed-backend', default='nccl',
help='which backend to use for distributed '
'training. One of [gloo, nccl]')
group.add_argument('--DDP-impl', default='local',
help='which DistributedDataParallel implementation '
'to use. One of [local, torch]')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
# autoresume
group.add_argument('--adlr-autoresume', action='store_true',
help='enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='intervals over which check for autoresume'
'termination signal')
return parser
def add_evaluation_args(parser):
"""Evaluation arguments."""
group = parser.add_argument_group('validation', 'validation configurations')
group.add_argument('--eval-batch-size', type=int, default=None,
help='Data Loader batch size for evaluation datasets.'
'Defaults to `--batch-size`')
group.add_argument('--eval-iters', type=int, default=100,
help='number of iterations to run for evaluation'
'validation/test for')
group.add_argument('--eval-interval', type=int, default=1000,
help='interval between running evaluation on validation set')
group.add_argument('--eval-seq-length', type=int, default=None,
help='Maximum sequence length to process for '
'evaluation. Defaults to `--seq-length`')
group.add_argument('--eval-max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use for '
'evaluation. Defaults to '
'math.ceil(`--eval-seq-length`*.15/10)*10')
group.add_argument('--overlapping-eval', type=int, default=32,
help='sliding window for overlapping eval ')
group.add_argument('--cloze-eval', action='store_true',
help='Evaluation dataset from `--valid-data` is a cloze task')
group.add_argument('--strict-lambada', action='store_true',
help='use more difficult formulation of lambada')
group.add_argument('--eval-hf', action='store_true',
help='perform evaluation with huggingface openai model.'
'use `--load` to specify weights path to be loaded')
group.add_argument('--load-openai', action='store_true',
help='load openai weights into our model. Use `--load` '
'to specify weights path to be loaded')
return parser
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def add_data_args(parser):
"""Train/valid/test data arguments."""
group = parser.add_argument_group('data', 'data configurations')
group.add_argument('--model-parallel-size', type=int, default=1,
help='size of the model parallel.')
group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated filenames or corpora names '
'for training.')
group.add_argument('--use-npy-data-loader', action='store_true',
help='Use the numpy data loader. If set, then'
'train-data-path, val-data-path, and test-data-path'
'should also be provided.')
group.add_argument('--train-data-path', type=str, default='',
help='path to the training data')
group.add_argument('--val-data-path', type=str, default='',
help='path to the validation data')
group.add_argument('--test-data-path', type=str, default='',
help='path to the test data')
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='the filename containing all the shards sizes')
group.add_argument('--delim', default=',',
help='delimiter used to parse csv data files')
group.add_argument('--text-key', default='sentence',
help='key to use to extract text from json/csv')
group.add_argument('--eval-text-key', default=None,
help='key to use to extract text from '
'json/csv evaluation datasets')
group.add_argument('--valid-data', nargs='*', default=None,
help="""Filename for validation data.""")
group.add_argument('--split', default='1000,1,1',
help='comma-separated list of proportions for training,'
' validation, and test split')
group.add_argument('--test-data', nargs='*', default=None,
help="""Filename for testing""")
group.add_argument('--lazy-loader', action='store_true',
help='whether to lazy read the data set')
group.add_argument('--loose-json', action='store_true',
help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one '
'json string)')
group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where '
'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str,
default='bert-large-uncased',
help="Model type to use for sentencepiece tokenization \
(one of ['bpe', 'char', 'unigram', 'word']) or \
bert vocab to use for BertWordPieceTokenizer (one of \
['bert-large-uncased', 'bert-large-cased', etc.])")
group.add_argument('--tokenizer-path', type=str, default='tokenizer.model',
help='path used to save/load sentencepiece tokenization '
'models')
group.add_argument('--tokenizer-type', type=str,
default='BertWordPieceTokenizer',
choices=['CharacterLevelTokenizer',
'SentencePieceTokenizer',
'BertWordPieceTokenizer',
'GPT2BPETokenizer'],
help='what type of tokenizer to use')
group.add_argument("--cache-dir", default=None, type=str,
help="Where to store pre-trained BERT downloads")
group.add_argument('--use-tfrecords', action='store_true',
help='load `--train-data`, `--valid-data`, '
'`--test-data` from BERT tf records instead of '
'normal data pipeline')
group.add_argument('--seq-length', type=int, default=512,
help="Maximum sequence length to process")
group.add_argument('--max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use per sequence.'
'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
return parser
def get_args():
"""Parse all the args."""
parser = argparse.ArgumentParser(description='PyTorch BERT Model')
parser = add_model_config_args(parser)
parser = add_fp16_config_args(parser)
parser = add_training_args(parser)
parser = add_evaluation_args(parser)
parser = add_text_generate_args(parser)
parser = add_data_args(parser)
args = parser.parse_args()
if not args.train_data and not args.train_data_path:
print('WARNING: No training data specified')
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
# We are using (OpenMPI) mpirun for launching distributed data parallel processes
local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
# Possibly running with Slurm
num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
nodeid = int(os.getenv('SLURM_NODEID', '0'))
args.local_rank = local_rank
args.rank = nodeid*local_size + local_rank
args.world_size = num_nodes*local_size
args.model_parallel_size = min(args.model_parallel_size, args.world_size)
if args.rank == 0:
print('using world size: {} and model-parallel size: {} '.format(
args.world_size, args.model_parallel_size))
args.dynamic_loss_scale = False
if args.loss_scale is None:
args.dynamic_loss_scale = True
if args.rank == 0:
print(' > using dynamic loss scaling')
# The args fp32_* or fp16_* meant to be active when the
# args fp16 is set. So the default behaviour should all
# be false.
if not args.fp16:
args.fp32_embedding = False
args.fp32_tokentypes = False
args.fp32_layernorm = False
return args