diff --git a/e2e/scripts/verdaccio.yaml b/e2e/scripts/verdaccio.yaml index ff7ed606491..bc5bb77e41a 100644 --- a/e2e/scripts/verdaccio.yaml +++ b/e2e/scripts/verdaccio.yaml @@ -23,7 +23,6 @@ packages: '@tensorflow/**': access: $all publish: $all - proxy: npmjs '@*/*': access: $all publish: $all diff --git a/tfjs-layers/src/layers/nlp/models/preprocessor.ts b/tfjs-layers/src/layers/nlp/models/preprocessor.ts new file mode 100644 index 00000000000..5ef9c0787db --- /dev/null +++ b/tfjs-layers/src/layers/nlp/models/preprocessor.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/* Original source: keras-nlp/models/preprocessor.py */ +import { serialization } from '@tensorflow/tfjs-core'; + +import { Layer, LayerArgs } from '../../../engine/topology'; +import { Tokenizer } from '../tokenizers'; +import { Kwargs } from '../../../types'; +import { deserializeKerasObject, serializeKerasObject } from '../../../utils/generic_utils'; + +/** + * Base class for model Preprocessors. + */ +export class Preprocessor extends Layer { + /** @nocollapse */ + static readonly className = 'Preprocessor'; + + private _tokenizer: Tokenizer; + + constructor(args: LayerArgs) { + super(args); + } + + /** + * The tokenizer used to tokenize strings. + */ + get tokenizer() { + return this._tokenizer; + } + + set tokenizer(value: Tokenizer) { + this._tokenizer = value; + } + + override getConfig(): serialization.ConfigDict { + const config = super.getConfig(); + config.tokenizer = serializeKerasObject(this.tokenizer); + return config; + } + + static override fromConfig( + cls: serialization.SerializableConstructor, + config: serialization.ConfigDict + ): T { + const kwargs: Kwargs = config; + + if (config.tokenizer != null && !(config.tokenizer instanceof Tokenizer)) { + const tokenizerConfigDict = config.tokenizer as serialization.ConfigDict; + + kwargs.tokenizer = deserializeKerasObject( + tokenizerConfigDict, + serialization.SerializationMap.getMap().classNameMap, + {}, 'preprocessor'); + } + return new cls(kwargs); + } + + static tokenizerCls( + cls: serialization.SerializableConstructor) {} +} +serialization.registerClass(Preprocessor); diff --git a/tfjs-layers/src/layers/nlp/models/preprocessor_test.ts b/tfjs-layers/src/layers/nlp/models/preprocessor_test.ts new file mode 100644 index 00000000000..1676c2a9803 --- /dev/null +++ b/tfjs-layers/src/layers/nlp/models/preprocessor_test.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Unit Tests for Preprocessor Layers. + */ +import { Preprocessor } from './preprocessor'; + +describe('Preprocessor', () => { + let preprocessor: Preprocessor; + + beforeEach(() => { + preprocessor = new Preprocessor({}); + }); + + it('serialization round-trip with no set tokenizer', () => { + const reserialized = Preprocessor.fromConfig( + Preprocessor, preprocessor.getConfig()); + expect(reserialized.getConfig()).toEqual(preprocessor.getConfig()); + }); +}); diff --git a/tfjs-layers/src/layers/nlp/tokenizers.ts b/tfjs-layers/src/layers/nlp/tokenizers.ts index 3473ee21a96..b4f395cff5a 100644 --- a/tfjs-layers/src/layers/nlp/tokenizers.ts +++ b/tfjs-layers/src/layers/nlp/tokenizers.ts @@ -331,7 +331,7 @@ export class BytePairTokenizer extends Tokenizer { override getConfig(): serialization.ConfigDict { const config = { - vocabulary: this.vocabulary, + vocabulary: Array.from(this._vocabulary.entries()), merges: this.merges, sequenceLength: this.sequenceLength, addPrefixSpace: this.addPrefixSpace,