Skip to content

Commit

Permalink
fix multinode abort
Browse files Browse the repository at this point in the history
  • Loading branch information
shihaobai committed Feb 20, 2025
1 parent b3de424 commit 3fd6e48
Showing 1 changed file with 40 additions and 21 deletions.
61 changes: 40 additions & 21 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ def __init__(
context = zmq.asyncio.Context(2)
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}")

self.multinode_req_manager = None
self.child_node_events = {}
self.waiting_objs = []
self.child_node_lock = asyncio.Lock()
self.nnodes = args.nnodes
if args.nnodes > 1:
if args.node_rank == 0:
self.multinode_req_manager = []
Expand All @@ -64,12 +65,16 @@ def __init__(
context = zmq.asyncio.Context(2)
self.multinode_req_manager.append(context.socket(zmq.PUSH))
self.multinode_req_manager[-1].connect(f"tcp://{child_ip}:{args.multinode_httpmanager_port}")
logger.info(f"HttpServerManager connected to child node at {child_ip}:{args.multinode_httpmanager_port}")
logger.info(
f"HttpServerManager connected to child node at {child_ip}:{args.multinode_httpmanager_port}"
)
else:
context = zmq.asyncio.Context(2)
self.multinode_req_manager = context.socket(zmq.PULL)
self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}")
logger.info(f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}")
logger.info(
f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}"
)

self.enable_multimodal = enable_multimodal
if self.enable_multimodal:
Expand Down Expand Up @@ -145,16 +150,26 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
def tokens(self, prompt):
prompt_ids = self.tokenizer.encode(prompt)
return len(prompt_ids)

async def loop_for_request(self):
assert self.args.node_rank > 0
tasks = []
while True:
request_id, prompt, sampling_params, multimodal_params, request_headers = await self.multinode_req_manager.recv_pyobj()
results_generator = self.generate(prompt, sampling_params, multimodal_params, None, request_headers, request_id)
(
request_id,
prompt,
sampling_params,
multimodal_params,
request_headers,
) = await self.multinode_req_manager.recv_pyobj()
results_generator = self.generate(
prompt, sampling_params, multimodal_params, None, request_headers, request_id
)

async def generate_wrapper(results_generator):
async for _, _, _, _ in results_generator:
pass

tasks.append(asyncio.create_task(generate_wrapper(results_generator)))
# cleanup
while len(tasks) > 0 and tasks[0].done():
Expand All @@ -166,7 +181,7 @@ async def generate(
sampling_params: SamplingParams,
multimodal_params: MultimodalParams,
request: Request,
request_headers = None,
request_headers=None,
multinode_remote_request_id: Optional[int] = None,
) -> Tuple[int, str, dict, FinishStatus]:
start_time = time.time()
Expand All @@ -178,7 +193,10 @@ async def generate(
if multinode_remote_request_id is None:
group_request_id = self.id_gen.generate_id()
for sender in self.multinode_req_manager:
sender.send_pyobj((group_request_id, prompt, sampling_params, multimodal_params, request_headers), protocol=pickle.HIGHEST_PROTOCOL)
sender.send_pyobj(
(group_request_id, prompt, sampling_params, multimodal_params, request_headers),
protocol=pickle.HIGHEST_PROTOCOL,
)
else:
group_request_id = multinode_remote_request_id
sampling_params.group_request_id = group_request_id
Expand Down Expand Up @@ -238,8 +256,12 @@ async def generate(
await self.transfer_to_next_module(req_status.group_req_objs)

results_generator = self._wait_to_token_package(
start_time, prompt_ids, group_request_id, sampling_params, req_status, # request,
request_headers,
start_time,
prompt_ids,
group_request_id,
sampling_params,
req_status,
request,
)
async for sub_req_id, request_output, metadata, finish_status in results_generator:
# p d 模式下,将 token 数据放入到转发队列中
Expand Down Expand Up @@ -368,8 +390,7 @@ async def _wait_to_token_package(
group_request_id: int,
sampling_params: SamplingParams,
req_status: "ReqStatus",
request_headers,
# request: Request,
request: Request,
):

event = req_status.event
Expand All @@ -385,10 +406,9 @@ async def _wait_to_token_package(
except asyncio.TimeoutError:
pass

# TODO: abort() for multinode
# if request is not None and await request.is_disconnected():
# await self.abort(group_request_id)
# raise Exception(f"req_id {group_request_id} disconnected")
if request is not None and await request.is_disconnected() and self.nnodes == 1:
await self.abort(group_request_id)
raise Exception(f"req_id {group_request_id} disconnected")

async with req_status.lock:
event.clear()
Expand Down Expand Up @@ -416,13 +436,12 @@ async def _wait_to_token_package(
unfinished_count -= 1

# 所有子请求完成后,就删除占用的资源
if unfinished_count == 0:
if unfinished_count == 0 and request is not None:
total_cost_time_ms = (time.time() - start_time) * 1000
mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter
self.per_token_costs.add(mean_per_token_cost_time_ms)
x_request_id = request_headers.get("X-Request-Id", "") if request_headers is not None else ""
x_session_id = request_headers.get("X-Session-Id", "") if request_headers is not None else ""

x_request_id = request.headers.get("X-Request-Id", "") if request is not None else ""
x_session_id = request.headers.get("X-Session-Id", "") if request is not None else ""
prompt_cache_ratio = prompt_cache_len / prompt_tokens
self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len)
self.metric_client.histogram_observe("lightllm_cache_ratio", prompt_cache_ratio)
Expand Down Expand Up @@ -506,7 +525,7 @@ async def handle_loop(self):
if self.pd_mode.is_P_or_D():
self.forwarding_queue = AsyncQueue()
asyncio.create_task(self.pd_handle_loop())

if self.args.node_rank > 0:
asyncio.create_task(self.loop_for_request())

Expand Down

0 comments on commit 3fd6e48

Please sign in to comment.