Skip to content

Commit

Permalink
Add support for llama3 adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
radhikamp99 committed May 13, 2024
1 parent ae599c8 commit 060935d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ The following models from Hugging Face hub are currently supported
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b)
- [meta-llama/Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b)
- [meta-llama/Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b)
- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
- [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
- [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)
- [facebook/opt-125m](https://huggingface.co/facebook/opt-125m)
- [facebook/opt-1.3b](https://huggingface.co/facebook/opt-1.3b)
- [facebook/opt-2.7b](https://huggingface.co/facebook/opt-2.7b)
Expand Down Expand Up @@ -118,7 +122,7 @@ and update `hf_utils.get_model_and_tokenizer` before slicing the new model.
This class should also provide an adapted `forward()` method to work with the compressed model.
This method should specify how the skip connection orthogonal matrices are used, depending on
whether MLP and attention blocks are sequential ([OPT](./src/slicegpt/adapters/opt_adapter.py),
[Llama-2](./src/slicegpt/adapters/llama_adapter.py)) or parallel
[Llama-2/Llama-3](./src/slicegpt/adapters/llama_adapter.py)) or parallel
([Phi-2](./src/slicegpt/adapters/phi2_adapter.py)). The `self.*_shortcut_Q` matrices are attached to the modules during
slicing and are available in `forward()`. If the skip connection does not need modification, these matrices will be None,
and the `forward()` method can follow the original workflow. For more details on this,
Expand Down
2 changes: 1 addition & 1 deletion experiments/bo_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def lora_target_map(model: str):
'lm_head',
],
}
case 'meta-llama/Llama-2-7b-hf' | 'meta-llama/Llama-2-13b-hf' | 'meta-llama/Llama-2-70b-hf':
case 'meta-llama/Llama-2-7b-hf' | 'meta-llama/Llama-2-13b-hf' | 'meta-llama/Llama-2-70b-hf' | 'meta-llama/Meta-Llama-3-8B' | 'meta-llama/Meta-Llama-3-8B-Instruct' | 'meta-llama/Meta-Llama-3-70B' | 'meta-llama/Meta-Llama-3-70B-Instruct':
return {
'qkv_proj': ['k_proj', 'q_proj', 'v_proj'],
'attn_head': ['k_proj', 'q_proj', 'v_proj', 'o_proj'],
Expand Down
6 changes: 3 additions & 3 deletions src/slicegpt/adapters/llama_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_lm_head(self) -> Linear:
return self.model.lm_head

def post_init(self, tokenizer: PreTrainedTokenizerBase) -> None:
# Llama-2 doesn't have a pad token by default
# Llama-2 and Llama-3 don't have a pad tokens by default
tokenizer.pad_token = tokenizer.eos_token
self.config.pad_token_id = tokenizer.pad_token_id

Expand All @@ -227,7 +227,7 @@ def _from_pretrained(
local_files_only: bool = False,
token: str | bool | None = None,
) -> ModelAdapter | None:
if not model_name.startswith("meta-llama/Llama-2"):
if not (model_name.startswith("meta-llama/Llama-2") or model_name.startswith("meta-llama/Meta-Llama-3")):
return None

model = LlamaForCausalLM.from_pretrained(
Expand All @@ -247,7 +247,7 @@ def _from_uninitialized(
local_files_only: bool = False,
token: str | bool | None = None,
) -> ModelAdapter | None:
if not model_name.startswith("meta-llama/Llama-2"):
if not (model_name.startswith("meta-llama/Llama-2") or model_name.startswith("meta-llama/Meta-Llama-3")):
return None

class UninitializedLlamaForCausalLM(LlamaForCausalLM):
Expand Down

0 comments on commit 060935d

Please sign in to comment.