Skip to content
This repository has been archived by the owner on Nov 18, 2023. It is now read-only.

Simplify QueryGraph interface #118

Merged
merged 3 commits into from
Dec 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kglib/kgcn/examples/diagnosis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ py_test(
],
deps = [
"diagnosis",
"//kglib/utils/graph/test",
requirement('numpy'),
requirement('networkx'),
requirement('decorator'),
Expand Down
111 changes: 50 additions & 61 deletions kglib/kgcn/examples/diagnosis/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
URI = "localhost:48555"

# Existing elements in the graph are those that pre-exist in the graph, and should be predicted to continue to exist
PREEXISTS = dict(solution=0)
PREEXISTS = 0

# Candidates are neither present in the input nor in the solution, they are negative samples
CANDIDATE = dict(solution=1)
CANDIDATE = 1

# Elements to infer are the graph elements whose existence we want to predict to be true, they are positive samples
TO_INFER = dict(solution=2)
TO_INFER = 2

# Categorical Attribute types and the values of their categories
CATEGORICAL_ATTRIBUTES = {'name': ['Diabetes Type II', 'Multiple Sclerosis', 'Blurred vision', 'Fatigue', 'Cigarettes',
Expand Down Expand Up @@ -138,19 +138,23 @@ def create_concept_graphs(example_indices, grakn_session):
# Build a graph from the queries, samplers, and query graphs
graph = build_graph_from_queries(graph_query_handles, tx, infer=infer)

# Remove label leakage - change type labels that indicate candidates into non-candidates
for data in multidigraph_data_iterator(graph):
for label_to_obfuscate, with_label in TYPES_AND_ROLES_TO_OBFUSCATE.items():
if data['type'] == label_to_obfuscate:
data.update(type=with_label)
break
obfuscate_labels(graph, TYPES_AND_ROLES_TO_OBFUSCATE)

graph.name = example_id
graphs.append(graph)

return graphs


def obfuscate_labels(graph, types_and_roles_to_obfuscate):
# Remove label leakage - change type labels that indicate candidates into non-candidates
for data in multidigraph_data_iterator(graph):
for label_to_obfuscate, with_label in types_and_roles_to_obfuscate.items():
if data['type'] == label_to_obfuscate:
data.update(type=with_label)
break


def get_query_handles(example_id):
"""
Creates an iterable, each element containing a Graql query, a function to sample the answers, and a QueryGraph
Expand All @@ -174,15 +178,13 @@ def get_query_handles(example_id):
get;''')

vars = p, par, ps, d, diag, n = 'p', 'par', 'ps', 'd', 'diag', 'n'
g = QueryGraph()
g.add_vars(*vars, **PREEXISTS)
g.add_role_edge(ps, p, 'child', **PREEXISTS)
g.add_role_edge(ps, par, 'parent', **PREEXISTS)
g.add_role_edge(diag, par, 'patient', **PREEXISTS)
g.add_role_edge(diag, d, 'diagnosed-disease', **PREEXISTS)
g.add_has_edge(d, n, **PREEXISTS)

hereditary_query_graph = g
hereditary_query_graph = (QueryGraph()
.add_vars(vars, PREEXISTS)
.add_role_edge(ps, p, 'child', PREEXISTS)
.add_role_edge(ps, par, 'parent', PREEXISTS)
.add_role_edge(diag, par, 'patient', PREEXISTS)
.add_role_edge(diag, d, 'diagnosed-disease', PREEXISTS)
.add_has_edge(d, n, PREEXISTS))

# === Consumption Feature ===
consumption_query = inspect.cleandoc(f'''match
Expand All @@ -192,26 +194,22 @@ def get_query_handles(example_id):
has units-per-week $u; get;''')

vars = p, s, n, c, u = 'p', 's', 'n', 'c', 'u'
g = QueryGraph()
g.add_vars(*vars, **PREEXISTS)
g.add_has_edge(s, n, **PREEXISTS)
g.add_role_edge(c, p, 'consumer', **PREEXISTS)
g.add_role_edge(c, s, 'consumed-substance', **PREEXISTS)
g.add_has_edge(c, u, **PREEXISTS)

consumption_query_graph = g
consumption_query_graph = (QueryGraph()
.add_vars(vars, PREEXISTS)
.add_has_edge(s, n, PREEXISTS)
.add_role_edge(c, p, 'consumer', PREEXISTS)
.add_role_edge(c, s, 'consumed-substance', PREEXISTS)
.add_has_edge(c, u, PREEXISTS))

# === Age Feature ===
person_age_query = inspect.cleandoc(f'''match
$p isa person, has example-id {example_id}, has age $a;
get;''')

vars = p, a = 'p', 'a'
g = QueryGraph()
g.add_vars(*vars, **PREEXISTS)
g.add_has_edge(p, a, **PREEXISTS)

person_age_query_graph = g
person_age_query_graph = (QueryGraph()
.add_vars(vars, PREEXISTS)
.add_has_edge(p, a, PREEXISTS))

# === Risk Factors Feature ===
risk_factor_query = inspect.cleandoc(f'''match
Expand All @@ -221,12 +219,10 @@ def get_query_handles(example_id):
get;''')

vars = p, d, r = 'p', 'd', 'r'
g = QueryGraph()
g.add_vars(*vars, **PREEXISTS)
g.add_role_edge(r, p, 'person-at-risk', **PREEXISTS)
g.add_role_edge(r, d, 'risked-disease', **PREEXISTS)

risk_factor_query_graph = g
risk_factor_query_graph = (QueryGraph()
.add_vars(vars, PREEXISTS)
.add_role_edge(r, p, 'person-at-risk', PREEXISTS)
.add_role_edge(r, d, 'risked-disease', PREEXISTS))

# === Diagnosis ===
diagnosis_query = inspect.cleandoc(f'''match
Expand All @@ -239,26 +235,22 @@ def get_query_handles(example_id):
get;''')

vars = p, s, sn, d, dn, sp, sev, c = 'p', 's', 'sn', 'd', 'dn', 'sp', 'sev', 'c'
g = QueryGraph()
g.add_vars(*vars, **PREEXISTS)
g.add_has_edge(s, sn, **PREEXISTS)
g.add_has_edge(d, dn, **PREEXISTS)
g.add_role_edge(sp, s, 'presented-symptom', **PREEXISTS)
g.add_has_edge(sp, sev, **PREEXISTS)
g.add_role_edge(sp, p, 'symptomatic-patient', **PREEXISTS)
g.add_role_edge(c, s, 'effect', **PREEXISTS)
g.add_role_edge(c, d, 'cause', **PREEXISTS)

base_query_graph = g

g = copy.copy(base_query_graph)
base_query_graph = (QueryGraph()
.add_vars(vars, PREEXISTS)
.add_has_edge(s, sn, PREEXISTS)
.add_has_edge(d, dn, PREEXISTS)
.add_role_edge(sp, s, 'presented-symptom', PREEXISTS)
.add_has_edge(sp, sev, PREEXISTS)
.add_role_edge(sp, p, 'symptomatic-patient', PREEXISTS)
.add_role_edge(c, s, 'effect', PREEXISTS)
.add_role_edge(c, d, 'cause', PREEXISTS))

diag, d, p = 'diag', 'd', 'p'
g.add_vars(diag, **TO_INFER)
g.add_role_edge(diag, d, 'diagnosed-disease', **TO_INFER)
g.add_role_edge(diag, p, 'patient', **TO_INFER)

diagnosis_query_graph = g
diagnosis_query_graph = (copy.copy(base_query_graph)
.add_vars([diag], TO_INFER)
.add_role_edge(diag, d, 'diagnosed-disease', TO_INFER)
.add_role_edge(diag, p, 'patient', TO_INFER))

# === Candidate Diagnosis ===
candidate_diagnosis_query = inspect.cleandoc(f'''match
Expand All @@ -270,13 +262,10 @@ def get_query_handles(example_id):
$diag(candidate-patient: $p, candidate-diagnosed-disease: $d) isa candidate-diagnosis;
get;''')

g = copy.copy(base_query_graph)

diag, d, p = 'diag', 'd', 'p'
g.add_vars(diag, **CANDIDATE)
g.add_role_edge(diag, d, 'candidate-diagnosed-disease', **CANDIDATE)
g.add_role_edge(diag, p, 'candidate-patient', **CANDIDATE)
candidate_diagnosis_query_graph = g
candidate_diagnosis_query_graph = (copy.copy(base_query_graph)
.add_vars([diag], CANDIDATE)
.add_role_edge(diag, d, 'candidate-diagnosed-disease', CANDIDATE)
.add_role_edge(diag, p, 'candidate-patient', CANDIDATE))

return [
(diagnosis_query, lambda x: x, diagnosis_query_graph),
Expand Down
32 changes: 31 additions & 1 deletion kglib/kgcn/examples/diagnosis/diagnosis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import networkx as nx
import numpy as np

from kglib.kgcn.examples.diagnosis.diagnosis import write_predictions_to_grakn
from kglib.kgcn.examples.diagnosis.diagnosis import write_predictions_to_grakn, obfuscate_labels
from kglib.utils.grakn.object.thing import Thing
from kglib.utils.graph.test.case import GraphTestCase


class TestWritePredictionsToGrakn(unittest.TestCase):
Expand Down Expand Up @@ -90,5 +91,34 @@ def test_query_made_only_if_relation_wins(self):
tx.commit.assert_called()


class TestObfuscateLabels(GraphTestCase):

def test_labels_obfuscated_as_expected(self):

graph = nx.MultiDiGraph()

graph.add_node(0, type='person')
graph.add_node(1, type='disease')
graph.add_node(2, type='candidate-diagnosis')

graph.add_edge(2, 0, type='candidate-patient')
graph.add_edge(2, 1, type='candidate-diagnosed-disease')

obfuscate_labels(graph, {'candidate-diagnosis': 'diagnosis',
'candidate-patient': 'patient',
'candidate-diagnosed-disease': 'diagnosed-disease'})

expected_graph = nx.MultiDiGraph()
expected_graph.add_node(0, type='person')
expected_graph.add_node(1, type='disease')
expected_graph.add_node(2, type='diagnosis')

expected_graph.add_edge(2, 0, type='patient')
expected_graph.add_edge(2, 1, type='diagnosed-disease')

self.assertGraphsEqual(graph, expected_graph)



if __name__ == "__main__":
unittest.main()
29 changes: 16 additions & 13 deletions kglib/utils/graph/query/query_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,45 @@ class QueryGraph(nx.MultiDiGraph):
A custom graph to represent a query. Has additional helper methods specific to adding Graql patterns.
"""

def add_vars(self, *vars, **attr):
def add_vars(self, vars, solution):
"""
Add Graql variables, stored as nodes in the graph
Args:
*vars: String variables
**attr: Properties to be added to the data stored on each variable node
vars: String variables
solution: Indicator of the ground truth class that the variables belongs to

Returns:
None
self
"""
for var in vars:
self.add_node(var, **attr)
self.add_node(var, solution=solution)
return self

def add_has_edge(self, owner_var, attribute_var, **attr):
def add_has_edge(self, owner_var, attribute_var, solution):
"""
Add a "has" edge to represent ownership of an attribute
Args:
owner_var: The variable of the owner
attribute_var: The variable of the owned attribute
**attr: Properties to be added to the data stored on the "has" edge added
solution: Indicator of the ground truth class that the edge belongs to

Returns:
None
self
"""
self.add_edge(owner_var, attribute_var, type='has', **attr)
self.add_edge(owner_var, attribute_var, type='has', solution=solution)
return self

def add_role_edge(self, relation_var, roleplayer_var, role_label, **attr):
def add_role_edge(self, relation_var, roleplayer_var, role_label, solution):
"""
Add an edge to represent the role a variable plays in a relation
Args:
relation_var: The variable of the relation
roleplayer_var: The variable of the roleplayer in the relation
role_label: The role the roleplayer plays in the relation
**attr: Properties to be added to the data stored on the role edge added
solution: Indicator of the ground truth class that the edge belongs to

Returns:
None
self
"""
self.add_edge(relation_var, roleplayer_var, type=role_label, **attr)
self.add_edge(relation_var, roleplayer_var, type=role_label, solution=solution)
return self
15 changes: 10 additions & 5 deletions kglib/utils/graph/query/query_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,32 @@

class TestQueryGraph(unittest.TestCase):

def test_add_single_var_adds_variable_node_as_expected(self):
g = QueryGraph()
g.add_vars(['a'], 0)
self.assertDictEqual({'solution': 0}, g.nodes['a'])

def test_add_vars_adds_variable_nodes_as_expected(self):
g = QueryGraph()
g.add_vars('a', 'b')
g.add_vars(['a', 'b'], 0)
nodes = {node for node in g.nodes}
self.assertSetEqual({'a', 'b'}, nodes)

def test_add_has_edge_adds_edge_as_expected(self):
g = QueryGraph()
g.add_vars('a', 'b')
g.add_has_edge('a', 'b')
g.add_has_edge('a', 'b', 0)
edges = [edge for edge in g.edges]
self.assertEqual(1, len(edges))
self.assertEqual('has', g.edges['a', 'b', 0]['type'])
self.assertDictEqual({'type': 'has', 'solution': 0}, g.edges['a', 'b', 0])

def test_add_role_edge_adds_role_as_expected(self):
g = QueryGraph()
g.add_vars('a', 'b')
g.add_role_edge('a', 'b', 'role')
g.add_role_edge('a', 'b', 'role_label', 1)
edges = [edge for edge in g.edges]
self.assertEqual(1, len(edges))
self.assertEqual('role', g.edges['a', 'b', 0]['type'])
self.assertDictEqual({'type': 'role_label', 'solution': 1}, g.edges['a', 'b', 0])


if __name__ == "__main__":
Expand Down