Skip to content

Commit bfb6281

Browse files
committed
WIP
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 807fb03 commit bfb6281

File tree

14 files changed

+1166
-0
lines changed

14 files changed

+1166
-0
lines changed

backend/python/mlx-vlm/Makefile

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
.PHONY: mlx
2+
mlx: protogen
3+
bash install.sh
4+
5+
.PHONY: run
6+
run: protogen
7+
@echo "Running mlx..."
8+
bash run.sh
9+
@echo "mlx run."
10+
11+
.PHONY: test
12+
test: protogen
13+
@echo "Testing mlx..."
14+
bash test.sh
15+
@echo "mlx tested."
16+
17+
.PHONY: protogen
18+
protogen: backend_pb2_grpc.py backend_pb2.py
19+
20+
.PHONY: protogen-clean
21+
protogen-clean:
22+
$(RM) backend_pb2_grpc.py backend_pb2.py
23+
24+
backend_pb2_grpc.py backend_pb2.py:
25+
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
26+
27+
.PHONY: clean
28+
clean: protogen-clean
29+
rm -rf venv __pycache__

backend/python/mlx-vlm/backend.py

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
#!/usr/bin/env python3
2+
import asyncio
3+
from concurrent import futures
4+
import argparse
5+
import signal
6+
import sys
7+
import os
8+
from typing import List
9+
from PIL import Image
10+
11+
import backend_pb2
12+
import backend_pb2_grpc
13+
14+
import grpc
15+
from vllm.engine.arg_utils import AsyncEngineArgs
16+
from vllm.engine.async_llm_engine import AsyncLLMEngine
17+
from vllm.sampling_params import SamplingParams
18+
from vllm.utils import random_uuid
19+
from vllm.transformers_utils.tokenizer import get_tokenizer
20+
from vllm.multimodal.utils import fetch_image
21+
from vllm.assets.video import VideoAsset
22+
import base64
23+
import io
24+
25+
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
26+
27+
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
28+
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
29+
30+
# Implement the BackendServicer class with the service methods
31+
class BackendServicer(backend_pb2_grpc.BackendServicer):
32+
"""
33+
A gRPC servicer that implements the Backend service defined in backend.proto.
34+
"""
35+
def generate(self,prompt, max_new_tokens):
36+
"""
37+
Generates text based on the given prompt and maximum number of new tokens.
38+
39+
Args:
40+
prompt (str): The prompt to generate text from.
41+
max_new_tokens (int): The maximum number of new tokens to generate.
42+
43+
Returns:
44+
str: The generated text.
45+
"""
46+
self.generator.end_beam_search()
47+
48+
# Tokenizing the input
49+
ids = self.generator.tokenizer.encode(prompt)
50+
51+
self.generator.gen_begin_reuse(ids)
52+
initial_len = self.generator.sequence[0].shape[0]
53+
has_leading_space = False
54+
decoded_text = ''
55+
for i in range(max_new_tokens):
56+
token = self.generator.gen_single_token()
57+
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
58+
has_leading_space = True
59+
60+
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
61+
if has_leading_space:
62+
decoded_text = ' ' + decoded_text
63+
64+
if token.item() == self.generator.tokenizer.eos_token_id:
65+
break
66+
return decoded_text
67+
68+
def Health(self, request, context):
69+
"""
70+
Returns a health check message.
71+
72+
Args:
73+
request: The health check request.
74+
context: The gRPC context.
75+
76+
Returns:
77+
backend_pb2.Reply: The health check reply.
78+
"""
79+
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
80+
81+
async def LoadModel(self, request, context):
82+
"""
83+
Loads a language model.
84+
85+
Args:
86+
request: The load model request.
87+
context: The gRPC context.
88+
89+
Returns:
90+
backend_pb2.Result: The load model result.
91+
"""
92+
engine_args = AsyncEngineArgs(
93+
model=request.Model,
94+
)
95+
96+
if request.Quantization != "":
97+
engine_args.quantization = request.Quantization
98+
if request.LoadFormat != "":
99+
engine_args.load_format = request.LoadFormat
100+
if request.GPUMemoryUtilization != 0:
101+
engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
102+
if request.TrustRemoteCode:
103+
engine_args.trust_remote_code = request.TrustRemoteCode
104+
if request.EnforceEager:
105+
engine_args.enforce_eager = request.EnforceEager
106+
if request.TensorParallelSize:
107+
engine_args.tensor_parallel_size = request.TensorParallelSize
108+
if request.SwapSpace != 0:
109+
engine_args.swap_space = request.SwapSpace
110+
if request.MaxModelLen != 0:
111+
engine_args.max_model_len = request.MaxModelLen
112+
if request.DisableLogStatus:
113+
engine_args.disable_log_status = request.DisableLogStatus
114+
if request.DType != "":
115+
engine_args.dtype = request.DType
116+
if request.LimitImagePerPrompt != 0 or request.LimitVideoPerPrompt != 0 or request.LimitAudioPerPrompt != 0:
117+
# limit-mm-per-prompt defaults to 1 per modality, based on vLLM docs
118+
engine_args.limit_mm_per_prompt = {
119+
"image": max(request.LimitImagePerPrompt, 1),
120+
"video": max(request.LimitVideoPerPrompt, 1),
121+
"audio": max(request.LimitAudioPerPrompt, 1)
122+
}
123+
124+
try:
125+
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
126+
except Exception as err:
127+
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
128+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
129+
130+
try:
131+
engine_model_config = await self.llm.get_model_config()
132+
self.tokenizer = get_tokenizer(
133+
engine_model_config.tokenizer,
134+
tokenizer_mode=engine_model_config.tokenizer_mode,
135+
trust_remote_code=engine_model_config.trust_remote_code,
136+
truncation_side="left",
137+
)
138+
except Exception as err:
139+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
140+
print("Model loaded successfully", file=sys.stderr)
141+
return backend_pb2.Result(message="Model loaded successfully", success=True)
142+
143+
async def Predict(self, request, context):
144+
"""
145+
Generates text based on the given prompt and sampling parameters.
146+
147+
Args:
148+
request: The predict request.
149+
context: The gRPC context.
150+
151+
Returns:
152+
backend_pb2.Reply: The predict result.
153+
"""
154+
gen = self._predict(request, context, streaming=False)
155+
res = await gen.__anext__()
156+
return res
157+
158+
def Embedding(self, request, context):
159+
"""
160+
A gRPC method that calculates embeddings for a given sentence.
161+
162+
Args:
163+
request: An EmbeddingRequest object that contains the request parameters.
164+
context: A grpc.ServicerContext object that provides information about the RPC.
165+
166+
Returns:
167+
An EmbeddingResult object that contains the calculated embeddings.
168+
"""
169+
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
170+
outputs = self.model.encode(request.Embeddings)
171+
# Check if we have one result at least
172+
if len(outputs) == 0:
173+
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
174+
context.set_details("No embeddings were calculated.")
175+
return backend_pb2.EmbeddingResult()
176+
return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding)
177+
178+
async def PredictStream(self, request, context):
179+
"""
180+
Generates text based on the given prompt and sampling parameters, and streams the results.
181+
182+
Args:
183+
request: The predict stream request.
184+
context: The gRPC context.
185+
186+
Returns:
187+
backend_pb2.Result: The predict stream result.
188+
"""
189+
iterations = self._predict(request, context, streaming=True)
190+
try:
191+
async for iteration in iterations:
192+
yield iteration
193+
finally:
194+
await iterations.aclose()
195+
196+
async def _predict(self, request, context, streaming=False):
197+
# Build the sampling parameters
198+
# NOTE: this must stay in sync with the vllm backend
199+
request_to_sampling_params = {
200+
"N": "n",
201+
"PresencePenalty": "presence_penalty",
202+
"FrequencyPenalty": "frequency_penalty",
203+
"RepetitionPenalty": "repetition_penalty",
204+
"Temperature": "temperature",
205+
"TopP": "top_p",
206+
"TopK": "top_k",
207+
"MinP": "min_p",
208+
"Seed": "seed",
209+
"StopPrompts": "stop",
210+
"StopTokenIds": "stop_token_ids",
211+
"BadWords": "bad_words",
212+
"IncludeStopStrInOutput": "include_stop_str_in_output",
213+
"IgnoreEOS": "ignore_eos",
214+
"Tokens": "max_tokens",
215+
"MinTokens": "min_tokens",
216+
"Logprobs": "logprobs",
217+
"PromptLogprobs": "prompt_logprobs",
218+
"SkipSpecialTokens": "skip_special_tokens",
219+
"SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
220+
"TruncatePromptTokens": "truncate_prompt_tokens",
221+
"GuidedDecoding": "guided_decoding",
222+
}
223+
224+
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
225+
226+
for request_field, param_field in request_to_sampling_params.items():
227+
if hasattr(request, request_field):
228+
value = getattr(request, request_field)
229+
if value not in (None, 0, [], False):
230+
setattr(sampling_params, param_field, value)
231+
232+
# Extract image paths and process images
233+
prompt = request.Prompt
234+
235+
image_paths = request.Images
236+
image_data = [self.load_image(img_path) for img_path in image_paths]
237+
238+
videos_path = request.Videos
239+
video_data = [self.load_video(video_path) for video_path in videos_path]
240+
241+
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
242+
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
243+
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
244+
245+
# Generate text using the LLM engine
246+
request_id = random_uuid()
247+
print(f"Generating text with request_id: {request_id}", file=sys.stderr)
248+
multi_modal_data = {}
249+
if image_data:
250+
multi_modal_data["image"] = image_data
251+
if video_data:
252+
multi_modal_data["video"] = video_data
253+
outputs = self.llm.generate(
254+
{
255+
"prompt": prompt,
256+
"multi_modal_data": multi_modal_data if multi_modal_data else None,
257+
},
258+
sampling_params=sampling_params,
259+
request_id=request_id,
260+
)
261+
262+
# Stream the results
263+
generated_text = ""
264+
try:
265+
async for request_output in outputs:
266+
iteration_text = request_output.outputs[0].text
267+
268+
if streaming:
269+
# Remove text already sent as vllm concatenates the text from previous yields
270+
delta_iteration_text = iteration_text.removeprefix(generated_text)
271+
# Send the partial result
272+
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
273+
274+
# Keep track of text generated
275+
generated_text = iteration_text
276+
finally:
277+
await outputs.aclose()
278+
279+
# If streaming, we already sent everything
280+
if streaming:
281+
return
282+
283+
# Remove the image files from /tmp folder
284+
for img_path in image_paths:
285+
try:
286+
os.remove(img_path)
287+
except Exception as e:
288+
print(f"Error removing image file: {img_path}, {e}", file=sys.stderr)
289+
290+
# Sending the final generated text
291+
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
292+
293+
def load_image(self, image_path: str):
294+
"""
295+
Load an image from the given file path or base64 encoded data.
296+
297+
Args:
298+
image_path (str): The path to the image file or base64 encoded data.
299+
300+
Returns:
301+
Image: The loaded image.
302+
"""
303+
try:
304+
305+
image_data = base64.b64decode(image_path)
306+
image = Image.open(io.BytesIO(image_data))
307+
return image
308+
except Exception as e:
309+
print(f"Error loading image {image_path}: {e}", file=sys.stderr)
310+
return None
311+
312+
def load_video(self, video_path: str):
313+
"""
314+
Load a video from the given file path.
315+
316+
Args:
317+
video_path (str): The path to the image file.
318+
319+
Returns:
320+
Video: The loaded video.
321+
"""
322+
try:
323+
timestamp = str(int(time.time() * 1000)) # Generate timestamp
324+
p = f"/tmp/vl-{timestamp}.data" # Use timestamp in filename
325+
with open(p, "wb") as f:
326+
f.write(base64.b64decode(video_path))
327+
video = VideoAsset(name=p).np_ndarrays
328+
os.remove(p)
329+
return video
330+
except Exception as e:
331+
print(f"Error loading video {video_path}: {e}", file=sys.stderr)
332+
return None
333+
334+
async def serve(address):
335+
# Start asyncio gRPC server
336+
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
337+
options=[
338+
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
339+
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
340+
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
341+
])
342+
# Add the servicer to the server
343+
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
344+
# Bind the server to the address
345+
server.add_insecure_port(address)
346+
347+
# Gracefully shutdown the server on SIGTERM or SIGINT
348+
loop = asyncio.get_event_loop()
349+
for sig in (signal.SIGINT, signal.SIGTERM):
350+
loop.add_signal_handler(
351+
sig, lambda: asyncio.ensure_future(server.stop(5))
352+
)
353+
354+
# Start the server
355+
await server.start()
356+
print("Server started. Listening on: " + address, file=sys.stderr)
357+
# Wait for the server to be terminated
358+
await server.wait_for_termination()
359+
360+
if __name__ == "__main__":
361+
parser = argparse.ArgumentParser(description="Run the gRPC server.")
362+
parser.add_argument(
363+
"--addr", default="localhost:50051", help="The address to bind the server to."
364+
)
365+
args = parser.parse_args()
366+
367+
asyncio.run(serve(args.addr))

0 commit comments

Comments
 (0)