|
| 1 | +<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
| 4 | +the License. You may obtain a copy of the License at |
| 5 | +
|
| 6 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +
|
| 8 | +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| 9 | +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| 10 | +specific language governing permissions and limitations under the License. |
| 11 | +
|
| 12 | +โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be |
| 13 | +rendered properly in your Markdown viewer. |
| 14 | +
|
| 15 | +--> |
| 16 | + |
| 17 | +<div style="float: right;"> |
| 18 | + <div class="flex flex-wrap space-x-1"> |
| 19 | + <img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> |
| 20 | + <img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> |
| 21 | + <img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> |
| 22 | + </div> |
| 23 | +</div> |
| 24 | + |
| 25 | +# Jamba[[jamba]] |
| 26 | + |
| 27 | +[Jamba](https://huggingface.co/papers/2403.19887)๋ Transformer์ Mamba ๊ธฐ๋ฐ์ ํ์ด๋ธ๋ฆฌ๋ ์ ๋ฌธ๊ฐ ํผํฉ(MoE) ์ธ์ด ๋ชจ๋ธ๋ก, ์ด ๋งค๊ฐ๋ณ์ ์๋ 52B์์ 398B๊น์ง ๋ค์ํฉ๋๋ค. ์ด ๋ชจ๋ธ์ Transformer ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ Mamba์ ๊ฐ์ ์ํ ๊ณต๊ฐ ๋ชจ๋ธ์ ํจ์จ์ฑ ๋ฐ ๊ธด ์ปจํ
์คํธ ์ฒ๋ฆฌ ๋ฅ๋ ฅ(256K ํ ํฐ)์ ๋ชจ๋ ํ์ฉํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. |
| 28 | + |
| 29 | +Jamba์ ์ํคํ
์ฒ๋ ๋ธ๋ก๊ณผ ๋ ์ด์ด ๊ธฐ๋ฐ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์ฌ Transformer์ Mamba ์ํคํ
์ฒ๋ฅผ ํตํฉํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค. ๊ฐ Jamba ๋ธ๋ก์ ์ดํ
์
๋ ์ด์ด ๋๋ Mamba ๋ ์ด์ด ์ค ํ๋์ ๊ทธ ๋ค๋ฅผ ์๋ ๋ค์ธต ํผ์
ํธ๋ก (MLP)์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. Transformer ๋ ์ด์ด๋ 8๊ฐ์ ๋ ์ด์ด ์ค ํ๋์ ๋น์จ๋ก ์ฃผ๊ธฐ์ ์ผ๋ก ๋ฐฐ์น๋ฉ๋๋ค. ๋ํ ๋ชจ๋ธ ์ฉ๋์ ํ์ฅํ๊ธฐ ์ํด MoE ๋ ์ด์ด๊ฐ ํผํฉ๋์ด ์์ต๋๋ค. |
| 30 | + |
| 31 | +๋ชจ๋ ์๋ณธ Jamba ์ฒดํฌํฌ์ธํธ๋ [AI21](https://huggingface.co/ai21labs) ์กฐ์ง์์ ํ์ธํ ์ ์์ต๋๋ค. |
| 32 | + |
| 33 | +> [!TIP] |
| 34 | +> ์ค๋ฅธ์ชฝ ์ฌ์ด๋๋ฐ์ ์๋ Jamba ๋ชจ๋ธ์ ๋๋ฅด๋ฉด ๋ค์ํ ์ธ์ด ์์
์ Jamba๋ฅผ ์ ์ฉํ๋ ์์ ๋ฅผ ๋ ํ์ธํ ์ ์์ต๋๋ค. |
| 35 | +
|
| 36 | +์๋ ์์ ๋ [`Pipeline`]๊ณผ [`AutoModel`], ๊ทธ๋ฆฌ๊ณ ์ปค๋งจ๋๋ผ์ธ์ ํตํด ํ
์คํธ๋ฅผ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. |
| 37 | + |
| 38 | +<hfoptions id="usage"> |
| 39 | +<hfoption id="Pipeline"> |
| 40 | + |
| 41 | +```py |
| 42 | +# ์ต์ ํ๋ Mamba ๊ตฌํ ์ค์น |
| 43 | +# !pip install mamba-ssm causal-conv1d>=1.2.0 |
| 44 | +import torch |
| 45 | +from transformers import pipeline |
| 46 | + |
| 47 | +pipeline = pipeline( |
| 48 | + task="text-generation", |
| 49 | + model="ai21labs/AI21-Jamba-Mini-1.6", |
| 50 | + torch_dtype=torch.float16, |
| 51 | + device=0 |
| 52 | +) |
| 53 | +pipeline("Plants create energy through a process known as") |
| 54 | +``` |
| 55 | + |
| 56 | +</hfoption> |
| 57 | +<hfoption id="AutoModel"> |
| 58 | + |
| 59 | +```py |
| 60 | +import torch |
| 61 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 62 | + |
| 63 | +tokenizer = AutoTokenizer.from_pretrained( |
| 64 | + "ai21labs/AI21-Jamba-Large-1.6", |
| 65 | +) |
| 66 | +model = AutoModelForCausalLM.from_pretrained( |
| 67 | + "ai21labs/AI21-Jamba-Large-1.6", |
| 68 | + torch_dtype=torch.float16, |
| 69 | + device_map="auto", |
| 70 | + attn_implementation="sdpa" |
| 71 | +) |
| 72 | +input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda") |
| 73 | + |
| 74 | +output = model.generate(**input_ids, cache_implementation="static") |
| 75 | +print(tokenizer.decode(output[0], skip_special_tokens=True)) |
| 76 | +``` |
| 77 | +</hfoption> |
| 78 | +<hfoption id="transformers CLI"> |
| 79 | + |
| 80 | +```bash |
| 81 | +echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model ai21labs/AI21-Jamba-Mini-1.6 --device 0 |
| 82 | +``` |
| 83 | + |
| 84 | +</hfoption> |
| 85 | +</hfoptions> |
| 86 | + |
| 87 | +์์ํ๋ ๊ฐ์ค์น๋ฅผ ๋ ๋ฎ์ ์ ๋ฐ๋๋ก ํํํ์ฌ ๋๊ท๋ชจ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ๋ถ๋ด์ ์ค์ฌ์ค๋๋ค. ์ฌ์ฉํ ์ ์๋ ๋ค์ํ ์์ํ ๋ฐฑ์๋์ ๋ํด์๋ [Quantization](../quantization/overview)๋ฅผ ์ฐธ๊ณ ํ์ธ์. |
| 88 | + |
| 89 | +์๋ ์์ ๋ [bitsandbytes](../quantization/bitsandbytes)๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ค์น๋ง 8๋นํธ๋ก ์์ํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. |
| 90 | + |
| 91 | +```py |
| 92 | +import torch |
| 93 | +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| 94 | + |
| 95 | +quantization_config = BitsAndBytesConfig(load_in_8bit=True, |
| 96 | + llm_int8_skip_modules=["mamba"]) |
| 97 | + |
| 98 | +# ๋ชจ๋ธ์ 8๊ฐ์ GPU์ ๊ณ ๋ฅด๊ฒ ๋ถ์ฐ์ํค๊ธฐ ์ํ ๋๋ฐ์ด์ค ๋งต |
| 99 | +device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 4, 'model.layers.37': 4, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40': 4, 'model.layers.41': 4, 'model.layers.42': 4, 'model.layers.43': 4, 'model.layers.44': 4, 'model.layers.45': 5, 'model.layers.46': 5, 'model.layers.47': 5, 'model.layers.48': 5, 'model.layers.49': 5, 'model.layers.50': 5, 'model.layers.51': 5, 'model.layers.52': 5, 'model.layers.53': 5, 'model.layers.54': 6, 'model.layers.55': 6, 'model.layers.56': 6, 'model.layers.57': 6, 'model.layers.58': 6, 'model.layers.59': 6, 'model.layers.60': 6, 'model.layers.61': 6, 'model.layers.62': 6, 'model.layers.63': 7, 'model.layers.64': 7, 'model.layers.65': 7, 'model.layers.66': 7, 'model.layers.67': 7, 'model.layers.68': 7, 'model.layers.69': 7, 'model.layers.70': 7, 'model.layers.71': 7, 'model.final_layernorm': 7, 'lm_head': 7} |
| 100 | +model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-Large-1.6", |
| 101 | + torch_dtype=torch.bfloat16, |
| 102 | + attn_implementation="flash_attention_2", |
| 103 | + quantization_config=quantization_config, |
| 104 | + device_map=device_map) |
| 105 | + |
| 106 | +tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Large-1.6") |
| 107 | + |
| 108 | +messages = [ |
| 109 | + {"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."}, |
| 110 | + {"role": "user", "content": "Hello!"}, |
| 111 | +] |
| 112 | + |
| 113 | +input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device) |
| 114 | + |
| 115 | +outputs = model.generate(input_ids, max_new_tokens=216) |
| 116 | + |
| 117 | +# ์ถ๋ ฅ ๋์ฝ๋ฉ |
| 118 | +conversation = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| 119 | + |
| 120 | +# ์ด์์คํดํธ์ ์๋ต๋ง ์ถ์ถ |
| 121 | +assistant_response = conversation.split(messages[-1]['content'])[1].strip() |
| 122 | +print(assistant_response) |
| 123 | +# ์ถ๋ ฅ: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes? |
| 124 | +``` |
| 125 | + |
| 126 | +## ์ฐธ๊ณ [[notes]] |
| 127 | + |
| 128 | +- ๋ชจ๋ธ ์ฑ๋ฅ ์ ํ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด Mamba ๋ธ๋ก์ ์์ํํ์ง ๋ง์ธ์. |
| 129 | +- ์ต์ ํ๋ Mamba ์ปค๋ ์์ด Mamba๋ฅผ ์ฌ์ฉํ๋ฉด ์ง์ฐ ์๊ฐ์ด ํฌ๊ฒ ์ฆ๊ฐํ๋ฏ๋ก ๊ถ์ฅ๋์ง ์์ต๋๋ค. ๊ทธ๋๋ ์ปค๋ ์์ด Mamba๋ฅผ ์ฌ์ฉํ๊ณ ์ ํ๋ค๋ฉด [`~AutoModel.from_pretrained`]์์ `use_mamba_kernels=False`๋ก ์ค์ ํ์ธ์. |
| 130 | + |
| 131 | + ```py |
| 132 | + import torch |
| 133 | + from transformers import AutoModelForCausalLM |
| 134 | + model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Large", |
| 135 | + use_mamba_kernels=False) |
| 136 | + ``` |
| 137 | + |
| 138 | +## JambaConfig[[transformers.JambaConfig]] |
| 139 | + |
| 140 | +[[autodoc]] JambaConfig |
| 141 | + |
| 142 | + |
| 143 | +## JambaModel[[transformers.JambaModel]] |
| 144 | + |
| 145 | +[[autodoc]] JambaModel |
| 146 | + - forward |
| 147 | + |
| 148 | + |
| 149 | +## JambaForCausalLM[[transformers.JambaForCausalLM]] |
| 150 | + |
| 151 | +[[autodoc]] JambaForCausalLM |
| 152 | + - forward |
| 153 | + |
| 154 | + |
| 155 | +## JambaForSequenceClassification[[transformers.JambaForSequenceClassification]] |
| 156 | + |
| 157 | +[[autodoc]] transformers.JambaForSequenceClassification |
| 158 | + - forward |
0 commit comments