From 862df64a54cf98341924bea1f47c4d463336031c Mon Sep 17 00:00:00 2001
From: PythonFZ <fabian.zills@web.de>
Date: Sun, 9 Apr 2023 21:57:54 +0200
Subject: [PATCH] mrebase & force push

---
 ipsuite/analysis/model/predict.py | 71 ++++++++++++++++++++++++++++++-
 1 file changed, 69 insertions(+), 2 deletions(-)

diff --git a/ipsuite/analysis/model/predict.py b/ipsuite/analysis/model/predict.py
index b0acb86a..17cdeb17 100644
--- a/ipsuite/analysis/model/predict.py
+++ b/ipsuite/analysis/model/predict.py
@@ -241,16 +241,26 @@ class ForceDecomposition(base.AnalyseProcessAtoms):
 
     The implementation follows the method described in
     https://doi.org/10.26434/chemrxiv-2022-l4tb9
+
+
+    Attributes
+    ----------
+    wasserstein_distance: float
+        Compute the wasserstein distance between the distributions of the
+        prediced and true forces for each trans, rot, vib component.
     """
 
     trans_forces: dict = zntrack.zn.metrics()
     rot_forces: dict = zntrack.zn.metrics()
     vib_forces: dict = zntrack.zn.metrics()
+    wasserstein_distance = zntrack.zn.metrics()
 
     rot_force_plt = zntrack.dvc.outs(zntrack.nwd / "rot_force.png")
     trans_force_plt = zntrack.dvc.outs(zntrack.nwd / "trans_force.png")
     vib_force_plt = zntrack.dvc.outs(zntrack.nwd / "vib_force.png")
 
+    histogram_plt = zntrack.dvc.outs(zntrack.nwd / "histogram.png")
+
     def get_plots(self):
         fig = get_figure(
             np.linalg.norm(self.true_forces["trans"], axis=-1),
@@ -294,21 +304,75 @@ def get_metrics(self):
             np.array(self.true_forces["vib"]), np.array(self.pred_forces["vib"])
         )
 
+    def get_histogram(self):
+        import matplotlib.pyplot as plt
+        from scipy.stats import wasserstein_distance
+
+        def get_rel_scalar_prod(main, relative) -> np.ndarray:
+            x = np.einsum("ij,ij->i", main, relative)
+            x /= np.linalg.norm(main, axis=-1)
+            return x
+
+        fig, axes = plt.subplots(4, 3, figsize=(4 * 5, 3 * 3))
+        fig.suptitle(
+            (
+                r"A fraction $\dfrac{\vec{a} \cdot"
+                r" \vec{b}}{\left|\left|\vec{a}\right|\right|_{2}} $ of $\vec{b}$ that"
+                r" contributes to $\vec{a}$"
+            ),
+            fontsize=16,
+        )
+
+        self.wasserstein_distance = {}
+
+        for label, ax_ in zip(self.true_forces.keys(), axes):
+            self.wasserstein_distance[label] = {}
+            for part, ax in zip(["vib", "rot", "trans"], ax_):
+                data = get_rel_scalar_prod(
+                    self.true_forces[label], self.true_forces[part]
+                )
+                true_bins = ax.hist(
+                    data, bins=50, density=True, label=f"true {label} {part}"
+                )
+
+                data = get_rel_scalar_prod(
+                    self.pred_forces[label], self.pred_forces[part]
+                )
+                pred_bins = ax.hist(
+                    data,
+                    bins=true_bins[1],
+                    density=True,
+                    alpha=0.5,
+                    label=f"pred {label} {part}",
+                )
+                ax.legend()
+                self.wasserstein_distance[label][part] = wasserstein_distance(
+                    true_bins[0], pred_bins[0]
+                )
+
+        fig.savefig(self.histogram_plt, bbox_inches="tight")
+
     def run(self):
         true_atoms, pred_atoms = self.get_data()
         mapping = BarycenterMapping(data=None)
+        # TODO make the force_decomposition return full forces
+        # TODO check if you sum the forces they yield the full forces
+        # TODO make mapping a 'zn.nodes' with Mapping(species="BF4")
+        #  maybe allow smiles and enumeration 0, 1, ...
 
-        self.true_forces = {"trans": [], "rot": [], "vib": []}
-        self.pred_forces = {"trans": [], "rot": [], "vib": []}
+        self.true_forces = {"all": [], "trans": [], "rot": [], "vib": []}
+        self.pred_forces = {"all": [], "trans": [], "rot": [], "vib": []}
 
         for atom in tqdm.tqdm(true_atoms):
             atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition(
                 atom, mapping
             )
+            self.true_forces["all"].append(atom.get_forces())
             self.true_forces["trans"].append(atom_trans_forces)
             self.true_forces["rot"].append(atom_rot_forces)
             self.true_forces["vib"].append(atom_vib_forces)
 
+        self.true_forces["all"] = np.concatenate(self.true_forces["all"])
         self.true_forces["trans"] = np.concatenate(self.true_forces["trans"])
         self.true_forces["rot"] = np.concatenate(self.true_forces["rot"])
         self.true_forces["vib"] = np.concatenate(self.true_forces["vib"])
@@ -317,13 +381,16 @@ def run(self):
             atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition(
                 atom, mapping
             )
+            self.pred_forces["all"].append(atom.get_forces())
             self.pred_forces["trans"].append(atom_trans_forces)
             self.pred_forces["rot"].append(atom_rot_forces)
             self.pred_forces["vib"].append(atom_vib_forces)
 
+        self.pred_forces["all"] = np.concatenate(self.pred_forces["all"])
         self.pred_forces["trans"] = np.concatenate(self.pred_forces["trans"])
         self.pred_forces["rot"] = np.concatenate(self.pred_forces["rot"])
         self.pred_forces["vib"] = np.concatenate(self.pred_forces["vib"])
 
         self.get_metrics()
         self.get_plots()
+        self.get_histogram()