Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ux] cache cluster status of autostop or spot clusters for 2s #4332

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 102 additions & 67 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@
CLUSTER_STATUS_LOCK_PATH = os.path.expanduser('~/.sky/.{}.lock')
CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS = 20

# Time that must elapse since the last status check before we should re-check if
# the cluster has been terminated or autostopped.
_CLUSTER_STATUS_CACHE_DURATION_SECONDS = 2

# Filelocks for updating cluster's file_mounts.
CLUSTER_FILE_MOUNTS_LOCK_PATH = os.path.expanduser(
'~/.sky/.{}_file_mounts.lock')
Expand Down Expand Up @@ -1668,11 +1672,27 @@ def check_can_clone_disk_and_override_task(

def _update_cluster_status_no_lock(
cluster_name: str) -> Optional[Dict[str, Any]]:
"""Updates the status of the cluster.
"""Update the cluster status.

The cluster status is updated by checking ray cluster and real status from
cloud.

The function will update the cached cluster status in the global state. For
the design of the cluster status and transition, please refer to the
sky/design_docs/cluster_status.md

Returns:
If the cluster is terminated or does not exist, return None. Otherwise
returns the input record with status and handle potentially updated.

Raises:
exceptions.ClusterOwnerIdentityMismatchError: if the current user is
not the same as the user who created the cluster.
exceptions.CloudUserIdentityError: if we fail to get the current user
identity.
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider.
fetched from the cloud provider or there are leaked nodes causing
the node number larger than expected.
"""
record = global_user_state.get_cluster_from_name(cluster_name)
if record is None:
Expand Down Expand Up @@ -1892,52 +1912,22 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
return global_user_state.get_cluster_from_name(cluster_name)


def _update_cluster_status(
cluster_name: str,
acquire_per_cluster_status_lock: bool,
cluster_status_lock_timeout: int = CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS
) -> Optional[Dict[str, Any]]:
"""Update the cluster status.
def _must_refresh_cluster_status(
record: Dict[str, Any],
force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]]
) -> bool:
force_refresh_for_cluster = (force_refresh_statuses is not None and
record['status'] in force_refresh_statuses)

The cluster status is updated by checking ray cluster and real status from
cloud.
use_spot = record['handle'].launched_resources.use_spot
has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and
record['autostop'] >= 0)
recently_refreshed = (record['status_updated_at'] is not None and
time.time() - record['status_updated_at'] <
_CLUSTER_STATUS_CACHE_DURATION_SECONDS)
is_stale = (use_spot or has_autostop) and not recently_refreshed

The function will update the cached cluster status in the global state. For
the design of the cluster status and transition, please refer to the
sky/design_docs/cluster_status.md

Args:
cluster_name: The name of the cluster.
acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock
before updating the status.
cluster_status_lock_timeout: The timeout to acquire the per-cluster
lock.

Returns:
If the cluster is terminated or does not exist, return None. Otherwise
returns the input record with status and handle potentially updated.

Raises:
exceptions.ClusterOwnerIdentityMismatchError: if the current user is
not the same as the user who created the cluster.
exceptions.CloudUserIdentityError: if we fail to get the current user
identity.
exceptions.ClusterStatusFetchingError: the cluster status cannot be
fetched from the cloud provider or there are leaked nodes causing
the node number larger than expected.
"""
if not acquire_per_cluster_status_lock:
return _update_cluster_status_no_lock(cluster_name)

try:
with filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name),
timeout=cluster_status_lock_timeout):
return _update_cluster_status_no_lock(cluster_name)
except filelock.Timeout:
logger.debug('Refreshing status: Failed get the lock for cluster '
f'{cluster_name!r}. Using the cached status.')
record = global_user_state.get_cluster_from_name(cluster_name)
return record
return force_refresh_for_cluster or is_stale


def refresh_cluster_record(
Expand All @@ -1955,16 +1945,22 @@ def refresh_cluster_record(

Args:
cluster_name: The name of the cluster.
force_refresh_statuses: if specified, refresh the cluster if it has one of
the specified statuses. Additionally, clusters satisfying the
following conditions will always be refreshed no matter the
argument is specified or not:
1. is a spot cluster, or
2. is a non-spot cluster, is not STOPPED, and autostop is set.
force_refresh_statuses: if specified, refresh the cluster if it has one
of the specified statuses. Additionally, clusters satisfying the
following conditions will be refreshed no matter the argument is
specified or not:
- the most latest available status update is more than
_CLUSTER_STATUS_CACHE_DURATION_SECONDS old, and one of:
1. the cluster is a spot cluster, or
2. cluster autostop is set and the cluster is not STOPPED.
acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock
before updating the status.
before updating the status. Even if this is True, the lock may not be
acquired if the status does not need to be refreshed.
cluster_status_lock_timeout: The timeout to acquire the per-cluster
lock. If timeout, the function will use the cached status.
lock. If timeout, the function will use the cached status. If the
value is <0, do not timeout (wait for the lock indefinitely). By
default, this is set to CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS. Warning:
if correctness is required, you must set this to -1.

Returns:
If the cluster is terminated or does not exist, return None.
Expand All @@ -1985,19 +1981,58 @@ def refresh_cluster_record(
return None
check_owner_identity(cluster_name)

handle = record['handle']
if isinstance(handle, backends.CloudVmRayResourceHandle):
use_spot = handle.launched_resources.use_spot
has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and
record['autostop'] >= 0)
force_refresh_for_cluster = (force_refresh_statuses is not None and
record['status'] in force_refresh_statuses)
if force_refresh_for_cluster or has_autostop or use_spot:
record = _update_cluster_status(
cluster_name,
acquire_per_cluster_status_lock=acquire_per_cluster_status_lock,
cluster_status_lock_timeout=cluster_status_lock_timeout)
return record
if not isinstance(record['handle'], backends.CloudVmRayResourceHandle):
return record

# The loop logic allows us to notice if the status was updated in the
# global_user_state by another process and stop trying to get the lock.
# The core loop logic is adapted from FileLock's implementation.
lock = filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name))
start_time = time.perf_counter()

# Loop until we have an up-to-date status or until we acquire the lock.
while True:
# Check to see if we can return the cached status.
if not _must_refresh_cluster_status(record, force_refresh_statuses):
return record

if not acquire_per_cluster_status_lock:
return _update_cluster_status_no_lock(cluster_name)

# Try to acquire the lock so we can fetch the status.
try:
with lock.acquire(blocking=False):
# Lock acquired.

# Check the cluster status again, since it could have been
# updated between our last check and acquiring the lock.
record = global_user_state.get_cluster_from_name(cluster_name)
if record is None or not _must_refresh_cluster_status(
record, force_refresh_statuses):
return record

# Update and return the cluster status.
return _update_cluster_status_no_lock(cluster_name)
except filelock.Timeout:
# lock.acquire() will throw a Timeout exception if the lock is not
# available and we have blocking=False.
pass

# Logic adapted from FileLock.acquire().
# If cluster_status_lock_time is <0, we will never hit this. No timeout.
# Otherwise, if we have timed out, return the cached status. This has
# the potential to cause correctness issues, but if so it is the
# caller's responsibility to set the timeout to -1.
if 0 <= cluster_status_lock_timeout < time.perf_counter() - start_time:
logger.debug('Refreshing status: Failed get the lock for cluster '
f'{cluster_name!r}. Using the cached status.')
return record
time.sleep(0.05)

# Refresh for next loop iteration.
record = global_user_state.get_cluster_from_name(cluster_name)
if record is None:
return None


@timeline.event
Expand Down
2 changes: 1 addition & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3554,7 +3554,7 @@ def _teardown(self,
backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name))

try:
with filelock.FileLock(
with timeline.FileLockEvent(
lock_path,
backend_utils.CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS):
self.teardown_no_lock(
Expand Down
2 changes: 2 additions & 0 deletions sky/clouds/service_catalog/aws_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sky.utils import common_utils
from sky.utils import resources_utils
from sky.utils import rich_utils
from sky.utils import timeline
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -100,6 +101,7 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']:
return az_mappings


@timeline.event
def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame':
"""Maps zone IDs (use1-az1) to zone names (us-east-1x).

Expand Down
52 changes: 36 additions & 16 deletions sky/global_user_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def create_table(cursor, conn):
owner TEXT DEFAULT null,
cluster_hash TEXT DEFAULT null,
storage_mounts_metadata BLOB DEFAULT null,
cluster_ever_up INTEGER DEFAULT 0)""")
cluster_ever_up INTEGER DEFAULT 0,
status_updated_at INTEGER DEFAULT null)""")

# Table for Cluster History
# usage_intervals: List[Tuple[int, int]]
Expand Down Expand Up @@ -130,6 +131,10 @@ def create_table(cursor, conn):
# clusters were never really UP, setting it to 1 means they won't be
# auto-deleted during any failover.
value_to_replace_existing_entries=1)

db_utils.add_column_to_table(cursor, conn, 'clusters', 'status_updated_at',
'INTEGER DEFAULT null')

conn.commit()


Expand Down Expand Up @@ -159,6 +164,7 @@ def add_or_update_cluster(cluster_name: str,
status = status_lib.ClusterStatus.INIT
if ready:
status = status_lib.ClusterStatus.UP
status_updated_at = int(time.time())

# TODO (sumanth): Cluster history table will have multiple entries
# when the cluster failover through multiple regions (one entry per region).
Expand Down Expand Up @@ -191,7 +197,7 @@ def add_or_update_cluster(cluster_name: str,
# specified.
'(name, launched_at, handle, last_use, status, '
'autostop, to_down, metadata, owner, cluster_hash, '
'storage_mounts_metadata, cluster_ever_up) '
'storage_mounts_metadata, cluster_ever_up, status_updated_at) '
'VALUES ('
# name
'?, '
Expand Down Expand Up @@ -228,7 +234,9 @@ def add_or_update_cluster(cluster_name: str,
'COALESCE('
'(SELECT storage_mounts_metadata FROM clusters WHERE name=?), null), '
# cluster_ever_up
'((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?)'
'((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?),'
# status_updated_at
'?'
')',
(
# name
Expand Down Expand Up @@ -260,6 +268,8 @@ def add_or_update_cluster(cluster_name: str,
# cluster_ever_up
cluster_name,
int(ready),
# status_updated_at
status_updated_at,
))

launched_nodes = getattr(cluster_handle, 'launched_nodes', None)
Expand Down Expand Up @@ -330,11 +340,13 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None:
# stopped VM, which leads to timeout.
if hasattr(handle, 'stable_internal_external_ips'):
handle.stable_internal_external_ips = None
current_time = int(time.time())
_DB.cursor.execute(
'UPDATE clusters SET handle=(?), status=(?) '
'WHERE name=(?)', (
'UPDATE clusters SET handle=(?), status=(?), '
'status_updated_at=(?) WHERE name=(?)', (
pickle.dumps(handle),
status_lib.ClusterStatus.STOPPED.value,
current_time,
cluster_name,
))
_DB.conn.commit()
Expand All @@ -359,10 +371,10 @@ def get_glob_cluster_names(cluster_name: str) -> List[str]:

def set_cluster_status(cluster_name: str,
status: status_lib.ClusterStatus) -> None:
_DB.cursor.execute('UPDATE clusters SET status=(?) WHERE name=(?)', (
status.value,
cluster_name,
))
current_time = int(time.time())
_DB.cursor.execute(
'UPDATE clusters SET status=(?), status_updated_at=(?) WHERE name=(?)',
(status.value, current_time, cluster_name))
count = _DB.cursor.rowcount
_DB.conn.commit()
assert count <= 1, count
Expand Down Expand Up @@ -570,15 +582,18 @@ def _load_storage_mounts_metadata(

def get_cluster_from_name(
cluster_name: Optional[str]) -> Optional[Dict[str, Any]]:
rows = _DB.cursor.execute('SELECT * FROM clusters WHERE name=(?)',
(cluster_name,)).fetchall()
rows = _DB.cursor.execute(
'SELECT name, launched_at, handle, last_use, status, autostop, '
'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, '
'cluster_ever_up, status_updated_at FROM clusters WHERE name=(?)',
(cluster_name,)).fetchall()
for row in rows:
# Explicitly specify the number of fields to unpack, so that
# we can add new fields to the database in the future without
# breaking the previous code.
(name, launched_at, handle, last_use, status, autostop, metadata,
to_down, owner, cluster_hash, storage_mounts_metadata,
cluster_ever_up) = row[:12]
to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up,
status_updated_at) = row[:13]
# TODO: use namedtuple instead of dict
record = {
'name': name,
Expand All @@ -594,19 +609,23 @@ def get_cluster_from_name(
'storage_mounts_metadata':
_load_storage_mounts_metadata(storage_mounts_metadata),
'cluster_ever_up': bool(cluster_ever_up),
'status_updated_at': status_updated_at,
}
return record
return None


def get_clusters() -> List[Dict[str, Any]]:
rows = _DB.cursor.execute(
'select * from clusters order by launched_at desc').fetchall()
'select name, launched_at, handle, last_use, status, autostop, '
'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, '
'cluster_ever_up, status_updated_at from clusters '
'order by launched_at desc').fetchall()
Comment on lines +620 to +623
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to move away from *?

Copy link
Collaborator Author

@cg505 cg505 Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is robust against column order.

Edit: to expand on this - there's no guarantee on the order of columns that select * will give up. For instance, if me and you are both developing features that add a column, and then we merge both of these changes, this will break. We will each have added our own new column before the other, so the order of the columns in our global state db will be different.

I've already hit this bug a few times between this change, #4289, and the stuff we were testing yesterday

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also #4211, same class of issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Is backward compatibility still preserved? E.g., if a cluster was launched before this PR and state.db doesn't contain status_updated_at, but on upgrading to this branch this line tries to select status_updated_at, will that work?

(I think it should still work because create_table is called at module initialization, but just want to double check).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe that should be fine. In fact even with select *, this method would crash if status_updated_at was missing because we would not have enough columns to unpack.

records = []
for row in rows:
(name, launched_at, handle, last_use, status, autostop, metadata,
to_down, owner, cluster_hash, storage_mounts_metadata,
cluster_ever_up) = row[:12]
to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up,
status_updated_at) = row[:13]
# TODO: use namedtuple instead of dict
record = {
'name': name,
Expand All @@ -622,6 +641,7 @@ def get_clusters() -> List[Dict[str, Any]]:
'storage_mounts_metadata':
_load_storage_mounts_metadata(storage_mounts_metadata),
'cluster_ever_up': bool(cluster_ever_up),
'status_updated_at': status_updated_at,
}

records.append(record)
Expand Down
6 changes: 2 additions & 4 deletions sky/utils/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ def event(name_or_fn: Union[str, Callable], message: Optional[str] = None):
class FileLockEvent:
"""Serve both as a file lock and event for the lock."""

def __init__(self, lockfile: Union[str, os.PathLike]):
def __init__(self, lockfile: Union[str, os.PathLike], timeout: float = -1):
self._lockfile = lockfile
# TODO(mraheja): remove pylint disabling when filelock version updated
# pylint: disable=abstract-class-instantiated
self._lock = filelock.FileLock(self._lockfile)
self._lock = filelock.FileLock(self._lockfile, timeout)
self._hold_lock_event = Event(f'[FileLock.hold]:{self._lockfile}')

def acquire(self):
Expand Down
Loading