Skip to content

Commit

Permalink
Add new model (#32615)
Browse files Browse the repository at this point in the history
* v1 - working version

* fix

* fix

* fix

* fix

* rename to correct name

* fix title

* fixup

* rename files

* fix

* add copied from on tests

* rename to `FalconMamba` everywhere and fix bugs

* fix quantization + accelerate

* fix copies

* add `torch.compile` support

* fix tests

* fix tests and add slow tests

* copies on config

* merge the latest changes

* fix tests

* add few lines about instruct

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix tests

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
younesbelkada and ArthurZucker authored Aug 12, 2024
1 parent 48101cf commit 7c11491
Show file tree
Hide file tree
Showing 16 changed files with 1,693 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@
title: ESM
- local: model_doc/falcon
title: Falcon
- local: model_doc/falcon_mamba
title: FalconMamba
- local: model_doc/fastspeech2_conformer
title: FastSpeech2Conformer
- local: model_doc/flan-t5
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Flax), PyTorch, and/or TensorFlow.
| [ESM](model_doc/esm) ||||
| [FairSeq Machine-Translation](model_doc/fsmt) ||||
| [Falcon](model_doc/falcon) ||||
| [FalconMamba](model_doc/falcon_mamba) ||||
| [FastSpeech2Conformer](model_doc/fastspeech2_conformer) ||||
| [FLAN-T5](model_doc/flan-t5) ||||
| [FLAN-UL2](model_doc/flan-ul2) ||||
Expand Down
116 changes: 116 additions & 0 deletions docs/source/en/model_doc/falcon_mamba.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# FalconMamba

## Overview

The FalconMamba model was proposed by TII UAE (Technology Innovation Institute) in their release.

The abstract from the paper is the following:

*We present FalconMamba, a new base large language model based on the novel Mamba architecture. FalconMamba is trained on 5.8 trillion tokens with carefully selected data mixtures. As a pure Mamba-based model, FalconMamba surpasses leading open-weight models based on Transformers, such as Mistral 7B, Llama3 8B, and Falcon2 11B. It is on par with Gemma 7B and outperforms models with different architecture designs, such as RecurrentGemma 9B. Currently, FalconMamba is the best-performing Mamba model in the literature at this scale, surpassing both existing Mamba and hybrid Mamba-Transformer models.
Due to its architecture, FalconMamba is significantly faster at inference and requires substantially less memory for long sequence generation. Despite recent studies suggesting that hybrid Mamba-Transformer models outperform pure architecture designs, we argue and demonstrate that the pure Mamba design can achieve similar, even superior results compared to the hybrid design. We make the weights of our implementation of FalconMamba publicly available under a permissive license.*

Tips:

- FalconMamba is mostly based on Mamba architecutre, the same [tips and best practices](./mamba) would be relevant here.

The model has been trained on approximtely 6T tokens consisting a mixture of many data sources such as RefineWeb, Cosmopedia and Math data.

For more details about the training procedure and the architecture, have a look at [the technical paper of FalconMamba]() (coming soon).

# Usage

Below we demonstrate how to use the model:

```python
from transformers import FalconMambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b")

input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```

The architecture is also compatible with `torch.compile` for faster generation:

```python
from transformers import FalconMambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", torch_dtype=torch.bfloat16).to(0)
model = torch.compile(model)

input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```

If you have access to a GPU that is compatible with `bitsandbytes`, you can also quantize the model in 4-bit precision:

```python
from transformers import FalconMambaForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", quantization_config=quantization_config)

input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```

You can also play with the instruction fine-tuned model:

```python
from transformers import FalconMambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b-instruct")
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b-instruct")

# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).input_ids

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
```

## FalconMambaConfig

[[autodoc]] FalconMambaConfig

## FalconMambaModel

[[autodoc]] FalconMambaModel
- forward

## FalconMambaLMHeadModel

[[autodoc]] FalconMambaForCausalLM
- forward
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@
"models.ernie": ["ErnieConfig"],
"models.esm": ["EsmConfig", "EsmTokenizer"],
"models.falcon": ["FalconConfig"],
"models.falcon_mamba": ["FalconMambaConfig"],
"models.fastspeech2_conformer": [
"FastSpeech2ConformerConfig",
"FastSpeech2ConformerHifiGanConfig",
Expand Down Expand Up @@ -2138,6 +2139,13 @@
"FalconPreTrainedModel",
]
)
_import_structure["models.falcon_mamba"].extend(
[
"FalconMambaForCausalLM",
"FalconMambaModel",
"FalconMambaPreTrainedModel",
]
)
_import_structure["models.fastspeech2_conformer"].extend(
[
"FastSpeech2ConformerHifiGan",
Expand Down Expand Up @@ -5127,6 +5135,7 @@
from .models.ernie import ErnieConfig
from .models.esm import EsmConfig, EsmTokenizer
from .models.falcon import FalconConfig
from .models.falcon_mamba import FalconMambaConfig
from .models.fastspeech2_conformer import (
FastSpeech2ConformerConfig,
FastSpeech2ConformerHifiGanConfig,
Expand Down Expand Up @@ -6739,6 +6748,11 @@
FalconModel,
FalconPreTrainedModel,
)
from .models.falcon_mamba import (
FalconMambaForCausalLM,
FalconMambaModel,
FalconMambaPreTrainedModel,
)
from .models.fastspeech2_conformer import (
FastSpeech2ConformerHifiGan,
FastSpeech2ConformerModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
ernie,
esm,
falcon,
falcon_mamba,
fastspeech2_conformer,
flaubert,
flava,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"),
("falcon", "FalconConfig"),
("falcon_mamba", "FalconMambaConfig"),
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
("flaubert", "FlaubertConfig"),
("flava", "FlavaConfig"),
Expand Down Expand Up @@ -384,6 +385,7 @@
("ernie_m", "ErnieM"),
("esm", "ESM"),
("falcon", "Falcon"),
("falcon_mamba", "FalconMamba"),
("fastspeech2_conformer", "FastSpeech2Conformer"),
("flan-t5", "FLAN-T5"),
("flan-ul2", "FLAN-UL2"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
("ernie_m", "ErnieMModel"),
("esm", "EsmModel"),
("falcon", "FalconModel"),
("falcon_mamba", "FalconMambaModel"),
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
("flaubert", "FlaubertModel"),
("flava", "FlavaModel"),
Expand Down Expand Up @@ -291,6 +292,7 @@
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForPreTraining"),
("ernie", "ErnieForPreTraining"),
("falcon_mamba", "FalconMambaForCausalLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("flava", "FlavaForPreTraining"),
("fnet", "FNetForPreTraining"),
Expand Down Expand Up @@ -377,6 +379,7 @@
("encoder-decoder", "EncoderDecoderModel"),
("ernie", "ErnieForMaskedLM"),
("esm", "EsmForMaskedLM"),
("falcon_mamba", "FalconMambaForCausalLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("fnet", "FNetForMaskedLM"),
("fsmt", "FSMTForConditionalGeneration"),
Expand Down Expand Up @@ -462,6 +465,7 @@
("electra", "ElectraForCausalLM"),
("ernie", "ErnieForCausalLM"),
("falcon", "FalconForCausalLM"),
("falcon_mamba", "FalconMambaForCausalLM"),
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
("esm", ("EsmTokenizer", None)),
("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
(
"fastspeech2_conformer",
("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
Expand Down
58 changes: 58 additions & 0 deletions src/transformers/models/falcon_mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_falcon_mamba": ["FalconMambaConfig"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_falcon_mamba"] = [
"FalconMambaForCausalLM",
"FalconMambaModel",
"FalconMambaPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_falcon_mamba import FalconMambaConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_falcon_mamba import (
FalconMambaForCausalLM,
FalconMambaModel,
FalconMambaPreTrainedModel,
)
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading

0 comments on commit 7c11491

Please sign in to comment.