-
Notifications
You must be signed in to change notification settings - Fork 104
/
run_sequence_level_classification.py
394 lines (342 loc) · 18.4 KB
/
run_sequence_level_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
# coding: utf-8
# Copyright 2019 Sinovation Ventures AI Institute
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run sequence level classification task on ZEN model."""
from __future__ import absolute_import, division, print_function
import argparse
import sys
import logging
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import datetime
from tensorboardX import SummaryWriter
from utils_sequence_level_task import processors, convert_examples_to_features, compute_metrics
from ZEN import BertTokenizer, BertAdam, WarmupLinearSchedule
from ZEN import ZenForSequenceClassification, ZenNgramDict
from ZEN import WEIGHTS_NAME, CONFIG_NAME, NGRAM_DICT_NAME
logger = logging.getLogger(__name__)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
def load_examples(args, tokenizer, ngram_dict, processor, label_list, mode):
if mode == "train":
examples = processor.get_train_examples(args.data_dir)
elif mode == "test":
examples = processor.get_test_examples(args.data_dir)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, ngram_dict)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
all_ngram_ids = torch.tensor([f.ngram_ids for f in features], dtype=torch.long)
all_ngram_positions = torch.tensor([f.ngram_positions for f in features], dtype=torch.long)
all_ngram_lengths = torch.tensor([f.ngram_lengths for f in features], dtype=torch.long)
all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in features], dtype=torch.long)
all_ngram_masks = torch.tensor([f.ngram_masks for f in features], dtype=torch.long)
return TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_ngram_ids,
all_ngram_positions, all_ngram_lengths, all_ngram_seg_ids, all_ngram_masks)
def save_zen_model(save_zen_model_path, model, tokenizer, ngram_dict, args):
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_zen_model_path, WEIGHTS_NAME)
output_config_file = os.path.join(save_zen_model_path, CONFIG_NAME)
output_ngram_dict_file = os.path.join(save_zen_model_path, NGRAM_DICT_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(save_zen_model_path)
ngram_dict.save(output_ngram_dict_file)
output_args_file = os.path.join(save_zen_model_path, 'training_args.bin')
torch.save(args, output_args_file)
def evaluate(args, model, tokenizer, ngram_dict, processor, label_list):
eval_dataset = load_examples(args, tokenizer, ngram_dict, processor, label_list, mode="test")
# Run prediction for full data
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_dataset)
else:
eval_sampler = DistributedSampler(eval_dataset) # Note that this sampler samples randomly
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Eval!
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
model.eval()
preds = []
out_label_ids = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
batch = tuple(t.to(args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids, input_ngram_ids, ngram_position_matrix, \
ngram_lengths, ngram_seg_ids, ngram_masks = batch
with torch.no_grad():
logits = model(input_ids=input_ids,
input_ngram_ids=input_ngram_ids,
ngram_position_matrix=ngram_position_matrix,
labels=None, head_mask=None)
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
out_label_ids = label_ids.detach().cpu().numpy()
else:
preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
preds = np.argmax(preds[0], axis=1)
return compute_metrics(args.task_name, preds, out_label_ids)
def train(args, model, tokenizer, ngram_dict, processor, label_list):
global_step = 0
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
train_dataset = load_examples(args, tokenizer, ngram_dict, processor, label_list, mode="train")
if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset)
else:
train_sampler = DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
for epoch_num in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
batch = tuple(t.to(args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids, ngram_ids, ngram_positions, \
ngram_lengths, ngram_seg_ids, ngram_masks = batch
loss = model(input_ids,
ngram_ids,
ngram_positions,
labels=label_ids)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
if args.local_rank in [-1, 0]:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step)
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
save_zen_model(output_dir, model, tokenizer, ngram_dict, args)
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train.")
parser.add_argument("--output_dir",
default='./results/result-seqlevel-{}'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')),
type=str,
help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters
parser.add_argument("--multift",
action='store_true',
help="True for multi-task fine tune")
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval",
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=8,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
default=3.0,
type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--overwrite_output_dir',
action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument("--save_steps", type=int, default=50,
help="Save checkpoint every X updates steps.")
args = parser.parse_args()
args.task_name = args.task_name.lower()
# Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
filemode='w',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
if args.local_rank == -1 or args.no_cuda:
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
args.device = torch.device("cuda", args.local_rank)
args.n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
print("Output directory already exists and is not empty.")
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
task_name = args.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
label_list = processor.get_labels()
num_labels = len(label_list)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
ngram_dict = ZenNgramDict(args.bert_model, tokenizer=tokenizer)
model = ZenForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels, multift = args.multift)
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16:
model.half()
model.to(args.device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
elif args.n_gpu > 1:
model = torch.nn.DataParallel(model)
if args.do_train:
train(args, model, tokenizer, ngram_dict, processor, label_list)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
result = evaluate(args, model, tokenizer, ngram_dict, processor, label_list)
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
if __name__ == "__main__":
main()