From 389f7b49358bf85270fc0b0856b0159859273d61 Mon Sep 17 00:00:00 2001 From: James Fletcher Date: Mon, 16 Dec 2019 17:43:29 +0000 Subject: [PATCH 1/3] Make 'solution' class mandatory when building querygraphs, make methods chainable --- kglib/kgcn/examples/diagnosis/diagnosis.py | 95 +++++++++------------ kglib/utils/graph/query/query_graph.py | 29 ++++--- kglib/utils/graph/query/query_graph_test.py | 15 ++-- 3 files changed, 66 insertions(+), 73 deletions(-) diff --git a/kglib/kgcn/examples/diagnosis/diagnosis.py b/kglib/kgcn/examples/diagnosis/diagnosis.py index a5b08eae..13e24b71 100644 --- a/kglib/kgcn/examples/diagnosis/diagnosis.py +++ b/kglib/kgcn/examples/diagnosis/diagnosis.py @@ -34,13 +34,13 @@ URI = "localhost:48555" # Existing elements in the graph are those that pre-exist in the graph, and should be predicted to continue to exist -PREEXISTS = dict(solution=0) +PREEXISTS = 0 # Candidates are neither present in the input nor in the solution, they are negative samples -CANDIDATE = dict(solution=1) +CANDIDATE = 1 # Elements to infer are the graph elements whose existence we want to predict to be true, they are positive samples -TO_INFER = dict(solution=2) +TO_INFER = 2 # Categorical Attribute types and the values of their categories CATEGORICAL_ATTRIBUTES = {'name': ['Diabetes Type II', 'Multiple Sclerosis', 'Blurred vision', 'Fatigue', 'Cigarettes', @@ -174,15 +174,13 @@ def get_query_handles(example_id): get;''') vars = p, par, ps, d, diag, n = 'p', 'par', 'ps', 'd', 'diag', 'n' - g = QueryGraph() - g.add_vars(*vars, **PREEXISTS) - g.add_role_edge(ps, p, 'child', **PREEXISTS) - g.add_role_edge(ps, par, 'parent', **PREEXISTS) - g.add_role_edge(diag, par, 'patient', **PREEXISTS) - g.add_role_edge(diag, d, 'diagnosed-disease', **PREEXISTS) - g.add_has_edge(d, n, **PREEXISTS) - - hereditary_query_graph = g + hereditary_query_graph = (QueryGraph() + .add_vars(vars, PREEXISTS) + .add_role_edge(ps, p, 'child', PREEXISTS) + .add_role_edge(ps, par, 'parent', PREEXISTS) + .add_role_edge(diag, par, 'patient', PREEXISTS) + .add_role_edge(diag, d, 'diagnosed-disease', PREEXISTS) + .add_has_edge(d, n, PREEXISTS)) # === Consumption Feature === consumption_query = inspect.cleandoc(f'''match @@ -192,14 +190,12 @@ def get_query_handles(example_id): has units-per-week $u; get;''') vars = p, s, n, c, u = 'p', 's', 'n', 'c', 'u' - g = QueryGraph() - g.add_vars(*vars, **PREEXISTS) - g.add_has_edge(s, n, **PREEXISTS) - g.add_role_edge(c, p, 'consumer', **PREEXISTS) - g.add_role_edge(c, s, 'consumed-substance', **PREEXISTS) - g.add_has_edge(c, u, **PREEXISTS) - - consumption_query_graph = g + consumption_query_graph = (QueryGraph() + .add_vars(vars, PREEXISTS) + .add_has_edge(s, n, PREEXISTS) + .add_role_edge(c, p, 'consumer', PREEXISTS) + .add_role_edge(c, s, 'consumed-substance', PREEXISTS) + .add_has_edge(c, u, PREEXISTS)) # === Age Feature === person_age_query = inspect.cleandoc(f'''match @@ -207,11 +203,9 @@ def get_query_handles(example_id): get;''') vars = p, a = 'p', 'a' - g = QueryGraph() - g.add_vars(*vars, **PREEXISTS) - g.add_has_edge(p, a, **PREEXISTS) - - person_age_query_graph = g + person_age_query_graph = (QueryGraph() + .add_vars(vars, PREEXISTS) + .add_has_edge(p, a, PREEXISTS)) # === Risk Factors Feature === risk_factor_query = inspect.cleandoc(f'''match @@ -221,12 +215,10 @@ def get_query_handles(example_id): get;''') vars = p, d, r = 'p', 'd', 'r' - g = QueryGraph() - g.add_vars(*vars, **PREEXISTS) - g.add_role_edge(r, p, 'person-at-risk', **PREEXISTS) - g.add_role_edge(r, d, 'risked-disease', **PREEXISTS) - - risk_factor_query_graph = g + risk_factor_query_graph = (QueryGraph() + .add_vars(vars, PREEXISTS) + .add_role_edge(r, p, 'person-at-risk', PREEXISTS) + .add_role_edge(r, d, 'risked-disease', PREEXISTS)) # === Diagnosis === diagnosis_query = inspect.cleandoc(f'''match @@ -239,26 +231,22 @@ def get_query_handles(example_id): get;''') vars = p, s, sn, d, dn, sp, sev, c = 'p', 's', 'sn', 'd', 'dn', 'sp', 'sev', 'c' - g = QueryGraph() - g.add_vars(*vars, **PREEXISTS) - g.add_has_edge(s, sn, **PREEXISTS) - g.add_has_edge(d, dn, **PREEXISTS) - g.add_role_edge(sp, s, 'presented-symptom', **PREEXISTS) - g.add_has_edge(sp, sev, **PREEXISTS) - g.add_role_edge(sp, p, 'symptomatic-patient', **PREEXISTS) - g.add_role_edge(c, s, 'effect', **PREEXISTS) - g.add_role_edge(c, d, 'cause', **PREEXISTS) - - base_query_graph = g - - g = copy.copy(base_query_graph) + base_query_graph = (QueryGraph() + .add_vars(vars, PREEXISTS) + .add_has_edge(s, sn, PREEXISTS) + .add_has_edge(d, dn, PREEXISTS) + .add_role_edge(sp, s, 'presented-symptom', PREEXISTS) + .add_has_edge(sp, sev, PREEXISTS) + .add_role_edge(sp, p, 'symptomatic-patient', PREEXISTS) + .add_role_edge(c, s, 'effect', PREEXISTS) + .add_role_edge(c, d, 'cause', PREEXISTS)) diag, d, p = 'diag', 'd', 'p' - g.add_vars(diag, **TO_INFER) - g.add_role_edge(diag, d, 'diagnosed-disease', **TO_INFER) - g.add_role_edge(diag, p, 'patient', **TO_INFER) - diagnosis_query_graph = g + diagnosis_query_graph = (copy.copy(base_query_graph) + .add_vars([diag], TO_INFER) + .add_role_edge(diag, d, 'diagnosed-disease', TO_INFER) + .add_role_edge(diag, p, 'patient', TO_INFER)) # === Candidate Diagnosis === candidate_diagnosis_query = inspect.cleandoc(f'''match @@ -270,13 +258,10 @@ def get_query_handles(example_id): $diag(candidate-patient: $p, candidate-diagnosed-disease: $d) isa candidate-diagnosis; get;''') - g = copy.copy(base_query_graph) - - diag, d, p = 'diag', 'd', 'p' - g.add_vars(diag, **CANDIDATE) - g.add_role_edge(diag, d, 'candidate-diagnosed-disease', **CANDIDATE) - g.add_role_edge(diag, p, 'candidate-patient', **CANDIDATE) - candidate_diagnosis_query_graph = g + candidate_diagnosis_query_graph = (copy.copy(base_query_graph) + .add_vars([diag], CANDIDATE) + .add_role_edge(diag, d, 'candidate-diagnosed-disease', CANDIDATE) + .add_role_edge(diag, p, 'candidate-patient', CANDIDATE)) return [ (diagnosis_query, lambda x: x, diagnosis_query_graph), diff --git a/kglib/utils/graph/query/query_graph.py b/kglib/utils/graph/query/query_graph.py index 0db781a0..455602a7 100644 --- a/kglib/utils/graph/query/query_graph.py +++ b/kglib/utils/graph/query/query_graph.py @@ -25,42 +25,45 @@ class QueryGraph(nx.MultiDiGraph): A custom graph to represent a query. Has additional helper methods specific to adding Graql patterns. """ - def add_vars(self, *vars, **attr): + def add_vars(self, vars, solution): """ Add Graql variables, stored as nodes in the graph Args: - *vars: String variables - **attr: Properties to be added to the data stored on each variable node + vars: String variables + solution: Indicator of the ground truth class that the variables belongs to Returns: - None + self """ for var in vars: - self.add_node(var, **attr) + self.add_node(var, solution=solution) + return self - def add_has_edge(self, owner_var, attribute_var, **attr): + def add_has_edge(self, owner_var, attribute_var, solution): """ Add a "has" edge to represent ownership of an attribute Args: owner_var: The variable of the owner attribute_var: The variable of the owned attribute - **attr: Properties to be added to the data stored on the "has" edge added + solution: Indicator of the ground truth class that the edge belongs to Returns: - None + self """ - self.add_edge(owner_var, attribute_var, type='has', **attr) + self.add_edge(owner_var, attribute_var, type='has', solution=solution) + return self - def add_role_edge(self, relation_var, roleplayer_var, role_label, **attr): + def add_role_edge(self, relation_var, roleplayer_var, role_label, solution): """ Add an edge to represent the role a variable plays in a relation Args: relation_var: The variable of the relation roleplayer_var: The variable of the roleplayer in the relation role_label: The role the roleplayer plays in the relation - **attr: Properties to be added to the data stored on the role edge added + solution: Indicator of the ground truth class that the edge belongs to Returns: - None + self """ - self.add_edge(relation_var, roleplayer_var, type=role_label, **attr) + self.add_edge(relation_var, roleplayer_var, type=role_label, solution=solution) + return self diff --git a/kglib/utils/graph/query/query_graph_test.py b/kglib/utils/graph/query/query_graph_test.py index 1d44479a..4affd1f4 100644 --- a/kglib/utils/graph/query/query_graph_test.py +++ b/kglib/utils/graph/query/query_graph_test.py @@ -24,27 +24,32 @@ class TestQueryGraph(unittest.TestCase): + def test_add_single_var_adds_variable_node_as_expected(self): + g = QueryGraph() + g.add_vars(['a'], 0) + self.assertDictEqual({'solution': 0}, g.nodes['a']) + def test_add_vars_adds_variable_nodes_as_expected(self): g = QueryGraph() - g.add_vars('a', 'b') + g.add_vars(['a', 'b'], 0) nodes = {node for node in g.nodes} self.assertSetEqual({'a', 'b'}, nodes) def test_add_has_edge_adds_edge_as_expected(self): g = QueryGraph() g.add_vars('a', 'b') - g.add_has_edge('a', 'b') + g.add_has_edge('a', 'b', 0) edges = [edge for edge in g.edges] self.assertEqual(1, len(edges)) - self.assertEqual('has', g.edges['a', 'b', 0]['type']) + self.assertDictEqual({'type': 'has', 'solution': 0}, g.edges['a', 'b', 0]) def test_add_role_edge_adds_role_as_expected(self): g = QueryGraph() g.add_vars('a', 'b') - g.add_role_edge('a', 'b', 'role') + g.add_role_edge('a', 'b', 'role_label', 1) edges = [edge for edge in g.edges] self.assertEqual(1, len(edges)) - self.assertEqual('role', g.edges['a', 'b', 0]['type']) + self.assertDictEqual({'type': 'role_label', 'solution': 1}, g.edges['a', 'b', 0]) if __name__ == "__main__": From 7081568f85c899341e7ccf54404a78151f38bf0c Mon Sep 17 00:00:00 2001 From: James Fletcher Date: Mon, 16 Dec 2019 18:34:27 +0000 Subject: [PATCH 2/3] Separate out function to obfuscate labels --- kglib/kgcn/examples/diagnosis/diagnosis.py | 16 ++++++---- .../kgcn/examples/diagnosis/diagnosis_test.py | 32 ++++++++++++++++++- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/kglib/kgcn/examples/diagnosis/diagnosis.py b/kglib/kgcn/examples/diagnosis/diagnosis.py index 13e24b71..077f6561 100644 --- a/kglib/kgcn/examples/diagnosis/diagnosis.py +++ b/kglib/kgcn/examples/diagnosis/diagnosis.py @@ -138,12 +138,7 @@ def create_concept_graphs(example_indices, grakn_session): # Build a graph from the queries, samplers, and query graphs graph = build_graph_from_queries(graph_query_handles, tx, infer=infer) - # Remove label leakage - change type labels that indicate candidates into non-candidates - for data in multidigraph_data_iterator(graph): - for label_to_obfuscate, with_label in TYPES_AND_ROLES_TO_OBFUSCATE.items(): - if data['type'] == label_to_obfuscate: - data.update(type=with_label) - break + obfuscate_labels(graph, TYPES_AND_ROLES_TO_OBFUSCATE) graph.name = example_id graphs.append(graph) @@ -151,6 +146,15 @@ def create_concept_graphs(example_indices, grakn_session): return graphs +def obfuscate_labels(graph, types_and_roles_to_obfuscate): + # Remove label leakage - change type labels that indicate candidates into non-candidates + for data in multidigraph_data_iterator(graph): + for label_to_obfuscate, with_label in types_and_roles_to_obfuscate.items(): + if data['type'] == label_to_obfuscate: + data.update(type=with_label) + break + + def get_query_handles(example_id): """ Creates an iterable, each element containing a Graql query, a function to sample the answers, and a QueryGraph diff --git a/kglib/kgcn/examples/diagnosis/diagnosis_test.py b/kglib/kgcn/examples/diagnosis/diagnosis_test.py index 5ab91d4e..998c7b1b 100644 --- a/kglib/kgcn/examples/diagnosis/diagnosis_test.py +++ b/kglib/kgcn/examples/diagnosis/diagnosis_test.py @@ -24,8 +24,9 @@ import networkx as nx import numpy as np -from kglib.kgcn.examples.diagnosis.diagnosis import write_predictions_to_grakn +from kglib.kgcn.examples.diagnosis.diagnosis import write_predictions_to_grakn, obfuscate_labels from kglib.utils.grakn.object.thing import Thing +from kglib.utils.graph.test.case import GraphTestCase class TestWritePredictionsToGrakn(unittest.TestCase): @@ -90,5 +91,34 @@ def test_query_made_only_if_relation_wins(self): tx.commit.assert_called() +class TestObfuscateLabels(GraphTestCase): + + def test_labels_obfuscated_as_expected(self): + + graph = nx.MultiDiGraph() + + graph.add_node(0, type='person') + graph.add_node(1, type='disease') + graph.add_node(2, type='candidate-diagnosis') + + graph.add_edge(2, 0, type='candidate-patient') + graph.add_edge(2, 1, type='candidate-diagnosed-disease') + + obfuscate_labels(graph, {'candidate-diagnosis': 'diagnosis', + 'candidate-patient': 'patient', + 'candidate-diagnosed-disease': 'diagnosed-disease'}) + + expected_graph = nx.MultiDiGraph() + expected_graph.add_node(0, type='person') + expected_graph.add_node(1, type='disease') + expected_graph.add_node(2, type='diagnosis') + + expected_graph.add_edge(2, 0, type='patient') + expected_graph.add_edge(2, 1, type='diagnosed-disease') + + self.assertGraphsEqual(graph, expected_graph) + + + if __name__ == "__main__": unittest.main() From 91363777f45c2834d8cda6a5863be6aa2a730670 Mon Sep 17 00:00:00 2001 From: James Fletcher Date: Mon, 16 Dec 2019 19:53:22 +0000 Subject: [PATCH 3/3] Add dependency on GraphTestCase --- kglib/kgcn/examples/diagnosis/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/kglib/kgcn/examples/diagnosis/BUILD b/kglib/kgcn/examples/diagnosis/BUILD index 4a075430..9c5915cf 100644 --- a/kglib/kgcn/examples/diagnosis/BUILD +++ b/kglib/kgcn/examples/diagnosis/BUILD @@ -8,6 +8,7 @@ py_test( ], deps = [ "diagnosis", + "//kglib/utils/graph/test", requirement('numpy'), requirement('networkx'), requirement('decorator'),