Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate curies issue #432

Merged
merged 10 commits into from
Sep 11, 2023
22 changes: 16 additions & 6 deletions strider/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ async def generate_from_kp(
)
for batch_results in batch(onehop_results, 1_000_000):
result_map = defaultdict(list)
# copy subqgraph between each batch
# before we fill it with result curies
# this keeps the sub query graph from being modified and passing
# extra curies into subsequent batches
populated_subqgraph = copy.deepcopy(subqgraph)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another comment or two describing why we need to do this would be helpful. Just describe the problem we are solving right now.

for result in batch_results:
# add edge to results and kgraph

Expand Down Expand Up @@ -227,12 +232,17 @@ async def generate_from_kp(

# pin nodes
for qnode_id, bindings in result.node_bindings.items():
if qnode_id not in subqgraph["nodes"]:
if qnode_id not in populated_subqgraph["nodes"]:
continue
subqgraph["nodes"][qnode_id]["ids"] = (
subqgraph["nodes"][qnode_id].get("ids") or []
) + [binding.id for binding in bindings]
qnode_ids = set(subqgraph["nodes"].keys()) & set(
# add curies from result into the qgraph
# need to call set() to remove any duplicates
uhbrar marked this conversation as resolved.
Show resolved Hide resolved
populated_subqgraph["nodes"][qnode_id]["ids"] = list(
set(
(populated_subqgraph["nodes"][qnode_id].get("ids") or [])
+ [binding.id for binding in bindings]
)
)
qnode_ids = set(populated_subqgraph["nodes"].keys()) & set(
result.node_bindings.keys()
)
key_fcn = lambda res: tuple(
Expand All @@ -249,7 +259,7 @@ async def generate_from_kp(

generators.append(
self.generate_from_result(
copy.deepcopy(subqgraph),
populated_subqgraph,
lambda result: result_map[key_fcn(result)],
call_stack,
)
Expand Down
21 changes: 20 additions & 1 deletion strider/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
"""
import copy
import datetime
import json
import os
import uuid
import logging
import warnings
import time
import traceback
import asyncio

Expand Down Expand Up @@ -434,6 +436,7 @@ async def lookup(
qid: str = None,
) -> dict:
"""Perform lookup operation."""
lookup_start_time = time.time()
qgraph = query_dict["message"]["query_graph"]

log_level = query_dict.get("log_level") or "INFO"
Expand All @@ -454,7 +457,7 @@ async def lookup(

fetcher = Fetcher(logger, parameters)

logger.info(f"Doing lookup for qgraph: {qgraph}")
logger.info(f"Doing lookup for qgraph: {json.dumps(qgraph)}")
try:
await fetcher.setup(qgraph, registry, information_content_threshold)
except NoAnswersError:
Expand All @@ -477,9 +480,12 @@ async def lookup(

output_auxgraphs = AuxiliaryGraphs.parse_obj({})

message_merging_time = 0

async with fetcher:
async for result_kgraph, result, result_auxgraph in fetcher.lookup(None):
# Update the kgraph
start_merging = time.time()
output_kgraph.update(result_kgraph)

# Update the aux graphs
Expand All @@ -496,6 +502,9 @@ async def lookup(
# add new result to hashmap
output_results[sub_result_hash] = result

stop_merging = time.time()
message_merging_time += stop_merging - start_merging

results = Results.parse_obj([])
for result in output_results.values():
# copy so result analyses don't get combined somehow
Expand All @@ -517,6 +526,13 @@ async def lookup(
collapse_sets(output_query, logger)

output_query.logs = list(log_handler.contents())
lookup_end_time = time.time()
logger.info(
{
"total_lookup_time": (lookup_end_time - lookup_start_time),
"total_merging": message_merging_time,
}
)
return output_query.dict(exclude_none=True)


Expand Down Expand Up @@ -560,6 +576,7 @@ async def async_lookup(

async def multi_lookup(multiqid, callback, queries: dict, query_keys: list):
"Performs lookup for multiple queries and sends all results to callback url"
start_time = time.time()

async def single_lookup(query_key):
qid = f"{multiqid}.{str(uuid.uuid4())[:8]}"
Expand Down Expand Up @@ -623,6 +640,8 @@ async def single_lookup(query_key):
LOGGER.error(
f"[{multiqid}] Failed to send 'completed' response back to {callback} with error: {e}"
)
end_time = time.time()
LOGGER.info(f"[{multiqid}] took {(end_time - start_time)} seconds")


@APP.post("/plan", response_model=dict[str, list[str]], include_in_schema=False)
Expand Down
5 changes: 3 additions & 2 deletions strider/throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from asyncio.queues import QueueEmpty
from asyncio.tasks import Task
import copy
import json
import datetime
from functools import wraps
import itertools
Expand Down Expand Up @@ -351,8 +352,8 @@ async def process_batch(
self.logger.warning(
{
"message": f"Received bad JSON data from {self.id}",
"request": e.request,
"response": e.response.text,
"request": json.dumps(merged_request_value),
"response": response.text,
"error": str(e),
}
)
Expand Down
Loading