diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index e4633ef0671..cde47726295 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -100,6 +100,10 @@ CLUSTER_STATUS_LOCK_PATH = os.path.expanduser('~/.sky/.{}.lock') CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS = 20 +# Time that must elapse since the last status check before we should re-check if +# the cluster has been terminated or autostopped. +_CLUSTER_STATUS_CACHE_DURATION_SECONDS = 2 + # Filelocks for updating cluster's file_mounts. CLUSTER_FILE_MOUNTS_LOCK_PATH = os.path.expanduser( '~/.sky/.{}_file_mounts.lock') @@ -1668,11 +1672,27 @@ def check_can_clone_disk_and_override_task( def _update_cluster_status_no_lock( cluster_name: str) -> Optional[Dict[str, Any]]: - """Updates the status of the cluster. + """Update the cluster status. + + The cluster status is updated by checking ray cluster and real status from + cloud. + + The function will update the cached cluster status in the global state. For + the design of the cluster status and transition, please refer to the + sky/design_docs/cluster_status.md + + Returns: + If the cluster is terminated or does not exist, return None. Otherwise + returns the input record with status and handle potentially updated. Raises: + exceptions.ClusterOwnerIdentityMismatchError: if the current user is + not the same as the user who created the cluster. + exceptions.CloudUserIdentityError: if we fail to get the current user + identity. exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider. + fetched from the cloud provider or there are leaked nodes causing + the node number larger than expected. """ record = global_user_state.get_cluster_from_name(cluster_name) if record is None: @@ -1892,52 +1912,22 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: return global_user_state.get_cluster_from_name(cluster_name) -def _update_cluster_status( - cluster_name: str, - acquire_per_cluster_status_lock: bool, - cluster_status_lock_timeout: int = CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS -) -> Optional[Dict[str, Any]]: - """Update the cluster status. +def _must_refresh_cluster_status( + record: Dict[str, Any], + force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]] +) -> bool: + force_refresh_for_cluster = (force_refresh_statuses is not None and + record['status'] in force_refresh_statuses) - The cluster status is updated by checking ray cluster and real status from - cloud. + use_spot = record['handle'].launched_resources.use_spot + has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and + record['autostop'] >= 0) + recently_refreshed = (record['status_updated_at'] is not None and + time.time() - record['status_updated_at'] < + _CLUSTER_STATUS_CACHE_DURATION_SECONDS) + is_stale = (use_spot or has_autostop) and not recently_refreshed - The function will update the cached cluster status in the global state. For - the design of the cluster status and transition, please refer to the - sky/design_docs/cluster_status.md - - Args: - cluster_name: The name of the cluster. - acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. - cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. - - Returns: - If the cluster is terminated or does not exist, return None. Otherwise - returns the input record with status and handle potentially updated. - - Raises: - exceptions.ClusterOwnerIdentityMismatchError: if the current user is - not the same as the user who created the cluster. - exceptions.CloudUserIdentityError: if we fail to get the current user - identity. - exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider or there are leaked nodes causing - the node number larger than expected. - """ - if not acquire_per_cluster_status_lock: - return _update_cluster_status_no_lock(cluster_name) - - try: - with filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name), - timeout=cluster_status_lock_timeout): - return _update_cluster_status_no_lock(cluster_name) - except filelock.Timeout: - logger.debug('Refreshing status: Failed get the lock for cluster ' - f'{cluster_name!r}. Using the cached status.') - record = global_user_state.get_cluster_from_name(cluster_name) - return record + return force_refresh_for_cluster or is_stale def refresh_cluster_record( @@ -1955,16 +1945,22 @@ def refresh_cluster_record( Args: cluster_name: The name of the cluster. - force_refresh_statuses: if specified, refresh the cluster if it has one of - the specified statuses. Additionally, clusters satisfying the - following conditions will always be refreshed no matter the - argument is specified or not: - 1. is a spot cluster, or - 2. is a non-spot cluster, is not STOPPED, and autostop is set. + force_refresh_statuses: if specified, refresh the cluster if it has one + of the specified statuses. Additionally, clusters satisfying the + following conditions will be refreshed no matter the argument is + specified or not: + - the most latest available status update is more than + _CLUSTER_STATUS_CACHE_DURATION_SECONDS old, and one of: + 1. the cluster is a spot cluster, or + 2. cluster autostop is set and the cluster is not STOPPED. acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. + before updating the status. Even if this is True, the lock may not be + acquired if the status does not need to be refreshed. cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. If timeout, the function will use the cached status. + lock. If timeout, the function will use the cached status. If the + value is <0, do not timeout (wait for the lock indefinitely). By + default, this is set to CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS. Warning: + if correctness is required, you must set this to -1. Returns: If the cluster is terminated or does not exist, return None. @@ -1985,19 +1981,58 @@ def refresh_cluster_record( return None check_owner_identity(cluster_name) - handle = record['handle'] - if isinstance(handle, backends.CloudVmRayResourceHandle): - use_spot = handle.launched_resources.use_spot - has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and - record['autostop'] >= 0) - force_refresh_for_cluster = (force_refresh_statuses is not None and - record['status'] in force_refresh_statuses) - if force_refresh_for_cluster or has_autostop or use_spot: - record = _update_cluster_status( - cluster_name, - acquire_per_cluster_status_lock=acquire_per_cluster_status_lock, - cluster_status_lock_timeout=cluster_status_lock_timeout) - return record + if not isinstance(record['handle'], backends.CloudVmRayResourceHandle): + return record + + # The loop logic allows us to notice if the status was updated in the + # global_user_state by another process and stop trying to get the lock. + # The core loop logic is adapted from FileLock's implementation. + lock = filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) + start_time = time.perf_counter() + + # Loop until we have an up-to-date status or until we acquire the lock. + while True: + # Check to see if we can return the cached status. + if not _must_refresh_cluster_status(record, force_refresh_statuses): + return record + + if not acquire_per_cluster_status_lock: + return _update_cluster_status_no_lock(cluster_name) + + # Try to acquire the lock so we can fetch the status. + try: + with lock.acquire(blocking=False): + # Lock acquired. + + # Check the cluster status again, since it could have been + # updated between our last check and acquiring the lock. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None or not _must_refresh_cluster_status( + record, force_refresh_statuses): + return record + + # Update and return the cluster status. + return _update_cluster_status_no_lock(cluster_name) + except filelock.Timeout: + # lock.acquire() will throw a Timeout exception if the lock is not + # available and we have blocking=False. + pass + + # Logic adapted from FileLock.acquire(). + # If cluster_status_lock_time is <0, we will never hit this. No timeout. + # Otherwise, if we have timed out, return the cached status. This has + # the potential to cause correctness issues, but if so it is the + # caller's responsibility to set the timeout to -1. + if 0 <= cluster_status_lock_timeout < time.perf_counter() - start_time: + logger.debug('Refreshing status: Failed get the lock for cluster ' + f'{cluster_name!r}. Using the cached status.') + return record + time.sleep(0.05) + + # Refresh for next loop iteration. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None: + return None @timeline.event diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index e338eecb744..b4a1268f174 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3554,7 +3554,7 @@ def _teardown(self, backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) try: - with filelock.FileLock( + with timeline.FileLockEvent( lock_path, backend_utils.CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS): self.teardown_no_lock( diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 918a4070414..bbd48863755 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -20,6 +20,7 @@ from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import rich_utils +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -100,6 +101,7 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings +@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). diff --git a/sky/global_user_state.py b/sky/global_user_state.py index 7c040ea55fc..e9f15df4f52 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -60,7 +60,8 @@ def create_table(cursor, conn): owner TEXT DEFAULT null, cluster_hash TEXT DEFAULT null, storage_mounts_metadata BLOB DEFAULT null, - cluster_ever_up INTEGER DEFAULT 0)""") + cluster_ever_up INTEGER DEFAULT 0, + status_updated_at INTEGER DEFAULT null)""") # Table for Cluster History # usage_intervals: List[Tuple[int, int]] @@ -130,6 +131,10 @@ def create_table(cursor, conn): # clusters were never really UP, setting it to 1 means they won't be # auto-deleted during any failover. value_to_replace_existing_entries=1) + + db_utils.add_column_to_table(cursor, conn, 'clusters', 'status_updated_at', + 'INTEGER DEFAULT null') + conn.commit() @@ -159,6 +164,7 @@ def add_or_update_cluster(cluster_name: str, status = status_lib.ClusterStatus.INIT if ready: status = status_lib.ClusterStatus.UP + status_updated_at = int(time.time()) # TODO (sumanth): Cluster history table will have multiple entries # when the cluster failover through multiple regions (one entry per region). @@ -191,7 +197,7 @@ def add_or_update_cluster(cluster_name: str, # specified. '(name, launched_at, handle, last_use, status, ' 'autostop, to_down, metadata, owner, cluster_hash, ' - 'storage_mounts_metadata, cluster_ever_up) ' + 'storage_mounts_metadata, cluster_ever_up, status_updated_at) ' 'VALUES (' # name '?, ' @@ -228,7 +234,9 @@ def add_or_update_cluster(cluster_name: str, 'COALESCE(' '(SELECT storage_mounts_metadata FROM clusters WHERE name=?), null), ' # cluster_ever_up - '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?)' + '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?),' + # status_updated_at + '?' ')', ( # name @@ -260,6 +268,8 @@ def add_or_update_cluster(cluster_name: str, # cluster_ever_up cluster_name, int(ready), + # status_updated_at + status_updated_at, )) launched_nodes = getattr(cluster_handle, 'launched_nodes', None) @@ -330,11 +340,13 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None: # stopped VM, which leads to timeout. if hasattr(handle, 'stable_internal_external_ips'): handle.stable_internal_external_ips = None + current_time = int(time.time()) _DB.cursor.execute( - 'UPDATE clusters SET handle=(?), status=(?) ' - 'WHERE name=(?)', ( + 'UPDATE clusters SET handle=(?), status=(?), ' + 'status_updated_at=(?) WHERE name=(?)', ( pickle.dumps(handle), status_lib.ClusterStatus.STOPPED.value, + current_time, cluster_name, )) _DB.conn.commit() @@ -359,10 +371,10 @@ def get_glob_cluster_names(cluster_name: str) -> List[str]: def set_cluster_status(cluster_name: str, status: status_lib.ClusterStatus) -> None: - _DB.cursor.execute('UPDATE clusters SET status=(?) WHERE name=(?)', ( - status.value, - cluster_name, - )) + current_time = int(time.time()) + _DB.cursor.execute( + 'UPDATE clusters SET status=(?), status_updated_at=(?) WHERE name=(?)', + (status.value, current_time, cluster_name)) count = _DB.cursor.rowcount _DB.conn.commit() assert count <= 1, count @@ -570,15 +582,18 @@ def _load_storage_mounts_metadata( def get_cluster_from_name( cluster_name: Optional[str]) -> Optional[Dict[str, Any]]: - rows = _DB.cursor.execute('SELECT * FROM clusters WHERE name=(?)', - (cluster_name,)).fetchall() + rows = _DB.cursor.execute( + 'SELECT name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at FROM clusters WHERE name=(?)', + (cluster_name,)).fetchall() for row in rows: # Explicitly specify the number of fields to unpack, so that # we can add new fields to the database in the future without # breaking the previous code. (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + status_updated_at) = row[:13] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -594,6 +609,7 @@ def get_cluster_from_name( 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, } return record return None @@ -601,12 +617,15 @@ def get_cluster_from_name( def get_clusters() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( - 'select * from clusters order by launched_at desc').fetchall() + 'select name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at from clusters ' + 'order by launched_at desc').fetchall() records = [] for row in rows: (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + status_updated_at) = row[:13] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -622,6 +641,7 @@ def get_clusters() -> List[Dict[str, Any]]: 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, } records.append(record) diff --git a/sky/utils/timeline.py b/sky/utils/timeline.py index 29c6c3d94ee..e1e0984f748 100644 --- a/sky/utils/timeline.py +++ b/sky/utils/timeline.py @@ -77,11 +77,9 @@ def event(name_or_fn: Union[str, Callable], message: Optional[str] = None): class FileLockEvent: """Serve both as a file lock and event for the lock.""" - def __init__(self, lockfile: Union[str, os.PathLike]): + def __init__(self, lockfile: Union[str, os.PathLike], timeout: float = -1): self._lockfile = lockfile - # TODO(mraheja): remove pylint disabling when filelock version updated - # pylint: disable=abstract-class-instantiated - self._lock = filelock.FileLock(self._lockfile) + self._lock = filelock.FileLock(self._lockfile, timeout) self._hold_lock_event = Event(f'[FileLock.hold]:{self._lockfile}') def acquire(self):