From ea472e2131f809aa61a03069955c031afe4576bb Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 09:50:00 -0700 Subject: [PATCH 01/18] Add spec for multi-head attention --- .../src/layers/nlp/multihead_attention.ts | 444 ++++++++++++++++++ 1 file changed, 444 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/multihead_attention.ts diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts new file mode 100644 index 00000000000..84720625617 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -0,0 +1,444 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * TFJS-based multi-head attention layer. + */ + +/* Original source: keras/layers/attention/multi_head_attention.py */ +import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { ConstraintIdentifier } from '../../constraints'; +import { Layer, LayerArgs } from '../../engine/topology'; +import { NotImplementedError } from '../../errors'; +import { InitializerIdentifier } from '../../initializers'; +import { Shape } from '../../keras_format/common'; +import { RegularizerIdentifier } from '../../regularizers'; +import { Kwargs } from '../../types'; + +/** + * Builds einsum equations for the attention computation. + * + * Query, key, value inputs after projection are expected to have the shape as: + * `(bs, , , num_heads, channels)`. + * `bs` and `` are treated as ``. + + * The attention operations can be generalized: + * (1) Query-key dot product: + * `(, , num_heads, channels), (, + * , num_heads, channels) -> (, + * num_heads, , )` + * (2) Combination: + * `(, num_heads, , ), + * (, , num_heads, channels) -> (, , num_heads, channels)` + * + * @param rank Rank of query, key, value tensors. + * @param attnAxes Array of axes, `[-1, rank)`, + * that attention will be applied to. + * @returns Einsum equations. + */ +function buildAttentionEquation( + rank: number, attnAxes: Array +): [string, string, string] { + throw new NotImplementedError('Not implemented yet.'); +} + +/** + * Builds an einsum equation for projections inside multi-head attention. + */ +function buildProjectionEquation( + freeDims: number, boundDims: number, outputDims: number +): [string, string, number] { + throw new NotImplementedError('Not implemented yet.'); +} + +function getOutputShape( + outputRank: number, knownLastDims: Iterable +): Shape { + throw new NotImplementedError('Not implemented yet.'); +} + +export declare interface MultiHeadAttentionArgs extends LayerArgs { + /** + * Integer. Number of attention heads. + */ + numHeads: number, + + /** + * Integer. Size of each attention head for query and key. + */ + keyDim: number, + + /** + * Integer. Size of each attention head for value. + * Defaults to `keyDim`. + */ + valueDim?: number, + + /** + * Dropout probability. + * Defaults to 0.0. + */ + dropout?: number, + + /** + * Whether the dense layers use bias vectors/matrices. + * Defaults to true. + */ + useBias?: boolean, + + /** + * The expected shape of an output tensor, besides the batch + * and sequence dims. If not specified, projects back to the query + * feature dim (the query input's last dimension). + */ + outputShape?: Shape, + + /** + * Axes over which the attention is applied. `null` means attention over + * all axes, but batch, heads, and features. + */ + attentionAxes: Array, + + /** + * Initializer for dense layer kernels. + * Defaults to `"glorotUniform"`. + */ + kernelInitializer?: InitializerIdentifier, + + /** + * Initializer for dense layer biases. + * Defaults to `"zeros"`. + */ + biasInitializer?: InitializerIdentifier, + + /** + * Regularizer for dense layer kernels. + */ + kernelRegularizer?: RegularizerIdentifier, + + /** + * Regularizer for dense layer biases. + */ + biasRegularizer?: RegularizerIdentifier, + + /** + * Regularizer for dense layer activity. + */ + activityRegularizer?: RegularizerIdentifier, + + /** + * Constraint for dense layer kernels. + */ + kernelConstraint?: ConstraintIdentifier, + + /** + * Constraint for dense layer kernels. + */ + biasConstraint?: ConstraintIdentifier, +} + +export declare interface MultiHeadAttentionOptions { + /** + * Query `Tensor` of shape `(B, T, dim)`. + */ + + /** + * Value `Tensor` of shape `(B, S, dim)`. + */ + value: Tensor, + + /** + * Key `Tensor` of shape `(B, S, dim)`. If not given, will use `value` for + * both `key` and `value`, which is the most common case. + */ + key?: Tensor, + + /** + * A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions. The boolean mask specifies which + * query elements can attend to which key elements, 1 indicates + * attention and 0 indicates no attention. Broadcasting can happen for + * the missing batch dimensions and the head dimension. + */ + attentionMask?: Tensor, + + /** + * Indicates whether the layer should behave in training mode + * (adding dropout) or in inference mode (no dropout). + * Will go with either using the training mode of the parent + * layer/model, or false (inference) if there is no parent layer. + */ + training?: boolean, + + /** + * Indicates whether to apply a causal mask to prevent tokens from attending + * to future tokens (e.g., used in a decoder Transformer). + * Defaults to false. + */ + useCausalMask?: boolean, +} + +/** + * MultiHeadAttention layer. + * + * This is an implementation of multi-headed attention as described in the + * paper "Attention is all you Need" (Vaswani et al., 2017). + * If `query`, `key,` `value` are the same, then + * this is self-attention. Each timestep in `query` attends to the + * corresponding sequence in `key`, and returns a fixed-width vector. + * + * This layer first projects `query`, `key` and `value`. These are + * (effectively) a list of tensors of length `numAttentionHeads`, where the + * corresponding shapes are `(batchSize, , keyDim)`, + * `(batchSize, , keyDim)`, + * `(batchSize, , valueDim)`. + * + * Then, the query and key tensors are dot-producted and scaled. These are + * softmaxed to obtain attention probabilities. The value tensors are then + * interpolated by these probabilities, then concatenated back to a single + * tensor. + * + * Finally, the result tensor with the last dimension as valueDim can take an + * linear projection and return. + + * When using `MultiHeadAttention` inside a custom layer, the custom layer must + * implement its own `build()` method and call `MultiHeadAttention`'s + * `buildFromSignature()` there. + * This enables weights to be restored correctly when the model is loaded. + * + * Examples: + * + * Performs 1D cross-attention over two sequence inputs with an attention mask. + * Returns the additional attention weights over heads. + * + * const layer = new MultiHeadAttention({numHeads=2, keyDim=2}); + * const target = tf.keras.Input(shape=[8, 16]); + * const source = tf.keras.Input(shape=[4, 16]); + * const outputTensor, weights = layer.callAndReturnAttentionScores( + * target, {value: source}); + * console.log(outputTensor.shape); // [null, 8, 16] + * console.log(weights.shape); // [null, 2, 8, 4] + * + * Performs 2D self-attention over a 5D input tensor on axes 2 and 3. + * + * const layer = new MultiHeadAttention({ + * numHeads: 2, keyDim: 2, attentionAxes: [2, 3]}); + * const inputTensor = tf.input({shape: [5, 3, 4, 16]}); + * const outputTensor = layer.call(inputTensor, {value: inputTensor}); + * console.log(outputTensor.shape); // [null, 5, 3, 4, 16] + * + * Returns: + * attentionOutput: The result of the computation, of shape `(B, T, E)`, + * where `T` is for target sequence shapes and `E` is the query input + * last dimension if `outputShape` is `None`. Otherwise, the + * multi-head outputs are projected to the shape specified by + * `outputShape`. + * attentionScores: multi-head attention coefficients over attention axes. + */ +export class MultiHeadAttention extends Layer { + /** @nocollapse */ + static readonly className = 'MultiHeadAttention'; + + constructor(args: MultiHeadAttentionArgs) { + super(args); + throw new NotImplementedError('Not implemented yet.'); + } + + override getConfig(): serialization.ConfigDict { + throw new NotImplementedError('Not implemented yet.'); + } + + static override fromConfig( + cls: serialization.SerializableConstructor, + config: serialization.ConfigDict + ): T { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Builds layers and variables. + * + * Once the method is called, this.builtFromSignature will be set to true. + */ + private buildFromSignature(query: Tensor, value: Tensor, key?: Tensor) { + throw new NotImplementedError( + `Not implemented yet. Uses ${buildProjectionEquation}, ${getOutputShape}, + ${this.getCommonKwargsForSublayer}, ${this.buildAttention}, + ${this.makeOutputDense}.`); + } + + private getCommonKwargsForSublayer(): Kwargs { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Builds the output projection matrix. + * + * @param freeDims Number of free dimensions for einsum equation building. + * @param commonKwargs Common keyword arguments for einsum layer. + * @param name Name for the projection layer. + * @returns Projection layer. + */ + private makeOutputDense( + freeDims: number, commonKwargs: Kwargs, name?: string + ): Kwargs { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Builds multi-head dot-product attention computations. + * + * This function builds attributes necessary for `_compute_attention` to + * customize attention computation to replace the default dot-product + * attention. + * + * @param rank The rank of query, key, value tensors. + */ + private buildAttention(rank: number) { + throw new NotImplementedError( + `Not implemented yet. Uses ${buildAttentionEquation}.`); + } + + private maskedSoftmax( + attentionScores: Tensor, attentionMask?: Tensor + ): Tensor|Tensor[] { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Applies Dot-product attention with query, key, value tensors. + * + * This function defines the computation inside `call` with projected + * multi-head Q, K, V inputs. Users can override this function for + * customized attention implementation. + * + * @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`. + * @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`. + * @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`. + * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions. It is generally not needed if + * the `query` and `value` (and/or `key`) are masked. + * @param training Boolean indicating whether the layer should behave + * in training mode (adding dropout) or in inference mode (doing + * nothing). + * @returns attentionOutput: Multi-headed outputs of attention computation. + * @returns attentionScores: Multi-headed attention weights. + */ + private computeAttention( + query: Tensor, + key: Tensor, + value: Tensor, + attentionMask?: Tensor, + training?: boolean + ): [Tensor, Tensor] { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.maskedSoftmax}.`); + } + + override call( + query: Tensor, kwargs: MultiHeadAttentionOptions + ): Tensor|Tensor2D { + return this.callAndReturnAttentionScores(query, kwargs)[0]; + } + + /** + * Exactly like `call` except also returns the attention scores. + */ + callAndReturnAttentionScores( + query: Tensor, kwargs: MultiHeadAttentionOptions + ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D] { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.buildFromSignature}, + ${this.computeAttentionMask}, ${this.computeAttention}.`); + } + + /** + * Computes the attention mask. + * + * * The `query`'s mask is reshaped from [B, T] to [B, T, 1]. + * * The `value`'s mask is reshaped from [B, S] to [B, 1, S]. + * * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s + * mask is ignored if `key` is `None` or if `key is value`. + * * If `useCausalMask=true`, then the causal mask is computed. Its shape + * is [1, T, S]. + * + * All defined masks are merged using a logical AND operation (`&`). + * + * In general, if the `query` and `value` are masked, then there is no need + * to define the `attentionMask`. + * + * @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`. + * @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`. + * @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`. + * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions. + * @param useCausalMask A boolean to indicate whether to apply a causal + * mask to prevent tokens from attending to future tokens (e.g., + * used in a decoder Transformer). + * @returns attentionMask: A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions, based on the Keras masks of the + * `query`, `key`, `value`, and `attentionMask` tensors, and the + * causal mask if `useCausalMask=true`. + */ + private computeAttentionMask( + query: Tensor, + value: Tensor, + key?: Tensor, + attentionMask?: Tensor, + useCausalMask?: boolean + ): Tensor { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.computeCasualMask}`); + } + + /** + * Computes a causal mask (e.g., for masked self-attention layers). + * + * For example, if query and value both contain sequences of length 4, + * this function returns a boolean `Tensor` equal to: + * + * ``` + * [[[true, false, false, false], + * [true, true, false, false], + * [true, true, true, false], + * [true, true, true, true]]] + * ``` + * + * @param query query `Tensor` of shape `(B, T, ...)`. + * @param value value `Tensor` of shape `(B, S, ...)` (defaults to query). + * @returns mask: A boolean `Tensor` of shape [1, T, S] containing a lower + * triangular matrix of shape [T, S]. + */ + private computeCasualMask(query: Tensor, value?: Tensor): Tensor { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * + * @param inputShapes A list of [queryShape, valueShape] or + * [queryShape, valueShape, keyShape]. If no keyShape provided, valueShape + * is assumed as the keyShape. + */ + override computeOutputShape( + inputShapes: [Shape, Shape] | [Shape, Shape, Shape] + ): Shape { + throw new NotImplementedError('Not implemented yet.'); + } +} +serialization.registerClass(MultiHeadAttention); From 41a105e6a7948bf90a22c4a6d11130f3be5fcbf9 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 14:13:40 -0700 Subject: [PATCH 02/18] Add CachedMultiHeadAttention cache --- .../modeling/cached_multihead_attention.ts | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts diff --git a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts new file mode 100644 index 00000000000..27fedbbc6a1 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts @@ -0,0 +1,122 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Cached MHA layer based on `MultiHeadAttention`. + */ + +/* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */ +import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { MultiHeadAttention } from '../multihead_attention'; +import { NotImplementedError } from '../../../errors'; + +export declare interface CachedMultiHeadAttentionOptions { + /** + * Query `Tensor` of shape `(B, T, dim)`. + */ + + /** + * Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` + * must equal `S` and match the shape of `attentionMask`. If `cache` is + * not `null`, `S*` can be any length less than `S`, and the computed + * value will be spliced into `cache` at `cacheUpdateIndex`. + */ + value: Tensor, + + /** + * Key `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` must + * equal `S` and match the shape of `attentionMask`. If `cache` is not `null`, + * `S*` can be any length less than `S`, and the computed value will be + * spliced into `cache` at `cacheUpdateIndex`. + */ + key?: Tensor, + + /** + * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents + * attention to certain positions. The boolean mask specifies which + * query elements can attend to which key elements, 1 indicates + * attention and 0 indicates no attention. Broadcasting can happen for + * the missing batch dimensions and the head dimension. + */ + attentionMask?: Tensor, + + /** + * A dense float Tensor. The key/value cache, of shape + * `[B, 2, S, num_heads, key_dims]`, where `S` must agree with the + * `attentionMask` shape. This argument is intended for use during + * generation to avoid recomputing intermediate state. + */ + cache?: Tensor + + /** + * Integer or Integer `Tensor`. The index at which to update `cache` + * (usually the index of the current token being processed when running + * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache + * will not be updated. + */ + cacheUpdateIndex?: number|Tensor +} + +/** + * MultiHeadAttention layer with cache support. + * + * This layer is suitable for use in autoregressive decoding. It can be use + * to cache decoder self-attention and cross-attention. The forward pass + * can happen in one of three modes: + * - No cache, same as regular multi-head attention. + * - Static cache (`cache_update_index` is None). In this case, the + * cached key/value projections will be used and the input values will + * be ignored. + * - Updated cache (`cache_update_index` is not None). In this case, new + * key/value projections are computed using the input, and spliced into + * the cache at the specified index. + * + * Note that caching is useful only during inference and should not be used + * during training. + * + * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension, + * `T` is the target sequence length, and `S` in the source sequence length. + * Note that during generative decoding, `T` is usually 1 (you are + * generating a target sequence of length one to predict the next token). + * + * Returns: + * An `(attentionOutput, cache)` tuple. `attentionOutput` is the result + * of the computation, of shape `(B, T, dim)`, where `T` is for target + * sequence shapes and `dim` is the query input last dimension if + * `outputShape` is `null`. Otherwise, the multi-head outputs are + * projected to the shape specified by `outputShape`. `cache` is the + * updated cache. + */ +export class CachedMultiHeadAttention extends MultiHeadAttention { + + override call( + query: Tensor, kwargs: CachedMultiHeadAttentionOptions + ): Tensor|Tensor2D { + return this.callAndReturnCache(query, kwargs)[0]; + } + + /** + * Exactly like `call` except also returns the updated cache. + */ + callAndReturnCache( + query: Tensor, kwargs: CachedMultiHeadAttentionOptions + ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D] { + throw new NotImplementedError(`Not implemented yet.`); + } +} +serialization.registerClass(CachedMultiHeadAttention); From 6e78ffcffd0f442660934a7cc06eadba06b9c415 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 16:11:14 -0700 Subject: [PATCH 03/18] Fix typos --- .../src/layers/nlp/modeling/cached_multihead_attention.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts index 27fedbbc6a1..f28c7aaa7ce 100644 --- a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts @@ -57,7 +57,7 @@ export declare interface CachedMultiHeadAttentionOptions { /** * A dense float Tensor. The key/value cache, of shape - * `[B, 2, S, num_heads, key_dims]`, where `S` must agree with the + * `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the * `attentionMask` shape. This argument is intended for use during * generation to avoid recomputing intermediate state. */ @@ -79,10 +79,10 @@ export declare interface CachedMultiHeadAttentionOptions { * to cache decoder self-attention and cross-attention. The forward pass * can happen in one of three modes: * - No cache, same as regular multi-head attention. - * - Static cache (`cache_update_index` is None). In this case, the + * - Static cache (`cacheUpdateIndex` is None). In this case, the * cached key/value projections will be used and the input values will * be ignored. - * - Updated cache (`cache_update_index` is not None). In this case, new + * - Updated cache (`cacheUpdateIndex` is not None). In this case, new * key/value projections are computed using the input, and spliced into * the cache at the specified index. * From 01d9e2e64ed28ef9906389a72d70aa3e34828017 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 16:13:21 -0700 Subject: [PATCH 04/18] Lint --- .../layers/nlp/modeling/cached_multihead_attention.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts index f28c7aaa7ce..19a0d8890e2 100644 --- a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts @@ -36,7 +36,7 @@ export declare interface CachedMultiHeadAttentionOptions { * not `null`, `S*` can be any length less than `S`, and the computed * value will be spliced into `cache` at `cacheUpdateIndex`. */ - value: Tensor, + value: Tensor; /** * Key `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` must @@ -44,7 +44,7 @@ export declare interface CachedMultiHeadAttentionOptions { * `S*` can be any length less than `S`, and the computed value will be * spliced into `cache` at `cacheUpdateIndex`. */ - key?: Tensor, + key?: Tensor; /** * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents @@ -53,7 +53,7 @@ export declare interface CachedMultiHeadAttentionOptions { * attention and 0 indicates no attention. Broadcasting can happen for * the missing batch dimensions and the head dimension. */ - attentionMask?: Tensor, + attentionMask?: Tensor; /** * A dense float Tensor. The key/value cache, of shape @@ -61,7 +61,7 @@ export declare interface CachedMultiHeadAttentionOptions { * `attentionMask` shape. This argument is intended for use during * generation to avoid recomputing intermediate state. */ - cache?: Tensor + cache?: Tensor; /** * Integer or Integer `Tensor`. The index at which to update `cache` @@ -69,7 +69,7 @@ export declare interface CachedMultiHeadAttentionOptions { * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache * will not be updated. */ - cacheUpdateIndex?: number|Tensor + cacheUpdateIndex?: number|Tensor; } /** From 8f08c19c67d8b4ee48cd83bbac85afc25c027fd1 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 16:57:58 -0700 Subject: [PATCH 05/18] Add Transformer Decoder spec --- .../nlp/modeling/transformer_decoder.ts | 265 ++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts diff --git a/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts new file mode 100644 index 00000000000..36304db5866 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts @@ -0,0 +1,265 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Transformer decoder block implementation based on TFJS `Layer`. + */ + +/* Original source: keras_nlp/layers/modeling/transformer_decoder.py */ +import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { Layer, LayerArgs, } from '../../../engine/topology'; +import { NotImplementedError } from '../../../errors'; +import { InitializerIdentifier } from '../../../initializers'; +import { ActivationIdentifier } from '../../../keras_format/activation_config'; +import { Shape } from '../../../keras_format/common'; + +export declare interface TransformerDecoderArgs extends LayerArgs { + /** + * Integer. The hidden size of feedforward network. + */ + intermediateDim: number; + + /** + * Integer. The number of heads in MultiHeadAttention. + */ + numHeads: number; + + /** + * The dropout value, shared by MultiHeadAttention and feedforward network. + * Defaults to `0.`. + */ + dropout?: number; + + /** + * The activation function of feedforward network. + * Defaults to `"relu"`. + */ + activation?: ActivationIdentifier + + /** + * The eps value in layer normalization components. + * Defaults to `1e-5`. + */ + layerNormEpsilon?: number; + + /** + * The kernel initializer for the dense and multiheaded attention layers. + * Defaults to `"glorotUniform"`. + */ + kernelInitializer?: InitializerIdentifier; + + /** + * The bias initializer for the dense and multiheaded attention layers. + * Defaults to `"zeros"`. + */ + biasInitializer?: InitializerIdentifier; + + /** + * If true, the inputs to the attention layer(s) and the intermediate dense + * layer are normalized (similar to GPT-2). If set to false, outputs of + * attention layer and intermediate dense layer are normalized + * (similar to BERT). + * Defaults to `false`. + */ + normalizeFirst: boolean; +} + +export declare interface TransformerDecoderOptions { + /** + * decoderSequence: The decode input sequence. + */ + + /** + * The encoder input sequence. For decoder only models (like GPT2), this + * should be left `null`. Once the model is called without an encoderSequence, + * you cannot call it again with encoderSequence. + */ + encoderSequence?: Tensor; + + /** + * A boolean Tensor, the padding mask of decoder sequence, must be of shape + * `[batchSize, decoderSequenceLength]`. + */ + decoderPaddingMask: Tensor; + + /** + * A boolean Tensor. Customized decoder sequence mask, must be of shape + * `[batchSize, decoderSequenceLength, decoderSequenceLength]`. + */ + decoderAttentionMask?: Tensor; + + /** + * A boolean Tensor, the padding mask of encoder sequence, must be of shape + * `[batchSize, encoderSequenceLength]`. + */ + encoderPaddingMask?: Tensor + + /** + * A boolean Tensor. Customized encoder sequence mask, must be of shape + `[batchSize, encoderSequenceLength, encoderSequenceLength]`. + */ + encoderAttentionMask?: Tensor; + + /** + * A dense float Tensor. The cache of key/values pairs in the self-attention + * layer. Has shape `[batchSize, 2, maxSeqLen, numHeads, keyDims]`. + */ + selfAttentionCache?: Tensor; + + /** + * Integer or Integer Tensor. The index at which to update the + * `selfAttentionCache`. Usually, this is the index of the current token + * being processed during decoding. + */ + selfAttentionCacheUpdateIndex?: number|Tensor; + + /** + * A dense float Tensor. The cache of key/value pairs in the cross-attention + * layer. Has shape `[batchSize, 2, S, numHeads, keyDims]`. + */ + crossAttentionCache?: Tensor; + + /** + * Integer or Integer Tensor. The index at which to update the + * `crossAttentionCache`. Usually, this is either `0` (compute the entire + * `crossAttentionCache`), or `null` (reuse a previously computed + * `crossAttentionCache`). + */ + crossAttentionCacheUpdateIndex?: number|Tensor; + + /** + * If true, a causal mask (masking out future input) is applied on the decoder + * sequence. + * Defaults to `true`. + */ + useCausalMask?: boolean; +} + +/** + * Transformer decoder. + * + * This class follows the architecture of the transformer decoder layer in the + * paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users + * can instantiate multiple instances of this class to stack up a decoder. + * + * By default, this layer will apply a causal mask to the decoder attention + * layer. This layer will correctly compute an attention mask from an implicit + * padding mask (for example, by passing `maskZero=true` to a + * `tf.layers.embedding` layer). See the Masking and Padding + * [guide](https://keras.io/guides/understanding_masking_and_padding/) + * for more details. + * + * This layer can be called with either one or two inputs. The number of inputs + * must be consistent across all calls. The options are as follows: + * `layer.call(decoderSequence)`: no cross-attention will be built into the + * decoder block. This is useful when building a "decoder-only" + * transformer such as GPT-2. + * `layer.call(decoderSequence, {encoderSequence})`: cross-attention will be + * built into the decoder block. This is useful when building an + * "encoder-decoder" transformer, such as the original transformer + * model described in Attention is All You Need. + * + * Examples: + * ```js + * // Create a single transformer decoder layer. + * const decoder = new TransformerDecoder({intermediateDim: 64, numHeads: 8}); + * + * // Create a simple model containing the decoder. + * const decoderInput = tf.input({shape: [10, 64]}); + * const encoderInput = tf.input({shape: {[10, 64]}); + * const output = decoder.call(decoderInput, {encoderInput}); + * const model = tf.model({ + * inputs: [decoderInput, encoderInput], + * outputs: output, + * ); + * + * // Call decoder on the inputs. + * const decoderInputData = tf.randomUniform([2, 10, 64]); + * const encoderInputData = tf.randomUniform([2, 10, 64]); + * const decoderOutput = model.predict([decoderInputData, encoderInputData]); + * ``` + * + * References: + * - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762) + */ +export class TransformerDecoder extends Layer { + /** @nocollapse */ + static readonly className = 'TransformerDecoder'; + + constructor(args: TransformerDecoderArgs) { + super(args); + throw new NotImplementedError(`Not implemented yet.`); + } + + /** + * + * @param inputShape decoderSequenceShape or + * [decoderSequenceShape, encoderSequenceShape] + */ + override build(inputShape: Shape|[Shape, Shape]): void { + throw new NotImplementedError(`Not implemented yet.`); + } + + override apply( + inputs: Tensor|Tensor[], kwargs?: TransformerDecoderOptions + ): Tensor | Tensor[] { + throw new NotImplementedError(`Not implemented yet.`); + } + + override call( + decoderSequence: Tensor, kwargs: TransformerDecoderOptions + ): Tensor|Tensor[] { + return this.callAndReturnCaches(decoderSequence, kwargs)[0]; + } + + /** + * @returns One of three things, depending on call arguments: + * - `[outputs, null, null]`, if `selfAttentionCache` is `null`. + * - `[outputs, selfAttentionCache, null]`, if `selfAttentionCache` is + * set and the layer has no cross-attention. + * - `[outputs, selfAttentionCache, crossAttentionCache]`, if + * `selfAttentionCache` and `crossAttentionCache` are set and + * the layer has cross-attention. + */ + callAndReturnCaches( + decoderSequence: Tensor, kwargs: TransformerDecoderOptions + ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D, Tensor1D|Tensor2D] { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.computeSelfAttentionMask}`); + } + + private computeSelfAttentionMask( + decoderSequence: Tensor, + decoderPaddingMask: Tensor, + decoderAttentionMask: Tensor, + useCasualMask: boolean, + selfAttentionCache: Tensor, + selfAttentionCacheUpdateIndex: number|Tensor + ): Tensor { + throw new NotImplementedError(`Not implemented yet.`); + } + + override getConfig(): serialization.ConfigDict { + throw new NotImplementedError(`Not implemented yet.`); + } + + override computeOutputShape(decoderSequenceShape: Shape): Shape { + throw new NotImplementedError(`Not implemented yet.`); + } +} +serialization.registerClass(TransformerDecoder); From 4713c4ef02a0bc0629b7187ad625caa08bbfd508 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Fri, 14 Jul 2023 00:03:50 +0000 Subject: [PATCH 06/18] lint --- .../nlp/modeling/transformer_decoder.ts | 6 +-- .../src/layers/nlp/multihead_attention.ts | 44 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts index 36304db5866..5582b372fcf 100644 --- a/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts +++ b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts @@ -49,7 +49,7 @@ export declare interface TransformerDecoderArgs extends LayerArgs { * The activation function of feedforward network. * Defaults to `"relu"`. */ - activation?: ActivationIdentifier + activation?: ActivationIdentifier; /** * The eps value in layer normalization components. @@ -107,11 +107,11 @@ export declare interface TransformerDecoderOptions { * A boolean Tensor, the padding mask of encoder sequence, must be of shape * `[batchSize, encoderSequenceLength]`. */ - encoderPaddingMask?: Tensor + encoderPaddingMask?: Tensor; /** * A boolean Tensor. Customized encoder sequence mask, must be of shape - `[batchSize, encoderSequenceLength, encoderSequenceLength]`. + * `[batchSize, encoderSequenceLength, encoderSequenceLength]`. */ encoderAttentionMask?: Tensor; diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index 84720625617..628b4b0c49c 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -36,7 +36,7 @@ import { Kwargs } from '../../types'; * Query, key, value inputs after projection are expected to have the shape as: * `(bs, , , num_heads, channels)`. * `bs` and `` are treated as ``. - + * * The attention operations can be generalized: * (1) Query-key dot product: * `(, , num_heads, channels), (, @@ -53,7 +53,7 @@ import { Kwargs } from '../../types'; * @returns Einsum equations. */ function buildAttentionEquation( - rank: number, attnAxes: Array + rank: number, attnAxes: number[] ): [string, string, string] { throw new NotImplementedError('Not implemented yet.'); } @@ -77,80 +77,80 @@ export declare interface MultiHeadAttentionArgs extends LayerArgs { /** * Integer. Number of attention heads. */ - numHeads: number, + numHeads: number; /** * Integer. Size of each attention head for query and key. */ - keyDim: number, + keyDim: number; /** * Integer. Size of each attention head for value. * Defaults to `keyDim`. */ - valueDim?: number, + valueDim?: number; /** * Dropout probability. * Defaults to 0.0. */ - dropout?: number, + dropout?: number; /** * Whether the dense layers use bias vectors/matrices. * Defaults to true. */ - useBias?: boolean, + useBias?: boolean; /** * The expected shape of an output tensor, besides the batch * and sequence dims. If not specified, projects back to the query * feature dim (the query input's last dimension). */ - outputShape?: Shape, + outputShape?: Shape; /** * Axes over which the attention is applied. `null` means attention over * all axes, but batch, heads, and features. */ - attentionAxes: Array, + attentionAxes: number[]; /** * Initializer for dense layer kernels. * Defaults to `"glorotUniform"`. */ - kernelInitializer?: InitializerIdentifier, + kernelInitializer?: InitializerIdentifier; /** * Initializer for dense layer biases. * Defaults to `"zeros"`. */ - biasInitializer?: InitializerIdentifier, + biasInitializer?: InitializerIdentifier; /** * Regularizer for dense layer kernels. */ - kernelRegularizer?: RegularizerIdentifier, + kernelRegularizer?: RegularizerIdentifier; /** * Regularizer for dense layer biases. */ - biasRegularizer?: RegularizerIdentifier, + biasRegularizer?: RegularizerIdentifier; /** * Regularizer for dense layer activity. */ - activityRegularizer?: RegularizerIdentifier, + activityRegularizer?: RegularizerIdentifier; /** * Constraint for dense layer kernels. */ - kernelConstraint?: ConstraintIdentifier, + kernelConstraint?: ConstraintIdentifier; /** * Constraint for dense layer kernels. */ - biasConstraint?: ConstraintIdentifier, + biasConstraint?: ConstraintIdentifier; } export declare interface MultiHeadAttentionOptions { @@ -161,13 +161,13 @@ export declare interface MultiHeadAttentionOptions { /** * Value `Tensor` of shape `(B, S, dim)`. */ - value: Tensor, + value: Tensor; /** * Key `Tensor` of shape `(B, S, dim)`. If not given, will use `value` for * both `key` and `value`, which is the most common case. */ - key?: Tensor, + key?: Tensor; /** * A boolean mask of shape `(B, T, S)`, that prevents @@ -176,7 +176,7 @@ export declare interface MultiHeadAttentionOptions { * attention and 0 indicates no attention. Broadcasting can happen for * the missing batch dimensions and the head dimension. */ - attentionMask?: Tensor, + attentionMask?: Tensor; /** * Indicates whether the layer should behave in training mode @@ -184,14 +184,14 @@ export declare interface MultiHeadAttentionOptions { * Will go with either using the training mode of the parent * layer/model, or false (inference) if there is no parent layer. */ - training?: boolean, + training?: boolean; /** * Indicates whether to apply a causal mask to prevent tokens from attending * to future tokens (e.g., used in a decoder Transformer). * Defaults to false. */ - useCausalMask?: boolean, + useCausalMask?: boolean; } /** @@ -216,7 +216,7 @@ export declare interface MultiHeadAttentionOptions { * * Finally, the result tensor with the last dimension as valueDim can take an * linear projection and return. - + * * When using `MultiHeadAttention` inside a custom layer, the custom layer must * implement its own `build()` method and call `MultiHeadAttention`'s * `buildFromSignature()` there. From 37aca1a9f9706208f53210718963509d7545ff6d Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 18:01:24 -0700 Subject: [PATCH 07/18] Add Einsum spec --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 218 +++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/einsum_dense.ts diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts new file mode 100644 index 00000000000..05b367f610a --- /dev/null +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -0,0 +1,218 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * TFJS-based einsum dense layer. + */ + +/* Original source: keras/layers/core/einsum_dense.py */ +import { Tensor, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { ConstraintIdentifier } from '../../constraints'; +import { Layer, LayerArgs } from '../../engine/topology'; +import { NotImplementedError } from '../../errors'; +import { InitializerIdentifier } from '../../initializers'; +import { ActivationIdentifier } from '../../keras_format/activation_config'; +import { Shape } from '../../keras_format/common'; +import { RegularizerIdentifier } from '../../regularizers'; +import { Kwargs } from 'tfjs-layers/src/types'; + +export declare interface EinsumDenseArgs extends LayerArgs { + /** + * An equation describing the einsum to perform. This equation must be a + * valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or + * `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum + * axis expression sequence. + */ + equation: string; + + /** + * The expected shape of the output tensor (excluding the batch dimension and + * any dimensions represented by ellipses). You can specify None for any + * dimension that is unknown or can be inferred from the input shape. + */ + outputShape: Shape; + + /** + * Activation function to use. If you don't specify anything, no activation + * is applied (that is, a "linear" activation: `a(x) = x`). + */ + activation?: ActivationIdentifier; + + /** + * A string containing the output dimension(s) to apply a bias to. Each + * character in the `biasAxes` string should correspond to a character + * in the output portion of the `equation` string. + */ + biasAxes?: string; + + /** + * Initializer for the `kernel` weights matrix. + * Defaults to `"glorotUniform"`. + */ + kernelInitializer?: InitializerIdentifier; + + /** + * Initializer for the bias vector. + * Defaults to `"zeros"`. + */ + biasInitializer?: InitializerIdentifier; + + /** + * Regularizer function applied to the `kernel` weights matrix. + */ + kernelRegularizer?: RegularizerIdentifier; + + /** + * Regularizer function applied to the bias vector. + */ + biasRegularizer?: RegularizerIdentifier; + + /** + * Regularizer function applied to the output of the layer (its "activation"). + */ + activityRegularizer?: RegularizerIdentifier; + + /** + * Constraint function applied to the `kernel` weights matrix. + */ + kernelConstraint?: ConstraintIdentifier; + + /** + * Constraint function applied to the bias vector. + */ + biasConstraint?: ConstraintIdentifier; +} + +export declare interface EinsumDenseOptions { + /** + * Pass to override the configured `sequenceLength` of the layer. + */ + sequenceLength?: number; + + /** + * Pass `false` to not append a start value for this input. + * Defaults to true. + */ + addStartValue?: boolean; + + /** + * Pass `false` to not append an end value for this input. + * Defaults to true. + */ + addEndValue?: boolean; +} + +/** + * A layer that uses `tf.einsum` as the backing computation. + * + * This layer can perform einsum calculations of arbitrary dimensionality. + * + * Examples: + * + * **Biased dense layer with einsums** + * + * This example shows how to instantiate a standard Keras dense layer using + * einsum operations. This example is equivalent to + * tf.layers.Dense({units: 64, useBias: true})`. + * + * const layer = new EinsumDense({ + * equation: "ab,bc->ac", outputShape: 4, biasAxes: "c"}); + * const inputTensor = tf.input({shape: [32]}); + * const outputTensor = layer.call(inputTensor); + * console.log(outputTensor); // [null, 64] + * + * **Applying a dense layer to a sequence** + * + * This example shows how to instantiate a layer that applies the same dense + * operation to every element in a sequence. Here, the `outputShape` has two + * values (since there are two non-batch dimensions in the output); the first + * dimension in the `outputShape` is `null`, because the sequence dimension + * `b` has an unknown shape. + * + * const layer = new EinsumDense({ + * equation: "abc,cd->abd", outputShape: [null, 64], biasAxes: "d"}); + * const inputTensor = tf.input({shape: [32, 128]}); + * const outputTensor = layer.call(inputTensor); + * console.log(outputTensor); // [null, 32, 64] + * + * **Applying a dense layer to a sequence using ellipses** + * + * This example shows how to instantiate a layer that applies the same dense + * operation to every element in a sequence, but uses the ellipsis notation + * instead of specifying the batch and sequence dimensions. + * + * Because we are using ellipsis notation and have specified only one axis, the + * `outputShape` arg is a single value. When instantiated in this way, the + * layer can handle any number of sequence dimensions - including the case + * where no sequence dimension exists. + * + * const layer = new EinsumDense({ + * equation: "...x,xy->...y", outputShape: 64, biasAxes: "y"}); + * const inputTensor = tf.input({shape: [32, 128]}); + * const outputTensor = layer.call(inputTensor); + * console.log(outputTensor); // [null, 32, 64] + */ +export class EinsumDense extends Layer { + /** @nocollapse */ + static readonly className = 'EinsumDense'; + + constructor(args: EinsumDenseArgs) { + super(args); + throw new NotImplementedError(`Not implmented yet.`); + } + + override build(inputShape: Shape | Shape[]): void { + throw new NotImplementedError( + `Not implmented yet. Uses ${this.analyzeEinsumString}`); + } + + override getConfig(): serialization.ConfigDict { + throw new NotImplementedError(`Not implmented yet.`); + } + + override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor2D { + throw new NotImplementedError(`Not implmented yet.`); + } + + /** + * Analyzes an einsum string to determine the required weight shape. + */ + private analyzeEinsumString( + equation: string, + biasAxes: string, + inputShape: Shape, + outputShape: Shape + ): [Shape, Shape, Shape] { + throw new NotImplementedError( + `Not implmented yet. Uses ${this.analyzeSplitString}.`); + } + + /** + * Analyze an pre-split einsum string to find the weight shape. + */ + private analyzeSplitString( + splitString: [string, string, string], + biasAxes: string, + inputShape: Shape, + outputShape: Shape, + leftElided?: boolean + ): [Shape, Shape, Shape] { + throw new NotImplementedError(`Not implmented yet.`); + } +} +serialization.registerClass(EinsumDense); From 2a6d92931c0681427f59af92718757b52b8650c5 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Fri, 14 Jul 2023 01:04:41 +0000 Subject: [PATCH 08/18] lint --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 2 +- tfjs-layers/src/layers/nlp/multihead_attention.ts | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index 05b367f610a..f709e9182c7 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -29,7 +29,7 @@ import { InitializerIdentifier } from '../../initializers'; import { ActivationIdentifier } from '../../keras_format/activation_config'; import { Shape } from '../../keras_format/common'; import { RegularizerIdentifier } from '../../regularizers'; -import { Kwargs } from 'tfjs-layers/src/types'; +import { Kwargs } from '../../types'; export declare interface EinsumDenseArgs extends LayerArgs { /** diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index 628b4b0c49c..2ea80cd783f 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -227,9 +227,9 @@ export declare interface MultiHeadAttentionOptions { * Performs 1D cross-attention over two sequence inputs with an attention mask. * Returns the additional attention weights over heads. * - * const layer = new MultiHeadAttention({numHeads=2, keyDim=2}); - * const target = tf.keras.Input(shape=[8, 16]); - * const source = tf.keras.Input(shape=[4, 16]); + * const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); + * const target = tf.input({shape: [8, 16]}); + * const source = tf.input({shape: [4, 16]}); * const outputTensor, weights = layer.callAndReturnAttentionScores( * target, {value: source}); * console.log(outputTensor.shape); // [null, 8, 16] From 6dcb7a0b2a7bab1214910e2f5a23ed8916be3cd4 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Mon, 17 Jul 2023 11:24:18 -0700 Subject: [PATCH 09/18] Remove unused type declaration --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index f709e9182c7..435c74cd9b0 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -98,25 +98,6 @@ export declare interface EinsumDenseArgs extends LayerArgs { biasConstraint?: ConstraintIdentifier; } -export declare interface EinsumDenseOptions { - /** - * Pass to override the configured `sequenceLength` of the layer. - */ - sequenceLength?: number; - - /** - * Pass `false` to not append a start value for this input. - * Defaults to true. - */ - addStartValue?: boolean; - - /** - * Pass `false` to not append an end value for this input. - * Defaults to true. - */ - addEndValue?: boolean; -} - /** * A layer that uses `tf.einsum` as the backing computation. * From e58981701fba167e85f0d76c25474e6ae31d0ba6 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Mon, 17 Jul 2023 11:26:47 -0700 Subject: [PATCH 10/18] Move helper functions outside EinsumDense class --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 54 +++++++++++----------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index 435c74cd9b0..dbc53a704a2 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -31,6 +31,32 @@ import { Shape } from '../../keras_format/common'; import { RegularizerIdentifier } from '../../regularizers'; import { Kwargs } from '../../types'; +/** + * Analyzes an einsum string to determine the required weight shape. + */ +export function analyzeEinsumString( + equation: string, + biasAxes: string, + inputShape: Shape, + outputShape: Shape +): [Shape, Shape, Shape] { + throw new NotImplementedError( + `Not implmented yet. Uses ${analyzeSplitString}.`); +} + +/** + * Analyze an pre-split einsum string to find the weight shape. + */ +function analyzeSplitString( + splitString: [string, string, string], + biasAxes: string, + inputShape: Shape, + outputShape: Shape, + leftElided?: boolean +): [Shape, Shape, Shape] { + throw new NotImplementedError(`Not implmented yet.`); +} + export declare interface EinsumDenseArgs extends LayerArgs { /** * An equation describing the einsum to perform. This equation must be a @@ -159,7 +185,7 @@ export class EinsumDense extends Layer { override build(inputShape: Shape | Shape[]): void { throw new NotImplementedError( - `Not implmented yet. Uses ${this.analyzeEinsumString}`); + `Not implmented yet. Uses ${analyzeEinsumString}`); } override getConfig(): serialization.ConfigDict { @@ -169,31 +195,5 @@ export class EinsumDense extends Layer { override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor2D { throw new NotImplementedError(`Not implmented yet.`); } - - /** - * Analyzes an einsum string to determine the required weight shape. - */ - private analyzeEinsumString( - equation: string, - biasAxes: string, - inputShape: Shape, - outputShape: Shape - ): [Shape, Shape, Shape] { - throw new NotImplementedError( - `Not implmented yet. Uses ${this.analyzeSplitString}.`); - } - - /** - * Analyze an pre-split einsum string to find the weight shape. - */ - private analyzeSplitString( - splitString: [string, string, string], - biasAxes: string, - inputShape: Shape, - outputShape: Shape, - leftElided?: boolean - ): [Shape, Shape, Shape] { - throw new NotImplementedError(`Not implmented yet.`); - } } serialization.registerClass(EinsumDense); From 9bafba54d8b9bed7fdd169b9271073690840ac6c Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Wed, 19 Jul 2023 10:31:12 -0700 Subject: [PATCH 11/18] Implement Einsum Dense --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 285 ++++++++++++++-- .../src/layers/nlp/einsum_dense_test.ts | 317 ++++++++++++++++++ 2 files changed, 583 insertions(+), 19 deletions(-) create mode 100644 tfjs-layers/src/layers/nlp/einsum_dense_test.ts diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index dbc53a704a2..6515313db05 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -20,16 +20,18 @@ */ /* Original source: keras/layers/core/einsum_dense.py */ -import { Tensor, Tensor2D, serialization } from '@tensorflow/tfjs-core'; +import { Tensor, Tensor2D, einsum, serialization, tidy } from '@tensorflow/tfjs-core'; -import { ConstraintIdentifier } from '../../constraints'; +import { Activation, getActivation, serializeActivation } from '../../activations'; +import { Constraint, ConstraintIdentifier, getConstraint, serializeConstraint } from '../../constraints'; import { Layer, LayerArgs } from '../../engine/topology'; -import { NotImplementedError } from '../../errors'; -import { InitializerIdentifier } from '../../initializers'; +import { ValueError } from '../../errors'; +import { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../initializers'; import { ActivationIdentifier } from '../../keras_format/activation_config'; import { Shape } from '../../keras_format/common'; -import { RegularizerIdentifier } from '../../regularizers'; +import { Regularizer, RegularizerIdentifier, getRegularizer, serializeRegularizer } from '../../regularizers'; import { Kwargs } from '../../types'; +import { LayerVariable } from '../../variables'; /** * Analyzes an einsum string to determine the required weight shape. @@ -40,21 +42,173 @@ export function analyzeEinsumString( inputShape: Shape, outputShape: Shape ): [Shape, Shape, Shape] { - throw new NotImplementedError( - `Not implmented yet. Uses ${analyzeSplitString}.`); + const dotReplacedString = equation.replace(/\.\.\./g, '0'); + + // This is the case where no ellipses are present in the string. + let splitString = + dotReplacedString.match(/([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)/); + if (splitString) { + return analyzeSplitString( + splitString, biasAxes, inputShape, outputShape); + } + + // This is the case where ellipses are present on the left. + splitString = + dotReplacedString.match(/0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)/); + if (splitString) { + return analyzeSplitString( + splitString, biasAxes, inputShape, outputShape, true); + } + + // This is the case where ellipses are present on the right. + splitString = + dotReplacedString.match(/([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0/); + if (splitString) { + return analyzeSplitString( + splitString, biasAxes, inputShape, outputShape); + } + + throw new ValueError( + `Invalid einsum equation '${equation}'. Equations must be in the form ` + + '[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....' + ); } /** * Analyze an pre-split einsum string to find the weight shape. */ -function analyzeSplitString( - splitString: [string, string, string], +export function analyzeSplitString( + splitString: RegExpMatchArray, biasAxes: string, inputShape: Shape, - outputShape: Shape, - leftElided?: boolean + outputShape: Shape|number, + leftElided = false ): [Shape, Shape, Shape] { - throw new NotImplementedError(`Not implmented yet.`); + const inputSpec = splitString[1]; + const weightSpec = splitString[2]; + const outputSpec = splitString[3]; + const elided = inputShape.length - inputSpec.length; + + const newOutputShape: Shape = Array.isArray(outputShape) ? + outputShape.slice() : [outputShape]; + newOutputShape.unshift(inputShape[0]); + + if (elided > 0 && leftElided) { + for(let i = 1; i < elided; i++) { + // We already inserted the 0th input dimension at dim 0, so we need + // to start at location 1 here. + newOutputShape.splice(1, 0, inputShape[i]); + } + } else if (elided > 0 && !leftElided) { + for(let i = inputShape.length - elided; i < inputShape.length; i++) { + newOutputShape.push(inputShape[i]); + } + } + + const inputSpecArr = Array.from(inputSpec); + const outputSpecArr = Array.from(outputSpec); + let inputDimMap, outputDimMap; + + if (leftElided) { + // If we have beginning dimensions elided, we need to use negative + // indexing to determine where in the input dimension our values are. + inputDimMap = new Map( + inputSpecArr.map((dim, i) => { + // This converts any negative indices to positive ones. + const idx = i + elided - inputShape.length; + const positiveIdx = idx < 0 ? inputShape.length + idx : idx; + return [dim, positiveIdx]; + }) + ); + + // Because we've constructed the full output shape already, we don't need + // to do negative indexing. + outputDimMap = new Map( + outputSpecArr.map((dim, i) => [dim, i + elided]) + ); + } else { + inputDimMap = new Map( + inputSpecArr.map((dim, i) => [dim, i]) + ); + outputDimMap = new Map( + outputSpecArr.map((dim, i) => [dim, i]) + ); + } + + for (const dim of inputSpec) { + const inputShapeAtDim = inputShape[inputDimMap.get(dim)]; + if (outputDimMap.has(dim)) { + const outputShapeAtDim = newOutputShape[outputDimMap.get(dim)]; + if (outputShapeAtDim !== null && outputShapeAtDim !== inputShapeAtDim) { + throw new ValueError( + `Input shape and output shape do not match at shared dimension `+ + `'${dim}'. Input shape is ${inputShapeAtDim}, and output shape ` + + `is ${outputShapeAtDim}.` + ); + } + } + } + + for (const dim of outputSpec) { + if (!inputSpec.includes(dim) && !weightSpec.includes(dim)) { + throw new ValueError( + `Dimension '${dim}' was specified in the output '${outputSpec}' ` + + `but has no corresponding dimension in the input spec ` + + `'${inputSpec}' or weight spec '${weightSpec}'` + ); + } + } + + const weightShape: Shape = []; + for (const dim of weightSpec) { + if (inputDimMap.has(dim)) { + weightShape.push(inputShape[inputDimMap.get(dim)]); + } else if (outputDimMap.has(dim)) { + weightShape.push(newOutputShape[outputDimMap.get(dim)]); + } else { + throw new ValueError( + `Weight dimension '${dim}' did not have a match in either the ` + + `input spec '${inputSpec}' or the output spec '${outputSpec}'. For ` + + `this layer, the weight must be fully specified.` + ); + } + } + + let biasShape: Shape; + if (biasAxes != null) { + const numLeftElided = leftElided ? elided : 0; + const idxMap: { [char: string]: number } = {}; + for (let i = 0; i < outputSpec.length; i++) { + idxMap[outputSpec[i]] = newOutputShape[i + numLeftElided]; + } + + for (const char of biasAxes) { + if (!outputSpec.includes(char)) { + throw new ValueError( + `Bias dimension '${char}' was requested, but is not part of the ` + + `output spec '${outputSpec}'` + ); + } + } + + const firstBiasLocation = Math.min( + ...biasAxes.split('').map(char => outputSpec.indexOf(char)) + ); + const biasOutputSpec = outputSpec.slice(firstBiasLocation); + + biasShape = biasOutputSpec.split('').map(char => + biasAxes.includes(char) ? idxMap[char] : 1 + ); + + if (!leftElided) { + for (let _ = 0; _ < elided; _++) { + biasShape.push(1); + } + } + } else { + biasShape = null; + } + return [weightShape, biasShape, newOutputShape]; } export declare interface EinsumDenseArgs extends LayerArgs { @@ -71,7 +225,7 @@ export declare interface EinsumDenseArgs extends LayerArgs { * any dimensions represented by ellipses). You can specify None for any * dimension that is unknown or can be inferred from the input shape. */ - outputShape: Shape; + outputShape: Shape|number; /** * Activation function to use. If you don't specify anything, no activation @@ -151,11 +305,13 @@ export declare interface EinsumDenseArgs extends LayerArgs { * dimension in the `outputShape` is `null`, because the sequence dimension * `b` has an unknown shape. * + * ```js * const layer = new EinsumDense({ * equation: "abc,cd->abd", outputShape: [null, 64], biasAxes: "d"}); * const inputTensor = tf.input({shape: [32, 128]}); * const outputTensor = layer.call(inputTensor); * console.log(outputTensor); // [null, 32, 64] + * ``` * * **Applying a dense layer to a sequence using ellipses** * @@ -168,32 +324,123 @@ export declare interface EinsumDenseArgs extends LayerArgs { * layer can handle any number of sequence dimensions - including the case * where no sequence dimension exists. * + * ```js * const layer = new EinsumDense({ * equation: "...x,xy->...y", outputShape: 64, biasAxes: "y"}); * const inputTensor = tf.input({shape: [32, 128]}); * const outputTensor = layer.call(inputTensor); * console.log(outputTensor); // [null, 32, 64] + * `` */ export class EinsumDense extends Layer { /** @nocollapse */ static readonly className = 'EinsumDense'; + private readonly equation: string; + private readonly biasAxes: string; + private readonly partialOutputShape: Shape; + private readonly activation: Activation; + private readonly kernelInitializer: Initializer; + private readonly biasInitializer: Initializer; + private readonly kernelRegularizer: Regularizer; + private readonly biasRegularizer: Regularizer; + private readonly kernelConstraint: Constraint; + private readonly biasConstraint: Constraint; + private fullOutputShape: Shape; + private _kernel: LayerVariable; + private _bias: LayerVariable; constructor(args: EinsumDenseArgs) { super(args); - throw new NotImplementedError(`Not implmented yet.`); + this.equation = args.equation; + this.biasAxes = args.biasAxes; + this.partialOutputShape = + Array.isArray(args.outputShape) ? args.outputShape : [args.outputShape]; + this.activation = getActivation(args.activation); + this.kernelInitializer = getInitializer( + args.kernelInitializer ?? 'glorotUniform'); + this.biasInitializer = getInitializer(args.biasInitializer ?? 'zeros'); + this.kernelRegularizer = getRegularizer(args.kernelRegularizer); + this.biasRegularizer = getRegularizer(args.biasRegularizer); + this.kernelConstraint = getConstraint(args.kernelConstraint); + this.biasConstraint = getConstraint(args.biasConstraint); + } + + get kernel(): LayerVariable { + return this._kernel; + } + + get bias(): LayerVariable { + return this._bias; + } + + override build(inputShape: Shape): void { + const [kernelShape, biasShape, fullOutputShape] = analyzeEinsumString( + this.equation, + this.biasAxes, + inputShape, + this.partialOutputShape + ); + this.fullOutputShape = fullOutputShape; + this._kernel = this.addWeight( + 'kernel', + kernelShape, + this.dtype, + this.kernelInitializer, + this.kernelRegularizer, + true, + this.kernelConstraint, + ); + + if (biasShape != null) { + this._bias = this.addWeight( + 'bias', + biasShape, + this.dtype, + this.biasInitializer, + this.biasRegularizer, + true, + this.biasConstraint, + ); + } else { + this._bias = null; + } + super.build(inputShape); } - override build(inputShape: Shape | Shape[]): void { - throw new NotImplementedError( - `Not implmented yet. Uses ${analyzeEinsumString}`); + override computeOutputShape(_: Shape): Shape { + return this.fullOutputShape; } override getConfig(): serialization.ConfigDict { - throw new NotImplementedError(`Not implmented yet.`); + const config = { + outputShape: this.partialOutputShape, + equation: this.equation, + activation: serializeActivation(this.activation), + biasAxes: this.biasAxes, + kernelInitializer: serializeInitializer(this.kernelInitializer), + biasInitializer: serializeInitializer(this.biasInitializer), + kernelRegularizer: serializeRegularizer(this.kernelRegularizer), + biasRegularizer: serializeRegularizer(this.biasRegularizer), + kernelConstraint: serializeConstraint(this.kernelConstraint), + biasConstraint: serializeConstraint(this.biasConstraint), + }; + const baseConfig = super.getConfig(); + Object.assign(config, baseConfig); + return config; } override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor2D { - throw new NotImplementedError(`Not implmented yet.`); + return tidy(() => { + inputs = Array.isArray(inputs) ? inputs : [inputs]; + let ret = einsum(this.equation, ...inputs, this.kernel.read()); + if (this.bias != null) { + ret = ret.add(this.bias.read()); + } + if (this.activation != null) { + ret = this.activation.apply(ret); + } + return ret; + }); } } serialization.registerClass(EinsumDense); diff --git a/tfjs-layers/src/layers/nlp/einsum_dense_test.ts b/tfjs-layers/src/layers/nlp/einsum_dense_test.ts new file mode 100644 index 00000000000..f209c0c7c03 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/einsum_dense_test.ts @@ -0,0 +1,317 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the 'License'); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an 'AS IS' BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Unit Tests for TFJS-based EinsumDense Layer. + */ + +import { memory, Tensor } from '@tensorflow/tfjs-core'; + +import { analyzeEinsumString, EinsumDense } from './einsum_dense'; +import { Shape } from '../../keras_format/common'; +import { input } from '../../exports'; + +declare interface EinsumDenseTestCaseArgs { + testcaseName: string; + equation: string; + biasAxes: string; + inputShape: Shape; + outputShape: Shape; + expectedWeightShape: Shape; + expectedBiasShape: Shape; + expectedOutputShape: Shape; +} + +describe('EinsumDense', () => { + const combinations: EinsumDenseTestCaseArgs[] = [ + { + testcaseName: '_1d_end_weight', + equation: 'ab,b->a', + biasAxes: null, + inputShape: [null, 32], + outputShape: [], + expectedWeightShape: [32], + expectedBiasShape: null, + expectedOutputShape: [null], + }, + { + testcaseName: '_2d_middle_weight', + equation: 'ab,bc->ac', + biasAxes: null, + inputShape: [null, 32], + outputShape: [64], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, 64], + }, + { + testcaseName: '_3d_bert', + equation: 'abc,cde->abde', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [1, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_3_bias', + equation: 'abc,cde->abde', + biasAxes: 'e', + inputShape: [null, 1, 2], + outputShape: [1, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [4], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_2_bias', + equation: 'abc,cde->abde', + biasAxes: 'd', + inputShape: [null, 1, 2], + outputShape: [1, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [3, 1], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_1_3_bias', + equation: 'abc,cde->abde', + biasAxes: 'be', + inputShape: [null, 7, 2], + outputShape: [7, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [7, 1, 4], + expectedOutputShape: [null, 7, 3, 4], + }, + { + testcaseName: '_3d_bert_projection', + equation: 'BFNH,NHD->BFD', + biasAxes: null, + inputShape: [null, 1, 2, 3], + outputShape: [1, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 4], + }, + { + testcaseName: '_2d_bert', + equation: 'abc,cd->abd', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [1, 4], + expectedWeightShape: [2, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 4], + }, + { + testcaseName: '_embedding_1d', + equation: 'i,d->id', + biasAxes: null, + inputShape: [null], + outputShape: [2], + expectedWeightShape: [2], + expectedBiasShape: null, + expectedOutputShape: [null, 2], + }, + { + testcaseName: '_xlnet_lm', + equation: 'ibd,nd->ibn', + biasAxes: null, + inputShape: [null, null, 1], + outputShape: [null, 2], + expectedWeightShape: [2, 1], + expectedBiasShape: null, + expectedOutputShape: [null, null, 2], + }, + { + testcaseName: '_2d_precast', + equation: '...b,bc->...c', + biasAxes: null, + inputShape: [null, 32], + outputShape: [64], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, 64], + }, + { + testcaseName: '_2d_precast_elided_input_used_in_output', + equation: '...bc,bc->...b', + biasAxes: null, + inputShape: [null, 32, 64], + outputShape: [32], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, 32], + }, + { + testcaseName: '_2d_precast_multiple_elided_dims', + equation: '...b,bc->...c', + biasAxes: null, + inputShape: [null, null, 32], + outputShape: [64], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, null, 64], + }, + { + testcaseName: '_3d_precast', + equation: '...c,cde->...de', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_precast_3_bias', + equation: '...c,cde->...de', + biasAxes: 'e', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [4], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_precast_2_bias', + equation: '...c,cde->...de', + biasAxes: 'd', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [3, 1], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_precast_2_3_bias', + equation: '...c,cde->...de', + biasAxes: 'de', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [3, 4], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_2d_postcast', + equation: 'bc...,cd->bd...', + biasAxes: null, + inputShape: [null, 1, 2, 3], + outputShape: [4], + expectedWeightShape: [1, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 4, 2, 3], + }, + { + testcaseName: '_3d_postcast', + equation: 'bc...,cde->bde...', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 3, 4, 2], + }, + { + testcaseName: '_3d_postcast_1_bias', + equation: 'bc...,cde->bde...', + biasAxes: 'd', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: [3, 1, 1], + expectedOutputShape: [null, 3, 4, 2], + }, + { + testcaseName: '_3d_postcast_2_bias', + equation: 'bc...,cde->bde...', + biasAxes: 'e', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: [4, 1], + expectedOutputShape: [null, 3, 4, 2], + }, + { + testcaseName: '_3d_postcast_1_2_bias', + equation: 'bc...,cde->bde...', + biasAxes: 'de', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: [3, 4, 1], + expectedOutputShape: [null, 3, 4, 2], + }, + ]; + + function testWeightShape(combo: EinsumDenseTestCaseArgs) { + it(`${combo.testcaseName} weight shape`, () => { + const [weightShape, biasShape, _] = analyzeEinsumString( + combo.equation, combo.biasAxes, combo.inputShape, combo.outputShape + ); + expect(weightShape).toEqual(combo.expectedWeightShape); + expect(biasShape).toEqual(combo.expectedBiasShape); + }); + } + + function testLayerCreation(combo: EinsumDenseTestCaseArgs) { + it(`${combo.testcaseName} layer creation`, () => { + const nonBatchInputShape = combo.inputShape.slice(1); + const inputTensor = input({shape: nonBatchInputShape}); + + const layer = new EinsumDense({ + equation: combo.equation, + biasAxes: combo.biasAxes, + outputShape: combo.outputShape, + }); + const outputTensor = layer.apply(inputTensor) as Tensor; + + expect(layer.kernel.shape).toEqual(combo.expectedWeightShape); + if (combo.expectedBiasShape === null) { + expect(layer.bias).toBeNull(); + } else { + expect(layer.bias.shape).toEqual(combo.expectedBiasShape); + } + expect(outputTensor.shape).toEqual(combo.expectedOutputShape); + }); + } + + for (const combo of combinations) { + testWeightShape(combo); + testLayerCreation(combo); + } + + it('Does not leak memory', () => { + const combo = combinations[0]; + const layer = new EinsumDense({ + equation: combo.equation, + biasAxes: combo.biasAxes, + outputShape: combo.outputShape, + }); + const nonBatchInputShape = combo.inputShape.slice(1); + const inputTensor = input({shape: nonBatchInputShape}); + + const numTensors = memory().numTensors; + layer.apply(inputTensor); + + expect(memory().numTensors).toEqual(numTensors + 1); + }); + + // TODO(pforderique): Test serialization. +}); From 4428cf1d4c4da2ba10e4aac8cbec915081daf77e Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Wed, 26 Jul 2023 15:08:28 -0700 Subject: [PATCH 12/18] Address comments --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index 6515313db05..6fd744ef0c5 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -116,7 +116,8 @@ export function analyzeSplitString( inputSpecArr.map((dim, i) => { // This converts any negative indices to positive ones. const idx = i + elided - inputShape.length; - const positiveIdx = idx < 0 ? inputShape.length + idx : idx; + const positiveIdx = + ((idx % inputShape.length) + inputShape.length) % inputShape.length; return [dim, positiveIdx]; }) ); @@ -201,7 +202,7 @@ export function analyzeSplitString( ); if (!leftElided) { - for (let _ = 0; _ < elided; _++) { + for (let i = 0; i < elided; i++) { biasShape.push(1); } } From 9e54a156aa27ecdde07a1cf22b2dc8e3fb4046f2 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Wed, 26 Jul 2023 15:16:51 -0700 Subject: [PATCH 13/18] Implement MHA Layer --- .../src/layers/nlp/multihead_attention.ts | 587 +++++++++++++++--- .../layers/nlp/multihead_attention_test.ts | 509 +++++++++++++++ 2 files changed, 1026 insertions(+), 70 deletions(-) create mode 100644 tfjs-layers/src/layers/nlp/multihead_attention_test.ts diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index 2ea80cd783f..0b7c7c5e13b 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -20,32 +20,39 @@ */ /* Original source: keras/layers/attention/multi_head_attention.py */ -import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core'; - -import { ConstraintIdentifier } from '../../constraints'; -import { Layer, LayerArgs } from '../../engine/topology'; -import { NotImplementedError } from '../../errors'; -import { InitializerIdentifier } from '../../initializers'; +import { Tensor, einsum, linalg, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core'; +// tslint:disable-next-line: no-imports-from-dist +import { arraysEqual } from '@tensorflow/tfjs-core/dist/util_base'; + +import { cast, expandDims } from '../../backend/tfjs_backend'; +import { Constraint, ConstraintIdentifier, getConstraint, serializeConstraint } from '../../constraints'; +import { Layer, LayerArgs, SymbolicTensor } from '../../engine/topology'; +import { ValueError } from '../../errors'; +import { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../initializers'; import { Shape } from '../../keras_format/common'; -import { RegularizerIdentifier } from '../../regularizers'; +import { Regularizer, RegularizerIdentifier, getRegularizer, serializeRegularizer } from '../../regularizers'; import { Kwargs } from '../../types'; +import { Softmax } from '../advanced_activations'; +import { Dropout } from '../core'; +import { EinsumDense } from './einsum_dense'; +const _CHR_IDX = 'abcdefghijklmnopqrstuvwxyz'.split(''); /** * Builds einsum equations for the attention computation. * * Query, key, value inputs after projection are expected to have the shape as: - * `(bs, , , num_heads, channels)`. + * `(bs, , , numHeads, channels)`. * `bs` and `` are treated as ``. * * The attention operations can be generalized: * (1) Query-key dot product: - * `(, , num_heads, channels), (, - * , num_heads, channels) -> (, - * num_heads, , )` + * `(, , numHeads, channels), (, + * , numHeads, channels) -> (, + * numHeads, , )` * (2) Combination: - * `(, num_heads, , ), - * (, , num_heads, channels) -> (, , num_heads, channels)` + * `(, numHeads, , ), + * (, , numHeads, channels) -> (, , numHeads, channels)` * * @param rank Rank of query, key, value tensors. * @param attnAxes Array of axes, `[-1, rank)`, @@ -54,8 +61,41 @@ import { Kwargs } from '../../types'; */ function buildAttentionEquation( rank: number, attnAxes: number[] -): [string, string, string] { - throw new NotImplementedError('Not implemented yet.'); +): [string, string, number] { + const targetNotationArr = _CHR_IDX.slice(0, rank); + // `batchDims` includes the head dim. + const excludeIndices = [...attnAxes, rank - 1]; + const batchDims = []; + for (const e of Array(rank).keys()) { + if (!excludeIndices.includes(e)) { + batchDims.push(e); + } + } + let letterOffset = rank; + let sourceNotation = ''; + for (let i = 0; i < rank; i++) { + if (batchDims.includes(i) || i === rank - 1) { + sourceNotation += targetNotationArr[i]; + } else { + sourceNotation += _CHR_IDX[letterOffset]; + letterOffset++; + } + } + + const productNotation = + batchDims.map(i => targetNotationArr[i]).concat( + attnAxes.map(i => targetNotationArr[i]), + attnAxes.map(i => sourceNotation[i]), + ).join(''); + const targetNotation = targetNotationArr.join(''); + + const dotProductEquation = + `${sourceNotation},${targetNotation}->${productNotation}`; + const attnScoresRank = productNotation.length; + const combineEquation = + `${productNotation},${sourceNotation}->${targetNotation}`; + + return [dotProductEquation, combineEquation, attnScoresRank]; } /** @@ -64,13 +104,43 @@ function buildAttentionEquation( function buildProjectionEquation( freeDims: number, boundDims: number, outputDims: number ): [string, string, number] { - throw new NotImplementedError('Not implemented yet.'); + let inputStr = ''; + let kernelStr = ''; + let outputStr = ''; + let biasAxes = ''; + let letterOffset = 0; + + for (let i = 0; i < freeDims; i++) { + const char = _CHR_IDX[i + letterOffset]; + inputStr += char; + outputStr += char; + } + + letterOffset += freeDims; + for (let i = 0; i < boundDims; i++) { + const char = _CHR_IDX[i + letterOffset]; + inputStr += char; + kernelStr += char; + } + + letterOffset += boundDims; + for (let i = 0; i < outputDims; i++) { + const char = _CHR_IDX[i + letterOffset]; + kernelStr += char; + outputStr += char; + biasAxes += char; + } + + const equation = `${inputStr},${kernelStr}->${outputStr}`; + return [equation, biasAxes, outputStr.length]; } function getOutputShape( - outputRank: number, knownLastDims: Iterable + outputRank: number, knownLastDims: number[] ): Shape { - throw new NotImplementedError('Not implemented yet.'); + const outputShape = + Array(outputRank - knownLastDims.length).fill(null).concat(knownLastDims); + return outputShape; } export declare interface MultiHeadAttentionArgs extends LayerArgs { @@ -113,44 +183,44 @@ export declare interface MultiHeadAttentionArgs extends LayerArgs { * Axes over which the attention is applied. `null` means attention over * all axes, but batch, heads, and features. */ - attentionAxes: number[]; + attentionAxes?: number[]|number; /** * Initializer for dense layer kernels. * Defaults to `"glorotUniform"`. */ - kernelInitializer?: InitializerIdentifier; + kernelInitializer?: Initializer|InitializerIdentifier; /** * Initializer for dense layer biases. * Defaults to `"zeros"`. */ - biasInitializer?: InitializerIdentifier; + biasInitializer?: Initializer|InitializerIdentifier; /** * Regularizer for dense layer kernels. */ - kernelRegularizer?: RegularizerIdentifier; + kernelRegularizer?: Regularizer|RegularizerIdentifier; /** * Regularizer for dense layer biases. */ - biasRegularizer?: RegularizerIdentifier; + biasRegularizer?: Regularizer|RegularizerIdentifier; /** * Regularizer for dense layer activity. */ - activityRegularizer?: RegularizerIdentifier; + activityRegularizer?: Regularizer|RegularizerIdentifier; /** * Constraint for dense layer kernels. */ - kernelConstraint?: ConstraintIdentifier; + kernelConstraint?: Constraint|ConstraintIdentifier; /** * Constraint for dense layer kernels. */ - biasConstraint?: ConstraintIdentifier; + biasConstraint?: Constraint|ConstraintIdentifier; } export declare interface MultiHeadAttentionOptions { @@ -227,6 +297,7 @@ export declare interface MultiHeadAttentionOptions { * Performs 1D cross-attention over two sequence inputs with an attention mask. * Returns the additional attention weights over heads. * + * ```js * const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); * const target = tf.input({shape: [8, 16]}); * const source = tf.input({shape: [4, 16]}); @@ -234,14 +305,17 @@ export declare interface MultiHeadAttentionOptions { * target, {value: source}); * console.log(outputTensor.shape); // [null, 8, 16] * console.log(weights.shape); // [null, 2, 8, 4] + * ``` * * Performs 2D self-attention over a 5D input tensor on axes 2 and 3. * + * ```js * const layer = new MultiHeadAttention({ * numHeads: 2, keyDim: 2, attentionAxes: [2, 3]}); * const inputTensor = tf.input({shape: [5, 3, 4, 16]}); * const outputTensor = layer.call(inputTensor, {value: inputTensor}); * console.log(outputTensor.shape); // [null, 5, 3, 4, 16] + * ``` * * Returns: * attentionOutput: The result of the computation, of shape `(B, T, E)`, @@ -255,20 +329,138 @@ export class MultiHeadAttention extends Layer { /** @nocollapse */ static readonly className = 'MultiHeadAttention'; + protected readonly numHeads: number; + protected readonly keyDim: number; + protected readonly valueDim: number; + protected readonly dropout: number; + protected readonly useBias: boolean; + protected readonly _outputShape: Shape; + protected readonly kernelInitializer: Initializer; + protected readonly biasInitializer: Initializer; + protected readonly kernelRegularizer: Regularizer; + protected readonly biasRegularizer: Regularizer; + protected readonly kernelConstraint: Constraint; + protected readonly biasConstraint: Constraint; + protected dotProductEquation: string; + protected combineEquation: string; + protected attentionAxes: number[]; + protected builtFromSignature: boolean; + protected softmax: Softmax; + protected dropoutLayer: Dropout; + protected queryShape: Shape; + protected keyShape: Shape; + protected valueShape: Shape; + protected queryDense: EinsumDense; + protected keyDense: EinsumDense; + protected valueDense: EinsumDense; + protected outputDense: EinsumDense; + constructor(args: MultiHeadAttentionArgs) { super(args); - throw new NotImplementedError('Not implemented yet.'); + this.supportsMasking = true; + this.numHeads = args.numHeads; + this.keyDim = args.keyDim; + this.valueDim = args.valueDim ?? args.keyDim; + this.dropout = args.dropout ?? 0; + this.useBias = args.useBias ?? true; + this._outputShape = args.outputShape; + this.kernelInitializer = getInitializer( + args.kernelInitializer ?? 'glorotUniform'); + this.biasInitializer = getInitializer(args.biasInitializer ?? 'zeros'); + this.kernelRegularizer = getRegularizer(args.kernelRegularizer); + this.biasRegularizer = getRegularizer(args.biasRegularizer); + this.activityRegularizer = getRegularizer(args.activityRegularizer); + this.kernelConstraint = getConstraint(args.kernelConstraint); + this.biasConstraint = getConstraint(args.biasConstraint); + if (args.attentionAxes != null && !Array.isArray(args.attentionAxes)) { + this.attentionAxes = [args.attentionAxes]; + } else { + this.attentionAxes = args.attentionAxes as number[]; + } + this.builtFromSignature = false; + this.queryShape = null; + this.keyShape = null; + this.valueShape = null; + } + + /** + * Should be used for testing purposes only. + */ + get _queryDense() { + return this.queryDense; + } + + /** + * Should be used for testing purposes only. + */ + get _keyDense() { + return this.keyDense; + } + + /** + * Should be used for testing purposes only. + */ + get _valueDense() { + return this.valueDense; + } + + /** + * Should be used for testing purposes only. + */ + get _outputDense() { + return this.outputDense; } override getConfig(): serialization.ConfigDict { - throw new NotImplementedError('Not implemented yet.'); + const config = { + numHeads: this.numHeads, + keyDim: this.keyDim, + valueDim: this.valueDim, + dropout: this.dropout, + useBias: this.useBias, + outputShape: this._outputShape, + attentionAxes: this.attentionAxes, + kernelInitializer: serializeInitializer(this.kernelInitializer), + biasInitializer: serializeInitializer(this.biasInitializer), + kernelRegularizer: serializeRegularizer(this.kernelRegularizer), + biasRegularizer: serializeRegularizer(this.biasRegularizer), + activityRegularizer: serializeRegularizer(this.activityRegularizer), + kernelConstraint: serializeConstraint(this.kernelConstraint), + biasConstraint: serializeConstraint(this.biasConstraint), + queryShape: this.queryShape, + keyShape: this.keyShape, + valueShape: this.valueShape, + }; + const baseConfig = super.getConfig(); + Object.assign(config, baseConfig); + return config; } static override fromConfig( cls: serialization.SerializableConstructor, config: serialization.ConfigDict ): T { - throw new NotImplementedError('Not implemented yet.'); + // If the layer has a different build() function from the default, + // we need to trigger the customized build to create weights. + const queryShape = config['queryShape'] as Shape; + const keyShape = config['keyShape'] as Shape; + const valueShape = config['valueShape'] as Shape; + delete config['queryShape']; + delete config['keyShape']; + delete config['valueShape']; + + const layer = new cls(config); + if ([queryShape, keyShape, valueShape].includes(null)) { + console.warn( + 'One of dimensions of the input shape is missing. It ' + + 'should have been memorized when the layer was serialized. ' + + `${cls.toString()} is created without weights.` + ); + } else { + (layer as unknown as MultiHeadAttention).buildFromSignature( + queryShape, valueShape, keyShape); + } + return layer; } /** @@ -276,15 +468,91 @@ export class MultiHeadAttention extends Layer { * * Once the method is called, this.builtFromSignature will be set to true. */ - private buildFromSignature(query: Tensor, value: Tensor, key?: Tensor) { - throw new NotImplementedError( - `Not implemented yet. Uses ${buildProjectionEquation}, ${getOutputShape}, - ${this.getCommonKwargsForSublayer}, ${this.buildAttention}, - ${this.makeOutputDense}.`); + protected buildFromSignature( + queryShape: Shape, + valueShape: Shape, + keyShape?: Shape + ) { + this.builtFromSignature = true; + + if (keyShape === null) { + keyShape = valueShape; + } + + this.queryShape = queryShape; + this.valueShape = valueShape; + this.keyShape = keyShape; + + // Not using SymbolicTensors since tf.input() adds a batch dimension to the + // given shape, therefore giving the tensor the wrong rank. + const queryRank = queryShape.length; + const valueRank = valueShape.length; + const keyRank = keyShape.length; + + const freeDims = queryRank - 1; + let [einsumEquation, biasAxes, outputRank] = + buildProjectionEquation(freeDims, 1, 2); + this.queryDense = new EinsumDense({ + equation: einsumEquation, + outputShape: getOutputShape(outputRank - 1, [this.numHeads, this.keyDim]), + biasAxes: this.useBias ? biasAxes : null, + name: 'query', + ...this.getCommonKwargsForSublayer(), + }); + + [einsumEquation, biasAxes, outputRank] = + buildProjectionEquation(keyRank - 1, 1, 2); + this.keyDense = new EinsumDense({ + equation: einsumEquation, + outputShape: getOutputShape(outputRank - 1, [this.numHeads, this.keyDim]), + biasAxes: this.useBias ? biasAxes : null, + name: 'key', + ...this.getCommonKwargsForSublayer(), + }); + + [einsumEquation, biasAxes, outputRank] = + buildProjectionEquation(valueRank - 1, 1, 2); + this.valueDense = new EinsumDense({ + equation: einsumEquation, + outputShape: getOutputShape( + outputRank - 1, [this.numHeads, this.valueDim]), + biasAxes: this.useBias ? biasAxes : null, + name: 'value', + ...this.getCommonKwargsForSublayer(), + }); + + // Builds the attention computations for multi-head dot product attention. + this.buildAttention(outputRank); + this.outputDense = this.makeOutputDense( + freeDims, + this.getCommonKwargsForSublayer(), + 'attentionOutput' + ); } private getCommonKwargsForSublayer(): Kwargs { - throw new NotImplementedError('Not implemented yet.'); + // Create new clone of kernel/bias initializer, so that we don't reuse + // the initializer instance, which could lead to same init value since + // initializer is stateless. + const kernelInitializer = getInitializer({ + className: this.kernelInitializer.getClassName(), + config: this.kernelInitializer.getConfig(), + }); + const biasInitializer = getInitializer({ + className: this.biasInitializer.getClassName(), + config: this.biasInitializer.getConfig(), + }); + + const commonKwargs = { + kernelInitializer, + biasInitializer, + kernelRegularizer: this.kernelRegularizer, + biasRegularizer: this.biasRegularizer, + activityRegularizer: this.activityRegularizer, + kernelConstraint: this.kernelConstraint, + biasConstraint: this.biasConstraint, + }; + return commonKwargs; } /** @@ -297,28 +565,82 @@ export class MultiHeadAttention extends Layer { */ private makeOutputDense( freeDims: number, commonKwargs: Kwargs, name?: string - ): Kwargs { - throw new NotImplementedError('Not implemented yet.'); + ): EinsumDense { + let outputShape: Shape; + if (this._outputShape) { + if (!Array.isArray(this._outputShape)) { + outputShape = [this._outputShape]; + } else { + outputShape = this._outputShape; + } + } else { + outputShape = [this.queryShape[this.queryShape.length - 1]]; + } + + const [einsumEquation, biasAxes, outputRank] = + buildProjectionEquation(freeDims, 2, outputShape.length); + + return new EinsumDense({ + equation: einsumEquation, + outputShape: getOutputShape(outputRank - 1, outputShape), + biasAxes: this.useBias ? biasAxes : null, + name, + ...commonKwargs, + }); } /** * Builds multi-head dot-product attention computations. * - * This function builds attributes necessary for `_compute_attention` to + * This function builds attributes necessary for `computeAttention` to * customize attention computation to replace the default dot-product * attention. * * @param rank The rank of query, key, value tensors. */ - private buildAttention(rank: number) { - throw new NotImplementedError( - `Not implemented yet. Uses ${buildAttentionEquation}.`); + protected buildAttention(rank: number) { + if (this.attentionAxes == null) { + this.attentionAxes = []; + for (let i = 1; i < rank - 2; i++) { + this.attentionAxes.push(i); + } + } else { + this.attentionAxes = [...this.attentionAxes]; + } + + const [dotProductEquation, combineEquation, attnScoresRank] = + buildAttentionEquation(rank, this.attentionAxes); + this.dotProductEquation = dotProductEquation; + this.combineEquation = combineEquation; + + const normAxes: number[] = []; + const startIdx = attnScoresRank - this.attentionAxes.length; + for (let i = startIdx; i < attnScoresRank; i++) { + normAxes.push(i); + } + this.softmax = new Softmax({axis: normAxes}); + this.dropoutLayer = new Dropout({rate: this.dropout}); } - private maskedSoftmax( + protected maskedSoftmax( attentionScores: Tensor, attentionMask?: Tensor - ): Tensor|Tensor[] { - throw new NotImplementedError('Not implemented yet.'); + ): Tensor { + return tidy(() => { + // Normalize the attention scores to probabilities. + // `attentionScores` = [B, N, T, S] + if (attentionMask != null) { + // The expand dim happens starting from the `numHeads` dimension, + // (, numHeads, ) + const maskExpansionAxis = -this.attentionAxes.length * 2 - 1; + const endIdx = + attentionScores.shape.length - attentionMask.shape.length; + for (let _ = 0; _ < endIdx; _++) { + attentionMask = expandDims(attentionMask, maskExpansionAxis); + } + } + return this.softmax.apply( + attentionScores, {mask: attentionMask}) as Tensor; + }); } /** @@ -328,9 +650,9 @@ export class MultiHeadAttention extends Layer { * multi-head Q, K, V inputs. Users can override this function for * customized attention implementation. * - * @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`. - * @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`. - * @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`. + * @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`. + * @param key Projected key `Tensor` of shape `(B, S, N, keyDim)`. + * @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`. * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents * attention to certain positions. It is generally not needed if * the `query` and `value` (and/or `key`) are masked. @@ -340,32 +662,118 @@ export class MultiHeadAttention extends Layer { * @returns attentionOutput: Multi-headed outputs of attention computation. * @returns attentionScores: Multi-headed attention weights. */ - private computeAttention( + protected computeAttention( query: Tensor, key: Tensor, value: Tensor, attentionMask?: Tensor, training?: boolean ): [Tensor, Tensor] { - throw new NotImplementedError( - `Not implemented yet. Uses ${this.maskedSoftmax}.`); + return tidy(() => { + // Note: Applying scalar multiply at the smaller end of einsum improves + // XLA performance, but may introduce slight numeric differences in + // the Transformer attention head. + query = mul(query, 1.0 / Math.sqrt(this.keyDim)); + + // Take the dot product between "query" and "key" to get the raw + // attention scores. + let attentionScores = einsum(this.dotProductEquation, key, query); + + attentionScores = this.maskedSoftmax(attentionScores, attentionMask); + + // This is actually dropping out entire tokens to attend to, which might + // seem a bit unusual, but is taken from the original Transformer paper. + const attentionScoresDropout = + this.dropoutLayer.apply(attentionScores, {training}) as Tensor; + + // `contextLayer` = [B, T, N, H] + const attentionOutput = + einsum(this.combineEquation, attentionScoresDropout, value); + + return [attentionOutput, attentionScores]; + }); + } + + override apply( + inputs: Tensor | SymbolicTensor, + kwargs?: Kwargs + ): Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[] { + if (!kwargs || !kwargs['value']) { + throw new ValueError('Must pass in `value` argument in `kwargs.`'); + } + let newInputs: Tensor[]|SymbolicTensor[]; + + newInputs = [inputs, kwargs['value']].concat(kwargs['key'] ?? []); + + // TODO(pforderique): Support mask propogation. + return super.apply(newInputs, kwargs); } override call( query: Tensor, kwargs: MultiHeadAttentionOptions - ): Tensor|Tensor2D { - return this.callAndReturnAttentionScores(query, kwargs)[0]; + ): Tensor { + return tidy(() => { + return this.callAndReturnAttentionScores(query, kwargs)[0]; + }); } /** * Exactly like `call` except also returns the attention scores. */ callAndReturnAttentionScores( - query: Tensor, kwargs: MultiHeadAttentionOptions - ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D] { - throw new NotImplementedError( - `Not implemented yet. Uses ${this.buildFromSignature}, - ${this.computeAttentionMask}, ${this.computeAttention}.`); + query: Tensor, + { + value, + key, + useCausalMask, + attentionMask, + training + }: MultiHeadAttentionOptions + ): [Tensor, Tensor] { + return tidy(() => { + if (!this.builtFromSignature) { + this.buildFromSignature( + query.shape, + value.shape, + key ? key.shape : null + ); + } + if (key == null) { + key = value; + } + + // TODO(pforderique): Support RaggedTensor inputs. + + attentionMask = this.computeAttentionMask( + query, + value, + attentionMask, + useCausalMask, + ); + + // N = `numAttentionHeads` + // H = `sizePerHead` + // `query` = [B, T, N ,H] + query = this.queryDense.apply(query) as Tensor; + + // `key` = [B, S, N, H] + key = this.keyDense.apply(key) as Tensor; + + // `value` = [B, S, N, H] + value = this.valueDense.apply(value) as Tensor; + + const [attentionOutputPreDense, attentionScores] = this.computeAttention( + query, + key, + value, + attentionMask, + training + ); + const attentionOutput = + this.outputDense.apply(attentionOutputPreDense) as Tensor; + + return [attentionOutput, attentionScores]; + }); } /** @@ -383,9 +791,9 @@ export class MultiHeadAttention extends Layer { * In general, if the `query` and `value` are masked, then there is no need * to define the `attentionMask`. * - * @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`. - * @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`. - * @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`. + * @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`. + * @param key Projected key `Tensor` of shape `(B, S, N, keyDim)`. + * @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`. * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents * attention to certain positions. * @param useCausalMask A boolean to indicate whether to apply a causal @@ -399,12 +807,26 @@ export class MultiHeadAttention extends Layer { private computeAttentionMask( query: Tensor, value: Tensor, - key?: Tensor, attentionMask?: Tensor, - useCausalMask?: boolean + useCausalMask = false ): Tensor { - throw new NotImplementedError( - `Not implemented yet. Uses ${this.computeCasualMask}`); + return tidy(() => { + let autoMask: Tensor; + + if (useCausalMask) { + // the shape of the causal mask is [1, T, S] + const mask = this.computeCasualMask(query, value); + autoMask = mask; + } + + if (autoMask != null) { + // Merge attentionMask & automatic mask, to shape [B, T, S] + attentionMask = attentionMask ? + cast(attentionMask, 'bool').logicalAnd(autoMask) : autoMask; + } + + return attentionMask; + }); } /** @@ -426,7 +848,12 @@ export class MultiHeadAttention extends Layer { * triangular matrix of shape [T, S]. */ private computeCasualMask(query: Tensor, value?: Tensor): Tensor { - throw new NotImplementedError('Not implemented yet.'); + return tidy(() => { + const qSeqLength = query.shape[1]; + const vSeqLength = value ? value.shape[1] : qSeqLength; + // Create a lower triangular matrix. + return linalg.bandPart(ones([1, qSeqLength, vSeqLength], 'bool'), -1, 0); + }); } /** @@ -435,10 +862,30 @@ export class MultiHeadAttention extends Layer { * [queryShape, valueShape, keyShape]. If no keyShape provided, valueShape * is assumed as the keyShape. */ - override computeOutputShape( - inputShapes: [Shape, Shape] | [Shape, Shape, Shape] - ): Shape { - throw new NotImplementedError('Not implemented yet.'); + override computeOutputShape(inputShapes: [Shape, Shape, Shape|null]): Shape { + const [queryShape, valueShape, maybeKeyShape] = inputShapes; + const keyShape = maybeKeyShape ?? valueShape; + + if (queryShape.slice(-1)[0] !== valueShape.slice(-1)[0]) { + throw new ValueError( + `The last dimension of 'queryShape' and 'valueShape' must be equal, ` + + `but are ${queryShape.slice(-1)[0]}, ${valueShape.slice(-1)[0]}. ` + + `Received: queryShape=${queryShape}, valueShape=${valueShape}` + ); + } + + if (!arraysEqual(valueShape.slice(1, -1), keyShape.slice(1, -1))) { + throw new Error( + `All dimensions of 'value' and 'key', except the last one, must be ` + + `equal. Received ${valueShape} and ${keyShape}` + ); + } + + if (this._outputShape) { + return queryShape.slice(0, -1).concat(this._outputShape); + } + + return queryShape; } } serialization.registerClass(MultiHeadAttention); diff --git a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts new file mode 100644 index 00000000000..a389d1ab52c --- /dev/null +++ b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts @@ -0,0 +1,509 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the 'License'); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an 'AS IS' BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Unit Tests for MultiHeadAttention layer. + */ + +import { Tensor, memory, ones, randomUniform, randomUniformInt, tensor, tensor2d } from '@tensorflow/tfjs-core'; + +import { TruncatedNormal } from '../../initializers'; +import { input } from '../../exports'; +import { Shape } from '../../keras_format/common'; +import { MultiHeadAttention } from './multihead_attention'; +import { describeMathCPU, expectTensorsClose, expectTensorsNotClose } from '../../utils/test_utils'; +import { Embedding } from '../embeddings'; + +describe('MultiHeadAttention', () => { + + describe('Non Masked Attention', () => { + interface NonMaskedAttentionArgs { + testcaseName: string; + valueDim: number; + outputShape: Shape; + outputDims: Shape; + } + /** + * Test that the attention layer can be created without a mask tensor. + */ + function testNonMaskedAttention( + {testcaseName, valueDim, outputShape, outputDims}: NonMaskedAttentionArgs + ) { + it(`${testcaseName} non masked attention`, () => { + const testLayer = new MultiHeadAttention({ + numHeads: 12, + keyDim: 64, + valueDim, + outputShape, + }); + // Create a 3-dimensional input (the first dimension is implicit). + const query = input({shape: [40, 80]}); + const value = input({shape: [20, 80]}); + const output = testLayer.apply(query, {value}) as Tensor; + expect(output.shape).toEqual([null].concat(outputDims)); + }); + } + + const params: NonMaskedAttentionArgs[] = [ + { + testcaseName: 'key value same proj', + valueDim: null, + outputShape: null, + outputDims: [40, 80], + }, + { + testcaseName: 'key value different proj', + valueDim: 32, + outputShape: [60], + outputDims: [40, 60], + } + ]; + for (const param of params) { + testNonMaskedAttention(param); + } + }); + + // Test with one input (self-attention) and no mask tensor. + it('non masked self attention', () => { + const testLayer = new MultiHeadAttention({numHeads: 12, keyDim: 64}); + // Create a 3-dimensional input (the first dimension is implicit). + const query = input({shape: [40, 80]}); + const output = testLayer.apply(query, {value: query}) as Tensor; + expect(output.shape).toEqual([null, 40, 80]); + }); + + // Test attention outputs with coefficients. + it('attention scores', () => { + const testLayer = new MultiHeadAttention({numHeads: 12, keyDim: 64}); + const query = ones([1, 40, 80]); + const [output, coef] = + testLayer.callAndReturnAttentionScores(query, {value: query}); + expect(output.shape).toEqual([1, 40, 80]); + expect(coef.shape).toEqual([1, 12, 40, 40]); + }); + + // Test attention outputs with coefficients. + it('attention scores with values', () => { + const testLayer = new MultiHeadAttention({numHeads: 12, keyDim: 64}); + const query = ones([1, 40, 80]); + const value = ones([1, 60, 80]); + const [output, coef] = + testLayer.callAndReturnAttentionScores(query, {value}); + expect(output.shape).toEqual([1, 40, 80]); + expect(coef.shape).toEqual([1, 12, 40, 60]); + }); + + describe('Masked Attention', () => { + interface MaskedAttentionArgs { + testcaseName: string; + useBias: boolean; + } + /** + * Test with a mask tensor. + */ + function testMaskedAttention({testcaseName, useBias}: MaskedAttentionArgs) { + it(`${testcaseName}`, () => { + const testLayer = new MultiHeadAttention({ + numHeads: 2, + keyDim: 2, + useBias, + }); + // Create a 3-dimensional input (the first dimension is implicit). + const batchSize = 4; + const query = randomUniform([batchSize, 4, 8]); + const value = randomUniform([batchSize, 2, 8]); + + // Invoke the data with a random set of mask data. This should mask at + // least one element. + const maskData = randomUniformInt([batchSize, 4, 2], 0, 2); + const maskedOutputData = testLayer.call( + query, {value, attentionMask: maskData}); + + // Invoke the same data, but with a null mask (where no elements are + // masked). + const nullMaskData = ones([batchSize, 4, 2]); + const unmaskedOutputData = testLayer.call( + query, {value, attentionMask: nullMaskData}); + + expectTensorsNotClose(maskedOutputData, unmaskedOutputData); + + if (useBias) { + expect(testLayer._queryDense.trainableWeights.length).toEqual(2); + expect(testLayer._outputDense.trainableWeights.length).toEqual(2); + } else { + expect(testLayer._queryDense.trainableWeights.length).toEqual(1); + expect(testLayer._outputDense.trainableWeights.length).toEqual(1); + } + }); + } + const params: MaskedAttentionArgs[] = [ + { + testcaseName: 'with bias', + useBias: true, + }, + { + testcaseName: 'no bias', + useBias: false, + } + ]; + for (const param of params) { + testMaskedAttention(param); + } + }); + + // Test with a specified initializer. + it('initializers', () => { + const testLayer = new MultiHeadAttention({ + numHeads: 12, + keyDim: 64, + kernelInitializer: new TruncatedNormal({stddev: 0.02}), + }); + const query = ones([1, 40, 80]); + const output = testLayer.call(query, {value: query}); + expect(output.shape).toEqual([1, 40, 80]); + + // Make sure the sub layers have different kernel init value, and not + // reusing the initializers. + const queryKernel = testLayer._queryDense.kernel.read(); + const keyKernel = testLayer._keyDense.kernel.read(); + const valueKernel = testLayer._valueDense.kernel.read(); + const outputKernel = testLayer._outputDense.kernel.read(); + + expectTensorsNotClose(queryKernel, keyKernel, 1e-6); + expectTensorsNotClose(queryKernel, valueKernel, 1e-6); + expectTensorsNotClose(queryKernel, outputKernel, 1e-6); + }); + + describeMathCPU('High Dimensional Attention', () => { + interface HighDimAttentionArgs { + testcaseName: string; + qDims: Shape; + vDims: Shape; + maskDims: Shape; + attentionAxes: number[]; + } + /** + * Test with high dimensional inputs. + */ + function testHighDimAttention({ + testcaseName, qDims, vDims, maskDims, attentionAxes, + }: HighDimAttentionArgs) { + it(testcaseName, () => { + const testLayer = new MultiHeadAttention({ + numHeads: 2, keyDim: 2, attentionAxes, + }); + const batchSize = 3; + const hiddenSize = 8; + // Generate data for the input (non-mask) tensors. + const queryShape = [batchSize].concat(qDims).concat(hiddenSize); + const valueShape = [batchSize].concat(vDims).concat(hiddenSize); + const maskShape = [batchSize].concat(maskDims); + const query = randomUniform(queryShape, 0, 10); + const value = randomUniform(valueShape, 0, 10); + + // Invoke the data with a random set of mask data. This should mask at + // least one element. + const maskData = randomUniformInt(maskShape, 0, 2).asType('bool'); + + // Invoke the same data, but with a null mask (where no elements are + // masked). + const nullMaskData = ones(maskShape); + + // Because one data is masked and one is not, the outputs should not be + // the same. + + const outputWithMask = testLayer.call( + query, {value, attentionMask: maskData}); + const outputWithNullMask = testLayer.call( + query, {value, attentionMask: nullMaskData}); + + expectTensorsNotClose(outputWithMask, outputWithNullMask); + }); + } + const params: HighDimAttentionArgs[] = [ + { + testcaseName: '4d_inputs_1freebatch_mask2', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4d_inputs_1freebatch_mask3', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4d_inputs_1freebatch_mask4', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 2, 4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4D_inputs_2D_attention', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 4, 3, 2], + attentionAxes: [1, 2], + }, + { + testcaseName: '5D_inputs_2D_attention', + qDims: [5, 3, 4], + vDims: [5, 3, 2], + maskDims: [3, 4, 3, 2], + attentionAxes: [2, 3], + }, + { + testcaseName: '5D_inputs_2D_attention_fullmask', + qDims: [5, 3, 4], + vDims: [5, 3, 2], + maskDims: [5, 3, 4, 3, 2], + attentionAxes: [2, 3], + }, + ]; + for (const param of params) { + testHighDimAttention(param); + } + }); + + it('dropout', () => { + const testLayer = new MultiHeadAttention({ + numHeads: 2, + keyDim: 2, + dropout: 0.5, + }); + const fromData = ones([32, 4, 8]); + const toData = ones([32, 2, 8]); + + const trainOut = testLayer.call(fromData, {value: toData, training: true}); + const testOut = testLayer.call(fromData, {value: toData, training: false}); + + expectTensorsNotClose(trainOut, testOut); + }); + + describe('Casual Mask Value', () => { + /** + * Test that the value and causal masks are taken into account. + */ + function testValueMask(testcaseName: string, useCausalMask: boolean) { + it(testcaseName, () => { + const testLayer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); + const query = tensor2d([ + [1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0] + ]); + const maskedQuery = new Embedding( + {inputDim: 4, outputDim: 8, maskZero: true}).apply(query) as Tensor; + const value = tensor2d([[5, 4, 0], [3, 0, 0], [2, 1, 1]]); + const maskedValue = new Embedding( + {inputDim: 6, outputDim: 8, maskZero: true}).apply(value) as Tensor; + + const output = testLayer.call( + maskedQuery, {value: maskedValue, useCausalMask: true}); + + let mask = tensor([ + Array(3).fill([true, true, false]).concat( + Array(2).fill([false, false, false])), + Array(5).fill([true, false, false]), + [[true, true, true]].concat( + Array(4).fill([false, false, false])) + ]); + if (useCausalMask) { + mask = mask.logicalAnd(tensor([ + [[true, false, false], [true, true, false]].concat( + [[true, true, true], [true, true, true], [true, true, true]]) + ])); + } + + const outputWithManualMask = testLayer.call( + maskedQuery, {value: maskedValue, attentionMask: mask}); + + expectTensorsClose(output, outputWithManualMask); + }); + } + + const params: Array<[string, boolean]> = [ + ['casual', true], ['not_casual', false] + ]; + for (const [testName, useMask] of params) { + testValueMask(testName, useMask); + } + }); + + describe('Compute Output Shape', () => { + interface ComputeOutputShapeArgs { + testcaseName: string; + queryDims: Shape; + valueDims: Shape; + keyDims?: Shape; + outputShape: Shape; + } + /** + * Test computed shape is equal to the layer output's shape. + */ + function testComputeOutputShape({ + testcaseName, queryDims, valueDims, keyDims, outputShape, + }: ComputeOutputShapeArgs) { + it(testcaseName, () => { + const testLayer = new MultiHeadAttention({ + numHeads: 2, + keyDim: 2, + valueDim: 2, + outputShape + }); + const batchSize = 1; + + const queryShape = [batchSize].concat(queryDims); + const valueShape = [batchSize].concat(valueDims); + const keyShape = keyDims ? [batchSize].concat(keyDims) : null; + + const query = randomUniform(queryShape); + const value = randomUniform(valueShape); + const key = keyShape ? randomUniform(keyShape) : null; + + const output = testLayer.call(query, {value, key}); + const computedOutputShape = testLayer.computeOutputShape( + [queryShape, valueShape, keyShape]); + + expect(output.shape).toEqual(computedOutputShape); + }); + } + const params: ComputeOutputShapeArgs[] = [ + { + testcaseName: 'without_key_same_proj', + queryDims: [40, 80], + valueDims: [20, 80], + keyDims: null, + outputShape: null + }, + { + testcaseName: 'with_key_same_proj', + queryDims: [40, 80], + valueDims: [20, 80], + keyDims: [20, 30], + outputShape: null + }, + { + testcaseName: 'wihtout_key_different_proj', + queryDims: [40, 80], + valueDims: [20, 80], + keyDims: null, + outputShape: [30, 40] + }, + { + testcaseName: 'with_key_different_proj', + queryDims: [40, 80], + valueDims: [20, 80], + keyDims: [20, 30], + outputShape: [15, 50] + }, + ]; + for (const param of params) { + testComputeOutputShape(param); + } + }); + + describe('Compute Output Shape Raises Error', () => { + interface ComputeOutputShapeErrorArgs { + testcaseName: string; + queryShape: Shape; + valueShape: Shape; + keyShape?: Shape; + } + /** + * Test dimension mismatches. + */ + function testComputeOutputShapeError({ + testcaseName, queryShape, valueShape, keyShape, + }: ComputeOutputShapeErrorArgs) { + it(testcaseName, () => { + const testLayer = new MultiHeadAttention({ + numHeads: 4, + keyDim: 2, + valueDim: 2, + }); + + expect(() => testLayer.computeOutputShape( + [queryShape, valueShape, keyShape])).toThrow(); + }); + } + const params: ComputeOutputShapeErrorArgs[] = [ + { + testcaseName: 'query_value_dim_mismatch', + queryShape: [null, 40, 80], + valueShape: [null, 20, 70], + keyShape: null + }, + { + testcaseName: 'key_value_dim_mismatch', + queryShape: [null, 40, 80], + valueShape: [null, 20, 80], + keyShape: [null, 10, 70], + }, + { + testcaseName:'key_value_dim_mismatch_high_dim', + queryShape: [null, 40, 20, 30, 80], + valueShape: [null, 10, 10, 50, 80], + keyShape: [null, 10, 15, 50, 20], + }, + ]; + for (const param of params) { + testComputeOutputShapeError(param); + } + }); + + it('does not leak memory', () => { + const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); + const query = ones([1, 4, 8]); + // Initial call that builds sublayers and necessary tensors. + layer.call(query, {value: query}); + + const numTensors = memory().numTensors; + layer.call(query, {value: query}); + + expect(memory().numTensors).toEqual(numTensors + 1); + }); + // TODO(pforderique): Test serialization. +}); + +class SubclassAttention extends MultiHeadAttention { + protected override buildAttention(qkvRank: number) {} + + protected override computeAttention( + query: Tensor, + key: Tensor, + value: Tensor, + attentionMask?: Tensor, + training?: boolean + ): [Tensor, Tensor] { + return [value, null]; + } +} + +describe('AttentionSubclass', () => { + // Test with a specified initializer. + it('initializer', () => { + const testLayer = new SubclassAttention({numHeads: 12, keyDim: 64}); + // Create a 3-dimensional input. + const query = ones([1, 40, 80]); + const output = testLayer.call(query, {value: query}); + + expect(output.shape).toEqual([1, 40, 80]); + }); +}); From acb83e2f85e5f5eaee6f37c1b848a432cf31ae28 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Wed, 26 Jul 2023 15:20:08 -0700 Subject: [PATCH 14/18] Add masked softmax support --- .../src/layers/advanced_activations.ts | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tfjs-layers/src/layers/advanced_activations.ts b/tfjs-layers/src/layers/advanced_activations.ts index d03af4168b3..b03eb9f2cf5 100644 --- a/tfjs-layers/src/layers/advanced_activations.ts +++ b/tfjs-layers/src/layers/advanced_activations.ts @@ -12,7 +12,7 @@ * Advanced activation layers. */ -import {cast, clipByValue, elu, greater, leakyRelu, mul, prelu, relu, serialization, Tensor} from '@tensorflow/tfjs-core'; +import {add, cast, clipByValue, elu, exp, greater, leakyRelu, logSumExp, mul, ones, prelu, relu, scalar, serialization, sub, Tensor} from '@tensorflow/tfjs-core'; import {Softmax as softmaxActivation} from '../activations'; import {Constraint, getConstraint, serializeConstraint} from '../constraints'; @@ -306,13 +306,13 @@ export declare interface SoftmaxLayerArgs extends LayerArgs { * Integer, axis along which the softmax normalization is applied. * Defaults to `-1` (i.e., the last axis). */ - axis?: number; + axis?: number|number[]; } export class Softmax extends Layer { /** @nocollapse */ static className = 'Softmax'; - readonly axis: number; + readonly axis: number|number[]; readonly softmax: (t: Tensor, a?: number) => Tensor; readonly DEFAULT_AXIS = 1.0; @@ -326,7 +326,27 @@ export class Softmax extends Layer { } override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] { - const x = getExactlyOneTensor(inputs); + // TODO(pforderique): Add tests for when `this.axis` is a number[]. + let x = getExactlyOneTensor(inputs); + const mask = kwargs['mask'] as Tensor; + if (mask != null) { + // Since mask is 1.0 for positions we want to keep and 0.0 for masked + // positions, this operation will create a tensor which is 0.0 for + // positions we want to attend and -1e.9 for masked positions. + const adder = + mul(sub(ones(x.shape), cast(mask, x.dtype)), scalar(-1e9)); + + // Since we are adding it to the raw scores before the softmax, this + // is effectively the same as removing these entirely. + x = add(x, adder); + } + if (this.axis instanceof Array) { + if (this.axis.length > 1) { + return exp(sub(x, logSumExp(x, this.axis, true))); + } else { + return this.softmax(x, this.axis[0]); + } + } return this.softmax(x, this.axis); } From 2c74fe41a05c184d765e8463ddd993bbd478b127 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Mon, 31 Jul 2023 16:41:05 +0000 Subject: [PATCH 15/18] Fix typo --- tfjs-layers/src/layers/nlp/multihead_attention.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index 0b7c7c5e13b..f410f947ab5 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -815,7 +815,7 @@ export class MultiHeadAttention extends Layer { if (useCausalMask) { // the shape of the causal mask is [1, T, S] - const mask = this.computeCasualMask(query, value); + const mask = this.computeCausalMask(query, value); autoMask = mask; } @@ -847,7 +847,7 @@ export class MultiHeadAttention extends Layer { * @returns mask: A boolean `Tensor` of shape [1, T, S] containing a lower * triangular matrix of shape [T, S]. */ - private computeCasualMask(query: Tensor, value?: Tensor): Tensor { + private computeCausalMask(query: Tensor, value?: Tensor): Tensor { return tidy(() => { const qSeqLength = query.shape[1]; const vSeqLength = value ? value.shape[1] : qSeqLength; From c3123354fcfaaeae2f694ceb2e2e1c3936fcd78d Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Mon, 31 Jul 2023 19:33:41 +0000 Subject: [PATCH 16/18] Check for undef and null --- tfjs-layers/src/layers/nlp/multihead_attention.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index f410f947ab5..49d9d402a55 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -475,7 +475,7 @@ export class MultiHeadAttention extends Layer { ) { this.builtFromSignature = true; - if (keyShape === null) { + if (keyShape == null) { keyShape = valueShape; } From bdc2f4a1a2d3994248b8f149d596284898a01747 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Fri, 4 Aug 2023 13:27:57 +0000 Subject: [PATCH 17/18] Make buildFromSignature public --- tfjs-layers/src/layers/nlp/multihead_attention.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index 49d9d402a55..b89d472f27d 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -468,7 +468,7 @@ export class MultiHeadAttention extends Layer { * * Once the method is called, this.builtFromSignature will be set to true. */ - protected buildFromSignature( + buildFromSignature( queryShape: Shape, valueShape: Shape, keyShape?: Shape From c5d65b5d66aa890ef33e9fc1e87de3b03cd8ab97 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Mon, 7 Aug 2023 10:20:30 -0700 Subject: [PATCH 18/18] Wrap softmax call in tf.tidy --- .../src/layers/advanced_activations.ts | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/tfjs-layers/src/layers/advanced_activations.ts b/tfjs-layers/src/layers/advanced_activations.ts index b03eb9f2cf5..de9180c776a 100644 --- a/tfjs-layers/src/layers/advanced_activations.ts +++ b/tfjs-layers/src/layers/advanced_activations.ts @@ -12,7 +12,7 @@ * Advanced activation layers. */ -import {add, cast, clipByValue, elu, exp, greater, leakyRelu, logSumExp, mul, ones, prelu, relu, scalar, serialization, sub, Tensor} from '@tensorflow/tfjs-core'; +import {add, cast, clipByValue, elu, exp, greater, leakyRelu, logSumExp, mul, ones, prelu, relu, scalar, serialization, sub, Tensor, tidy} from '@tensorflow/tfjs-core'; import {Softmax as softmaxActivation} from '../activations'; import {Constraint, getConstraint, serializeConstraint} from '../constraints'; @@ -327,27 +327,29 @@ export class Softmax extends Layer { override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] { // TODO(pforderique): Add tests for when `this.axis` is a number[]. - let x = getExactlyOneTensor(inputs); - const mask = kwargs['mask'] as Tensor; - if (mask != null) { - // Since mask is 1.0 for positions we want to keep and 0.0 for masked - // positions, this operation will create a tensor which is 0.0 for - // positions we want to attend and -1e.9 for masked positions. - const adder = - mul(sub(ones(x.shape), cast(mask, x.dtype)), scalar(-1e9)); - - // Since we are adding it to the raw scores before the softmax, this - // is effectively the same as removing these entirely. - x = add(x, adder); - } - if (this.axis instanceof Array) { - if (this.axis.length > 1) { - return exp(sub(x, logSumExp(x, this.axis, true))); - } else { - return this.softmax(x, this.axis[0]); + return tidy(() => { + let x = getExactlyOneTensor(inputs); + const mask = kwargs['mask'] as Tensor; + if (mask != null) { + // Since mask is 1.0 for positions we want to keep and 0.0 for masked + // positions, this operation will create a tensor which is 0.0 for + // positions we want to attend and -1e.9 for masked positions. + const adder = + mul(sub(ones(x.shape), cast(mask, x.dtype)), scalar(-1e9)); + + // Since we are adding it to the raw scores before the softmax, this + // is effectively the same as removing these entirely. + x = add(x, adder); } - } - return this.softmax(x, this.axis); + if (this.axis instanceof Array) { + if (this.axis.length > 1) { + return exp(sub(x, logSumExp(x, this.axis, true))); + } else { + return this.softmax(x, this.axis[0]); + } + } + return this.softmax(x, this.axis); + }); } override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {