Skip to content
This repository has been archived by the owner on Nov 18, 2023. It is now read-only.

Commit

Permalink
Don't Plot Preexisting Graph Elements (#108)
Browse files Browse the repository at this point in the history
## What is the goal of this PR?

Update graph plots in line with the loss function being used (where preexisting graph elements are not penalised)

## What are the changes implemented in this PR?

- Preexisting graph elements are only drawn for the ground truth plot
- Use consistent colours and change only the opacity to indicate certainty
- Simplify the plotting code
  • Loading branch information
jmsfltchr authored Dec 3, 2019
1 parent 62ed323 commit 7724fe7
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 55 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions kglib/kgcn/examples/diagnosis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This example is entirely fabricated as a demonstration for how to construct a KG

Studying the schema for this example, we have people who present symptoms, with some severity. Separately, we may know that certain symptoms can be caused by a disease. Lastly, people can be diagnosed with a disease.

![Diagnosis Schema](images/diagnosis_schema.png)
![Diagnosis Schema](.images/diagnosis_schema.png)

## Running the Example

Expand Down Expand Up @@ -93,13 +93,13 @@ You will see plots of metrics for the training process (training iteration on th
- The fraction of all graph elements predicted correctly across the dataset
- The fraction of completely solved examples (subgraphs extracted from Grakn that are solved in full)

![learning metrics](images/learning.png)
![learning metrics](.images/learning.png)

#### Visualise the Predictions

We also receive a plot of some of the predictions made on the test set.

![predictions made on test set](images/graph_snippet.png)
![predictions made on test set](.images/graph_snippet.png)

**Blue box:** Ground Truth

Expand Down Expand Up @@ -159,7 +159,7 @@ A single subgraph is extracted from Grakn by making these queries and combining

We can visualise such a subgraph by running these two queries one after the other in Grakn Workbase:

![queried subgraph](images/queried_subgraph.png)
![queried subgraph](.images/queried_subgraph.png)

You can get the relevant version of Grakn Workbase from the Assets of the [latest Workbase release](https://github.com/graknlabs/workbase/releases/latest).

Expand Down
Binary file not shown.
Binary file removed kglib/kgcn/examples/diagnosis/images/learning.png
Binary file not shown.
66 changes: 30 additions & 36 deletions kglib/kgcn/plot/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def plot_predictions(raw_graphs, test_values, num_processing_steps_ge, solution_
ground_truth_node_prob = target["nodes"][:, -1]
ground_truth_edge_prob = target["edges"][:, -1]

non_preexist_node_mask = mask_preexists(target["nodes"])
non_preexist_edge_mask = mask_preexists(target["edges"])

# Ground truth.
iax = j * (2 + num_steps_to_plot) + 1
ax = draw_subplot(graph, fig, pos, node_size, h, w, iax, ground_truth_node_prob, ground_truth_edge_prob, True)
Expand All @@ -118,16 +121,16 @@ def plot_predictions(raw_graphs, test_values, num_processing_steps_ge, solution_
# Prediction.
for k, outp in enumerate(output):
iax = j * (2 + num_steps_to_plot) + 2 + k
node_prob = softmax_prob_last_dim(outp["nodes"])
edge_prob = softmax_prob_last_dim(outp["edges"])
node_prob = softmax_prob_last_dim(outp["nodes"]) * non_preexist_node_mask
edge_prob = softmax_prob_last_dim(outp["edges"]) * non_preexist_edge_mask
ax = draw_subplot(graph, fig, pos, node_size, h, w, iax, node_prob, edge_prob, False)
ax.set_title("Model-predicted\nStep {:02d} / {:02d}".format(
step_indices[k] + 1, step_indices[-1] + 1))

# Class Winners
# Displays whether the class represented by the last dimension was the winner
node_prob = last_dim_was_class_winner(output[-1]["nodes"])
edge_prob = last_dim_was_class_winner(output[-1]["edges"])
node_prob = last_dim_was_class_winner(output[-1]["nodes"]) * non_preexist_node_mask
edge_prob = last_dim_was_class_winner(output[-1]["edges"]) * non_preexist_edge_mask

iax = j * (2 + num_steps_to_plot) + 2 + len(output)
ax = draw_subplot(graph, fig, pos, node_size, h, w, iax, node_prob, edge_prob, False)
Expand All @@ -146,6 +149,10 @@ def plot_predictions(raw_graphs, test_values, num_processing_steps_ge, solution_
plt.savefig(output_file, bbox_inches='tight')


def mask_preexists(arr):
return (arr[:, 0] == 0) * 1


def softmax_prob_last_dim(x):
e = np.exp(x)
return e[:, -1] / np.sum(e, axis=-1)
Expand All @@ -155,52 +162,39 @@ def last_dim_was_class_winner(x):
return (np.argmax(x, axis=-1) == 2) * 1


def above_base(val, base=0.0):
return val * (1.0 - base) + base


def element_color(gt_plot, probability, element_props):
"""
Determine the color values to use for a node/edge and its label
gt plot:
blue for existing elements, green for those to infer, red and transparent for candidates (as there could be many)
blue for existing elements, green for those to infer, red candidates
output plot:
blue for existing elements, green for those to infer, red for candidates, all with transparency
"""

existing = dict(input=1, solution=0)
to_infer = dict(input=0, solution=2)
candidate = dict(input=0, solution=1)
existing = 0
candidate = 1
to_infer = 2

solution = element_props.get('solution')

to_infer = all([element_props.get(key) == value for key, value in to_infer.items()])
candidate = all([element_props.get(key) == value for key, value in candidate.items()])
existing = all([element_props.get(key) == value for key, value in existing.items()])
color_config = {
to_infer: {'color': [0.0, 1.0, 0.0], 'gt_opacity': 1.0},
candidate: {'color': [1.0, 0.0, 0.0], 'gt_opacity': 1.0},
existing: {'color': [0.0, 0.0, 1.0], 'gt_opacity': 0.2}
}

default_gt_label_color = np.array([0.0, 0.0, 0.0, 1.0])
output_label_color = np.array([0.0, 0.0, 0.0, above_base(probability, base=0.0)])
chosen_config = color_config[solution]

if gt_plot:
if to_infer:
return dict(element=np.array([0.0, 1.0, 0.0, 1.0]), label=default_gt_label_color)
elif existing:
return dict(element=np.array([0.0, 0.0, 1.0, 1.0]), label=default_gt_label_color)
elif candidate:
return dict(element=np.array([1.0, 0.0, 0.0, 0.1]), label=np.array([0.0, 0.0, 0.0, 0.1]))
else:
raise ValueError('Node to colour did not fit any category')
opacity = chosen_config['gt_opacity']
else:
if to_infer:
return dict(element=np.array([0.0, above_base(probability), 0.0, above_base(probability, base=0.0)]),
label=output_label_color)
elif existing:
return dict(element=np.array([0.0, 0.0, above_base(probability), above_base(probability, base=0.0)]),
label=output_label_color)
elif candidate:
return dict(element=np.array([above_base(probability), 0.0, 0.0, above_base(probability, base=0.0)]),
label=output_label_color)
else:
raise ValueError('Node to colour did not fit any category')
opacity = probability

label = np.array([0.0, 0.0, 0.0] + [opacity])
color = np.array(chosen_config['color'] + [opacity])

return dict(element=color, label=label)


def draw_subplot(graph, fig, pos, node_size, h, w, iax, node_prob, edge_prob, gt_plot):
Expand Down
41 changes: 26 additions & 15 deletions kglib/kgcn/plot/plotting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def test_plot_is_created(self):

graph = nx.MultiDiGraph(name=0)

existing = dict(input=1, solution=0)
to_infer = dict(input=0, solution=2)
candidate = dict(input=0, solution=1)
existing = dict(solution=0)
to_infer = dict(solution=2)
candidate = dict(solution=1)

# people
graph.add_node(0, type='person', **existing)
Expand All @@ -47,18 +47,29 @@ def test_plot_is_created(self):
graph.add_edge(2, 0, type='parent', **to_infer)
graph.add_edge(2, 1, type='child', **candidate)

graph_tuple = GraphsTuple(nodes=np.array([[1., 0., 0.],
[1., 1., 0.],
[1., 0., 1.]]),
edges=np.array([[1., 0., 0.],
[1., 1., 0.]]),
receivers=np.array([1, 2], dtype=np.int32),
senders=np.array([0, 1], dtype=np.int32),
globals=np.array([[0., 0., 0., 0., 0.]], dtype=np.float32),
n_node=np.array([3], dtype=np.int32),
n_edge=np.array([2], dtype=np.int32))

test_values = {"target": graph_tuple, "outputs": [graph_tuple for _ in range(6)]}
graph_tuple_target = GraphsTuple(nodes=np.array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]),
edges=np.array([[0., 0., 1.],
[0., 1., 0.]]),
receivers=np.array([1, 2], dtype=np.int32),
senders=np.array([0, 1], dtype=np.int32),
globals=np.array([[0., 0., 0., 0., 0.]], dtype=np.float32),
n_node=np.array([3], dtype=np.int32),
n_edge=np.array([2], dtype=np.int32))

graph_tuple_output = GraphsTuple(nodes=np.array([[1., 0., 0.],
[1., 1., 0.],
[1., 0., 1.]]),
edges=np.array([[1., 0., 0.],
[1., 1., 0.]]),
receivers=np.array([1, 2], dtype=np.int32),
senders=np.array([0, 1], dtype=np.int32),
globals=np.array([[0., 0., 0., 0., 0.]], dtype=np.float32),
n_node=np.array([3], dtype=np.int32),
n_edge=np.array([2], dtype=np.int32))

test_values = {"target": graph_tuple_target, "outputs": [graph_tuple_output for _ in range(6)]}

filename = f'./graph_{datetime.datetime.now()}.png'

Expand Down

0 comments on commit 7724fe7

Please sign in to comment.