Skip to content

Commit

Permalink
Add support for MT5 (Closes huggingface#39, huggingface#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 24, 2023
1 parent 691fb39 commit 4e96788
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
[![license](https://img.shields.io/github/license/xenova/transformers.js)](https://github.com/xenova/transformers.js/blob/main/LICENSE)


Run 🤗 Transformers in your browser! We currently support [BERT](https://huggingface.co/docs/transformers/model_doc/bert), [ALBERT](https://huggingface.co/docs/transformers/model_doc/albert), [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), [T5](https://huggingface.co/docs/transformers/model_doc/t5), [T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1), [FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5), [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2), [GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo), [BART](https://huggingface.co/docs/transformers/model_doc/bart), [CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen), [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper), [CLIP](https://huggingface.co/docs/transformers/model_doc/clip), [Vision Transformer](https://huggingface.co/docs/transformers/model_doc/vit), [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder), and [DETR](https://huggingface.co/docs/transformers/model_doc/detr) models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection.
Run 🤗 Transformers in your browser! We currently support [BERT](https://huggingface.co/docs/transformers/model_doc/bert), [ALBERT](https://huggingface.co/docs/transformers/model_doc/albert), [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), [T5](https://huggingface.co/docs/transformers/model_doc/t5), [T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1), [FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5), [mT5](https://huggingface.co/docs/transformers/model_doc/mt5), [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2), [GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo), [BART](https://huggingface.co/docs/transformers/model_doc/bart), [CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen), [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper), [CLIP](https://huggingface.co/docs/transformers/model_doc/clip), [Vision Transformer](https://huggingface.co/docs/transformers/model_doc/vit), [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder), and [DETR](https://huggingface.co/docs/transformers/model_doc/detr) models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection.

![teaser](https://user-images.githubusercontent.com/26504141/221056008-e906614e-e6f0-4e10-b0a8-7d5c99e955b4.gif)

Expand Down
10 changes: 10 additions & 0 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ <h2 class="fw-bolder">Usage</h2>
<code>t5-base</code>,
<code>google/t5-v1_1-small</code>,
<code>google/t5-v1_1-base</code>,
<code>google/mt5-small</code>,
<code>facebook/bart-large-cnn</code>,
<code>sshleifer/distilbart-cnn-6-6</code>,
<code>sshleifer/distilbart-cnn-12-6</code>.
Expand Down Expand Up @@ -837,6 +838,15 @@ <h2 class="fw-bolder">Usage</h2>
href="https://huggingface.co/docs/transformers/model_doc/flan-t5">FLAN-T5 docs</a>.
</div>
</li>
<li class="list-group-item d-flex justify-content-between align-items-start">
<div class="ms-2 me-auto">
<div class="fw-bold">mT5</div>
Tasks: Sequence-to-sequence
<code>(AutoModelForSeq2SeqLM)</code>.
For more information, check out the <a
href="https://huggingface.co/docs/transformers/model_doc/mt5">mT5 docs</a>.
</div>
</li>
<li class="list-group-item d-flex justify-content-between align-items-start">
<div class="ms-2 me-auto">
<div class="fw-bold">GPT2/DistilGPT2</div>
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@xenova/transformers",
"version": "1.3.4",
"description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, T5, T5v1.1, FLAN-T5, GPT2, GPT Neo, BART, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection.",
"description": "Run 🤗 Transformers in your browser! We currently support BERT, ALBERT, DistilBERT, T5, T5v1.1, FLAN-T5, mT5, GPT2, GPT Neo, BART, CodeGen, Whisper, CLIP, Vision Transformer, VisionEncoderDecoder, and DETR models, for a variety of tasks including: masked language modelling, text classification, text-to-text generation, translation, summarization, question answering, text generation, automatic speech recognition, image classification, zero-shot image classification, image-to-text, and object detection.",
"main": "./src/transformers.js",
"directories": {
"test": "tests"
Expand Down
50 changes: 50 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,55 @@ class T5ForConditionalGeneration extends T5PreTrainedModel {
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// MT5 models
class MT5PreTrainedModel extends PreTrainedModel { };

class MT5Model extends MT5PreTrainedModel {
async generate(...args) {
throw Error(
"The current model class (MT5Model) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'MT5ForConditionalGeneration'}"
)
}
}

class MT5ForConditionalGeneration extends MT5PreTrainedModel {
constructor(config, session, decoder_merged_session, generation_config) {
super(config, session);
this.decoder_merged_session = decoder_merged_session;
this.generation_config = generation_config;

this.num_decoder_layers = this.config.num_decoder_layers;
this.num_decoder_heads = this.config.num_heads;
this.decoder_dim_kv = this.config.d_kv;

this.num_encoder_layers = this.config.num_layers;
this.num_encoder_heads = this.config.num_heads;
this.encoder_dim_kv = this.config.d_kv;
}

static async from_pretrained(modelPath, progressCallback = null) {
let info = await seq2seqLoadModel(modelPath, progressCallback);
return new this(...info);
}

getStartBeams(inputs, numOutputTokens, ...args) {
return seq2seqStartBeams(this, inputs, numOutputTokens);
}

async runBeam(beam) {
return await seq2seqRunBeam(this, beam);
}
updateBeam(beam, newTokenId) {
beam.output_token_ids = [...beam.output_token_ids, newTokenId];
}

async forward(model_inputs) {
return await seq2seq_forward(this, model_inputs);
}
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// Bart models
class BartPretrainedModel extends PreTrainedModel { };
Expand Down Expand Up @@ -1236,6 +1285,7 @@ class AutoModelForSequenceClassification {
class AutoModelForSeq2SeqLM {
static modelClassMapping = {
't5': T5ForConditionalGeneration,
'mt5': MT5ForConditionalGeneration,
'bart': BartForConditionalGeneration,
'whisper': WhisperForConditionalGeneration,
}
Expand Down

0 comments on commit 4e96788

Please sign in to comment.