diff --git a/strider/query_planner.py b/strider/query_planner.py index 3ad6763d..ce2bdb62 100644 --- a/strider/query_planner.py +++ b/strider/query_planner.py @@ -247,41 +247,44 @@ def get_next_qedge(qgraph): def get_pinnedness(qgraph, qnode_id): """Get pinnedness of each node.""" adjacency_mat = get_adjacency_matrix(qgraph) - return -compute_expected_n( + num_ids = get_num_ids(qgraph) + return -compute_log_expected_n( adjacency_mat, + num_ids, qnode_id, ) -def compute_expected_n(adjacency_mat, qnode_id, last=None, level=0): - """Compute the expected number of unique knodes bound to the specified qnode in the final results.""" - pinnedness = math.log(adjacency_mat[qnode_id][qnode_id]) +def compute_log_expected_n(adjacency_mat, num_ids, qnode_id, last=None, level=0): + """Compute the log of the expected number of unique knodes bound to the specified qnode in the final results.""" + log_expected_n = math.log(num_ids[qnode_id]) if level < 10: - for neighbor in adjacency_mat[qnode_id]: - if neighbor in (qnode_id, last): + for neighbor, num_edges in adjacency_mat[qnode_id].items(): + if neighbor == last: continue - pinnedness += min(max(compute_expected_n( + # ignore contributions >0 - edges should only _further_ constrain n + log_expected_n += num_edges * min(max(compute_log_expected_n( adjacency_mat, + num_ids, neighbor, qnode_id, level + 1, ), 0) + math.log(R / N), 0) - return pinnedness + return log_expected_n def get_adjacency_matrix(qgraph): """Get adjacency matrix.""" - A = defaultdict(lambda: defaultdict(bool)) - for qnode_id, qnode in qgraph["nodes"].items(): - ids = qnode.get("ids") - if ids is None: - num_ids = N - elif isinstance(ids, list): - num_ids = len(ids) - else: - num_ids = ids - A[qnode_id][qnode_id] = num_ids + A = defaultdict(lambda: defaultdict(int)) for qedge in qgraph["edges"].values(): - A[qedge["subject"]][qedge["object"]] = True - A[qedge["object"]][qedge["subject"]] = True + A[qedge["subject"]][qedge["object"]] += 1 + A[qedge["object"]][qedge["subject"]] += 1 return A + + +def get_num_ids(qgraph): + """Get the number of ids for each node.""" + return { + qnode_id: qnode["ids"] + for qnode_id, qnode in qgraph["nodes"].items() + }