From 862df64a54cf98341924bea1f47c4d463336031c Mon Sep 17 00:00:00 2001 From: PythonFZ <fabian.zills@web.de> Date: Sun, 9 Apr 2023 21:57:54 +0200 Subject: [PATCH] mrebase & force push --- ipsuite/analysis/model/predict.py | 71 ++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/ipsuite/analysis/model/predict.py b/ipsuite/analysis/model/predict.py index b0acb86a..17cdeb17 100644 --- a/ipsuite/analysis/model/predict.py +++ b/ipsuite/analysis/model/predict.py @@ -241,16 +241,26 @@ class ForceDecomposition(base.AnalyseProcessAtoms): The implementation follows the method described in https://doi.org/10.26434/chemrxiv-2022-l4tb9 + + + Attributes + ---------- + wasserstein_distance: float + Compute the wasserstein distance between the distributions of the + prediced and true forces for each trans, rot, vib component. """ trans_forces: dict = zntrack.zn.metrics() rot_forces: dict = zntrack.zn.metrics() vib_forces: dict = zntrack.zn.metrics() + wasserstein_distance = zntrack.zn.metrics() rot_force_plt = zntrack.dvc.outs(zntrack.nwd / "rot_force.png") trans_force_plt = zntrack.dvc.outs(zntrack.nwd / "trans_force.png") vib_force_plt = zntrack.dvc.outs(zntrack.nwd / "vib_force.png") + histogram_plt = zntrack.dvc.outs(zntrack.nwd / "histogram.png") + def get_plots(self): fig = get_figure( np.linalg.norm(self.true_forces["trans"], axis=-1), @@ -294,21 +304,75 @@ def get_metrics(self): np.array(self.true_forces["vib"]), np.array(self.pred_forces["vib"]) ) + def get_histogram(self): + import matplotlib.pyplot as plt + from scipy.stats import wasserstein_distance + + def get_rel_scalar_prod(main, relative) -> np.ndarray: + x = np.einsum("ij,ij->i", main, relative) + x /= np.linalg.norm(main, axis=-1) + return x + + fig, axes = plt.subplots(4, 3, figsize=(4 * 5, 3 * 3)) + fig.suptitle( + ( + r"A fraction $\dfrac{\vec{a} \cdot" + r" \vec{b}}{\left|\left|\vec{a}\right|\right|_{2}} $ of $\vec{b}$ that" + r" contributes to $\vec{a}$" + ), + fontsize=16, + ) + + self.wasserstein_distance = {} + + for label, ax_ in zip(self.true_forces.keys(), axes): + self.wasserstein_distance[label] = {} + for part, ax in zip(["vib", "rot", "trans"], ax_): + data = get_rel_scalar_prod( + self.true_forces[label], self.true_forces[part] + ) + true_bins = ax.hist( + data, bins=50, density=True, label=f"true {label} {part}" + ) + + data = get_rel_scalar_prod( + self.pred_forces[label], self.pred_forces[part] + ) + pred_bins = ax.hist( + data, + bins=true_bins[1], + density=True, + alpha=0.5, + label=f"pred {label} {part}", + ) + ax.legend() + self.wasserstein_distance[label][part] = wasserstein_distance( + true_bins[0], pred_bins[0] + ) + + fig.savefig(self.histogram_plt, bbox_inches="tight") + def run(self): true_atoms, pred_atoms = self.get_data() mapping = BarycenterMapping(data=None) + # TODO make the force_decomposition return full forces + # TODO check if you sum the forces they yield the full forces + # TODO make mapping a 'zn.nodes' with Mapping(species="BF4") + # maybe allow smiles and enumeration 0, 1, ... - self.true_forces = {"trans": [], "rot": [], "vib": []} - self.pred_forces = {"trans": [], "rot": [], "vib": []} + self.true_forces = {"all": [], "trans": [], "rot": [], "vib": []} + self.pred_forces = {"all": [], "trans": [], "rot": [], "vib": []} for atom in tqdm.tqdm(true_atoms): atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition( atom, mapping ) + self.true_forces["all"].append(atom.get_forces()) self.true_forces["trans"].append(atom_trans_forces) self.true_forces["rot"].append(atom_rot_forces) self.true_forces["vib"].append(atom_vib_forces) + self.true_forces["all"] = np.concatenate(self.true_forces["all"]) self.true_forces["trans"] = np.concatenate(self.true_forces["trans"]) self.true_forces["rot"] = np.concatenate(self.true_forces["rot"]) self.true_forces["vib"] = np.concatenate(self.true_forces["vib"]) @@ -317,13 +381,16 @@ def run(self): atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition( atom, mapping ) + self.pred_forces["all"].append(atom.get_forces()) self.pred_forces["trans"].append(atom_trans_forces) self.pred_forces["rot"].append(atom_rot_forces) self.pred_forces["vib"].append(atom_vib_forces) + self.pred_forces["all"] = np.concatenate(self.pred_forces["all"]) self.pred_forces["trans"] = np.concatenate(self.pred_forces["trans"]) self.pred_forces["rot"] = np.concatenate(self.pred_forces["rot"]) self.pred_forces["vib"] = np.concatenate(self.pred_forces["vib"]) self.get_metrics() self.get_plots() + self.get_histogram()