Skip to content

Commit c9de02b

Browse files
jiqing-fengzaristei
authored andcommitted
enable static cache on TP model (huggingface#39164)
* enable static cache on TP model Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check tp size before init kv cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix docstring Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tp tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix other cache head size Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent cf2c93e commit c9de02b

File tree

4 files changed

+84
-1
lines changed

4 files changed

+84
-1
lines changed

src/transformers/cache_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,10 @@ class StaticCache(Cache):
10981098
Mapping between the layers and its device. This is required when you are manually initializing the cache
10991099
and the model is split between different gpus. You can know which layers mapped to which device by
11001100
checking the associated device_map: `model.hf_device_map`.
1101+
tp_size (`Optional[int]`, *optional*):
1102+
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
1103+
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
1104+
number of key/value heads will not be adjusted.
11011105
11021106
11031107
Example:
@@ -1130,6 +1134,7 @@ def __init__(
11301134
device: Union[torch.device, str, None] = None,
11311135
dtype: torch.dtype = torch.float32,
11321136
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
1137+
tp_size: Optional[int] = None,
11331138
) -> None:
11341139
super().__init__()
11351140
self.max_batch_size = max_batch_size
@@ -1144,6 +1149,13 @@ def __init__(
11441149
if getattr(config, "num_key_value_heads", None) is None
11451150
else config.num_key_value_heads
11461151
)
1152+
if tp_size is not None and tp_size > 1:
1153+
if self.num_key_value_heads % tp_size != 0:
1154+
raise ValueError(
1155+
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
1156+
)
1157+
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
1158+
self.num_key_value_heads //= tp_size
11471159

11481160
self.key_cache: list[torch.Tensor] = []
11491161
self.value_cache: list[torch.Tensor] = []
@@ -1573,6 +1585,10 @@ class HybridCache(Cache):
15731585
Mapping between the layers and its device. This is required when you are manually initializing the cache
15741586
and the model is split between different gpus. You can know which layers mapped to which device by
15751587
checking the associated device_map: `model.hf_device_map`.
1588+
tp_size (`Optional[int]`, *optional*):
1589+
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
1590+
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
1591+
number of key/value heads will not be adjusted.
15761592
15771593
Example:
15781594
@@ -1604,6 +1620,7 @@ def __init__(
16041620
device: Union[torch.device, str, None] = None,
16051621
dtype: torch.dtype = torch.float32,
16061622
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
1623+
tp_size: Optional[int] = None,
16071624
) -> None:
16081625
super().__init__()
16091626
if not hasattr(config, "sliding_window") or config.sliding_window is None:
@@ -1627,6 +1644,13 @@ def __init__(
16271644
if getattr(config, "num_key_value_heads", None) is None
16281645
else config.num_key_value_heads
16291646
)
1647+
if tp_size is not None and tp_size > 1:
1648+
if self.num_key_value_heads % tp_size != 0:
1649+
raise ValueError(
1650+
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
1651+
)
1652+
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
1653+
self.num_key_value_heads //= tp_size
16301654

16311655
# If the attribute does not exist in the config, fallback to a simple StaticCache
16321656
if hasattr(config, "layer_types"):
@@ -2197,6 +2221,10 @@ class OffloadedStaticCache(StaticCache):
21972221
Mapping between the layers and its device. This is required when you are manually initializing the cache
21982222
and the model is split between different gpus. You can know which layers mapped to which device by
21992223
checking the associated device_map: `model.hf_device_map`.
2224+
tp_size (`Optional[int]`, *optional*):
2225+
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
2226+
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
2227+
number of key/value heads will not be adjusted.
22002228
22012229
Example:
22022230
@@ -2228,6 +2256,7 @@ def __init__(
22282256
dtype: Optional[torch.dtype] = None,
22292257
offload_device: Union[str, torch.device] = torch.device("cpu"),
22302258
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
2259+
tp_size: Optional[int] = None,
22312260
) -> None:
22322261
super(Cache, self).__init__()
22332262

@@ -2251,6 +2280,13 @@ def __init__(
22512280
if getattr(config, "num_key_value_heads", None) is None
22522281
else config.num_key_value_heads
22532282
)
2283+
if tp_size is not None and tp_size > 1:
2284+
if num_key_value_heads % tp_size != 0:
2285+
raise ValueError(
2286+
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
2287+
)
2288+
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
2289+
num_key_value_heads //= tp_size
22542290

22552291
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
22562292

src/transformers/generation/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,9 @@ def _get_cache(
19631963
"device": device,
19641964
"layer_device_map": layer_device_map,
19651965
}
1966+
if cache_implementation in ["static", "hybrid", "offloaded_static"]:
1967+
cache_kwargs.update({"tp_size": self.tp_size})
1968+
19661969
self._cache = cache_cls(**cache_kwargs)
19671970
if requires_cross_attention_cache:
19681971
encoder_kwargs = cache_kwargs.copy()

src/transformers/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4494,6 +4494,9 @@ def from_pretrained(
44944494
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
44954495
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
44964496

4497+
if tp_size is None:
4498+
tp_size = torch.distributed.get_world_size()
4499+
44974500
if use_auth_token is not None:
44984501
warnings.warn(
44994502
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",

tests/tensor_parallel/test_tensor_parallel.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
16+
1517
import os
1618
import subprocess
1719
import tempfile
@@ -62,7 +64,6 @@ def size(self):
6264
assert torch.allclose(unpacked_weights, original_packed_weights)
6365

6466

65-
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
6667
class TestTensorParallel(TestCasePlus):
6768
nproc_per_node = 2
6869

@@ -125,6 +126,46 @@ def test_model_forward(self):
125126
)
126127
self.torchrun(script_to_run)
127128

129+
def test_model_generate(self):
130+
script_to_run = textwrap.dedent(
131+
"""
132+
import torch
133+
import os
134+
from transformers import AutoModelForCausalLM, AutoTokenizer
135+
136+
model_id = "JackFram/llama-68m"
137+
138+
rank = int(os.environ["RANK"])
139+
world_size = int(os.environ["WORLD_SIZE"])
140+
141+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
142+
torch.distributed.barrier()
143+
144+
model.forward = torch.compile(model.forward)
145+
146+
has_dtensor = 0
147+
for name, parameter in model.named_parameters():
148+
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
149+
has_dtensor = 1
150+
break
151+
152+
assert has_dtensor == 1, "TP model must has DTensor"
153+
154+
tokenizer = AutoTokenizer.from_pretrained(model_id)
155+
prompt = "Can I help"
156+
157+
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
158+
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
159+
160+
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
161+
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
162+
163+
torch.distributed.barrier()
164+
torch.distributed.destroy_process_group()
165+
"""
166+
)
167+
self.torchrun(script_to_run)
168+
128169
@require_huggingface_hub_greater_or_equal("0.31.4")
129170
def test_model_save(self):
130171
from safetensors import safe_open

0 commit comments

Comments
 (0)