From 71f7fc38516178c4c5281e04759f34f8b10be96b Mon Sep 17 00:00:00 2001 From: kmorton Date: Wed, 10 Jul 2024 16:47:43 -0400 Subject: [PATCH] Modified probe algorithm to use pinned nodes and longest paths --- ranker/shared/ranker_obj.py | 51 +++++++++++++++++++++++++++---------- requirements.txt | 1 + 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/ranker/shared/ranker_obj.py b/ranker/shared/ranker_obj.py index ac5b5a5..47e478f 100644 --- a/ranker/shared/ranker_obj.py +++ b/ranker/shared/ranker_obj.py @@ -5,6 +5,7 @@ from collections import defaultdict from itertools import combinations, product +import scipy.sparse.csgraph import numpy as np from ranker.shared.sources import get_profile, get_source_sigmoid, get_source_weight, get_base_weight @@ -72,11 +73,14 @@ def rank(self, answers, jaccard_like=False): return answers def probes(self): - + # Identify Probes ################# # Q Graph Connectivity Matrix - q_node_ids = list(self.qgraph["nodes"].keys()) # Need to preserve order! + q_node_ids = list(self.qgraph["nodes"].keys()) # Preserve order later! + bound_q_node_ids = [k for k, v in self.qgraph["nodes"].items() if "ids" in v and v["ids"]] + + # Calculate connectivity matrix n_q_nodes = len(q_node_ids) q_conn = np.full((n_q_nodes, n_q_nodes), 0) for e in self.qgraph["edges"].values(): @@ -85,18 +89,37 @@ def probes(self): if e_sub is not None and e_obj is not None: q_conn[e_sub, e_obj] = 1 - # Determine probes based on connectivity - node_conn = np.sum(q_conn, 0) + np.sum(q_conn, 1).T - probe_nodes = [] - for conn in range(np.max(node_conn)): - is_this_conn = node_conn == (conn + 1) - probe_nodes += list(np.where(is_this_conn)[0]) - if len(probe_nodes) > 1: - break - q_probes = list(combinations(probe_nodes, 2)) - - # Convert probes back to q_node_ids - probes = [(q_node_ids[p[0]],q_node_ids[p[1]]) for p in q_probes] + if len(bound_q_node_ids) > 1: + return list(combinations(bound_q_node_ids, 2)) + + elif len(bound_q_node_ids) == 1: + the_bound_q_node_id = bound_q_node_ids[0] + + the_bound_ind = q_node_ids.index(the_bound_q_node_id) + check_inds = list(range(n_q_nodes)) + check_inds.remove(the_bound_ind) + + shortest_dist_mat = scipy.sparse.csgraph.dijkstra(q_conn) + max_dist = max(shortest_dist_mat[the_bound_ind, :]) + + other_probe_inds = np.where(shortest_dist_mat[the_bound_ind, :] == max_dist) + + return [(the_bound_q_node_id, q_node_ids[op]) for op in other_probe_inds[0]] + + else: # No bound q_nodes + + # Determine probes based on maximum node connectivity + node_conn = np.sum(q_conn, 0) + np.sum(q_conn, 1).T + probe_nodes = [] + for conn in range(np.max(node_conn)): + is_this_conn = node_conn == (conn + 1) + probe_nodes += list(np.where(is_this_conn)[0]) + if len(probe_nodes) > 1: + break + q_probes = list(combinations(probe_nodes, 2)) + + # Convert probes back to q_node_ids + return [(q_node_ids[p[0]], q_node_ids[p[1]]) for p in q_probes] return probes diff --git a/requirements.txt b/requirements.txt index b3c269b..503f018 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ lru-dict==1.3.0 pydantic>=1.8.1 pyyaml==5.3.1 redis==4.5.4 +scipy==1.13.1 uvicorn==0.17.6 uvloop==0.19.0 numpy==1.26.4