-
Notifications
You must be signed in to change notification settings - Fork 510
/
backend_utils.py
2437 lines (2153 loc) · 102 KB
/
backend_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Util constants/functions for the backends."""
from datetime import datetime
import difflib
import enum
import getpass
import json
import os
import pathlib
import re
import subprocess
import tempfile
import textwrap
import time
import typing
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing_extensions import Literal
import uuid
import colorama
import filelock
import jinja2
import jsonschema
from packaging import version
import requests
from requests import adapters
from requests.packages.urllib3.util import retry as retry_lib
import rich.progress as rich_progress
import yaml
import sky
from sky import authentication as auth
from sky import backends
from sky import check as sky_check
from sky import clouds
from sky import exceptions
from sky import global_user_state
from sky import skypilot_config
from sky import sky_logging
from sky import spot as spot_lib
from sky.backends import onprem_utils
from sky.skylet import constants
from sky.skylet import log_lib
from sky.skylet.providers.lambda_cloud import lambda_utils
from sky.utils import common_utils
from sky.utils import command_runner
from sky.utils import env_options
from sky.utils import log_utils
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import tpu_utils
from sky.utils import ux_utils
from sky.utils import validator
from sky.usage import usage_lib
if typing.TYPE_CHECKING:
from sky import resources
from sky import task as task_lib
from sky.backends import cloud_vm_ray_backend
from sky.backends import local_docker_backend
logger = sky_logging.init_logger(__name__)
# NOTE: keep in sync with the cluster template 'file_mounts'.
SKY_REMOTE_APP_DIR = '~/.sky/sky_app'
SKY_RAY_YAML_REMOTE_PATH = '~/.sky/sky_ray.yml'
IP_ADDR_REGEX = r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}'
SKY_REMOTE_PATH = '~/.sky/wheels'
SKY_USER_FILE_PATH = '~/.sky/generated'
BOLD = '\033[1m'
RESET_BOLD = '\033[0m'
# Do not use /tmp because it gets cleared on VM restart.
_SKY_REMOTE_FILE_MOUNTS_DIR = '~/.sky/file_mounts/'
_LAUNCHED_HEAD_PATTERN = re.compile(r'(\d+) ray[._]head[._]default')
_LAUNCHED_LOCAL_WORKER_PATTERN = re.compile(r'(\d+) node_')
_LAUNCHED_WORKER_PATTERN = re.compile(r'(\d+) ray[._]worker[._]default')
# Intentionally not using prefix 'rf' for the string format because yapf have a
# bug with python=3.6.
# 10.133.0.5: ray.worker.default,
_LAUNCHING_IP_PATTERN = re.compile(
r'({}): ray[._]worker[._]default'.format(IP_ADDR_REGEX))
WAIT_HEAD_NODE_IP_MAX_ATTEMPTS = 3
# We use fixed IP address to avoid DNS lookup blocking the check, for machine
# with no internet connection.
# Refer to: https://stackoverflow.com/questions/3764291/how-can-i-see-if-theres-an-available-and-active-network-connection-in-python # pylint: disable=line-too-long
_TEST_IP = 'https://8.8.8.8'
# Allow each CPU thread take 2 tasks.
# Note: This value cannot be too small, otherwise OOM issue may occur.
DEFAULT_TASK_CPU_DEMAND = 0.5
# Mapping from reserved cluster names to the corresponding group name (logging
# purpose).
# NOTE: each group can only have one reserved cluster name for now.
SKY_RESERVED_CLUSTER_NAMES: Dict[str, str] = {
spot_lib.SPOT_CONTROLLER_NAME: 'Managed spot controller'
}
# Filelocks for the cluster status change.
CLUSTER_STATUS_LOCK_PATH = os.path.expanduser('~/.sky/.{}.lock')
CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS = 20
# Remote dir that holds our runtime files.
_REMOTE_RUNTIME_FILES_DIR = '~/.sky/.runtime_files'
# Include the fields that will be used for generating tags that distinguishes
# the cluster in ray, to avoid the stopped cluster being discarded due to
# updates in the yaml template.
# Some notes on the fields:
# - 'provider' fields will be used for bootstrapping and insert more new items
# in 'node_config'.
# - keeping the auth is not enough becuase the content of the key file will be
# used for calculating the hash.
# TODO(zhwu): Keep in sync with the fields used in https://github.com/ray-project/ray/blob/e4ce38d001dbbe09cd21c497fedd03d692b2be3e/python/ray/autoscaler/_private/commands.py#L687-L701
_RAY_YAML_KEYS_TO_RESTORE_FOR_BACK_COMPATIBILITY = {
'cluster_name', 'provider', 'auth', 'node_config'
}
# For these keys, don't use the old yaml's version and instead use the new yaml's.
# - zone: The zone field of the old yaml may be '1a,1b,1c' (AWS) while the actual
# zone of the launched cluster is '1a'. If we restore, then on capacity errors
# it's possible to failover to 1b, which leaves a leaked instance in 1a. Here,
# we use the new yaml's zone field, which is guaranteed to be the existing zone
# '1a'.
_RAY_YAML_KEYS_TO_RESTORE_EXCEPTIONS = [
('provider', 'availability_zone'),
]
def is_ip(s: str) -> bool:
"""Returns whether this string matches IP_ADDR_REGEX."""
return len(re.findall(IP_ADDR_REGEX, s)) == 1
def _get_yaml_path_from_cluster_name(cluster_name: str,
prefix: str = SKY_USER_FILE_PATH) -> str:
output_path = pathlib.Path(
prefix).expanduser().resolve() / f'{cluster_name}.yml'
os.makedirs(output_path.parents[0], exist_ok=True)
return str(output_path)
def fill_template(template_name: str,
variables: Dict,
output_path: Optional[str] = None,
output_prefix: str = SKY_USER_FILE_PATH) -> str:
"""Create a file from a Jinja template and return the filename."""
assert template_name.endswith('.j2'), template_name
template_path = os.path.join(sky.__root_dir__, 'templates', template_name)
if not os.path.exists(template_path):
raise FileNotFoundError(f'Template "{template_name}" does not exist.')
with open(template_path) as fin:
template = fin.read()
if output_path is None:
cluster_name = variables.get('cluster_name')
assert isinstance(cluster_name, str), cluster_name
output_path = _get_yaml_path_from_cluster_name(cluster_name,
output_prefix)
output_path = os.path.abspath(output_path)
# Write out yaml config.
j2_template = jinja2.Template(template)
content = j2_template.render(**variables)
with open(output_path, 'w') as fout:
fout.write(content)
return output_path
def _optimize_file_mounts(yaml_path: str) -> None:
"""Optimize file mounts in the given ray yaml file.
Runtime files handling:
List of runtime files to be uploaded to cluster:
- yaml config (for autostopping)
- wheel
- credentials
Format is {dst: src}.
"""
yaml_config = common_utils.read_yaml(yaml_path)
file_mounts = yaml_config.get('file_mounts', {})
# Remove the file mounts added by the newline.
if '' in file_mounts:
assert file_mounts[''] == '', file_mounts['']
file_mounts.pop('')
# Putting these in file_mounts hurts provisioning speed, as each file
# opens/closes an SSH connection. Instead, we:
# - cp locally them into a directory
# - upload that directory as a file mount (1 connection)
# - use a remote command to move all runtime files to their right places.
# Local tmp dir holding runtime files.
local_runtime_files_dir = tempfile.mkdtemp()
new_file_mounts = {_REMOTE_RUNTIME_FILES_DIR: local_runtime_files_dir}
# (For remote) Build a command that copies runtime files to their right
# destinations.
# NOTE: we copy rather than move, because when launching >1 node, head node
# is fully set up first, and if moving then head node's files would already
# move out of _REMOTE_RUNTIME_FILES_DIR, which would cause setting up
# workers (from the head's files) to fail. An alternative is softlink
# (then we need to make sure the usage of runtime files follow links).
commands = []
basenames = set()
for dst, src in file_mounts.items():
src_basename = os.path.basename(src)
dst_basename = os.path.basename(dst)
dst_parent_dir = os.path.dirname(dst)
# Validate by asserts here as these files are added by our backend.
# Our runtime files (wheel, yaml, credentials) do not have backslashes.
assert not src.endswith('/'), src
assert not dst.endswith('/'), dst
assert src_basename not in basenames, (
f'Duplicated src basename: {src_basename}; mounts: {file_mounts}')
basenames.add(src_basename)
# Our runtime files (wheel, yaml, credentials) are not relative paths.
assert dst_parent_dir, f'Found relative destination path: {dst}'
mkdir_parent = f'mkdir -p {dst_parent_dir}'
if os.path.isdir(os.path.expanduser(src)):
# Special case for directories. If the dst already exists as a
# folder, directly copy the folder will create a subfolder under
# the dst.
mkdir_parent = f'mkdir -p {dst}'
src_basename = f'{src_basename}/*'
mv = (f'cp -r {_REMOTE_RUNTIME_FILES_DIR}/{src_basename} '
f'{dst_parent_dir}/{dst_basename}')
fragment = f'({mkdir_parent} && {mv})'
commands.append(fragment)
postprocess_runtime_files_command = ' && '.join(commands)
setup_commands = yaml_config.get('setup_commands', [])
if setup_commands:
setup_commands[
0] = f'{postprocess_runtime_files_command}; {setup_commands[0]}'
else:
setup_commands = [postprocess_runtime_files_command]
yaml_config['file_mounts'] = new_file_mounts
yaml_config['setup_commands'] = setup_commands
# (For local) Move all runtime files, including the just-written yaml, to
# local_runtime_files_dir/.
all_local_sources = ''
for local_src in file_mounts.values():
full_local_src = str(pathlib.Path(local_src).expanduser())
# Add quotes for paths containing spaces.
all_local_sources += f'{full_local_src!r} '
# Takes 10-20 ms on laptop incl. 3 clouds' credentials.
subprocess.run(f'cp -r {all_local_sources} {local_runtime_files_dir}/',
shell=True,
check=True)
common_utils.dump_yaml(yaml_path, yaml_config)
def path_size_megabytes(path: str) -> int:
"""Returns the size of 'path' (directory or file) in megabytes."""
resolved_path = pathlib.Path(path).expanduser().resolve()
git_exclude_filter = ''
if (resolved_path / command_runner.GIT_EXCLUDE).exists():
# Ensure file exists; otherwise, rsync will error out.
git_exclude_filter = command_runner.RSYNC_EXCLUDE_OPTION.format(
str(resolved_path / command_runner.GIT_EXCLUDE))
rsync_output = str(
subprocess.check_output(
f'rsync {command_runner.RSYNC_DISPLAY_OPTION} '
f'{command_runner.RSYNC_FILTER_OPTION} '
f'{git_exclude_filter} --dry-run {path!r}',
shell=True).splitlines()[-1])
total_bytes = rsync_output.split(' ')[3].replace(',', '')
return int(total_bytes) // 10**6
class FileMountHelper(object):
"""Helper for handling file mounts."""
@classmethod
def wrap_file_mount(cls, path: str) -> str:
"""Prepends ~/<opaque dir>/ to a path to work around permission issues.
Examples:
/root/hello.txt -> ~/<opaque dir>/root/hello.txt
local.txt -> ~/<opaque dir>/local.txt
After the path is synced, we can later create a symlink to this wrapped
path from the original path, e.g., in the initialization_commands of the
ray autoscaler YAML.
"""
return os.path.join(_SKY_REMOTE_FILE_MOUNTS_DIR, path.lstrip('/'))
@classmethod
def make_safe_symlink_command(cls, *, source: str, target: str) -> str:
"""Returns a command that safely symlinks 'source' to 'target'.
All intermediate directories of 'source' will be owned by $USER,
excluding the root directory (/).
'source' must be an absolute path; both 'source' and 'target' must not
end with a slash (/).
This function is needed because a simple 'ln -s target source' may
fail: 'source' can have multiple levels (/a/b/c), its parent dirs may
or may not exist, can end with a slash, or may need sudo access, etc.
Cases of <target: local> file mounts and their behaviors:
/existing_dir: ~/local/dir
- error out saying this cannot be done as LHS already exists
/existing_file: ~/local/file
- error out saying this cannot be done as LHS already exists
/existing_symlink: ~/local/file
- overwrite the existing symlink; this is important because `sky
launch` can be run multiple times
Paths that start with ~/ and /tmp/ do not have the above
restrictions; they are delegated to rsync behaviors.
"""
assert os.path.isabs(source), source
assert not source.endswith('/') and not target.endswith('/'), (source,
target)
# Below, use sudo in case the symlink needs sudo access to create.
# Prepare to create the symlink:
# 1. make sure its dir(s) exist & are owned by $USER.
dir_of_symlink = os.path.dirname(source)
commands = [
# mkdir, then loop over '/a/b/c' as /a, /a/b, /a/b/c. For each,
# chown $USER on it so user can use these intermediate dirs
# (excluding /).
f'sudo mkdir -p {dir_of_symlink}',
# p: path so far
('(p=""; '
f'for w in $(echo {dir_of_symlink} | tr "/" " "); do '
'p=${p}/${w}; sudo chown $USER $p; done)')
]
# 2. remove any existing symlink (ln -f may throw 'cannot
# overwrite directory', if the link exists and points to a
# directory).
commands += [
# Error out if source is an existing, non-symlink directory/file.
f'((test -L {source} && sudo rm {source} &>/dev/null) || '
f'(test ! -e {source} || '
f'(echo "!!! Failed mounting because path exists ({source})"; '
'exit 1)))',
]
commands += [
# Link.
f'sudo ln -s {target} {source}',
# chown. -h to affect symlinks only.
f'sudo chown -h $USER {source}',
]
return ' && '.join(commands)
class SSHConfigHelper(object):
"""Helper for handling local SSH configuration."""
ssh_conf_path = '~/.ssh/config'
ssh_conf_lock_path = os.path.expanduser('~/.sky/ssh_config.lock')
ssh_multinode_path = SKY_USER_FILE_PATH + '/ssh/{}'
@classmethod
def _get_generated_config(cls, autogen_comment: str, host_name: str,
ip: str, username: str, ssh_key_path: str,
proxy_command: Optional[str]):
if proxy_command is not None:
proxy = f'ProxyCommand {proxy_command}'
else:
proxy = ''
# StrictHostKeyChecking=no skips the host key check for the first
# time. UserKnownHostsFile=/dev/null and GlobalKnownHostsFile/dev/null
# prevent the host key from being added to the known_hosts file and
# always return an empty file for known hosts, making the ssh think
# this is a first-time connection, and thus skipping the host key
# check.
codegen = textwrap.dedent(f"""\
{autogen_comment}
Host {host_name}
HostName {ip}
User {username}
IdentityFile {ssh_key_path}
IdentitiesOnly yes
ForwardAgent yes
StrictHostKeyChecking no
UserKnownHostsFile=/dev/null
GlobalKnownHostsFile=/dev/null
Port 22
{proxy}
""".rstrip())
codegen = codegen + '\n'
return codegen
@classmethod
@timeline.FileLockEvent(ssh_conf_lock_path)
def add_cluster(
cls,
cluster_name: str,
ips: List[str],
auth_config: Dict[str, str],
):
"""Add authentication information for cluster to local SSH config file.
If a host with `cluster_name` already exists and the configuration was
not added by sky, then `ip` is used to identify the host instead in the
file.
If a host with `cluster_name` already exists and the configuration was
added by sky (e.g. a spot instance), then the configuration is
overwritten.
Args:
cluster_name: Cluster name (see `sky status`)
ips: List of public IP addresses in the cluster. First IP is head
node.
auth_config: read_yaml(handle.cluster_yaml)['auth']
"""
username = auth_config['ssh_user']
key_path = os.path.expanduser(auth_config['ssh_private_key'])
host_name = cluster_name
sky_autogen_comment = ('# Added by sky (use `sky stop/down '
f'{cluster_name}` to remove)')
overwrite = False
overwrite_begin_idx = None
ip = ips[0]
config_path = os.path.expanduser(cls.ssh_conf_path)
if os.path.exists(config_path):
with open(config_path) as f:
config = f.readlines()
# If an existing config with `cluster_name` exists, raise a warning.
for i, line in enumerate(config):
if line.strip() == f'Host {cluster_name}':
prev_line = config[i - 1] if i - 1 >= 0 else ''
if prev_line.strip().startswith(sky_autogen_comment):
overwrite = True
overwrite_begin_idx = i - 1
else:
logger.warning(f'{cls.ssh_conf_path} contains '
f'host named {cluster_name}.')
host_name = ip
logger.warning(f'Using {ip} to identify host instead.')
if line.strip() == f'Host {ip}':
prev_line = config[i - 1] if i - 1 >= 0 else ''
if prev_line.strip().startswith(sky_autogen_comment):
overwrite = True
overwrite_begin_idx = i - 1
else:
config = ['\n']
with open(config_path, 'w') as f:
f.writelines(config)
os.chmod(config_path, 0o644)
proxy_command = auth_config.get('ssh_proxy_command', None)
codegen = cls._get_generated_config(sky_autogen_comment, host_name, ip,
username, key_path, proxy_command)
# Add (or overwrite) the new config.
if overwrite:
assert overwrite_begin_idx is not None
updated_lines = codegen.splitlines(keepends=True) + ['\n']
config[overwrite_begin_idx:overwrite_begin_idx +
len(updated_lines)] = updated_lines
with open(config_path, 'w') as f:
f.write(''.join(config).strip())
f.write('\n' * 2)
else:
with open(config_path, 'a') as f:
if len(config) > 0 and config[-1] != '\n':
f.write('\n')
f.write(codegen)
f.write('\n')
with open(config_path, 'r+') as f:
config = f.readlines()
if config[-1] != '\n':
f.write('\n')
if len(ips) > 1:
SSHConfigHelper._add_multinode_config(cluster_name, ips[1:],
auth_config)
@classmethod
def _add_multinode_config(
cls,
cluster_name: str,
external_worker_ips: List[str],
auth_config: Dict[str, str],
):
username = auth_config['ssh_user']
key_path = os.path.expanduser(auth_config['ssh_private_key'])
host_name = cluster_name
sky_autogen_comment = ('# Added by sky (use `sky stop/down '
f'{cluster_name}` to remove)')
# Ensure stableness of the aliases worker-<i> by sorting based on
# public IPs.
external_worker_ips = list(sorted(external_worker_ips))
overwrites = [False] * len(external_worker_ips)
overwrite_begin_idxs: List[Optional[int]] = [None
] * len(external_worker_ips)
codegens: List[Optional[str]] = [None] * len(external_worker_ips)
worker_names = []
extra_path_name = cls.ssh_multinode_path.format(cluster_name)
for idx in range(len(external_worker_ips)):
worker_names.append(cluster_name + f'-worker{idx+1}')
config_path = os.path.expanduser(cls.ssh_conf_path)
with open(config_path) as f:
config = f.readlines()
extra_config_path = os.path.expanduser(extra_path_name)
os.makedirs(os.path.dirname(extra_config_path), exist_ok=True)
if not os.path.exists(extra_config_path):
extra_config = ['\n']
with open(extra_config_path, 'w') as f:
f.writelines(extra_config)
else:
with open(extra_config_path) as f:
extra_config = f.readlines()
# Handle Include on top of Config file
include_str = f'Include {extra_config_path}'
for i, line in enumerate(config):
config_str = line.strip()
if config_str == include_str:
break
# Did not find Include string
if 'Host' in config_str:
with open(config_path, 'w') as f:
config.insert(0, '\n')
config.insert(0, include_str + '\n')
config.insert(0, sky_autogen_comment + '\n')
f.write(''.join(config).strip())
f.write('\n' * 2)
break
with open(config_path) as f:
config = f.readlines()
proxy_command = auth_config.get('ssh_proxy_command', None)
# Check if ~/.ssh/config contains existing names
host_lines = [f'Host {c_name}' for c_name in worker_names]
for i, line in enumerate(config):
if line.strip() in host_lines:
idx = host_lines.index(line.strip())
prev_line = config[i - 1] if i > 0 else ''
logger.warning(f'{cls.ssh_conf_path} contains '
f'host named {worker_names[idx]}.')
host_name = external_worker_ips[idx]
logger.warning(f'Using {host_name} to identify host instead.')
codegens[idx] = cls._get_generated_config(
sky_autogen_comment, host_name, external_worker_ips[idx],
username, key_path, proxy_command)
# All workers go to SKY_USER_FILE_PATH/ssh/{cluster_name}
for i, line in enumerate(extra_config):
if line.strip() in host_lines:
idx = host_lines.index(line.strip())
prev_line = extra_config[i - 1] if i > 0 else ''
if prev_line.strip().startswith(sky_autogen_comment):
host_name = worker_names[idx]
overwrites[idx] = True
overwrite_begin_idxs[idx] = i - 1
codegens[idx] = cls._get_generated_config(
sky_autogen_comment, host_name, external_worker_ips[idx],
username, key_path, proxy_command)
# This checks if all codegens have been created.
for idx, ip in enumerate(external_worker_ips):
if not codegens[idx]:
codegens[idx] = cls._get_generated_config(
sky_autogen_comment, worker_names[idx], ip, username,
key_path, proxy_command)
for idx in range(len(external_worker_ips)):
# Add (or overwrite) the new config.
overwrite = overwrites[idx]
overwrite_begin_idx = overwrite_begin_idxs[idx]
codegen = codegens[idx]
assert codegen is not None, (codegens, idx)
if overwrite:
assert overwrite_begin_idx is not None
updated_lines = codegen.splitlines(keepends=True) + ['\n']
extra_config[overwrite_begin_idx:overwrite_begin_idx +
len(updated_lines)] = updated_lines
with open(extra_config_path, 'w') as f:
f.write(''.join(extra_config).strip())
f.write('\n' * 2)
else:
with open(extra_config_path, 'a') as f:
f.write(codegen)
f.write('\n')
# Add trailing new line at the end of the file if it doesn't exit
with open(extra_config_path, 'r+') as f:
extra_config = f.readlines()
if extra_config[-1] != '\n':
f.write('\n')
@classmethod
@timeline.FileLockEvent(ssh_conf_lock_path)
def remove_cluster(
cls,
cluster_name: str,
ip: str,
auth_config: Dict[str, str],
):
"""Remove authentication information for cluster from local SSH config.
If no existing host matching the provided specification is found, then
nothing is removed.
Args:
ip: Head node's IP address.
auth_config: read_yaml(handle.cluster_yaml)['auth']
"""
username = auth_config['ssh_user']
config_path = os.path.expanduser(cls.ssh_conf_path)
if not os.path.exists(config_path):
return
with open(config_path) as f:
config = f.readlines()
start_line_idx = None
# Scan the config for the cluster name.
for i, line in enumerate(config):
next_line = config[i + 1] if i + 1 < len(config) else ''
if (line.strip() == f'HostName {ip}' and
next_line.strip() == f'User {username}'):
start_line_idx = i - 1
break
if start_line_idx is None: # No config to remove.
return
# Scan for end of previous config.
cursor = start_line_idx
while cursor > 0 and len(config[cursor].strip()) > 0:
cursor -= 1
prev_end_line_idx = cursor
# Scan for end of the cluster config.
end_line_idx = None
cursor = start_line_idx + 1
start_line_idx -= 1 # remove auto-generated comment
while cursor < len(config):
if config[cursor].strip().startswith(
'# ') or config[cursor].strip().startswith('Host '):
end_line_idx = cursor
break
cursor += 1
# Remove sky-generated config and update the file.
config[prev_end_line_idx:end_line_idx] = [
'\n'
] if end_line_idx is not None else []
with open(config_path, 'w') as f:
f.write(''.join(config).strip())
f.write('\n' * 2)
SSHConfigHelper._remove_multinode_config(cluster_name)
@classmethod
def _remove_multinode_config(
cls,
cluster_name: str,
):
config_path = os.path.expanduser(cls.ssh_conf_path)
if not os.path.exists(config_path):
return
extra_path_name = cls.ssh_multinode_path.format(cluster_name)
extra_config_path = os.path.expanduser(extra_path_name)
common_utils.remove_file_if_exists(extra_config_path)
# Delete include statement
sky_autogen_comment = ('# Added by sky (use `sky stop/down '
f'{cluster_name}` to remove)')
with open(config_path) as f:
config = f.readlines()
for i, line in enumerate(config):
config_str = line.strip()
if f'Include {extra_config_path}' in config_str:
with open(config_path, 'w') as f:
if i < len(config) - 1 and config[i + 1] == '\n':
del config[i + 1]
# Delete Include string
del config[i]
# Delete Sky Autogen Comment
if i > 0 and sky_autogen_comment in config[i - 1].strip():
del config[i - 1]
f.write(''.join(config))
break
if 'Host' in config_str:
break
def _replace_yaml_dicts(
new_yaml: str, old_yaml: str, restore_key_names: Set[str],
restore_key_names_exceptions: Sequence[Sequence[str]]) -> str:
"""Replaces 'new' with 'old' for all keys in restore_key_names.
The replacement will be applied recursively and only for the blocks
with the key in key_names, and have the same ancestors in both 'new'
and 'old' YAML tree.
The restore_key_names_exceptions is a list of key names that should not
be restored, i.e. those keys will be reset to the value in 'new' YAML
tree after the replacement.
"""
def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]):
for key, value in new_block.items():
if key in restore_key_names:
if key in old_block:
new_block[key] = old_block[key]
else:
del new_block[key]
elif isinstance(value, dict):
if key in old_block:
_restore_block(value, old_block[key])
new_config = yaml.safe_load(new_yaml)
old_config = yaml.safe_load(old_yaml)
excluded_results = {}
# Find all key values excluded from restore
for exclude_restore_key_name_list in restore_key_names_exceptions:
excluded_result = new_config
found_excluded_key = True
for key in exclude_restore_key_name_list:
if (not isinstance(excluded_result, dict) or
key not in excluded_result):
found_excluded_key = False
break
excluded_result = excluded_result[key]
if found_excluded_key:
excluded_results[exclude_restore_key_name_list] = excluded_result
# Restore from old config
_restore_block(new_config, old_config)
# Revert the changes for the excluded key values
for exclude_restore_key_name, value in excluded_results.items():
curr = new_config
for key in exclude_restore_key_name[:-1]:
curr = curr[key]
curr[exclude_restore_key_name[-1]] = value
return common_utils.dump_yaml_str(new_config)
# TODO: too many things happening here - leaky abstraction. Refactor.
@timeline.event
def write_cluster_config(
to_provision: 'resources.Resources',
num_nodes: int,
cluster_config_template: str,
cluster_name: str,
local_wheel_path: pathlib.Path,
wheel_hash: str,
region: Optional[clouds.Region] = None,
zones: Optional[List[clouds.Zone]] = None,
dryrun: bool = False,
keep_launch_fields_in_existing_config: bool = True) -> Dict[str, str]:
"""Fills in cluster configuration templates and writes them out.
Returns: {provisioner: path to yaml, the provisioning spec}.
'provisioner' can be
- 'ray'
- 'tpu-create-script' (if TPU is requested)
- 'tpu-delete-script' (if TPU is requested)
Raises:
exceptions.ResourcesUnavailableError: if the region/zones requested does not appear
in the catalog, or an ssh_proxy_command is specified but not for the given region.
"""
# task.best_resources may not be equal to to_provision if the user
# is running a job with less resources than the cluster has.
cloud = to_provision.cloud
# This can raise a ResourcesUnavailableError, when the region/zones requested
# does not appear in the catalog. It can be triggered when the user changed
# the catalog file, while there is a cluster in the removed region/zone.
# TODO(zhwu): We should change the exception type to a more specific one,
# as the ResourcesUnavailableError is overly used. Also, it would be better
# to move the check out of this function, i.e. the caller should be
# responsible for the validation.
resources_vars = cloud.make_deploy_resources_variables(
to_provision, region, zones)
config_dict = {}
azure_subscription_id = None
if isinstance(cloud, clouds.Azure):
azure_subscription_id = cloud.get_project_id(dryrun=dryrun)
gcp_project_id = None
if isinstance(cloud, clouds.GCP):
gcp_project_id = cloud.get_project_id(dryrun=dryrun)
assert cluster_name is not None
credentials = sky_check.get_cloud_credential_file_mounts()
ip_list = None
auth_config = {'ssh_private_key': auth.PRIVATE_SSH_KEY_PATH}
if isinstance(cloud, clouds.Local):
ip_list = onprem_utils.get_local_ips(cluster_name)
auth_config = onprem_utils.get_local_auth_config(cluster_name)
region_name = resources_vars.get('region')
yaml_path = _get_yaml_path_from_cluster_name(cluster_name)
# Retrieve the ssh_proxy_command for the given cloud / region.
ssh_proxy_command_config = skypilot_config.get_nested(
(str(cloud).lower(), 'ssh_proxy_command'), None)
if (isinstance(ssh_proxy_command_config, str) or
ssh_proxy_command_config is None):
ssh_proxy_command = ssh_proxy_command_config
else:
# ssh_proxy_command_config: Dict[str, str], region_name -> command
# This type check is done by skypilot_config at config load time.
if region_name not in ssh_proxy_command_config:
# Skip this region. The upper layer will handle the failover to
# other regions.
raise exceptions.ResourcesUnavailableError(
f'No ssh_proxy_command provided for region {region_name}. Skipped.'
)
ssh_proxy_command = ssh_proxy_command_config[region_name]
logger.debug(f'Using ssh_proxy_command: {ssh_proxy_command!r}')
# Use a tmp file path to avoid incomplete YAML file being re-used in the
# future.
tmp_yaml_path = yaml_path + '.tmp'
tmp_yaml_path = fill_template(
cluster_config_template,
dict(
resources_vars,
**{
'cluster_name': cluster_name,
'num_nodes': num_nodes,
'disk_size': to_provision.disk_size,
# If the current code is run by controller, propagate the real
# calling user which should've been passed in as the
# SKYPILOT_USER env var (see spot-controller.yaml.j2).
'user': os.environ.get('SKYPILOT_USER', getpass.getuser()),
# AWS only:
# Temporary measure, as deleting per-cluster SGs is too slow.
# See https://github.com/skypilot-org/skypilot/pull/742.
# Generate the name of the security group we're looking for.
# (username, last 4 chars of hash of hostname): for uniquefying
# users on shared-account scenarios.
'security_group': skypilot_config.get_nested(
('aws', 'security_group_name'),
f'sky-sg-{common_utils.user_and_hostname_hash()}'),
'vpc_name': skypilot_config.get_nested(('aws', 'vpc_name'),
None),
'use_internal_ips': skypilot_config.get_nested(
('aws', 'use_internal_ips'), False),
# Not exactly AWS only, but we only test it's supported on AWS
# for now:
'ssh_proxy_command': ssh_proxy_command,
# Azure only:
'azure_subscription_id': azure_subscription_id,
'resource_group': f'{cluster_name}-{region_name}',
# GCP only:
'gcp_project_id': gcp_project_id,
# Ray version.
'ray_version': constants.SKY_REMOTE_RAY_VERSION,
# Cloud credentials for cloud storage.
'credentials': credentials,
# Sky remote utils.
'sky_remote_path': SKY_REMOTE_PATH,
'sky_local_path': str(local_wheel_path),
# Add yaml file path to the template variables.
'sky_ray_yaml_remote_path': SKY_RAY_YAML_REMOTE_PATH,
'sky_ray_yaml_local_path':
tmp_yaml_path
if not isinstance(cloud, clouds.Local) else yaml_path,
'sky_version': str(version.parse(sky.__version__)),
'sky_wheel_hash': wheel_hash,
# Local IP handling (optional).
'head_ip': None if ip_list is None else ip_list[0],
'worker_ips': None if ip_list is None else ip_list[1:],
# Authentication (optional).
**auth_config,
}),
output_path=tmp_yaml_path)
config_dict['cluster_name'] = cluster_name
config_dict['ray'] = yaml_path
if dryrun:
# If dryrun, return the unfinished tmp yaml path.
config_dict['ray'] = tmp_yaml_path
return config_dict
_add_auth_to_cluster_config(cloud, tmp_yaml_path)
# Restore the old yaml content for backward compatibility.
if os.path.exists(yaml_path) and keep_launch_fields_in_existing_config:
with open(yaml_path, 'r') as f:
old_yaml_content = f.read()
with open(tmp_yaml_path, 'r') as f:
new_yaml_content = f.read()
restored_yaml_content = _replace_yaml_dicts(
new_yaml_content, old_yaml_content,
_RAY_YAML_KEYS_TO_RESTORE_FOR_BACK_COMPATIBILITY,
_RAY_YAML_KEYS_TO_RESTORE_EXCEPTIONS)
with open(tmp_yaml_path, 'w') as f:
f.write(restored_yaml_content)
# Optimization: copy the contents of source files in file_mounts to a
# special dir, and upload that as the only file_mount instead. Delay
# calling this optimization until now, when all source files have been
# written and their contents finalized.
#
# Note that the ray yaml file will be copied into that special dir (i.e.,
# uploaded as part of the file_mounts), so the restore for backward
# compatibility should go before this call.
if not isinstance(cloud, clouds.Local):
# Only optimize the file mounts for public clouds now, as local has not
# been fully tested yet.
_optimize_file_mounts(tmp_yaml_path)
# Rename the tmp file to the final YAML path.
os.rename(tmp_yaml_path, yaml_path)
usage_lib.messages.usage.update_ray_yaml(yaml_path)
# For TPU nodes. TPU VMs do not need TPU_NAME.
if (resources_vars.get('tpu_type') is not None and
resources_vars.get('tpu_vm') is None):
tpu_name = resources_vars.get('tpu_name')
if tpu_name is None:
tpu_name = cluster_name
user_file_dir = os.path.expanduser(f'{SKY_USER_FILE_PATH}/')
from sky.skylet.providers.gcp import config as gcp_config # pylint: disable=import-outside-toplevel
config = common_utils.read_yaml(os.path.expanduser(config_dict['ray']))
vpc_name = gcp_config.get_usable_vpc(config)
scripts = tuple(
fill_template(
template_name,
dict(
resources_vars, **{
'tpu_name': tpu_name,
'gcp_project_id': gcp_project_id,
'vpc_name': vpc_name,
}),
# Use new names for TPU scripts so that different runs can use
# different TPUs. Put in SKY_USER_FILE_PATH to be consistent
# with cluster yamls.
output_path=os.path.join(user_file_dir, template_name).replace(
'.sh.j2', f'.{cluster_name}.sh'),
) for template_name in
['gcp-tpu-create.sh.j2', 'gcp-tpu-delete.sh.j2'])
config_dict['tpu-create-script'] = scripts[0]
config_dict['tpu-delete-script'] = scripts[1]
config_dict['tpu_name'] = tpu_name
return config_dict
def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
"""Adds SSH key info to the cluster config.
This function's output removes comments included in the jinja2 template.
"""
config = common_utils.read_yaml(cluster_config_file)
# Check the availability of the cloud type.
if isinstance(cloud, clouds.AWS):
config = auth.setup_aws_authentication(config)
elif isinstance(cloud, clouds.GCP):
config = auth.setup_gcp_authentication(config)
elif isinstance(cloud, clouds.Azure):
config = auth.setup_azure_authentication(config)
elif isinstance(cloud, clouds.Lambda):
config = auth.setup_lambda_authentication(config)
else:
assert isinstance(cloud, clouds.Local), cloud
# Local cluster case, authentication is already filled by the user
# in the local cluster config (in ~/.sky/local/...). There is no need
# for Sky to generate authentication.
pass
common_utils.dump_yaml(cluster_config_file, config)
def get_run_timestamp() -> str:
return 'sky-' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
def get_timestamp_from_run_timestamp(run_timestamp: str) -> float:
return datetime.strptime(