Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add force distribution plots #46

Merged
merged 5 commits into from
Apr 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions ipsuite/analysis/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,26 @@ class ForceDecomposition(base.AnalyseProcessAtoms):

The implementation follows the method described in
https://doi.org/10.26434/chemrxiv-2022-l4tb9


Attributes
----------
wasserstein_distance: float
Compute the wasserstein distance between the distributions of the
prediced and true forces for each trans, rot, vib component.
"""

trans_forces: dict = zntrack.zn.metrics()
rot_forces: dict = zntrack.zn.metrics()
vib_forces: dict = zntrack.zn.metrics()
wasserstein_distance = zntrack.zn.metrics()

rot_force_plt = zntrack.dvc.outs(zntrack.nwd / "rot_force.png")
trans_force_plt = zntrack.dvc.outs(zntrack.nwd / "trans_force.png")
vib_force_plt = zntrack.dvc.outs(zntrack.nwd / "vib_force.png")

histogram_plt = zntrack.dvc.outs(zntrack.nwd / "histogram.png")

def get_plots(self):
fig = get_figure(
np.linalg.norm(self.true_forces["trans"], axis=-1),
Expand Down Expand Up @@ -283,21 +293,75 @@ def get_metrics(self):
np.array(self.true_forces["vib"]), np.array(self.pred_forces["vib"])
)

def get_histogram(self):
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance

def get_rel_scalar_prod(main, relative) -> np.ndarray:
x = np.einsum("ij,ij->i", main, relative)
x /= np.linalg.norm(main, axis=-1)
return x

fig, axes = plt.subplots(4, 3, figsize=(4 * 5, 3 * 3))
fig.suptitle(
(
r"A fraction $\dfrac{\vec{a} \cdot"
r" \vec{b}}{\left|\left|\vec{a}\right|\right|_{2}} $ of $\vec{b}$ that"
r" contributes to $\vec{a}$"
),
fontsize=16,
)

self.wasserstein_distance = {}

for label, ax_ in zip(self.true_forces.keys(), axes):
self.wasserstein_distance[label] = {}
for part, ax in zip(["vib", "rot", "trans"], ax_):
data = get_rel_scalar_prod(
self.true_forces[label], self.true_forces[part]
)
true_bins = ax.hist(
data, bins=50, density=True, label=f"true {label} {part}"
)

data = get_rel_scalar_prod(
self.pred_forces[label], self.pred_forces[part]
)
pred_bins = ax.hist(
data,
bins=true_bins[1],
density=True,
alpha=0.5,
label=f"pred {label} {part}",
)
ax.legend()
self.wasserstein_distance[label][part] = wasserstein_distance(
true_bins[0], pred_bins[0]
)

fig.savefig(self.histogram_plt, bbox_inches="tight")

def run(self):
true_atoms, pred_atoms = self.get_data()
mapping = BarycenterMapping(data=None)
# TODO make the force_decomposition return full forces
# TODO check if you sum the forces they yield the full forces
# TODO make mapping a 'zn.nodes' with Mapping(species="BF4")
# maybe allow smiles and enumeration 0, 1, ...

self.true_forces = {"trans": [], "rot": [], "vib": []}
self.pred_forces = {"trans": [], "rot": [], "vib": []}
self.true_forces = {"all": [], "trans": [], "rot": [], "vib": []}
self.pred_forces = {"all": [], "trans": [], "rot": [], "vib": []}

for atom in tqdm.tqdm(true_atoms):
atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition(
atom, mapping
)
self.true_forces["all"].append(atom.get_forces())
self.true_forces["trans"].append(atom_trans_forces)
self.true_forces["rot"].append(atom_rot_forces)
self.true_forces["vib"].append(atom_vib_forces)

self.true_forces["all"] = np.concatenate(self.true_forces["all"])
self.true_forces["trans"] = np.concatenate(self.true_forces["trans"])
self.true_forces["rot"] = np.concatenate(self.true_forces["rot"])
self.true_forces["vib"] = np.concatenate(self.true_forces["vib"])
Expand All @@ -306,13 +370,16 @@ def run(self):
atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition(
atom, mapping
)
self.pred_forces["all"].append(atom.get_forces())
self.pred_forces["trans"].append(atom_trans_forces)
self.pred_forces["rot"].append(atom_rot_forces)
self.pred_forces["vib"].append(atom_vib_forces)

self.pred_forces["all"] = np.concatenate(self.pred_forces["all"])
self.pred_forces["trans"] = np.concatenate(self.pred_forces["trans"])
self.pred_forces["rot"] = np.concatenate(self.pred_forces["rot"])
self.pred_forces["vib"] = np.concatenate(self.pred_forces["vib"])

self.get_metrics()
self.get_plots()
self.get_histogram()