-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTransformer
1 lines (1 loc) · 11.3 KB
/
Transformer
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyM0oRd78tGbzg154k4Tc2cF"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["import numpy as np\n","import tensorflow as tf"],"metadata":{"id":"0wHHeXn5ddmb","executionInfo":{"status":"ok","timestamp":1680756987392,"user_tz":-330,"elapsed":4274,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":1,"outputs":[]},{"cell_type":"code","execution_count":2,"metadata":{"id":"BiUuyuOEdMck","executionInfo":{"status":"ok","timestamp":1680756987392,"user_tz":-330,"elapsed":3,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"outputs":[],"source":["def positional_encoding(length, depth):\n"," depth = depth/2\n","\n"," positions = np.arange(length)[:, np.newaxis] # (seq, 1)\n"," depths = np.arange(depth)[np.newaxis, :]/depth # (1, depth)\n","\n"," angle_rates = 1 / (10000**depths) # (1, depth)\n"," angle_rads = positions * angle_rates # (pos, depth)\n","\n"," pos_encoding = np.concatenate(\n"," [np.sin(angle_rads), np.cos(angle_rads)],\n"," axis=-1) \n","\n"," return tf.cast(pos_encoding, dtype=tf.float32)"]},{"cell_type":"code","source":["class PositionalEmbedding(tf.keras.layers.Layer):\n"," def __init__(self, vocab_size, d_model):\n"," super().__init__()\n"," self.d_model = d_model\n"," self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) \n"," self.pos_encoding = positional_encoding(length=2048, depth=d_model)\n","\n"," def compute_mask(self, *args, **kwargs):\n"," return self.embedding.compute_mask(*args, **kwargs)\n","\n"," def call(self, x):\n"," length = tf.shape(x)[1]\n"," x = self.embedding(x)\n"," # This factor sets the relative scale of the embedding and positonal_encoding.\n"," x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))\n"," x = x + self.pos_encoding[tf.newaxis, :length, :]\n"," return x"],"metadata":{"id":"J8H3OSgPdpzp","executionInfo":{"status":"ok","timestamp":1680756987392,"user_tz":-330,"elapsed":3,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","source":["class BaseAttention(tf.keras.layers.Layer):\n"," def __init__(self, **kwargs):\n"," super().__init__()\n"," self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)\n"," self.layernorm = tf.keras.layers.LayerNormalization()\n"," self.add = tf.keras.layers.Add()"],"metadata":{"id":"mWXLidfXduq6","executionInfo":{"status":"ok","timestamp":1680756987392,"user_tz":-330,"elapsed":3,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["class CrossAttention(BaseAttention):\n"," def call(self, x, context):\n"," attn_output, attn_scores = self.mha(\n"," query=x,\n"," key=context,\n"," value=context,\n"," return_attention_scores=True)\n","\n"," # Cache the attention scores for plotting later.\n"," self.last_attn_scores = attn_scores\n","\n"," x = self.add([x, attn_output])\n"," x = self.layernorm(x)\n","\n"," return x"],"metadata":{"id":"tQbeYXBAd4EA","executionInfo":{"status":"ok","timestamp":1680756987392,"user_tz":-330,"elapsed":2,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["class GlobalSelfAttention(BaseAttention):\n"," def call(self, x):\n"," attn_output = self.mha(\n"," query=x,\n"," value=x,\n"," key=x)\n"," x = self.add([x, attn_output])\n"," x = self.layernorm(x)\n"," return x"],"metadata":{"id":"uWnJdffnd8vM","executionInfo":{"status":"ok","timestamp":1680756987917,"user_tz":-330,"elapsed":527,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":6,"outputs":[]},{"cell_type":"code","source":["class CausalSelfAttention(BaseAttention):\n"," def call(self, x):\n"," attn_output = self.mha(\n"," query=x,\n"," value=x,\n"," key=x,\n"," use_causal_mask = True)\n"," x = self.add([x, attn_output])\n"," x = self.layernorm(x)\n"," return x"],"metadata":{"id":"mNnHMCj_eBu3","executionInfo":{"status":"ok","timestamp":1680756987918,"user_tz":-330,"elapsed":3,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["class FeedForward(tf.keras.layers.Layer):\n"," def __init__(self, d_model, dff, dropout_rate=0.1):\n"," super().__init__()\n"," self.seq = tf.keras.Sequential([\n"," tf.keras.layers.Dense(dff, activation='relu'),\n"," tf.keras.layers.Dense(d_model),\n"," tf.keras.layers.Dropout(dropout_rate)\n"," ])\n"," self.add = tf.keras.layers.Add()\n"," self.layer_norm = tf.keras.layers.LayerNormalization()\n","\n"," def call(self, x):\n"," x = self.add([x, self.seq(x)])\n"," x = self.layer_norm(x) \n"," return x"],"metadata":{"id":"Tq5g0iDMeTRu","executionInfo":{"status":"ok","timestamp":1680756987918,"user_tz":-330,"elapsed":3,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","source":["class EncoderLayer(tf.keras.layers.Layer):\n"," def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):\n"," super().__init__()\n","\n"," self.self_attention = GlobalSelfAttention(\n"," num_heads=num_heads,\n"," key_dim=d_model,\n"," dropout=dropout_rate)\n","\n"," self.ffn = FeedForward(d_model, dff)\n","\n"," def call(self, x):\n"," x = self.self_attention(x)\n"," x = self.ffn(x)\n"," return x"],"metadata":{"id":"wMRJsX5QeW3A","executionInfo":{"status":"ok","timestamp":1680756987918,"user_tz":-330,"elapsed":3,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["class Encoder(tf.keras.layers.Layer):\n"," def __init__(self, *, num_layers, d_model, num_heads,\n"," dff, vocab_size, dropout_rate=0.1):\n"," super().__init__()\n","\n"," self.d_model = d_model\n"," self.num_layers = num_layers\n","\n"," self.pos_embedding = PositionalEmbedding(\n"," vocab_size=vocab_size, d_model=d_model)\n","\n"," self.enc_layers = [\n"," EncoderLayer(d_model=d_model,\n"," num_heads=num_heads,\n"," dff=dff,\n"," dropout_rate=dropout_rate)\n"," for _ in range(num_layers)]\n"," self.dropout = tf.keras.layers.Dropout(dropout_rate)\n","\n"," def call(self, x):\n"," # `x` is token-IDs shape: (batch, seq_len)\n"," x = self.pos_embedding(x) # Shape `(batch_size, seq_len, d_model)`.\n","\n"," # Add dropout.\n"," x = self.dropout(x)\n","\n"," for i in range(self.num_layers):\n"," x = self.enc_layers[i](x)\n","\n"," return x # Shape `(batch_size, seq_len, d_model)`."],"metadata":{"id":"swoe9CMoeZD8","executionInfo":{"status":"ok","timestamp":1680756987918,"user_tz":-330,"elapsed":3,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["class DecoderLayer(tf.keras.layers.Layer):\n"," def __init__(self,\n"," *,\n"," d_model,\n"," num_heads,\n"," dff,\n"," dropout_rate=0.1):\n"," super(DecoderLayer, self).__init__()\n","\n"," self.causal_self_attention = CausalSelfAttention(\n"," num_heads=num_heads,\n"," key_dim=d_model,\n"," dropout=dropout_rate)\n","\n"," self.cross_attention = CrossAttention(\n"," num_heads=num_heads,\n"," key_dim=d_model,\n"," dropout=dropout_rate)\n","\n"," self.ffn = FeedForward(d_model, dff)\n","\n"," def call(self, x, context):\n"," x = self.causal_self_attention(x=x)\n"," print(x.shape)\n"," x = self.cross_attention(x=x, context=context)\n","\n"," # Cache the last attention scores for plotting later\n"," self.last_attn_scores = self.cross_attention.last_attn_scores\n","\n"," x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`.\n"," return x"],"metadata":{"id":"oT3YWMzbedNS","executionInfo":{"status":"ok","timestamp":1680756987918,"user_tz":-330,"elapsed":2,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["class Decoder(tf.keras.layers.Layer):\n"," def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,\n"," dropout_rate=0.1):\n"," super(Decoder, self).__init__()\n","\n"," self.d_model = d_model\n"," self.num_layers = num_layers\n","\n"," self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,\n"," d_model=d_model)\n"," self.dropout = tf.keras.layers.Dropout(dropout_rate)\n"," self.dec_layers = [\n"," DecoderLayer(d_model=d_model, num_heads=num_heads,\n"," dff=dff, dropout_rate=dropout_rate)\n"," for _ in range(num_layers)]\n","\n"," self.last_attn_scores = None\n","\n"," def call(self, x, context):\n"," # `x` is token-IDs shape (batch, target_seq_len)\n"," x = self.pos_embedding(x) # (batch_size, target_seq_len, d_model)\n"," # print(x.shape)\n"," # print(context.shape)\n","\n"," x = self.dropout(x)\n","\n"," for i in range(self.num_layers):\n"," x = self.dec_layers[i](x, context)\n","\n"," self.last_attn_scores = self.dec_layers[-1].last_attn_scores\n","\n"," # The shape of x is (batch_size, target_seq_len, d_model).\n"," return x"],"metadata":{"id":"l8Xyin5Dekv1","executionInfo":{"status":"ok","timestamp":1680756987918,"user_tz":-330,"elapsed":2,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["inputs1 = tf.keras.layers.Input(shape=(1,256))\n","inputs2 = tf.keras.layers.Input(shape=((34,)))"],"metadata":{"id":"b4KHienXeo_c","executionInfo":{"status":"ok","timestamp":1680757061001,"user_tz":-330,"elapsed":663,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":16,"outputs":[]},{"cell_type":"code","source":["sample_decoder = Decoder(num_layers=4,\n"," d_model=256,\n"," num_heads=8,\n"," dff=2048,\n"," vocab_size=8485)"],"metadata":{"id":"Z37GIXU7TsFq","executionInfo":{"status":"ok","timestamp":1680757062496,"user_tz":-330,"elapsed":2,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}}},"execution_count":17,"outputs":[]},{"cell_type":"code","source":["output = sample_decoder(\n"," x=inputs2,\n"," context=inputs1)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"I0BZ3nLyTx04","executionInfo":{"status":"ok","timestamp":1680757066726,"user_tz":-330,"elapsed":1815,"user":{"displayName":"Keshav Agrawal","userId":"03859630670343900213"}},"outputId":"873ef436-8ed9-4cf3-bbb3-a8b7b47ab0fd"},"execution_count":18,"outputs":[{"output_type":"stream","name":"stdout","text":["(None, 34, 256)\n","(None, 34, 256)\n","(None, 34, 256)\n","(None, 34, 256)\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"AkmRXQflT3x7"},"execution_count":null,"outputs":[]}]}