Skip to content

Commit

Permalink
ENH: Added the parameter 'worker_ip' to the 'register' model. (#1773)
Browse files Browse the repository at this point in the history
Co-authored-by: wuzhaoxin <15667065080@162.com>
  • Loading branch information
hainaweiben and wuzhaoxin authored Jul 12, 2024
1 parent e916d05 commit 5e3f254
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 25 deletions.
4 changes: 3 additions & 1 deletion xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class SpeechRequest(BaseModel):

class RegisterModelRequest(BaseModel):
model: str
worker_ip: Optional[str]
persist: bool


Expand Down Expand Up @@ -1639,11 +1640,12 @@ async def query_engines_by_model_name(self, model_name: str) -> JSONResponse:
async def register_model(self, model_type: str, request: Request) -> JSONResponse:
body = RegisterModelRequest.parse_obj(await request.json())
model = body.model
worker_ip = body.worker_ip
persist = body.persist

try:
await (await self._get_supervisor_ref()).register_model(
model_type, model, persist
model_type, model, persist, worker_ip
)
except ValueError as re:
logger.error(re, exc_info=True)
Expand Down
12 changes: 10 additions & 2 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,13 @@ def describe_model(self, model_uid: str):
)
return response.json()

def register_model(self, model_type: str, model: str, persist: bool):
def register_model(
self,
model_type: str,
model: str,
persist: bool,
worker_ip: Optional[str] = None,
):
"""
Register a custom model.
Expand All @@ -1111,6 +1117,8 @@ def register_model(self, model_type: str, model: str, persist: bool):
The type of model.
model: str
The model definition. (refer to: https://inference.readthedocs.io/en/latest/models/custom.html)
worker_ip: Optional[str]
The IP address of the worker on which the model is running.
persist: bool
Expand All @@ -1120,7 +1128,7 @@ def register_model(self, model_type: str, model: str, persist: bool):
Report failure to register the custom model. Provide details of failure through error message.
"""
url = f"{self.base_url}/v1/model_registrations/{model_type}"
request_body = {"model": model, "persist": persist}
request_body = {"model": model, "worker_ip": worker_ip, "persist": persist}
response = requests.post(url, json=request_body, headers=self._headers)
if response.status_code != 200:
raise RuntimeError(
Expand Down
73 changes: 59 additions & 14 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,10 +513,15 @@ def sort_helper(item):
assert isinstance(item["model_name"], str)
return item.get("model_name").lower()

ret = []
if not self.is_local_deployment():
workers = list(self._worker_address_to_worker.values())
for worker in workers:
ret.extend(await worker.list_model_registrations(model_type, detailed))

if model_type == "LLM":
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families

ret = []
for family in BUILTIN_LLM_FAMILIES:
if detailed:
ret.append(await self._to_llm_reg(family, True))
Expand All @@ -535,7 +540,6 @@ def sort_helper(item):
from ..model.embedding import BUILTIN_EMBEDDING_MODELS
from ..model.embedding.custom import get_user_defined_embeddings

ret = []
for model_name, family in BUILTIN_EMBEDDING_MODELS.items():
if detailed:
ret.append(
Expand All @@ -560,7 +564,6 @@ def sort_helper(item):
from ..model.image import BUILTIN_IMAGE_MODELS
from ..model.image.custom import get_user_defined_images

ret = []
for model_name, family in BUILTIN_IMAGE_MODELS.items():
if detailed:
ret.append(await self._to_image_model_reg(family, is_builtin=True))
Expand All @@ -583,7 +586,6 @@ def sort_helper(item):
from ..model.audio import BUILTIN_AUDIO_MODELS
from ..model.audio.custom import get_user_defined_audios

ret = []
for model_name, family in BUILTIN_AUDIO_MODELS.items():
if detailed:
ret.append(await self._to_audio_model_reg(family, is_builtin=True))
Expand All @@ -606,7 +608,6 @@ def sort_helper(item):
from ..model.rerank import BUILTIN_RERANK_MODELS
from ..model.rerank.custom import get_user_defined_reranks

ret = []
for model_name, family in BUILTIN_RERANK_MODELS.items():
if detailed:
ret.append(await self._to_rerank_model_reg(family, is_builtin=True))
Expand Down Expand Up @@ -646,7 +647,15 @@ def sort_helper(item):
raise ValueError(f"Unsupported model type: {model_type}")

@log_sync(logger=logger)
def get_model_registration(self, model_type: str, model_name: str) -> Any:
async def get_model_registration(self, model_type: str, model_name: str) -> Any:
# search in worker first
if not self.is_local_deployment():
workers = list(self._worker_address_to_worker.values())
for worker in workers:
f = await worker.get_model_registration(model_type, model_name)
if f is not None:
return f

if model_type == "LLM":
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families

Expand Down Expand Up @@ -705,6 +714,13 @@ async def query_engines_by_model_name(self, model_name: str):

from ..model.llm.llm_family import LLM_ENGINES

# search in worker first
workers = list(self._worker_address_to_worker.values())
for worker in workers:
res = await worker.query_engines_by_model_name(model_name)
if res is not None:
return res

if model_name not in LLM_ENGINES:
raise ValueError(f"Model {model_name} not found")

Expand All @@ -718,7 +734,13 @@ async def query_engines_by_model_name(self, model_name: str):
return engine_params

@log_async(logger=logger)
async def register_model(self, model_type: str, model: str, persist: bool):
async def register_model(
self,
model_type: str,
model: str,
persist: bool,
worker_ip: Optional[str] = None,
):
if model_type in self._custom_register_type_to_cls:
(
model_spec_cls,
Expand All @@ -727,17 +749,30 @@ async def register_model(self, model_type: str, model: str, persist: bool):
generate_fn,
) = self._custom_register_type_to_cls[model_type]

if not self.is_local_deployment():
workers = list(self._worker_address_to_worker.values())
for worker in workers:
await worker.register_model(model_type, model, persist)
target_ip_worker_ref = (
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
)
if (
worker_ip is not None
and not self.is_local_deployment()
and target_ip_worker_ref is None
):
raise ValueError(
f"Worker ip address {worker_ip} is not in the cluster."
)

if target_ip_worker_ref:
await target_ip_worker_ref.register_model(model_type, model, persist)
return

model_spec = model_spec_cls.parse_raw(model)
try:
register_fn(model_spec, persist)
await self._cache_tracker_ref.record_model_version(
generate_fn(model_spec), self.address
)
except ValueError as e:
raise e
except Exception as e:
unregister_fn(model_spec.model_name, raise_error=False)
raise e
Expand All @@ -748,13 +783,14 @@ async def register_model(self, model_type: str, model: str, persist: bool):
async def unregister_model(self, model_type: str, model_name: str):
if model_type in self._custom_register_type_to_cls:
_, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type]
unregister_fn(model_name)
await self._cache_tracker_ref.unregister_model_version(model_name)
unregister_fn(model_name, False)

if not self.is_local_deployment():
workers = list(self._worker_address_to_worker.values())
for worker in workers:
await worker.unregister_model(model_name)
await worker.unregister_model(model_type, model_name)

await self._cache_tracker_ref.unregister_model_version(model_name)
else:
raise ValueError(f"Unsupported model type: {model_type}")

Expand Down Expand Up @@ -825,6 +861,14 @@ async def launch_builtin_model(
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
**kwargs,
) -> str:
# search in worker first
if not self.is_local_deployment():
workers = list(self._worker_address_to_worker.values())
for worker in workers:
res = await worker.get_model_registration(model_type, model_name)
if res is not None:
worker_ip = worker.address.split(":")[0]

target_ip_worker_ref = (
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
)
Expand Down Expand Up @@ -877,6 +921,7 @@ async def _launch_one_model(_replica_model_uid):
)
replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
nonlocal model_type

worker_ref = (
target_ip_worker_ref
if target_ip_worker_ref is not None
Expand Down
Loading

0 comments on commit 5e3f254

Please sign in to comment.