From 4e967885b8f5bdcf1c5bbd36ad04c93ba27f3438 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 24 Mar 2023 16:03:56 +0200 Subject: [PATCH] Add support for MT5 (Closes #39, #47) --- README.md | 2 +- index.html | 10 ++++++++++ package.json | 2 +- src/models.js | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6b4557e6960d1f..29f9c11774077c 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![license](https://img.shields.io/github/license/xenova/transformers.js)](https://github.com/xenova/transformers.js/blob/main/LICENSE) -Run 🤗 Transformers in your browser! We currently support [BERT](https://huggingface.co/docs/transformers/model_doc/bert), [ALBERT](https://huggingface.co/docs/transformers/model_doc/albert), [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), [T5](https://huggingface.co/docs/transformers/model_doc/t5), [T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1), [FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5), [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2), [GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo), [BART](https://huggingface.co/docs/transformers/model_doc/bart), [CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen), [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper), [CLIP](https://huggingface.co/docs/transformers/model_doc/clip), [Vision Transformer](https://huggingface.co/docs/transformers/model_doc/vit), [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder), and [DETR](https://huggingface.co/docs/transformers/model_doc/detr) models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection. +Run 🤗 Transformers in your browser! We currently support [BERT](https://huggingface.co/docs/transformers/model_doc/bert), [ALBERT](https://huggingface.co/docs/transformers/model_doc/albert), [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), [T5](https://huggingface.co/docs/transformers/model_doc/t5), [T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1), [FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5), [mT5](https://huggingface.co/docs/transformers/model_doc/mt5), [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2), [GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo), [BART](https://huggingface.co/docs/transformers/model_doc/bart), [CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen), [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper), [CLIP](https://huggingface.co/docs/transformers/model_doc/clip), [Vision Transformer](https://huggingface.co/docs/transformers/model_doc/vit), [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder), and [DETR](https://huggingface.co/docs/transformers/model_doc/detr) models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection. ![teaser](https://user-images.githubusercontent.com/26504141/221056008-e906614e-e6f0-4e10-b0a8-7d5c99e955b4.gif) diff --git a/index.html b/index.html index b774f35444d3db..17f9550111cd6c 100644 --- a/index.html +++ b/index.html @@ -673,6 +673,7 @@

Usage

t5-base, google/t5-v1_1-small, google/t5-v1_1-base, + google/mt5-small, facebook/bart-large-cnn, sshleifer/distilbart-cnn-6-6, sshleifer/distilbart-cnn-12-6. @@ -837,6 +838,15 @@

Usage

href="https://huggingface.co/docs/transformers/model_doc/flan-t5">FLAN-T5 docs. +
  • +
    +
    mT5
    + Tasks: Sequence-to-sequence + (AutoModelForSeq2SeqLM). + For more information, check out the mT5 docs. +
    +
  • GPT2/DistilGPT2
    diff --git a/package.json b/package.json index e376ad1268b3bc..829065e7d58138 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@xenova/transformers", "version": "1.3.4", - "description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, T5, T5v1.1, FLAN-T5, GPT2, GPT Neo, BART, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection.", + "description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, T5, T5v1.1, FLAN-T5, mT5, GPT2, GPT Neo, BART, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection.", "main": "./src/transformers.js", "directories": { "test": "tests" diff --git a/src/models.js b/src/models.js index 187515d1cecb33..313b54afaff6fb 100644 --- a/src/models.js +++ b/src/models.js @@ -782,6 +782,55 @@ class T5ForConditionalGeneration extends T5PreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// MT5 models +class MT5PreTrainedModel extends PreTrainedModel { }; + +class MT5Model extends MT5PreTrainedModel { + async generate(...args) { + throw Error( + "The current model class (MT5Model) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'MT5ForConditionalGeneration'}" + ) + } +} + +class MT5ForConditionalGeneration extends MT5PreTrainedModel { + constructor(config, session, decoder_merged_session, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; + + this.num_decoder_layers = this.config.num_decoder_layers; + this.num_decoder_heads = this.config.num_heads; + this.decoder_dim_kv = this.config.d_kv; + + this.num_encoder_layers = this.config.num_layers; + this.num_encoder_heads = this.config.num_heads; + this.encoder_dim_kv = this.config.d_kv; + } + + static async from_pretrained(modelPath, progressCallback = null) { + let info = await seq2seqLoadModel(modelPath, progressCallback); + return new this(...info); + } + + getStartBeams(inputs, numOutputTokens, ...args) { + return seq2seqStartBeams(this, inputs, numOutputTokens); + } + + async runBeam(beam) { + return await seq2seqRunBeam(this, beam); + } + updateBeam(beam, newTokenId) { + beam.output_token_ids = [...beam.output_token_ids, newTokenId]; + } + + async forward(model_inputs) { + return await seq2seq_forward(this, model_inputs); + } +} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // Bart models class BartPretrainedModel extends PreTrainedModel { }; @@ -1236,6 +1285,7 @@ class AutoModelForSequenceClassification { class AutoModelForSeq2SeqLM { static modelClassMapping = { 't5': T5ForConditionalGeneration, + 'mt5': MT5ForConditionalGeneration, 'bart': BartForConditionalGeneration, 'whisper': WhisperForConditionalGeneration, }