diff --git a/keras_hub/src/layers/modeling/transformer_encoder.py b/keras_hub/src/layers/modeling/transformer_encoder.py index 614ba0f0f4..b4a975ff86 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder.py +++ b/keras_hub/src/layers/modeling/transformer_encoder.py @@ -184,7 +184,12 @@ def build(self, inputs_shape): self.built = True def call( - self, inputs, padding_mask=None, attention_mask=None, training=None + self, + inputs, + padding_mask=None, + attention_mask=None, + training=None, + return_attention_scores=False, ): """Forward pass of the TransformerEncoder. @@ -199,6 +204,7 @@ def call( [batch_size, sequence_length, sequence_length]. training: a boolean indicating whether the layer should behave in training mode or in inference mode. + return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`. Returns: A Tensor of the same shape as the `inputs`. @@ -214,12 +220,24 @@ def call( residual = x if self.normalize_first: x = self._self_attention_layer_norm(x) - x = self._self_attention_layer( - query=x, - value=x, - attention_mask=self_attention_mask, - training=training, - ) + + if return_attention_scores: + x, attention_scores = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + return_attention_scores=return_attention_scores, + training=training, + ) + return x, attention_scores + else: + x = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + training=training, + ) + x = self._self_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: @@ -236,6 +254,9 @@ def call( if not self.normalize_first: x = self._feedforward_layer_norm(x) + if return_attention_scores: + return x, attention_scores + return x def get_config(self): diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index 9640d02a19..107d1d693a 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -109,3 +109,14 @@ def test_mask_propagation(self): inputs._keras_mask = mask outputs = encoder(inputs) self.assertAllEqual(outputs._keras_mask, mask) + + def test_attention_scores(self): + encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) + inputs = random.uniform(shape=[1, 4, 6]) + outputs, attention_scores = encoder( + inputs, return_attention_scores=True + ) + self.assertAllEqual(outputs.shape, inputs.shape) + + # attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length) + self.assertAllEqual(attention_scores.shape, [1, 2, 4, 4])