Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,7 @@ def schedule(self):
# Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request))
num_decoding_req_nums += 1
token_budget -= 1

token_budget -= 1
if (
request.use_extend_tables
and request.request_id not in self.using_extend_tables_req_id
Expand Down
52 changes: 40 additions & 12 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,31 +1192,48 @@ def initialize_kv_cache(self, profile: bool = False) -> None:

logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
cache_kvs_list = []

# NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention,
# To rationalize the allocation of kvcache.
from fastdeploy import envs

self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是用 MLA 的模型自动设置此环境变量,还是需要手动设置?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前是启动脚本手动设置 export FD_ATTENTION_BACKEND="MLA_ATTN",
后面会根据config.json中的model_type 自动设置backend,这项修改计划和mla默认开启tensor_core一起提交。

for i in range(self.model_config.num_hidden_layers):
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
if not self.mla_cache:
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
if create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(key_cache, key_cache_name)
set_data_ipc(val_cache, val_cache_name)
cache_kvs_list.extend([key_cache, val_cache])
if not self.mla_cache:
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(val_cache, val_cache_name)
cache_kvs_list.extend([key_cache, val_cache])
else:
cache_kvs_list.extend([key_cache])
if kv_cache_quant_type == "block_wise_fp8":
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
if not self.mla_cache:
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
else:
cache_kvs_list.extend([key_cache_scales])
else:
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.extend([key_cache, val_cache])
if not self.mla_cache:
val_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.extend([key_cache, val_cache])
else:
cache_kvs_list.extend([key_cache])

self.share_inputs["caches"] = cache_kvs_list

if not profile and create_cache_tensor:
Expand Down Expand Up @@ -1936,7 +1953,18 @@ def cal_theortical_kvcache(self):
if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers
)
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v

# NOTE:(changwenbin) Determie whether it is Multi-Head Latent Attention,
# To rationalize the allocation of kvcache.
if self.mla_cache:
required_memory = (
byte_of_dtype
* (self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.qk_rope_head_dim)
* (self.cache_config.block_size)
* num_layers
) # compress_kv + k_pe
else:
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
return required_memory

def not_need_stop(self) -> bool:
Expand Down
Loading