Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handles both MultiDiGraph and DiGraph without up-conversion #55

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ def _is_edge_attr_match(
motif_edges = _get_edge_attributes(motif, motif_u, motif_v)
host_edges = _get_edge_attributes(host, host_u, host_v)

if not motif_edges or not host_edges:
# if there are no edges, they don't match
return False

# Aggregate all __labels__ into one set
motif_edges = _aggregate_edge_labels(motif_edges)
host_edges = _aggregate_edge_labels(host_edges)
Expand All @@ -262,9 +266,11 @@ def _get_edge_attributes(graph: Union[nx.Graph, nx.MultiDiGraph], u, v) -> Dict:
"""
Retrieve edge attributes from a graph, handling both Graph and MultiDiGraph.
"""
if isinstance(graph, nx.MultiDiGraph):
return graph[u][v]
return {0: graph[u][v]} # Mock single edge for DiGraph
if graph.is_multigraph():
return graph.get_edge_data(u, v)
else:
data = graph.get_edge_data(u, v)
return {0: data} # Wrap in dict to mimic MultiDiGraph structure


def _aggregate_edge_labels(edges: Dict) -> Dict:
Expand Down Expand Up @@ -294,24 +300,27 @@ def _get_entity_from_host(
return entity_name
else:
# looking for an edge:
edge_data = host.get_edge_data(*entity_name)
u, v = entity_name
edge_data = _get_edge_attributes(host, u, v)
if not edge_data:
return None # print(f"Nothing found for {entity_name} {entity_attribute}")

if entity_attribute:
# looking for edge attribute:
if isinstance(host, nx.MultiDiGraph):
return [r.get(entity_attribute, None) for r in edge_data.values()]
if host.is_multigraph():
# return a list of attribute values for all edges between u and v
return [attrs.get(entity_attribute) for attrs in edge_data.values()]
else:
return edge_data.get(entity_attribute, None)
# return the attribute value for the single edge
return edge_data[0].get(entity_attribute)
else:
return host.get_edge_data(*entity_name)
return edge_data


def _get_edge(host: nx.DiGraph, mapping, match_path, u, v):
def _get_edge(host: Union[nx.DiGraph, nx.MultiDiGraph], mapping, match_path, u, v):
edge_path = match_path[(u, v)]
return [
host.get_edge_data(mapping[u], mapping[v])
_get_edge_attributes(host, mapping[u], mapping[v])
for u, v in zip(edge_path[:-1], edge_path[1:])
]

Expand Down Expand Up @@ -353,11 +362,11 @@ def inner(
else:
raise IndexError(f"Entity {host_entity_id} not in graph.")

operator_results = []
if isinstance(host, nx.MultiDiGraph):
# if any of the relations between nodes satisfies condition, return True
r_vals = _get_entity_from_host(host, *host_entity_id)
r_vals = [r_vals] if not isinstance(r_vals, list) else r_vals
operator_results = []
for r_val in r_vals:
try:
operator_results.append(operator(r_val, value))
Expand All @@ -369,6 +378,7 @@ def inner(
val = operator(_get_entity_from_host(host, *host_entity_id), value)
except:
val = False
operator_results.append(val)

if val != should_be:
return False, operator_results
Expand Down Expand Up @@ -398,9 +408,6 @@ def _data_path_to_entity_name_attribute(data_path):
class _GrandCypherTransformer(Transformer):
def __init__(self, target_graph: nx.Graph, limit=None):
self._target_graph = target_graph
if not isinstance(self._target_graph, nx.MultiDiGraph):
self._target_graph = nx.MultiDiGraph(target_graph)
logger.warning("Converting graph to MultiDiGraph")
self._entity2alias = dict()
self._alias2entity = dict()
self._paths = []
Expand Down Expand Up @@ -754,6 +761,15 @@ def returns(self, ignore_limit=False):
for key, values in results.items()
if self._alias2entity.get(key, key) in self._return_requests
}
# HACK: convert to [None] if edge is None
for key, values in results.items():
parsed_values = []
for v in values:
if v == [{0: None}]: # edge is None
parsed_values.append([None])
else:
parsed_values.append(v)
results[key] = parsed_values

return results

Expand Down Expand Up @@ -973,8 +989,18 @@ def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict]
if motif.out_degree(n) == 0 and motif.in_degree(n) == 0:
new_motif.add_node(n, **motif.nodes[n])
motifs: List[Tuple[nx.DiGraph, dict]] = [(new_motif, {})]

if motif.is_multigraph():
edge_iter = motif.edges(keys=True)
else:
edge_iter = motif.edges(keys=False)

for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)])
for edge in edge_iter:
if motif.is_multigraph():
u, v, k = edge
else:
u, v = edge
k = 0 # Dummy key for DiGraph
new_motifs = []
min_hop = motif.edges[u, v, k]["__min_hop__"]
max_hop = motif.edges[u, v, k]["__max_hop__"]
Expand Down Expand Up @@ -1002,8 +1028,8 @@ def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict]

def _product_motifs(
self,
motifs_1: List[Tuple[nx.DiGraph, dict]],
motifs_2: List[Tuple[nx.DiGraph, dict]],
motifs_1: List[Tuple[nx.Graph, dict]],
motifs_2: List[Tuple[nx.Graph, dict]],
):
new_motifs = []
for motif_1, mapping_1 in motifs_1:
Expand Down
14 changes: 0 additions & 14 deletions grandcypher/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,6 @@ def test_simple_structural_match_returns_node_attributes(self, graph_type):
assert len(returns["A.dinnertime"]) == 2


@pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES)
def test_warning_for_non_multidigraph(self, graph_type, caplog):
host = graph_type()

with caplog.at_level(logging.WARNING):
gct = GrandCypher(host)

if isinstance(host, nx.MultiDiGraph):
assert len(caplog.records) == 0
elif isinstance(host, nx.DiGraph):
assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"


class TestSimpleAPI:
@pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES)
def test_simple_api(self, graph_type):
Expand Down
Loading