Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax mistral #26943

Merged
merged 98 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 83 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
3f74b1f
direct copy from llama work
kiansierra Oct 17, 2023
b194fb9
mistral modules forward pass working
kiansierra Oct 17, 2023
6126bcd
flax mistral forward pass with sliding window
kiansierra Oct 17, 2023
09717fd
added tests
kiansierra Oct 17, 2023
0e2905b
added layer collection approach
kiansierra Oct 17, 2023
fb17b61
Revert "added layer collection approach"
kiansierra Oct 17, 2023
41ed9a9
Revert "Revert "added layer collection approach""
kiansierra Oct 17, 2023
89a0fd7
fixed attention outputs
kiansierra Oct 17, 2023
beca7be
added mistral to init and auto
kiansierra Oct 17, 2023
299c07a
fixed import name
kiansierra Oct 18, 2023
d7ced4d
fixed layernorm weight dtype
kiansierra Oct 18, 2023
2985a61
freeze initialized weights
kiansierra Oct 18, 2023
f39a798
make sure conversion consideres bfloat16
kiansierra Oct 18, 2023
47a5311
added backend
kiansierra Oct 18, 2023
93cfc3a
added docstrings
kiansierra Oct 18, 2023
948cc06
added cache
kiansierra Oct 18, 2023
46d33ad
fixed sliding window causal mask
kiansierra Oct 18, 2023
6ebac9a
passes cache tests
kiansierra Oct 19, 2023
3302451
passed all tests
kiansierra Oct 19, 2023
9307ba8
applied make style
kiansierra Oct 19, 2023
4788867
removed commented out code
kiansierra Oct 19, 2023
9a55531
applied fix-copies ignored other model changes
kiansierra Oct 19, 2023
c0b4429
Merge branch 'huggingface:main' into flax-mistral
kiansierra Oct 19, 2023
a5c3fa4
applied make fix-copies
kiansierra Oct 19, 2023
e3f1078
removed unused functions
kiansierra Oct 19, 2023
69b223a
passed generation integration test
kiansierra Oct 19, 2023
3a48478
slow tests pass
kiansierra Oct 20, 2023
3ed1ab8
fixed slow tests
kiansierra Oct 20, 2023
ebf50bb
changed default dtype from jax.numpy.float32 to float32 for docstring…
kiansierra Oct 20, 2023
8d4b56a
skip cache test for FlaxMistralForSequenceClassification since if pa…
kiansierra Oct 20, 2023
acf5a96
updated checkpoint since from_pt not included
kiansierra Oct 21, 2023
876c49a
applied black style
kiansierra Oct 21, 2023
d9fdd15
removed unused args
kiansierra Oct 25, 2023
7867646
Merge branch 'main' into flax-mistral
kiansierra Nov 22, 2023
bc9345a
Merge branch 'main' into flax-mistral
kiansierra Nov 24, 2023
8d40900
Applied styling and fixup
kiansierra Nov 25, 2023
60fdad7
changed checkpoint for doc back
kiansierra Nov 25, 2023
dac618a
fixed rf after adding it to hf hub
kiansierra Nov 25, 2023
71671d6
Add dummy ckpt
kiansierra Dec 4, 2023
c569340
applied styling
kiansierra Dec 4, 2023
b0ef5a1
added tokenizer to new ckpt
kiansierra Dec 4, 2023
5e31d5d
fixed slice format
kiansierra Dec 5, 2023
d376f6a
Merge branch 'main' into flax-mistral
kiansierra Dec 6, 2023
0f87db4
fix init and slice
kiansierra Dec 6, 2023
d018f6e
changed ref for placeholder TODO
kiansierra Dec 6, 2023
c415cd9
Merge branch 'main' into flax-mistral
kiansierra Dec 7, 2023
73acb3c
added copies from Llama
kiansierra Dec 7, 2023
5d0a679
applied styling
kiansierra Dec 7, 2023
2d9211e
Merge branch 'main' into flax-mistral
kiansierra Dec 13, 2023
a3ce45c
applied fix-copies
kiansierra Dec 13, 2023
4a218a3
fixed docs
kiansierra Dec 14, 2023
dd14759
Merge branch 'main' into flax-mistral
kiansierra Dec 14, 2023
9c42f95
Merge branch 'main' into flax-mistral
kiansierra Dec 14, 2023
edd9cc6
update weight dtype reconversion for sharded weights
kiansierra Dec 14, 2023
c2950d9
removed Nullable input ids
kiansierra Dec 14, 2023
471e3e4
Removed unnecessary output attentions in Module
kiansierra Dec 14, 2023
3aaa014
added embedding weight initialziation
kiansierra Dec 14, 2023
e33327c
removed unused past_key_values
kiansierra Dec 14, 2023
1e00d30
fixed deterministic
kiansierra Dec 14, 2023
5bef1d2
Fixed RMS Norm and added copied from
kiansierra Dec 14, 2023
5b2d914
removed input_embeds
kiansierra Dec 14, 2023
adcac1c
applied make style
kiansierra Dec 14, 2023
a5a6d70
removed nullable input ids from sequence classification model
kiansierra Dec 14, 2023
85d282a
added copied from GPTJ
kiansierra Dec 14, 2023
c1758cb
added copied from Llama on FlaxMistralDecoderLayer
kiansierra Dec 14, 2023
05d62d0
added copied from to FlaxMistralPreTrainedModel methods
kiansierra Dec 14, 2023
a2c2808
fix test deprecation warning
kiansierra Dec 14, 2023
ca00fab
freeze gpt neox random_params and fix copies
kiansierra Dec 15, 2023
0ba0fea
applied make style
kiansierra Dec 15, 2023
535ef00
fixed doc issue
kiansierra Dec 15, 2023
faac78c
skipped docstring test to allign # copied from
kiansierra Dec 15, 2023
8c34572
Merge branch 'main' into flax-mistral
kiansierra Dec 15, 2023
9b028d2
applied make style
kiansierra Dec 15, 2023
212cf5d
removed FlaxMistralForSequenceClassification
kiansierra Jan 6, 2024
a1d20c8
removed unused padding_idx
kiansierra Jan 6, 2024
432db63
removed more sequence classification
kiansierra Jan 6, 2024
3b1d8c7
removed sequence classification
kiansierra Jan 6, 2024
2b11ce8
applied styling and consistency
kiansierra Jan 6, 2024
72ac552
Merge branch 'main' into flax-mistral
kiansierra Jan 6, 2024
23d1289
added copied from in tests
kiansierra Jan 7, 2024
df023d8
removed sequence classification test logic
kiansierra Jan 7, 2024
977690e
Merge branch 'main' into flax-mistral
kiansierra Jan 18, 2024
f794296
applied styling
kiansierra Jan 18, 2024
28e77c1
Merge branch 'main' into flax-mistral
kiansierra Jan 27, 2024
e572977
applied make style
kiansierra Jan 27, 2024
ff103d0
removed freeze and fixed copies
kiansierra Jan 27, 2024
80bce8d
undo test change
kiansierra Jan 27, 2024
6281c60
changed repeat_kv to tile
kiansierra Jan 27, 2024
c278516
fixed to key value groups
kiansierra Jan 27, 2024
67d71a0
updated copyright year
kiansierra Jan 30, 2024
df76af3
split casual_mask
kiansierra Jan 30, 2024
88e86c6
empty to rerun failed pt_flax_equivalence test FlaxWav2Vec2ModelTest
kiansierra Jan 30, 2024
5caed6b
went back to 2023 for tests_pr_documentation_tests
kiansierra Jan 30, 2024
7764c12
went back to 2024
kiansierra Jan 30, 2024
501cc22
changed tile to repeat
kiansierra Jan 31, 2024
9d46eeb
Merge branch 'main' into flax-mistral
kiansierra Jan 31, 2024
ed4461f
applied make style
kiansierra Jan 31, 2024
ab28806
empty for retry on Wav2Vec2
kiansierra Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ |
| [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ |
| [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ |
| [Mistral](model_doc/mistral) | ✅ | ❌ | |
| [Mistral](model_doc/mistral) | ✅ | ❌ | |
| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |
| [MMS](model_doc/mms) | ✅ | ✅ | ✅ |
Expand Down
10 changes: 10 additions & 0 deletions docs/source/en/model_doc/mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,13 @@ Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Sin

[[autodoc]] MistralForSequenceClassification
- forward

## FlaxMistralModel

[[autodoc]] FlaxMistralModel
- __call__

## FlaxMistralForCausalLM

[[autodoc]] FlaxMistralForCausalLM
- __call__
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4669,6 +4669,13 @@
"FlaxMBartPreTrainedModel",
]
)
_import_structure["models.mistral"].extend(
[
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]
)
_import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.opt"].extend(
[
Expand Down Expand Up @@ -8815,6 +8822,11 @@
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)
from .models.mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)
from .models.mt5 import (
FlaxMT5EncoderModel,
FlaxMT5ForConditionalGeneration,
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
kiansierra marked this conversation as resolved.
Show resolved Hide resolved
pt_state_dict = {
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved

model_prefix = flax_model.base_model_prefix

Expand All @@ -277,6 +280,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16

# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
Expand Down Expand Up @@ -313,11 +317,15 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
continue

# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)

else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[flax_key] = (
kiansierra marked this conversation as resolved.
Show resolved Hide resolved
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
return unflatten_dict(flax_state_dict)


Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("mistral", "FlaxMistralModel"),
("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"),
Expand Down Expand Up @@ -148,6 +149,7 @@
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("mistral", "FlaxMistralForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
return freeze(random_params)

def init_cache(self, batch_size, max_length):
r"""
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_flax_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
return freeze(random_params)
kiansierra marked this conversation as resolved.
Show resolved Hide resolved

def init_cache(self, batch_size, max_length):
r"""
Expand Down
30 changes: 25 additions & 5 deletions src/transformers/models/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available


_import_structure = {
Expand All @@ -38,6 +34,18 @@
"MistralForSequenceClassification",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_mistral"] = [
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig
Expand All @@ -55,6 +63,18 @@
MistralPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)


else:
import sys
Expand Down
Loading
Loading