Skip to content

Commit

Permalink
fix get_memory_pool_size deadlock for DP (sgl-project#1830)
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Oct 29, 2024
1 parent 0a24eb8 commit 680cad2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
27 changes: 23 additions & 4 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,22 @@ async def get_memory_pool_size(self):
self.create_handle_loop()

req = GetMemPoolSizeReq()
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
return await self.mem_pool_size
ret = None

if self.server_args.dp_size == 1:
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
res = await self.mem_pool_size
ret = res.size

else: # self.server_args.dp_size > 1
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]

return ret

async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
Expand Down Expand Up @@ -634,7 +647,13 @@ async def handle_loop(self):
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.mem_pool_size.set_result(recv_obj)
if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue

assert isinstance(
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try:
ret = await tokenizer_manager.get_memory_pool_size()
return ret.size

return ret
except Exception as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
Expand Down
9 changes: 9 additions & 0 deletions test/srt/test_data_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_update_weight(self):
# check if the response is 200
assert response.status_code == 200

def test_get_memory_pool_size(self):
response = requests.get(self.base_url + "/get_memory_pool_size")
assert response.status_code == 200

time.sleep(5)

response = requests.get(self.base_url + "/get_memory_pool_size")
assert response.status_code == 200


if __name__ == "__main__":
unittest.main()

0 comments on commit 680cad2

Please sign in to comment.