From 80388dc7c0e707934edbb765b46587da276389c4 Mon Sep 17 00:00:00 2001 From: Ryunosuke O'Neil Date: Tue, 17 Dec 2024 15:20:06 +0100 Subject: [PATCH] Moved non-DB related job submission logic out of diracx-db to diracx-routers. --- diracx-db/src/diracx/db/sql/job/db.py | 229 +++++------------- .../src/diracx/routers/jobs/submission.py | 109 ++++++++- 2 files changed, 161 insertions(+), 177 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 31f1ea98..4db8af65 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -1,7 +1,5 @@ from __future__ import annotations -from asyncio import TaskGroup -from copy import deepcopy from datetime import datetime, timezone from typing import TYPE_CHECKING, Any @@ -109,13 +107,18 @@ async def search( dict(row._mapping) async for row in (await self.conn.stream(stmt)) ] - async def _insertJob(self, jobData: dict[str, Any]): - stmt = insert(Jobs).values(jobData) - await self.conn.execute(stmt) - - async def _insertInputData(self, job_id: int, lfns: list[str]): - stmt = insert(InputData).values([{"JobID": job_id, "LFN": lfn} for lfn in lfns]) - await self.conn.execute(stmt) + async def insert_input_data(self, lfns: dict[int, list[str]]): + await self.conn.execute( + InputData.__table__.insert(), + [ + { + "JobID": job_id, + "LFN": lfn, + } + for job_id, lfns_ in lfns.items() + for lfn in lfns_ + ], + ) async def setJobAttributes(self, job_id, jobData): """TODO: add myDate and force parameters.""" @@ -124,6 +127,48 @@ async def setJobAttributes(self, job_id, jobData): stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData) await self.conn.execute(stmt) + async def create_job(self, original_jdl): + """Used to insert a new job with original JDL. Returns inserted job id.""" + from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL + + result = await self.conn.execute( + JobJDLs.__table__.insert().values( + JDL="", + JobRequirements="", + OriginalJDL=compressJDL(original_jdl), + ) + ) + return result.lastrowid + + async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): + await self.conn.execute( + Jobs.__table__.insert(), + [ + { + "JobID": job_id, + **attrs, + } + for job_id, attrs in jobs_to_update.items() + ], + ) + + async def update_job_jdls(self, jdls_to_update: dict[int, str]): + """Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example.""" + from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL + + await self.conn.execute( + JobJDLs.__table__.update().where( + JobJDLs.__table__.c.JobID == bindparam("b_JobID") + ), + [ + { + "b_JobID": job_id, + "JDL": compressJDL(jdl), + } + for job_id, jdl in jdls_to_update.items() + ], + ) + async def checkAndPrepareJob( self, jobID, @@ -222,172 +267,6 @@ async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, s if jdl } - async def insert_bulk( - self, - jobs: list[JobSubmissionSpec], - ): - from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd - from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import ( - checkAndAddOwner, - compressJDL, - createJDLWithInitialStatus, - ) - - jobs_to_insert = [] - jdls_to_update = [] - inputdata_to_insert = [] - original_jdls = [] - - # generate the jobIDs first - async with TaskGroup() as tg: - for job in jobs: - original_jdl = deepcopy(job.jdl) - jobManifest = returnValueOrRaise( - checkAndAddOwner(original_jdl, job.owner, job.owner_group) - ) - - # Fix possible lack of brackets - if original_jdl.strip()[0] != "[": - original_jdl = f"[{original_jdl}]" - - original_jdls.append( - ( - original_jdl, - jobManifest, - tg.create_task( - self.conn.execute( - JobJDLs.__table__.insert().values( - JDL="", - JobRequirements="", - OriginalJDL=compressJDL(original_jdl), - ) - ) - ), - ) - ) - - job_ids = [] - - async with TaskGroup() as tg: - for job, (original_jdl, jobManifest_, job_id_task) in zip( - jobs, original_jdls - ): - job_id = job_id_task.result().lastrowid - job_attrs = { - "JobID": job_id, - "LastUpdateTime": datetime.now(tz=timezone.utc), - "SubmissionTime": datetime.now(tz=timezone.utc), - "Owner": job.owner, - "OwnerGroup": job.owner_group, - "VO": job.vo, - } - - jobManifest_.setOption("JobID", job_id) - - # 2.- Check JDL and Prepare DIRAC JDL - jobJDL = jobManifest_.dumpAsJDL() - - # Replace the JobID placeholder if any - if jobJDL.find("%j") != -1: - jobJDL = jobJDL.replace("%j", str(job_id)) - - class_ad_job = ClassAd(jobJDL) - - class_ad_req = ClassAd("[]") - if not class_ad_job.isOK(): - # Rollback the entire transaction - raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}") - # TODO: check if that is actually true - if class_ad_job.lookupAttribute("Parameters"): - raise NotImplementedError("Parameters in the JDL are not supported") - - # TODO is this even needed? - class_ad_job.insertAttributeInt("JobID", job_id) - - await self.checkAndPrepareJob( - job_id, - class_ad_job, - class_ad_req, - job.owner, - job.owner_group, - job_attrs, - job.vo, - ) - jobJDL = createJDLWithInitialStatus( - class_ad_job, - class_ad_req, - self.jdl2DBParameters, - job_attrs, - job.initial_status, - job.initial_minor_status, - modern=True, - ) - # assert "JobType" in job_attrs, job_attrs - job_ids.append(job_id) - jobs_to_insert.append(job_attrs) - jdls_to_update.append( - { - "b_JobID": job_id, - "JDL": compressJDL(jobJDL), - } - ) - - if class_ad_job.lookupAttribute("InputData"): - inputData = class_ad_job.getListFromExpression("InputData") - inputdata_to_insert += [ - {"JobID": job_id, "LFN": lfn} for lfn in inputData if lfn - ] - - tg.create_task( - self.conn.execute( - JobJDLs.__table__.update().where( - JobJDLs.__table__.c.JobID == bindparam("b_JobID") - ), - jdls_to_update, - ) - ) - tg.create_task( - self.conn.execute( - Jobs.__table__.insert(), - jobs_to_insert, - ) - ) - - if inputdata_to_insert: - tg.create_task( - self.conn.execute( - InputData.__table__.insert(), - inputdata_to_insert, - ) - ) - - return job_ids - - async def insert( - self, - jdl, - owner, - owner_group, - initial_status, - initial_minor_status, - vo, - ): - submitted_job_ids = await self.insert_bulk( - [ - JobSubmissionSpec( - jdl=jdl, - owner=owner, - owner_group=owner_group, - initial_status=initial_status, - initial_minor_status=initial_minor_status, - vo=vo, - ) - ] - ) - - return submitted_job_ids[0] - async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: try: stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where( diff --git a/diracx-routers/src/diracx/routers/jobs/submission.py b/diracx-routers/src/diracx/routers/jobs/submission.py index 853c5684..3e5cc141 100644 --- a/diracx-routers/src/diracx/routers/jobs/submission.py +++ b/diracx-routers/src/diracx/routers/jobs/submission.py @@ -1,6 +1,8 @@ from __future__ import annotations import logging +from asyncio import TaskGroup +from copy import deepcopy from datetime import datetime, timezone from http import HTTPStatus from typing import Annotated @@ -68,6 +70,108 @@ class JobID(BaseModel): } +async def _submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): + from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd + from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise + from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import ( + checkAndAddOwner, + createJDLWithInitialStatus, + ) + + jobs_to_insert = {} + jdls_to_update = {} + inputdata_to_insert = {} + original_jdls = [] + + # generate the jobIDs first + async with TaskGroup() as tg: + for job in jobs: + original_jdl = deepcopy(job.jdl) + jobManifest = returnValueOrRaise( + checkAndAddOwner(original_jdl, job.owner, job.owner_group) + ) + + # Fix possible lack of brackets + if original_jdl.strip()[0] != "[": + original_jdl = f"[{original_jdl}]" + + original_jdls.append( + ( + original_jdl, + jobManifest, + tg.create_task(job_db.create_job(original_jdl)), + ) + ) + + async with TaskGroup() as tg: + for job, (original_jdl, jobManifest_, job_id_task) in zip(jobs, original_jdls): + job_id = job_id_task.result() + job_attrs = { + "JobID": job_id, + "LastUpdateTime": datetime.now(tz=timezone.utc), + "SubmissionTime": datetime.now(tz=timezone.utc), + "Owner": job.owner, + "OwnerGroup": job.owner_group, + "VO": job.vo, + } + + jobManifest_.setOption("JobID", job_id) + + # 2.- Check JDL and Prepare DIRAC JDL + jobJDL = jobManifest_.dumpAsJDL() + + # Replace the JobID placeholder if any + if jobJDL.find("%j") != -1: + jobJDL = jobJDL.replace("%j", str(job_id)) + + class_ad_job = ClassAd(jobJDL) + + class_ad_req = ClassAd("[]") + if not class_ad_job.isOK(): + # Rollback the entire transaction + raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}") + # TODO: check if that is actually true + if class_ad_job.lookupAttribute("Parameters"): + raise NotImplementedError("Parameters in the JDL are not supported") + + # TODO is this even needed? + class_ad_job.insertAttributeInt("JobID", job_id) + + await job_db.checkAndPrepareJob( + job_id, + class_ad_job, + class_ad_req, + job.owner, + job.owner_group, + job_attrs, + job.vo, + ) + jobJDL = createJDLWithInitialStatus( + class_ad_job, + class_ad_req, + job_db.jdl2DBParameters, + job_attrs, + job.initial_status, + job.initial_minor_status, + modern=True, + ) + + jobs_to_insert[job_id] = job_attrs + jdls_to_update[job_id] = jobJDL + + if class_ad_job.lookupAttribute("InputData"): + inputData = class_ad_job.getListFromExpression("InputData") + inputdata_to_insert[job_id] = [lfn for lfn in inputData if lfn] + + tg.create_task(job_db.update_job_jdls(jdls_to_update)) + tg.create_task(job_db.insert_job_attributes(jobs_to_insert)) + + if inputdata_to_insert: + tg.create_task(job_db.insert_input_data(inputdata_to_insert)) + + return jobs_to_insert.keys() + + @router.post("/jdl") async def submit_bulk_jdl_jobs( job_definitions: Annotated[list[str], Body(openapi_examples=EXAMPLE_JDLS)], @@ -148,7 +252,7 @@ async def submit_bulk_jdl_jobs( initialStatus = JobStatus.RECEIVED initialMinorStatus = "Job accepted" - submitted_job_ids = await job_db.insert_bulk( + submitted_job_ids = await _submit_jobs_jdl( [ JobSubmissionSpec( jdl=jdl, @@ -159,7 +263,8 @@ async def submit_bulk_jdl_jobs( vo=user_info.vo, ) for jdl in jobDescList - ] + ], + job_db=job_db, ) logging.debug(