Skip to content

Commit

Permalink
ENH: [worker] Allow init supervisor_ref lazy (xorbitsai#1958)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostyplanet authored Aug 9, 2024
1 parent 9a02b40 commit 93e6604
Showing 1 changed file with 124 additions and 84 deletions.
208 changes: 124 additions & 84 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
# static attrs.
self._total_gpu_devices = gpu_devices
self._supervisor_address = supervisor_address
self._supervisor_ref = None
self._supervisor_ref: Optional[xo.ActorRefType] = None
self._main_pool = main_pool
self._main_pool.recover_sub_pool = self.recover_sub_pool

Expand Down Expand Up @@ -147,14 +147,15 @@ async def recover_sub_pool(self, address):
)
event_model_uid, _, __ = parse_replica_model_uid(model_uid)
try:
await self._event_collector_ref.report_event(
event_model_uid,
Event(
event_type=EventType.WARNING,
event_ts=int(time.time()),
event_content="Recreate model",
),
)
if self._event_collector_ref is not None:
await self._event_collector_ref.report_event(
event_model_uid,
Event(
event_type=EventType.WARNING,
event_ts=int(time.time()),
event_content="Recreate model",
),
)
except Exception as e:
# Report callback error can be log and ignore, should not interrupt the Process
logger.error("report_event error: %s" % (e))
Expand All @@ -177,80 +178,39 @@ def uid(cls) -> str:
return "worker"

async def __post_create__(self):
from ..isolation import Isolation
from .cache_tracker import CacheTrackerActor
from .status_guard import StatusGuardActor
from .supervisor import SupervisorActor

self._status_guard_ref: xo.ActorRefType[ # type: ignore
"StatusGuardActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.uid()
)
self._event_collector_ref: xo.ActorRefType[ # type: ignore
EventCollectorActor
] = await xo.actor_ref(
address=self._supervisor_address, uid=EventCollectorActor.uid()
)
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
"CacheTrackerActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=CacheTrackerActor.uid()
)
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
address=self._supervisor_address, uid=SupervisorActor.uid()
)
await self._supervisor_ref.add_worker(self.address)
if not XINFERENCE_DISABLE_HEALTH_CHECK:
# Run _periodical_report_status() in a dedicated thread.
self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
self._isolation.start()
asyncio.run_coroutine_threadsafe(
self._periodical_report_status(), loop=self._isolation.loop
)
logger.info(f"Xinference worker {self.address} started")
logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
purge_dir(XINFERENCE_CACHE_DIR)

from ..model.audio import (
CustomAudioModelFamilyV1,
generate_audio_description,
get_audio_model_descriptions,
register_audio,
unregister_audio,
)
from ..model.embedding import (
CustomEmbeddingModelSpec,
generate_embedding_description,
get_embedding_model_descriptions,
register_embedding,
unregister_embedding,
)
from ..model.flexible import (
FlexibleModelSpec,
generate_flexible_model_description,
get_flexible_model_descriptions,
register_flexible_model,
unregister_flexible_model,
)
from ..model.image import (
CustomImageModelFamilyV1,
generate_image_description,
get_image_model_descriptions,
register_image,
unregister_image,
)
from ..model.llm import (
CustomLLMFamilyV1,
generate_llm_description,
get_llm_model_descriptions,
register_llm,
unregister_llm,
)
from ..model.rerank import (
CustomRerankModelSpec,
generate_rerank_description,
get_rerank_model_descriptions,
register_rerank,
unregister_rerank,
)
Expand Down Expand Up @@ -294,24 +254,33 @@ async def __post_create__(self):
),
}

# record model version
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
model_version_infos.update(get_audio_model_descriptions())
model_version_infos.update(get_flexible_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)
logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
purge_dir(XINFERENCE_CACHE_DIR)

try:
await self.get_supervisor_ref(add_worker=True)
except Exception as e:
# Do not crash the worker if supervisor is down, auto re-connect later
logger.error(f"cannot connect to supervisor {e}")

if not XINFERENCE_DISABLE_HEALTH_CHECK:
from ..isolation import Isolation

# Run _periodical_report_status() in a dedicated thread.
self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
self._isolation.start()
asyncio.run_coroutine_threadsafe(
self._periodical_report_status(), loop=self._isolation.loop
)
logger.info(f"Xinference worker {self.address} started")

# Windows does not have signal handler
if os.name != "nt":

async def signal_handler():
try:
await self._supervisor_ref.remove_worker(self.address)
supervisor_ref = await self.get_supervisor_ref(add_worker=False)
await supervisor_ref.remove_worker(self.address)
except Exception as e:
# Ignore the error of rpc, anyway we are exiting
logger.exception("remove worker rpc error: %s", e)
Expand All @@ -333,6 +302,64 @@ async def trigger_exit(self) -> bool:
return False
return True

async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
"""
Try connect to supervisor and return ActorRef. Raise exception on error
Params:
add_worker: By default will call supervisor.add_worker after first connect
"""
from .status_guard import StatusGuardActor
from .supervisor import SupervisorActor

if self._supervisor_ref is not None:
return self._supervisor_ref
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
address=self._supervisor_address, uid=SupervisorActor.uid()
)
if add_worker and len(self._model_uid_to_model) == 0:
# Newly started (or restarted), has no model, notify supervisor
await self._supervisor_ref.add_worker(self.address)
logger.info("Connected to supervisor as a fresh worker")

self._status_guard_ref: xo.ActorRefType[ # type: ignore
"StatusGuardActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.uid()
)

self._event_collector_ref: xo.ActorRefType[ # type: ignore
EventCollectorActor
] = await xo.actor_ref(
address=self._supervisor_address, uid=EventCollectorActor.uid()
)
from .cache_tracker import CacheTrackerActor

self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
"CacheTrackerActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=CacheTrackerActor.uid()
)
# cache_tracker is on supervisor
from ..model.audio import get_audio_model_descriptions
from ..model.embedding import get_embedding_model_descriptions
from ..model.flexible import get_flexible_model_descriptions
from ..model.image import get_image_model_descriptions
from ..model.llm import get_llm_model_descriptions
from ..model.rerank import get_rerank_model_descriptions

# record model version
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
model_version_infos.update(get_audio_model_descriptions())
model_version_infos.update(get_flexible_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)
return self._supervisor_ref

@staticmethod
def get_devices_count():
from ..device_utils import gpu_count
Expand All @@ -344,9 +371,9 @@ def get_model_count(self) -> int:
return len(self._model_uid_to_model)

async def is_model_vllm_backend(self, model_uid: str) -> bool:
assert self._supervisor_ref is not None
_model_uid, _, _ = parse_replica_model_uid(model_uid)
model_ref = await self._supervisor_ref.get_model(_model_uid)
supervisor_ref = await self.get_supervisor_ref()
model_ref = await supervisor_ref.get_model(_model_uid)
return await model_ref.is_vllm_backend()

async def allocate_devices_for_embedding(self, model_uid: str) -> int:
Expand Down Expand Up @@ -764,14 +791,15 @@ async def launch_builtin_model(
logger.exception(e)
raise
try:
await self._event_collector_ref.report_event(
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
event_content="Launch model",
),
)
if self._event_collector_ref is not None:
await self._event_collector_ref.report_event(
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
event_content="Launch model",
),
)
except Exception as e:
# Report callback error can be log and ignore, should not interrupt the Process
logger.error("report_event error: %s" % (e))
Expand Down Expand Up @@ -867,6 +895,11 @@ async def launch_builtin_model(

# update status to READY
abilities = await self._get_model_ability(model, model_type)
_ = await self.get_supervisor_ref(add_worker=False)

if self._status_guard_ref is None:
_ = await self.get_supervisor_ref()
assert self._status_guard_ref is not None
await self._status_guard_ref.update_instance_info(
origin_uid,
{"model_ability": abilities, "status": LaunchStatus.READY.name},
Expand All @@ -879,21 +912,23 @@ async def terminate_model(self, model_uid: str):
raise ValueError(f"{model_uid} is launching")
origin_uid, _, __ = parse_replica_model_uid(model_uid)
try:
await self._event_collector_ref.report_event(
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
event_content="Terminate model",
),
)
if self._event_collector_ref is not None:
await self._event_collector_ref.report_event(
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
event_content="Terminate model",
),
)
except Exception as e:
# Report callback error can be log and ignore, should not interrupt the Process
logger.error("report_event error: %s" % (e))

await self._status_guard_ref.update_instance_info(
origin_uid, {"status": LaunchStatus.TERMINATING.name}
)
if self._status_guard_ref is not None:
await self._status_guard_ref.update_instance_info(
origin_uid, {"status": LaunchStatus.TERMINATING.name}
)
model_ref = self._model_uid_to_model.get(model_uid, None)
if model_ref is None:
logger.debug("Model not found, uid: %s", model_uid)
Expand All @@ -918,6 +953,10 @@ async def terminate_model(self, model_uid: str):
self._model_uid_to_addr.pop(model_uid, None)
self._model_uid_to_recover_count.pop(model_uid, None)
self._model_uid_to_launch_args.pop(model_uid, None)

if self._status_guard_ref is None:
_ = await self.get_supervisor_ref()
assert self._status_guard_ref is not None
await self._status_guard_ref.update_instance_info(
origin_uid, {"status": LaunchStatus.TERMINATED.name}
)
Expand Down Expand Up @@ -970,7 +1009,8 @@ async def report_status(self):
raise
except Exception:
logger.exception("Report status got error.")
await self._supervisor_ref.report_worker_status(self.address, status)
supervisor_ref = await self.get_supervisor_ref()
await supervisor_ref.report_worker_status(self.address, status)

async def _periodical_report_status(self):
while True:
Expand Down

0 comments on commit 93e6604

Please sign in to comment.