Skip to content

Commit

Permalink
feat(moler_visualiser): render initial atom choice; improve motif/ato…
Browse files Browse the repository at this point in the history
…m choice rendering
  • Loading branch information
mmjb authored and kmaziarz committed May 16, 2022
1 parent 69ae903 commit 520ac89
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 61 deletions.
43 changes: 32 additions & 11 deletions molecule_generation/utils/moler_visualisation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def render_property_data(self, prop_infos: Dict[str, PropertyPredictionInformati
pass

@abstractmethod
def render_atom_data(self, atom_infos: List[MoleculeGenerationAtomChoiceInfo]) -> None:
def render_atom_data(
self,
atom_info: MoleculeGenerationAtomChoiceInfo,
choice_descr: str,
prob_threshold: float = 0.001,
) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -132,6 +137,23 @@ def visualise_from_smiles(self, smiles: str):
node_idx
]

# First, render the initial atom choice:
first_node_choice_one_hot_labels = batch_labels["correct_first_node_type_choices"][
0
].numpy()
first_node_choice_true_type_idxs = first_node_choice_one_hot_labels.nonzero()[0]
first_node_choice_probs = tf.nn.softmax(predictions.first_node_type_logits[0, :]).numpy()

self.render_atom_data(
MoleculeGenerationAtomChoiceInfo(
node_idx=0,
true_type_idx=first_node_choice_true_type_idxs,
type_idx_to_prob=first_node_choice_probs,
),
choice_descr="initial starting point",
)

# Now loop over each step:
for step, focus_node_idx in enumerate(batch_features["focus_nodes"].numpy()):
focus_node_orig_idx = partial_node_to_orig_node_id[focus_node_idx]

Expand Down Expand Up @@ -212,15 +234,14 @@ def visualise_from_smiles(self, smiles: str):
one_hot_labels = batch_labels["correct_node_type_choices"][node_choice_idx].numpy()
true_type_idx = one_hot_labels.nonzero()[0]
self.render_atom_data(
[
MoleculeGenerationAtomChoiceInfo(
node_idx=focus_node_orig_idx + 1,
true_type_idx=true_type_idx,
type_idx_to_prob=tf.nn.softmax(
predictions.node_type_logits[node_choice_idx, :]
).numpy(),
)
]
MoleculeGenerationAtomChoiceInfo(
node_idx=focus_node_orig_idx + 1,
true_type_idx=true_type_idx,
type_idx_to_prob=tf.nn.softmax(
predictions.node_type_logits[node_choice_idx, :]
).numpy(),
),
choice_descr="next addition to partial molecule",
)

def visualise_from_samples(self, molecule_representation: np.ndarray):
Expand Down Expand Up @@ -268,4 +289,4 @@ def visualise_from_samples(self, molecule_representation: np.ndarray):
step += 1

if atom is not None:
self.render_atom_data([atom])
self.render_atom_data(atom)
41 changes: 24 additions & 17 deletions molecule_generation/visualisation/moler_visualiser_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
import pickle
from typing import List, Dict
from typing import Dict

import numpy as np
from rdkit import Chem
Expand All @@ -27,24 +27,31 @@ def render_property_data(self, prop_infos: Dict[str, PropertyPredictionInformati
else:
print(f" (True = {prop_info.ground_truth:7.3f})")

def render_atom_data(self, atom_infos: List[MoleculeGenerationAtomChoiceInfo]) -> None:
print(" - Node choices")
def render_atom_data(
self,
atom_info: MoleculeGenerationAtomChoiceInfo,
choice_descr: str,
prob_threshold: float = 0.001,
) -> None:
print(f" - Atom/motif choices for {choice_descr}")
if atom_info.true_type_idx is not None:
correct_choices_str = ", ".join(
[self.dataset._node_type_index_to_string[idx] for idx in atom_info.true_type_idx]
)
print(f" Correct: {correct_choices_str}")

# Skip the first type, "UNK":
num_atom_types = len(self.dataset._node_type_index_to_string)
for atom_info in atom_infos:
node_type_strs = [
f"{self.dataset._node_type_index_to_string[node_typ_idx]}:"
f" {atom_info.type_idx_to_prob[node_typ_idx]:.3f}"
# Skip the first type, "UNK":
for node_typ_idx in range(1, num_atom_types)
]
for atom_type_idx in range(1, num_atom_types):
atom_type_prob = atom_info.type_idx_to_prob[atom_type_idx]
if atom_info.true_type_idx is not None:
true_node_type_info = f" (True: {[self.dataset._node_type_index_to_string[idx] for idx in atom_info.true_type_idx]})"
else:
true_node_type_info = ""
atom_type_is_correct = atom_type_idx in atom_info.true_type_idx

print(
f" Node {atom_info.node_idx:2}{true_node_type_info}:\t{', '.join(node_type_strs)}"
)
if atom_type_prob < prob_threshold and not atom_type_is_correct:
continue

atom_type_str = self.dataset._node_type_index_to_string[atom_type_idx]
print(f" {atom_type_prob:.3f} - {atom_type_str}")

def render_attachment_point_selection_step(
self, step: int, attachment_point_info: MoleculeGenerationAttachmentPointChoiceInfo
Expand All @@ -66,7 +73,7 @@ def render_attachment_point_selection_step(
print(f" Node {candidate:2}: prob {prob:.3f} {correctness_info}")

def render_molecule_gen_start(self, mol: Chem.Mol) -> None:
print(f"= Edge Steps to create {Chem.MolToSmiles(mol)}")
print(f"= Steps to create {Chem.MolToSmiles(mol)}")

def render_molecule_gen_edge_step(
self, step: int, step_info: MoleculeGenerationEdgeChoiceInfo
Expand Down
82 changes: 49 additions & 33 deletions molecule_generation/visualisation/moler_visualiser_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,46 +50,62 @@ def render_property_data(self, prop_infos: Dict[str, PropertyPredictionInformati
print(f" </tr>", file=self.__out_fh)
print(f"</table>", file=self.__out_fh)

def render_atom_data(self, atom_infos: List[MoleculeGenerationAtomChoiceInfo]) -> None:
print(f"<h3>Next Atom Type</h3>", file=self.__out_fh)
print(f"<table>", file=self.__out_fh)
print(f" <tr>", file=self.__out_fh)
print(f" <th>Node Id</th>", file=self.__out_fh)
print(f" <th>True Type</th>", file=self.__out_fh)
num_atom_types = len(self.dataset._node_type_index_to_string)
# Skip the first type, "UNK":
for atom_type_idx in range(1, num_atom_types):
def render_atom_data(
self,
atom_info: MoleculeGenerationAtomChoiceInfo,
choice_descr: str,
prob_threshold: float = 0.001,
) -> None:
print(f"<h3>New Atom/Motif Choice</h3>", file=self.__out_fh)
print(
f"Selecting atom/motif for {choice_descr}, based on molecule generated so far.",
file=self.__out_fh,
)
if prob_threshold > 0:
print(
f"Only showing choices with assigned probability >= {prob_threshold}.",
file=self.__out_fh,
)

if atom_info.true_type_idx is not None:
print(f"</br>Correct choices: ", file=self.__out_fh)
print(
f" <th>{self.dataset._node_type_index_to_string[atom_type_idx]}</th>",
f", ".join(
[
self.dataset._node_type_index_to_string[idx]
for idx in atom_info.true_type_idx
]
),
file=self.__out_fh,
)
print(f"</br>", file=self.__out_fh)

# Show all possible choices, limited to those over threshold:
print(f"<table>", file=self.__out_fh)
print(f" <tr>", file=self.__out_fh)
print(f" <th>Atom/Motif</th>", file=self.__out_fh)
print(f" <th>Predicted Probability</th>", file=self.__out_fh)
print(f" </tr>", file=self.__out_fh)

for atom_info in atom_infos:
print(f" <tr>", file=self.__out_fh)
print(f" <td>{atom_info.node_idx}</td>", file=self.__out_fh)
# Skip the first type, "UNK":
num_atom_types = len(self.dataset._node_type_index_to_string)
max_prob_atom_idx = np.argmax(atom_info.type_idx_to_prob)
for atom_type_idx in range(1, num_atom_types):
atom_type_prob = atom_info.type_idx_to_prob[atom_type_idx]
if atom_info.true_type_idx is not None:
print(
f" <td><b>{[self.dataset._node_type_index_to_string[idx] for idx in atom_info.true_type_idx]}</b></td>",
file=self.__out_fh,
)
atom_type_is_correct = atom_type_idx in atom_info.true_type_idx

if atom_type_prob < prob_threshold and not atom_type_is_correct:
continue

print(f" <tr>", file=self.__out_fh)
atom_type_str = self.dataset._node_type_index_to_string[atom_type_idx]
if atom_type_idx == max_prob_atom_idx:
print(f" <td><b>{atom_type_str}</b></td>", file=self.__out_fh)
print(f" <td><b>{atom_type_prob:.3f}</b></td>", file=self.__out_fh)
else:
print(
f" <td>n/a</td>",
file=self.__out_fh,
)
max_prob_atom_idx = np.argmax(atom_info.type_idx_to_prob)
for atom_type_idx in range(1, num_atom_types):
if atom_type_idx == max_prob_atom_idx:
print(
f" <td><b>{atom_info.type_idx_to_prob[atom_type_idx]:.3f}</b></td>",
file=self.__out_fh,
)
else:
print(
f" <td>{atom_info.type_idx_to_prob[atom_type_idx]:.3f}</td>",
file=self.__out_fh,
)
print(f" <td>{atom_type_str}</td>", file=self.__out_fh)
print(f" <td>{atom_type_prob:.3f}</td>", file=self.__out_fh)
print(f" </tr>", file=self.__out_fh)
print(f"</table>", file=self.__out_fh)

Expand Down

0 comments on commit 520ac89

Please sign in to comment.