Skip to content

Commit

Permalink
[Python] Extract common device str parse function in ChatModule (mlc-…
Browse files Browse the repository at this point in the history
…ai#1074)

This PR lifts the device string parsing (just a few of lines)
to a standalone function, so that on the serving side the serving
can make use of this function as well.

Tested Python API and it does not seem to incur regression.
  • Loading branch information
MasterJH5574 authored Oct 16, 2023
1 parent d202077 commit 9872c48
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,38 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio
return json.dumps(asdict(generation_config))


def _parse_device_str(device: str):
"""Parse the input device identifier into device name and id.
Parameters
----------
device : str
The device identifier to parse.
It can be "device_name" (e.g., "cuda") or
"device_name:device_id" (e.g., "cuda:1").
Returns
-------
device_name : str
The name of the device.
device_id : int
The id of the device, or 0 if not specified in the input.
"""
device_err_msg = (
f"Invalid device name: {device}. Please enter the device in the form "
"'device_name:device_id' or 'device_name', where 'device_name' needs to be "
"one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'."
)
device_args = device.split(":")
if len(device_args) == 1:
return device_args[0], 0
elif len(device_args) == 2:
return device_args[0], int(device_args[1])
elif len(device_args) > 2:
raise ValueError(device_err_msg)


def _detect_local_device(device_id: int = 0):
"""Automatically detect the local device if user does not specify.
Expand Down Expand Up @@ -647,13 +679,7 @@ def __init__(
)

# 0. Retrieve device_name and device_id (if any, default 0) from device arg
device_args = device.split(":")
if len(device_args) == 1:
device_name, device_id = device_args[0], 0
elif len(device_args) == 2:
device_name, device_id = device_args[0], int(device_args[1])
elif len(device_args) > 2:
raise ValueError(device_err_msg)
device_name, device_id = _parse_device_str(device)

# 1. Get self.device
if device_name == "cuda":
Expand Down

0 comments on commit 9872c48

Please sign in to comment.