Skip to content

Commit 744fcbd

Browse files
committed
[REFACTOR] Migrate JSONFFIEngine to formal namespace
This PR migrates JSONFFIEngine to a formal namespace. Also list TODOs to further simplify the JSONFFIEngine.
1 parent fd65973 commit 744fcbd

File tree

3 files changed

+318
-295
lines changed

3 files changed

+318
-295
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""JSON FFI is a pure string based interface of MLC LLM Engine.
2+
3+
We build interfacing with JSON FFI for both testing purposes
4+
and internal use. For most python API usage, please use MLCEngine
5+
and MLCAsyncEngine
6+
"""
7+
8+
from .engine import JSONFFIEngine

python/mlc_llm/json_ffi/engine.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# pylint: disable=chained-comparison,line-too-long,missing-docstring,
2+
# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable
3+
import json
4+
import queue
5+
import threading
6+
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union
7+
8+
import tvm
9+
10+
from mlc_llm.protocol import openai_api_protocol
11+
from mlc_llm.serve import engine_utils
12+
from mlc_llm.serve.engine_base import (
13+
EngineConfig,
14+
SpeculativeMode,
15+
_infer_kv_cache_config,
16+
_parse_models,
17+
_process_model_args,
18+
detect_device,
19+
)
20+
from mlc_llm.tokenizer import Tokenizer
21+
22+
23+
# TODO(mlc-team): further minimize the JSONFFIEngine
24+
# construction to not depend on any config and directly pass in JSON
25+
# model defined generation config should be read from the JSONFFIEngine via Reload
26+
def create_model_defined_generation_config(
27+
temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float
28+
) -> tvm.runtime.Object:
29+
return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")(
30+
temperature,
31+
top_p,
32+
frequency_penalty,
33+
presence_penalty,
34+
)
35+
36+
# TODO(mlc-team): further minimize the JSONFFIEngine
37+
# Engine config should be passed as json str
38+
# and backend should have good default
39+
# only model and model_lib should be mandatory
40+
def create_json_ffi_engine_config(
41+
conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object]
42+
) -> tvm.runtime.Object:
43+
return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")(
44+
conv_template, model_generation_cfgs
45+
)
46+
47+
48+
class EngineState:
49+
sync_queue: queue.Queue
50+
51+
def get_request_stream_callback(self) -> Callable[[List[str]], None]:
52+
# ChatCompletionStreamResponse
53+
54+
def _callback(chat_completion_stream_responses_json_str: List[str]) -> None:
55+
self._sync_request_stream_callback(chat_completion_stream_responses_json_str)
56+
57+
return _callback
58+
59+
def _sync_request_stream_callback(
60+
self, chat_completion_stream_responses_json_str: List[str]
61+
) -> None:
62+
# Put the delta outputs to the queue in the unblocking way.
63+
self.sync_queue.put_nowait(chat_completion_stream_responses_json_str)
64+
65+
66+
class JSONFFIEngine:
67+
def __init__( # pylint: disable=too-many-arguments,too-many-locals
68+
self,
69+
model: str,
70+
device: Union[str, tvm.runtime.Device] = "auto",
71+
*,
72+
model_lib_path: Optional[str] = None,
73+
mode: Literal["local", "interactive", "server"] = "local",
74+
additional_models: Optional[List[str]] = None,
75+
max_batch_size: Optional[int] = None,
76+
max_total_sequence_length: Optional[int] = None,
77+
max_history_size: Optional[int] = None,
78+
prefill_chunk_size: Optional[int] = None,
79+
speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE,
80+
spec_draft_length: int = 4,
81+
gpu_memory_utilization: Optional[float] = None,
82+
) -> None:
83+
# - Initialize model loading info.
84+
models = _parse_models(model, model_lib_path, additional_models)
85+
if isinstance(device, str):
86+
device = detect_device(device)
87+
assert isinstance(device, tvm.runtime.Device)
88+
(
89+
model_args,
90+
model_config_paths,
91+
self.conv_template,
92+
) = _process_model_args(models, device)
93+
94+
# TODO(mlc-team) Remove the model config parsing, estimation below
95+
# in favor of a simple direct passing of parameters into backend.
96+
# JSONFFIEngine do not have to support automatic mode
97+
#
98+
# Instead, its config should default to interactive mode always
99+
# and allow overrides of parameters through json config via reload
100+
#
101+
# This is to simplify the logic of users of JSONFFI
102+
# since we won't have similar logics in android/iOS
103+
#
104+
# - Load the raw model config into dict
105+
self.model_config_dicts = []
106+
for i, model_info in enumerate(models):
107+
model_info.model_lib_path = model_args[i][1]
108+
with open(model_config_paths[i], "r", encoding="utf-8") as file:
109+
self.model_config_dicts.append(json.load(file))
110+
111+
# - Decide the KV cache config based on mode and user input.
112+
(
113+
max_batch_size,
114+
max_total_sequence_length,
115+
prefill_chunk_size,
116+
max_single_sequence_length,
117+
max_history_size,
118+
kv_state_kind,
119+
) = _infer_kv_cache_config(
120+
mode,
121+
max_batch_size,
122+
max_total_sequence_length,
123+
prefill_chunk_size,
124+
max_history_size,
125+
gpu_memory_utilization,
126+
models,
127+
device,
128+
self.model_config_dicts,
129+
model_config_paths,
130+
)
131+
self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length)
132+
133+
# - Initialize engine state and engine.
134+
self.state = EngineState()
135+
module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)()
136+
self._ffi = {
137+
key: module[key]
138+
for key in [
139+
"init_background_engine",
140+
"reload",
141+
"unload",
142+
"reset",
143+
"chat_completion",
144+
"abort",
145+
"get_last_error",
146+
"run_background_loop",
147+
"run_background_stream_back_loop",
148+
"exit_background_loop",
149+
]
150+
}
151+
self.tokenizer = Tokenizer(model_args[0][0])
152+
153+
self.engine_config = EngineConfig(
154+
model=model_args[0][0],
155+
model_lib_path=model_args[0][1],
156+
additional_models=[model_arg[0] for model_arg in model_args[1:]],
157+
additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]],
158+
kv_cache_page_size=16,
159+
max_num_sequence=max_batch_size,
160+
max_total_sequence_length=max_total_sequence_length,
161+
max_single_sequence_length=max_single_sequence_length,
162+
prefill_chunk_size=prefill_chunk_size,
163+
max_history_size=max_history_size,
164+
kv_state_kind=kv_state_kind,
165+
speculative_mode=speculative_mode,
166+
spec_draft_length=spec_draft_length,
167+
)
168+
169+
self.json_ffi_engine_config = create_json_ffi_engine_config(
170+
conv_template=self.conv_template.model_dump_json(),
171+
model_generation_cfgs={
172+
model.model: create_model_defined_generation_config(
173+
temperature=model_config["temperature"],
174+
top_p=model_config["top_p"],
175+
frequency_penalty=model_config["frequency_penalty"],
176+
presence_penalty=model_config["presence_penalty"],
177+
)
178+
for model, model_config in zip(models, self.model_config_dicts)
179+
},
180+
)
181+
182+
self._ffi["init_background_engine"](
183+
self.json_ffi_engine_config,
184+
self.engine_config,
185+
device,
186+
self.state.get_request_stream_callback(),
187+
None,
188+
)
189+
190+
def _background_loop():
191+
self._ffi["run_background_loop"]()
192+
193+
def _background_stream_back_loop():
194+
self._ffi["run_background_stream_back_loop"]()
195+
196+
# Create the background engine-driving thread and start the loop.
197+
self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop)
198+
self._background_stream_back_loop_thread: threading.Thread = threading.Thread(
199+
target=_background_stream_back_loop
200+
)
201+
self._background_loop_thread.start()
202+
self._background_stream_back_loop_thread.start()
203+
self._terminated = False
204+
205+
def terminate(self):
206+
self._terminated = True
207+
self._ffi["exit_background_loop"]()
208+
self._background_loop_thread.join()
209+
self._background_stream_back_loop_thread.join()
210+
211+
def chat_completion( # pylint: disable=too-many-arguments,too-many-locals
212+
self,
213+
*,
214+
messages: List[Dict[str, Any]],
215+
model: str,
216+
frequency_penalty: Optional[float] = None,
217+
presence_penalty: Optional[float] = None,
218+
logprobs: bool = False,
219+
top_logprobs: int = 0,
220+
logit_bias: Optional[Dict[int, float]] = None,
221+
max_tokens: Optional[int] = None,
222+
n: int = 1,
223+
seed: Optional[int] = None,
224+
stop: Optional[Union[str, List[str]]] = None,
225+
stream: bool = False,
226+
temperature: Optional[float] = None,
227+
top_p: Optional[float] = None,
228+
tools: Optional[List[Dict[str, Any]]] = None,
229+
tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None,
230+
user: Optional[str] = None,
231+
ignore_eos: bool = False,
232+
response_format: Optional[Dict[str, Any]] = None,
233+
request_id: Optional[str] = None,
234+
) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]:
235+
if request_id is None:
236+
request_id = f"chatcmpl-{engine_utils.random_uuid()}"
237+
238+
chatcmpl_generator = self._handle_chat_completion(
239+
openai_api_protocol.ChatCompletionRequest(
240+
messages=[
241+
openai_api_protocol.ChatCompletionMessage.model_validate(message)
242+
for message in messages
243+
],
244+
model=model,
245+
frequency_penalty=frequency_penalty,
246+
presence_penalty=presence_penalty,
247+
logprobs=logprobs,
248+
top_logprobs=top_logprobs,
249+
logit_bias=logit_bias,
250+
max_tokens=max_tokens,
251+
n=n,
252+
seed=seed,
253+
stop=stop,
254+
stream=stream,
255+
temperature=temperature,
256+
top_p=top_p,
257+
tools=(
258+
[openai_api_protocol.ChatTool.model_validate(tool) for tool in tools]
259+
if tools is not None
260+
else None
261+
),
262+
tool_choice=tool_choice,
263+
user=user,
264+
ignore_eos=ignore_eos,
265+
response_format=(
266+
openai_api_protocol.RequestResponseFormat.model_validate(response_format)
267+
if response_format is not None
268+
else None
269+
),
270+
).model_dump_json(),
271+
n=n,
272+
request_id=request_id,
273+
)
274+
for response in chatcmpl_generator:
275+
yield response
276+
277+
def _handle_chat_completion(
278+
self, request_json_str: str, n: int, request_id: str
279+
) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]:
280+
self.state.sync_queue = queue.Queue()
281+
num_unfinished_requests = n
282+
283+
success = bool(self._ffi["chat_completion"](request_json_str, request_id))
284+
285+
try:
286+
while num_unfinished_requests > 0:
287+
chat_completion_stream_responses_json_str = self.state.sync_queue.get()
288+
for chat_completion_response_json_str in chat_completion_stream_responses_json_str:
289+
chat_completion_response = (
290+
openai_api_protocol.ChatCompletionStreamResponse.model_validate_json(
291+
chat_completion_response_json_str
292+
)
293+
)
294+
for choice in chat_completion_response.choices:
295+
if choice.finish_reason is not None:
296+
num_unfinished_requests -= 1
297+
yield chat_completion_response
298+
except Exception as exception: # pylint: disable=broad-exception-caught
299+
self._ffi["abort"](request_id)
300+
raise exception
301+
302+
def _test_reload(self):
303+
self._ffi["reload"](self.engine_config)
304+
305+
def _test_reset(self):
306+
self._ffi["reset"]()
307+
308+
def _test_unload(self):
309+
self._ffi["unload"]()

0 commit comments

Comments
 (0)