diff --git a/strider/caching.py b/strider/caching.py index 42710f17..d7928809 100644 --- a/strider/caching.py +++ b/strider/caching.py @@ -207,3 +207,9 @@ async def get_post_response(url, request): if response is not None: response = json.loads(gzip.decompress(response)) return response + + +async def clear_cache(): + """Clear one-hop redis cache.""" + client = await aioredis.Redis(connection_pool=onehop_redis_pool) + await client.flushdb() diff --git a/strider/fetcher.py b/strider/fetcher.py index 078f834d..381c7af3 100644 --- a/strider/fetcher.py +++ b/strider/fetcher.py @@ -48,11 +48,12 @@ class Fetcher: a full result can be returned for merging. """ - def __init__(self, logger): + def __init__(self, logger, parameters): """Initialize.""" self.logger: logging.Logger = logger self.normalizer = Normalizer(self.logger) self.kps = dict() + self.parameters = parameters self.preferred_prefixes = WBMT.entity_prefix_mapping @@ -344,5 +345,6 @@ async def setup( kp_id, kp, self.logger, + self.parameters, information_content_threshold=information_content_threshold, ) diff --git a/strider/knowledge_provider.py b/strider/knowledge_provider.py index e8ec1f0a..32b2302e 100644 --- a/strider/knowledge_provider.py +++ b/strider/knowledge_provider.py @@ -34,6 +34,7 @@ def __init__( kp_id, kp, logger, + parameters: dict = {}, information_content_threshold: int = settings.information_content_threshold, *args, **kwargs, @@ -47,6 +48,7 @@ def __init__( logger=logger, preproc=self.get_preprocessor(kp["details"]["preferred_prefixes"]), postproc=self.get_postprocessor(WBMT.entity_prefix_mapping), + parameters=parameters, *args, *kwargs, ) diff --git a/strider/server.py b/strider/server.py index 8e86cc49..fa2468a7 100644 --- a/strider/server.py +++ b/strider/server.py @@ -23,6 +23,7 @@ from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles import httpx +from pydantic import BaseModel from reasoner_pydantic.kgraph import KnowledgeGraph from reasoner_pydantic.qgraph import QueryGraph from reasoner_pydantic.utils import HashableMapping @@ -45,6 +46,7 @@ save_kp_registry, get_registry_lock, remove_registry_lock, + clear_cache, ) from .fetcher import Fetcher from .node_sets import collapse_sets @@ -70,7 +72,7 @@ title="Strider", description=DESCRIPTION, docs_url=None, - version="4.4.5", + version="4.4.6", terms_of_service=( "http://robokop.renci.org:7055/tos" "?service_long=Strider" @@ -328,9 +330,11 @@ async def sync_query( qid = str(uuid.uuid4())[:8] try: LOGGER.info(f"[{qid}] Starting sync query") - query_results = await asyncio.wait_for( - lookup(query_dict, qid), timeout=max_process_time - ) + # get max timeout + timeout_seconds = (query_dict.get("parameters") or {}).get("timeout_seconds") + timeout_seconds = timeout_seconds if type(timeout_seconds) is int else 0 + timeout = max(max_process_time, timeout_seconds) + query_results = await asyncio.wait_for(lookup(query_dict, qid), timeout=timeout) except asyncio.TimeoutError: LOGGER.error(f"[{qid}] Sync query cancelled due to timeout.") query_results = { @@ -338,7 +342,9 @@ async def sync_query( "status_communication": {"strider_process_status": "timeout"}, } except Exception as e: - LOGGER.error(f"[{qid}] Sync query failed unexpectedly: {e}") + LOGGER.error( + f"[{qid}] Sync query failed unexpectedly: {traceback.format_exc()}" + ) qid = "Exception" query_results = { "message": {}, @@ -442,7 +448,9 @@ async def lookup( logger.setLevel(level_number) logger.addHandler(log_handler) - fetcher = Fetcher(logger) + parameters = query_dict.get("parameters") or {} + + fetcher = Fetcher(logger, parameters) logger.info(f"Doing lookup for qgraph: {qgraph}") try: @@ -518,9 +526,11 @@ async def async_lookup( qid = str(uuid.uuid4())[:8] query_results = {} try: - query_results = await asyncio.wait_for( - lookup(query_dict, qid), timeout=max_process_time - ) + # get max timeout + timeout_seconds = (query_dict.get("parameters") or {}).get("timeout_seconds") + timeout_seconds = timeout_seconds if type(timeout_seconds) is int else 0 + timeout = max(max_process_time, timeout_seconds) + query_results = await asyncio.wait_for(lookup(query_dict, qid), timeout=timeout) except asyncio.TimeoutError: LOGGER.error(f"[{qid}]: Process cancelled due to timeout.") query_results = { @@ -549,8 +559,14 @@ async def single_lookup(query_key): qid = f"{multiqid}.{str(uuid.uuid4())[:8]}" query_result = {} try: + # get max timeout + timeout_seconds = (queries[query_key].get("parameters") or {}).get( + "timeout_seconds" + ) + timeout_seconds = timeout_seconds if type(timeout_seconds) is int else 0 + timeout = max(max_process_time, timeout_seconds) query_result = await asyncio.wait_for( - lookup(queries[query_key], qid), timeout=max_process_time + lookup(queries[query_key], qid), timeout=timeout ) except asyncio.TimeoutError: LOGGER.error(f"[{qid}]: Process cancelled due to timeout.") @@ -615,3 +631,17 @@ async def get_kps(): registry = await get_kp_registry() # print(registry) return list(registry.keys()) + + +class ClearCacheRequest(BaseModel): + pswd: str + + +@APP.post("/clear_cache", status_code=200, include_in_schema=False) +async def clear_redis_cache(request: ClearCacheRequest) -> dict: + """Clear the redis cache.""" + if request.pswd == settings.redis_password: + await clear_cache() + return {"status": "success"} + else: + raise HTTPException(status_code=401, detail="Invalid Password") diff --git a/strider/throttle.py b/strider/throttle.py index 1d8a582e..cc510987 100644 --- a/strider/throttle.py +++ b/strider/throttle.py @@ -64,6 +64,7 @@ def __init__( preproc: Callable = anull, postproc: Callable = anull, logger: logging.Logger = None, + parameters: dict = {}, **kwargs, ): """Initialize.""" @@ -75,6 +76,7 @@ def __init__( self.preproc = preproc self.postproc = postproc self.use_cache = settings.use_cache + self.parameters = parameters if logger is None: logger = logging.getLogger(__name__) self.logger = logger @@ -211,7 +213,11 @@ async def process_batch( ), ) ) - async with httpx.AsyncClient(timeout=settings.kp_timeout) as client: + kp_timeout = self.parameters.get("timeout_seconds") + kp_timeout = ( + kp_timeout if type(kp_timeout) is int else settings.kp_timeout + ) + async with httpx.AsyncClient(timeout=kp_timeout) as client: response = await client.post( self.url, json=merged_request_value,