Skip to content

Commit

Permalink
Merge pull request #2880 from harshthakkar01/slurm-maintenance
Browse files Browse the repository at this point in the history
Add reservation support in slurm sync for scheduled maintenance
  • Loading branch information
harshthakkar01 authored Aug 15, 2024
2 parents fd6cfb7 + 356b488 commit cdb3ab5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from itertools import chain
from pathlib import Path
import yaml
from datetime import datetime
from typing import List, Dict, Tuple

import util
from util import (
Expand Down Expand Up @@ -439,7 +441,7 @@ def reconfigure_slurm():
save_config(cfg_new, CONFIG_FILE)
cfg_new = load_config_file(CONFIG_FILE)
util._lkp = Lookup(cfg_new)

if lookup().is_controller:
conf.gen_controller_configs(lookup())
log.info("Restarting slurmctld to make changes take effect.")
Expand All @@ -466,6 +468,82 @@ def update_topology(lkp: util.Lookup) -> None:
log.debug("Topology configuration updated. Reconfiguring Slurm.")
util.scontrol_reconfigure(lkp)


def delete_reservation(lkp: util.Lookup, reservation_name: str) -> None:
util.run(f"{lkp.scontrol} delete reservation {reservation_name}")


def create_reservation(lkp: util.Lookup, reservation_name: str, node: str, start_time: datetime) -> None:
# Format time to be compatible with slurm reservation.
formatted_start_time = start_time.strftime('%Y-%m-%dT%H:%M:%S')
util.run(f"{lkp.scontrol} create reservation user=slurm starttime={formatted_start_time} duration=180 nodes={node} reservationname={reservation_name}")


def get_slurm_reservation_maintenance(lkp: util.Lookup) -> Dict[str, datetime]:
res = util.run(f"{lkp.scontrol} show reservation --json")
all_reservations = json.loads(res.stdout)
reservation_map = {}

for reservation in all_reservations['reservations']:
name = reservation.get('name')
nodes = reservation.get('node_list')
time_epoch = reservation.get('start_time', {}).get('number')

if name is None or nodes is None or time_epoch is None:
continue

if reservation.get('node_count') != 1:
continue

if name != f"{nodes}_maintenance":
continue

reservation_map[name] = datetime.fromtimestamp(time_epoch)

return reservation_map


def get_upcoming_maintenance(lkp: util.Lookup) -> Dict[str, Tuple[str, datetime]]:
upc_maint_map = {}

for node, properties in lkp.instances().items():
if 'upcomingMaintenance' in properties:
start_time = datetime.strptime(properties['upcomingMaintenance']['startTimeWindow']['earliest'], '%Y-%m-%dT%H:%M:%S%z')
upc_maint_map[node + "_maintenance"] = (node, start_time)

return upc_maint_map


def sync_maintenance_reservation(lkp: util.Lookup) -> None:
upc_maint_map = get_upcoming_maintenance(lkp) # map reservation_name -> (node_name, time)
log.debug(f"upcoming-maintenance-vms: {upc_maint_map}")

curr_reservation_map = get_slurm_reservation_maintenance(lkp) # map reservation_name -> time
log.debug(f"curr-reservation-map: {curr_reservation_map}")

del_reservation = set(curr_reservation_map.keys() - upc_maint_map.keys())
create_reservation_map = {}

for res_name, (node, start_time) in upc_maint_map.items():
if res_name in curr_reservation_map:
diff = curr_reservation_map[res_name] - start_time
if abs(diff) <= datetime.timedelta(seconds=1):
continue
else:
del_reservation.add(res_name)
create_reservation_map[res_name] = (node, start_time)
else:
create_reservation_map[res_name] = (node, start_time)

log.debug(f"del-reservation: {del_reservation}")
for res_name in del_reservation:
delete_reservation(lkp, res_name)

log.debug(f"create-reservation-map: {create_reservation_map}")
for res_name, (node, start_time) in create_reservation_map.items():
create_reservation(lkp, res_name, node, start_time)


def main():
try:
reconfigure_slurm()
Expand All @@ -477,15 +555,23 @@ def main():
sync_slurm()
except Exception:
log.exception("failed to sync instances")

try:
sync_placement_groups()
except Exception:
log.exception("failed to sync placement groups")

try:
update_topology(lookup())
except Exception:
log.exception("failed to update topology")

## TODO: Enable reservation for scheduled maintenance.
# try:
# sync_maintenance_reservation(lookup())
# except Exception:
# log.exception("failed to sync slurm reservation for scheduled maintenance")

try:
install_custom_scripts(check_hash=True)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def install_custom_scripts(check_hash=False):
chown_slurm(dirs.custom_scripts / par)
need_update = True
if check_hash and fullpath.exists():
# TODO: MD5 reported by gcloud may differ from the one calculated here (e.g. if blob got gzipped),
# TODO: MD5 reported by gcloud may differ from the one calculated here (e.g. if blob got gzipped),
# consider using gCRC32C
need_update = hash_file(fullpath) != blob.md5_hash
if need_update:
Expand Down Expand Up @@ -501,7 +501,6 @@ def init_log_and_parse(parser: argparse.ArgumentParser) -> argparse.Namespace:
help="Enable detailed api request output",
)
args = parser.parse_args()

loglevel = args.loglevel
if lookup().cfg.enable_debug_logging:
loglevel = logging.DEBUG
Expand Down Expand Up @@ -549,7 +548,6 @@ def log_api_request(request):
"""log.trace info about a compute API request"""
if not lookup().cfg.extra_logging_flags.get("trace_api"):
return

# output the whole request object as pretty yaml
# the body is nested json, so load it as well
rep = json.loads(request.to_json())
Expand Down Expand Up @@ -1632,6 +1630,12 @@ def instances(self, project=None, slurm_cluster_name=None):
slurm_cluster_name=slurm_cluster_name,
instance_information_fields=instance_information_fields,
)

# TODO: Merge this with all fields when upcoming maintenance is
# supported in beta.
if endpoint_version(ApiEndpoint.COMPUTE) == 'alpha':
instance_information_fields.append("upcomingMaintenance")

instance_information_fields = sorted(set(instance_information_fields))
instance_fields = ",".join(instance_information_fields)
fields = f"items.zones.instances({instance_fields}),nextPageToken"
Expand Down Expand Up @@ -1659,7 +1663,7 @@ def properties(inst):
instance_iter = (
(inst["name"], properties(inst))
for inst in chain.from_iterable(
m["instances"] for m in result.get("items", {}).values()
zone.get("instances", []) for zone in result.get("items", {}).values()
)
)
instances.update(
Expand Down

0 comments on commit cdb3ab5

Please sign in to comment.