From fb55adc81be7c25a93a261b06a80ad0976895dd9 Mon Sep 17 00:00:00 2001 From: Ryunosuke O'Neil Date: Thu, 19 Dec 2024 15:05:03 +0100 Subject: [PATCH] Fix job attribute update to account for mismatching columns between rows to be updated --- diracx-db/src/diracx/db/sql/job/db.py | 21 +++++++++++++++------ diracx-db/src/diracx/db/sql/utils/job.py | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 145b4eb6..fb677fe9 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any -from sqlalchemy import bindparam, delete, func, insert, select, update +from sqlalchemy import bindparam, delete, func, insert, select, update, case from sqlalchemy.exc import IntegrityError, NoResultFound if TYPE_CHECKING: @@ -219,13 +219,22 @@ async def setJobAttributesBulk(self, jobData): jobData[job_id].update( {"LastUpdateTime": datetime.now(tz=timezone.utc)} ) + columns = set(key for attrs in jobData.values() for key in attrs.keys()) + case_expressions = { + column: case( + *[ + (Jobs.__table__.c.JobID == job_id, attrs[column]) + for job_id, attrs in jobData.items() if column in attrs + ], + else_=getattr(Jobs.__table__.c, column) # Retain original value + ) + for column in columns + } - await self.conn.execute( - Jobs.__table__.update().where( - Jobs.__table__.c.JobID == bindparam("b_JobID") - ), - [{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()], + stmt = Jobs.__table__.update().values(**case_expressions).where( + Jobs.__table__.c.JobID.in_(jobData.keys()) ) + await self.conn.execute(stmt) async def getJobJDL(self, job_id: int, original: bool = False) -> str: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL diff --git a/diracx-db/src/diracx/db/sql/utils/job.py b/diracx-db/src/diracx/db/sql/utils/job.py index 0032ff1d..544e3773 100644 --- a/diracx-db/src/diracx/db/sql/utils/job.py +++ b/diracx-db/src/diracx/db/sql/utils/job.py @@ -325,11 +325,11 @@ def parse_jdl(job_id, job_jdl): "failed": failed, "success": { job_id: { - "InputData": job_jdls[job_id], + "InputData": job_jdls.get(job_id, None), **attribute_changes[job_id], **set_status_result.model_dump(), } - for job_id, set_status_result in set_job_status_result.success.items() + for job_id, set_status_result in set_job_status_result.success.items() if job_id not in failed }, }