Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Fix concurrent ops in worker initialization #2125

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 46 additions & 39 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
from ..device_utils import get_available_device_env_name, gpu_count
from ..model.core import ModelDescription, create_model_instance
from ..types import PeftModelConfig
from .cache_tracker import CacheTrackerActor
from .event import Event, EventCollectorActor, EventType
from .metrics import launch_metrics_export_server, record_metrics
from .resource import gather_node_info
from .status_guard import StatusGuardActor
from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir

logger = getLogger(__name__)
Expand Down Expand Up @@ -71,6 +73,15 @@ def __init__(
self._supervisor_ref: Optional[xo.ActorRefType] = None
self._main_pool = main_pool
self._main_pool.recover_sub_pool = self.recover_sub_pool
self._status_guard_ref: xo.ActorRefType[ # type: ignore
"StatusGuardActor"
] = None
self._event_collector_ref: xo.ActorRefType[ # type: ignore
EventCollectorActor
] = None
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
CacheTrackerActor
] = None

# internal states.
# temporary placeholder during model launch process:
Expand Down Expand Up @@ -308,56 +319,50 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
Params:
add_worker: By default will call supervisor.add_worker after first connect
"""
from .status_guard import StatusGuardActor
from .supervisor import SupervisorActor

if self._supervisor_ref is not None:
return self._supervisor_ref
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
supervisor_ref = await xo.actor_ref( # type: ignore
address=self._supervisor_address, uid=SupervisorActor.uid()
)
# Prevent concurrent operations leads to double initialization, check again.
if self._supervisor_ref is not None:
return self._supervisor_ref
self._supervisor_ref = supervisor_ref
if add_worker and len(self._model_uid_to_model) == 0:
# Newly started (or restarted), has no model, notify supervisor
await self._supervisor_ref.add_worker(self.address)
logger.info("Connected to supervisor as a fresh worker")

self._status_guard_ref: xo.ActorRefType[ # type: ignore
"StatusGuardActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.uid()
)

self._event_collector_ref: xo.ActorRefType[ # type: ignore
EventCollectorActor
] = await xo.actor_ref(
address=self._supervisor_address, uid=EventCollectorActor.uid()
)
from .cache_tracker import CacheTrackerActor

self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
"CacheTrackerActor"
] = await xo.actor_ref(
address=self._supervisor_address, uid=CacheTrackerActor.uid()
)
# cache_tracker is on supervisor
from ..model.audio import get_audio_model_descriptions
from ..model.embedding import get_embedding_model_descriptions
from ..model.flexible import get_flexible_model_descriptions
from ..model.image import get_image_model_descriptions
from ..model.llm import get_llm_model_descriptions
from ..model.rerank import get_rerank_model_descriptions

# record model version
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
model_version_infos.update(get_audio_model_descriptions())
model_version_infos.update(get_flexible_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)
self._status_guard_ref = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.uid()
)
self._event_collector_ref = await xo.actor_ref(
address=self._supervisor_address, uid=EventCollectorActor.uid()
)
self._cache_tracker_ref = await xo.actor_ref(
address=self._supervisor_address, uid=CacheTrackerActor.uid()
)
# cache_tracker is on supervisor
from ..model.audio import get_audio_model_descriptions
from ..model.embedding import get_embedding_model_descriptions
from ..model.flexible import get_flexible_model_descriptions
from ..model.image import get_image_model_descriptions
from ..model.llm import get_llm_model_descriptions
from ..model.rerank import get_rerank_model_descriptions

# record model version
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
model_version_infos.update(get_llm_model_descriptions())
model_version_infos.update(get_embedding_model_descriptions())
model_version_infos.update(get_rerank_model_descriptions())
model_version_infos.update(get_image_model_descriptions())
model_version_infos.update(get_audio_model_descriptions())
model_version_infos.update(get_flexible_model_descriptions())
await self._cache_tracker_ref.record_model_version(
model_version_infos, self.address
)
return self._supervisor_ref

@staticmethod
Expand Down Expand Up @@ -793,6 +798,7 @@ async def launch_builtin_model(
logger.exception(e)
raise
try:
_ = await self.get_supervisor_ref()
qinxuye marked this conversation as resolved.
Show resolved Hide resolved
if self._event_collector_ref is not None:
await self._event_collector_ref.report_event(
origin_uid,
Expand Down Expand Up @@ -914,6 +920,7 @@ async def terminate_model(self, model_uid: str):
raise ValueError(f"{model_uid} is launching")
origin_uid, _, __ = parse_replica_model_uid(model_uid)
try:
_ = await self.get_supervisor_ref()
if self._event_collector_ref is not None:
await self._event_collector_ref.report_event(
origin_uid,
Expand Down
Loading