Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
suquark committed Oct 20, 2023
1 parent b02931b commit 91c00ad
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
9 changes: 5 additions & 4 deletions sky/provision/gcp/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,10 @@ def get_cluster_info(
label_filters,
lambda _: ['RUNNING'],
)
all_instances = [
i for instances in handler_to_instances.values() for i in instances
]
instances = {}
for res, insts in handler_to_instances.items():
for inst in insts:
instances[inst] = res.get_instance_info(project_id, zone, inst)

head_instances = _filter_instances(
handlers,
Expand All @@ -315,7 +316,7 @@ def get_cluster_info(
break

return common.ClusterInfo(
instances=all_instances,
instances=instances,
head_instance_id=head_instance_id,
)

Expand Down
39 changes: 36 additions & 3 deletions sky/provision/gcp/instance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sky import sky_logging
from sky.adaptors import gcp
from sky.provision import common
from sky.provision.gcp.constants import MAX_POLLS
from sky.provision.gcp.constants import POLL_INTERVAL
from sky.utils import ux_utils
Expand Down Expand Up @@ -269,6 +270,15 @@ def create_node_tag(cls,
wait_for_operation: bool = True) -> str:
raise NotImplementedError

@classmethod
def get_instance_info(
cls,
project_id: str,
availability_zone: str,
instance_id: str,
wait_for_operation: bool = True) -> common.InstanceInfo:
raise NotImplementedError


class GCPComputeInstance(GCPInstance):
"""Instance handler for GCP compute instances."""
Expand Down Expand Up @@ -525,12 +535,11 @@ def set_labels(cls,
node_id: str,
labels: dict,
wait_for_operation: bool = True) -> dict:
response = (cls.load_resource().instances().get(
node = cls.load_resource().instances().get(
project=project_id,
instance=node_id,
zone=availability_zone,
).execute())
node = response.get('items', [])[0]
).execute()
body = {
"labels": dict(node["labels"], **labels),
"labelFingerprint": node["labelFingerprint"],
Expand Down Expand Up @@ -766,6 +775,30 @@ def start_instance(cls,

return result

@classmethod
def get_instance_info(
cls,
project_id: str,
availability_zone: str,
instance_id: str,
wait_for_operation: bool = True) -> common.InstanceInfo:
result = cls.load_resource().instances().get(
project=project_id,
zone=availability_zone,
instance=instance_id,
).execute()
external_ip = (result.get("networkInterfaces",
[{}])[0].get("accessConfigs",
[{}])[0].get("natIP", None))
internal_ip = result.get("networkInterfaces", [{}])[0].get("networkIP")

return common.InstanceInfo(
instance_id=instance_id,
internal_ip=internal_ip,
external_ip=external_ip,
tags=result.get('labels', {}),
)


class GCPTPUVMInstance(GCPInstance):
"""Instance handler for GCP TPU node."""
Expand Down

0 comments on commit 91c00ad

Please sign in to comment.