diff --git a/CHANGELOG.md b/CHANGELOG.md index d7f10341..df972286 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - failed_node - Now possible to initialize a pyslurm.db.Jobs collection with existing job ids or pyslurm.db.Job objects +- Added `as_dict` function to all Collections ### Fixed @@ -26,6 +27,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - no start/end time was specified - the Job was older than a day +### Changed + +- All Collections (like [pyslurm.Jobs](https://pyslurm.github.io/23.2/reference/job/#pyslurm.Jobs)) inherit from `list` now instead of `dict` +- `JobSearchFilter` has been renamed to `JobFilter` + ## [23.2.1](https://github.com/PySlurm/pyslurm/releases/tag/v23.2.1) - 2023-05-18 ### Added @@ -40,7 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - [pyslurm.db.Job](https://pyslurm.github.io/23.2/reference/db/job/#pyslurm.db.Job) - [pyslurm.db.Jobs](https://pyslurm.github.io/23.2/reference/db/job/#pyslurm.db.Jobs) - [pyslurm.db.JobStep](https://pyslurm.github.io/23.2/reference/db/jobstep/#pyslurm.db.JobStep) - - [pyslurm.db.JobSearchFilter](https://pyslurm.github.io/23.2/reference/db/jobsearchfilter/#pyslurm.db.JobSearchFilter) + - [pyslurm.db.JobFilter](https://pyslurm.github.io/23.2/reference/db/jobsearchfilter/#pyslurm.db.JobFilter) - Classes to interact with the Node API - [pyslurm.Node](https://pyslurm.github.io/23.2/reference/node/#pyslurm.Node) - [pyslurm.Nodes](https://pyslurm.github.io/23.2/reference/node/#pyslurm.Nodes) @@ -49,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - [pyslurm.RPCError](https://pyslurm.github.io/23.2/reference/exceptions/#pyslurm.RPCError) - [Utility Functions](https://pyslurm.github.io/23.2/reference/utilities/#pyslurm.utils) -### Changes +### Changed - Completely overhaul the documentation, switch to mkdocs - Rework the tests: Split them into unit and integration tests diff --git a/docs/reference/db/jobfilter.md b/docs/reference/db/jobfilter.md new file mode 100644 index 00000000..21aa55d1 --- /dev/null +++ b/docs/reference/db/jobfilter.md @@ -0,0 +1,6 @@ +--- +title: JobFilter +--- + +::: pyslurm.db.JobFilter + handler: python diff --git a/docs/reference/db/jobsearchfilter.md b/docs/reference/db/jobsearchfilter.md deleted file mode 100644 index fa3864c5..00000000 --- a/docs/reference/db/jobsearchfilter.md +++ /dev/null @@ -1,6 +0,0 @@ ---- -title: JobSearchFilter ---- - -::: pyslurm.db.JobSearchFilter - handler: python diff --git a/docs/reference/index.md b/docs/reference/index.md index 5f66d339..35a6c678 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -37,7 +37,7 @@ The `pyslurm` package is a wrapper around the Slurm C-API * [pyslurm.db.Job][] * [pyslurm.db.JobStep][] * [pyslurm.db.Jobs][] - * [pyslurm.db.JobSearchFilter][] + * [pyslurm.db.JobFilter][] * Node API * [pyslurm.Node][] * [pyslurm.Nodes][] diff --git a/pyslurm/__init__.py b/pyslurm/__init__.py index 06bd804b..4d3a5101 100644 --- a/pyslurm/__init__.py +++ b/pyslurm/__init__.py @@ -9,11 +9,15 @@ sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) +# Initialize slurm api +from pyslurm.api import slurm_init, slurm_fini +slurm_init() + from .pyslurm import * from .__version__ import __version__ -from pyslurm import utils from pyslurm import db +from pyslurm import utils from pyslurm import constants from pyslurm.core.job import ( @@ -32,10 +36,6 @@ ) from pyslurm.core import slurmctld -# Initialize slurm api -from pyslurm.api import slurm_init, slurm_fini -slurm_init() - def version(): return __version__ diff --git a/pyslurm/core/job/job.pxd b/pyslurm/core/job/job.pxd index d1c8ddf8..bee4f9ec 100644 --- a/pyslurm/core/job/job.pxd +++ b/pyslurm/core/job/job.pxd @@ -67,7 +67,7 @@ from pyslurm.slurm cimport ( ) -cdef class Jobs(dict): +cdef class Jobs(list): """A collection of [pyslurm.Job][] objects. Args: diff --git a/pyslurm/core/job/job.pyx b/pyslurm/core/job/job.pyx index 521a42a9..2c33d581 100644 --- a/pyslurm/core/job/job.pyx +++ b/pyslurm/core/job/job.pyx @@ -34,6 +34,7 @@ from typing import Union from pyslurm.utils import cstr, ctime from pyslurm.utils.uint import * from pyslurm.core.job.util import * +from pyslurm.db.cluster import LOCAL_CLUSTER from pyslurm.core.error import ( RPCError, verify_rpc, @@ -47,12 +48,14 @@ from pyslurm.utils.helpers import ( _getgrall_to_dict, _getpwall_to_dict, instance_to_dict, + collection_to_dict, + group_collection_by_cluster, _sum_prop, _get_exit_code, ) -cdef class Jobs(dict): +cdef class Jobs(list): def __cinit__(self): self.info = NULL @@ -63,14 +66,37 @@ cdef class Jobs(dict): def __init__(self, jobs=None, frozen=False): self.frozen = frozen - if isinstance(jobs, dict): - self.update(jobs) - elif jobs is not None: + if isinstance(jobs, list): for job in jobs: if isinstance(job, int): - self[job] = Job(job) + self.append(Job(job)) else: - self[job.id] = job + self.append(job) + elif isinstance(jobs, str): + joblist = jobs.split(",") + self.extend([Job(int(job)) for job in joblist]) + elif isinstance(jobs, dict): + self.extend([job for job in jobs.values()]) + elif jobs is not None: + raise TypeError("Invalid Type: {type(jobs)}") + + def as_dict(self, recursive=False): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + + Returns: + (dict): Collection as a dict. + """ + col = collection_to_dict(self, identifier=Job.id, recursive=recursive) + return col.get(LOCAL_CLUSTER, {}) + + def group_by_cluster(self): + return group_collection_by_cluster(self) @staticmethod def load(preload_passwd_info=False, frozen=False): @@ -124,7 +150,7 @@ cdef class Jobs(dict): job.passwd = passwd job.groups = groups - jobs[job.id] = job + jobs.append(job) # At this point we memcpy'd all the memory for the Jobs. Setting this # to 0 will prevent the slurm job free function to deallocate the @@ -143,28 +169,34 @@ cdef class Jobs(dict): Raises: RPCError: When getting the Jobs from the slurmctld failed. """ - cdef Jobs reloaded_jobs = Jobs.load() + cdef: + Jobs reloaded_jobs + Jobs new_jobs = Jobs() + dict self_dict - for jid in list(self.keys()): + if not self: + return self + + reloaded_jobs = Jobs.load().as_dict() + for idx, jid in enumerate(self): if jid in reloaded_jobs: # Put the new data in. - self[jid] = reloaded_jobs[jid] - elif not self.frozen: - # Remove this instance from the current collection, as the Job - # doesn't exist anymore. - del self[jid] + new_jobs.append(reloaded_jobs[jid]) if not self.frozen: + self_dict = self.as_dict() for jid in reloaded_jobs: - if jid not in self: - self[jid] = reloaded_jobs[jid] + if jid not in self_dict: + new_jobs.append(reloaded_jobs[jid]) + self.clear() + self.extend(new_jobs) return self def load_steps(self): """Load all Job steps for this collection of Jobs. - This function fills in the "steps" attribute for all Jobs in the + This function fills in the `steps` attribute for all Jobs in the collection. !!! note @@ -175,21 +207,16 @@ cdef class Jobs(dict): RPCError: When retrieving the Job information for all the Steps failed. """ - cdef dict step_info = JobSteps.load_all() + cdef dict steps = JobSteps.load().as_dict() - for jid in self: + for idx, job in enumerate(self): # Ignore any Steps from Jobs which do not exist in this # collection. - if jid in step_info: - self[jid].steps = step_info[jid] - - def as_list(self): - """Format the information as list of Job objects. - - Returns: - (list[pyslurm.Job]): List of Job objects - """ - return list(self.values()) + jid = job.id + if jid in steps: + job_steps = self[idx].steps + job_steps.clear() + job_steps.extend(steps[jid].values()) @property def memory(self): @@ -218,6 +245,7 @@ cdef class Job: self.ptr.job_id = job_id self.passwd = {} self.groups = {} + cstr.fmalloc(&self.ptr.cluster, LOCAL_CLUSTER) self.steps = JobSteps.__new__(JobSteps) def _alloc_impl(self): @@ -234,7 +262,9 @@ cdef class Job: self._dealloc_impl() def __eq__(self, other): - return isinstance(other, Job) and self.id == other.id + if isinstance(other, Job): + return self.id == other.id and self.cluster == other.cluster + return NotImplemented @staticmethod def load(job_id): @@ -278,7 +308,7 @@ cdef class Job: if not slurm.IS_JOB_PENDING(wrap.ptr): # Just ignore if the steps couldn't be loaded here. try: - wrap.steps = JobSteps._load(wrap) + wrap.steps = JobSteps._load_single(wrap) except RPCError: pass else: diff --git a/pyslurm/core/job/step.pxd b/pyslurm/core/job/step.pxd index 087742d6..458ee506 100644 --- a/pyslurm/core/job/step.pxd +++ b/pyslurm/core/job/step.pxd @@ -49,7 +49,7 @@ from pyslurm.utils.ctime cimport time_t from pyslurm.core.job.task_dist cimport TaskDistribution -cdef class JobSteps(dict): +cdef class JobSteps(list): """A collection of [pyslurm.JobStep][] objects for a given Job. Args: @@ -64,11 +64,12 @@ cdef class JobSteps(dict): cdef: job_step_info_response_msg_t *info job_step_info_t tmp_info + _job_id @staticmethod - cdef JobSteps _load(Job job) + cdef JobSteps _load_single(Job job) - cdef dict _get_info(self, uint32_t job_id, int flags) + cdef _load_data(self, uint32_t job_id, int flags) cdef class JobStep: diff --git a/pyslurm/core/job/step.pyx b/pyslurm/core/job/step.pyx index f6b60d9c..d4038f54 100644 --- a/pyslurm/core/job/step.pyx +++ b/pyslurm/core/job/step.pyx @@ -26,10 +26,15 @@ from typing import Union from pyslurm.utils import cstr, ctime from pyslurm.utils.uint import * from pyslurm.core.error import RPCError, verify_rpc +from pyslurm.db.cluster import LOCAL_CLUSTER from pyslurm.utils.helpers import ( signal_to_num, instance_to_dict, uid_to_name, + collection_to_dict, + group_collection_by_cluster, + humanize_step_id, + dehumanize_step_id, ) from pyslurm.core.job.util import cpu_freq_int_to_str from pyslurm.utils.ctime import ( @@ -41,7 +46,7 @@ from pyslurm.utils.ctime import ( ) -cdef class JobSteps(dict): +cdef class JobSteps(list): def __dealloc__(self): slurm_free_job_step_info_response_msg(self.info) @@ -49,44 +54,74 @@ cdef class JobSteps(dict): def __cinit__(self): self.info = NULL - def __init__(self): - pass + def __init__(self, steps=None): + if isinstance(steps, list): + self.extend(steps) + elif steps is not None: + raise TypeError("Invalid Type: {type(steps)}") + + def as_dict(self, recursive=False): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + + Returns: + (dict): Collection as a dict. + """ + col = collection_to_dict(self, identifier=JobStep.id, + recursive=recursive, group_id=JobStep.job_id) + col = col.get(LOCAL_CLUSTER, {}) + if self._job_id: + return col.get(self._job_id, {}) + + return col + + def group_by_cluster(self): + return group_collection_by_cluster(self) @staticmethod - def load(job): - """Load the Steps for a specific Job + def load(job_id=0): + """Load the Job Steps from the system. Args: - job (Union[Job, int]): - The Job for which the Steps should be loaded + job_id (Union[Job, int]): + The Job for which the Steps should be loaded. Returns: (pyslurm.JobSteps): JobSteps of the Job """ - cdef Job _job - _job = Job.load(job.id) if isinstance(job, Job) else Job.load(job) - return JobSteps._load(_job) + cdef: + Job job + JobSteps steps + + if job_id: + job = Job.load(job_id.id if isinstance(job_id, Job) else job_id) + steps = JobSteps._load_single(job) + steps._job_id = job.id + return steps + else: + steps = JobSteps() + return steps._load_data(0, slurm.SHOW_ALL) @staticmethod - cdef JobSteps _load(Job job): - cdef JobSteps steps = JobSteps.__new__(JobSteps) + cdef JobSteps _load_single(Job job): + cdef JobSteps steps = JobSteps() - step_info = steps._get_info(job.id, slurm.SHOW_ALL) - if not step_info and not slurm.IS_JOB_PENDING(job.ptr): + steps._load_data(job.id, slurm.SHOW_ALL) + if not steps and not slurm.IS_JOB_PENDING(job.ptr): msg = f"Failed to load step info for Job {job.id}." raise RPCError(msg=msg) - # No super().__init__() needed? Cython probably already initialized - # the dict automatically. - steps.update(step_info[job.id]) return steps - cdef dict _get_info(self, uint32_t job_id, int flags): + cdef _load_data(self, uint32_t job_id, int flags): cdef: JobStep step - JobSteps steps uint32_t cnt = 0 - dict out = {} rc = slurm_get_job_steps(0, job_id, slurm.NO_VAL, &self.info, flags) @@ -102,12 +137,7 @@ cdef class JobSteps(dict): # Prevent double free if xmalloc fails mid-loop and a MemoryError # is raised by replacing it with a zeroed-out job_step_info_t. self.info.job_steps[cnt] = self.tmp_info - - if not step.job_id in out: - steps = JobSteps.__new__(JobSteps) - out[step.job_id] = steps - - out[step.job_id].update({step.id: step}) + self.append(step) # At this point we memcpy'd all the memory for the Steps. Setting this # to 0 will prevent the slurm step free function to deallocate the @@ -117,18 +147,7 @@ cdef class JobSteps(dict): # instance. self.info.job_step_count = 0 - return out - - @staticmethod - def load_all(): - """Loads all the steps in the system. - - Returns: - (dict): A dict where every JobID (key) is mapped with an instance - of its JobSteps (value). - """ - cdef JobSteps steps = JobSteps.__new__(JobSteps) - return steps._get_info(slurm.NO_VAL, slurm.SHOW_ALL) + return self cdef class JobStep: @@ -425,29 +444,3 @@ cdef class JobStep: @property def slurm_protocol_version(self): return u32_parse(self.ptr.start_protocol_ver) - - -def humanize_step_id(sid): - if sid == slurm.SLURM_BATCH_SCRIPT: - return "batch" - elif sid == slurm.SLURM_EXTERN_CONT: - return "extern" - elif sid == slurm.SLURM_INTERACTIVE_STEP: - return "interactive" - elif sid == slurm.SLURM_PENDING_STEP: - return "pending" - else: - return sid - - -def dehumanize_step_id(sid): - if sid == "batch": - return slurm.SLURM_BATCH_SCRIPT - elif sid == "extern": - return slurm.SLURM_EXTERN_CONT - elif sid == "interactive": - return slurm.SLURM_INTERACTIVE_STEP - elif sid == "pending": - return slurm.SLURM_PENDING_STEP - else: - return int(sid) diff --git a/pyslurm/core/node.pxd b/pyslurm/core/node.pxd index 19684612..ea59e6ff 100644 --- a/pyslurm/core/node.pxd +++ b/pyslurm/core/node.pxd @@ -57,7 +57,7 @@ from pyslurm.utils.ctime cimport time_t from pyslurm.utils.uint cimport * -cdef class Nodes(dict): +cdef class Nodes(list): """A collection of [pyslurm.Node][] objects. Args: @@ -233,6 +233,8 @@ cdef class Node: dict passwd dict groups + cdef readonly cluster + @staticmethod cdef _swap_data(Node dst, Node src) diff --git a/pyslurm/core/node.pyx b/pyslurm/core/node.pyx index 9c1ecf30..609016fe 100644 --- a/pyslurm/core/node.pyx +++ b/pyslurm/core/node.pyx @@ -28,6 +28,7 @@ from pyslurm.utils import ctime from pyslurm.utils.uint import * from pyslurm.core.error import RPCError, verify_rpc from pyslurm.utils.ctime import timestamp_to_date, _raw_time +from pyslurm.db.cluster import LOCAL_CLUSTER from pyslurm.utils.helpers import ( uid_to_name, gid_to_name, @@ -36,13 +37,15 @@ from pyslurm.utils.helpers import ( _getpwall_to_dict, cpubind_to_num, instance_to_dict, + collection_to_dict, + group_collection_by_cluster, _sum_prop, nodelist_from_range_str, nodelist_to_range_str, ) -cdef class Nodes(dict): +cdef class Nodes(list): def __dealloc__(self): slurm_free_node_info_msg(self.info) @@ -53,17 +56,38 @@ cdef class Nodes(dict): self.part_info = NULL def __init__(self, nodes=None): - if isinstance(nodes, dict): - self.update(nodes) - elif isinstance(nodes, str): - nodelist = nodelist_from_range_str(nodes) - self.update({node: Node(node) for node in nodelist}) - elif nodes is not None: + if isinstance(nodes, list): for node in nodes: if isinstance(node, str): - self[node] = Node(node) + self.append(Node(node)) else: - self[node.name] = node + self.append(node) + elif isinstance(nodes, str): + nodelist = nodes.split(",") + self.extend([Node(node) for node in nodelist]) + elif isinstance(nodes, dict): + self.extend([node for node in nodes.values()]) + elif nodes is not None: + raise TypeError("Invalid Type: {type(nodes)}") + + def as_dict(self, recursive=False): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + + Returns: + (dict): Collection as a dict. + """ + col = collection_to_dict(self, identifier=Node.name, + recursive=recursive) + return col.get(LOCAL_CLUSTER, {}) + + def group_by_cluster(self): + return group_collection_by_cluster(self) @staticmethod def load(preload_passwd_info=False): @@ -117,7 +141,7 @@ cdef class Nodes(dict): node.passwd = passwd node.groups = groups - nodes[node.name] = node + nodes.append(node) # At this point we memcpy'd all the memory for the Nodes. Setting this # to 0 will prevent the slurm node free function to deallocate the @@ -141,27 +165,19 @@ cdef class Nodes(dict): RPCError: When getting the Nodes from the slurmctld failed. """ cdef Nodes reloaded_nodes - our_nodes = list(self.keys()) - if not our_nodes: - return None + if not self: + return self - reloaded_nodes = Nodes.load() - for node in list(self.keys()): + reloaded_nodes = Nodes.load().as_dict() + for idx, node in enumerate(self): + node_name = node.name if node in reloaded_nodes: # Put the new data in. - self[node] = reloaded_nodes[node] + self[idx] = reloaded_nodes[node_name] return self - def as_list(self): - """Format the information as list of Node objects. - - Returns: - (list[pyslurm.Node]): List of Node objects - """ - return list(self.values()) - def modify(self, Node changes): """Modify all Nodes in a collection. @@ -183,8 +199,11 @@ cdef class Nodes(dict): >>> # Apply the changes to all the nodes >>> nodes.modify(changes) """ - cdef Node n = changes - node_str = nodelist_to_range_str(list(self.keys())) + cdef: + Node n = changes + list node_names = [node.name for node in self] + + node_str = nodelist_to_range_str(node_names) n._alloc_umsg() cstr.fmalloc(&n.umsg.node_names, node_str) verify_rpc(slurm_update_node(n.umsg)) @@ -235,6 +254,7 @@ cdef class Node: def __init__(self, name=None, **kwargs): self._alloc_impl() self.name = name + self.cluster = LOCAL_CLUSTER for k, v in kwargs.items(): setattr(self, k, v) @@ -282,6 +302,7 @@ cdef class Node: wrap._alloc_info() wrap.passwd = {} wrap.groups = {} + wrap.cluster = LOCAL_CLUSTER memcpy(wrap.info, in_ptr, sizeof(node_info_t)) return wrap diff --git a/pyslurm/core/partition.pxd b/pyslurm/core/partition.pxd index 37d6a37c..b10366b8 100644 --- a/pyslurm/core/partition.pxd +++ b/pyslurm/core/partition.pxd @@ -58,7 +58,7 @@ from pyslurm.utils.uint cimport * from pyslurm.core cimport slurmctld -cdef class Partitions(dict): +cdef class Partitions(list): """A collection of [pyslurm.Partition][] objects. Args: @@ -216,5 +216,7 @@ cdef class Partition: int power_save_enabled slurmctld.Config slurm_conf + cdef readonly cluster + @staticmethod cdef Partition from_ptr(partition_info_t *in_ptr) diff --git a/pyslurm/core/partition.pyx b/pyslurm/core/partition.pyx index 99aaf5e8..56375d33 100644 --- a/pyslurm/core/partition.pyx +++ b/pyslurm/core/partition.pyx @@ -30,6 +30,7 @@ from pyslurm.utils.uint import * from pyslurm.core.error import RPCError, verify_rpc from pyslurm.utils.ctime import timestamp_to_date, _raw_time from pyslurm.constants import UNLIMITED +from pyslurm.db.cluster import LOCAL_CLUSTER from pyslurm.utils.helpers import ( uid_to_name, gid_to_name, @@ -37,6 +38,8 @@ from pyslurm.utils.helpers import ( _getpwall_to_dict, cpubind_to_num, instance_to_dict, + collection_to_dict, + group_collection_by_cluster, _sum_prop, dehumanize, ) @@ -46,7 +49,8 @@ from pyslurm.utils.ctime import ( ) -cdef class Partitions(dict): +cdef class Partitions(list): + def __dealloc__(self): slurm_free_partition_info_msg(self.info) @@ -54,17 +58,38 @@ cdef class Partitions(dict): self.info = NULL def __init__(self, partitions=None): - if isinstance(partitions, dict): - self.update(partitions) - elif isinstance(partitions, str): - partlist = partitions.split(",") - self.update({part: Partition(part) for part in partlist}) - elif partitions is not None: + if isinstance(partitions, list): for part in partitions: if isinstance(part, str): - self[part] = Partition(part) + self.append(Partition(part)) else: - self[part.name] = part + self.append(part) + elif isinstance(partitions, str): + partlist = partitions.split(",") + self.extend([Partition(part) for part in partlist]) + elif isinstance(partitions, dict): + self.extend([part for part in partitions.values()]) + elif partitions is not None: + raise TypeError("Invalid Type: {type(partitions)}") + + def as_dict(self, recursive=False): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + + Returns: + (dict): Collection as a dict. + """ + col = collection_to_dict(self, identifier=Partition.name, + recursive=recursive) + return col.get(LOCAL_CLUSTER, {}) + + def group_by_cluster(self): + return group_collection_by_cluster(self) @staticmethod def load(): @@ -103,7 +128,7 @@ cdef class Partitions(dict): partition.power_save_enabled = power_save_enabled partition.slurm_conf = slurm_conf - partitions[partition.name] = partition + partitions.append(partition) # At this point we memcpy'd all the memory for the Partitions. Setting # this to 0 will prevent the slurm partition free function to @@ -129,17 +154,17 @@ cdef class Partitions(dict): Raises: RPCError: When getting the Partitions from the slurmctld failed. """ - cdef Partitions reloaded_parts - our_parts = list(self.keys()) + cdef dict reloaded_parts - if not our_parts: + if not self: return self - reloaded_parts = Partitions.load() - for part in our_parts: - if part in reloaded_parts: + reloaded_parts = Partitions.load().as_dict() + for idx, part in enumerate(self): + part_name = part.name + if part_name in reloaded_parts: # Put the new data in. - self[part] = reloaded_parts[part] + self[idx] = reloaded_parts[part_name] return self @@ -164,17 +189,9 @@ cdef class Partitions(dict): >>> # Apply the changes to all the partitions >>> parts.modify(changes) """ - for part in self.values(): + for part in self: part.modify(changes) - def as_list(self): - """Format the information as list of Partition objects. - - Returns: - (list): List of Partition objects - """ - return list(self.values()) - @property def total_cpus(self): return _sum_prop(self, Partition.total_cpus) @@ -192,6 +209,7 @@ cdef class Partition: def __init__(self, name=None, **kwargs): self._alloc_impl() self.name = name + self.cluster = LOCAL_CLUSTER for k, v in kwargs.items(): setattr(self, k, v) @@ -214,6 +232,7 @@ cdef class Partition: cdef Partition from_ptr(partition_info_t *in_ptr): cdef Partition wrap = Partition.__new__(Partition) wrap._alloc_impl() + wrap.cluster = LOCAL_CLUSTER memcpy(wrap.ptr, in_ptr, sizeof(partition_info_t)) return wrap @@ -255,7 +274,7 @@ cdef class Partition: >>> import pyslurm >>> part = pyslurm.Partition.load("normal") """ - partitions = Partitions.load() + partitions = Partitions.load().as_dict() if name not in partitions: raise RPCError(msg=f"Partition '{name}' doesn't exist") diff --git a/pyslurm/db/__init__.py b/pyslurm/db/__init__.py index bb34e232..0e78a734 100644 --- a/pyslurm/db/__init__.py +++ b/pyslurm/db/__init__.py @@ -25,6 +25,7 @@ from .job import ( Job, Jobs, + JobFilter, JobSearchFilter, ) from .tres import ( @@ -34,5 +35,11 @@ from .qos import ( QualitiesOfService, QualityOfService, - QualityOfServiceSearchFilter, + QualityOfServiceFilter, ) +from .assoc import ( + Associations, + Association, + AssociationFilter, +) +from . import cluster diff --git a/pyslurm/db/assoc.pxd b/pyslurm/db/assoc.pxd new file mode 100644 index 00000000..12a0cde1 --- /dev/null +++ b/pyslurm/db/assoc.pxd @@ -0,0 +1,87 @@ +######################################################################### +# assoc.pxd - pyslurm slurmdbd association api +######################################################################### +# Copyright (C) 2023 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm cimport slurm +from pyslurm.slurm cimport ( + slurmdb_assoc_rec_t, + slurmdb_assoc_cond_t, + slurmdb_associations_get, + slurmdb_destroy_assoc_rec, + slurmdb_destroy_assoc_cond, + slurmdb_init_assoc_rec, + slurmdb_associations_modify, + try_xmalloc, +) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, + make_char_list, + slurm_list_to_pylist, + qos_list_to_pylist, +) +from pyslurm.db.tres cimport ( + _set_tres_limits, + TrackableResources, + TrackableResourceLimits, +) +from pyslurm.db.connection cimport Connection +from pyslurm.utils cimport cstr +from pyslurm.utils.uint cimport * +from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list + +cdef _parse_assoc_ptr(Association ass) +cdef _create_assoc_ptr(Association ass, conn=*) + + +cdef class Associations(list): + pass + + +cdef class AssociationFilter: + cdef slurmdb_assoc_cond_t *ptr + + cdef public: + users + ids + + +cdef class Association: + cdef: + slurmdb_assoc_rec_t *ptr + dict qos_data + dict tres_data + + cdef public: + group_tres + group_tres_mins + group_tres_run_mins + max_tres_mins_per_job + max_tres_run_mins_per_user + max_tres_per_job + max_tres_per_node + qos + + @staticmethod + cdef Association from_ptr(slurmdb_assoc_rec_t *in_ptr) + diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx new file mode 100644 index 00000000..d1ac4789 --- /dev/null +++ b/pyslurm/db/assoc.pyx @@ -0,0 +1,455 @@ +######################################################################### +# assoc.pyx - pyslurm slurmdbd association api +######################################################################### +# Copyright (C) 2023 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import RPCError +from pyslurm.utils.helpers import ( + instance_to_dict, + collection_to_dict, + group_collection_by_cluster, + user_to_uid, +) +from pyslurm.utils.uint import * +from pyslurm.db.connection import _open_conn_or_error +from pyslurm.db.cluster import LOCAL_CLUSTER + + +cdef class Associations(list): + + def __init__(self): + pass + + def as_dict(self, recursive=False, group_by_cluster=False): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + group_by_cluster (bool, optional): + By default, only the Jobs from your local Cluster are + returned. If this is set to `True`, then all the Jobs in the + collection will be grouped by the Cluster - with the name of + the cluster as the key and the value being the collection as + another dict. + + Returns: + (dict): Collection as a dict. + """ + col = collection_to_dict(self, identifier=Association.id, + recursive=recursive) + if not group_by_cluster: + return col.get(LOCAL_CLUSTER, {}) + + return col + + def group_by_cluster(self): + return group_collection_by_cluster(self) + + @staticmethod + def load(AssociationFilter db_filter=None, Connection db_connection=None): + cdef: + Associations out = Associations() + Association assoc + AssociationFilter cond = db_filter + SlurmList assoc_data + SlurmListItem assoc_ptr + Connection conn + dict qos_data + dict tres_data + + # Prepare SQL Filter + if not db_filter: + cond = AssociationFilter() + cond._create() + + # Setup DB Conn + conn = _open_conn_or_error(db_connection) + + # Fetch Assoc Data + assoc_data = SlurmList.wrap(slurmdb_associations_get( + conn.ptr, cond.ptr)) + + if assoc_data.is_null: + raise RPCError(msg="Failed to get Association data from slurmdbd") + + # Fetch other necessary dependencies needed for translating some + # attributes (i.e QoS IDs to its name) + qos_data = QualitiesOfService.load(db_connection=conn).as_dict( + name_is_key=False) + tres_data = TrackableResources.load(db_connection=conn).as_dict( + name_is_key=False) + + # Setup Association objects + for assoc_ptr in SlurmList.iter_and_pop(assoc_data): + assoc = Association.from_ptr(assoc_ptr.data) + assoc.qos_data = qos_data + assoc.tres_data = tres_data + _parse_assoc_ptr(assoc) + out.append(assoc) + + return out + + @staticmethod + def modify(db_filter, Association changes, Connection db_connection=None): + cdef: + AssociationFilter afilter + Connection conn + SlurmList response + SlurmListItem response_ptr + list out = [] + + # Prepare SQL Filter + if isinstance(db_filter, Associations): + assoc_ids = [ass.id for ass in db_filter] + afilter = AssociationFilter(ids=assoc_ids) + else: + afilter = db_filter + afilter._create() + + # Setup DB conn + conn = _open_conn_or_error(db_connection) + + # Any data that isn't parsed yet or needs validation is done in this + # function. + _create_assoc_ptr(changes, conn) + + # Modify associations, get the result + # This returns a List of char* with the associations that were + # modified + response = SlurmList.wrap(slurmdb_associations_modify( + conn.ptr, afilter.ptr, changes.ptr)) + + if not response.is_null and response.cnt: + for response_ptr in response: + response_str = cstr.to_unicode(response_ptr.data) + if not response_str: + continue + + # TODO: Better format + out.append(response_str) + + elif not response.is_null: + # There was no real error, but simply nothing has been modified + raise RPCError(msg="Nothing was modified") + else: + # Autodetects the last slurm error + raise RPCError() + + if not db_connection: + # Autocommit if no connection was explicitly specified. + conn.commit() + + return out + + +cdef class AssociationFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_assoc_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_assoc_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_assoc_cond_t") + + def _parse_users(self): + if not self.users: + return None + return list({user_to_uid(user) for user in self.users}) + + def _create(self): + self._alloc() + cdef slurmdb_assoc_cond_t *ptr = self.ptr + + make_char_list(&ptr.user_list, self.users) + + +cdef class Association: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs): + self._alloc_impl() + self.id = 0 + self.cluster = LOCAL_CLUSTER + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc_impl() + + def _dealloc_impl(self): + slurmdb_destroy_assoc_rec(self.ptr) + self.ptr = NULL + + def _alloc_impl(self): + if not self.ptr: + self.ptr = try_xmalloc( + sizeof(slurmdb_assoc_rec_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_assoc_rec_t") + + slurmdb_init_assoc_rec(self.ptr, 0) + + @staticmethod + cdef Association from_ptr(slurmdb_assoc_rec_t *in_ptr): + cdef Association wrap = Association.__new__(Association) + wrap.ptr = in_ptr + return wrap + + def as_dict(self): + """Database Association information formatted as a dictionary. + + Returns: + (dict): Database Association information as dict + """ + return instance_to_dict(self) + + def __eq__(self, other): + if isinstance(other, Association): + return self.id == other.id and self.cluster == other.cluster + return NotImplemented + + @property + def account(self): + return cstr.to_unicode(self.ptr.acct) + + @account.setter + def account(self, val): + cstr.fmalloc(&self.ptr.acct, val) + + @property + def cluster(self): + return cstr.to_unicode(self.ptr.cluster) + + @cluster.setter + def cluster(self, val): + cstr.fmalloc(&self.ptr.cluster, val) + + @property + def comment(self): + return cstr.to_unicode(self.ptr.comment) + + @comment.setter + def comment(self, val): + cstr.fmalloc(&self.ptr.comment, val) + + # uint32_t def_qos_id + + # uint16_t flags (ASSOC_FLAG_*) + + @property + def group_jobs(self): + return u32_parse(self.ptr.grp_jobs, zero_is_noval=False) + + @group_jobs.setter + def group_jobs(self, val): + self.ptr.grp_jobs = u32(val, zero_is_noval=False) + + @property + def group_jobs_accrue(self): + return u32_parse(self.ptr.grp_jobs_accrue, zero_is_noval=False) + + @group_jobs_accrue.setter + def group_jobs_accrue(self, val): + self.ptr.grp_jobs_accrue = u32(val, zero_is_noval=False) + + @property + def group_submit_jobs(self): + return u32_parse(self.ptr.grp_submit_jobs, zero_is_noval=False) + + @group_submit_jobs.setter + def group_submit_jobs(self, val): + self.ptr.grp_submit_jobs = u32(val, zero_is_noval=False) + + @property + def group_wall_time(self): + return u32_parse(self.ptr.grp_wall, zero_is_noval=False) + + @group_wall_time.setter + def group_wall_time(self, val): + self.ptr.grp_wall = u32(val, zero_is_noval=False) + + @property + def id(self): + return u32_parse(self.ptr.id) + + @id.setter + def id(self, val): + self.ptr.id = val + + @property + def is_default(self): + return u16_parse_bool(self.ptr.is_def) + + @property + def lft(self): + return u32_parse(self.ptr.lft) + + @property + def max_jobs(self): + return u32_parse(self.ptr.max_jobs, zero_is_noval=False) + + @max_jobs.setter + def max_jobs(self, val): + self.ptr.max_jobs = u32(val, zero_is_noval=False) + + @property + def max_jobs_accrue(self): + return u32_parse(self.ptr.max_jobs_accrue, zero_is_noval=False) + + @max_jobs_accrue.setter + def max_jobs_accrue(self, val): + self.ptr.max_jobs_accrue = u32(val, zero_is_noval=False) + + @property + def max_submit_jobs(self): + return u32_parse(self.ptr.max_submit_jobs, zero_is_noval=False) + + @max_submit_jobs.setter + def max_submit_jobs(self, val): + self.ptr.max_submit_jobs = u32(val, zero_is_noval=False) + + @property + def max_wall_time_per_job(self): + return u32_parse(self.ptr.max_wall_pj, zero_is_noval=False) + + @max_wall_time_per_job.setter + def max_wall_time_per_job(self, val): + self.ptr.max_wall_pj = u32(val, zero_is_noval=False) + + @property + def min_priority_threshold(self): + return u32_parse(self.ptr.min_prio_thresh, zero_is_noval=False) + + @min_priority_threshold.setter + def min_priority_threshold(self, val): + self.ptr.min_prio_thresh = u32(val, zero_is_noval=False) + + @property + def parent_account(self): + return cstr.to_unicode(self.ptr.parent_acct) + + @property + def parent_account_id(self): + return u32_parse(self.ptr.parent_id, zero_is_noval=False) + + @property + def partition(self): + return cstr.to_unicode(self.ptr.partition) + + @partition.setter + def partition(self, val): + cstr.fmalloc(&self.ptr.partition, val) + + @property + def priority(self): + return u32_parse(self.ptr.priority, zero_is_noval=False) + + @priority.setter + def priority(self, val): + self.ptr.priority = u32(val) + + @property + def rgt(self): + return u32_parse(self.ptr.rgt) + + @property + def shares(self): + return u32_parse(self.ptr.shares_raw, zero_is_noval=False) + + @shares.setter + def shares(self, val): + self.ptr.shares_raw = u32(val) + + @property + def user(self): + return cstr.to_unicode(self.ptr.user) + + @user.setter + def user(self, val): + cstr.fmalloc(&self.ptr.user, val) + + +cdef _parse_assoc_ptr(Association ass): + cdef: + dict tres = ass.tres_data + dict qos = ass.qos_data + + ass.group_tres = TrackableResourceLimits.from_ids( + ass.ptr.grp_tres, tres) + ass.group_tres_mins = TrackableResourceLimits.from_ids( + ass.ptr.grp_tres_mins, tres) + ass.group_tres_run_mins = TrackableResourceLimits.from_ids( + ass.ptr.grp_tres_mins, tres) + ass.max_tres_mins_per_job = TrackableResourceLimits.from_ids( + ass.ptr.max_tres_mins_pj, tres) + ass.max_tres_run_mins_per_user = TrackableResourceLimits.from_ids( + ass.ptr.max_tres_run_mins, tres) + ass.max_tres_per_job = TrackableResourceLimits.from_ids( + ass.ptr.max_tres_pj, tres) + ass.max_tres_per_node = TrackableResourceLimits.from_ids( + ass.ptr.max_tres_pn, tres) + ass.qos = qos_list_to_pylist(ass.ptr.qos_list, qos) + + +cdef _create_assoc_ptr(Association ass, conn=None): + # _set_tres_limits will also check if specified TRES are valid and + # translate them to its ID which is why we need to load the current TRES + # available in the system. + ass.tres_data = TrackableResources.load(db_connection=conn) + _set_tres_limits(&ass.ptr.grp_tres, ass.group_tres, ass.tres_data) + _set_tres_limits(&ass.ptr.grp_tres_mins, ass.group_tres_mins, + ass.tres_data) + _set_tres_limits(&ass.ptr.grp_tres_run_mins, ass.group_tres_run_mins, + ass.tres_data) + _set_tres_limits(&ass.ptr.max_tres_mins_pj, ass.max_tres_mins_per_job, + ass.tres_data) + _set_tres_limits(&ass.ptr.max_tres_run_mins, ass.max_tres_run_mins_per_user, + ass.tres_data) + _set_tres_limits(&ass.ptr.max_tres_pj, ass.max_tres_per_job, + ass.tres_data) + _set_tres_limits(&ass.ptr.max_tres_pn, ass.max_tres_per_node, + ass.tres_data) + + # _set_qos_list will also check if specified QoS are valid and translate + # them to its ID, which is why we need to load the current QOS available + # in the system. + ass.qos_data = QualitiesOfService.load(db_connection=conn) + _set_qos_list(&ass.ptr.qos_list, self.qos, ass.qos_data) + diff --git a/pyslurm/db/cluster.pxd b/pyslurm/db/cluster.pxd new file mode 100644 index 00000000..30acdbde --- /dev/null +++ b/pyslurm/db/cluster.pxd @@ -0,0 +1,27 @@ +######################################################################### +# cluster.pxd - pyslurm slurmdbd cluster api +######################################################################### +# Copyright (C) 2023 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + + +from pyslurm cimport slurm +from pyslurm.utils cimport cstr diff --git a/pyslurm/db/cluster.pyx b/pyslurm/db/cluster.pyx new file mode 100644 index 00000000..436183a8 --- /dev/null +++ b/pyslurm/db/cluster.pyx @@ -0,0 +1,31 @@ +######################################################################### +# cluster.pyx - pyslurm slurmdbd cluster api +######################################################################### +# Copyright (C) 2023 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core import slurmctld + + +LOCAL_CLUSTER = cstr.to_unicode(slurm.slurm_conf.cluster_name) +if not LOCAL_CLUSTER: + slurm_conf = slurmctld.Config.load() + LOCAL_CLUSTER = slurm_conf.cluster diff --git a/pyslurm/db/connection.pyx b/pyslurm/db/connection.pyx index eab6572d..67ef7603 100644 --- a/pyslurm/db/connection.pyx +++ b/pyslurm/db/connection.pyx @@ -25,6 +25,16 @@ from pyslurm.core.error import RPCError +def _open_conn_or_error(conn): + if not conn: + conn = Connection.open() + + if not conn.is_open: + raise ValueError("Database connection is not open") + + return conn + + cdef class Connection: def __cinit__(self): diff --git a/pyslurm/db/job.pxd b/pyslurm/db/job.pxd index 0faac2fd..fc395943 100644 --- a/pyslurm/db/job.pxd +++ b/pyslurm/db/job.pxd @@ -55,7 +55,7 @@ from pyslurm.db.qos cimport QualitiesOfService from pyslurm.db.tres cimport TrackableResources, TrackableResource -cdef class JobSearchFilter: +cdef class JobFilter: """Query-Conditions for Jobs in the Slurm Database. Args: @@ -150,11 +150,9 @@ cdef class JobSearchFilter: with_env -cdef class Jobs(dict): +cdef class Jobs(list): """A collection of [pyslurm.db.Job][] objects.""" - cdef: - SlurmList info - Connection db_conn + pass cdef class Job: @@ -285,7 +283,7 @@ cdef class Job: """ cdef: slurmdb_job_rec_t *ptr - QualitiesOfService qos_data + dict qos_data cdef public: JobSteps steps diff --git a/pyslurm/db/job.pyx b/pyslurm/db/job.pyx index af86f704..636e1137 100644 --- a/pyslurm/db/job.pyx +++ b/pyslurm/db/job.pyx @@ -27,6 +27,7 @@ from pyslurm.core.error import RPCError, PyslurmError from pyslurm.core import slurmctld from typing import Any from pyslurm.utils.uint import * +from pyslurm.db.cluster import LOCAL_CLUSTER from pyslurm.utils.ctime import ( date_to_timestamp, timestr_to_mins, @@ -39,11 +40,14 @@ from pyslurm.utils.helpers import ( uid_to_name, nodelist_to_range_str, instance_to_dict, + collection_to_dict, + group_collection_by_cluster, _get_exit_code, ) +from pyslurm.db.connection import _open_conn_or_error -cdef class JobSearchFilter: +cdef class JobFilter: def __cinit__(self): self.ptr = NULL @@ -73,14 +77,19 @@ cdef class JobSearchFilter: return None qos_id_list = [] - qos = QualitiesOfService.load() - for q in self.qos: - if isinstance(q, int): - qos_id_list.append(q) - elif q in qos: - qos_id_list.append(str(qos[q].id)) - else: - raise ValueError(f"QoS {q} does not exist") + qos_data = QualitiesOfService.load() + for user_input in self.qos: + found = False + for qos in qos_data: + if (qos.id == user_input + or qos.name == user_input + or qos == user_input): + qos_id_list.append(str(qos.id)) + found = True + break + + if not found: + raise ValueError(f"QoS '{user_input}' does not exist") return qos_id_list @@ -96,11 +105,9 @@ cdef class JobSearchFilter: def _parse_clusters(self): if not self.clusters: - # Get the local cluster name # This is a requirement for some other parameters to function # correctly, like self.nodelist - slurm_conf = slurmctld.Config.load() - return [slurm_conf.cluster] + return [LOCAL_CLUSTER] elif self.clusters == "all": return None else: @@ -178,31 +185,71 @@ cdef class JobSearchFilter: slurmdb_job_cond_def_start_end(ptr) -cdef class Jobs(dict): +# Alias +JobSearchFilter = JobFilter + + +cdef class Jobs(list): def __init__(self, jobs=None): - if isinstance(jobs, dict): - self.update(jobs) - elif isinstance(jobs, str): - joblist = jobs.split(",") - self.update({int(job): Job(job) for job in joblist}) - elif jobs is not None: + if isinstance(jobs, list): for job in jobs: if isinstance(job, int): - self[job] = Job(job) + self.append(Job(job)) else: - self[job.name] = job + self.append(job) + elif isinstance(jobs, str): + joblist = jobs.split(",") + self.extend([Job(job) for job in joblist]) + elif isinstance(jobs, dict): + self.extend([job for job in jobs.values()]) + elif jobs is not None: + raise TypeError("Invalid Type: {type(jobs)}") + + def as_dict(self, recursive=False, group_by_cluster=False): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + group_by_cluster (bool, optional): + By default, only the Jobs from your local Cluster are + returned. If this is set to `True`, then all the Jobs in the + collection will be grouped by the Cluster - with the name of + the cluster as the key and the value being the collection as + another dict. + + Returns: + (dict): Collection as a dict. + """ + col = collection_to_dict(self, identifier=Job.id, recursive=recursive) + if not group_by_cluster: + return col.get(LOCAL_CLUSTER, {}) + + return col + + def group_by_cluster(self): + """Group Jobs by cluster name + + Returns: + (dict[str, Jobs]): Jobs grouped by cluster. + """ + return group_collection_by_cluster(self) @staticmethod - def load(search_filter=None): + def load(JobFilter db_filter=None, Connection db_connection=None): """Load Jobs from the Slurm Database Implements the slurmdb_jobs_get RPC. Args: - search_filter (pyslurm.db.JobSearchFilter): + db_filter (pyslurm.db.JobFilter): A search filter that the slurmdbd will apply when retrieving Jobs from the database. + db_connection (pyslurm.db.Connection): + An open database connection. Returns: (pyslurm.db.Jobs): A Collection of database Jobs. @@ -223,30 +270,35 @@ cdef class Jobs(dict): >>> import pyslurm >>> accounts = ["acc1", "acc2"] - >>> search_filter = pyslurm.db.JobSearchFilter(accounts=accounts) - >>> db_jobs = pyslurm.db.Jobs.load(search_filter) + >>> db_filter = pyslurm.db.JobFilter(accounts=accounts) + >>> db_jobs = pyslurm.db.Jobs.load(db_filter) """ cdef: - Jobs jobs = Jobs() + Jobs out = Jobs() Job job - JobSearchFilter cond + JobFilter cond = db_filter + SlurmList job_data SlurmListItem job_ptr - QualitiesOfService qos_data - - if search_filter: - cond = search_filter - else: - cond = JobSearchFilter() + Connection conn + dict qos_data + # Prepare SQL Filter + if not db_filter: + cond = JobFilter() cond._create() - jobs.db_conn = Connection.open() - jobs.info = SlurmList.wrap(slurmdb_jobs_get(jobs.db_conn.ptr, - cond.ptr)) - if jobs.info.is_null: + + # Setup DB Conn + conn = _open_conn_or_error(db_connection) + + # Fetch Job data + job_data = SlurmList.wrap(slurmdb_jobs_get(conn.ptr, cond.ptr)) + if job_data.is_null: raise RPCError(msg="Failed to get Jobs from slurmdbd") - qos_data = QualitiesOfService.load(name_is_key=False, - db_connection=jobs.db_conn) + # Fetch other necessary dependencies needed for translating some + # attributes (i.e QoS IDs to its name) + qos_data = QualitiesOfService.load(db_connection=conn).as_dict( + name_is_key=False) # TODO: also get trackable resources with slurmdb_tres_get and store # it in each job instance. tres_alloc_str and tres_req_str only @@ -256,23 +308,23 @@ cdef class Jobs(dict): # TODO: For multi-cluster support, remove duplicate federation jobs # TODO: How to handle the possibility of duplicate job ids that could # appear if IDs on a cluster are resetted? - for job_ptr in SlurmList.iter_and_pop(jobs.info): + for job_ptr in SlurmList.iter_and_pop(job_data): job = Job.from_ptr(job_ptr.data) job.qos_data = qos_data job._create_steps() JobStatistics._sum_step_stats_for_job(job, job.steps) - jobs[job.id] = job + out.append(job) - return jobs + return out @staticmethod - def modify(search_filter, Job changes, db_connection=None): + def modify(db_filter, Job changes, db_connection=None): """Modify Slurm database Jobs. Implements the slurm_job_modify RPC. Args: - search_filter (Union[pyslurm.db.JobSearchFilter, pyslurm.db.Jobs]): + db_filter (Union[pyslurm.db.JobFilter, pyslurm.db.Jobs]): A filter to decide which Jobs should be modified. changes (pyslurm.db.Job): Another [pyslurm.db.Job][] object that contains all the @@ -307,9 +359,9 @@ cdef class Jobs(dict): >>> import pyslurm >>> - >>> search_filter = pyslurm.db.JobSearchFilter(ids=[9999]) + >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) >>> changes = pyslurm.db.Job(comment="A comment for the job") - >>> modified_jobs = pyslurm.db.Jobs.modify(search_filter, changes) + >>> modified_jobs = pyslurm.db.Jobs.modify(db_filter, changes) >>> print(modified_jobs) >>> [9999] @@ -321,10 +373,10 @@ cdef class Jobs(dict): >>> import pyslurm >>> >>> db_conn = pyslurm.db.Connection.open() - >>> search_filter = pyslurm.db.JobSearchFilter(ids=[9999]) + >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) >>> changes = pyslurm.db.Job(comment="A comment for the job") >>> modified_jobs = pyslurm.db.Jobs.modify( - ... search_filter, changes, db_conn) + ... db_filter, changes, db_conn) >>> >>> # Now you can first examine which Jobs have been modified >>> print(modified_jobs) @@ -333,28 +385,29 @@ cdef class Jobs(dict): >>> # changes >>> db_conn.commit() """ - cdef: - Job job = changes - JobSearchFilter jfilter - Connection conn = db_connection + JobFilter cond + Connection conn SlurmList response SlurmListItem response_ptr list out = [] - conn = Connection.open() if not conn else conn - if not conn.is_open: - raise ValueError("Database connection is not open") - - if isinstance(search_filter, Jobs): - job_ids = list(search_filter.keys()) - jfilter = JobSearchFilter(ids=job_ids) + # Prepare SQL Filter + if isinstance(db_filter, Jobs): + job_ids = [job.id for job in self] + cond = JobFilter(ids=job_ids) else: - jfilter = search_filter + cond = db_filter + cond._create() + + # Setup DB Conn + conn = _open_conn_or_error(db_connection) - jfilter._create() + # Modify Jobs, get the result + # This returns a List of char* with the Jobs ids that were + # modified response = SlurmList.wrap( - slurmdb_job_modify(conn.ptr, jfilter.ptr, job.ptr)) + slurmdb_job_modify(conn.ptr, cond.ptr, changes.ptr)) if not response.is_null and response.cnt: for response_ptr in response: @@ -391,9 +444,10 @@ cdef class Job: def __cinit__(self): self.ptr = NULL - def __init__(self, job_id=0, **kwargs): + def __init__(self, job_id=0, cluster=LOCAL_CLUSTER, **kwargs): self._alloc_impl() self.ptr.jobid = int(job_id) + cstr.fmalloc(&self.ptr.cluster, cluster) for k, v in kwargs.items(): setattr(self, k, v) @@ -417,7 +471,7 @@ cdef class Job: return wrap @staticmethod - def load(job_id, with_script=False, with_env=False): + def load(job_id, cluster=LOCAL_CLUSTER, with_script=False, with_env=False): """Load the information for a specific Job from the Database. Args: @@ -444,13 +498,15 @@ cdef class Job: >>> print(db_job.script) """ - jfilter = JobSearchFilter(ids=[int(job_id)], - with_script=with_script, with_env=with_env) + jfilter = JobFilter(ids=[int(job_id)], clusters=[cluster], + with_script=with_script, with_env=with_env) jobs = Jobs.load(jfilter) - if not jobs or job_id not in jobs: - raise RPCError(msg=f"Job {job_id} does not exist") + if not jobs: + raise RPCError(msg=f"Job {job_id} does not exist on " + f"Cluster {cluster}") - return jobs[job_id] + # TODO: There might be multiple entries when job ids were reset. + return jobs[0] def _create_steps(self): cdef: @@ -503,7 +559,7 @@ cdef class Job: Raises: RPCError: When modifying the Job failed. """ - cdef JobSearchFilter jfilter = JobSearchFilter(ids=[self.id]) + cdef JobFilter jfilter = JobFilter(ids=[self.id]) Jobs.modify(jfilter, changes, db_connection) @property diff --git a/pyslurm/db/qos.pxd b/pyslurm/db/qos.pxd index b2b0bcf9..9cb3df86 100644 --- a/pyslurm/db/qos.pxd +++ b/pyslurm/db/qos.pxd @@ -30,6 +30,7 @@ from pyslurm.slurm cimport ( slurmdb_destroy_qos_cond, slurmdb_qos_get, slurm_preempt_mode_num, + List, try_xmalloc, ) from pyslurm.db.util cimport ( @@ -40,14 +41,14 @@ from pyslurm.db.util cimport ( from pyslurm.db.connection cimport Connection from pyslurm.utils cimport cstr +cdef _set_qos_list(List *in_list, vals, QualitiesOfService data) -cdef class QualitiesOfService(dict): - cdef: - SlurmList info - Connection db_conn +cdef class QualitiesOfService(list): + pass -cdef class QualityOfServiceSearchFilter: + +cdef class QualityOfServiceFilter: cdef slurmdb_qos_cond_t *ptr cdef public: diff --git a/pyslurm/db/qos.pyx b/pyslurm/db/qos.pyx index 2851587e..a01ef9b0 100644 --- a/pyslurm/db/qos.pyx +++ b/pyslurm/db/qos.pyx @@ -23,46 +23,72 @@ # cython: language_level=3 from pyslurm.core.error import RPCError -from pyslurm.utils.helpers import instance_to_dict +from pyslurm.utils.helpers import instance_to_dict, collection_to_dict_global +from pyslurm.db.connection import _open_conn_or_error -cdef class QualitiesOfService(dict): +cdef class QualitiesOfService(list): def __init__(self): pass + def as_dict(self, recursive=False, name_is_key=True): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + name_is_key (bool, optional): + By default, the keys in this dict are the names of each QoS. + If this is set to `False`, then the unique ID of the QoS will + be used as dict keys. + + Returns: + (dict): Collection as a dict. + """ + identifier = QualityOfService.name + if not name_is_key: + identifier = QualityOfService.id + + return collection_to_dict_global(self, identifier=identifier, + recursive=recursive) + @staticmethod - def load(search_filter=None, name_is_key=True, db_connection=None): + def load(QualityOfServiceFilter db_filter=None, + Connection db_connection=None): cdef: - QualitiesOfService qos_dict = QualitiesOfService() + QualitiesOfService out = QualitiesOfService() QualityOfService qos - QualityOfServiceSearchFilter cond + QualityOfServiceFilter cond = db_filter + SlurmList qos_data SlurmListItem qos_ptr - Connection conn = db_connection - - if search_filter: - cond = search_filter - else: - cond = QualityOfServiceSearchFilter() + Connection conn + # Prepare SQL Filter + if not db_filter: + cond = QualityOfServiceFilter() cond._create() - qos_dict.db_conn = Connection.open() if not conn else conn - qos_dict.info = SlurmList.wrap(slurmdb_qos_get(qos_dict.db_conn.ptr, - cond.ptr)) - if qos_dict.info.is_null: + + # Setup DB Conn + conn = _open_conn_or_error(db_connection) + + # Fetch QoS Data + qos_data = SlurmList.wrap(slurmdb_qos_get(conn.ptr, cond.ptr)) + + if qos_data.is_null: raise RPCError(msg="Failed to get QoS data from slurmdbd") - for qos_ptr in SlurmList.iter_and_pop(qos_dict.info): + # Setup QOS objects + for qos_ptr in SlurmList.iter_and_pop(qos_data): qos = QualityOfService.from_ptr(qos_ptr.data) - if name_is_key: - qos_dict[qos.name] = qos - else: - qos_dict[qos.id] = qos + out.append(qos) - return qos_dict + return out -cdef class QualityOfServiceSearchFilter: +cdef class QualityOfServiceFilter: def __cinit__(self): self.ptr = NULL @@ -168,12 +194,12 @@ cdef class QualityOfService: RPCError: If requesting the information from the database was not sucessful. """ - qfilter = QualityOfServiceSearchFilter(names=[name]) + qfilter = QualityOfServiceFilter(names=[name]) qos_data = QualitiesOfService.load(qfilter) - if not qos_data or name not in qos_data: + if not qos_data: raise RPCError(msg=f"QualityOfService {name} does not exist") - return qos_data[name] + return qos_data[0] @property def name(self): @@ -190,3 +216,24 @@ cdef class QualityOfService: @property def id(self): return self.ptr.id + + +def _qos_names_to_ids(qos_list, QualitiesOfService data): + cdef list out = [] + if not qos_list: + return None + + return [_validate_qos_single(qid, data) for qid in qos_list] + + +def _validate_qos_single(qid, QualitiesOfService data): + for item in data: + if qid == item.id or qid == item.name: + return item.id + + raise ValueError(f"Invalid QOS specified: {qid}") + + +cdef _set_qos_list(List *in_list, vals, QualitiesOfService data): + qos_ids = _qos_names_to_ids(vals, data) + make_char_list(in_list, qos_ids) diff --git a/pyslurm/db/step.pyx b/pyslurm/db/step.pyx index 22a46fa8..fa4ab8bb 100644 --- a/pyslurm/db/step.pyx +++ b/pyslurm/db/step.pyx @@ -32,9 +32,9 @@ from pyslurm.utils.helpers import ( uid_to_name, instance_to_dict, _get_exit_code, + humanize_step_id, ) from pyslurm.core.job.util import cpu_freq_int_to_str -from pyslurm.core.job.step import humanize_step_id cdef class JobStep: diff --git a/pyslurm/db/tres.pxd b/pyslurm/db/tres.pxd index 40d28799..41ed1b4d 100644 --- a/pyslurm/db/tres.pxd +++ b/pyslurm/db/tres.pxd @@ -25,18 +25,59 @@ from pyslurm.utils cimport cstr from libc.stdint cimport uint64_t from pyslurm.slurm cimport ( slurmdb_tres_rec_t, + slurmdb_tres_cond_t, + slurmdb_destroy_tres_cond, + slurmdb_init_tres_cond, slurmdb_destroy_tres_rec, slurmdb_find_tres_count_in_string, + slurmdb_tres_get, try_xmalloc, ) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, +) +from pyslurm.db.connection cimport Connection + +cdef find_tres_count(char *tres_str, typ, on_noval=*, on_inf=*) +cdef find_tres_limit(char *tres_str, typ) +cdef merge_tres_str(char **tres_str, typ, val) +cdef _tres_ids_to_names(char *tres_str, dict tres_data) +cdef _set_tres_limits(char **dest, TrackableResourceLimits src, + TrackableResources tres_data) + +cdef class TrackableResourceLimits: + + cdef public: + cpu + mem + energy + node + billing + fs + vmem + pages + gres + license + + @staticmethod + cdef from_ids(char *tres_id_str, dict tres_data) -cdef class TrackableResources(dict): + +cdef class TrackableResourceFilter: + cdef slurmdb_tres_cond_t *ptr + + +cdef class TrackableResources(list): cdef public raw_str @staticmethod cdef TrackableResources from_str(char *tres_str) + @staticmethod + cdef find_count_in_str(char *tres_str, typ, on_noval=*, on_inf=*) + cdef class TrackableResource: cdef slurmdb_tres_rec_t *ptr diff --git a/pyslurm/db/tres.pyx b/pyslurm/db/tres.pyx index f4e84130..df93dda0 100644 --- a/pyslurm/db/tres.pyx +++ b/pyslurm/db/tres.pyx @@ -23,13 +23,175 @@ # cython: language_level=3 from pyslurm.utils.uint import * +from pyslurm.constants import UNLIMITED +from pyslurm.core.error import RPCError +from pyslurm.utils.helpers import instance_to_dict, collection_to_dict_global +from pyslurm.utils import cstr +from pyslurm.db.connection import _open_conn_or_error +import json -cdef class TrackableResources(dict): +TRES_TYPE_DELIM = "/" + + +cdef class TrackableResourceLimits: + + def __init__(self, **kwargs): + self.fs = {} + self.gres = {} + self.license = {} + + for k, v in kwargs.items(): + if TRES_TYPE_DELIM in k: + typ, name = self._unflatten_tres(k) + cur_val = getattr(self, typ) + + if not isinstance(cur_val, dict): + raise ValueError(f"TRES Type {typ} cannot have a name " + f"({name}). Invalid Value: {typ}/{name}") + + cur_val.update({name : int(v)}) + setattr(self, typ, cur_val) + else: + setattr(self, k, v) + + @staticmethod + cdef from_ids(char *tres_id_str, dict tres_data): + tres_list = _tres_ids_to_names(tres_id_str, tres_data) + if not tres_list: + return None + + cdef TrackableResourceLimits out = TrackableResourceLimits() + + for tres in tres_list: + typ, name, cnt = tres + cur_val = getattr(out, typ, slurm.NO_VAL64) + if cur_val != slurm.NO_VAL64: + if isinstance(cur_val, dict): + cur_val.update({name : cnt}) + setattr(out, typ, cur_val) + else: + setattr(out, typ, cnt) + + return out + + def _validate(self, TrackableResources tres_data): + id_dict = _tres_names_to_ids(self.as_dict(flatten_limits=True), + tres_data) + return id_dict + + def _unflatten_tres(self, type_and_name): + typ, name = type_and_name.split(TRES_TYPE_DELIM, 1) + return typ, name + + def _flatten_tres(self, typ, vals): + cdef dict out = {} + for name, cnt in vals.items(): + out[f"{typ}{TRES_TYPE_DELIM}{name}"] = cnt + + return out + + def as_dict(self, flatten_limits=False): + cdef dict inst_dict = instance_to_dict(self) + + if flatten_limits: + vals = inst_dict.pop("fs") + inst_dict.update(self._flatten_tres("fs", vals)) + + vals = inst_dict.pop("license") + inst_dict.update(self._flatten_tres("license", vals)) + + vals = inst_dict.pop("gres") + inst_dict.update(self._flatten_tres("gres", vals)) + + return inst_dict + + +cdef class TrackableResourceFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_tres_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_tres_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_tres_cond_t") + slurmdb_init_tres_cond(self.ptr, 0) + + def _create(self): + self._alloc() + + +cdef class TrackableResources(list): def __init__(self): pass + def as_dict(self, recursive=False, name_is_key=True): + """Convert the collection data to a dict. + + Args: + recursive (bool, optional): + By default, the objects will not be converted to a dict. If + this is set to `True`, then additionally all objects are + converted to dicts. + name_is_key (bool, optional): + By default, the keys in this dict are the names of each TRES. + If this is set to `False`, then the unique ID of the TRES will + be used as dict keys. + + Returns: + (dict): Collection as a dict. + """ + identifier = TrackableResource.type_and_name + if not name_is_key: + identifier = TrackableResource.id + + return collection_to_dict_global(self, identifier=identifier, + recursive=recursive) + + @staticmethod + def load(Connection db_connection=None): + cdef: + TrackableResources out = TrackableResources() + TrackableResource tres + Connection conn + SlurmList tres_data + SlurmListItem tres_ptr + TrackableResourceFilter db_filter = TrackableResourceFilter() + + # Prepare SQL Filter + db_filter._create() + + # Setup DB Conn + conn = _open_conn_or_error(db_connection) + + # Fetch TRES data + tres_data = SlurmList.wrap(slurmdb_tres_get(conn.ptr, db_filter.ptr)) + + if tres_data.is_null: + raise RPCError(msg="Failed to get TRES data from slurmdbd") + + # Setup TRES objects + for tres_ptr in SlurmList.iter_and_pop(tres_data): + tres = TrackableResource.from_ptr( + tres_ptr.data) + out.append(tres) + + return out + @staticmethod cdef TrackableResources from_str(char *tres_str): cdef: @@ -51,16 +213,8 @@ cdef class TrackableResources(dict): return tres @staticmethod - def find_count_in_str(tres_str, typ): - if not tres_str: - return 0 - - cdef uint64_t tmp - tmp = slurmdb_find_tres_count_in_string(tres_str, typ) - if tmp == slurm.INFINITE64 or tmp == slurm.NO_VAL64: - return 0 - else: - return tmp + cdef find_count_in_str(char *tres_str, typ, on_noval=0, on_inf=0): + return find_tres_count(tres_str, typ, on_noval, on_inf) cdef class TrackableResource: @@ -92,6 +246,9 @@ cdef class TrackableResource: wrap.ptr = in_ptr return wrap + def as_dict(self): + return instance_to_dict(self) + @property def id(self): return self.ptr.id @@ -104,9 +261,94 @@ cdef class TrackableResource: def type(self): return cstr.to_unicode(self.ptr.type) + @property + def type_and_name(self): + type_and_name = self.type + if self.name: + type_and_name = f"{type_and_name}{TRES_TYPE_DELIM}{self.name}" + + return type_and_name + @property def count(self): return u64_parse(self.ptr.count) # rec_count # alloc_secs + + +cdef find_tres_count(char *tres_str, typ, on_noval=0, on_inf=0): + if not tres_str: + return on_noval + + cdef uint64_t tmp + tmp = slurmdb_find_tres_count_in_string(tres_str, typ) + if tmp == slurm.INFINITE64: + return on_inf + elif tmp == slurm.NO_VAL64: + return on_noval + else: + return tmp + + +cdef find_tres_limit(char *tres_str, typ): + return find_tres_count(tres_str, typ, on_noval=None, on_inf=UNLIMITED) + + +cdef merge_tres_str(char **tres_str, typ, val): + cdef uint64_t _val = u64(dehumanize(val)) + + current = cstr.to_dict(tres_str[0]) + if _val == slurm.NO_VAL64: + current.pop(typ, None) + else: + current.update({typ : _val}) + + cstr.from_dict(tres_str, current) + + +cdef _tres_ids_to_names(char *tres_str, dict tres_data): + if not tres_str: + return None + + cdef: + dict tdict = cstr.to_dict(tres_str) + list out = [] + + if not tres_data: + return None + + for tid, cnt in tdict.items(): + if isinstance(tid, str) and tid.isdigit(): + _tid = int(tid) + if _tid in tres_data: + out.append( + (tres_data[_tid].type, tres_data[_tid].name, int(cnt)) + ) + + return out + + +def _tres_names_to_ids(dict tres_dict, TrackableResources tres_data): + cdef dict out = {} + if not tres_dict: + return out + + for tid, cnt in tres_dict.items(): + real_id = _validate_tres_single(tid, tres_data) + out[real_id] = cnt + + return out + + +def _validate_tres_single(tid, TrackableResources tres_data): + for tres in tres_data: + if tid == tres.id or tid == tres.type_and_name: + return tres.id + + raise ValueError(f"Invalid TRES specified: {tid}") + + +cdef _set_tres_limits(char **dest, TrackableResourceLimits src, + TrackableResources tres_data): + cstr.from_dict(dest, src._validate(tres_data)) diff --git a/pyslurm/db/util.pxd b/pyslurm/db/util.pxd index 2e9498a6..01951de8 100644 --- a/pyslurm/db/util.pxd +++ b/pyslurm/db/util.pxd @@ -39,6 +39,7 @@ from pyslurm.slurm cimport ( cdef slurm_list_to_pylist(List in_list) cdef make_char_list(List *in_list, vals) +cdef qos_list_to_pylist(List in_list, qos_data) cdef class SlurmListItem: diff --git a/pyslurm/db/util.pyx b/pyslurm/db/util.pyx index 2560c4b0..672886c2 100644 --- a/pyslurm/db/util.pyx +++ b/pyslurm/db/util.pyx @@ -43,6 +43,15 @@ cdef slurm_list_to_pylist(List in_list): return SlurmList.wrap(in_list, owned=False).to_pylist() +cdef qos_list_to_pylist(List in_list, qos_data): + if not in_list: + return [] + + cdef list qos_nums = SlurmList.wrap(in_list, owned=False).to_pylist() + return [qos.name for qos_id, qos in qos_data.items() + if qos_id in qos_nums] + + cdef class SlurmListItem: def __cinit__(self): diff --git a/pyslurm/slurm/extra.pxi b/pyslurm/slurm/extra.pxi index fb922ac5..3557b0b9 100644 --- a/pyslurm/slurm/extra.pxi +++ b/pyslurm/slurm/extra.pxi @@ -165,6 +165,9 @@ ctypedef enum tres_types_t: # Global Environment cdef extern char **environ +# Local slurm config +cdef extern slurm_conf_t slurm_conf + # # Slurm Memory routines # We simply use the macros from xmalloc.h - more convenient @@ -272,6 +275,8 @@ cdef extern char *slurm_hostlist_deranged_string_malloc(hostlist_t hl) cdef extern void slurmdb_job_cond_def_start_end(slurmdb_job_cond_t *job_cond) cdef extern uint64_t slurmdb_find_tres_count_in_string(char *tres_str_in, int id) cdef extern slurmdb_job_rec_t *slurmdb_create_job_rec() +cdef extern void slurmdb_init_assoc_rec(slurmdb_assoc_rec_t *assoc, bool free_it) +cdef extern void slurmdb_init_tres_cond(slurmdb_tres_cond_t *tres, bool free_it) # # Slurm Partition functions diff --git a/pyslurm/utils/cstr.pyx b/pyslurm/utils/cstr.pyx index 489d80e8..13795544 100644 --- a/pyslurm/utils/cstr.pyx +++ b/pyslurm/utils/cstr.pyx @@ -186,7 +186,7 @@ def dict_to_str(vals, prepend=None, delim1=",", delim2="="): tmp_dict = validate_str_key_value_format(vals, delim1, delim2) for k, v in tmp_dict.items(): - if ((delim1 in k or delim2 in k) or + if ((delim1 in str(k) or delim2 in str(k)) or delim1 in str(v) or delim2 in str(v)): raise ValueError( f"Key or Value cannot contain either {delim1} or {delim2}. " diff --git a/pyslurm/utils/helpers.pyx b/pyslurm/utils/helpers.pyx index fcfe9965..fb1d2201 100644 --- a/pyslurm/utils/helpers.pyx +++ b/pyslurm/utils/helpers.pyx @@ -341,6 +341,50 @@ def instance_to_dict(inst): return out +def collection_to_dict(collection, identifier, recursive=False, group_id=None): + cdef dict out = {} + + for item in collection: + cluster = item.cluster + if cluster not in out: + out[cluster] = {} + + _id = identifier.__get__(item) + data = item if not recursive else item.as_dict() + + if group_id: + grp_id = group_id.__get__(item) + if grp_id not in out[cluster]: + out[cluster][grp_id] = {} + out[cluster][grp_id].update({_id: data}) + else: + out[cluster][_id] = data + + return out + + +def collection_to_dict_global(collection, identifier, recursive=False): + cdef dict out = {} + for item in collection: + _id = identifier.__get__(item) + out[_id] = item if not recursive else item.as_dict() + return out + + +def group_collection_by_cluster(collection): + cdef dict out = {} + collection_type = type(collection) + + for item in collection: + cluster = item.cluster + if cluster not in out: + out[cluster] = collection_type() + + out[cluster].append(item) + + return out + + def _sum_prop(obj, name, startval=0): val = startval for n in obj.values(): @@ -362,3 +406,29 @@ def _get_exit_code(exit_code): exit_state -= 128 return exit_state, sig + + +def humanize_step_id(sid): + if sid == slurm.SLURM_BATCH_SCRIPT: + return "batch" + elif sid == slurm.SLURM_EXTERN_CONT: + return "extern" + elif sid == slurm.SLURM_INTERACTIVE_STEP: + return "interactive" + elif sid == slurm.SLURM_PENDING_STEP: + return "pending" + else: + return sid + + +def dehumanize_step_id(sid): + if sid == "batch": + return slurm.SLURM_BATCH_SCRIPT + elif sid == "extern": + return slurm.SLURM_EXTERN_CONT + elif sid == "interactive": + return slurm.SLURM_INTERACTIVE_STEP + elif sid == "pending": + return slurm.SLURM_PENDING_STEP + else: + return int(sid) diff --git a/tests/integration/test_db_job.py b/tests/integration/test_db_job.py index 36005935..571ec0d2 100644 --- a/tests/integration/test_db_job.py +++ b/tests/integration/test_db_job.py @@ -42,7 +42,7 @@ def test_load_single(submit_job): assert db_job.id == job.id with pytest.raises(pyslurm.RPCError): - pyslurm.db.Job.load(1000) + pyslurm.db.Job.load(0) def test_parse_all(submit_job): @@ -59,7 +59,7 @@ def test_modify(submit_job): job = submit_job() util.wait(5) - jfilter = pyslurm.db.JobSearchFilter(ids=[job.id]) + jfilter = pyslurm.db.JobFilter(ids=[job.id]) changes = pyslurm.db.Job(comment="test comment") pyslurm.db.Jobs.modify(jfilter, changes) @@ -72,7 +72,7 @@ def test_modify_with_existing_conn(submit_job): util.wait(5) conn = pyslurm.db.Connection.open() - jfilter = pyslurm.db.JobSearchFilter(ids=[job.id]) + jfilter = pyslurm.db.JobFilter(ids=[job.id]) changes = pyslurm.db.Job(comment="test comment") pyslurm.db.Jobs.modify(jfilter, changes, conn) diff --git a/tests/integration/test_db_qos.py b/tests/integration/test_db_qos.py index 5bbd69e4..11d9e870 100644 --- a/tests/integration/test_db_qos.py +++ b/tests/integration/test_db_qos.py @@ -50,6 +50,6 @@ def test_load_all(): def test_load_with_filter_name(): - qfilter = pyslurm.db.QualityOfServiceSearchFilter(names=["non_existent"]) + qfilter = pyslurm.db.QualityOfServiceFilter(names=["non_existent"]) qos = pyslurm.db.QualitiesOfService.load(qfilter) assert not qos diff --git a/tests/integration/test_job.py b/tests/integration/test_job.py index 15c4bdef..cef42daf 100644 --- a/tests/integration/test_job.py +++ b/tests/integration/test_job.py @@ -150,7 +150,7 @@ def test_get_job_queue(submit_job): # Submit 10 jobs, gather the job_ids in a list job_list = [submit_job() for i in range(10)] - jobs = Jobs.load() + jobs = Jobs.load().as_dict() for job in job_list: # Check to see if all the Jobs we submitted exist assert job.id in jobs diff --git a/tests/integration/test_job_steps.py b/tests/integration/test_job_steps.py index bd17a188..b24409f5 100644 --- a/tests/integration/test_job_steps.py +++ b/tests/integration/test_job_steps.py @@ -102,9 +102,9 @@ def test_collection(submit_job): job = submit_job(script=create_job_script_multi_step()) time.sleep(util.WAIT_SECS_SLURMCTLD) - steps = JobSteps.load(job) + steps = JobSteps.load(job).as_dict() - assert steps != {} + assert steps # We have 3 Steps: batch, 0 and 1 assert len(steps) == 3 assert ("batch" in steps and @@ -116,7 +116,7 @@ def test_cancel(submit_job): job = submit_job(script=create_job_script_multi_step()) time.sleep(util.WAIT_SECS_SLURMCTLD) - steps = JobSteps.load(job) + steps = JobSteps.load(job).as_dict() assert len(steps) == 3 assert ("batch" in steps and 0 in steps and @@ -125,7 +125,7 @@ def test_cancel(submit_job): steps[0].cancel() time.sleep(util.WAIT_SECS_SLURMCTLD) - steps = JobSteps.load(job) + steps = JobSteps.load(job).as_dict() assert len(steps) == 2 assert ("batch" in steps and 1 in steps) diff --git a/tests/integration/test_node.py b/tests/integration/test_node.py index fb6f5197..49a69db2 100644 --- a/tests/integration/test_node.py +++ b/tests/integration/test_node.py @@ -29,7 +29,7 @@ def test_load(): - name = Nodes.load().as_list()[0].name + name = Nodes.load()[0].name # Now load the node info node = Node.load(name) @@ -56,7 +56,7 @@ def test_create(): def test_modify(): - node = Node(Nodes.load().as_list()[0].name) + node = Node(Nodes.load()[0].name) node.modify(Node(weight=10000)) assert Node.load(node.name).weight == 10000 @@ -69,4 +69,4 @@ def test_modify(): def test_parse_all(): - Node.load(Nodes.load().as_list()[0].name).as_dict() + Node.load(Nodes.load()[0].name).as_dict() diff --git a/tests/integration/test_partition.py b/tests/integration/test_partition.py index fcfcf4af..8d7a4de4 100644 --- a/tests/integration/test_partition.py +++ b/tests/integration/test_partition.py @@ -28,7 +28,7 @@ def test_load(): - part = Partitions.load().as_list()[0] + part = Partitions.load()[0] assert part.name assert part.state @@ -49,7 +49,7 @@ def test_create_delete(): def test_modify(): - part = Partitions.load().as_list()[0] + part = Partitions.load()[0] part.modify(Partition(default_time=120)) assert Partition.load(part.name).default_time == 120 @@ -68,22 +68,23 @@ def test_modify(): def test_parse_all(): - Partitions.load().as_list()[0].as_dict() + Partitions.load()[0].as_dict() def test_reload(): _partnames = [util.randstr() for i in range(3)] _tmp_parts = Partitions(_partnames) - for part in _tmp_parts.values(): + for part in _tmp_parts: part.create() all_parts = Partitions.load() assert len(all_parts) >= 3 my_parts = Partitions(_partnames[1:]).reload() + print(my_parts) assert len(my_parts) == 2 - for part in my_parts.as_list(): + for part in my_parts: assert part.state != "UNKNOWN" - for part in _tmp_parts.values(): + for part in _tmp_parts: part.delete() diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index 48f4fecf..1598d191 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -54,6 +54,10 @@ cpubind_to_num, nodelist_from_range_str, nodelist_to_range_str, + instance_to_dict, + collection_to_dict, + collection_to_dict_global, + group_collection_by_cluster, _sum_prop, ) from pyslurm.utils import cstr @@ -426,4 +430,59 @@ def cpus(self): assert _sum_prop(object_dict, TestObject.memory) == expected expected = 0 - assert _sum_prop(object_dict, TestObject.cpus) == 0 + assert _sum_prop(object_dict, TestObject.cpus) == expected + + def test_collection_to_dict(self): + class TestObject: + + def __init__(self, _id, _grp_id, cluster): + self._id = _id + self._grp_id = _grp_id + self.cluster = cluster + + @property + def id(self): + return self._id + + @property + def group_id(self): + return self._grp_id + + def as_dict(self): + return instance_to_dict(self) + + class TestCollection(list): + + def __init__(self, data): + super().__init__() + self.extend(data) + + OFFSET = 100 + RANGE = 10 + + data = [TestObject(x, x+OFFSET, "TestCluster") for x in range(RANGE)] + collection = TestCollection(data) + + coldict = collection_to_dict(collection, identifier=TestObject.id) + coldict = coldict.get("TestCluster", {}) + + assert len(coldict) == RANGE + for i in range(RANGE): + assert i in coldict + assert isinstance(coldict[i], TestObject) + + coldict = collection_to_dict(collection, identifier=TestObject.id, + group_id=TestObject.group_id) + coldict = coldict.get("TestCluster", {}) + + assert len(coldict) == RANGE + for i in range(RANGE): + assert i+OFFSET in coldict + assert i in coldict[i+OFFSET] + + coldict = collection_to_dict(collection, identifier=TestObject.id, + recursive=True) + coldict = coldict.get("TestCluster", {}) + + for item in coldict.values(): + assert isinstance(item, dict) diff --git a/tests/unit/test_db_job.py b/tests/unit/test_db_job.py index 9391f04a..7b77671f 100644 --- a/tests/unit/test_db_job.py +++ b/tests/unit/test_db_job.py @@ -25,7 +25,7 @@ def test_filter(): - job_filter = pyslurm.db.JobSearchFilter() + job_filter = pyslurm.db.JobFilter() job_filter.clusters = ["test1"] job_filter.partitions = ["partition1", "partition2"] @@ -45,6 +45,7 @@ def test_filter(): def test_create_collection(): jobs = pyslurm.db.Jobs("101,102") assert len(jobs) == 2 + jobs = jobs.as_dict() assert 101 in jobs assert 102 in jobs assert jobs[101].id == 101 @@ -52,6 +53,7 @@ def test_create_collection(): jobs = pyslurm.db.Jobs([101, 102]) assert len(jobs) == 2 + jobs = jobs.as_dict() assert 101 in jobs assert 102 in jobs assert jobs[101].id == 101 @@ -64,6 +66,7 @@ def test_create_collection(): } ) assert len(jobs) == 2 + jobs = jobs.as_dict() assert 101 in jobs assert 102 in jobs assert jobs[101].id == 101 diff --git a/tests/unit/test_db_qos.py b/tests/unit/test_db_qos.py index acf12fea..0d2fd538 100644 --- a/tests/unit/test_db_qos.py +++ b/tests/unit/test_db_qos.py @@ -25,7 +25,7 @@ def test_search_filter(): - qos_filter = pyslurm.db.QualityOfServiceSearchFilter() + qos_filter = pyslurm.db.QualityOfServiceFilter() qos_filter._create() qos_filter.ids = [1, 2] diff --git a/tests/unit/test_job_steps.py b/tests/unit/test_job_steps.py index c222ef34..fcd0d012 100644 --- a/tests/unit/test_job_steps.py +++ b/tests/unit/test_job_steps.py @@ -22,7 +22,7 @@ import pytest from pyslurm import JobStep, Job -from pyslurm.core.job.step import ( +from pyslurm.utils.helpers import ( humanize_step_id, dehumanize_step_id, ) diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index f2b5594a..755e85d9 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -36,14 +36,14 @@ def test_parse_all(): def test_create_nodes_collection(): - nodes = Nodes("node1,node2") + nodes = Nodes("node1,node2").as_dict() assert len(nodes) == 2 assert "node1" in nodes assert "node2" in nodes assert nodes["node1"].name == "node1" assert nodes["node2"].name == "node2" - nodes = Nodes(["node1", "node2"]) + nodes = Nodes(["node1", "node2"]).as_dict() assert len(nodes) == 2 assert "node1" in nodes assert "node2" in nodes @@ -55,7 +55,7 @@ def test_create_nodes_collection(): "node1": Node("node1"), "node2": Node("node2"), } - ) + ).as_dict() assert len(nodes) == 2 assert "node1" in nodes assert "node2" in nodes diff --git a/tests/unit/test_partition.py b/tests/unit/test_partition.py index 141a6e51..89403ae2 100644 --- a/tests/unit/test_partition.py +++ b/tests/unit/test_partition.py @@ -32,14 +32,14 @@ def test_create_instance(): def test_create_collection(): - parts = Partitions("part1,part2") + parts = Partitions("part1,part2").as_dict() assert len(parts) == 2 assert "part1" in parts assert "part2" in parts assert parts["part1"].name == "part1" assert parts["part2"].name == "part2" - parts = Partitions(["part1", "part2"]) + parts = Partitions(["part1", "part2"]).as_dict() assert len(parts) == 2 assert "part1" in parts assert "part2" in parts @@ -51,7 +51,7 @@ def test_create_collection(): "part1": Partition("part1"), "part2": Partition("part2"), } - ) + ).as_dict() assert len(parts) == 2 assert "part1" in parts assert "part2" in parts