Skip to content

Commit

Permalink
BUG: Fix concurrent ops in worker initialization (#2125)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostyplanet authored Aug 23, 2024
1 parent 4d35b6f commit b66c1d4
Showing 1 changed file with 46 additions and 39 deletions.
85 changes: 46 additions & 39 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
from ..device_utils import get_available_device_env_name, gpu_count
from ..model.core import ModelDescription, create_model_instance
from ..types import PeftModelConfig
from .cache_tracker import CacheTrackerActor
from .event import Event, EventCollectorActor, EventType
from .metrics import launch_metrics_export_server, record_metrics
from .resource import gather_node_info
from .status_guard import StatusGuardActor
from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir

logger = getLogger(__name__)
Expand Down Expand Up @@ -71,6 +73,15 @@ def __init__(
self._supervisor_ref: Optional[xo.ActorRefType] = None
self._main_pool = main_pool
self._main_pool.recover_sub_pool = self.recover_sub_pool
self._status_guard_ref: xo.ActorRefType[ # type: ignore
"StatusGuardActor"
] = None
self._event_collector_ref: xo.ActorRefType[ # type: ignore
EventCollectorActor
] = None
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
CacheTrackerActor
] = None

# internal states.
# temporary placeholder during model launch process:
Expand Down Expand Up @@ -308,56 +319,50 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
Params:
add_worker: By default will call supervisor.add_worker after first connect
"""
from .status_guard import StatusGuardActor
from .supervisor import SupervisorActor

if self._supervisor_ref is not None:
return self._supervisor_ref
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
supervisor_ref = await xo.actor_ref( # type: ignore
address=self._supervisor_address, uid=SupervisorActor.uid()
)
# Prevent concurrent operations leads to double initialization, check again.
if self._supervisor_ref is not None:
return self._supervisor_ref
self._supervisor_ref = supervisor_ref
if add_worker and len(self._model_uid_to_model) == 0:
# Newly started (or restarted), has no model, notify supervisor
await self._supervisor_ref.add_worker(self.address)
logger.info("Connected to supervisor as a fresh worker")

self._status_guard_ref: xo.ActorRefType[ # type: ignore
"StatusGuardActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.uid()
)

self._event_collector_ref: xo.ActorRefType[ # type: ignore
EventCollectorActor
] = await xo.actor_ref(
address=self._supervisor_address, uid=EventCollectorActor.uid()
)
from .cache_tracker import CacheTrackerActor

self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
"CacheTrackerActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=CacheTrackerActor.uid()
)
# cache_tracker is on supervisor
from ..model.audio import get_audio_model_descriptions
from ..model.embedding import get_embedding_model_descriptions
from ..model.flexible import get_flexible_model_descriptions
from ..model.image import get_image_model_descriptions
from ..model.llm import get_llm_model_descriptions
from ..model.rerank import get_rerank_model_descriptions

# record model version
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
model_version_infos.update(get_audio_model_descriptions())
model_version_infos.update(get_flexible_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)
self._status_guard_ref = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.uid()
)
self._event_collector_ref = await xo.actor_ref(
address=self._supervisor_address, uid=EventCollectorActor.uid()
)
self._cache_tracker_ref = await xo.actor_ref(
address=self._supervisor_address, uid=CacheTrackerActor.uid()
)
# cache_tracker is on supervisor
from ..model.audio import get_audio_model_descriptions
from ..model.embedding import get_embedding_model_descriptions
from ..model.flexible import get_flexible_model_descriptions
from ..model.image import get_image_model_descriptions
from ..model.llm import get_llm_model_descriptions
from ..model.rerank import get_rerank_model_descriptions

# record model version
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
model_version_infos.update(get_audio_model_descriptions())
model_version_infos.update(get_flexible_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)
return self._supervisor_ref

@staticmethod
Expand Down Expand Up @@ -793,6 +798,7 @@ async def launch_builtin_model(
logger.exception(e)
raise
try:
_ = await self.get_supervisor_ref()
if self._event_collector_ref is not None:
await self._event_collector_ref.report_event(
origin_uid,
Expand Down Expand Up @@ -914,6 +920,7 @@ async def terminate_model(self, model_uid: str):
raise ValueError(f"{model_uid} is launching")
origin_uid, _, __ = parse_replica_model_uid(model_uid)
try:
_ = await self.get_supervisor_ref()
if self._event_collector_ref is not None:
await self._event_collector_ref.report_event(
origin_uid,
Expand Down

0 comments on commit b66c1d4

Please sign in to comment.