Skip to content

Commit

Permalink
Moved non-DB related job submission logic out of diracx-db to diracx-…
Browse files Browse the repository at this point in the history
…routers.
  • Loading branch information
ryuwd committed Dec 17, 2024
1 parent 4c0346d commit 80388dc
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 177 deletions.
229 changes: 54 additions & 175 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from asyncio import TaskGroup
from copy import deepcopy
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -109,13 +107,18 @@ async def search(
dict(row._mapping) async for row in (await self.conn.stream(stmt))
]

async def _insertJob(self, jobData: dict[str, Any]):
stmt = insert(Jobs).values(jobData)
await self.conn.execute(stmt)

async def _insertInputData(self, job_id: int, lfns: list[str]):
stmt = insert(InputData).values([{"JobID": job_id, "LFN": lfn} for lfn in lfns])
await self.conn.execute(stmt)
async def insert_input_data(self, lfns: dict[int, list[str]]):
await self.conn.execute(
InputData.__table__.insert(),
[
{
"JobID": job_id,
"LFN": lfn,
}
for job_id, lfns_ in lfns.items()
for lfn in lfns_
],
)

async def setJobAttributes(self, job_id, jobData):
"""TODO: add myDate and force parameters."""
Expand All @@ -124,6 +127,48 @@ async def setJobAttributes(self, job_id, jobData):
stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData)
await self.conn.execute(stmt)

async def create_job(self, original_jdl):
"""Used to insert a new job with original JDL. Returns inserted job id."""
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL

result = await self.conn.execute(
JobJDLs.__table__.insert().values(
JDL="",
JobRequirements="",
OriginalJDL=compressJDL(original_jdl),
)
)
return result.lastrowid

async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
await self.conn.execute(
Jobs.__table__.insert(),
[
{
"JobID": job_id,
**attrs,
}
for job_id, attrs in jobs_to_update.items()
],
)

async def update_job_jdls(self, jdls_to_update: dict[int, str]):
"""Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL

await self.conn.execute(
JobJDLs.__table__.update().where(
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
),
[
{
"b_JobID": job_id,
"JDL": compressJDL(jdl),
}
for job_id, jdl in jdls_to_update.items()
],
)

async def checkAndPrepareJob(
self,
jobID,
Expand Down Expand Up @@ -222,172 +267,6 @@ async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, s
if jdl
}

async def insert_bulk(
self,
jobs: list[JobSubmissionSpec],
):
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import (
checkAndAddOwner,
compressJDL,
createJDLWithInitialStatus,
)

jobs_to_insert = []
jdls_to_update = []
inputdata_to_insert = []
original_jdls = []

# generate the jobIDs first
async with TaskGroup() as tg:
for job in jobs:
original_jdl = deepcopy(job.jdl)
jobManifest = returnValueOrRaise(
checkAndAddOwner(original_jdl, job.owner, job.owner_group)
)

# Fix possible lack of brackets
if original_jdl.strip()[0] != "[":
original_jdl = f"[{original_jdl}]"

original_jdls.append(
(
original_jdl,
jobManifest,
tg.create_task(
self.conn.execute(
JobJDLs.__table__.insert().values(
JDL="",
JobRequirements="",
OriginalJDL=compressJDL(original_jdl),
)
)
),
)
)

job_ids = []

async with TaskGroup() as tg:
for job, (original_jdl, jobManifest_, job_id_task) in zip(
jobs, original_jdls
):
job_id = job_id_task.result().lastrowid
job_attrs = {
"JobID": job_id,
"LastUpdateTime": datetime.now(tz=timezone.utc),
"SubmissionTime": datetime.now(tz=timezone.utc),
"Owner": job.owner,
"OwnerGroup": job.owner_group,
"VO": job.vo,
}

jobManifest_.setOption("JobID", job_id)

# 2.- Check JDL and Prepare DIRAC JDL
jobJDL = jobManifest_.dumpAsJDL()

# Replace the JobID placeholder if any
if jobJDL.find("%j") != -1:
jobJDL = jobJDL.replace("%j", str(job_id))

class_ad_job = ClassAd(jobJDL)

class_ad_req = ClassAd("[]")
if not class_ad_job.isOK():
# Rollback the entire transaction
raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}")
# TODO: check if that is actually true
if class_ad_job.lookupAttribute("Parameters"):
raise NotImplementedError("Parameters in the JDL are not supported")

# TODO is this even needed?
class_ad_job.insertAttributeInt("JobID", job_id)

await self.checkAndPrepareJob(
job_id,
class_ad_job,
class_ad_req,
job.owner,
job.owner_group,
job_attrs,
job.vo,
)
jobJDL = createJDLWithInitialStatus(
class_ad_job,
class_ad_req,
self.jdl2DBParameters,
job_attrs,
job.initial_status,
job.initial_minor_status,
modern=True,
)
# assert "JobType" in job_attrs, job_attrs
job_ids.append(job_id)
jobs_to_insert.append(job_attrs)
jdls_to_update.append(
{
"b_JobID": job_id,
"JDL": compressJDL(jobJDL),
}
)

if class_ad_job.lookupAttribute("InputData"):
inputData = class_ad_job.getListFromExpression("InputData")
inputdata_to_insert += [
{"JobID": job_id, "LFN": lfn} for lfn in inputData if lfn
]

tg.create_task(
self.conn.execute(
JobJDLs.__table__.update().where(
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
),
jdls_to_update,
)
)
tg.create_task(
self.conn.execute(
Jobs.__table__.insert(),
jobs_to_insert,
)
)

if inputdata_to_insert:
tg.create_task(
self.conn.execute(
InputData.__table__.insert(),
inputdata_to_insert,
)
)

return job_ids

async def insert(
self,
jdl,
owner,
owner_group,
initial_status,
initial_minor_status,
vo,
):
submitted_job_ids = await self.insert_bulk(
[
JobSubmissionSpec(
jdl=jdl,
owner=owner,
owner_group=owner_group,
initial_status=initial_status,
initial_minor_status=initial_minor_status,
vo=vo,
)
]
)

return submitted_job_ids[0]

async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
try:
stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where(
Expand Down
109 changes: 107 additions & 2 deletions diracx-routers/src/diracx/routers/jobs/submission.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
from asyncio import TaskGroup
from copy import deepcopy
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Annotated
Expand Down Expand Up @@ -68,6 +70,108 @@ class JobID(BaseModel):
}


async def _submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import (
checkAndAddOwner,
createJDLWithInitialStatus,
)

jobs_to_insert = {}
jdls_to_update = {}
inputdata_to_insert = {}
original_jdls = []

# generate the jobIDs first
async with TaskGroup() as tg:
for job in jobs:
original_jdl = deepcopy(job.jdl)
jobManifest = returnValueOrRaise(
checkAndAddOwner(original_jdl, job.owner, job.owner_group)
)

# Fix possible lack of brackets
if original_jdl.strip()[0] != "[":
original_jdl = f"[{original_jdl}]"

original_jdls.append(
(
original_jdl,
jobManifest,
tg.create_task(job_db.create_job(original_jdl)),
)
)

async with TaskGroup() as tg:
for job, (original_jdl, jobManifest_, job_id_task) in zip(jobs, original_jdls):
job_id = job_id_task.result()
job_attrs = {
"JobID": job_id,
"LastUpdateTime": datetime.now(tz=timezone.utc),
"SubmissionTime": datetime.now(tz=timezone.utc),
"Owner": job.owner,
"OwnerGroup": job.owner_group,
"VO": job.vo,
}

jobManifest_.setOption("JobID", job_id)

# 2.- Check JDL and Prepare DIRAC JDL
jobJDL = jobManifest_.dumpAsJDL()

# Replace the JobID placeholder if any
if jobJDL.find("%j") != -1:
jobJDL = jobJDL.replace("%j", str(job_id))

class_ad_job = ClassAd(jobJDL)

class_ad_req = ClassAd("[]")
if not class_ad_job.isOK():
# Rollback the entire transaction
raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}")
# TODO: check if that is actually true
if class_ad_job.lookupAttribute("Parameters"):
raise NotImplementedError("Parameters in the JDL are not supported")

# TODO is this even needed?
class_ad_job.insertAttributeInt("JobID", job_id)

await job_db.checkAndPrepareJob(
job_id,
class_ad_job,
class_ad_req,
job.owner,
job.owner_group,
job_attrs,
job.vo,
)
jobJDL = createJDLWithInitialStatus(
class_ad_job,
class_ad_req,
job_db.jdl2DBParameters,
job_attrs,
job.initial_status,
job.initial_minor_status,
modern=True,
)

jobs_to_insert[job_id] = job_attrs
jdls_to_update[job_id] = jobJDL

if class_ad_job.lookupAttribute("InputData"):
inputData = class_ad_job.getListFromExpression("InputData")
inputdata_to_insert[job_id] = [lfn for lfn in inputData if lfn]

tg.create_task(job_db.update_job_jdls(jdls_to_update))
tg.create_task(job_db.insert_job_attributes(jobs_to_insert))

if inputdata_to_insert:
tg.create_task(job_db.insert_input_data(inputdata_to_insert))

return jobs_to_insert.keys()


@router.post("/jdl")
async def submit_bulk_jdl_jobs(
job_definitions: Annotated[list[str], Body(openapi_examples=EXAMPLE_JDLS)],
Expand Down Expand Up @@ -148,7 +252,7 @@ async def submit_bulk_jdl_jobs(
initialStatus = JobStatus.RECEIVED
initialMinorStatus = "Job accepted"

submitted_job_ids = await job_db.insert_bulk(
submitted_job_ids = await _submit_jobs_jdl(
[
JobSubmissionSpec(
jdl=jdl,
Expand All @@ -159,7 +263,8 @@ async def submit_bulk_jdl_jobs(
vo=user_info.vo,
)
for jdl in jobDescList
]
],
job_db=job_db,
)

logging.debug(
Expand Down

0 comments on commit 80388dc

Please sign in to comment.