Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pinnedness planning #300

Merged
merged 5 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions strider/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .trapi import canonicalize_qgraph, filter_by_qgraph, get_curies, map_qgraph_curies, merge_messages, merge_results, \
fill_categories_predicates
from .caching import async_locking_cache
from .query_planner import generate_plan
from .query_planner import generate_plan, get_next_qedge
from .storage import RedisGraph, RedisList, RedisLogHandler
from .kp_registry import Registry
from .config import settings
Expand Down Expand Up @@ -119,17 +119,7 @@ async def lookup(
self.logger.debug(f"Lookup for qgraph: {qgraph}")

try:
qedge_id, qedge = next(
(qedge_id, qedge)
for qedge_id, qedge in qgraph["edges"].items()
if any(
qnode.get("ids", [])
for qnode in (
qgraph["nodes"][qedge["subject"]],
qgraph["nodes"][qedge["object"]],
)
)
)
qedge_id, qedge = get_next_qedge(qgraph)
except StopIteration:
raise RuntimeError("Cannot find qedge with pinned endpoint")

Expand Down
81 changes: 80 additions & 1 deletion strider/query_planner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Query planner."""
from collections import namedtuple
from collections import defaultdict, namedtuple
import logging
import copy
import math
from typing import Generator

from reasoner_pydantic import QueryGraph
Expand Down Expand Up @@ -209,3 +210,81 @@ def get_query_graph_edge_kps(
"reverse": og_edge["edge_reverse"],
})
return kps


N = 1_000_000 # total number of nodes
R = 25 # number of edges per node


def get_next_qedge(qgraph):
"""Get next qedge to solve."""
qgraph = copy.deepcopy(qgraph)
for qnode in qgraph["nodes"].values():
if qnode.get("ids") is not None:
qnode["ids"] = len(qnode["ids"])
else:
qnode["ids"] = N
pinnednesses = {
qnode_id: get_pinnedness(qgraph, qnode_id)
for qnode_id in qgraph["nodes"]
}
efforts = {
qedge_id: math.log(
qgraph["nodes"][qedge["subject"]]["ids"]
) + math.log(
qgraph["nodes"][qedge["object"]]["ids"]
)
for qedge_id, qedge in qgraph["edges"].items()
}
edge_priorities = {
qedge_id: pinnednesses[qedge["subject"]] + pinnednesses[qedge["object"]] - efforts[qedge_id]
for qedge_id, qedge in qgraph["edges"].items()
}
qedge_id = max(edge_priorities, key=edge_priorities.get)
return qedge_id, qgraph["edges"][qedge_id]


def get_pinnedness(qgraph, qnode_id):
"""Get pinnedness of each node."""
adjacency_mat = get_adjacency_matrix(qgraph)
num_ids = get_num_ids(qgraph)
return -compute_log_expected_n(
adjacency_mat,
num_ids,
qnode_id,
)


def compute_log_expected_n(adjacency_mat, num_ids, qnode_id, last=None, level=0):
"""Compute the log of the expected number of unique knodes bound to the specified qnode in the final results."""
log_expected_n = math.log(num_ids[qnode_id])
if level < 10:
for neighbor, num_edges in adjacency_mat[qnode_id].items():
if neighbor == last:
continue
# ignore contributions >0 - edges should only _further_ constrain n
log_expected_n += num_edges * min(max(compute_log_expected_n(
adjacency_mat,
num_ids,
neighbor,
qnode_id,
level + 1,
), 0) + math.log(R / N), 0)
return log_expected_n


def get_adjacency_matrix(qgraph):
"""Get adjacency matrix."""
A = defaultdict(lambda: defaultdict(int))
for qedge in qgraph["edges"].values():
A[qedge["subject"]][qedge["object"]] += 1
A[qedge["object"]][qedge["subject"]] += 1
return A


def get_num_ids(qgraph):
"""Get the number of ids for each node."""
return {
qnode_id: qnode["ids"]
for qnode_id, qnode in qgraph["nodes"].items()
}
74 changes: 73 additions & 1 deletion tests/test_query_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tests.helpers.logger import assert_no_level


from strider.query_planner import generate_plan
from strider.query_planner import generate_plan, get_next_qedge

from strider.trapi import fill_categories_predicates

Expand Down Expand Up @@ -514,3 +514,75 @@ async def test_double_sided(caplog):
plan, kps = await generate_plan(qg, logger=logging.getLogger())
assert plan == {"n0n1": ["kp0"]}
assert "kp0" in kps


def test_get_next_qedge():
"""Test get_next_qedge()."""
qgraph = {
"nodes": {
"n0": {"ids": ["01", "02"]},
"n1": {},
"n2": {"ids": ["03"]},
},
"edges": {
"e01": {
"subject": "n0",
"object": "n1",
},
"e12": {
"subject": "n1",
"object": "n2",
},
},
}
qedge_id, _ = get_next_qedge(qgraph)
assert qedge_id == "e12"


def test_get_next_qedge_with_self_edge():
"""Test get_next_qedge() with a self edge."""
qgraph = {
"nodes": {
"n0": {"ids": ["01", "02"]},
"n1": {},
},
"edges": {
"e01": {
"subject": "n0",
"object": "n1",
},
"e00": {
"subject": "n0",
"object": "n0",
},
},
}
qedge_id, _ = get_next_qedge(qgraph)
assert qedge_id == "e00"


def test_get_next_qedge_multi_edges():
"""Test get_next_qedge() with multiple edges between two nodes."""
qgraph = {
"nodes": {
"n0": {"ids": ["01", "02"]},
"n1": {},
"n2": {"ids": ["03"]},
},
"edges": {
"e01": {
"subject": "n0",
"object": "n1",
},
"e12": {
"subject": "n1",
"object": "n2",
},
"e012": {
"subject": "n0",
"object": "n1",
},
},
}
qedge_id, _ = get_next_qedge(qgraph)
assert qedge_id == "e01"