Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Llama Flax Implementation #24587

Merged
merged 91 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 86 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
415f0d2
Copies `modeling_flax_gpt_neo.py` to start
vvvm23 Jun 6, 2023
76a599c
MLP Block. WIP Attention and Block
vvvm23 Jun 6, 2023
4f5654d
Adds Flax implementation of `LlamaMLP`
vvvm23 Jun 21, 2023
36c48fa
Adds `FlaxLlamaRMSNorm` layer
vvvm23 Jun 21, 2023
a38b097
Adds FlaxLlamaAttention
vvvm23 Jun 23, 2023
5738666
Adds `FlaxLlamaDecoderLayer`
vvvm23 Jun 25, 2023
401a72f
debugging rotary mismatch
vvvm23 Jun 25, 2023
5177bfd
fixes bug with decoder layer
vvvm23 Jun 27, 2023
578e0d9
adds markers for what to implement next
vvvm23 Jun 27, 2023
9f74d83
implements `FlaxLlamaBlockCollection`]
vvvm23 Jul 10, 2023
a60b00f
Adds `FlaxLlamaModule`
vvvm23 Jul 11, 2023
c7ac55b
adds `FlaxLlamaForCausalLMModule`
vvvm23 Jul 12, 2023
b6dff5a
start porting pretrained wrappers
vvvm23 Jul 12, 2023
7bf3567
cleanup, quality, style
vvvm23 Jul 14, 2023
99d40a0
readds `return_dict` and model output named tuples
vvvm23 Jul 14, 2023
4eebae9
(tentatively) pretrained wrappers work 🔥
vvvm23 Jul 14, 2023
4bb4206
fixes numerical mismatch in `FlaxLlamaRMSNorm`
vvvm23 Jul 15, 2023
b78671e
[WIP] debugging numerics
vvvm23 Jul 15, 2023
3f7bc54
numerical match
vvvm23 Jul 16, 2023
e386c85
adds in model and integration tests for Flax Llama
vvvm23 Aug 5, 2023
539d041
adds missing TYPE_CHECKING import and `make fixup`
vvvm23 Aug 5, 2023
c695d0a
adds back missing docstrings
vvvm23 Aug 5, 2023
1776e75
commenting out equivalence test as can just use common
vvvm23 Aug 6, 2023
3b4f55a
debugging
vvvm23 Aug 10, 2023
3fa9f2e
Fixes bug where mask and pos_ids were swapped in pretrained models
vvvm23 Aug 10, 2023
9d4bdad
cleanup of modeling file
vvvm23 Aug 11, 2023
ffb5e47
cleanup of test file
vvvm23 Aug 11, 2023
020bd4e
Resolving simpler review comments
vvvm23 Aug 17, 2023
e9e391f
addresses more minor review comments
vvvm23 Aug 23, 2023
ff0818f
fixing introduced pytest errors from review
vvvm23 Aug 23, 2023
d18daad
wip additional slow tests
vvvm23 Aug 23, 2023
230abeb
wip tests
vvvm23 Aug 24, 2023
b19213d
`make quality`, `make style`
vvvm23 Aug 24, 2023
2959abd
adds slow integration tests
vvvm23 Aug 24, 2023
a5b587b
`make fix-copies`
vvvm23 Aug 24, 2023
852e5e3
fix mangled function following `make fix-copies`
vvvm23 Aug 24, 2023
fd85d5a
adds missing type checking imports
vvvm23 Aug 25, 2023
fe5aed2
fixes missing parameter checkpoint warning
vvvm23 Aug 30, 2023
57b47c6
more finegrained 'Copied from' tags
vvvm23 Aug 31, 2023
b768559
swaps import guards
vvvm23 Aug 31, 2023
ac3f74f
removing `inv_freq` again as pytorch version has now removed
vvvm23 Aug 31, 2023
05cade4
attempting to get CI to pass
vvvm23 Aug 31, 2023
3bf0b8b
adds doc entries for llama flax models
vvvm23 Aug 31, 2023
211a72b
fixes typo in __init__.py imports
vvvm23 Aug 31, 2023
27a7522
adds back special equivalence tests
vvvm23 Aug 31, 2023
67f300c
overrides tests with dummy to see if CI passes
vvvm23 Aug 31, 2023
2ec5c20
adds my contribution to docs
vvvm23 Sep 1, 2023
609a113
`make style; make quality`
vvvm23 Sep 1, 2023
224f546
replaces random masking with fixed to work with flax version
vvvm23 Sep 1, 2023
7de8b58
`make quality; make style`
vvvm23 Sep 24, 2023
20b5767
Update src/transformers/models/llama/modeling_flax_llama.py
vvvm23 Sep 7, 2023
ac4183c
Update src/transformers/models/llama/modeling_flax_llama.py
vvvm23 Sep 7, 2023
bd5451a
Update src/transformers/models/llama/modeling_flax_llama.py
vvvm23 Sep 7, 2023
c997a38
Update src/transformers/models/llama/modeling_flax_llama.py
vvvm23 Sep 7, 2023
5019d4c
Update src/transformers/models/llama/modeling_flax_llama.py
vvvm23 Sep 7, 2023
4df7730
Update src/transformers/models/llama/modeling_flax_llama.py
vvvm23 Sep 7, 2023
f8ccb05
updates `x`->`tensor` in `rotate_half`
vvvm23 Sep 24, 2023
b01cb70
addresses smaller review comments
vvvm23 Sep 24, 2023
6848c63
Update docs/source/en/model_doc/llama.md
vvvm23 Sep 24, 2023
9994b91
adds integration test class
vvvm23 Sep 24, 2023
d248925
adds `dtype` to rotary embedding to cast outputs
vvvm23 Sep 24, 2023
f1fc40a
adds type to flax llama rotary layer
vvvm23 Sep 24, 2023
1f7cb9b
`make style`
vvvm23 Sep 24, 2023
be7be91
`make fix-copies`
vvvm23 Sep 24, 2023
3a7a3ae
Apply suggestions from code review
vvvm23 Sep 30, 2023
20e6b35
applies suggestions from review
vvvm23 Sep 30, 2023
3da6a6a
Update modeling_flax_llama.py
vvvm23 Oct 2, 2023
5f5ca1d
`make fix-copies`
vvvm23 Oct 2, 2023
9166130
Update tests/models/llama/test_modeling_llama.py
vvvm23 Oct 2, 2023
d8570b4
Update src/transformers/models/llama/modeling_flax_llama.py
vvvm23 Oct 2, 2023
6d7a930
fixes shape mismatch in FlaxLlamaMLP
vvvm23 Oct 3, 2023
39b55f8
applies some suggestions from reviews
vvvm23 Oct 7, 2023
a6d8c06
casts attn output logits to f32 regardless of dtype
vvvm23 Oct 16, 2023
9718fed
adds attn bias using `LlamaConfig.attention_bias`
vvvm23 Oct 16, 2023
fc6554b
adds Copied From comments to Flax Llama test
vvvm23 Oct 17, 2023
f9fd7b6
mistral and persimmon test change -copy from llama
vvvm23 Oct 17, 2023
8b1f374
updates docs index
vvvm23 Nov 5, 2023
d9c7af6
removes Copied from in tests
vvvm23 Nov 5, 2023
30fe8e1
quality and style
vvvm23 Nov 5, 2023
d21d306
ignores FlaxLlama input docstring
vvvm23 Nov 9, 2023
d14d339
Merge branch 'main' into add-llama-flax
vvvm23 Nov 9, 2023
2b9e410
Merge branch 'main' into add-llama-flax
vvvm23 Nov 24, 2023
3390651
adds revision to `_CHECKPOINT_FOR_DOC`
vvvm23 Nov 24, 2023
f303188
repo consistency and quality
vvvm23 Nov 24, 2023
aea3e03
removes unused import
vvvm23 Nov 24, 2023
842b550
removes copied from from Phi test
vvvm23 Nov 24, 2023
7e69179
Merge branch 'main' into add-llama-flax
vvvm23 Dec 4, 2023
193052f
adds `_REAL_CHECKPOINT_FOR_DOC`
vvvm23 Dec 4, 2023
b33a54d
removes refs from pr tests
vvvm23 Dec 5, 2023
b9ed34f
reformat to make ruff happy
vvvm23 Dec 5, 2023
a15b844
Merge branch 'main' into add-llama-flax
vvvm23 Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -94,7 +94,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CLIPSeg](model_doc/clipseg) | ✅ | ❌ | ❌ |
| [CLVP](model_doc/clvp) | ✅ | ❌ | ❌ |
| [CodeGen](model_doc/codegen) | ✅ | ❌ | ❌ |
| [CodeLlama](model_doc/code_llama) | ✅ | ❌ | |
| [CodeLlama](model_doc/code_llama) | ✅ | ❌ | |
| [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ |
| [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ |
| [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ |
Expand Down Expand Up @@ -167,8 +167,8 @@ Flax), PyTorch, and/or TensorFlow.
| [LED](model_doc/led) | ✅ | ✅ | ❌ |
| [LeViT](model_doc/levit) | ✅ | ❌ | ❌ |
| [LiLT](model_doc/lilt) | ✅ | ❌ | ❌ |
| [LLaMA](model_doc/llama) | ✅ | ❌ | |
| [Llama2](model_doc/llama2) | ✅ | ❌ | |
| [LLaMA](model_doc/llama) | ✅ | ❌ | |
| [Llama2](model_doc/llama2) | ✅ | ❌ | |
| [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ |
| [LongT5](model_doc/longt5) | ✅ | ❌ | ✅ |
| [LUKE](model_doc/luke) | ✅ | ❌ | ❌ |
Expand Down
13 changes: 13 additions & 0 deletions docs/source/en/model_doc/llama.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ come in several checkpoints they each contain a part of each weight of the model

- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.

This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama). The Flax version of the implementation was contributed by [afmck](https://huggingface.co/afmck) with the code in the implementation based on Hugging Face's Flax GPT-Neo.


Based on the original LLaMA model, Meta AI has released some follow-up works:

- **Llama2**: Llama2 is an improved version of Llama with some architectural tweaks (Grouped Query Attention), and is pre-trained on 2Trillion tokens. Refer to the documentation of Llama2 which can be found [here](llama2).
Expand Down Expand Up @@ -112,3 +115,13 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h

[[autodoc]] LlamaForSequenceClassification
- forward

## FlaxLlamaModel

[[autodoc]] FlaxLlamaModel
- __call__

## FlaxLlamaForCausalLM

[[autodoc]] FlaxLlamaForCausalLM
- __call__
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4191,6 +4191,7 @@
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"])
_import_structure["models.longt5"].extend(
["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"]
)
Expand Down Expand Up @@ -7788,6 +7789,7 @@
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
from .models.llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel
from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
from .models.mbart import (
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("llama", "FlaxLlamaModel"),
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
Expand Down Expand Up @@ -146,6 +147,7 @@
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -55,6 +56,14 @@
"LlamaForSequenceClassification",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]


if TYPE_CHECKING:
from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
Expand Down Expand Up @@ -83,6 +92,14 @@
else:
from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel


else:
import sys
Expand Down
Loading