Skip to content

Commit

Permalink
Support for computing and exporting encodings for quantizable MHA
Browse files Browse the repository at this point in the history
Signed-off-by: Ashvin Kumar <quic_ashvkuma@quicinc.com>
  • Loading branch information
quic-ashvkuma committed Jul 29, 2022
1 parent d7065f1 commit 3d503f3
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,22 @@ def _build_from_signature(self, query, value, key=None):
_, _, output_rank = _build_proj_equation(query_shape.rank - 1, bound_dims=1, output_dims=2)
output_shape = _get_output_shape(output_rank, [self._num_heads, self._key_dim])

self._query_dense.build(query_shape)
self._value_dense.build(value_shape)
self._key_dense.build(key_shape)
self._output_dense.build(output_shape)
with tf.name_scope("query"):
self._query_dense.build(query_shape)
with tf.name_scope("value"):
self._value_dense.build(value_shape)
with tf.name_scope("key"):
self._key_dense.build(key_shape)
with tf.name_scope("attention_output"):
self._output_dense.build(output_shape)

if self.copy_source_weights is not None:
new_weights = self.get_weights()
# Weights 0-5 in QcQuantizableMultiHeadAttention correspond to the weights 0-5 in Keras MHA, and
# represent the weights and biases associated with the query, key, and value feedforward layers
new_weights[0:6] = self.copy_source_weights[0:6]
# Weights 32-33 in QcQuantizableMultiHeadAttention correspond to the weights 6-7 in Keras MHA, and
# represent the output feedforward layer weights and biases
new_weights[32:34] = self.copy_source_weights[6:8]
self.set_weights(new_weights)

Expand Down Expand Up @@ -266,7 +274,7 @@ def reactivate_quantizers(self):
"""Function to reactivate quantizers during forward pass"""
self._remove_quantizers = False

def quant_layers(self):
def quant_wrappers(self):
"""Function to allow QuantizationSimModel to access local quantization wrappers"""
for layer in self._wrapped_layers:
yield layer
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,12 @@ def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]:
for wrapper in self.quant_wrappers():
for idx, input_quantizer in enumerate(wrapper.input_quantizers):
if input_quantizer.encoding is not None:
tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name
# because dense layers in quantizable MHA are not explicitly sublayers, they don't have their
# inbound_nodes parameter populated, so the name of the quantizer is used instead
if not wrapper._layer_to_wrap.inbound_nodes:
tensor_name = wrapper.name + "/" + input_quantizer.name
else:
tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name
encoding_dict = self._get_encoding_dict_for_quantizer(input_quantizer)
activation_encodings[tensor_name] = encoding_dict
for idx, param_quantizer in enumerate(wrapper.param_quantizers):
Expand All @@ -274,7 +279,12 @@ def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]:
param_encodings[param_name] = encoding_dict
for idx, output_quantizer in enumerate(wrapper.output_quantizers):
if output_quantizer.encoding is not None:
tensor_name = wrapper._layer_to_wrap.output.name
# because dense layers in quantizable MHA are not explicitly sublayers, they don't have their
# inbound_nodes parameter populated, so the name of the quantizer is used instead
if not wrapper._layer_to_wrap.inbound_nodes:
tensor_name = wrapper.name + "/" + output_quantizer.name
else:
tensor_name = wrapper._layer_to_wrap.output.name
encoding_dict = self._get_encoding_dict_for_quantizer(output_quantizer)
activation_encodings[tensor_name] = encoding_dict
encodings_dict = {'version': encoding_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,80 @@ def test_quantizable_mha_with_mask():

# check that QcQuantizableMultiHeadAttention exists in QuantSim model.layers
assert any(isinstance(layer, QcQuantizableMultiHeadAttention) for layer in quantized_model.model.layers)

def test_quantizable_mha_encodings():
B = 5
T = 8
S = 4

q_inputs = keras.Input(shape=(T, 16))
v_inputs = keras.Input(shape=(S, 16))
k_inputs = keras.Input(shape=(S, 16))
m_inputs = keras.Input(shape=(T, S))
model_output = keras.layers.MultiHeadAttention(key_dim=2, num_heads=2)(q_inputs, v_inputs, k_inputs, m_inputs)
unquantized_model = keras.Model(inputs=[q_inputs, v_inputs, k_inputs, m_inputs], outputs=model_output)

quantized_model = QuantizationSimModel(unquantized_model)

rng = np.random.default_rng(seed=42)
query = rng.random([B, T, 16])
value = rng.random([B, S, 16])
key = rng.random([B, S, 16])
mask = np.zeros([B, T, S])

quantized_model.compute_encodings(lambda m, _: m([query, value, key, mask]), None)

query = query * 10
value = value * 10
key = key * 10

unquantized_model_tensor = unquantized_model([query, value, key, mask]).numpy().flatten()
quantized_model_tensor = quantized_model.model([query, value, key, mask]).numpy().flatten()

output_encoding_min = quantized_model.model.layers[-1]._wrapped_layers[-1].output_quantizers[0]._encoding_min
output_encoding_max = quantized_model.model.layers[-1]._wrapped_layers[-1].output_quantizers[0]._encoding_max

# checking to make sure all outputs fall within the limits set by the output quantizer
FLOAT_DELTA = 0.0001
assert all((quantized_model_tensor >= output_encoding_min - FLOAT_DELTA) &
(quantized_model_tensor <= output_encoding_max + FLOAT_DELTA))
assert abs(quantized_model_tensor.min() - output_encoding_min) < FLOAT_DELTA
assert abs(quantized_model_tensor.max() - output_encoding_max) < FLOAT_DELTA

def test_quantizable_mha_export_encodings():
B = 5
T = 8
S = 4

# STAGE 1 MODEL - model created with layers.MultiHeadAttention
stage_1_q_inputs = keras.Input(shape=(T, 16))
stage_1_v_inputs = keras.Input(shape=(S, 16))
stage_1_output = keras.layers.MultiHeadAttention(key_dim=2, num_heads=2)(stage_1_q_inputs, stage_1_v_inputs)
stage_1_model = keras.Model(inputs=[stage_1_q_inputs, stage_1_v_inputs], outputs=stage_1_output)

# STAGE 3 MODEL - model created using QuantSim
stage_3_model = QuantizationSimModel(stage_1_model)

rng = np.random.default_rng(seed=42)
query = rng.random([B, T, 16]) * 100
value = rng.random([B, S, 16]) * 100

stage_3_model.compute_encodings(lambda m, _: m([query, value]), None)
stage_3_model.export('./data', 'mha')

with open("./data/mha.encodings", "r") as encodings_file:
encodings = json.load(encodings_file)

for wrapper in stage_3_model.model.layers[2]._wrapped_layers:
for io_quantizer in wrapper.input_quantizers + wrapper.output_quantizers:
if io_quantizer.encoding is not None:
tensor_name = wrapper.name + "/" + io_quantizer.name
encoding_dict = QuantizationSimModel._get_encoding_dict_for_quantizer(io_quantizer)
assert tensor_name in encodings['activation_encodings']
assert encodings['activation_encodings'][tensor_name] == encoding_dict
for idx, param_quantizer in enumerate(wrapper.param_quantizers):
if param_quantizer.encoding is not None:
param_name = wrapper._layer_to_wrap.weights[idx].name
encoding_dict = QuantizationSimModel._get_encoding_dict_for_quantizer(param_quantizer)
assert param_name in encodings['param_encodings']
assert encodings['param_encodings'][param_name] == encoding_dict

0 comments on commit 3d503f3

Please sign in to comment.