Skip to content

Commit

Permalink
Add LLM Engine
Browse files Browse the repository at this point in the history
  • Loading branch information
JianyuZhan committed Aug 30, 2024
1 parent f414352 commit 60efeb7
Show file tree
Hide file tree
Showing 26 changed files with 989 additions and 548 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/e2e-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ jobs:
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Set PYTHONPATH
run: |
echo "PYTHONPATH=$PYTHONPATH:$(pwd)/python" >> $GITHUB_ENV
- name: Verify import
run: |
python3 -c "import sglang.srt.serving"
- name: Benchmark Serving Throughput
timeout-minutes: 10
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
human-eval/

# Translations
*.mo
Expand Down
25 changes: 25 additions & 0 deletions examples/usage/llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sglang import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The capital of China is",
"What is the meaning of life?",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="deepseek-ai/deepseek-llm-7b-chat", tensor_parallel_size=1)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
index = output["index"]
prompt = prompts[index]
answer = output["text"]
print("===============================")
print(f"Prompt: {prompt}")
print(f"Generated text: {output['text']}")
4 changes: 4 additions & 0 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SGL API Components

from sglang.api import (
LLM,
Runtime,
SamplingParams,
assistant,
assistant_begin,
assistant_end,
Expand Down Expand Up @@ -30,6 +32,8 @@

# SGLang DSL APIs
__all__ = [
"LLM",
"SamplingParams",
"Runtime",
"assistant",
"assistant_begin",
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
SglSelect,
SglVideo,
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.serving.engine import LLM


def function(
Expand All @@ -35,7 +37,7 @@ def decorator(func):
def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from sglang.srt.server import Runtime
from sglang.srt.serving.server import Runtime

return Runtime(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.serving.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers


Expand Down
4 changes: 2 additions & 2 deletions python/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import argparse
import os

from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.serving.server import launch_server
from sglang.srt.serving.server_args import ServerArgs
from sglang.srt.utils import kill_child_process

if __name__ == "__main__":
Expand Down
29 changes: 12 additions & 17 deletions python/sglang/srt/managers/controller_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.serving.engine_args import EngineArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback

Expand Down Expand Up @@ -69,22 +69,19 @@ class ControllerMulti:

def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args,
engine_args: EngineArgs,
):
# Parse args
self.server_args = server_args
self.port_args = port_args
self.model_overide_args = model_overide_args
self.engine_args = engine_args
self.model_overide_args = engine_args.model_override_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
engine_args.load_balance_method
)

# Init communication
context = zmq.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{engine_args.controller_port}")

# Dispatch method
self.round_robin_counter = 0
Expand All @@ -96,11 +93,11 @@ def __init__(

# Start data parallel workers
self.workers = []
for i in range(server_args.dp_size):
for i in range(engine_args.dp_size):
self.start_dp_worker(i)

def start_dp_worker(self, dp_worker_id: int):
tp_size = self.server_args.tp_size
tp_size = self.engine_args.tp_size

pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
duplex=False
Expand All @@ -111,7 +108,7 @@ def start_dp_worker(self, dp_worker_id: int):
proc = multiprocessing.Process(
target=start_controller_process_single,
args=(
self.server_args,
self.engine_args,
self.port_args,
pipe_controller_writer,
self.model_overide_args,
Expand Down Expand Up @@ -186,17 +183,15 @@ def recv_requests(self):


def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
engine_args: EngineArgs,
pipe_writer,
model_overide_args: dict,
):
"""Start a controller process."""

configure_logger(server_args)
configure_logger(engine_args.log_level)

try:
controller = ControllerMulti(server_args, port_args, model_overide_args)
controller = ControllerMulti(engine_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
Expand Down
38 changes: 15 additions & 23 deletions python/sglang/srt/managers/controller_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
broadcast_recv_input,
launch_tp_servers,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.serving.engine_args import EngineArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback

Expand All @@ -38,16 +38,14 @@ class ControllerSingle:

def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: dict,
engine_args: EngineArgs,
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
):
# Parse args
self.tp_size = server_args.tp_size
self.tp_size = engine_args.tp_size
self.is_dp_worker = is_data_parallel_worker
self.dp_worker_id = dp_worker_id
self.mp_queue = mp_queue
Expand All @@ -58,34 +56,32 @@ def __init__(
if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.controller_port}"
f"tcp://127.0.0.1:{engine_args.controller_port}"
)

self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
f"tcp://127.0.0.1:{engine_args.detokenizer_port}"
)

# Launch other tp ranks
tp_size_local = server_args.tp_size // server_args.nnodes
tp_size_local = engine_args.tp_size // engine_args.nnodes
self.tp_procs = []
if tp_size_local > 1:
tp_rank_range = range(1, tp_size_local)
self.tp_procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
port_args.nccl_ports[dp_worker_id],
model_overide_args,
engine_args.nccl_ports[dp_worker_id],
engine_args,
)

# Launch tp rank 0
self.tp_server = ModelTpServer(
gpu_ids[0],
0,
server_args,
port_args.nccl_ports[dp_worker_id],
model_overide_args,
engine_args.nccl_ports[dp_worker_id],
engine_args,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group

Expand Down Expand Up @@ -123,10 +119,8 @@ def recv_requests_from_mp_queue(self):


def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
engine_args: EngineArgs,
pipe_writer: multiprocessing.connection.Connection,
model_overide_args: dict,
is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None,
dp_worker_id: int = None,
Expand All @@ -137,19 +131,17 @@ def start_controller_process(
logger_prefix = f" DP{dp_worker_id} TP0"
else:
logger_prefix = " TP0"
configure_logger(server_args, prefix=logger_prefix)
configure_logger(engine_args.log_level, prefix=logger_prefix)

if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
tp_size_local = engine_args.tp_size // engine_args.nnodes
gpu_ids = [i for _ in range(engine_args.nnodes) for i in range(tp_size_local)]
dp_worker_id = 0
queue = None

try:
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
engine_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
Expand Down
27 changes: 14 additions & 13 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.serving.engine_args import EngineArgs
from sglang.utils import find_printable_text, get_exception_traceback

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
Expand All @@ -53,24 +53,23 @@ class DetokenizerManager:

def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
engine_args: EngineArgs,
):
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_router = context.socket(zmq.PULL)
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
self.recv_from_router.bind(f"tcp://127.0.0.1:{engine_args.detokenizer_port}")

self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{engine_args.tokenizer_port}")

if server_args.skip_tokenizer_init:
if engine_args.skip_tokenizer_init:
self.tokenizer = None
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
engine_args.tokenizer_path,
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=engine_args.trust_remote_code,
)

self.decode_status = {}
Expand Down Expand Up @@ -171,15 +170,17 @@ async def handle_loop(self):


def start_detokenizer_process(
server_args: ServerArgs,
port_args: PortArgs,
engine_args: EngineArgs,
pipe_writer,
):
try:
manager = DetokenizerManager(server_args, port_args)
manager = DetokenizerManager(engine_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
# Create a new event loop for this process because asyncio.get_event_loop()
# does not return a loop in a new thread or process in Python 3.10+.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(manager.handle_loop())
Loading

0 comments on commit 60efeb7

Please sign in to comment.