diff --git a/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py b/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py index 9037a4aa669..ab8fe53f2c2 100644 --- a/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py @@ -2,6 +2,13 @@ Utilities and classes here are used by the Matcher """ +import threading +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, wait +from functools import partial + +from cachetools import TTLCache + from DIRAC import S_OK, S_ERROR from DIRAC import gLogger @@ -12,10 +19,40 @@ from DIRAC.WorkloadManagementSystem.Client import JobStatus +class TwoLevelCache: + def __init__(self, soft_ttl: int, hard_ttl: int): + self.soft_cache = TTLCache(1_000_000, soft_ttl) + self.hard_cache = TTLCache(1_000_000, hard_ttl) + self.locks = defaultdict(threading.Lock) + self.futures = {} + self.pool = ThreadPoolExecutor(max_workers=10) + + def get(self, key, populate_func): + if result := self.soft_cache.get(key): + return result + with self.locks[key]: + if key not in self.futures: + self.futures[key] = self.pool.submit(populate_func, key) + if result := self.hard_cache.get(key): + self.soft_cache[key] = result + return result + future = self.futures[key] + wait([future]) + return self.hard_cache[key] + + def _work(self, key, func): + result = func() + with self.locks[key]: + self.futures.pop(key) + self.hard_cache[key] = result + self.soft_cache[key] = result + + class Limiter: # static variables shared between all instances of this class csDictCache = DictCache() condCache = DictCache() + newCache = TwoLevelCache() delayMem = {} def __init__(self, jobDB=None, opsHelper=None, pilotRef=None): @@ -177,19 +214,7 @@ def __getRunningCondition(self, siteName, gridCE=None): if attName not in self.jobDB.jobAttributeNames: self.log.error("Attribute does not exist", f"({attName}). Check the job limits") continue - cK = f"Running:{siteName}:{attName}" - data = self.condCache.get(cK) - if not data: - result = self.jobDB.getCounters( - "Jobs", - [attName], - {"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]}, - ) - if not result["OK"]: - return result - data = result["Value"] - data = {k[0][attName]: k[1] for k in data} - self.condCache.add(cK, 10, data) + data = self.newCache.get(f"Running:{siteName}:{attName}", partial(countsByJobType, siteName, attName)) for attValue in limitsDict[attName]: limit = limitsDict[attName][attValue] running = data.get(attValue, 0) @@ -249,3 +274,16 @@ def __getDelayCondition(self, siteName): negCond[attName] = [] negCond[attName].append(attValue) return S_OK(negCond) + + +def countsByJobType(siteName, attName): + result = JobDB().getCounters( + "Jobs", + [attName], + {"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]}, + ) + if not result["OK"]: + return result + data = result["Value"] + data = {k[0][attName]: k[1] for k in data} + return data