-
Notifications
You must be signed in to change notification settings - Fork 6
/
hawq_classifier.py
139 lines (124 loc) · 5 KB
/
hawq_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# schdule specific order for training different layer iteratively
import os
import json
import re
import argparse
import ray
from tuning_order import tuning_order
from schedule_run_classifier import schedule_run
from util import BertConfig_generic
quantize_assigned_bits = [2, 4, 6, 8]
block_wise_bits_mask = [32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32]
TOTAL_GPUS = 4
PER_JOB_GPUS = 2
def hawq(bert_model, lambdas, task_name, bits_order, do_fullprecision,
freeze_tune, freeze_embedding, quantize_activation, emb_bits):
'''
Perform Hessian aware quantization.
The 12 layers are grouped into 4 blocks, each one containing 3 layers:
layer[bits_order[0]:bits_order[3]] belong to block1,
layer[bits_order[3]:bits_order[6]] belong to block2,
layer[bits_order[6]:bits_order[9]] belong to block3,
layer[bits_order[9]:bits_order[12]] belong to block4.
Each block is assigned 1 bit, with the constraint:
bits_block1 >= bits_block2 >= bits_block3 >= bits_block4
'''
def get_valid_bit_assignments():
quantize_bits_list = []
for bits_block1 in [4, 6, 8]:
for bits_block2 in [4, 6, 8]:
for bits_block3 in [4, 6, 8]:
for bits_block4 in [2, 4, 6, 8]:
# get bits of each block
if (bits_block1 < bits_block2) or (
bits_block2 < bits_block3) or (bits_block3 <
bits_block4):
continue
quantize_bits = [bits_block1] * 3 + [bits_block2] * \
3 + [bits_block3] * 3 + [bits_block4] * 3
quantize_bits_list.append(quantize_bits)
return quantize_bits_list
quantize_bits_list = get_valid_bit_assignments()
ray.init(num_gpus=TOTAL_GPUS)
@ray.remote(num_gpus=PER_JOB_GPUS)
def run(quantize_bits, index):
for i in range(12):
block_wise_bits_mask[bits_order[i]] = quantize_bits[i]
quantized_config = BertConfig_generic(
block_wise_bits_mask=block_wise_bits_mask)
block_wise_order = tuning_order(bert_model, quantized_config, lambdas,
task_name)
# fine tune the quantized model
return schedule_run(
f'results/experiment-{task_name}-{index}/',
block_wise_order,
block_wise_bits_mask,
do_fullprecision=do_fullprecision,
freeze_tune=freeze_tune,
freeze_embedding=freeze_embedding,
task_name=task_name,
quantize_activation=quantize_activation,
emb_bits=emb_bits)
results = ray.get(
[run.remote(quantize_bits, index) for index, quantize_bits in enumerate(quantize_bits_list)])
with open("results/HAWQ-results.json", "w") as writer:
for i in range(len(results)):
writer.write(results[i])
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train.")
parser.add_argument(
"--bert_model",
default=None,
type=str,
required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument(
"--lambdas",
default=[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
nargs='+',
type=float,
help="lambda (eigenvalue in the HAWQ paper)")
parser.add_argument(
"--bits_order",
default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
nargs='+',
type=int,
help="the ordering of bit assignment")
parser.add_argument(
"--do_fullprecision",
action='store_true',
help=
"if true, will train full precision with 3 epoches; otherwise will use the saved results of full precision results"
)
parser.add_argument(
'--freeze_tune',
action='store_true',
help="Whether to perform fine-tune with other layers frozen")
parser.add_argument(
'--freeze_embedding',
action='store_true',
help="Whether to perform fine-tune with embedding frozen")
parser.add_argument(
'--quantize_activation',
action='store_true',
help="Whether to quantize the activation layers")
parser.add_argument(
'--emb_bits',
default=[32, 32, 32],
nargs='+',
type=int,
help="bits for embedding layer, following word, pos, type; 15261:256:1 as #param.")
args = parser.parse_args()
hawq(args.bert_model, args.lambdas, args.task_name, args.bits_order,
args.do_fullprecision, args.freeze_tune, args.freeze_embedding,
args.quantize_activation, args.emb_bits)
if __name__ == "__main__":
main()