Skip to content

Commit

Permalink
Flax mistral (#26943)
Browse files Browse the repository at this point in the history
* direct copy from llama work

* mistral modules forward pass working

* flax mistral forward pass with sliding window

* added tests

* added layer collection approach

* Revert "added layer collection approach"

This reverts commit 0e2905b.

* Revert "Revert "added layer collection approach""

This reverts commit fb17b61.

* fixed attention outputs

* added mistral to init and auto

* fixed import name

* fixed layernorm weight dtype

* freeze initialized weights

* make sure conversion consideres bfloat16

* added backend

* added docstrings

* added cache

* fixed sliding window causal mask

* passes cache tests

* passed all tests

* applied make style

* removed commented out code

* applied fix-copies ignored other model changes

* applied make fix-copies

* removed unused functions

* passed generation integration test

* slow tests pass

* fixed slow tests

* changed default dtype from jax.numpy.float32 to float32 for docstring check

* skip cache test  for FlaxMistralForSequenceClassification since if pad_token_id in input_ids it doesn't score previous input_ids

* updated checkpoint since from_pt not included

* applied black style

* removed unused args

* Applied styling and fixup

* changed checkpoint for doc back

* fixed rf after adding it to hf hub

* Add dummy ckpt

* applied styling

* added tokenizer to new ckpt

* fixed slice format

* fix init and slice

* changed ref for placeholder TODO

* added copies from Llama

* applied styling

* applied fix-copies

* fixed docs

* update weight dtype reconversion for sharded weights

* removed Nullable input ids

* Removed unnecessary output attentions in Module

* added embedding weight initialziation

* removed unused past_key_values

* fixed deterministic

* Fixed RMS Norm and added copied from

* removed input_embeds

* applied make style

* removed nullable input ids from sequence classification model

* added copied from GPTJ

* added copied from Llama on FlaxMistralDecoderLayer

* added copied from to FlaxMistralPreTrainedModel methods

* fix test deprecation warning

* freeze gpt neox random_params and fix copies

* applied make style

* fixed doc issue

* skipped docstring test to allign # copied from

* applied make style

* removed FlaxMistralForSequenceClassification

* removed unused padding_idx

* removed more sequence classification

* removed sequence classification

* applied styling and consistency

* added copied from in tests

* removed sequence classification test logic

* applied styling

* applied make style

* removed freeze and fixed copies

* undo test change

* changed repeat_kv to tile

* fixed to key value groups

* updated copyright year

* split casual_mask

* empty to rerun failed pt_flax_equivalence test FlaxWav2Vec2ModelTest

* went back to 2023 for tests_pr_documentation_tests

* went back to 2024

* changed tile to repeat

* applied make style

* empty for retry on Wav2Vec2
  • Loading branch information
kiansierra authored Jan 31, 2024
1 parent 7a49610 commit f7076cd
Show file tree
Hide file tree
Showing 10 changed files with 1,068 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Megatron-BERT](model_doc/megatron-bert) ||||
| [Megatron-GPT2](model_doc/megatron_gpt2) ||||
| [MGP-STR](model_doc/mgp-str) ||||
| [Mistral](model_doc/mistral) ||| |
| [Mistral](model_doc/mistral) ||| |
| [Mixtral](model_doc/mixtral) ||||
| [mLUKE](model_doc/mluke) ||||
| [MMS](model_doc/mms) ||||
Expand Down
10 changes: 10 additions & 0 deletions docs/source/en/model_doc/mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,13 @@ Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Sin

[[autodoc]] MistralForSequenceClassification
- forward

## FlaxMistralModel

[[autodoc]] FlaxMistralModel
- __call__

## FlaxMistralForCausalLM

[[autodoc]] FlaxMistralForCausalLM
- __call__
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4678,6 +4678,13 @@
"FlaxMBartPreTrainedModel",
]
)
_import_structure["models.mistral"].extend(
[
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]
)
_import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.opt"].extend(
[
Expand Down Expand Up @@ -8830,6 +8837,11 @@
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)
from .models.mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)
from .models.mt5 import (
FlaxMT5EncoderModel,
FlaxMT5ForConditionalGeneration,
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
# load using msgpack utils
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}

model_prefix = flax_model.base_model_prefix

Expand All @@ -278,6 +281,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16

# remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix
Expand Down Expand Up @@ -314,11 +318,15 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
continue

# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)

else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
return unflatten_dict(flax_state_dict)


Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("mistral", "FlaxMistralModel"),
("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"),
Expand Down Expand Up @@ -148,6 +149,7 @@
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("mistral", "FlaxMistralForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
Expand Down
30 changes: 25 additions & 5 deletions src/transformers/models/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available


_import_structure = {
Expand All @@ -38,6 +34,18 @@
"MistralForSequenceClassification",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_mistral"] = [
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig
Expand All @@ -55,6 +63,18 @@
MistralPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)


else:
import sys
Expand Down
Loading

0 comments on commit f7076cd

Please sign in to comment.