Skip to content

Commit

Permalink
Modified probe algorithm to use pinned nodes and longest paths
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethmorton committed Jul 10, 2024
1 parent a5f02c7 commit 71f7fc3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
51 changes: 37 additions & 14 deletions ranker/shared/ranker_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from itertools import combinations, product

import scipy.sparse.csgraph
import numpy as np

from ranker.shared.sources import get_profile, get_source_sigmoid, get_source_weight, get_base_weight
Expand Down Expand Up @@ -72,11 +73,14 @@ def rank(self, answers, jaccard_like=False):
return answers

def probes(self):

# Identify Probes
#################
# Q Graph Connectivity Matrix
q_node_ids = list(self.qgraph["nodes"].keys()) # Need to preserve order!
q_node_ids = list(self.qgraph["nodes"].keys()) # Preserve order later!
bound_q_node_ids = [k for k, v in self.qgraph["nodes"].items() if "ids" in v and v["ids"]]

# Calculate connectivity matrix
n_q_nodes = len(q_node_ids)
q_conn = np.full((n_q_nodes, n_q_nodes), 0)
for e in self.qgraph["edges"].values():
Expand All @@ -85,18 +89,37 @@ def probes(self):
if e_sub is not None and e_obj is not None:
q_conn[e_sub, e_obj] = 1

# Determine probes based on connectivity
node_conn = np.sum(q_conn, 0) + np.sum(q_conn, 1).T
probe_nodes = []
for conn in range(np.max(node_conn)):
is_this_conn = node_conn == (conn + 1)
probe_nodes += list(np.where(is_this_conn)[0])
if len(probe_nodes) > 1:
break
q_probes = list(combinations(probe_nodes, 2))

# Convert probes back to q_node_ids
probes = [(q_node_ids[p[0]],q_node_ids[p[1]]) for p in q_probes]
if len(bound_q_node_ids) > 1:
return list(combinations(bound_q_node_ids, 2))

elif len(bound_q_node_ids) == 1:
the_bound_q_node_id = bound_q_node_ids[0]

the_bound_ind = q_node_ids.index(the_bound_q_node_id)
check_inds = list(range(n_q_nodes))
check_inds.remove(the_bound_ind)

shortest_dist_mat = scipy.sparse.csgraph.dijkstra(q_conn)
max_dist = max(shortest_dist_mat[the_bound_ind, :])

other_probe_inds = np.where(shortest_dist_mat[the_bound_ind, :] == max_dist)

return [(the_bound_q_node_id, q_node_ids[op]) for op in other_probe_inds[0]]

else: # No bound q_nodes

# Determine probes based on maximum node connectivity
node_conn = np.sum(q_conn, 0) + np.sum(q_conn, 1).T
probe_nodes = []
for conn in range(np.max(node_conn)):
is_this_conn = node_conn == (conn + 1)
probe_nodes += list(np.where(is_this_conn)[0])
if len(probe_nodes) > 1:
break
q_probes = list(combinations(probe_nodes, 2))

# Convert probes back to q_node_ids
return [(q_node_ids[p[0]], q_node_ids[p[1]]) for p in q_probes]

return probes

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ lru-dict==1.3.0
pydantic>=1.8.1
pyyaml==5.3.1
redis==4.5.4
scipy==1.13.1
uvicorn==0.17.6
uvloop==0.19.0
numpy==1.26.4
Expand Down

0 comments on commit 71f7fc3

Please sign in to comment.