From a0115ea52a949197606cf4c894c5fa74a09e61be Mon Sep 17 00:00:00 2001 From: Piero F Orderique <45519489+pforderique@users.noreply.github.com> Date: Mon, 28 Aug 2023 11:06:31 -0700 Subject: [PATCH] Add spec for GPT2CausalLM and dependencies (#7897) Adds the specs for GPT2CausalLM, GPT2CausalLMPreprocessor, GenerativeTask, Task, and PipelineModel. --- tfjs-layers/src/layers/nlp/models/backbone.ts | 4 +- .../src/layers/nlp/models/generative_task.ts | 116 +++++++++ .../layers/nlp/models/gpt2/gpt2_backbone.ts | 4 +- .../layers/nlp/models/gpt2/gpt2_causal_lm.ts | 224 ++++++++++++++++++ .../gpt2/gpt2_causal_lm_preprocessor.ts | 119 ++++++++++ .../nlp/models/gpt2/gpt2_preprocessor.ts | 36 ++- .../nlp/models/gpt2/gpt2_preprocessor_test.ts | 12 +- .../src/layers/nlp/models/preprocessor.ts | 2 +- tfjs-layers/src/layers/nlp/models/task.ts | 126 ++++++++++ tfjs-layers/src/layers/nlp/utils.ts | 118 ++++++++- 10 files changed, 730 insertions(+), 31 deletions(-) create mode 100644 tfjs-layers/src/layers/nlp/models/generative_task.ts create mode 100644 tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts create mode 100644 tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm_preprocessor.ts create mode 100644 tfjs-layers/src/layers/nlp/models/task.ts diff --git a/tfjs-layers/src/layers/nlp/models/backbone.ts b/tfjs-layers/src/layers/nlp/models/backbone.ts index b6d0811b89b..72ae1c54687 100644 --- a/tfjs-layers/src/layers/nlp/models/backbone.ts +++ b/tfjs-layers/src/layers/nlp/models/backbone.ts @@ -25,7 +25,7 @@ import { serialization } from '@tensorflow/tfjs-core'; import { ContainerArgs } from '../../../engine/container'; import { LayersModel } from '../../../engine/training'; import { NotImplementedError } from '../../../errors'; -import { Layer } from '../../../exports_layers'; +import { Embedding } from '../../embeddings'; export class Backbone extends LayersModel { /** @nocollapse */ @@ -38,7 +38,7 @@ export class Backbone extends LayersModel { /** * A `tf.layers.embedding` instance for embedding token ids. */ - get tokenEmbedding(): Layer { + get tokenEmbedding(): Embedding { throw new NotImplementedError(); } diff --git a/tfjs-layers/src/layers/nlp/models/generative_task.ts b/tfjs-layers/src/layers/nlp/models/generative_task.ts new file mode 100644 index 00000000000..f83116bc371 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/models/generative_task.ts @@ -0,0 +1,116 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Base class for Generative Task models. + */ + +/* Original source: keras_nlp/models/generative_task.py */ +import { NamedTensorMap, Tensor } from '@tensorflow/tfjs-core'; + +import { NotImplementedError } from '../../../errors'; +import { ModelCompileArgs } from '../../../engine/training'; + +import { Task } from './task'; + +export type GenerateFn = + (inputs: NamedTensorMap, endTokenId?: number) => NamedTensorMap; + +/** + * Base class for Generative Task models. + */ +export class GenerativeTask extends Task { + /** @nocollapse */ + static override className = 'GenerativeTask'; + + protected generateFunction: GenerateFn; + + override compile(args: ModelCompileArgs): void { + throw new NotImplementedError(); + } + + /** + * Run the generation on a single batch of input. + */ + generateStep( + inputs: NamedTensorMap, + endTokenId: number + ): NamedTensorMap { + throw new NotImplementedError(); + } + + /** + * Create or return the compiled generation function. + */ + makeGenerateFunction(): GenerateFn { + throw new NotImplementedError(); + } + + /** + * Normalize user input to the generate function. + * + * This function converts all inputs to tensors, adds a batch dimension if + * necessary, and returns a iterable "dataset like" object. + */ + protected normalizeGenerateInputs(inputs: Tensor): [Tensor, boolean] { + throw new NotImplementedError(); + } + + /** + * Normalize user output from the generate function. + * + * This function converts all output to numpy (for integer output), or + * python strings (for string output). If a batch dimension was added to + * the input, it is removed from the output (so generate can be string in, + * string out). + */ + protected normalizeGenerateOutputs( + outputs: Tensor, + inputIsScalar: boolean + ): Tensor { + throw new NotImplementedError(); + } + + /** + * Generate text given prompt `inputs`. + * + * This method generates text based on given `inputs`. The sampling method + * used for generation can be set via the `compile()` method. + * + * `inputs` will be handled as a single batch. + * + * If a `preprocessor` is attached to the model, `inputs` will be + * preprocessed inside the `generate()` function and should match the + * structure expected by the `preprocessor` layer (usually raw strings). + * If a `preprocessor` is not attached, inputs should match the structure + * expected by the `backbone`. See the example usage above for a + * demonstration of each. + * + * @param inputs tensor data. If a `preprocessor` is attached to the model, + * `inputs` should match the structure expected by the `preprocessor` layer. + * If a `preprocessor` is not attached, `inputs` should match the structure + * expected the the `backbone` model. + * @param maxLength Integer. The max length of the generated sequence. + * Will default to the max configured `sequenceLength` of the + * `preprocessor`. If `preprocessor` is `null`, `inputs` should be + * should be padded to the desired maximum length and this argument + * will be ignored. + */ + generate(inputs: Tensor, maxLength?: number) { + throw new NotImplementedError(); + } +} diff --git a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts index 266427f61e0..971f8868cc6 100644 --- a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts +++ b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts @@ -215,8 +215,8 @@ export class GPT2Backbone extends Backbone { return config; } - override get tokenEmbedding() { - return this.getLayer('token_embedding'); + override get tokenEmbedding(): Embedding { + return this.getLayer('token_embedding') as Embedding; } } serialization.registerClass(GPT2Backbone); diff --git a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts new file mode 100644 index 00000000000..b859e16e27f --- /dev/null +++ b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts @@ -0,0 +1,224 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * GPT2 Causal LM (Language Model). + */ + +/* Original source: keras-nlp/models/gpt2/gpt2_causal_lm.py */ +import { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core'; + +import { GPT2Preprocessor } from './gpt2_preprocessor'; +import { NotImplementedError } from '../../../../errors'; +import { Layer } from '../../../../exports_layers'; +import { LayerArgs } from '../../../../engine/topology'; +import { Embedding } from '../../../../layers/embeddings'; +import { Shape } from '../../../../keras_format/common'; +import { GenerativeTask } from '../generative_task'; +import { GPT2Backbone } from './gpt2_backbone'; +import { PipelineModelArgs } from '../../utils'; +import { Kwargs } from '../../../../types'; + +declare interface ReverseEmbeddingArgs extends LayerArgs { + embedding: Embedding; +} + +class ReverseEmbedding extends Layer { + protected embedding: Embedding; + + constructor(args: ReverseEmbeddingArgs) { + super(args); + this.embedding = args.embedding; + } + + override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] { + throw new NotImplementedError(); + } + + override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { + throw new NotImplementedError(); + } + +} + +export declare interface GPT2CausalLMArgs extends PipelineModelArgs { + /** + * A `GPT2Backbone` instance. + */ + backbone: GPT2Backbone; + + /** + * Optional `GPT2CausalLMPreprocessor`. + * If `null`, this model will not apply preprocessing, and inputs should be + * preprocessed before calling the model. + */ + preprocessor?: GPT2Preprocessor; +} + +/** + * An end-to-end GPT2 model for causal langauge modeling. + * + * A causal language model (LM) predicts the next token based on previous + * tokens. This task setup can be used to train the model unsupervised on + * plain text input, or to autoregressively generate plain text similar to + * the data used for training. This task can be used for pre-training or + * fine-tuning a GPT-2 model, simply by calling `fit()`. + * + * This model has a `generate()` method, which generates text based on a + * prompt. The generation strategy used is controlled by an additional + * sampler` argument on `compile()`. + * By default, the top k results will be returned. + * + * This model can optionally be configured with a `preprocessor` layer, in + * which case it will automatically apply preprocessing to string inputs during + * fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + * when creating the model with `fromPreset()`. + * + * Disclaimer: Pre-trained models are provided on an "as is" basis, without + * warranties or conditions of any kind. The underlying model is provided by a + * third party and subject to a separate license, available + * here](https://github.com/openai/gpt-2). + * + * Use `generate()` to do text generation. + * ```js + * const gpt2LM = GPT2CausalLM.fromPreset('gpt2_base_en'); + * gpt2LM.generate("I want to say", max_length=30); + * // Generate with batched prompts. + * gpt2LM.generate(["This is a", "Where are you"], max_length=30); + * ``` + * + * Use `generate()` without preprocessing. + * ```js + * // Prompt the model with `5338, 318` (the token ids for `"Who is"`). + * // Use `"paddingMask"` to indicate values that should not be overridden. + * const prompt = { + * tokenIds: tf.tensor([[5338, 318, 0, 0, 0], [5338, 318, 0, 0, 0]]), + * paddingMask: tf.tensor([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]), + * }; + * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); + * gpt2LM.generate(prompt); + * ``` + * + * Call `fit()` on a single batch. + * ```js + * const features = ['The quick brown fox jumped.', 'I forgot my homework.']; + * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en'); + * gpt2LM.fit(features, {batchSize: 2}); + * ``` + * + * Call `fit()` without preprocessing. + * ```js + * const x = { + * tokenIds: tf.tensor([[50256, 1, 2, 3, 4], [50256, 1, 2, 3, 4]]), + * paddingMask: tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), + * }; + * const y = tf.tensor([[1, 2, 3, 4, 50256], [1, 2, 3, 4, 50256]]); + * const sw = tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]); + * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); + * gpt2LM.fit(x, y, {sampleWeight: sw, batchSize: 2}); + * ``` + * + * Custom backbone and vocabulary. + * ```js + * const features = ["a quick fox.", "a fox quick."]; + * const vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6}; + * const merges = [ + * "Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox" + * ]; + * const tokenizer = new GPT2Tokenizer({vocabulary: vocab, merges}); + * const preprocessor = new GPT2CausalLMPreprocessor({ + * tokenizer, + * sequence_length: 128, + * }); + * const backbone = new GPT2Backbone({ + * vocabularysize: 30552, + * numlayers: 4, + * numheads: 4, + * hiddendim: 256, + * intermediatedim: 512, + * maxSequenceLength: 128, + * }); + * const gpt2LM = new GPT2CausalLM({backbone, preprocessor}); + * gpt2LM.fit(features, {batch_size: 2}); + * ``` + */ +export class GPT2CausalLM extends GenerativeTask { + /** @nocollapse */ + static override className = 'GPT2CausalLM'; + + constructor(args: GPT2CausalLMArgs) { + super(args); + throw new NotImplementedError(`Uses ${ReverseEmbedding}.`); + } + + static override presets( + cls: serialization.SerializableConstructor + ): {} { + throw new NotImplementedError(); + } + + /** + * Forward pass of `GPT2CausalLM` with cache. + * + * `callWithCache` adds an additional forward pass for the model for + * autoregressive inference. Unlike calling the model directly, this method + * allows caching previous key/value Tensors in multi-head attention layer, + * and avoids recomputing the outputs of seen tokens. + * + * @param tokenIds a dense int Tensor with shape `[batchSize, maxLength]`. + * @param cache a dense float Tensor, the cache of key and value. + * @param cacheUpdateIndex Integer. The index of current inputs in the whole + * sequence. + * @returns [logits, hiddenStates, cache], where `logits` is the + * language model logits for the input tokenIds, `hiddenStates` is + * the final hidden representation of the input tokens, and `cache` is + * the decoding cache. + */ + callWithCache( + tokenIds: Tensor, + cache: Tensor, + cacheUpdateIndex: number + ): [Tensor, Tensor, Tensor] { + throw new NotImplementedError(); + } + + /** + * Build an empty cache for use with `callWithCache()`. + */ + private buildCache(tokenIds: Tensor): [Tensor, Tensor] { + throw new NotImplementedError(); + } + + /** + * A compilable generation function for a single batch of inputs. + * + * This function represents the inner generation function for a single batch + * of inputs. + * + * @param inputs An object with two keys `tokenIds` and `paddingMask` and + * batched tensor values. + * @param endTokenId The id of the end token to stop on. If all + * sequences have produced a new `endTokenId`, generation will stop. + */ + override generateStep( + inputs: NamedTensorMap, + endTokenId: number + ): NamedTensorMap { + throw new NotImplementedError(`Uses ${this.buildCache}`); + } +} +serialization.registerClass(GPT2CausalLM); diff --git a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm_preprocessor.ts b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm_preprocessor.ts new file mode 100644 index 00000000000..5b0df64ea78 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm_preprocessor.ts @@ -0,0 +1,119 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * GPT2 Causal LM preprocessor layer. + */ + +/* Original source: keras-nlp/models/gpt2/gpt2_causal_lm_preprocessor.py */ +import { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core'; + +import { GPT2Preprocessor, GPT2PreprocessorOptions, packXYSampleWeight } from './gpt2_preprocessor'; +import { NotImplementedError } from '../../../../errors'; + +/** + * GPT2 Causal LM preprocessor. + * + * This preprocessing layer is meant for use with + * `GPT2CausalLM`. By default, it will take in batches of + * strings, and return outputs in a `[x, y, sampleWeight]` format, where the + * `y` label is the next token id in the `x` sequence. + * + * For use with generation, the layer also exposes two methods + * generatePreprocess()` and `generatePostprocess()`. When this preprocessor + * is attached to a `GPT2CausalLM` instance, these methods + * will be called implicitly in `generate()`. They can also be called + * standalone (e.g. to precompute preprocessing inputs for generation in a + * separate process). + * + * Examples: + * ```js + * // Load the preprocessor from a preset. + * const preprocessor = GPT2CausalLMPreprocessor.from_preset('gpt2_base_en'); + * + * // Tokenize and pack a single sentence. + * const sentence = tf.scalar('League of legends'); + * preprocessor.apply(sentence); + * // Same output. + * preprocessor('League of legends'); + * + * // Tokenize a batch of sentences. + * const sentences = tf.constant(['Taco tuesday', 'Fish taco please!']); + * preprocessor.apply(sentences); + * // Same output. + * preprocessor.apply(['Taco tuesday', 'Fish taco please!']); + * ``` + */ +export class GPT2CausalLMPreprocessor extends GPT2Preprocessor { + /** @nocollapse */ + static override className = 'GPT2CausalLMPreprocessor'; + + override call( + inputs: Tensor|Tensor[], + kwargs: GPT2PreprocessorOptions + ): Tensor|Tensor[] { + const output = this.callAndPackArgs(inputs, kwargs); + if (kwargs.y) { + return (output as [NamedTensorMap, Tensor])[0]['tokenIds']; + } + return (output as NamedTensorMap)['tokenIds']; + } + + /** + * Calls the layer and returns extra information like the paddingMask used to + * pack the sequence, the label data, and the sample weights used. + */ + override callAndPackArgs( + inputs: Tensor|Tensor[], + kwargs: GPT2PreprocessorOptions + ): + NamedTensorMap + | [NamedTensorMap, Tensor] + | [NamedTensorMap, Tensor, Tensor] { + + throw new NotImplementedError(`Uses ${packXYSampleWeight}`); + } + + /** + * Covert strings to integer token input for generation. + * + * Similar to calling the layer for training, this method takes in strings + * or tensor strings, tokenizes and packs the input, and computes a padding + * mask masking all inputs not filled in with a padded value. + * + * Unlike calling the the layer for training, this method does not compute + * labels and will never append a `tokenizer.endTokenId` to the end of + * the sequence (as generation is expected to continue at the end of the + * inputted prompt). + */ + generatePreprocess(x: Tensor, sequenceLength?: number): NamedTensorMap { + throw new NotImplementedError(); + } + + /** + * Covert integer token output to strings for generation. + * + * This method reverses `generatePreprocess()`, by first removing all + * padding and start/end tokens, and then converting the integer sequence + * back to a string. + */ + generatePostprocess(x: NamedTensorMap): Tensor { + throw new NotImplementedError(); + } + +} +serialization.registerClass(GPT2CausalLMPreprocessor); diff --git a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts index a1a8d432c24..87e1a926210 100644 --- a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts +++ b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts @@ -20,7 +20,7 @@ */ /* Original source: keras-nlp/models/gpt2/gpt2_preprocessor.py */ -import { Tensor, Tensor2D, serialization, tidy } from '@tensorflow/tfjs-core'; +import { NamedTensorMap, Tensor, Tensor2D, serialization, tidy } from '@tensorflow/tfjs-core'; import { LayerArgs } from '../../../../engine/topology'; import { Preprocessor } from '../preprocessor'; @@ -72,16 +72,11 @@ export declare interface GPT2PreprocessorOptions { sequenceLength?: number; } -export declare interface PreprocessorOutputs { - tokenIds: Tensor2D; - paddingMask: Tensor2D; -} - -function packXYSampleWeight( - x: PreprocessorOutputs, y?: Tensor, sampleWeight?: Tensor): - PreprocessorOutputs - | [PreprocessorOutputs, Tensor] - | [PreprocessorOutputs, Tensor, Tensor] { +export function packXYSampleWeight( + x: NamedTensorMap, y?: Tensor, sampleWeight?: Tensor): + NamedTensorMap + | [NamedTensorMap, Tensor] + | [NamedTensorMap, Tensor, Tensor] { if (y === undefined) { return x; @@ -129,10 +124,13 @@ function packXYSampleWeight( * ``` */ export class GPT2Preprocessor extends Preprocessor { - private readonly sequenceLength: number; - private readonly addStartToken: boolean; - private readonly addEndToken: boolean; - private readonly packer: StartEndPacker; + /** @nocollapse */ + static override className = 'GPT2Preprocessor'; + + protected readonly sequenceLength: number; + protected readonly addStartToken: boolean; + protected readonly addEndToken: boolean; + protected readonly packer: StartEndPacker; constructor(args: GPT2PreprocessorArgs) { super(args); @@ -169,7 +167,7 @@ export class GPT2Preprocessor extends Preprocessor { private callAndReturnPaddingMask( inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions - ): PreprocessorOutputs { + ): NamedTensorMap { return tidy(() => { if (inputs instanceof Array) { if (inputs.length !== 1) { @@ -205,9 +203,9 @@ export class GPT2Preprocessor extends Preprocessor { * pack the sequence, the label data, and the sample weights used. */ callAndPackArgs(inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions): - PreprocessorOutputs - | [PreprocessorOutputs, Tensor] - | [PreprocessorOutputs, Tensor, Tensor] { + NamedTensorMap + | [NamedTensorMap, Tensor] + | [NamedTensorMap, Tensor, Tensor] { const x = this.callAndReturnPaddingMask(inputs, kwargs); return packXYSampleWeight(x, kwargs.y, kwargs.sampleWeight); } diff --git a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor_test.ts b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor_test.ts index edb0d06af86..b464c386cfb 100644 --- a/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor_test.ts +++ b/tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor_test.ts @@ -19,9 +19,9 @@ * Unit Tests for GPT2Preprocessor. */ -import { Tensor, memory, serialization, tensor, tensor2d } from '@tensorflow/tfjs-core'; +import { NamedTensorMap, Tensor, memory, serialization, tensor, tensor2d } from '@tensorflow/tfjs-core'; -import { GPT2Preprocessor, PreprocessorOutputs } from './gpt2_preprocessor'; +import { GPT2Preprocessor } from './gpt2_preprocessor'; import { GPT2Tokenizer } from './gpt2_tokenizer'; import { expectTensorsClose } from '../../../../utils/test_utils'; @@ -55,7 +55,7 @@ describe('GPT2Preprocessor', () => { const inputData = tensor(['airplane at airport']); const output = - preprocessor.callAndPackArgs(inputData, {}) as PreprocessorOutputs; + preprocessor.callAndPackArgs(inputData, {}) as NamedTensorMap; expectTensorsClose(output.tokenIds, tensor2d([[6, 1, 3, 4, 2, 5, 6, 0]])); expectTensorsClose( @@ -77,7 +77,7 @@ describe('GPT2Preprocessor', () => { }; const output = - preprocessor.callAndPackArgs(inputData, {}) as PreprocessorOutputs; + preprocessor.callAndPackArgs(inputData, {}) as NamedTensorMap; expectTensorsClose(output.tokenIds, expectedOutput.tokenIds); expectTensorsClose(output.paddingMask, expectedOutput.paddingMask); @@ -95,7 +95,7 @@ describe('GPT2Preprocessor', () => { const output = preprocessor.callAndPackArgs( inputData, {y: yIn, sampleWeight: swIn} - ) as [PreprocessorOutputs, Tensor, Tensor]; + ) as [NamedTensorMap, Tensor, Tensor]; expectTensorsClose(output[0].tokenIds, expectedX.tokenIds); expectTensorsClose(output[0].paddingMask, expectedX.paddingMask); @@ -108,7 +108,7 @@ describe('GPT2Preprocessor', () => { const output = preprocessor.callAndPackArgs( inputData, {sequenceLength: 4} - ) as PreprocessorOutputs; + ) as NamedTensorMap; expectTensorsClose(output.tokenIds, tensor2d([[6, 1, 3, 6]])); }); diff --git a/tfjs-layers/src/layers/nlp/models/preprocessor.ts b/tfjs-layers/src/layers/nlp/models/preprocessor.ts index 5ef9c0787db..7b8e0937ec1 100644 --- a/tfjs-layers/src/layers/nlp/models/preprocessor.ts +++ b/tfjs-layers/src/layers/nlp/models/preprocessor.ts @@ -28,7 +28,7 @@ import { deserializeKerasObject, serializeKerasObject } from '../../../utils/gen */ export class Preprocessor extends Layer { /** @nocollapse */ - static readonly className = 'Preprocessor'; + static className = 'Preprocessor'; private _tokenizer: Tokenizer; diff --git a/tfjs-layers/src/layers/nlp/models/task.ts b/tfjs-layers/src/layers/nlp/models/task.ts new file mode 100644 index 00000000000..d601c93290d --- /dev/null +++ b/tfjs-layers/src/layers/nlp/models/task.ts @@ -0,0 +1,126 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Base class for Task models. + */ + +/* Original source: keras_nlp/models/task.py */ +import { Tensor, serialization } from '@tensorflow/tfjs-core'; + +import { NotImplementedError } from '../../../errors'; + +import { PipelineModel, PipelineModelArgs } from '../utils'; +import { Backbone } from './backbone'; +import { Preprocessor } from './preprocessor'; +import { ModelCompileArgs } from '../../../engine/training'; +import { LossOrMetricFn } from '../../../types'; + +export class Task extends PipelineModel { + /** @nocollapse */ + static override className = 'Task'; + + protected _backbone: Backbone; + protected _preprocessor: Preprocessor; + + constructor(args: PipelineModelArgs) { + super(args); + } + + private checkForLossMismatch( + loss: string|string[]|{[outputName: string]: string}|LossOrMetricFn| + LossOrMetricFn[]|{[outputName: string]: LossOrMetricFn} + ) { + throw new NotImplementedError(); + } + + override compile(args: ModelCompileArgs): void { + this.checkForLossMismatch(args.loss); + super.compile(args); + } + + override preprocessSamples(x: Tensor, y?: Tensor, sampleWeight?: Tensor): + Tensor | [Tensor, Tensor] | [Tensor, Tensor, Tensor] { + throw new NotImplementedError(); + } + + /** + * A `LayersModel` instance providing the backbone submodel. + */ + get backbone(): Backbone { + return this._backbone; + } + + set backbone(value: Backbone) { + this._backbone = value; + } + + /** + * A `LayersModel` instance used to preprocess inputs. + */ + get preprocessor(): Preprocessor { + return this._preprocessor; + } + + set preprocessor(value: Preprocessor) { + this.includePreprocessing = value != null; + this._preprocessor = value; + } + + override getConfig(): serialization.ConfigDict { + // Don't chain to super here. The default `getConfig()` for functional + // models is nested and cannot be passed to our Task constructors. + throw new NotImplementedError(); + } + + static override fromConfig( + cls: serialization.SerializableConstructor, + config: serialization.ConfigDict + ): T { + throw new NotImplementedError(); + } + + static backboneCls( + cls: serialization.SerializableConstructor + ): serialization.SerializableConstructor { + return null; + } + + static preprocessorCls( + cls: serialization.SerializableConstructor + ): serialization.SerializableConstructor { + return null; + } + + static presets( + cls: serialization.SerializableConstructor + ) { + return {}; + } + + getLayers() { + throw new NotImplementedError(); + } + + override summary( + lineLength?: number, positions?: number[], + printFn: + // tslint:disable-next-line:no-any + (message?: any, ...optionalParams: any[]) => void = console.log) { + throw new NotImplementedError(); + } +} diff --git a/tfjs-layers/src/layers/nlp/utils.ts b/tfjs-layers/src/layers/nlp/utils.ts index 3ea3ebf49ec..c6807567535 100644 --- a/tfjs-layers/src/layers/nlp/utils.ts +++ b/tfjs-layers/src/layers/nlp/utils.ts @@ -15,7 +15,13 @@ * ============================================================================= */ -import { Tensor, tensorScatterUpdate, tidy } from '@tensorflow/tfjs-core'; +import { ModelPredictConfig, Scalar, Tensor, tensorScatterUpdate, tidy } from '@tensorflow/tfjs-core'; + +import { History } from '../../base_callbacks'; +import { ContainerArgs } from '../../engine/container'; +import { LayersModel, ModelEvaluateArgs } from '../../engine/training'; +import { ModelFitArgs } from '../../engine/training_tensors'; +import { NotImplementedError } from '../../errors'; export function tensorToArr(input: Tensor): unknown[] { return Array.from(input.dataSync()) as unknown as unknown[]; @@ -62,3 +68,113 @@ export function sliceUpdate( return tensorScatterUpdate(inputs, indices, updates); }); } + +function packXYSampleWeight(x: Tensor, y?: Tensor, sampleWeight?: Tensor): + Tensor + | [Tensor, Tensor] + | [Tensor, Tensor, Tensor] { + throw new NotImplementedError(); +} + +function unPackXYSampleWeight( + data: [Tensor]|[Tensor, Tensor]|[Tensor, Tensor, Tensor] +) { + throw new NotImplementedError(); +} + +// TODO(pforderique): Figure out a workaround for `tf.data.Dataset`. +function convertInputsToDataset( + x?: Tensor, y?: Tensor, sampleWeight?: Tensor, batchSize?: number +) { + throw new NotImplementedError(); +} + +function trainValidationSplit(arrays: Tensor[], validationSplit: number) { + throw new NotImplementedError(); +} + +/** + * A model which allows automatically applying preprocessing. + */ +export interface PipelineModelArgs extends ContainerArgs { + /** + * Defaults to true. + */ + includePreprocessing?: boolean; +} + +export class PipelineModel extends LayersModel { + /** @nocollapse */ + static override className = 'PipelineModel'; + + protected includePreprocessing: boolean; + + constructor(args: PipelineModelArgs) { + super(args); + this.includePreprocessing = args.includePreprocessing ?? true; + } + + /** + * An overridable function which preprocesses features. + */ + preprocessFeatures(x: Tensor) { + return x; + } + + /** + * An overridable function which preprocesses labels. + */ + preprocessLabels(y: Tensor) { + return y; + } + + /** + * An overridable function which preprocesses entire samples. + */ + preprocessSamples(x: Tensor, y?: Tensor, sampleWeight?: Tensor): + Tensor + | [Tensor, Tensor] + | [Tensor, Tensor, Tensor] { + throw new NotImplementedError(); + } + + // --------------------------------------------------------------------------- + // Below are overrides to LayersModel methods to apply the functions above. + // --------------------------------------------------------------------------- + override fit( + x: Tensor|Tensor[]|{[inputName: string]: Tensor}, + y: Tensor|Tensor[]|{[inputName: string]: Tensor}, + args: ModelFitArgs = {} + ): Promise { + throw new NotImplementedError( + `Uses ${convertInputsToDataset}, ${trainValidationSplit} ` + + `${packXYSampleWeight}, and ${unPackXYSampleWeight}`); + } + + override evaluate( + x: Tensor|Tensor[], + y: Tensor|Tensor[], + args?: ModelEvaluateArgs + ): Scalar | Scalar[] { + throw new NotImplementedError(); + } + + override predict( + x: Tensor | Tensor[], + args?: ModelPredictConfig + ): Tensor | Tensor[] { + throw new NotImplementedError(); + } + + override trainOnBatch( + x: Tensor|Tensor[]|{[inputName: string]: Tensor}, + y: Tensor|Tensor[]|{[inputName: string]: Tensor}, + sampleWeight?: Tensor + ): Promise { + throw new NotImplementedError(); + } + + override predictOnBatch(x: Tensor|Tensor[]): Tensor|Tensor[] { + throw new NotImplementedError(); + } +}