diff --git a/runhouse/main.py b/runhouse/main.py index f9a65e603..483c01862 100644 --- a/runhouse/main.py +++ b/runhouse/main.py @@ -39,7 +39,6 @@ kill_actors, ) - # create an explicit Typer application app = typer.Typer(add_completion=False) @@ -399,6 +398,26 @@ def _print_envs_info( ) +def _print_cloud_properties(cluster_config: dict): + cloud_properties = cluster_config.get("launched_properties", None) + if not cloud_properties: + return + cloud = cloud_properties.get("cloud") + instance_type = cloud_properties.get("instance_type") + region = cloud_properties.get("region") + cost_per_hour = cloud_properties.get("cost_per_hour") + + has_cuda = cluster_config.get("has_cuda", False) + cost_emoji = "💰" if has_cuda else "💸" + + num_of_instances = len(cluster_config.get("ips")) + num_of_instances_str = f"{num_of_instances}x " if num_of_instances > 1 else "" + + print( + f"🤖 {num_of_instances_str}{cloud} {instance_type} cluster | 🌍 {region} | {cost_emoji} ${cost_per_hour}/hr" + ) + + def _print_status(status_data: dict, current_cluster: Cluster) -> None: """Prints the status of the cluster to the console""" cluster_config = status_data.get("cluster_config") @@ -416,6 +435,7 @@ def _print_status(status_data: dict, current_cluster: Cluster) -> None: console.print(daemon_headline_txt, style="bold royal_blue1") console.print(f"Runhouse v{status_data.get('runhouse_version')}") + _print_cloud_properties(cluster_config) console.print(f"server pid: {status_data.get('server_pid')}") # Print relevant info from cluster config. diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index fd63d2805..81d024d12 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -591,7 +591,12 @@ def status_cli_test_logic(self, cluster, status_cli_command: str): assert "node: " in status_output_string assert status_output_string.count("node: ") >= 1 - # if it is a GPU cluster, check GPU print as well + cloud_properties = cluster.config().get("launched_properties", None) + if cloud_properties: + properties_to_check = ["cloud", "instance_type", "region", "cost_per_hour"] + for p in properties_to_check: + property_value = cloud_properties.get(p) + assert property_value in status_output_string @pytest.mark.level("local") @pytest.mark.clustertest