From ea472e2131f809aa61a03069955c031afe4576bb Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 09:50:00 -0700 Subject: [PATCH 01/12] Add spec for multi-head attention --- .../src/layers/nlp/multihead_attention.ts | 444 ++++++++++++++++++ 1 file changed, 444 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/multihead_attention.ts diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts new file mode 100644 index 00000000000..84720625617 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -0,0 +1,444 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * TFJS-based multi-head attention layer. + */ + +/* Original source: keras/layers/attention/multi_head_attention.py */ +import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { ConstraintIdentifier } from '../../constraints'; +import { Layer, LayerArgs } from '../../engine/topology'; +import { NotImplementedError } from '../../errors'; +import { InitializerIdentifier } from '../../initializers'; +import { Shape } from '../../keras_format/common'; +import { RegularizerIdentifier } from '../../regularizers'; +import { Kwargs } from '../../types'; + +/** + * Builds einsum equations for the attention computation. + * + * Query, key, value inputs after projection are expected to have the shape as: + * `(bs, , , num_heads, channels)`. + * `bs` and `` are treated as ``. + + * The attention operations can be generalized: + * (1) Query-key dot product: + * `(, , num_heads, channels), (, + * , num_heads, channels) -> (, + * num_heads, , )` + * (2) Combination: + * `(, num_heads, , ), + * (, , num_heads, channels) -> (, , num_heads, channels)` + * + * @param rank Rank of query, key, value tensors. + * @param attnAxes Array of axes, `[-1, rank)`, + * that attention will be applied to. + * @returns Einsum equations. + */ +function buildAttentionEquation( + rank: number, attnAxes: Array +): [string, string, string] { + throw new NotImplementedError('Not implemented yet.'); +} + +/** + * Builds an einsum equation for projections inside multi-head attention. + */ +function buildProjectionEquation( + freeDims: number, boundDims: number, outputDims: number +): [string, string, number] { + throw new NotImplementedError('Not implemented yet.'); +} + +function getOutputShape( + outputRank: number, knownLastDims: Iterable +): Shape { + throw new NotImplementedError('Not implemented yet.'); +} + +export declare interface MultiHeadAttentionArgs extends LayerArgs { + /** + * Integer. Number of attention heads. + */ + numHeads: number, + + /** + * Integer. Size of each attention head for query and key. + */ + keyDim: number, + + /** + * Integer. Size of each attention head for value. + * Defaults to `keyDim`. + */ + valueDim?: number, + + /** + * Dropout probability. + * Defaults to 0.0. + */ + dropout?: number, + + /** + * Whether the dense layers use bias vectors/matrices. + * Defaults to true. + */ + useBias?: boolean, + + /** + * The expected shape of an output tensor, besides the batch + * and sequence dims. If not specified, projects back to the query + * feature dim (the query input's last dimension). + */ + outputShape?: Shape, + + /** + * Axes over which the attention is applied. `null` means attention over + * all axes, but batch, heads, and features. + */ + attentionAxes: Array, + + /** + * Initializer for dense layer kernels. + * Defaults to `"glorotUniform"`. + */ + kernelInitializer?: InitializerIdentifier, + + /** + * Initializer for dense layer biases. + * Defaults to `"zeros"`. + */ + biasInitializer?: InitializerIdentifier, + + /** + * Regularizer for dense layer kernels. + */ + kernelRegularizer?: RegularizerIdentifier, + + /** + * Regularizer for dense layer biases. + */ + biasRegularizer?: RegularizerIdentifier, + + /** + * Regularizer for dense layer activity. + */ + activityRegularizer?: RegularizerIdentifier, + + /** + * Constraint for dense layer kernels. + */ + kernelConstraint?: ConstraintIdentifier, + + /** + * Constraint for dense layer kernels. + */ + biasConstraint?: ConstraintIdentifier, +} + +export declare interface MultiHeadAttentionOptions { + /** + * Query `Tensor` of shape `(B, T, dim)`. + */ + + /** + * Value `Tensor` of shape `(B, S, dim)`. + */ + value: Tensor, + + /** + * Key `Tensor` of shape `(B, S, dim)`. If not given, will use `value` for + * both `key` and `value`, which is the most common case. + */ + key?: Tensor, + + /** + * A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions. The boolean mask specifies which + * query elements can attend to which key elements, 1 indicates + * attention and 0 indicates no attention. Broadcasting can happen for + * the missing batch dimensions and the head dimension. + */ + attentionMask?: Tensor, + + /** + * Indicates whether the layer should behave in training mode + * (adding dropout) or in inference mode (no dropout). + * Will go with either using the training mode of the parent + * layer/model, or false (inference) if there is no parent layer. + */ + training?: boolean, + + /** + * Indicates whether to apply a causal mask to prevent tokens from attending + * to future tokens (e.g., used in a decoder Transformer). + * Defaults to false. + */ + useCausalMask?: boolean, +} + +/** + * MultiHeadAttention layer. + * + * This is an implementation of multi-headed attention as described in the + * paper "Attention is all you Need" (Vaswani et al., 2017). + * If `query`, `key,` `value` are the same, then + * this is self-attention. Each timestep in `query` attends to the + * corresponding sequence in `key`, and returns a fixed-width vector. + * + * This layer first projects `query`, `key` and `value`. These are + * (effectively) a list of tensors of length `numAttentionHeads`, where the + * corresponding shapes are `(batchSize, , keyDim)`, + * `(batchSize, , keyDim)`, + * `(batchSize, , valueDim)`. + * + * Then, the query and key tensors are dot-producted and scaled. These are + * softmaxed to obtain attention probabilities. The value tensors are then + * interpolated by these probabilities, then concatenated back to a single + * tensor. + * + * Finally, the result tensor with the last dimension as valueDim can take an + * linear projection and return. + + * When using `MultiHeadAttention` inside a custom layer, the custom layer must + * implement its own `build()` method and call `MultiHeadAttention`'s + * `buildFromSignature()` there. + * This enables weights to be restored correctly when the model is loaded. + * + * Examples: + * + * Performs 1D cross-attention over two sequence inputs with an attention mask. + * Returns the additional attention weights over heads. + * + * const layer = new MultiHeadAttention({numHeads=2, keyDim=2}); + * const target = tf.keras.Input(shape=[8, 16]); + * const source = tf.keras.Input(shape=[4, 16]); + * const outputTensor, weights = layer.callAndReturnAttentionScores( + * target, {value: source}); + * console.log(outputTensor.shape); // [null, 8, 16] + * console.log(weights.shape); // [null, 2, 8, 4] + * + * Performs 2D self-attention over a 5D input tensor on axes 2 and 3. + * + * const layer = new MultiHeadAttention({ + * numHeads: 2, keyDim: 2, attentionAxes: [2, 3]}); + * const inputTensor = tf.input({shape: [5, 3, 4, 16]}); + * const outputTensor = layer.call(inputTensor, {value: inputTensor}); + * console.log(outputTensor.shape); // [null, 5, 3, 4, 16] + * + * Returns: + * attentionOutput: The result of the computation, of shape `(B, T, E)`, + * where `T` is for target sequence shapes and `E` is the query input + * last dimension if `outputShape` is `None`. Otherwise, the + * multi-head outputs are projected to the shape specified by + * `outputShape`. + * attentionScores: multi-head attention coefficients over attention axes. + */ +export class MultiHeadAttention extends Layer { + /** @nocollapse */ + static readonly className = 'MultiHeadAttention'; + + constructor(args: MultiHeadAttentionArgs) { + super(args); + throw new NotImplementedError('Not implemented yet.'); + } + + override getConfig(): serialization.ConfigDict { + throw new NotImplementedError('Not implemented yet.'); + } + + static override fromConfig( + cls: serialization.SerializableConstructor, + config: serialization.ConfigDict + ): T { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Builds layers and variables. + * + * Once the method is called, this.builtFromSignature will be set to true. + */ + private buildFromSignature(query: Tensor, value: Tensor, key?: Tensor) { + throw new NotImplementedError( + `Not implemented yet. Uses ${buildProjectionEquation}, ${getOutputShape}, + ${this.getCommonKwargsForSublayer}, ${this.buildAttention}, + ${this.makeOutputDense}.`); + } + + private getCommonKwargsForSublayer(): Kwargs { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Builds the output projection matrix. + * + * @param freeDims Number of free dimensions for einsum equation building. + * @param commonKwargs Common keyword arguments for einsum layer. + * @param name Name for the projection layer. + * @returns Projection layer. + */ + private makeOutputDense( + freeDims: number, commonKwargs: Kwargs, name?: string + ): Kwargs { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Builds multi-head dot-product attention computations. + * + * This function builds attributes necessary for `_compute_attention` to + * customize attention computation to replace the default dot-product + * attention. + * + * @param rank The rank of query, key, value tensors. + */ + private buildAttention(rank: number) { + throw new NotImplementedError( + `Not implemented yet. Uses ${buildAttentionEquation}.`); + } + + private maskedSoftmax( + attentionScores: Tensor, attentionMask?: Tensor + ): Tensor|Tensor[] { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * Applies Dot-product attention with query, key, value tensors. + * + * This function defines the computation inside `call` with projected + * multi-head Q, K, V inputs. Users can override this function for + * customized attention implementation. + * + * @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`. + * @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`. + * @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`. + * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions. It is generally not needed if + * the `query` and `value` (and/or `key`) are masked. + * @param training Boolean indicating whether the layer should behave + * in training mode (adding dropout) or in inference mode (doing + * nothing). + * @returns attentionOutput: Multi-headed outputs of attention computation. + * @returns attentionScores: Multi-headed attention weights. + */ + private computeAttention( + query: Tensor, + key: Tensor, + value: Tensor, + attentionMask?: Tensor, + training?: boolean + ): [Tensor, Tensor] { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.maskedSoftmax}.`); + } + + override call( + query: Tensor, kwargs: MultiHeadAttentionOptions + ): Tensor|Tensor2D { + return this.callAndReturnAttentionScores(query, kwargs)[0]; + } + + /** + * Exactly like `call` except also returns the attention scores. + */ + callAndReturnAttentionScores( + query: Tensor, kwargs: MultiHeadAttentionOptions + ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D] { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.buildFromSignature}, + ${this.computeAttentionMask}, ${this.computeAttention}.`); + } + + /** + * Computes the attention mask. + * + * * The `query`'s mask is reshaped from [B, T] to [B, T, 1]. + * * The `value`'s mask is reshaped from [B, S] to [B, 1, S]. + * * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s + * mask is ignored if `key` is `None` or if `key is value`. + * * If `useCausalMask=true`, then the causal mask is computed. Its shape + * is [1, T, S]. + * + * All defined masks are merged using a logical AND operation (`&`). + * + * In general, if the `query` and `value` are masked, then there is no need + * to define the `attentionMask`. + * + * @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`. + * @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`. + * @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`. + * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions. + * @param useCausalMask A boolean to indicate whether to apply a causal + * mask to prevent tokens from attending to future tokens (e.g., + * used in a decoder Transformer). + * @returns attentionMask: A boolean mask of shape `(B, T, S)`, that prevents + * attention to certain positions, based on the Keras masks of the + * `query`, `key`, `value`, and `attentionMask` tensors, and the + * causal mask if `useCausalMask=true`. + */ + private computeAttentionMask( + query: Tensor, + value: Tensor, + key?: Tensor, + attentionMask?: Tensor, + useCausalMask?: boolean + ): Tensor { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.computeCasualMask}`); + } + + /** + * Computes a causal mask (e.g., for masked self-attention layers). + * + * For example, if query and value both contain sequences of length 4, + * this function returns a boolean `Tensor` equal to: + * + * ``` + * [[[true, false, false, false], + * [true, true, false, false], + * [true, true, true, false], + * [true, true, true, true]]] + * ``` + * + * @param query query `Tensor` of shape `(B, T, ...)`. + * @param value value `Tensor` of shape `(B, S, ...)` (defaults to query). + * @returns mask: A boolean `Tensor` of shape [1, T, S] containing a lower + * triangular matrix of shape [T, S]. + */ + private computeCasualMask(query: Tensor, value?: Tensor): Tensor { + throw new NotImplementedError('Not implemented yet.'); + } + + /** + * + * @param inputShapes A list of [queryShape, valueShape] or + * [queryShape, valueShape, keyShape]. If no keyShape provided, valueShape + * is assumed as the keyShape. + */ + override computeOutputShape( + inputShapes: [Shape, Shape] | [Shape, Shape, Shape] + ): Shape { + throw new NotImplementedError('Not implemented yet.'); + } +} +serialization.registerClass(MultiHeadAttention); From 41a105e6a7948bf90a22c4a6d11130f3be5fcbf9 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 14:13:40 -0700 Subject: [PATCH 02/12] Add CachedMultiHeadAttention cache --- .../modeling/cached_multihead_attention.ts | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts diff --git a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts new file mode 100644 index 00000000000..27fedbbc6a1 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts @@ -0,0 +1,122 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Cached MHA layer based on `MultiHeadAttention`. + */ + +/* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */ +import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { MultiHeadAttention } from '../multihead_attention'; +import { NotImplementedError } from '../../../errors'; + +export declare interface CachedMultiHeadAttentionOptions { + /** + * Query `Tensor` of shape `(B, T, dim)`. + */ + + /** + * Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` + * must equal `S` and match the shape of `attentionMask`. If `cache` is + * not `null`, `S*` can be any length less than `S`, and the computed + * value will be spliced into `cache` at `cacheUpdateIndex`. + */ + value: Tensor, + + /** + * Key `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` must + * equal `S` and match the shape of `attentionMask`. If `cache` is not `null`, + * `S*` can be any length less than `S`, and the computed value will be + * spliced into `cache` at `cacheUpdateIndex`. + */ + key?: Tensor, + + /** + * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents + * attention to certain positions. The boolean mask specifies which + * query elements can attend to which key elements, 1 indicates + * attention and 0 indicates no attention. Broadcasting can happen for + * the missing batch dimensions and the head dimension. + */ + attentionMask?: Tensor, + + /** + * A dense float Tensor. The key/value cache, of shape + * `[B, 2, S, num_heads, key_dims]`, where `S` must agree with the + * `attentionMask` shape. This argument is intended for use during + * generation to avoid recomputing intermediate state. + */ + cache?: Tensor + + /** + * Integer or Integer `Tensor`. The index at which to update `cache` + * (usually the index of the current token being processed when running + * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache + * will not be updated. + */ + cacheUpdateIndex?: number|Tensor +} + +/** + * MultiHeadAttention layer with cache support. + * + * This layer is suitable for use in autoregressive decoding. It can be use + * to cache decoder self-attention and cross-attention. The forward pass + * can happen in one of three modes: + * - No cache, same as regular multi-head attention. + * - Static cache (`cache_update_index` is None). In this case, the + * cached key/value projections will be used and the input values will + * be ignored. + * - Updated cache (`cache_update_index` is not None). In this case, new + * key/value projections are computed using the input, and spliced into + * the cache at the specified index. + * + * Note that caching is useful only during inference and should not be used + * during training. + * + * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension, + * `T` is the target sequence length, and `S` in the source sequence length. + * Note that during generative decoding, `T` is usually 1 (you are + * generating a target sequence of length one to predict the next token). + * + * Returns: + * An `(attentionOutput, cache)` tuple. `attentionOutput` is the result + * of the computation, of shape `(B, T, dim)`, where `T` is for target + * sequence shapes and `dim` is the query input last dimension if + * `outputShape` is `null`. Otherwise, the multi-head outputs are + * projected to the shape specified by `outputShape`. `cache` is the + * updated cache. + */ +export class CachedMultiHeadAttention extends MultiHeadAttention { + + override call( + query: Tensor, kwargs: CachedMultiHeadAttentionOptions + ): Tensor|Tensor2D { + return this.callAndReturnCache(query, kwargs)[0]; + } + + /** + * Exactly like `call` except also returns the updated cache. + */ + callAndReturnCache( + query: Tensor, kwargs: CachedMultiHeadAttentionOptions + ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D] { + throw new NotImplementedError(`Not implemented yet.`); + } +} +serialization.registerClass(CachedMultiHeadAttention); From 6e78ffcffd0f442660934a7cc06eadba06b9c415 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 16:11:14 -0700 Subject: [PATCH 03/12] Fix typos --- .../src/layers/nlp/modeling/cached_multihead_attention.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts index 27fedbbc6a1..f28c7aaa7ce 100644 --- a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts @@ -57,7 +57,7 @@ export declare interface CachedMultiHeadAttentionOptions { /** * A dense float Tensor. The key/value cache, of shape - * `[B, 2, S, num_heads, key_dims]`, where `S` must agree with the + * `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the * `attentionMask` shape. This argument is intended for use during * generation to avoid recomputing intermediate state. */ @@ -79,10 +79,10 @@ export declare interface CachedMultiHeadAttentionOptions { * to cache decoder self-attention and cross-attention. The forward pass * can happen in one of three modes: * - No cache, same as regular multi-head attention. - * - Static cache (`cache_update_index` is None). In this case, the + * - Static cache (`cacheUpdateIndex` is None). In this case, the * cached key/value projections will be used and the input values will * be ignored. - * - Updated cache (`cache_update_index` is not None). In this case, new + * - Updated cache (`cacheUpdateIndex` is not None). In this case, new * key/value projections are computed using the input, and spliced into * the cache at the specified index. * From 01d9e2e64ed28ef9906389a72d70aa3e34828017 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 16:13:21 -0700 Subject: [PATCH 04/12] Lint --- .../layers/nlp/modeling/cached_multihead_attention.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts index f28c7aaa7ce..19a0d8890e2 100644 --- a/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts @@ -36,7 +36,7 @@ export declare interface CachedMultiHeadAttentionOptions { * not `null`, `S*` can be any length less than `S`, and the computed * value will be spliced into `cache` at `cacheUpdateIndex`. */ - value: Tensor, + value: Tensor; /** * Key `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` must @@ -44,7 +44,7 @@ export declare interface CachedMultiHeadAttentionOptions { * `S*` can be any length less than `S`, and the computed value will be * spliced into `cache` at `cacheUpdateIndex`. */ - key?: Tensor, + key?: Tensor; /** * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents @@ -53,7 +53,7 @@ export declare interface CachedMultiHeadAttentionOptions { * attention and 0 indicates no attention. Broadcasting can happen for * the missing batch dimensions and the head dimension. */ - attentionMask?: Tensor, + attentionMask?: Tensor; /** * A dense float Tensor. The key/value cache, of shape @@ -61,7 +61,7 @@ export declare interface CachedMultiHeadAttentionOptions { * `attentionMask` shape. This argument is intended for use during * generation to avoid recomputing intermediate state. */ - cache?: Tensor + cache?: Tensor; /** * Integer or Integer `Tensor`. The index at which to update `cache` @@ -69,7 +69,7 @@ export declare interface CachedMultiHeadAttentionOptions { * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache * will not be updated. */ - cacheUpdateIndex?: number|Tensor + cacheUpdateIndex?: number|Tensor; } /** From 8f08c19c67d8b4ee48cd83bbac85afc25c027fd1 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 16:57:58 -0700 Subject: [PATCH 05/12] Add Transformer Decoder spec --- .../nlp/modeling/transformer_decoder.ts | 265 ++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts diff --git a/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts new file mode 100644 index 00000000000..36304db5866 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts @@ -0,0 +1,265 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Transformer decoder block implementation based on TFJS `Layer`. + */ + +/* Original source: keras_nlp/layers/modeling/transformer_decoder.py */ +import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { Layer, LayerArgs, } from '../../../engine/topology'; +import { NotImplementedError } from '../../../errors'; +import { InitializerIdentifier } from '../../../initializers'; +import { ActivationIdentifier } from '../../../keras_format/activation_config'; +import { Shape } from '../../../keras_format/common'; + +export declare interface TransformerDecoderArgs extends LayerArgs { + /** + * Integer. The hidden size of feedforward network. + */ + intermediateDim: number; + + /** + * Integer. The number of heads in MultiHeadAttention. + */ + numHeads: number; + + /** + * The dropout value, shared by MultiHeadAttention and feedforward network. + * Defaults to `0.`. + */ + dropout?: number; + + /** + * The activation function of feedforward network. + * Defaults to `"relu"`. + */ + activation?: ActivationIdentifier + + /** + * The eps value in layer normalization components. + * Defaults to `1e-5`. + */ + layerNormEpsilon?: number; + + /** + * The kernel initializer for the dense and multiheaded attention layers. + * Defaults to `"glorotUniform"`. + */ + kernelInitializer?: InitializerIdentifier; + + /** + * The bias initializer for the dense and multiheaded attention layers. + * Defaults to `"zeros"`. + */ + biasInitializer?: InitializerIdentifier; + + /** + * If true, the inputs to the attention layer(s) and the intermediate dense + * layer are normalized (similar to GPT-2). If set to false, outputs of + * attention layer and intermediate dense layer are normalized + * (similar to BERT). + * Defaults to `false`. + */ + normalizeFirst: boolean; +} + +export declare interface TransformerDecoderOptions { + /** + * decoderSequence: The decode input sequence. + */ + + /** + * The encoder input sequence. For decoder only models (like GPT2), this + * should be left `null`. Once the model is called without an encoderSequence, + * you cannot call it again with encoderSequence. + */ + encoderSequence?: Tensor; + + /** + * A boolean Tensor, the padding mask of decoder sequence, must be of shape + * `[batchSize, decoderSequenceLength]`. + */ + decoderPaddingMask: Tensor; + + /** + * A boolean Tensor. Customized decoder sequence mask, must be of shape + * `[batchSize, decoderSequenceLength, decoderSequenceLength]`. + */ + decoderAttentionMask?: Tensor; + + /** + * A boolean Tensor, the padding mask of encoder sequence, must be of shape + * `[batchSize, encoderSequenceLength]`. + */ + encoderPaddingMask?: Tensor + + /** + * A boolean Tensor. Customized encoder sequence mask, must be of shape + `[batchSize, encoderSequenceLength, encoderSequenceLength]`. + */ + encoderAttentionMask?: Tensor; + + /** + * A dense float Tensor. The cache of key/values pairs in the self-attention + * layer. Has shape `[batchSize, 2, maxSeqLen, numHeads, keyDims]`. + */ + selfAttentionCache?: Tensor; + + /** + * Integer or Integer Tensor. The index at which to update the + * `selfAttentionCache`. Usually, this is the index of the current token + * being processed during decoding. + */ + selfAttentionCacheUpdateIndex?: number|Tensor; + + /** + * A dense float Tensor. The cache of key/value pairs in the cross-attention + * layer. Has shape `[batchSize, 2, S, numHeads, keyDims]`. + */ + crossAttentionCache?: Tensor; + + /** + * Integer or Integer Tensor. The index at which to update the + * `crossAttentionCache`. Usually, this is either `0` (compute the entire + * `crossAttentionCache`), or `null` (reuse a previously computed + * `crossAttentionCache`). + */ + crossAttentionCacheUpdateIndex?: number|Tensor; + + /** + * If true, a causal mask (masking out future input) is applied on the decoder + * sequence. + * Defaults to `true`. + */ + useCausalMask?: boolean; +} + +/** + * Transformer decoder. + * + * This class follows the architecture of the transformer decoder layer in the + * paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users + * can instantiate multiple instances of this class to stack up a decoder. + * + * By default, this layer will apply a causal mask to the decoder attention + * layer. This layer will correctly compute an attention mask from an implicit + * padding mask (for example, by passing `maskZero=true` to a + * `tf.layers.embedding` layer). See the Masking and Padding + * [guide](https://keras.io/guides/understanding_masking_and_padding/) + * for more details. + * + * This layer can be called with either one or two inputs. The number of inputs + * must be consistent across all calls. The options are as follows: + * `layer.call(decoderSequence)`: no cross-attention will be built into the + * decoder block. This is useful when building a "decoder-only" + * transformer such as GPT-2. + * `layer.call(decoderSequence, {encoderSequence})`: cross-attention will be + * built into the decoder block. This is useful when building an + * "encoder-decoder" transformer, such as the original transformer + * model described in Attention is All You Need. + * + * Examples: + * ```js + * // Create a single transformer decoder layer. + * const decoder = new TransformerDecoder({intermediateDim: 64, numHeads: 8}); + * + * // Create a simple model containing the decoder. + * const decoderInput = tf.input({shape: [10, 64]}); + * const encoderInput = tf.input({shape: {[10, 64]}); + * const output = decoder.call(decoderInput, {encoderInput}); + * const model = tf.model({ + * inputs: [decoderInput, encoderInput], + * outputs: output, + * ); + * + * // Call decoder on the inputs. + * const decoderInputData = tf.randomUniform([2, 10, 64]); + * const encoderInputData = tf.randomUniform([2, 10, 64]); + * const decoderOutput = model.predict([decoderInputData, encoderInputData]); + * ``` + * + * References: + * - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762) + */ +export class TransformerDecoder extends Layer { + /** @nocollapse */ + static readonly className = 'TransformerDecoder'; + + constructor(args: TransformerDecoderArgs) { + super(args); + throw new NotImplementedError(`Not implemented yet.`); + } + + /** + * + * @param inputShape decoderSequenceShape or + * [decoderSequenceShape, encoderSequenceShape] + */ + override build(inputShape: Shape|[Shape, Shape]): void { + throw new NotImplementedError(`Not implemented yet.`); + } + + override apply( + inputs: Tensor|Tensor[], kwargs?: TransformerDecoderOptions + ): Tensor | Tensor[] { + throw new NotImplementedError(`Not implemented yet.`); + } + + override call( + decoderSequence: Tensor, kwargs: TransformerDecoderOptions + ): Tensor|Tensor[] { + return this.callAndReturnCaches(decoderSequence, kwargs)[0]; + } + + /** + * @returns One of three things, depending on call arguments: + * - `[outputs, null, null]`, if `selfAttentionCache` is `null`. + * - `[outputs, selfAttentionCache, null]`, if `selfAttentionCache` is + * set and the layer has no cross-attention. + * - `[outputs, selfAttentionCache, crossAttentionCache]`, if + * `selfAttentionCache` and `crossAttentionCache` are set and + * the layer has cross-attention. + */ + callAndReturnCaches( + decoderSequence: Tensor, kwargs: TransformerDecoderOptions + ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D, Tensor1D|Tensor2D] { + throw new NotImplementedError( + `Not implemented yet. Uses ${this.computeSelfAttentionMask}`); + } + + private computeSelfAttentionMask( + decoderSequence: Tensor, + decoderPaddingMask: Tensor, + decoderAttentionMask: Tensor, + useCasualMask: boolean, + selfAttentionCache: Tensor, + selfAttentionCacheUpdateIndex: number|Tensor + ): Tensor { + throw new NotImplementedError(`Not implemented yet.`); + } + + override getConfig(): serialization.ConfigDict { + throw new NotImplementedError(`Not implemented yet.`); + } + + override computeOutputShape(decoderSequenceShape: Shape): Shape { + throw new NotImplementedError(`Not implemented yet.`); + } +} +serialization.registerClass(TransformerDecoder); From 4713c4ef02a0bc0629b7187ad625caa08bbfd508 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Fri, 14 Jul 2023 00:03:50 +0000 Subject: [PATCH 06/12] lint --- .../nlp/modeling/transformer_decoder.ts | 6 +-- .../src/layers/nlp/multihead_attention.ts | 44 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts index 36304db5866..5582b372fcf 100644 --- a/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts +++ b/tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts @@ -49,7 +49,7 @@ export declare interface TransformerDecoderArgs extends LayerArgs { * The activation function of feedforward network. * Defaults to `"relu"`. */ - activation?: ActivationIdentifier + activation?: ActivationIdentifier; /** * The eps value in layer normalization components. @@ -107,11 +107,11 @@ export declare interface TransformerDecoderOptions { * A boolean Tensor, the padding mask of encoder sequence, must be of shape * `[batchSize, encoderSequenceLength]`. */ - encoderPaddingMask?: Tensor + encoderPaddingMask?: Tensor; /** * A boolean Tensor. Customized encoder sequence mask, must be of shape - `[batchSize, encoderSequenceLength, encoderSequenceLength]`. + * `[batchSize, encoderSequenceLength, encoderSequenceLength]`. */ encoderAttentionMask?: Tensor; diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index 84720625617..628b4b0c49c 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -36,7 +36,7 @@ import { Kwargs } from '../../types'; * Query, key, value inputs after projection are expected to have the shape as: * `(bs, , , num_heads, channels)`. * `bs` and `` are treated as ``. - + * * The attention operations can be generalized: * (1) Query-key dot product: * `(, , num_heads, channels), (, @@ -53,7 +53,7 @@ import { Kwargs } from '../../types'; * @returns Einsum equations. */ function buildAttentionEquation( - rank: number, attnAxes: Array + rank: number, attnAxes: number[] ): [string, string, string] { throw new NotImplementedError('Not implemented yet.'); } @@ -77,80 +77,80 @@ export declare interface MultiHeadAttentionArgs extends LayerArgs { /** * Integer. Number of attention heads. */ - numHeads: number, + numHeads: number; /** * Integer. Size of each attention head for query and key. */ - keyDim: number, + keyDim: number; /** * Integer. Size of each attention head for value. * Defaults to `keyDim`. */ - valueDim?: number, + valueDim?: number; /** * Dropout probability. * Defaults to 0.0. */ - dropout?: number, + dropout?: number; /** * Whether the dense layers use bias vectors/matrices. * Defaults to true. */ - useBias?: boolean, + useBias?: boolean; /** * The expected shape of an output tensor, besides the batch * and sequence dims. If not specified, projects back to the query * feature dim (the query input's last dimension). */ - outputShape?: Shape, + outputShape?: Shape; /** * Axes over which the attention is applied. `null` means attention over * all axes, but batch, heads, and features. */ - attentionAxes: Array, + attentionAxes: number[]; /** * Initializer for dense layer kernels. * Defaults to `"glorotUniform"`. */ - kernelInitializer?: InitializerIdentifier, + kernelInitializer?: InitializerIdentifier; /** * Initializer for dense layer biases. * Defaults to `"zeros"`. */ - biasInitializer?: InitializerIdentifier, + biasInitializer?: InitializerIdentifier; /** * Regularizer for dense layer kernels. */ - kernelRegularizer?: RegularizerIdentifier, + kernelRegularizer?: RegularizerIdentifier; /** * Regularizer for dense layer biases. */ - biasRegularizer?: RegularizerIdentifier, + biasRegularizer?: RegularizerIdentifier; /** * Regularizer for dense layer activity. */ - activityRegularizer?: RegularizerIdentifier, + activityRegularizer?: RegularizerIdentifier; /** * Constraint for dense layer kernels. */ - kernelConstraint?: ConstraintIdentifier, + kernelConstraint?: ConstraintIdentifier; /** * Constraint for dense layer kernels. */ - biasConstraint?: ConstraintIdentifier, + biasConstraint?: ConstraintIdentifier; } export declare interface MultiHeadAttentionOptions { @@ -161,13 +161,13 @@ export declare interface MultiHeadAttentionOptions { /** * Value `Tensor` of shape `(B, S, dim)`. */ - value: Tensor, + value: Tensor; /** * Key `Tensor` of shape `(B, S, dim)`. If not given, will use `value` for * both `key` and `value`, which is the most common case. */ - key?: Tensor, + key?: Tensor; /** * A boolean mask of shape `(B, T, S)`, that prevents @@ -176,7 +176,7 @@ export declare interface MultiHeadAttentionOptions { * attention and 0 indicates no attention. Broadcasting can happen for * the missing batch dimensions and the head dimension. */ - attentionMask?: Tensor, + attentionMask?: Tensor; /** * Indicates whether the layer should behave in training mode @@ -184,14 +184,14 @@ export declare interface MultiHeadAttentionOptions { * Will go with either using the training mode of the parent * layer/model, or false (inference) if there is no parent layer. */ - training?: boolean, + training?: boolean; /** * Indicates whether to apply a causal mask to prevent tokens from attending * to future tokens (e.g., used in a decoder Transformer). * Defaults to false. */ - useCausalMask?: boolean, + useCausalMask?: boolean; } /** @@ -216,7 +216,7 @@ export declare interface MultiHeadAttentionOptions { * * Finally, the result tensor with the last dimension as valueDim can take an * linear projection and return. - + * * When using `MultiHeadAttention` inside a custom layer, the custom layer must * implement its own `build()` method and call `MultiHeadAttention`'s * `buildFromSignature()` there. From 37aca1a9f9706208f53210718963509d7545ff6d Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Thu, 13 Jul 2023 18:01:24 -0700 Subject: [PATCH 07/12] Add Einsum spec --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 218 +++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 tfjs-layers/src/layers/nlp/einsum_dense.ts diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts new file mode 100644 index 00000000000..05b367f610a --- /dev/null +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -0,0 +1,218 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * TFJS-based einsum dense layer. + */ + +/* Original source: keras/layers/core/einsum_dense.py */ +import { Tensor, Tensor2D, serialization } from '@tensorflow/tfjs-core'; + +import { ConstraintIdentifier } from '../../constraints'; +import { Layer, LayerArgs } from '../../engine/topology'; +import { NotImplementedError } from '../../errors'; +import { InitializerIdentifier } from '../../initializers'; +import { ActivationIdentifier } from '../../keras_format/activation_config'; +import { Shape } from '../../keras_format/common'; +import { RegularizerIdentifier } from '../../regularizers'; +import { Kwargs } from 'tfjs-layers/src/types'; + +export declare interface EinsumDenseArgs extends LayerArgs { + /** + * An equation describing the einsum to perform. This equation must be a + * valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or + * `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum + * axis expression sequence. + */ + equation: string; + + /** + * The expected shape of the output tensor (excluding the batch dimension and + * any dimensions represented by ellipses). You can specify None for any + * dimension that is unknown or can be inferred from the input shape. + */ + outputShape: Shape; + + /** + * Activation function to use. If you don't specify anything, no activation + * is applied (that is, a "linear" activation: `a(x) = x`). + */ + activation?: ActivationIdentifier; + + /** + * A string containing the output dimension(s) to apply a bias to. Each + * character in the `biasAxes` string should correspond to a character + * in the output portion of the `equation` string. + */ + biasAxes?: string; + + /** + * Initializer for the `kernel` weights matrix. + * Defaults to `"glorotUniform"`. + */ + kernelInitializer?: InitializerIdentifier; + + /** + * Initializer for the bias vector. + * Defaults to `"zeros"`. + */ + biasInitializer?: InitializerIdentifier; + + /** + * Regularizer function applied to the `kernel` weights matrix. + */ + kernelRegularizer?: RegularizerIdentifier; + + /** + * Regularizer function applied to the bias vector. + */ + biasRegularizer?: RegularizerIdentifier; + + /** + * Regularizer function applied to the output of the layer (its "activation"). + */ + activityRegularizer?: RegularizerIdentifier; + + /** + * Constraint function applied to the `kernel` weights matrix. + */ + kernelConstraint?: ConstraintIdentifier; + + /** + * Constraint function applied to the bias vector. + */ + biasConstraint?: ConstraintIdentifier; +} + +export declare interface EinsumDenseOptions { + /** + * Pass to override the configured `sequenceLength` of the layer. + */ + sequenceLength?: number; + + /** + * Pass `false` to not append a start value for this input. + * Defaults to true. + */ + addStartValue?: boolean; + + /** + * Pass `false` to not append an end value for this input. + * Defaults to true. + */ + addEndValue?: boolean; +} + +/** + * A layer that uses `tf.einsum` as the backing computation. + * + * This layer can perform einsum calculations of arbitrary dimensionality. + * + * Examples: + * + * **Biased dense layer with einsums** + * + * This example shows how to instantiate a standard Keras dense layer using + * einsum operations. This example is equivalent to + * tf.layers.Dense({units: 64, useBias: true})`. + * + * const layer = new EinsumDense({ + * equation: "ab,bc->ac", outputShape: 4, biasAxes: "c"}); + * const inputTensor = tf.input({shape: [32]}); + * const outputTensor = layer.call(inputTensor); + * console.log(outputTensor); // [null, 64] + * + * **Applying a dense layer to a sequence** + * + * This example shows how to instantiate a layer that applies the same dense + * operation to every element in a sequence. Here, the `outputShape` has two + * values (since there are two non-batch dimensions in the output); the first + * dimension in the `outputShape` is `null`, because the sequence dimension + * `b` has an unknown shape. + * + * const layer = new EinsumDense({ + * equation: "abc,cd->abd", outputShape: [null, 64], biasAxes: "d"}); + * const inputTensor = tf.input({shape: [32, 128]}); + * const outputTensor = layer.call(inputTensor); + * console.log(outputTensor); // [null, 32, 64] + * + * **Applying a dense layer to a sequence using ellipses** + * + * This example shows how to instantiate a layer that applies the same dense + * operation to every element in a sequence, but uses the ellipsis notation + * instead of specifying the batch and sequence dimensions. + * + * Because we are using ellipsis notation and have specified only one axis, the + * `outputShape` arg is a single value. When instantiated in this way, the + * layer can handle any number of sequence dimensions - including the case + * where no sequence dimension exists. + * + * const layer = new EinsumDense({ + * equation: "...x,xy->...y", outputShape: 64, biasAxes: "y"}); + * const inputTensor = tf.input({shape: [32, 128]}); + * const outputTensor = layer.call(inputTensor); + * console.log(outputTensor); // [null, 32, 64] + */ +export class EinsumDense extends Layer { + /** @nocollapse */ + static readonly className = 'EinsumDense'; + + constructor(args: EinsumDenseArgs) { + super(args); + throw new NotImplementedError(`Not implmented yet.`); + } + + override build(inputShape: Shape | Shape[]): void { + throw new NotImplementedError( + `Not implmented yet. Uses ${this.analyzeEinsumString}`); + } + + override getConfig(): serialization.ConfigDict { + throw new NotImplementedError(`Not implmented yet.`); + } + + override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor2D { + throw new NotImplementedError(`Not implmented yet.`); + } + + /** + * Analyzes an einsum string to determine the required weight shape. + */ + private analyzeEinsumString( + equation: string, + biasAxes: string, + inputShape: Shape, + outputShape: Shape + ): [Shape, Shape, Shape] { + throw new NotImplementedError( + `Not implmented yet. Uses ${this.analyzeSplitString}.`); + } + + /** + * Analyze an pre-split einsum string to find the weight shape. + */ + private analyzeSplitString( + splitString: [string, string, string], + biasAxes: string, + inputShape: Shape, + outputShape: Shape, + leftElided?: boolean + ): [Shape, Shape, Shape] { + throw new NotImplementedError(`Not implmented yet.`); + } +} +serialization.registerClass(EinsumDense); From 2a6d92931c0681427f59af92718757b52b8650c5 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Fri, 14 Jul 2023 01:04:41 +0000 Subject: [PATCH 08/12] lint --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 2 +- tfjs-layers/src/layers/nlp/multihead_attention.ts | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index 05b367f610a..f709e9182c7 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -29,7 +29,7 @@ import { InitializerIdentifier } from '../../initializers'; import { ActivationIdentifier } from '../../keras_format/activation_config'; import { Shape } from '../../keras_format/common'; import { RegularizerIdentifier } from '../../regularizers'; -import { Kwargs } from 'tfjs-layers/src/types'; +import { Kwargs } from '../../types'; export declare interface EinsumDenseArgs extends LayerArgs { /** diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index 628b4b0c49c..2ea80cd783f 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -227,9 +227,9 @@ export declare interface MultiHeadAttentionOptions { * Performs 1D cross-attention over two sequence inputs with an attention mask. * Returns the additional attention weights over heads. * - * const layer = new MultiHeadAttention({numHeads=2, keyDim=2}); - * const target = tf.keras.Input(shape=[8, 16]); - * const source = tf.keras.Input(shape=[4, 16]); + * const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); + * const target = tf.input({shape: [8, 16]}); + * const source = tf.input({shape: [4, 16]}); * const outputTensor, weights = layer.callAndReturnAttentionScores( * target, {value: source}); * console.log(outputTensor.shape); // [null, 8, 16] From 6dcb7a0b2a7bab1214910e2f5a23ed8916be3cd4 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Mon, 17 Jul 2023 11:24:18 -0700 Subject: [PATCH 09/12] Remove unused type declaration --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index f709e9182c7..435c74cd9b0 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -98,25 +98,6 @@ export declare interface EinsumDenseArgs extends LayerArgs { biasConstraint?: ConstraintIdentifier; } -export declare interface EinsumDenseOptions { - /** - * Pass to override the configured `sequenceLength` of the layer. - */ - sequenceLength?: number; - - /** - * Pass `false` to not append a start value for this input. - * Defaults to true. - */ - addStartValue?: boolean; - - /** - * Pass `false` to not append an end value for this input. - * Defaults to true. - */ - addEndValue?: boolean; -} - /** * A layer that uses `tf.einsum` as the backing computation. * From e58981701fba167e85f0d76c25474e6ae31d0ba6 Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Mon, 17 Jul 2023 11:26:47 -0700 Subject: [PATCH 10/12] Move helper functions outside EinsumDense class --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 54 +++++++++++----------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index 435c74cd9b0..dbc53a704a2 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -31,6 +31,32 @@ import { Shape } from '../../keras_format/common'; import { RegularizerIdentifier } from '../../regularizers'; import { Kwargs } from '../../types'; +/** + * Analyzes an einsum string to determine the required weight shape. + */ +export function analyzeEinsumString( + equation: string, + biasAxes: string, + inputShape: Shape, + outputShape: Shape +): [Shape, Shape, Shape] { + throw new NotImplementedError( + `Not implmented yet. Uses ${analyzeSplitString}.`); +} + +/** + * Analyze an pre-split einsum string to find the weight shape. + */ +function analyzeSplitString( + splitString: [string, string, string], + biasAxes: string, + inputShape: Shape, + outputShape: Shape, + leftElided?: boolean +): [Shape, Shape, Shape] { + throw new NotImplementedError(`Not implmented yet.`); +} + export declare interface EinsumDenseArgs extends LayerArgs { /** * An equation describing the einsum to perform. This equation must be a @@ -159,7 +185,7 @@ export class EinsumDense extends Layer { override build(inputShape: Shape | Shape[]): void { throw new NotImplementedError( - `Not implmented yet. Uses ${this.analyzeEinsumString}`); + `Not implmented yet. Uses ${analyzeEinsumString}`); } override getConfig(): serialization.ConfigDict { @@ -169,31 +195,5 @@ export class EinsumDense extends Layer { override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor2D { throw new NotImplementedError(`Not implmented yet.`); } - - /** - * Analyzes an einsum string to determine the required weight shape. - */ - private analyzeEinsumString( - equation: string, - biasAxes: string, - inputShape: Shape, - outputShape: Shape - ): [Shape, Shape, Shape] { - throw new NotImplementedError( - `Not implmented yet. Uses ${this.analyzeSplitString}.`); - } - - /** - * Analyze an pre-split einsum string to find the weight shape. - */ - private analyzeSplitString( - splitString: [string, string, string], - biasAxes: string, - inputShape: Shape, - outputShape: Shape, - leftElided?: boolean - ): [Shape, Shape, Shape] { - throw new NotImplementedError(`Not implmented yet.`); - } } serialization.registerClass(EinsumDense); From 9bafba54d8b9bed7fdd169b9271073690840ac6c Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Wed, 19 Jul 2023 10:31:12 -0700 Subject: [PATCH 11/12] Implement Einsum Dense --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 285 ++++++++++++++-- .../src/layers/nlp/einsum_dense_test.ts | 317 ++++++++++++++++++ 2 files changed, 583 insertions(+), 19 deletions(-) create mode 100644 tfjs-layers/src/layers/nlp/einsum_dense_test.ts diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index dbc53a704a2..6515313db05 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -20,16 +20,18 @@ */ /* Original source: keras/layers/core/einsum_dense.py */ -import { Tensor, Tensor2D, serialization } from '@tensorflow/tfjs-core'; +import { Tensor, Tensor2D, einsum, serialization, tidy } from '@tensorflow/tfjs-core'; -import { ConstraintIdentifier } from '../../constraints'; +import { Activation, getActivation, serializeActivation } from '../../activations'; +import { Constraint, ConstraintIdentifier, getConstraint, serializeConstraint } from '../../constraints'; import { Layer, LayerArgs } from '../../engine/topology'; -import { NotImplementedError } from '../../errors'; -import { InitializerIdentifier } from '../../initializers'; +import { ValueError } from '../../errors'; +import { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../initializers'; import { ActivationIdentifier } from '../../keras_format/activation_config'; import { Shape } from '../../keras_format/common'; -import { RegularizerIdentifier } from '../../regularizers'; +import { Regularizer, RegularizerIdentifier, getRegularizer, serializeRegularizer } from '../../regularizers'; import { Kwargs } from '../../types'; +import { LayerVariable } from '../../variables'; /** * Analyzes an einsum string to determine the required weight shape. @@ -40,21 +42,173 @@ export function analyzeEinsumString( inputShape: Shape, outputShape: Shape ): [Shape, Shape, Shape] { - throw new NotImplementedError( - `Not implmented yet. Uses ${analyzeSplitString}.`); + const dotReplacedString = equation.replace(/\.\.\./g, '0'); + + // This is the case where no ellipses are present in the string. + let splitString = + dotReplacedString.match(/([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)/); + if (splitString) { + return analyzeSplitString( + splitString, biasAxes, inputShape, outputShape); + } + + // This is the case where ellipses are present on the left. + splitString = + dotReplacedString.match(/0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)/); + if (splitString) { + return analyzeSplitString( + splitString, biasAxes, inputShape, outputShape, true); + } + + // This is the case where ellipses are present on the right. + splitString = + dotReplacedString.match(/([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0/); + if (splitString) { + return analyzeSplitString( + splitString, biasAxes, inputShape, outputShape); + } + + throw new ValueError( + `Invalid einsum equation '${equation}'. Equations must be in the form ` + + '[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....' + ); } /** * Analyze an pre-split einsum string to find the weight shape. */ -function analyzeSplitString( - splitString: [string, string, string], +export function analyzeSplitString( + splitString: RegExpMatchArray, biasAxes: string, inputShape: Shape, - outputShape: Shape, - leftElided?: boolean + outputShape: Shape|number, + leftElided = false ): [Shape, Shape, Shape] { - throw new NotImplementedError(`Not implmented yet.`); + const inputSpec = splitString[1]; + const weightSpec = splitString[2]; + const outputSpec = splitString[3]; + const elided = inputShape.length - inputSpec.length; + + const newOutputShape: Shape = Array.isArray(outputShape) ? + outputShape.slice() : [outputShape]; + newOutputShape.unshift(inputShape[0]); + + if (elided > 0 && leftElided) { + for(let i = 1; i < elided; i++) { + // We already inserted the 0th input dimension at dim 0, so we need + // to start at location 1 here. + newOutputShape.splice(1, 0, inputShape[i]); + } + } else if (elided > 0 && !leftElided) { + for(let i = inputShape.length - elided; i < inputShape.length; i++) { + newOutputShape.push(inputShape[i]); + } + } + + const inputSpecArr = Array.from(inputSpec); + const outputSpecArr = Array.from(outputSpec); + let inputDimMap, outputDimMap; + + if (leftElided) { + // If we have beginning dimensions elided, we need to use negative + // indexing to determine where in the input dimension our values are. + inputDimMap = new Map( + inputSpecArr.map((dim, i) => { + // This converts any negative indices to positive ones. + const idx = i + elided - inputShape.length; + const positiveIdx = idx < 0 ? inputShape.length + idx : idx; + return [dim, positiveIdx]; + }) + ); + + // Because we've constructed the full output shape already, we don't need + // to do negative indexing. + outputDimMap = new Map( + outputSpecArr.map((dim, i) => [dim, i + elided]) + ); + } else { + inputDimMap = new Map( + inputSpecArr.map((dim, i) => [dim, i]) + ); + outputDimMap = new Map( + outputSpecArr.map((dim, i) => [dim, i]) + ); + } + + for (const dim of inputSpec) { + const inputShapeAtDim = inputShape[inputDimMap.get(dim)]; + if (outputDimMap.has(dim)) { + const outputShapeAtDim = newOutputShape[outputDimMap.get(dim)]; + if (outputShapeAtDim !== null && outputShapeAtDim !== inputShapeAtDim) { + throw new ValueError( + `Input shape and output shape do not match at shared dimension `+ + `'${dim}'. Input shape is ${inputShapeAtDim}, and output shape ` + + `is ${outputShapeAtDim}.` + ); + } + } + } + + for (const dim of outputSpec) { + if (!inputSpec.includes(dim) && !weightSpec.includes(dim)) { + throw new ValueError( + `Dimension '${dim}' was specified in the output '${outputSpec}' ` + + `but has no corresponding dimension in the input spec ` + + `'${inputSpec}' or weight spec '${weightSpec}'` + ); + } + } + + const weightShape: Shape = []; + for (const dim of weightSpec) { + if (inputDimMap.has(dim)) { + weightShape.push(inputShape[inputDimMap.get(dim)]); + } else if (outputDimMap.has(dim)) { + weightShape.push(newOutputShape[outputDimMap.get(dim)]); + } else { + throw new ValueError( + `Weight dimension '${dim}' did not have a match in either the ` + + `input spec '${inputSpec}' or the output spec '${outputSpec}'. For ` + + `this layer, the weight must be fully specified.` + ); + } + } + + let biasShape: Shape; + if (biasAxes != null) { + const numLeftElided = leftElided ? elided : 0; + const idxMap: { [char: string]: number } = {}; + for (let i = 0; i < outputSpec.length; i++) { + idxMap[outputSpec[i]] = newOutputShape[i + numLeftElided]; + } + + for (const char of biasAxes) { + if (!outputSpec.includes(char)) { + throw new ValueError( + `Bias dimension '${char}' was requested, but is not part of the ` + + `output spec '${outputSpec}'` + ); + } + } + + const firstBiasLocation = Math.min( + ...biasAxes.split('').map(char => outputSpec.indexOf(char)) + ); + const biasOutputSpec = outputSpec.slice(firstBiasLocation); + + biasShape = biasOutputSpec.split('').map(char => + biasAxes.includes(char) ? idxMap[char] : 1 + ); + + if (!leftElided) { + for (let _ = 0; _ < elided; _++) { + biasShape.push(1); + } + } + } else { + biasShape = null; + } + return [weightShape, biasShape, newOutputShape]; } export declare interface EinsumDenseArgs extends LayerArgs { @@ -71,7 +225,7 @@ export declare interface EinsumDenseArgs extends LayerArgs { * any dimensions represented by ellipses). You can specify None for any * dimension that is unknown or can be inferred from the input shape. */ - outputShape: Shape; + outputShape: Shape|number; /** * Activation function to use. If you don't specify anything, no activation @@ -151,11 +305,13 @@ export declare interface EinsumDenseArgs extends LayerArgs { * dimension in the `outputShape` is `null`, because the sequence dimension * `b` has an unknown shape. * + * ```js * const layer = new EinsumDense({ * equation: "abc,cd->abd", outputShape: [null, 64], biasAxes: "d"}); * const inputTensor = tf.input({shape: [32, 128]}); * const outputTensor = layer.call(inputTensor); * console.log(outputTensor); // [null, 32, 64] + * ``` * * **Applying a dense layer to a sequence using ellipses** * @@ -168,32 +324,123 @@ export declare interface EinsumDenseArgs extends LayerArgs { * layer can handle any number of sequence dimensions - including the case * where no sequence dimension exists. * + * ```js * const layer = new EinsumDense({ * equation: "...x,xy->...y", outputShape: 64, biasAxes: "y"}); * const inputTensor = tf.input({shape: [32, 128]}); * const outputTensor = layer.call(inputTensor); * console.log(outputTensor); // [null, 32, 64] + * `` */ export class EinsumDense extends Layer { /** @nocollapse */ static readonly className = 'EinsumDense'; + private readonly equation: string; + private readonly biasAxes: string; + private readonly partialOutputShape: Shape; + private readonly activation: Activation; + private readonly kernelInitializer: Initializer; + private readonly biasInitializer: Initializer; + private readonly kernelRegularizer: Regularizer; + private readonly biasRegularizer: Regularizer; + private readonly kernelConstraint: Constraint; + private readonly biasConstraint: Constraint; + private fullOutputShape: Shape; + private _kernel: LayerVariable; + private _bias: LayerVariable; constructor(args: EinsumDenseArgs) { super(args); - throw new NotImplementedError(`Not implmented yet.`); + this.equation = args.equation; + this.biasAxes = args.biasAxes; + this.partialOutputShape = + Array.isArray(args.outputShape) ? args.outputShape : [args.outputShape]; + this.activation = getActivation(args.activation); + this.kernelInitializer = getInitializer( + args.kernelInitializer ?? 'glorotUniform'); + this.biasInitializer = getInitializer(args.biasInitializer ?? 'zeros'); + this.kernelRegularizer = getRegularizer(args.kernelRegularizer); + this.biasRegularizer = getRegularizer(args.biasRegularizer); + this.kernelConstraint = getConstraint(args.kernelConstraint); + this.biasConstraint = getConstraint(args.biasConstraint); + } + + get kernel(): LayerVariable { + return this._kernel; + } + + get bias(): LayerVariable { + return this._bias; + } + + override build(inputShape: Shape): void { + const [kernelShape, biasShape, fullOutputShape] = analyzeEinsumString( + this.equation, + this.biasAxes, + inputShape, + this.partialOutputShape + ); + this.fullOutputShape = fullOutputShape; + this._kernel = this.addWeight( + 'kernel', + kernelShape, + this.dtype, + this.kernelInitializer, + this.kernelRegularizer, + true, + this.kernelConstraint, + ); + + if (biasShape != null) { + this._bias = this.addWeight( + 'bias', + biasShape, + this.dtype, + this.biasInitializer, + this.biasRegularizer, + true, + this.biasConstraint, + ); + } else { + this._bias = null; + } + super.build(inputShape); } - override build(inputShape: Shape | Shape[]): void { - throw new NotImplementedError( - `Not implmented yet. Uses ${analyzeEinsumString}`); + override computeOutputShape(_: Shape): Shape { + return this.fullOutputShape; } override getConfig(): serialization.ConfigDict { - throw new NotImplementedError(`Not implmented yet.`); + const config = { + outputShape: this.partialOutputShape, + equation: this.equation, + activation: serializeActivation(this.activation), + biasAxes: this.biasAxes, + kernelInitializer: serializeInitializer(this.kernelInitializer), + biasInitializer: serializeInitializer(this.biasInitializer), + kernelRegularizer: serializeRegularizer(this.kernelRegularizer), + biasRegularizer: serializeRegularizer(this.biasRegularizer), + kernelConstraint: serializeConstraint(this.kernelConstraint), + biasConstraint: serializeConstraint(this.biasConstraint), + }; + const baseConfig = super.getConfig(); + Object.assign(config, baseConfig); + return config; } override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor2D { - throw new NotImplementedError(`Not implmented yet.`); + return tidy(() => { + inputs = Array.isArray(inputs) ? inputs : [inputs]; + let ret = einsum(this.equation, ...inputs, this.kernel.read()); + if (this.bias != null) { + ret = ret.add(this.bias.read()); + } + if (this.activation != null) { + ret = this.activation.apply(ret); + } + return ret; + }); } } serialization.registerClass(EinsumDense); diff --git a/tfjs-layers/src/layers/nlp/einsum_dense_test.ts b/tfjs-layers/src/layers/nlp/einsum_dense_test.ts new file mode 100644 index 00000000000..f209c0c7c03 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/einsum_dense_test.ts @@ -0,0 +1,317 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the 'License'); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an 'AS IS' BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Unit Tests for TFJS-based EinsumDense Layer. + */ + +import { memory, Tensor } from '@tensorflow/tfjs-core'; + +import { analyzeEinsumString, EinsumDense } from './einsum_dense'; +import { Shape } from '../../keras_format/common'; +import { input } from '../../exports'; + +declare interface EinsumDenseTestCaseArgs { + testcaseName: string; + equation: string; + biasAxes: string; + inputShape: Shape; + outputShape: Shape; + expectedWeightShape: Shape; + expectedBiasShape: Shape; + expectedOutputShape: Shape; +} + +describe('EinsumDense', () => { + const combinations: EinsumDenseTestCaseArgs[] = [ + { + testcaseName: '_1d_end_weight', + equation: 'ab,b->a', + biasAxes: null, + inputShape: [null, 32], + outputShape: [], + expectedWeightShape: [32], + expectedBiasShape: null, + expectedOutputShape: [null], + }, + { + testcaseName: '_2d_middle_weight', + equation: 'ab,bc->ac', + biasAxes: null, + inputShape: [null, 32], + outputShape: [64], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, 64], + }, + { + testcaseName: '_3d_bert', + equation: 'abc,cde->abde', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [1, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_3_bias', + equation: 'abc,cde->abde', + biasAxes: 'e', + inputShape: [null, 1, 2], + outputShape: [1, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [4], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_2_bias', + equation: 'abc,cde->abde', + biasAxes: 'd', + inputShape: [null, 1, 2], + outputShape: [1, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [3, 1], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_1_3_bias', + equation: 'abc,cde->abde', + biasAxes: 'be', + inputShape: [null, 7, 2], + outputShape: [7, 3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [7, 1, 4], + expectedOutputShape: [null, 7, 3, 4], + }, + { + testcaseName: '_3d_bert_projection', + equation: 'BFNH,NHD->BFD', + biasAxes: null, + inputShape: [null, 1, 2, 3], + outputShape: [1, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 4], + }, + { + testcaseName: '_2d_bert', + equation: 'abc,cd->abd', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [1, 4], + expectedWeightShape: [2, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 4], + }, + { + testcaseName: '_embedding_1d', + equation: 'i,d->id', + biasAxes: null, + inputShape: [null], + outputShape: [2], + expectedWeightShape: [2], + expectedBiasShape: null, + expectedOutputShape: [null, 2], + }, + { + testcaseName: '_xlnet_lm', + equation: 'ibd,nd->ibn', + biasAxes: null, + inputShape: [null, null, 1], + outputShape: [null, 2], + expectedWeightShape: [2, 1], + expectedBiasShape: null, + expectedOutputShape: [null, null, 2], + }, + { + testcaseName: '_2d_precast', + equation: '...b,bc->...c', + biasAxes: null, + inputShape: [null, 32], + outputShape: [64], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, 64], + }, + { + testcaseName: '_2d_precast_elided_input_used_in_output', + equation: '...bc,bc->...b', + biasAxes: null, + inputShape: [null, 32, 64], + outputShape: [32], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, 32], + }, + { + testcaseName: '_2d_precast_multiple_elided_dims', + equation: '...b,bc->...c', + biasAxes: null, + inputShape: [null, null, 32], + outputShape: [64], + expectedWeightShape: [32, 64], + expectedBiasShape: null, + expectedOutputShape: [null, null, 64], + }, + { + testcaseName: '_3d_precast', + equation: '...c,cde->...de', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_precast_3_bias', + equation: '...c,cde->...de', + biasAxes: 'e', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [4], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_precast_2_bias', + equation: '...c,cde->...de', + biasAxes: 'd', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [3, 1], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_3d_precast_2_3_bias', + equation: '...c,cde->...de', + biasAxes: 'de', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [2, 3, 4], + expectedBiasShape: [3, 4], + expectedOutputShape: [null, 1, 3, 4], + }, + { + testcaseName: '_2d_postcast', + equation: 'bc...,cd->bd...', + biasAxes: null, + inputShape: [null, 1, 2, 3], + outputShape: [4], + expectedWeightShape: [1, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 4, 2, 3], + }, + { + testcaseName: '_3d_postcast', + equation: 'bc...,cde->bde...', + biasAxes: null, + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: null, + expectedOutputShape: [null, 3, 4, 2], + }, + { + testcaseName: '_3d_postcast_1_bias', + equation: 'bc...,cde->bde...', + biasAxes: 'd', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: [3, 1, 1], + expectedOutputShape: [null, 3, 4, 2], + }, + { + testcaseName: '_3d_postcast_2_bias', + equation: 'bc...,cde->bde...', + biasAxes: 'e', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: [4, 1], + expectedOutputShape: [null, 3, 4, 2], + }, + { + testcaseName: '_3d_postcast_1_2_bias', + equation: 'bc...,cde->bde...', + biasAxes: 'de', + inputShape: [null, 1, 2], + outputShape: [3, 4], + expectedWeightShape: [1, 3, 4], + expectedBiasShape: [3, 4, 1], + expectedOutputShape: [null, 3, 4, 2], + }, + ]; + + function testWeightShape(combo: EinsumDenseTestCaseArgs) { + it(`${combo.testcaseName} weight shape`, () => { + const [weightShape, biasShape, _] = analyzeEinsumString( + combo.equation, combo.biasAxes, combo.inputShape, combo.outputShape + ); + expect(weightShape).toEqual(combo.expectedWeightShape); + expect(biasShape).toEqual(combo.expectedBiasShape); + }); + } + + function testLayerCreation(combo: EinsumDenseTestCaseArgs) { + it(`${combo.testcaseName} layer creation`, () => { + const nonBatchInputShape = combo.inputShape.slice(1); + const inputTensor = input({shape: nonBatchInputShape}); + + const layer = new EinsumDense({ + equation: combo.equation, + biasAxes: combo.biasAxes, + outputShape: combo.outputShape, + }); + const outputTensor = layer.apply(inputTensor) as Tensor; + + expect(layer.kernel.shape).toEqual(combo.expectedWeightShape); + if (combo.expectedBiasShape === null) { + expect(layer.bias).toBeNull(); + } else { + expect(layer.bias.shape).toEqual(combo.expectedBiasShape); + } + expect(outputTensor.shape).toEqual(combo.expectedOutputShape); + }); + } + + for (const combo of combinations) { + testWeightShape(combo); + testLayerCreation(combo); + } + + it('Does not leak memory', () => { + const combo = combinations[0]; + const layer = new EinsumDense({ + equation: combo.equation, + biasAxes: combo.biasAxes, + outputShape: combo.outputShape, + }); + const nonBatchInputShape = combo.inputShape.slice(1); + const inputTensor = input({shape: nonBatchInputShape}); + + const numTensors = memory().numTensors; + layer.apply(inputTensor); + + expect(memory().numTensors).toEqual(numTensors + 1); + }); + + // TODO(pforderique): Test serialization. +}); From 4428cf1d4c4da2ba10e4aac8cbec915081daf77e Mon Sep 17 00:00:00 2001 From: Piero Orderique Date: Wed, 26 Jul 2023 15:08:28 -0700 Subject: [PATCH 12/12] Address comments --- tfjs-layers/src/layers/nlp/einsum_dense.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/einsum_dense.ts b/tfjs-layers/src/layers/nlp/einsum_dense.ts index 6515313db05..6fd744ef0c5 100644 --- a/tfjs-layers/src/layers/nlp/einsum_dense.ts +++ b/tfjs-layers/src/layers/nlp/einsum_dense.ts @@ -116,7 +116,8 @@ export function analyzeSplitString( inputSpecArr.map((dim, i) => { // This converts any negative indices to positive ones. const idx = i + elided - inputShape.length; - const positiveIdx = idx < 0 ? inputShape.length + idx : idx; + const positiveIdx = + ((idx % inputShape.length) + inputShape.length) % inputShape.length; return [dim, positiveIdx]; }) ); @@ -201,7 +202,7 @@ export function analyzeSplitString( ); if (!leftElided) { - for (let _ = 0; _ < elided; _++) { + for (let i = 0; i < elided; i++) { biasShape.push(1); } }