diff --git a/redisgraph/node.py b/redisgraph/node.py index 856d55e..3f583c8 100644 --- a/redisgraph/node.py +++ b/redisgraph/node.py @@ -11,7 +11,21 @@ def __init__(self, node_id=None, alias=None, label=None, properties=None): """ self.id = node_id self.alias = alias - self.label = label + + if isinstance(label, list): + label = [inner_label for inner_label in label if inner_label != ""] + + if label is None or label == "" or (isinstance(label, list) and len(label) == 0): + self.label = None + self.labels = None + elif isinstance(label, str): + self.label = label + self.labels = [label] + elif isinstance(label, list) and all([isinstance(inner_label, str) for inner_label in label]): + self.label = label[0] + self.labels = label + else: + raise AssertionError("label should be either None, string or a list of strings") self.properties = properties or {} def toString(self): @@ -26,8 +40,8 @@ def __str__(self): res = '(' if self.alias: res += self.alias - if self.label: - res += ':' + self.label + if self.labels: + res += ":" + ":".join(self.labels) if self.properties: props = ','.join(key+':'+str(quote_string(val)) for key, val in sorted(self.properties.items())) res += '{' + props + '}' diff --git a/redisgraph/query_result.py b/redisgraph/query_result.py index 7f9b4de..8474287 100644 --- a/redisgraph/query_result.py +++ b/redisgraph/query_result.py @@ -142,11 +142,13 @@ def parse_node(self, cell): # [[name, value type, value] X N] node_id = int(cell[0]) - label = None - if len(cell[1]) != 0: - label = self.graph.get_label(cell[1][0]) + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(self.graph.get_label(inner_label)) properties = self.parse_entity_properties(cell[2]) - return Node(node_id=node_id, label=label, properties=properties) + return Node(node_id=node_id, label=labels, properties=properties) def parse_edge(self, cell): # Edge ID (integer), diff --git a/tests/functional/test_all.py b/tests/functional/test_all.py index ab30512..30e8040 100644 --- a/tests/functional/test_all.py +++ b/tests/functional/test_all.py @@ -359,6 +359,34 @@ def test_cache_sync(self): assert(A._relationshipTypes[0] == 'S') assert(A._relationshipTypes[1] == 'R') + def test_multi_label(self): + redis_graph = Graph('g', self.r) + + node = Node(label=['l', 'll']) + redis_graph.add_node(node) + + redis_graph.commit() + + query = 'MATCH (n) RETURN n' + + result = redis_graph.query(query) + + result_node = result.result_set[0][0] + + self.assertEqual(result_node, node) + + try: + Node(label=1) + self.assertTrue(False) + except AssertionError: + self.assertTrue(True) + + try: + Node(label=['l', 1]) + self.assertTrue(False) + except AssertionError: + self.assertTrue(True) + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index 6b46064..d398f93 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -11,10 +11,12 @@ def setUp(self): self.props_only = node.Node(properties={"a": "a", "b": 10}) self.no_label = node.Node(node_id=1, alias="alias", properties={"a": "a"}) + self.multi_label = node.Node(node_id=1, alias="alias", label=["l", "ll"]) def test_toString(self): self.assertEqual(self.no_args.toString(), "") self.assertEqual(self.no_props.toString(), "") + self.assertEqual(self.multi_label.toString(), "") self.assertEqual(self.props_only.toString(), '{a:"a",b:10}') self.assertEqual(self.no_label.toString(), '{a:"a"}') @@ -23,6 +25,7 @@ def test_stringify(self): self.assertEqual(str(self.no_props), "(alias:l)") self.assertEqual(str(self.props_only), '({a:"a",b:10})') self.assertEqual(str(self.no_label), '(alias{a:"a"})') + self.assertEqual(str(self.multi_label), "(alias:l:ll)") def test_comparision(self):