diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_mha_wrapper.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_mha_wrapper.py index b1520b2ffc1..813558eb0d8 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_mha_wrapper.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quant_sim/qc_mha_wrapper.py @@ -131,14 +131,22 @@ def _build_from_signature(self, query, value, key=None): _, _, output_rank = _build_proj_equation(query_shape.rank - 1, bound_dims=1, output_dims=2) output_shape = _get_output_shape(output_rank, [self._num_heads, self._key_dim]) - self._query_dense.build(query_shape) - self._value_dense.build(value_shape) - self._key_dense.build(key_shape) - self._output_dense.build(output_shape) + with tf.name_scope("query"): + self._query_dense.build(query_shape) + with tf.name_scope("value"): + self._value_dense.build(value_shape) + with tf.name_scope("key"): + self._key_dense.build(key_shape) + with tf.name_scope("attention_output"): + self._output_dense.build(output_shape) if self.copy_source_weights is not None: new_weights = self.get_weights() + # Weights 0-5 in QcQuantizableMultiHeadAttention correspond to the weights 0-5 in Keras MHA, and + # represent the weights and biases associated with the query, key, and value feedforward layers new_weights[0:6] = self.copy_source_weights[0:6] + # Weights 32-33 in QcQuantizableMultiHeadAttention correspond to the weights 6-7 in Keras MHA, and + # represent the output feedforward layer weights and biases new_weights[32:34] = self.copy_source_weights[6:8] self.set_weights(new_weights) @@ -266,7 +274,7 @@ def reactivate_quantizers(self): """Function to reactivate quantizers during forward pass""" self._remove_quantizers = False - def quant_layers(self): + def quant_wrappers(self): """Function to allow QuantizationSimModel to access local quantization wrappers""" for layer in self._wrapped_layers: yield layer diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py index 26bd9ab2911..f4c90d5a0f9 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/quantsim.py @@ -264,7 +264,12 @@ def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]: for wrapper in self.quant_wrappers(): for idx, input_quantizer in enumerate(wrapper.input_quantizers): if input_quantizer.encoding is not None: - tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name + # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their + # inbound_nodes parameter populated, so the name of the quantizer is used instead + if not wrapper._layer_to_wrap.inbound_nodes: + tensor_name = wrapper.name + "/" + input_quantizer.name + else: + tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name encoding_dict = self._get_encoding_dict_for_quantizer(input_quantizer) activation_encodings[tensor_name] = encoding_dict for idx, param_quantizer in enumerate(wrapper.param_quantizers): @@ -274,7 +279,12 @@ def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]: param_encodings[param_name] = encoding_dict for idx, output_quantizer in enumerate(wrapper.output_quantizers): if output_quantizer.encoding is not None: - tensor_name = wrapper._layer_to_wrap.output.name + # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their + # inbound_nodes parameter populated, so the name of the quantizer is used instead + if not wrapper._layer_to_wrap.inbound_nodes: + tensor_name = wrapper.name + "/" + output_quantizer.name + else: + tensor_name = wrapper._layer_to_wrap.output.name encoding_dict = self._get_encoding_dict_for_quantizer(output_quantizer) activation_encodings[tensor_name] = encoding_dict encodings_dict = {'version': encoding_version, diff --git a/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py b/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py index 2c8cffd0259..cb0b046bb4a 100644 --- a/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py +++ b/TrainingExtensions/tensorflow/test/python/eager/test_quantsim_keras.py @@ -494,3 +494,80 @@ def test_quantizable_mha_with_mask(): # check that QcQuantizableMultiHeadAttention exists in QuantSim model.layers assert any(isinstance(layer, QcQuantizableMultiHeadAttention) for layer in quantized_model.model.layers) + +def test_quantizable_mha_encodings(): + B = 5 + T = 8 + S = 4 + + q_inputs = keras.Input(shape=(T, 16)) + v_inputs = keras.Input(shape=(S, 16)) + k_inputs = keras.Input(shape=(S, 16)) + m_inputs = keras.Input(shape=(T, S)) + model_output = keras.layers.MultiHeadAttention(key_dim=2, num_heads=2)(q_inputs, v_inputs, k_inputs, m_inputs) + unquantized_model = keras.Model(inputs=[q_inputs, v_inputs, k_inputs, m_inputs], outputs=model_output) + + quantized_model = QuantizationSimModel(unquantized_model) + + rng = np.random.default_rng(seed=42) + query = rng.random([B, T, 16]) + value = rng.random([B, S, 16]) + key = rng.random([B, S, 16]) + mask = np.zeros([B, T, S]) + + quantized_model.compute_encodings(lambda m, _: m([query, value, key, mask]), None) + + query = query * 10 + value = value * 10 + key = key * 10 + + unquantized_model_tensor = unquantized_model([query, value, key, mask]).numpy().flatten() + quantized_model_tensor = quantized_model.model([query, value, key, mask]).numpy().flatten() + + output_encoding_min = quantized_model.model.layers[-1]._wrapped_layers[-1].output_quantizers[0]._encoding_min + output_encoding_max = quantized_model.model.layers[-1]._wrapped_layers[-1].output_quantizers[0]._encoding_max + + # checking to make sure all outputs fall within the limits set by the output quantizer + FLOAT_DELTA = 0.0001 + assert all((quantized_model_tensor >= output_encoding_min - FLOAT_DELTA) & + (quantized_model_tensor <= output_encoding_max + FLOAT_DELTA)) + assert abs(quantized_model_tensor.min() - output_encoding_min) < FLOAT_DELTA + assert abs(quantized_model_tensor.max() - output_encoding_max) < FLOAT_DELTA + +def test_quantizable_mha_export_encodings(): + B = 5 + T = 8 + S = 4 + + # STAGE 1 MODEL - model created with layers.MultiHeadAttention + stage_1_q_inputs = keras.Input(shape=(T, 16)) + stage_1_v_inputs = keras.Input(shape=(S, 16)) + stage_1_output = keras.layers.MultiHeadAttention(key_dim=2, num_heads=2)(stage_1_q_inputs, stage_1_v_inputs) + stage_1_model = keras.Model(inputs=[stage_1_q_inputs, stage_1_v_inputs], outputs=stage_1_output) + + # STAGE 3 MODEL - model created using QuantSim + stage_3_model = QuantizationSimModel(stage_1_model) + + rng = np.random.default_rng(seed=42) + query = rng.random([B, T, 16]) * 100 + value = rng.random([B, S, 16]) * 100 + + stage_3_model.compute_encodings(lambda m, _: m([query, value]), None) + stage_3_model.export('./data', 'mha') + + with open("./data/mha.encodings", "r") as encodings_file: + encodings = json.load(encodings_file) + + for wrapper in stage_3_model.model.layers[2]._wrapped_layers: + for io_quantizer in wrapper.input_quantizers + wrapper.output_quantizers: + if io_quantizer.encoding is not None: + tensor_name = wrapper.name + "/" + io_quantizer.name + encoding_dict = QuantizationSimModel._get_encoding_dict_for_quantizer(io_quantizer) + assert tensor_name in encodings['activation_encodings'] + assert encodings['activation_encodings'][tensor_name] == encoding_dict + for idx, param_quantizer in enumerate(wrapper.param_quantizers): + if param_quantizer.encoding is not None: + param_name = wrapper._layer_to_wrap.weights[idx].name + encoding_dict = QuantizationSimModel._get_encoding_dict_for_quantizer(param_quantizer) + assert param_name in encodings['param_encodings'] + assert encodings['param_encodings'][param_name] == encoding_dict