Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle is_set #303

Merged
merged 5 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions strider/node_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Node sets."""
from collections import defaultdict


def collapse_sets(message: dict) -> None:
"""Collase results according to is_set qnode notations."""
unique_qnodes = {
qnode_id
for qnode_id, qnode in message["query_graph"]["nodes"].items()
if not qnode.get("is_set", False)
}
if len(unique_qnodes) == len(message["query_graph"]["nodes"]):
return
unique_qedges = {
qedge_id
for qedge_id, qedge in message["query_graph"]["edges"].items()
if (
qedge["subject"] in unique_qnodes
and qedge["object"] in unique_qnodes
)
}
result_buckets = defaultdict(lambda: {
"node_bindings": defaultdict(set),
"edge_bindings": defaultdict(set),
})
for result in message["results"]:
bucket_key = tuple([
binding["id"]
for qnode_id in unique_qnodes
for binding in result["node_bindings"][qnode_id]
] + [
binding["id"]
for qedge_id in unique_qedges
for binding in result["edge_bindings"][qedge_id]
])
for qnode_id in message["query_graph"]["nodes"]:
result_buckets[bucket_key]["node_bindings"][qnode_id] |= {
binding["id"]
for binding in result["node_bindings"][qnode_id]
}
for qedge_id in message["query_graph"]["edges"]:
result_buckets[bucket_key]["edge_bindings"][qedge_id] |= {
binding["id"]
for binding in result["edge_bindings"][qedge_id]
}
for result in result_buckets.values():
result["node_bindings"] = {
qnode_id: [
{"id": binding}
for binding in bindings
]
for qnode_id, bindings in result["node_bindings"].items()
}
result["edge_bindings"] = {
qedge_id: [
{"id": binding}
for binding in bindings
]
for qedge_id, bindings in result["edge_bindings"].items()
}
message["results"] = list(result_buckets.values())
13 changes: 8 additions & 5 deletions strider/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from reasoner_pydantic import Query, AsyncQuery, Message, Response as ReasonerResponse

from .fetcher import Binder
from .node_sets import collapse_sets
from .query_planner import NoAnswersError, generate_plan
from .scoring import score_graph
from .storage import RedisGraph, RedisList, get_client as get_redis_client
Expand Down Expand Up @@ -243,13 +244,15 @@ async def lookup(
kgraph,
])
results.append(result)
message = {
"query_graph": qgraph,
"knowledge_graph": kgraph,
"results": results,
}
collapse_sets(message)
logs = list(RedisList(f"{qid}:log", redis_client).get())
return {
"message": {
"query_graph": qgraph,
"knowledge_graph": kgraph,
"results": results,
},
"message": message,
"logs": logs,
}

Expand Down
54 changes: 54 additions & 0 deletions tests/test_node_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test node set handling."""
from strider.node_sets import collapse_sets


def test_node_sets():
"""Test collapsing one edge of a two-hop query."""
qgraph = {
"nodes": {
"n0": {},
"n1": {"is_set": True}
},
"edges": {
"e01": {
"subject": "n0",
"object": "n1",
},
},
}
results = [
{
"node_bindings": {
"n0": [{"id": "a0"}],
"n1": [{"id": "b0"}],
},
"edge_bindings": {
"e01": [{"id": "c0"}],
},
},
{
"node_bindings": {
"n0": [{"id": "a1"}],
"n1": [{"id": "b0"}],
},
"edge_bindings": {
"e01": [{"id": "c1"}],
},
},
{
"node_bindings": {
"n0": [{"id": "a0"}],
"n1": [{"id": "b1"}],
},
"edge_bindings": {
"e01": [{"id": "c2"}],
},
},
]
message = {
"query_graph": qgraph,
"results": results,
}
collapse_sets(message)
assert len(message["results"]) == 2
assert len(message["results"][0]["node_bindings"]["n1"]) == 2
57 changes: 56 additions & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ async def test_trivial_unbatching(redis):
},
output["message"]
)



@pytest.mark.asyncio
@with_translator_overlay(
settings.kpregistry_url,
Expand Down Expand Up @@ -464,3 +465,57 @@ async def test_gene_protein_conflation(redis):
},
output["message"]
)


@pytest.mark.asyncio
@with_translator_overlay(
settings.kpregistry_url,
settings.normalizer_url,
{
"ctd":
"""
CHEBI:6801(( category biolink:ChemicalSubstance ))
MONDO:0005148(( category biolink:Disease ))
CHEBI:6801-- predicate biolink:treats -->MONDO:0005148
MONDO:0003757(( category biolink:Disease ))
CHEBI:6801-- predicate biolink:treats -->MONDO:0003757
""",
"pharos":
"""
MONDO:0003757(( category biolink:Disease ))
MONDO:0005148(( category biolink:Disease ))
HP:XXX(( category biolink:PhenotypicFeature ))
HP:YYY(( category biolink:PhenotypicFeature ))
MONDO:0003757-- predicate biolink:has_phenotype -->HP:XXX
MONDO:0003757-- predicate biolink:has_phenotype -->HP:YYY
MONDO:0005148-- predicate biolink:has_phenotype -->HP:YYY
""",
}
)
async def test_node_set(redis):
"""Test that is_set is handled correctly."""
QGRAPH = query_graph_from_string(
"""
n0(( ids[] CHEBI:6801 ))
n0(( categories[] biolink:ChemicalSubstance ))
n1(( categories[] biolink:Disease ))
n2(( categories[] biolink:PhenotypicFeature ))
n0-- biolink:treats -->n1
n1-- biolink:has_phenotype -->n2
"""
)
QGRAPH["nodes"]["n1"]["is_set"] = True

# Create query
q = {
"message" : {"query_graph" : QGRAPH},
"log_level" : "WARNING",
}

# Run
output = await lookup(q, redis)
assert len(output["message"]["results"]) == 2
assert {
len(result["node_bindings"]["n1"])
for result in output["message"]["results"]
} == {1, 2}