From 5966ae8aa47882b433448599019c37ea74b4f3d0 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Sun, 26 Mar 2023 23:48:47 +0530 Subject: [PATCH 1/6] added everything from prev PR --- keras_nlp/layers/__init__.py | 2 + keras_nlp/layers/relative_attention.py | 505 +++++++++++++++++++++++ keras_nlp/models/xlnet/xlnet_backbone.py | 211 ++++++++++ keras_nlp/models/xlnet/xlnet_encoder.py | 250 +++++++++++ 4 files changed, 968 insertions(+) create mode 100644 keras_nlp/layers/relative_attention.py create mode 100644 keras_nlp/models/xlnet/xlnet_backbone.py create mode 100644 keras_nlp/models/xlnet/xlnet_encoder.py diff --git a/keras_nlp/layers/__init__.py b/keras_nlp/layers/__init__.py index 8bec4c10f8..ffe8b89bf5 100644 --- a/keras_nlp/layers/__init__.py +++ b/keras_nlp/layers/__init__.py @@ -22,6 +22,8 @@ from keras_nlp.layers.position_embedding import PositionEmbedding from keras_nlp.layers.random_deletion import RandomDeletion from keras_nlp.layers.random_swap import RandomSwap +from keras_nlp.layers.relative_attention import MultiHeadRelativeAttention +from keras_nlp.layers.relative_attention import TwoStreamRelativeAttention from keras_nlp.layers.sine_position_encoding import SinePositionEncoding from keras_nlp.layers.start_end_packer import StartEndPacker from keras_nlp.layers.token_and_position_embedding import ( diff --git a/keras_nlp/layers/relative_attention.py b/keras_nlp/layers/relative_attention.py new file mode 100644 index 0000000000..a3465fd763 --- /dev/null +++ b/keras_nlp/layers/relative_attention.py @@ -0,0 +1,505 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Keras-based relative attention layers.""" +import math +import string + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.api_export import keras_nlp_export + +_CHR_IDX = string.ascii_lowercase + + +def _build_proj_equation(free_dims, bound_dims, output_dims): + """Builds an einsum equation for projections inside multi-head attention.""" + input_str = "" + kernel_str = "" + output_str = "" + bias_axes = "" + letter_offset = 0 + for i in range(free_dims): + char = _CHR_IDX[i + letter_offset] + input_str += char + output_str += char + + letter_offset += free_dims + for i in range(bound_dims): + char = _CHR_IDX[i + letter_offset] + input_str += char + kernel_str += char + + letter_offset += bound_dims + for i in range(output_dims): + char = _CHR_IDX[i + letter_offset] + kernel_str += char + output_str += char + bias_axes += char + equation = "%s,%s->%s" % (input_str, kernel_str, output_str) + + return equation, bias_axes, len(output_str) + + +def _get_output_shape(output_rank, known_last_dims): + return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) + + +def _rel_shift(x, klen=-1): + """Performs relative shift to form the relative attention score.""" + + x = tf.transpose(x, perm=[2, 3, 0, 1]) + x_size = tf.shape(x) + + x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) + x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) + x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) + x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) + + x = tf.transpose(x, perm=[2, 3, 0, 1]) + + return x + + +@keras_nlp_export("keras_nlp.layers.MultiHeadRelativeAttention") +class MultiHeadRelativeAttention(keras.layers.MultiHeadAttention): + """A multi-head attention layer with relative attention + position encoding. + This layer shares the same input/output projections as the common + `keras.layers.MultiHeadAttention` layer. + When it calculates attention logits, position encoding is projected to form + relative keys. The logits are composed by shifted relative logits and content + logits. + **Note: This layer is currently experimental. + Attributes: + kernel_initializer: The kernel initializer. Defaults to variance_scaling. + Call args: + query: Query `Tensor` of shape `[B, T, dim]`. + value: Value `Tensor` of shape `[B, S, dim]`. + content_attention_bias: Bias `Tensor` for content based attention of shape + `[num_heads, dim]`. + positional_attention_bias: Bias `Tensor` for position based attention of + shape `[num_heads, dim]`. + key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use + `value` for both `key` and `value`, which is the most common case. + relative_position_encoding: Relative positional encoding `Tensor` of shape + `[B, L, dim]`. + segment_matrix: Optional `Tensor` representing segmentation IDs used in + XLNet of shape `[B, S, S + M]`. + segment_encoding: Optional `Tensor` representing the segmentation encoding + as used in XLNet of shape `[2, num_heads, dim]`. + segment_attention_bias: Optional trainable bias parameter added to the query + had when calculating the segment-based attention score used in XLNet of + shape `[num_heads, dim]`. + state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the + state or memory. If passed, this is also attended over as in Transformer + XL. + attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention + to certain positions. + """ + + def __init__(self, kernel_initializer="variance_scaling", **kwargs): + super().__init__(kernel_initializer=kernel_initializer, **kwargs) + + def _build_from_signature(self, query, value, key=None): + super(MultiHeadRelativeAttention, self)._build_from_signature( + query=query, value=value, key=key + ) + if hasattr(value, "shape"): + value_shape = tf.TensorShape(value.shape) + else: + value_shape = value + if key is None: + key_shape = value_shape + elif hasattr(key, "shape"): + key_shape = tf.TensorShape(key.shape) + else: + key_shape = key + + common_kwargs = dict( + kernel_initializer=self._kernel_initializer, + bias_initializer=self._bias_initializer, + kernel_regularizer=self._kernel_regularizer, + bias_regularizer=self._bias_regularizer, + activity_regularizer=self._activity_regularizer, + kernel_constraint=self._kernel_constraint, + bias_constraint=self._bias_constraint, + ) + + with tf.init_scope(): + einsum_equation, _, output_rank = _build_proj_equation( + key_shape.rank - 1, bound_dims=1, output_dims=2 + ) + self._encoding_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=None, + name="encoding", + **common_kwargs + ) + + def compute_attention( + self, + query, + key, + value, + position, + content_attention_bias, + positional_attention_bias, + segment_matrix=None, + segment_encoding=None, + segment_attention_bias=None, + attention_mask=None, + ): + """Computes the attention. + This function defines the computation inside `call` with projected + multihead Q, K, V, R inputs. + Args: + query: Projected query `Tensor` of shape `[B, T, N, key_dim]`. + key: Projected key `Tensor` of shape `[B, S + M, N, key_dim]`. + value: Projected value `Tensor` of shape `[B, S + M, N, key_dim]`. + position: Projected position `Tensor` of shape `[B, L, N, key_dim]`. + content_attention_bias: Trainable bias parameter added to the query head + when calculating the content-based attention score. + positional_attention_bias: Trainable bias parameter added to the query + head when calculating the position-based attention score. + segment_matrix: Optional `Tensor` representing segmentation IDs used in + XLNet. + segment_encoding: Optional trainable `Tensor` representing the + segmentation encoding as used in XLNet. + segment_attention_bias: Optional trainable bias parameter added to the + query had when calculating the segment-based attention score used in + XLNet. + attention_mask: (default None) Optional mask that is added to attention + logits. If state is not None, the mask source sequence dimension should + extend M. + Returns: + attention_output: Multi-headed output of attention computation of shape + `[B, S, N, key_dim]`. + """ + content_attention = tf.einsum( + self._dot_product_equation, key, query + content_attention_bias + ) + positional_attention = tf.einsum( + self._dot_product_equation, + position, + query + positional_attention_bias, + ) + positional_attention = _rel_shift( + positional_attention, klen=tf.shape(content_attention)[3] + ) + + if segment_matrix is not None: + segment_attention = tf.einsum( + "bind,snd->bnis", + query + segment_attention_bias, + segment_encoding, + ) + target_shape = tf.shape(positional_attention) + segment_attention = tf.where( + tf.broadcast_to( + tf.expand_dims(segment_matrix, 1), target_shape + ), + tf.broadcast_to(segment_attention[:, :, :, 1:], target_shape), + tf.broadcast_to(segment_attention[:, :, :, :1], target_shape), + ) + attention_sum = ( + content_attention + positional_attention + segment_attention + ) + else: + attention_sum = content_attention + positional_attention + + attention_scores = tf.multiply( + attention_sum, 1.0 / math.sqrt(float(self._key_dim)) + ) + + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + + attention_output = self._dropout_layer(attention_scores) + + attention_output = tf.einsum( + self._combine_equation, attention_output, value + ) + return attention_output + + def call( + self, + query, + value, + content_attention_bias, + positional_attention_bias, + key=None, + relative_position_encoding=None, + segment_matrix=None, + segment_encoding=None, + segment_attention_bias=None, + state=None, + attention_mask=None, + ): + """Compute multi-head relative attention over inputs. + Size glossary: + * Number of heads (H): the number of attention heads. + * Value size (V): the size of each value embedding per head. + * Key size (K): the size of each key embedding per head. Equally, the size + of each query embedding per head. Typically K <= V. + * Batch dimensions (B). + * Query (target) attention axes shape (T). + * Value (source) attention axes shape (S), the rank must match the target. + * Encoding length (L): The relative positional encoding length. + Args: + query: attention input. + value: attention input. + content_attention_bias: A trainable bias parameter added to the query head + when calculating the content-based attention score. + positional_attention_bias: A trainable bias parameter added to the query + head when calculating the position-based attention score. + key: attention input. + relative_position_encoding: relative positional encoding for key and + value. + segment_matrix: Optional `Tensor` representing segmentation IDs used in + XLNet. + segment_encoding: Optional `Tensor` representing the segmentation encoding + as used in XLNet. + segment_attention_bias: Optional trainable bias parameter added to the + query had when calculating the segment-based attention score used in + XLNet. + state: (default None) optional state. If passed, this is also attended + over as in TransformerXL. + attention_mask: (default None) Optional mask that is added to attention + logits. If state is not None, the mask source sequence dimension should + extend M. + Returns: + attention_output: The result of the computation, of shape [B, T, E], + where `T` is for target sequence shapes and `E` is the query input last + dimension if `output_shape` is `None`. Otherwise, the multi-head outputs + are projected to the shape specified by `output_shape`. + """ + if not self._built_from_signature: + self._build_from_signature(query, value, key=key) + if key is None: + key = value + if state is not None and state.shape.ndims > 1: + value = tf.concat([state, value], 1) + key = tf.concat([state, key], 1) + + # `query` = [B, T, N ,H] + query = self._query_dense(query) + + # `key` = [B, S + M, N, H] + key = self._key_dense(key) + + # `value` = [B, S + M, N, H] + value = self._value_dense(value) + + # `position` = [B, L, N, H] + position = self._encoding_dense(relative_position_encoding) + + attention_output = self.compute_attention( + query=query, + key=key, + value=value, + position=position, + content_attention_bias=content_attention_bias, + positional_attention_bias=positional_attention_bias, + segment_matrix=segment_matrix, + segment_encoding=segment_encoding, + segment_attention_bias=segment_attention_bias, + attention_mask=attention_mask, + ) + + # `attention_output` = [B, S, N, H] + attention_output = self._output_dense(attention_output) + + return attention_output + + +@keras_nlp_export("keras_nlp.layers.TwoStreamRelativeAttention") +class TwoStreamRelativeAttention(MultiHeadRelativeAttention): + """Two-stream relative self-attention for XLNet. + In XLNet, each token has two associated vectors at each self-attention layer, + the content stream (h) and the query stream (g). + The content stream is the self-attention stream as in Transformer XL and + represents the context and content (the token itself). + The query stream only has access to contextual information and the position, + but not the content. + This layer shares the same build signature as + `keras.layers.MultiHeadAttention` but has different input/output + projections. + **Note: This layer is currently experimental. + Call args: + content_stream: `Tensor` of shape `[B, T, dim]`. + content_attention_bias: Bias `Tensor` for content based attention of shape + `[num_heads, dim]`. + positional_attention_bias: Bias `Tensor` for position based attention of + shape `[num_heads, dim]`. + query_stream: `Tensor` of shape `[B, P, dim]`. + target_mapping: `Tensor` of shape `[B, P, S]`. + relative_position_encoding: Relative positional encoding `Tensor` of shape + `[B, L, dim]`. + segment_matrix: Optional `Tensor` representing segmentation IDs used in + XLNet of shape `[B, S, S + M]`. + segment_encoding: Optional `Tensor` representing the segmentation + encoding as used in XLNet of shape `[2, num_heads, dim]`. + segment_attention_bias: Optional trainable bias parameter added to the + query had when calculating the segment-based attention score used in + XLNet of shape `[num_heads, dim]`. + state: Optional `Tensor` of shape [B, M, E] where M is the length of the + state or memory. + If passed, this is also attended over as in Transformer XL. + content_attention_mask: a boolean mask of shape `[B, T, S]` that + prevents attention to certain positions for content attention computation. + query_attention_mask: a boolean mask of shape `[B, T, S]` that + prevents attention to certain position for query attention computation. + """ + + def call( + self, + content_stream, + content_attention_bias, + positional_attention_bias, + query_stream, + relative_position_encoding, + target_mapping=None, + segment_matrix=None, + segment_encoding=None, + segment_attention_bias=None, + state=None, + content_attention_mask=None, + query_attention_mask=None, + ): + """Compute multi-head relative attention over inputs. + Size glossary: + * Number of heads (H): the number of attention heads. + * Value size (V): the size of each value embedding per head. + * Key size (K): the size of each key embedding per head. Equally, the size + of each query embedding per head. Typically K <= V. + * Number of predictions (P): the number of predictions. + * Batch dimensions (B). + * Query (target) attention axes shape (T). + * Value (source) attention axes shape (S), the rank must match the target. + * Encoding length (L): The relative positional encoding length. + Args: + content_stream: The content representation, commonly referred to as h. + This serves a similar role to the standard hidden states in + Transformer-XL. + content_attention_bias: A trainable bias parameter added to the query head + when calculating the content-based attention score. + positional_attention_bias: A trainable bias parameter added to the query + head when calculating the position-based attention score. + query_stream: The query representation, commonly referred to as g. This + only has access to contextual information and position, but not content. + If not provided, then this is MultiHeadRelativeAttention with + self-attention. + relative_position_encoding: relative positional encoding for key and + value. + target_mapping: Optional `Tensor` representing the target mapping used in + partial prediction. + segment_matrix: Optional `Tensor` representing segmentation IDs used in + XLNet. + segment_encoding: Optional `Tensor` representing the segmentation encoding + as used in XLNet. + segment_attention_bias: Optional trainable bias parameter added to the + query head when calculating the segment-based attention score. + state: (default None) optional state. If passed, this is also attended + over as in TransformerXL and XLNet. + content_attention_mask: (default None) Optional mask that is added to + content attention logits. If state is not None, the mask source sequence + dimension should extend M. + query_attention_mask: (default None) Optional mask that is added to query + attention logits. If state is not None, the mask source sequence + dimension should extend M. + Returns: + content_attention_output, query_attention_output: the results of the + computation, both of shape [B, T, E]. `T` is for target sequence shapes, + `E` is the query input last dimension if `output_shape` is `None`. + Otherwise, the multi-head outputs are projected to the shape specified + by `output_shape`. + """ + if not self._built_from_signature: + self._build_from_signature( + content_stream, content_stream, content_stream + ) + if state is not None and state.shape.ndims > 1: + content_and_memory_stream = tf.concat([state, content_stream], 1) + else: + content_and_memory_stream = content_stream + + # `query` = [B, T, N, H] + query = self._query_dense(content_stream) + + # `key` = [B, S + M, N, H] + key = self._key_dense(content_and_memory_stream) + + # `value` = [B, S + M, N, H] + value = self._value_dense(content_and_memory_stream) + + # `position` = [B, L, N, H] + position = self._encoding_dense(relative_position_encoding) + + content_attention_output = self.compute_attention( + query=query, + key=key, + value=value, + position=position, + content_attention_bias=content_attention_bias, + positional_attention_bias=positional_attention_bias, + segment_matrix=segment_matrix, + segment_encoding=segment_encoding, + segment_attention_bias=segment_attention_bias, + attention_mask=content_attention_mask, + ) + + # `content_attention_output` = [B, S, N, H] + content_attention_output = self._output_dense(content_attention_output) + + query_attention_output = None + if query_stream is not None: + query = self._query_dense(query_stream) + if target_mapping is not None: + query = tf.einsum("bmnd,bml->blnd", query, target_mapping) + query_attention_output = self.compute_attention( + query=query, + key=key, + value=value, + position=position, + content_attention_bias=content_attention_bias, + positional_attention_bias=positional_attention_bias, + segment_matrix=segment_matrix, + segment_encoding=segment_encoding, + segment_attention_bias=segment_attention_bias, + attention_mask=query_attention_mask, + ) + query_attention_output = tf.einsum( + "blnd,bml->bmnd", query_attention_output, target_mapping + ) + else: + query_attention_output = self.compute_attention( + query=query, + key=key, + value=value, + position=position, + content_attention_bias=content_attention_bias, + positional_attention_bias=positional_attention_bias, + segment_matrix=segment_matrix, + segment_encoding=segment_encoding, + segment_attention_bias=segment_attention_bias, + attention_mask=query_attention_mask, + ) + query_attention_output = self._output_dense(query_attention_output) + + return content_attention_output, query_attention_output diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py new file mode 100644 index 0000000000..d6e61ba978 --- /dev/null +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -0,0 +1,211 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""XLNet backbone model.""" + + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.position_embedding import PositionEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.xlnet.xlnet_encoder import XLNetEncoder + + +def xlnet_kernel_initializer(stddev=0.02): + return keras.initializers.TruncatedNormal(stddev=stddev) + + +@keras_nlp_export("keras_nlp.models.XLNetBackbone") +class XLNetBackbone(Backbone): + """XLNet encoder network. + + This class implements a XLNet Transformer. + + The default constructor gives a fully customizable, randomly initialized XLNet + encoder with any number of layers, heads, and embedding dimensions. To load + preset architectures and weights, use the `from_preset` constructor. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The size of the transformer encoding and pooler layers. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + dropout: float. Dropout probability for the Transformer encoder. + max_sequence_length: int. The maximum sequence length that this encoder + can consume. If None, `max_sequence_length` uses the value from + sequence length. This determines the variable shape for positional + embeddings. + num_segments: int. The number of types that the 'segment_ids' input can + take. + + Examples: + ```python + input_data = { + "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), + "segment_ids": tf.constant( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + "padding_mask": tf.constant( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + } + + # Randomly initialized XLNet encoder with a custom config + model = keras_nlp.models.XLNetBackbone( + vocabulary_size=30552, + num_layers=12, + num_heads=12, + hidden_dim=768, + intermediate_dim=3072, + max_sequence_length=12, + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + dropout=0.1, + max_sequence_length=512, + num_segments=2, + **kwargs, + ): + # Index of classification token in the vocabulary + cls_token_index = 0 + # Inputs + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Embed tokens, positions, and segment ids. + token_embedding_layer = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=xlnet_kernel_initializer(), + name="token_embedding", + ) + token_embedding = token_embedding_layer(token_id_input) + position_embedding = PositionEmbedding( + initializer=xlnet_kernel_initializer(), + sequence_length=max_sequence_length, + name="position_embedding", + )(token_embedding) + segment_embedding = keras.layers.Embedding( + input_dim=num_segments, + output_dim=hidden_dim, + embeddings_initializer=xlnet_kernel_initializer(), + name="segment_embedding", + )(segment_id_input) + + # Sum, normalize and apply dropout to embeddings. + x = keras.layers.Add()( + (token_embedding, position_embedding, segment_embedding) + ) + x = keras.layers.LayerNormalization( + name="embeddings_layer_norm", + axis=-1, + epsilon=1e-12, + dtype=tf.float32, + )(x) + x = keras.layers.Dropout( + dropout, + name="embeddings_dropout", + )(x) + + # Apply successive transformer encoder blocks. + for i in range(num_layers): + x = XLNetEncoder( + num_heads=num_heads, + intermediate_dim=intermediate_dim, + activation=lambda x: keras.activations.gelu( + x, approximate=True + ), + dropout=dropout, + layer_norm_epsilon=1e-12, + kernel_initializer=xlnet_kernel_initializer(), + name=f"transformer_layer_{i}", + )(x, padding_mask=padding_mask) + + # Construct the two XLNet outputs. The pooled output is a dense layer on + # top of the [CLS] token. + sequence_output = x + pooled_output = keras.layers.Dense( + hidden_dim, + kernel_initializer=xlnet_kernel_initializer(), + activation="tanh", + name="pooled_dense", + )(x[:, cls_token_index, :]) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs={ + "token_ids": token_id_input, + "segment_ids": segment_id_input, + "padding_mask": padding_mask, + }, + outputs={ + "sequence_output": sequence_output, + "pooled_output": pooled_output, + }, + **kwargs, + ) + + # All references to `self` below this line + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.max_sequence_length = max_sequence_length + self.num_segments = num_segments + self.cls_token_index = cls_token_index + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + } + ) + return config + + @property + def token_embedding(self): + return self.get_layer("token_embedding") diff --git a/keras_nlp/models/xlnet/xlnet_encoder.py b/keras_nlp/models/xlnet/xlnet_encoder.py new file mode 100644 index 0000000000..42196115d1 --- /dev/null +++ b/keras_nlp/models/xlnet/xlnet_encoder.py @@ -0,0 +1,250 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Transformer encoder block implementation based on `keras.layers.Layer`.""" + +import tensorflow as tf +from tensorflow import keras + +import keras_nlp +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.utils.keras_utils import clone_initializer + +from keras_nlp.layers.transformer_layer_utils import ( # isort:skip + merge_padding_and_attention_mask, +) + + +@keras_nlp_export("keras_nlp.layers.XLNetEncoder") +class XLNetEncoder(keras.layers.Layer): + """XLNet encoder. + + This class follows the architecture of the transformer encoder layer in the + paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users + can instantiate multiple instances of this class to stack up an encoder. + + This layer will correctly compute an attention mask from an implicit + Keras padding mask (for example, by passing `mask_zero=True` to a + `keras.layers.Embedding` layer). See the Masking and Padding + [guide](https://keras.io/guides/understanding_masking_and_padding/) + for more details. + + Args: + intermediate_dim: int, the hidden size of feedforward network. + num_heads: int, the number of heads in the + `keras.layers.MultiHeadRelativeAttention` layer. + dropout: float, defaults to 0. the dropout value, shared by + `keras.layers.MultiHeadRelativeAttention` and feedforward network. + activation: string or `keras.activations`, defaults to "relu". the + activation function of feedforward network. + layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer + normalization components. + kernel_initializer: string or `keras.initializers` initializer, + defaults to "glorot_uniform". The kernel initializer for + the dense and multiheaded relative attention layers. + bias_initializer: string or `keras.initializers` initializer, + defaults to "zeros". The bias initializer for + the dense and multiheaded relative attention layers. + normalize_first: bool. Defaults to False. If True, the inputs to the + attention layer and the intermediate dense layer are normalized + (similar to GPT-2). If set to False, outputs of attention layer and + intermediate dense layer are normalized (similar to XLNet). + name: string, defaults to None. The name of the layer. + **kwargs: other keyword arguments. + + Examples: + + ```python + # Create a single transformer encoder layer. + encoder = keras_nlp.layers.XLNetEncoder( + intermediate_dim=64, num_heads=8) + + # Create a simple model containing the encoder. + input = keras.Input(shape=[10, 64]) + output = encoder(input) + model = keras.Model(inputs=input, outputs=output) + + # Call encoder on the inputs. + input_data = tf.random.uniform(shape=[2, 10, 64]) + output = model(input_data) + ``` + + References: + - [](https://arxiv.org/abs/1906.08237) + """ + + def __init__( + self, + intermediate_dim, + num_heads, + dropout=0, + activation="relu", + layer_norm_epsilon=1e-05, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + normalize_first=False, + name=None, + **kwargs, + ): + # Work around for model saving + self._input_shape = kwargs.pop("build_input_shape", None) + + super().__init__(name=name, **kwargs) + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.dropout = dropout + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.normalize_first = normalize_first + self._built = False + self.supports_masking = True + + if self._input_shape is not None: + self._build(self._input_shape) + + def _build(self, input_shape): + # Create layers based on input shape. + self._built = True + self._input_shape = input_shape + # Infer the dimension of our hidden feature size from the build shape. + hidden_dim = input_shape[-1] + # Attention head size is `hidden_dim` over the number of heads. + key_dim = int(hidden_dim // self.num_heads) + + # Relaive attention layers. + self._relative_attention_layer = ( + keras_nlp.layers.MultiHeadRelativeAttention( + num_heads=self.num_heads, + key_dim=key_dim, + dropout=self.dropout, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + ) + ) + self._relative_attention_layer._build_from_signature( + query=input_shape, + value=input_shape, + content_attention_bias=self.bias_param1, + positional_attention_bias=self.bias_param2, + ) + + # Feedforward layers. + self._feedforward_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + ) + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + activation=self.activation, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + ) + self._feedforward_output_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + ) + self._feedforward_dropout = keras.layers.Dropout( + rate=self.dropout, + ) + + self.bias_param1 = tf.Variable( + shape=input_shape, + name="bias_param1", + initializer=tf.zeros_initializer(), + ) + self.bias_param2 = tf.Variable( + shape=input_shape, + name="bias_param2", + initializer=tf.zeros_initializer(), + ) + + def call(self, inputs, padding_mask=None, attention_mask=None): + """Forward pass of the XLNetEncoder. + + Args: + inputs: a Tensor. The input data to XLNetEncoder, should be + of shape [batch_size, sequence_length, hidden_dim]. + padding_mask: a boolean Tensor. It indicates if the token should be + masked because the token is introduced due to padding. + `padding_mask` should have shape [batch_size, sequence_length]. + attention_mask: a boolean Tensor. Customized mask used to mask out + certain tokens. `attention_mask` should have shape + [batch_size, sequence_length, sequence_length]. + + Returns: + A Tensor of the same shape as the `inputs`. + """ + + if not self._built: + self._build(inputs.shape) + + x = inputs # Intermediate result. + + # Compute self attention mask. + relative_attention_mask = merge_padding_and_attention_mask( + inputs, padding_mask, attention_mask + ) + + # Self attention block. + residual = x + if self.normalize_first: + x = self._relative_attention_layernorm(x) + x = self._relative_attention_layer( + query=x, + value=x, + attention_mask=relative_attention_mask, + content_attention_bias=self.bias_param1, + positional_attention_bias=self.bias_param2, + ) + x = self._relative_attention_dropout(x) + x = x + residual + if not self.normalize_first: + x = self._relative_attention_layernorm(x) + + # Feedforward block. + residual = x + if self.normalize_first: + x = self._feedforward_layernorm(x) + x = self._feedforward_intermediate_dense(x) + x = self._feedforward_output_dense(x) + x = self._feedforward_dropout(x) + x = x + residual + if not self.normalize_first: + x = self._feedforward_layernorm(x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "dropout": self.dropout, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "normalize_first": self.normalize_first, + "build_input_shape": self._input_shape, + } + ) + return config From c1235b6cbaf6de3e637f460e0658af22baf7103e Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Sun, 26 Mar 2023 23:52:14 +0530 Subject: [PATCH 2/6] code format --- keras_nlp/models/xlnet/xlnet_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/xlnet/xlnet_encoder.py b/keras_nlp/models/xlnet/xlnet_encoder.py index 42196115d1..eeedda3aae 100644 --- a/keras_nlp/models/xlnet/xlnet_encoder.py +++ b/keras_nlp/models/xlnet/xlnet_encoder.py @@ -141,7 +141,7 @@ def _build(self, input_shape): content_attention_bias=self.bias_param1, positional_attention_bias=self.bias_param2, ) - + # Feedforward layers. self._feedforward_layernorm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, From e6fe7dc86c32041b7181a76d94aeb949070c5e56 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Wed, 7 Jun 2023 10:12:20 +0530 Subject: [PATCH 3/6] check and rebase --- keras_nlp/models/xlnet/xlnet_backbone.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index d6e61ba978..aaacedcee6 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -209,3 +209,5 @@ def get_config(self): @property def token_embedding(self): return self.get_layer("token_embedding") + + From 4388c0d926f04d6f9ae41c11e4e3b8b9b1d65580 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Wed, 14 Jun 2023 03:34:26 +0530 Subject: [PATCH 4/6] encoder done + tested(w/o FF) --- keras_nlp/layers/relative_attention.py | 80 +++- keras_nlp/models/xlnet/xlnet_encoder.py | 511 ++++++++++++++---------- 2 files changed, 379 insertions(+), 212 deletions(-) diff --git a/keras_nlp/layers/relative_attention.py b/keras_nlp/layers/relative_attention.py index a3465fd763..b30e9d0f7b 100644 --- a/keras_nlp/layers/relative_attention.py +++ b/keras_nlp/layers/relative_attention.py @@ -113,19 +113,21 @@ def __init__(self, kernel_initializer="variance_scaling", **kwargs): super().__init__(kernel_initializer=kernel_initializer, **kwargs) def _build_from_signature(self, query, value, key=None): - super(MultiHeadRelativeAttention, self)._build_from_signature( - query=query, value=value, key=key - ) + self._built_from_signature = True + if hasattr(query, "shape"): + self._query_shape = tf.TensorShape(query.shape) + else: + self._query_shape = tf.TensorShape(query) if hasattr(value, "shape"): - value_shape = tf.TensorShape(value.shape) + self._value_shape = tf.TensorShape(value.shape) else: - value_shape = value + self._value_shape = value if key is None: - key_shape = value_shape + self._key_shape = value_shape elif hasattr(key, "shape"): - key_shape = tf.TensorShape(key.shape) + self._key_shape = tf.TensorShape(key.shape) else: - key_shape = key + self._key_shape = key common_kwargs = dict( kernel_initializer=self._kernel_initializer, @@ -138,8 +140,65 @@ def _build_from_signature(self, query, value, key=None): ) with tf.init_scope(): + free_dims = self._query_shape.rank - 1 + einsum_equation, _, output_rank = _build_proj_equation( + free_dims, bound_dims=1, output_dims=2 + ) + self._query_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=None, + name="query", + **common_kwargs, + ) + einsum_equation, _, output_rank = _build_proj_equation( + self._key_shape.rank - 1, bound_dims=1, output_dims=2 + ) + self._key_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=None, + name="key", + **common_kwargs, + ) einsum_equation, _, output_rank = _build_proj_equation( - key_shape.rank - 1, bound_dims=1, output_dims=2 + self._value_shape.rank - 1, bound_dims=1, output_dims=2 + ) + self._value_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._value_dim] + ), + bias_axes=None, + name="value", + **common_kwargs, + ) + self._build_attention(output_rank) + + + + # self._output_dense = self._make_output_dense( + # free_dims, + # common_kwargs, + # "attention_output", + # ) + + einsum_equation, _, output_rank = _build_proj_equation( + free_dims, bound_dims=2, output_dims=1 + ) + self._output_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, [self._query_shape[-1]]), + bias_axes=None, + name="attention_output", + **common_kwargs, + ) + einsum_equation, _, output_rank = _build_proj_equation( + self._key_shape.rank - 1, bound_dims=1, output_dims=2 ) self._encoding_dense = keras.layers.EinsumDense( einsum_equation, @@ -148,7 +207,7 @@ def _build_from_signature(self, query, value, key=None): ), bias_axes=None, name="encoding", - **common_kwargs + **common_kwargs, ) def compute_attention( @@ -235,6 +294,7 @@ def compute_attention( attention_output = tf.einsum( self._combine_equation, attention_output, value ) + return attention_output def call( diff --git a/keras_nlp/models/xlnet/xlnet_encoder.py b/keras_nlp/models/xlnet/xlnet_encoder.py index eeedda3aae..8f3998a3fa 100644 --- a/keras_nlp/models/xlnet/xlnet_encoder.py +++ b/keras_nlp/models/xlnet/xlnet_encoder.py @@ -22,229 +22,336 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.utils.keras_utils import clone_initializer +from tensorflow import keras +from keras_nlp.layers import TwoStreamRelativeAttention + from keras_nlp.layers.transformer_layer_utils import ( # isort:skip merge_padding_and_attention_mask, ) +# +# @keras_nlp_export("keras_nlp.layers.XLNetEncoder") +# class XLNetEncoder(keras.layers.Layer): +# """XLNet encoder. +# +# This class follows the architecture of the transformer encoder layer in the +# paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users +# can instantiate multiple instances of this class to stack up an encoder. +# +# This layer will correctly compute an attention mask from an implicit +# Keras padding mask (for example, by passing `mask_zero=True` to a +# `keras.layers.Embedding` layer). See the Masking and Padding +# [guide](https://keras.io/guides/understanding_masking_and_padding/) +# for more details. +# +# Args: +# intermediate_dim: int, the hidden size of feedforward network. +# num_heads: int, the number of heads in the +# `keras.layers.MultiHeadRelativeAttention` layer. +# dropout: float, defaults to 0. the dropout value, shared by +# `keras.layers.MultiHeadRelativeAttention` and feedforward network. +# activation: string or `keras.activations`, defaults to "relu". the +# activation function of feedforward network. +# layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer +# normalization components. +# kernel_initializer: string or `keras.initializers` initializer, +# defaults to "glorot_uniform". The kernel initializer for +# the dense and multiheaded relative attention layers. +# bias_initializer: string or `keras.initializers` initializer, +# defaults to "zeros". The bias initializer for +# the dense and multiheaded relative attention layers. +# normalize_first: bool. Defaults to False. If True, the inputs to the +# attention layer and the intermediate dense layer are normalized +# (similar to GPT-2). If set to False, outputs of attention layer and +# intermediate dense layer are normalized (similar to XLNet). +# name: string, defaults to None. The name of the layer. +# **kwargs: other keyword arguments. +# +# Examples: +# +# ```python +# # Create a single transformer encoder layer. +# encoder = keras_nlp.layers.XLNetEncoder( +# intermediate_dim=64, num_heads=8) +# +# # Create a simple model containing the encoder. +# input = keras.Input(shape=[10, 64]) +# output = encoder(input) +# model = keras.Model(inputs=input, outputs=output) +# +# # Call encoder on the inputs. +# input_data = tf.random.uniform(shape=[2, 10, 64]) +# output = model(input_data) +# ``` +# +# References: +# - [](https://arxiv.org/abs/1906.08237) +# """ +# +# def __init__( +# self, +# intermediate_dim, +# num_heads, +# dropout=0, +# activation="relu", +# layer_norm_epsilon=1e-05, +# kernel_initializer="glorot_uniform", +# bias_initializer="zeros", +# normalize_first=False, +# name=None, +# **kwargs, +# ): +# # Work around for model saving +# self._input_shape = kwargs.pop("build_input_shape", None) +# +# super().__init__(name=name, **kwargs) +# self.intermediate_dim = intermediate_dim +# self.num_heads = num_heads +# self.dropout = dropout +# self.activation = keras.activations.get(activation) +# self.layer_norm_epsilon = layer_norm_epsilon +# self.kernel_initializer = keras.initializers.get(kernel_initializer) +# self.bias_initializer = keras.initializers.get(bias_initializer) +# self.normalize_first = normalize_first +# self._built = False +# self.supports_masking = True +# +# if self._input_shape is not None: +# self._build(self._input_shape) +# +# def _build(self, input_shape): +# # Create layers based on input shape. +# self._built = True +# self._input_shape = input_shape +# # Infer the dimension of our hidden feature size from the build shape. +# hidden_dim = input_shape[-1] +# # Attention head size is `hidden_dim` over the number of heads. +# key_dim = int(hidden_dim // self.num_heads) +# +# # Relaive attention layers. +# self._relative_attention_layer = ( +# keras_nlp.layers.MultiHeadRelativeAttention( +# num_heads=self.num_heads, +# key_dim=key_dim, +# dropout=self.dropout, +# kernel_initializer=clone_initializer(self.kernel_initializer), +# bias_initializer=clone_initializer(self.bias_initializer), +# ) +# ) +# self._relative_attention_layer._build_from_signature( +# query=input_shape, +# value=input_shape, +# content_attention_bias=self.bias_param1, +# positional_attention_bias=self.bias_param2, +# ) +# +# # Feedforward layers. +# self._feedforward_layernorm = keras.layers.LayerNormalization( +# epsilon=self.layer_norm_epsilon, +# ) +# self._feedforward_intermediate_dense = keras.layers.Dense( +# self.intermediate_dim, +# activation=self.activation, +# kernel_initializer=clone_initializer(self.kernel_initializer), +# bias_initializer=clone_initializer(self.bias_initializer), +# ) +# self._feedforward_output_dense = keras.layers.Dense( +# hidden_dim, +# kernel_initializer=clone_initializer(self.kernel_initializer), +# bias_initializer=clone_initializer(self.bias_initializer), +# ) +# self._feedforward_dropout = keras.layers.Dropout( +# rate=self.dropout, +# ) +# +# self.bias_param1 = tf.Variable( +# shape=input_shape, +# name="bias_param1", +# initializer=tf.zeros_initializer(), +# ) +# self.bias_param2 = tf.Variable( +# shape=input_shape, +# name="bias_param2", +# initializer=tf.zeros_initializer(), +# ) +# +# def call(self, inputs, padding_mask=None, attention_mask=None): +# """Forward pass of the XLNetEncoder. +# +# Args: +# inputs: a Tensor. The input data to XLNetEncoder, should be +# of shape [batch_size, sequence_length, hidden_dim]. +# padding_mask: a boolean Tensor. It indicates if the token should be +# masked because the token is introduced due to padding. +# `padding_mask` should have shape [batch_size, sequence_length]. +# attention_mask: a boolean Tensor. Customized mask used to mask out +# certain tokens. `attention_mask` should have shape +# [batch_size, sequence_length, sequence_length]. +# +# Returns: +# A Tensor of the same shape as the `inputs`. +# """ +# +# if not self._built: +# self._build(inputs.shape) +# +# x = inputs # Intermediate result. +# +# # Compute self attention mask. +# relative_attention_mask = merge_padding_and_attention_mask( +# inputs, padding_mask, attention_mask +# ) +# +# # Self attention block. +# residual = x +# if self.normalize_first: +# x = self._relative_attention_layernorm(x) +# x = self._relative_attention_layer( +# query=x, +# value=x, +# attention_mask=relative_attention_mask, +# content_attention_bias=self.bias_param1, +# positional_attention_bias=self.bias_param2, +# ) +# x = self._relative_attention_dropout(x) +# x = x + residual +# if not self.normalize_first: +# x = self._relative_attention_layernorm(x) +# +# # Feedforward block. +# residual = x +# if self.normalize_first: +# x = self._feedforward_layernorm(x) +# x = self._feedforward_intermediate_dense(x) +# x = self._feedforward_output_dense(x) +# x = self._feedforward_dropout(x) +# x = x + residual +# if not self.normalize_first: +# x = self._feedforward_layernorm(x) +# +# return x +# +# def get_config(self): +# config = super().get_config() +# config.update( +# { +# "intermediate_dim": self.intermediate_dim, +# "num_heads": self.num_heads, +# "dropout": self.dropout, +# "activation": keras.activations.serialize(self.activation), +# "layer_norm_epsilon": self.layer_norm_epsilon, +# "kernel_initializer": keras.initializers.serialize( +# self.kernel_initializer +# ), +# "bias_initializer": keras.initializers.serialize( +# self.bias_initializer +# ), +# "normalize_first": self.normalize_first, +# "build_input_shape": self._input_shape, +# } +# ) +# return config +# +# + + -@keras_nlp_export("keras_nlp.layers.XLNetEncoder") class XLNetEncoder(keras.layers.Layer): - """XLNet encoder. - - This class follows the architecture of the transformer encoder layer in the - paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users - can instantiate multiple instances of this class to stack up an encoder. - - This layer will correctly compute an attention mask from an implicit - Keras padding mask (for example, by passing `mask_zero=True` to a - `keras.layers.Embedding` layer). See the Masking and Padding - [guide](https://keras.io/guides/understanding_masking_and_padding/) - for more details. - - Args: - intermediate_dim: int, the hidden size of feedforward network. - num_heads: int, the number of heads in the - `keras.layers.MultiHeadRelativeAttention` layer. - dropout: float, defaults to 0. the dropout value, shared by - `keras.layers.MultiHeadRelativeAttention` and feedforward network. - activation: string or `keras.activations`, defaults to "relu". the - activation function of feedforward network. - layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer - normalization components. - kernel_initializer: string or `keras.initializers` initializer, - defaults to "glorot_uniform". The kernel initializer for - the dense and multiheaded relative attention layers. - bias_initializer: string or `keras.initializers` initializer, - defaults to "zeros". The bias initializer for - the dense and multiheaded relative attention layers. - normalize_first: bool. Defaults to False. If True, the inputs to the - attention layer and the intermediate dense layer are normalized - (similar to GPT-2). If set to False, outputs of attention layer and - intermediate dense layer are normalized (similar to XLNet). - name: string, defaults to None. The name of the layer. - **kwargs: other keyword arguments. - - Examples: - - ```python - # Create a single transformer encoder layer. - encoder = keras_nlp.layers.XLNetEncoder( - intermediate_dim=64, num_heads=8) - - # Create a simple model containing the encoder. - input = keras.Input(shape=[10, 64]) - output = encoder(input) - model = keras.Model(inputs=input, outputs=output) - - # Call encoder on the inputs. - input_data = tf.random.uniform(shape=[2, 10, 64]) - output = model(input_data) - ``` - - References: - - [](https://arxiv.org/abs/1906.08237) - """ - - def __init__( - self, - intermediate_dim, - num_heads, - dropout=0, - activation="relu", - layer_norm_epsilon=1e-05, - kernel_initializer="glorot_uniform", - bias_initializer="zeros", - normalize_first=False, - name=None, - **kwargs, + def __init__(self, + intermediate_dim, + num_heads, + dim, + dropout=0, + layer_norm_epsilon=1e-12, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + name=None, + **kwargs ): - # Work around for model saving - self._input_shape = kwargs.pop("build_input_shape", None) super().__init__(name=name, **kwargs) self.intermediate_dim = intermediate_dim self.num_heads = num_heads + self.dim = dim self.dropout = dropout - self.activation = keras.activations.get(activation) self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) - self.normalize_first = normalize_first self._built = False self.supports_masking = True - if self._input_shape is not None: - self._build(self._input_shape) - - def _build(self, input_shape): - # Create layers based on input shape. - self._built = True - self._input_shape = input_shape - # Infer the dimension of our hidden feature size from the build shape. - hidden_dim = input_shape[-1] - # Attention head size is `hidden_dim` over the number of heads. - key_dim = int(hidden_dim // self.num_heads) - - # Relaive attention layers. - self._relative_attention_layer = ( - keras_nlp.layers.MultiHeadRelativeAttention( - num_heads=self.num_heads, - key_dim=key_dim, - dropout=self.dropout, - kernel_initializer=clone_initializer(self.kernel_initializer), - bias_initializer=clone_initializer(self.bias_initializer), - ) - ) - self._relative_attention_layer._build_from_signature( - query=input_shape, - value=input_shape, - content_attention_bias=self.bias_param1, - positional_attention_bias=self.bias_param2, - ) - - # Feedforward layers. - self._feedforward_layernorm = keras.layers.LayerNormalization( - epsilon=self.layer_norm_epsilon, + self.relative_attention = TwoStreamRelativeAttention(num_heads=self.num_heads, + key_dim=self.dim, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=self.layer_norm_epsilon, + name="layer_norm") + self.dropout = keras.layers.Dropout(self.dropout) + + def build(self, input_shape): + self.content_attention_bias = self.add_weight( + shape=(self.num_heads, self.dim), + initializer=self.bias_initializer, + trainable=True, + name="content_attention_bias" ) - self._feedforward_intermediate_dense = keras.layers.Dense( - self.intermediate_dim, - activation=self.activation, - kernel_initializer=clone_initializer(self.kernel_initializer), - bias_initializer=clone_initializer(self.bias_initializer), + self.positional_attention_bias = self.add_weight( + shape=(self.num_heads, self.dim), + initializer=self.bias_initializer, + trainable=True, + name="positional_attention_bias" ) - self._feedforward_output_dense = keras.layers.Dense( - hidden_dim, - kernel_initializer=clone_initializer(self.kernel_initializer), - bias_initializer=clone_initializer(self.bias_initializer), + self.segment_attention_bias = self.add_weight( + shape=(self.num_heads, self.dim), + initializer=self.bias_initializer, + trainable=True, + name="segment_attention_bias" ) - self._feedforward_dropout = keras.layers.Dropout( - rate=self.dropout, + self.segment_encoding = self.add_weight( + shape=(2, self.num_heads, self.dim), + initializer=self.kernel_initializer, + trainable=True, + name="segment_encoding" ) + super().build(input_shape) + + def call(self, + output_h, + output_g, + attn_mask_h, + attn_mask_g, + pos_emb, + seg_mat=None, + mems=None, + target_mapping=None, + training=False, + ): + + attn_out_h, attn_out_g = self.relative_attention(content_stream=output_h, + query_stream=output_g, + content_attention_mask=attn_mask_h, + query_attention_mask=attn_mask_g, + relative_position_encoding=pos_emb, + content_attention_bias=self.content_attention_bias, + positional_attention_bias=self.positional_attention_bias, + segment_attention_bias=self.segment_attention_bias, + segment_matrix=seg_mat, + segment_encoding=self.segment_encoding, + target_mapping=target_mapping, + state=mems, + ) + attn_out_h = self.dropout(attn_out_h, training=training) + attn_out_h = attn_out_h + output_h + attn_out_h = self.layer_norm(attn_out_h) + + if output_g is not None: + attn_out_g = self.dropout(attn_out_g) + attn_out_g = attn_out_g + output_g + attn_out_g = self.layer_norm(attn_out_g) + + return attn_out_h, attn_out_g + + return attn_out_h, None - self.bias_param1 = tf.Variable( - shape=input_shape, - name="bias_param1", - initializer=tf.zeros_initializer(), - ) - self.bias_param2 = tf.Variable( - shape=input_shape, - name="bias_param2", - initializer=tf.zeros_initializer(), - ) - - def call(self, inputs, padding_mask=None, attention_mask=None): - """Forward pass of the XLNetEncoder. - - Args: - inputs: a Tensor. The input data to XLNetEncoder, should be - of shape [batch_size, sequence_length, hidden_dim]. - padding_mask: a boolean Tensor. It indicates if the token should be - masked because the token is introduced due to padding. - `padding_mask` should have shape [batch_size, sequence_length]. - attention_mask: a boolean Tensor. Customized mask used to mask out - certain tokens. `attention_mask` should have shape - [batch_size, sequence_length, sequence_length]. - - Returns: - A Tensor of the same shape as the `inputs`. - """ - - if not self._built: - self._build(inputs.shape) - x = inputs # Intermediate result. - - # Compute self attention mask. - relative_attention_mask = merge_padding_and_attention_mask( - inputs, padding_mask, attention_mask - ) - - # Self attention block. - residual = x - if self.normalize_first: - x = self._relative_attention_layernorm(x) - x = self._relative_attention_layer( - query=x, - value=x, - attention_mask=relative_attention_mask, - content_attention_bias=self.bias_param1, - positional_attention_bias=self.bias_param2, - ) - x = self._relative_attention_dropout(x) - x = x + residual - if not self.normalize_first: - x = self._relative_attention_layernorm(x) - - # Feedforward block. - residual = x - if self.normalize_first: - x = self._feedforward_layernorm(x) - x = self._feedforward_intermediate_dense(x) - x = self._feedforward_output_dense(x) - x = self._feedforward_dropout(x) - x = x + residual - if not self.normalize_first: - x = self._feedforward_layernorm(x) - - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "intermediate_dim": self.intermediate_dim, - "num_heads": self.num_heads, - "dropout": self.dropout, - "activation": keras.activations.serialize(self.activation), - "layer_norm_epsilon": self.layer_norm_epsilon, - "kernel_initializer": keras.initializers.serialize( - self.kernel_initializer - ), - "bias_initializer": keras.initializers.serialize( - self.bias_initializer - ), - "normalize_first": self.normalize_first, - "build_input_shape": self._input_shape, - } - ) - return config From 1e5fdf92e5e8f55e4b4ff438e101597e64a0e57d Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Wed, 14 Jun 2023 03:36:20 +0530 Subject: [PATCH 5/6] . --- keras_nlp/layers/relative_attention.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/keras_nlp/layers/relative_attention.py b/keras_nlp/layers/relative_attention.py index b30e9d0f7b..fca8d92daa 100644 --- a/keras_nlp/layers/relative_attention.py +++ b/keras_nlp/layers/relative_attention.py @@ -179,14 +179,6 @@ def _build_from_signature(self, query, value, key=None): ) self._build_attention(output_rank) - - - # self._output_dense = self._make_output_dense( - # free_dims, - # common_kwargs, - # "attention_output", - # ) - einsum_equation, _, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=1 ) From 948d325231fb6709602ced4df620e9ca4460fcb7 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Wed, 14 Jun 2023 14:02:48 +0530 Subject: [PATCH 6/6] outputs same now --- keras_nlp/layers/relative_attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/relative_attention.py b/keras_nlp/layers/relative_attention.py index fca8d92daa..8b41eb44b7 100644 --- a/keras_nlp/layers/relative_attention.py +++ b/keras_nlp/layers/relative_attention.py @@ -179,16 +179,17 @@ def _build_from_signature(self, query, value, key=None): ) self._build_attention(output_rank) - einsum_equation, _, output_rank = _build_proj_equation( + _, _, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=1 ) self._output_dense = keras.layers.EinsumDense( - einsum_equation, + "ibnd,hnd->ibh", output_shape=_get_output_shape(output_rank - 1, [self._query_shape[-1]]), bias_axes=None, name="attention_output", **common_kwargs, ) + einsum_equation, _, output_rank = _build_proj_equation( self._key_shape.rank - 1, bound_dims=1, output_dims=2 )