Skip to content

Commit

Permalink
Implement SeamlessM4T
Browse files Browse the repository at this point in the history
  • Loading branch information
ikergarcia1996 committed Nov 30, 2023
1 parent f88323f commit 9dcafee
Show file tree
Hide file tree
Showing 9 changed files with 2,854 additions and 73 deletions.
44 changes: 32 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

<p align="center">
<br>
<img src="images/title.png" width="900"/>
Expand Down Expand Up @@ -29,23 +28,19 @@ We currently support:
- BF16 / FP16 / FP32 / 8 Bits / 4 Bits precision.
- Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
- Multiple decoding strategies: Greedy Search, Beam Search, Top-K Sampling, Top-p (nucleus) sampling, etc. See [Decoding Strategies](#decodingsampling-strategies) for more information.
- :new: Load huge models in a single GPU with 8-bits / 4-bits quantization and support for splitting the model between GPU and CPU. See [Loading Huge Models](#loading-huge-models) for more information.
- :new: LoRA models support
- :new: Support for any Seq2SeqLM or CausalLM model from HuggingFace's Hub.
- :new: Prompt support! See [Prompting](#prompting) for more information.
- Load huge models in a single GPU with 8-bits / 4-bits quantization and support for splitting the model between GPU and CPU. See [Loading Huge Models](#loading-huge-models) for more information.
- LoRA models support
- Support for any Seq2SeqLM or CausalLM model from HuggingFace's Hub.
- Prompt support! See [Prompting](#prompting) for more information.
- :new: Add support for [SeamlessM4T](https://huggingface.co/docs/transformers/main/en/model_doc/seamless_m4t)!

>Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>


## Supported languages

See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.

## Supported Models

💥 EasyTranslate now supports any Seq2SeqLM (m2m100, nllb200, small100, mbart, MarianMT, T5, FlanT5, etc.) and any CausalLM (GPT2, LLaMA, Vicuna, Falcon) model from 🤗 Hugging Face's Hub!!
We still recommend you to use M2M100 or NLLB200 for the best results, but you can experiment with any other MT model, as well as prompting LLMs to generate translations (See [Prompting Section](#prompting) for more details).
We still recommend you to use M2M100, NLLB200 or SeamlessM4T for the best results, but you can experiment with any other MT model, as well as prompting LLMs to generate translations (See [Prompting Section](#prompting) for more details).
You can also see [the examples folder](examples) for examples of how to use EasyTranslate with different models.

### M2M100
Expand Down Expand Up @@ -73,13 +68,23 @@ You can also see [the examples folder](examples) for examples of how to use Easy

- **facebook/nllb-200-distilled-600M**: <https://huggingface.co/facebook/nllb-200-distilled-600M>

### SeamlessM4T

**SeamlessM4T** a collection of models designed to provide high quality translation, allowing people from different linguistic communities to communicate effortlessly through speech and text. It was introduced in this [paper](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) and first released in [this](https://github.com/facebookresearch/seamless_communication) repository.
>SeamlessM4T can directly translate between 196 Languages for text input/output.
- **facebook/hf-seamless-m4t-medium**: <https://huggingface.co/facebook/hf-seamless-m4t-medium> (Requires transformers 4.35.0)

- **facebook/hf-seamless-m4t-large**: <https://huggingface.co/facebook/hf-seamless-m4t-large> (Requires transformers 4.35.0)


### Other MT Models supported
We support every MT model in the 🤗 Hugging Face's Hub. If you find a model that doesn't work, please open an issue for us to fix it or a PR with the fix. This includes, among many others:
- **Small100**: <https://huggingface.co/alirezamsh/small100>
- **Mbart many-to-many / many-to-one**: <https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt>
- **Opus MT**: <https://huggingface.co/Helsinki-NLP/opus-mt-es-en>


See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.

## Citation
If you use this software please cite
Expand Down Expand Up @@ -110,6 +115,7 @@ pip install accelerate
HuggingFace Transformers
If you plan to use NLLB200, please use >= 4.28.0, as an important bug was fixed in this version.
If you plan to use SeamlessM4T, please use >= 4.35.0.
pip install --upgrade transformers
BitsAndBytes (Optional, required for 8-bits / 4-bits quantization)
Expand All @@ -135,6 +141,20 @@ python3 translate.py \
--model_name facebook/m2m100_1.2B
```

If you want to translate all the files in a directory, use the `--sentences_dir` flag instead of `--sentences_path`.
```bash
# We use --files_extension txt to translate only files with this extension.
# Use empty string to translate all files in the directory

python3 translate.py \
--sentences_dir sample_text/ \
--output_path sample_text/translations \
--files_extension txt \
--source_lang en \
--target_lang es \
--model_name facebook/m2m100_1.2B
```

#### Multi-GPU

See Accelerate documentation for more information (multi-node, TPU, Sharded model...): <https://huggingface.co/docs/accelerate/index>
Expand Down
75 changes: 35 additions & 40 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from typing import Optional, Tuple

import os

import torch

import json
Expand All @@ -27,6 +25,7 @@ def load_model_for_inference(
lora_weights_name_or_path: Optional[str] = None,
torch_dtype: Optional[str] = None,
force_auto_device_map: bool = False,
trust_remote_code: bool = False,
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
"""
Load any Decoder model for inference.
Expand All @@ -50,6 +49,8 @@ def load_model_for_inference(
Whether to force the use of the auto device map. If set to True, the model will be split across
GPUs and CPU to fit the model in memory. If set to False, a full copy of the model will be loaded
into each GPU. Defaults to False.
trust_remote_code (`bool`, optional):
Trust the remote code from HuggingFace model hub. Defaults to False.
Returns:
`Tuple[PreTrainedModel, PreTrainedTokenizerBase]`:
Expand All @@ -64,40 +65,49 @@ def load_model_for_inference(

print(f"Loading model from {weights_path}")

MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
{
"mpt": "MPTForCausalLM",
"RefinedWebModel": "RWForCausalLM",
"RefinedWeb": "RWForCausalLM",
}
) # MPT and Falcon are not in transformers yet

config = AutoConfig.from_pretrained(
weights_path,
trust_remote_code=True
if ("mpt" in weights_path or "falcon" in weights_path)
else False,
weights_path, trust_remote_code=trust_remote_code
)

torch_dtype = (
torch_dtype if torch_dtype in ["auto", None] else getattr(torch, torch_dtype)
)

if "small100" in weights_path:
import transformers

if transformers.__version__ > "4.34.0":
raise ValueError(
"Small100 tokenizer is not supported in transformers > 4.34.0. Please "
"use transformers <= 4.34.0 if you want to use small100"
)

print(f"Loading custom small100 tokenizer for utils.tokenization_small100")
from utils.tokenization_small100 import SMALL100Tokenizer as AutoTokenizer
else:
from transformers import AutoTokenizer

tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
weights_path,
add_eos_token=True,
trust_remote_code=True
if ("mpt" in weights_path or "falcon" in weights_path)
else False,
weights_path, add_eos_token=True, trust_remote_code=trust_remote_code
)

if tokenizer.pad_token_id is None:
if "<|padding|>" in tokenizer.get_vocab():
# StabilityLM specific fix
tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
elif tokenizer.unk_token is not None:
print(
"Tokenizer does not have a pad token, we will use the unk token as pad token."
)
tokenizer.pad_token_id = tokenizer.unk_token_id
else:
print(
"Tokenizer does not have a pad token. We will use the eos token as pad token."
)
tokenizer.pad_token_id = tokenizer.eos_token_id

quant_args = {}

if quantization is not None:
quant_args = (
{"load_in_4bit": True} if quantization == 4 else {"load_in_8bit": True}
Expand All @@ -107,16 +117,17 @@ def load_model_for_inference(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_compute_dtype=torch.bfloat16
if torch_dtype in ["auto", None]
else torch_dtype,
)
torch_dtype = torch.bfloat16

else:
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
)
print(
f"Bits and Bytes config: {json.dumps(bnb_config.to_dict(),indent=4,ensure_ascii=False)}"
f"Bits and Bytes config: {json.dumps(bnb_config.to_dict(), indent=4, ensure_ascii=False)}"
)
else:
print(f"Loading model with dtype: {torch_dtype}")
Expand All @@ -131,6 +142,7 @@ def load_model_for_inference(
device_map="auto" if force_auto_device_map else None,
torch_dtype=torch_dtype,
quantization_config=bnb_config,
trust_remote_code=trust_remote_code,
**quant_args,
)

Expand All @@ -142,9 +154,7 @@ def load_model_for_inference(
pretrained_model_name_or_path=weights_path,
device_map="auto" if force_auto_device_map else None,
torch_dtype=torch_dtype,
trust_remote_code=True
if ("mpt" in weights_path or "falcon" in weights_path)
else False,
trust_remote_code=trust_remote_code,
quantization_config=bnb_config,
**quant_args,
)
Expand All @@ -159,21 +169,6 @@ def load_model_for_inference(
f"CausalLM: {MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}\n"
)

if tokenizer.pad_token_id is None:
if "<|padding|>" in tokenizer.get_vocab():
# StableLM specific fix
tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
elif tokenizer.unk_token is not None:
print(
"Model does not have a pad token, we will use the unk token as pad token."
)
tokenizer.pad_token_id = tokenizer.unk_token_id
else:
print(
"Model does not have a pad token. We will use the eos token as pad token."
)
tokenizer.pad_token_id = tokenizer.eos_token_id

if lora_weights_name_or_path:
from peft import PeftModel

Expand Down
60 changes: 60 additions & 0 deletions sample_text/en2es.seamless-m4t-large.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"path": "sample_text/en2es.translation.seamless-m4t-large.txt",
"sacrebleu": {
"score": 36.315142112223896,
"counts": [
20334,
12742,
8758,
6156
],
"totals": [
31021,
30021,
29021,
28021
],
"precisions": [
65.54914412817124,
42.44362279737517,
30.178146859170944,
21.969237357696013
],
"bp": 0.9854077938820913,
"sys_len": 31021,
"ref_len": 31477
},
"rouge": {
"rouge1": 0.6330701226501922,
"rouge2": 0.4284215608900075,
"rougeL": 0.5852948888167713,
"rougeLsum": 0.5852893813466102
},
"bleu": {
"bleu": 0.36315142112223897,
"precisions": [
0.6554914412817124,
0.4244362279737517,
0.30178146859170946,
0.21969237357696014
],
"brevity_penalty": 0.9854077938820913,
"length_ratio": 0.9855132318835975,
"translation_length": 31021,
"reference_length": 31477
},
"meteor": {
"meteor": 0.5988659867679048
},
"ter": {
"score": 53.42233524051706,
"num_edits": 15126,
"ref_length": 28314.0
},
"bert_score": {
"precision": 0.8355873214006424,
"recall": 0.8343284497857094,
"f1": 0.8346186644434929,
"hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.35.2)_fast-tokenizer"
}
}
60 changes: 60 additions & 0 deletions sample_text/en2es.seamless-m4t-medium.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"path": "sample_text/en2es.translation.seamless-m4t-medium.txt",
"sacrebleu": {
"score": 32.86110838375764,
"counts": [
19564,
11721,
7752,
5264
],
"totals": [
30811,
29811,
28811,
27812
],
"precisions": [
63.49680308980559,
39.31770151957331,
26.90638992051647,
18.92708183517906
],
"bp": 0.978616287348328,
"sys_len": 30811,
"ref_len": 31477
},
"rouge": {
"rouge1": 0.609193205717968,
"rouge2": 0.3944070815557623,
"rougeL": 0.558841464797821,
"rougeLsum": 0.5594046328281417
},
"bleu": {
"bleu": 0.3286110838375765,
"precisions": [
0.6349680308980559,
0.3931770151957331,
0.2690638992051647,
0.1892708183517906
],
"brevity_penalty": 0.978616287348328,
"length_ratio": 0.9788416939352543,
"translation_length": 30811,
"reference_length": 31477
},
"meteor": {
"meteor": 0.5707261528520716
},
"ter": {
"score": 55.88754679663771,
"num_edits": 15824,
"ref_length": 28314.0
},
"bert_score": {
"precision": 0.8278114783763886,
"recall": 0.824702616840601,
"f1": 0.8259151731133461,
"hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.35.2)_fast-tokenizer"
}
}
Loading

0 comments on commit 9dcafee

Please sign in to comment.