Skip to content

Commit

Permalink
feat (WMS): Improve caching performance of Limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisburr committed Nov 22, 2024
1 parent cb84f65 commit 4ba5244
Showing 1 changed file with 50 additions and 13 deletions.
63 changes: 50 additions & 13 deletions src/DIRAC/WorkloadManagementSystem/Client/Limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
Utilities and classes here are used by the Matcher
"""
import threading
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, wait
from functools import partial

from cachetools import TTLCache

from DIRAC import S_OK, S_ERROR
from DIRAC import gLogger

Expand All @@ -12,10 +19,40 @@
from DIRAC.WorkloadManagementSystem.Client import JobStatus


class TwoLevelCache:
def __init__(self, soft_ttl: int, hard_ttl: int):
self.soft_cache = TTLCache(1_000_000, soft_ttl)
self.hard_cache = TTLCache(1_000_000, hard_ttl)
self.locks = defaultdict(threading.Lock)
self.futures = {}
self.pool = ThreadPoolExecutor(max_workers=10)

def get(self, key, populate_func):
if result := self.soft_cache.get(key):
return result
with self.locks[key]:
if key not in self.futures:
self.futures[key] = self.pool.submit(self._work, key, populate_func)
if result := self.hard_cache.get(key):
self.soft_cache[key] = result
return result
future = self.futures[key]
wait([future])
return self.hard_cache[key]

def _work(self, key, func):
result = func()
with self.locks[key]:
self.futures.pop(key)
self.hard_cache[key] = result
self.soft_cache[key] = result


class Limiter:
# static variables shared between all instances of this class
csDictCache = DictCache()
condCache = DictCache()
newCache = TwoLevelCache()
delayMem = {}

def __init__(self, jobDB=None, opsHelper=None, pilotRef=None):
Expand Down Expand Up @@ -177,19 +214,7 @@ def __getRunningCondition(self, siteName, gridCE=None):
if attName not in self.jobDB.jobAttributeNames:
self.log.error("Attribute does not exist", f"({attName}). Check the job limits")
continue
cK = f"Running:{siteName}:{attName}"
data = self.condCache.get(cK)
if not data:
result = self.jobDB.getCounters(
"Jobs",
[attName],
{"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]},
)
if not result["OK"]:
return result
data = result["Value"]
data = {k[0][attName]: k[1] for k in data}
self.condCache.add(cK, 10, data)
data = self.newCache.get(f"Running:{siteName}:{attName}", partial(self._countsByJobType, siteName, attName))
for attValue in limitsDict[attName]:
limit = limitsDict[attName][attValue]
running = data.get(attValue, 0)
Expand Down Expand Up @@ -249,3 +274,15 @@ def __getDelayCondition(self, siteName):
negCond[attName] = []
negCond[attName].append(attValue)
return S_OK(negCond)

def _countsByJobType(self, siteName, attName):
result = self.jobDB.getCounters(
"Jobs",
[attName],
{"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]},
)
if not result["OK"]:
return result
data = result["Value"]
data = {k[0][attName]: k[1] for k in data}
return data

0 comments on commit 4ba5244

Please sign in to comment.