Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shihaobai authored Dec 19, 2024
1 parent 8982092 commit 5488602
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lightllm/models/deepseek2/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
local_token_num = input_ids.size(0)
all_token_num = [torch.zeros(1, dtype=torch.int32).to(input_ids.device) for _ in range(world_size)]
dist.all_gather(all_token_num, torch.tensor([local_token_num], dtype=torch.int32).to(input_ids.device))
all_token_num = torch.cat(all_token_num, dim=0) # __~J: (world_size,)
all_token_num = torch.cat(all_token_num, dim=0)
self.all_token_num = all_token_num.sum().cpu().numpy()
cumsum_token_num = torch.cumsum(all_token_num, dim=0).cpu().numpy()
self.all_start_idx = cumsum_token_num - all_token_num.cpu().numpy()
Expand Down
4 changes: 3 additions & 1 deletion lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
import torch.functional as F
import torch.distributed as dist
Expand All @@ -21,6 +22,7 @@ def __init__(self, tp_rank, world_size, network_config, mode):
self.eps_ = network_config["rms_norm_eps"]
self.vocab_size_ = network_config["vocab_size"]
self.embed_dim_ = network_config["n_embed"]
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
return

def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor:
Expand Down Expand Up @@ -89,7 +91,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_
torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch)

last_input = None
if self.world_size_ == 1 or layer_weight.enable_dp:
if self.world_size_ == 1 or self.enable_dp:
gather_data = logic_batch
else:
gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype)
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/llama/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LlamaPreLayerInfer(PreLayerInferTpl):

def __init__(self, tp_rank, world_size, network_config, mode):
super().__init__(tp_rank, world_size, network_config, mode)
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["1", "ON"]
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
if not self.enable_dp:
tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64)
self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1])
Expand All @@ -30,7 +30,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1 and not layer_weight.enable_dp:
if self.world_size_ > 1 and not self.enable_dp:
dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False)
return input_embdings

Expand All @@ -39,7 +39,7 @@ def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weigh
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_
)
embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings)
if self.world_size_ > 1 and not layer_weight.enable_dp:
if self.world_size_ > 1 and not self.enable_dp:
dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False)
return input_embdings

Expand Down
13 changes: 4 additions & 9 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,18 +319,13 @@ async def _step(self):
return

def merge_rpyc_dict_ans(self, ans: List):
ans_: List[Dict] = []
if self.world_size != 1:
for d_ans in ans[0 : self.dp_size]:
d_ans = obtain(d_ans)
if len(d_ans) != 0:
ans_.append(d_ans)
else:
ans_ = ans
ans: List[Dict] = [obtain(e) for e in ans[0 : self.dp_size]]

if self.dp_size == 1:
return ans_[0]
return ans[0]
else:
return {k: v for t_ans in ans_ for k, v in t_ans.items()}
return {k: v for t_ans in ans for k, v in t_ans.items()}

async def _init_batch(self, batch: Batch):
reqs = [r.to_rpc_obj() for r in batch.reqs]
Expand Down
1 change: 0 additions & 1 deletion lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def init_batch(
req_manager: ReqManager,
vocab_size: int,
radix_cache: RadixCache = None,
dp_padding_batch_size: int = 0,
):

request_ids = []
Expand Down

0 comments on commit 5488602

Please sign in to comment.