Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filtering within multiedges #48

Merged
merged 10 commits into from
Jun 20, 2024
88 changes: 57 additions & 31 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@

return_clause : "return"i distinct_return? return_item ("," return_item)*
return_item : (entity_id | aggregation_function | entity_id "." attribute_id) ( "AS"i alias )?
alias : CNAME

aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")"
AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN"
attribute_id : CNAME
alias : CNAME

distinct_return : "DISTINCT"i
limit_clause : "limit"i NUMBER
Expand Down Expand Up @@ -314,14 +314,20 @@ def _get_edge(host: nx.DiGraph, mapping, match_path, u, v):

def and_(cond_a, cond_b) -> CONDITION:
def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
return cond_a(match, host, return_endges) and cond_b(match, host, return_endges)

condition_a, where_a = cond_a(match, host, return_endges)
condition_b, where_b = cond_b(match, host, return_endges)
where_result = [a and b for a, b in zip(where_a, where_b)]
return (condition_a and condition_b), where_result

return inner


def or_(cond_a, cond_b):
def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
return cond_a(match, host, return_endges) or cond_b(match, host, return_endges)
condition_a, where_a = cond_a(match, host, return_endges)
condition_b, where_b = cond_b(match, host, return_endges)
where_result = [a or b for a, b in zip(where_a, where_b)]
return (condition_a or condition_b), where_result

return inner

Expand All @@ -339,20 +345,27 @@ def inner(
host_entity_id[0] = (match[edge_mapping[0]], match[edge_mapping[1]])
else:
raise IndexError(f"Entity {host_entity_id} not in graph.")
try:
if isinstance(host, nx.MultiDiGraph):
# if any of the relations between nodes satisfies condition, return True
r_vals = _get_entity_from_host(host, *host_entity_id)
r_vals = [r_vals] if not isinstance(r_vals, list) else r_vals
val = any(operator(r_val, value) for r_val in r_vals)
else:

if isinstance(host, nx.MultiDiGraph):
# if any of the relations between nodes satisfies condition, return True
r_vals = _get_entity_from_host(host, *host_entity_id)
r_vals = [r_vals] if not isinstance(r_vals, list) else r_vals
operator_results = []
for r_val in r_vals:
try:
operator_results.append(operator(r_val, value))
except:
operator_results.append(False)
val = any(operator_results)
else:
try:
val = operator(_get_entity_from_host(host, *host_entity_id), value)
except:
val = False

except:
val = False
if val != should_be:
return False
return True
return False, operator_results
return True, operator_results

return inner

Expand Down Expand Up @@ -397,6 +410,16 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._max_hop = 100

def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:

def _filter_edge(edge, where_results):
# no where condition -> return edge
if where_results == []:
return edge
else:
# exclude edge(s) from multiedge that don't satisfy the where condition
edge = {k: v for k, v in edge[0].items() if where_results[k] is True}
return [edge]

if not data_paths:
return {}

Expand All @@ -422,7 +445,7 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
if entity_name in motif_nodes:
# We are looking for a node mapping in the target graph:

ret = (mapping[entity_name] for mapping, _ in true_matches)
ret = (mapping[0][entity_name] for mapping, _ in true_matches)
# by default, just return the node from the host graph

if entity_attribute:
Expand All @@ -436,6 +459,7 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
elif entity_name in self._paths:
ret = []
for mapping, _ in true_matches:
mapping = mapping[0]
path, nodes = [], list(mapping.values())
for x, node in enumerate(nodes):
# Edge
Expand All @@ -454,8 +478,10 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
# We are looking for an edge mapping in the target graph:
is_hop = self._motif.edges[(mapping_u, mapping_v, 0)]["__is_hop__"]
ret = (
_get_edge(
self._target_graph, mapping, match_path, mapping_u, mapping_v
_filter_edge(
_get_edge(
self._target_graph, mapping[0], match_path, mapping_u, mapping_v
), mapping[1]
)
for mapping, match_path in true_matches
)
Expand All @@ -479,14 +505,10 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
filtered_ret = []
for r in ret:

if any(
[
i.get("__labels__", None).issubset(
motif_edge_labels
)
for i in r.values()
]
):
r = {
k: v for k, v in r.items() if v.get("__labels__", None).intersection(motif_edge_labels)
}
if len(r) > 0:
filtered_ret.append(r)

ret = filtered_ret
Expand All @@ -498,7 +520,7 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
if isinstance(r, dict):
r = [r]
for el in r:
for i, v in el.items():
for i, v in enumerate(el.values()):
r_attr[(i, list(v.get("__labels__", [i]))[0])] = v.get(
entity_attribute, None
)
Expand Down Expand Up @@ -837,10 +859,14 @@ def _get_true_matches(self):
match[b] = match[a]
else: # For/else loop
# Check if match matches where condition and add
if not self._where_condition or self._where_condition(
match, self._target_graph, self._return_edges
):
self_matches.append(match)
if self._where_condition:
satisfies_where, where_results = self._where_condition(
match, self._target_graph, self._return_edges
)
else:
where_results = []
if not self._where_condition or satisfies_where:
self_matches.append((match, where_results))
self_matche_paths.append(edge_hop_map)

# Check if limit reached; stop ONLY IF we are not ordering
Expand Down
95 changes: 86 additions & 9 deletions grandcypher/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def test_multiple_edges_specific_attribute(self):
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice"]
assert res["b.name"] == ["Bob"]
assert res["r.years"] == [{(0, 'colleague'): 3, (1, 'friend'): 5, (2, 'enemy'): None}] # should return None when attr is missing
assert res["r.years"] == [{(0, 'friend'): 5}]

def test_edge_directionality(self):
host = nx.MultiDiGraph()
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def test_query_with_missing_edge_attribute(self):
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Charlie", "Charlie"]
assert res["r.duration"] == [{(0, 'colleague'): None}, {(0, 'colleague'): 10, (1, 'mentor'): None}]
assert res["r.duration"] == [{(0, 'colleague'): None}, {(0, 'colleague'): 10}]

qry = """
MATCH (a)-[r:colleague]->(b)
Expand All @@ -1073,7 +1073,7 @@ def test_query_with_missing_edge_attribute(self):
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Charlie", "Charlie"]
assert res["r.years"] == [{(0, 'colleague'): 10}, {(0, 'colleague'): None, (1, 'mentor'): 2}]
assert res["r.years"] == [{(0, 'colleague'): 10}, {(0, 'colleague'): None}]

qry = """
MATCH (a)-[r]->(b)
Expand All @@ -1085,7 +1085,7 @@ def test_query_with_missing_edge_attribute(self):
assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}]
assert res["r.duration"] == [{(0, 'friend'): None}, {(0, 'colleague'): None}, {(0, 'colleague'): 10, (1, 'mentor'): None}]

def test_multigraph_single_edge_where(self):
def test_multigraph_single_edge_where1(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
Expand All @@ -1103,9 +1103,30 @@ def test_multigraph_single_edge_where(self):
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Bob", "Alice"]
assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}]
assert res["r.years"] == [{(0, 'friend'): 1}, {(0, 'colleague'): 2, (1, 'mentor'): 4}]
assert res["r.friendly"] == [{(0, 'friend'): 'very'}, {(0, 'colleague'): None, (1, 'mentor'): None}]
assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}}]
assert res["r.years"] == [{(0, 'friend'): 1}, {(0, 'colleague'): 2}]
assert res["r.friendly"] == [{(0, 'friend'): 'very'}, {(0, 'colleague'): None}]

def test_multigraph_single_edge_where2(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"paid"}, value=20)
host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June")
host.add_edge("b", "a", __labels__={"paid"}, amount=6)
host.add_edge("b", "a", __labels__={"paid"}, value=14)
host.add_edge("a", "b", __labels__={"friends"}, years=9)
host.add_edge("a", "b", __labels__={"paid"}, amount=40)

qry = """
MATCH (n)-[r:paid]->(m)
WHERE r.amount > 12
RETURN n.name, m.name, r.amount
"""
res = GrandCypher(host).run(qry)
assert res['n.name'] == ['Alice']
assert res['m.name'] == ['Bob']
assert res['r.amount'] == [{(0, 'paid'): 40}]

def test_multigraph_where_node_attribute(self):
host = nx.MultiDiGraph()
Expand Down Expand Up @@ -1147,7 +1168,63 @@ def test_multigraph_multiple_same_edge_labels(self):
assert res["n.name"] == ["Alice", "Bob"]
assert res["m.name"] == ["Bob", "Alice"]
# the second "paid" edge between Bob -> Alice has no "amount" attribute, so it should be None
assert res["r.amount"] == [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}]
assert res["r.amount"] == [{(0, 'paid'): 12, (1, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}]

def test_order_by_edge_attribute1(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)
host.add_edge("b", "a", __labels__={"paid"}, value=14)
host.add_edge("a", "b", __labels__={"paid"}, value=9)
host.add_edge("a", "b", __labels__={"paid"}, value=40)

qry = """
MATCH (n)-[r]->()
RETURN n.name, r.value
ORDER BY r.value ASC
"""
res = GrandCypher(host).run(qry)
assert res['n.name'] == ['Alice', 'Bob']
assert res['r.value'] == [{(0, 'paid'): 9, (1, 'paid'): 40}, {(0, 'paid'): 14}]

qry = """
MATCH (n)-[r]->()
RETURN n.name, r.value
ORDER BY r.value DESC
"""
res = GrandCypher(host).run(qry)
assert res['n.name'] == ['Alice', 'Bob']
assert res['r.value'] == [{(1, 'paid'): 40, (0, 'paid'): 9}, {(0, 'paid'): 14}]

def test_order_by_edge_attribute2(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)
host.add_edge("b", "a", __labels__={"paid"}, amount=14) # different attribute name
host.add_edge("a", "b", __labels__={"paid"}, value=9)
host.add_edge("c", "b", __labels__={"paid"}, value=980)
host.add_edge("c", "b", __labels__={"paid"}, value=4)
host.add_edge("b", "c", __labels__={"paid"}, value=11)
host.add_edge("a", "b", __labels__={"paid"}, value=40)
host.add_edge("b", "a", __labels__={"paid"}, value=14) # duplicate edge
host.add_edge("a", "b", __labels__={"paid"}, value=9) # duplicate edge
host.add_edge("a", "b", __labels__={"paid"}, value=40) # duplicate edge

qry = """
MATCH (n)-[r]->(m)
RETURN n.name, r.value, m.name
ORDER BY r.value ASC
"""
res = GrandCypher(host).run(qry)
assert res['r.value'] == [
{(0, 'paid'): None, (1, 'paid'): 14}, # None for the different attribute edge
{(1, 'paid'): 4, (0, 'paid'): 980}, # within edges, the attributes are ordered
{(0, 'paid'): 9, (2, 'paid'): 9, (1, 'paid'): 40, (3, 'paid'): 40},
{(0, 'paid'): 11}
]
assert res['m.name'] == ['Alice', 'Bob', 'Bob', 'Carol']

def test_order_by_edge_attribute1(self):
host = nx.MultiDiGraph()
Expand Down Expand Up @@ -1220,7 +1297,7 @@ def test_multigraph_aggregation_function_sum(self):
RETURN n.name, m.name, SUM(r.amount)
"""
res = GrandCypher(host).run(qry)
assert res['SUM(r.amount)'] == [{'friends': 0, 'paid': 52}, {'paid': 6}]
assert res['SUM(r.amount)'] == [{'paid': 52}, {'paid': 6}]

def test_multigraph_aggregation_function_avg(self):
host = nx.MultiDiGraph()
Expand Down
Loading