From a7fecf8db6251f3caace5c8e482f1c09229688cc Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Mon, 12 Feb 2024 17:23:19 +0200 Subject: [PATCH] reduce precision error in energies by subtracting reference energies in float64 before converting to float32 --- torchmdnet/datasets/memdataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index 329d076fe..76549dfee 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -240,11 +240,10 @@ def get(self, idx): props = {} if "y" in self.properties: - props["y"] = pt.tensor(self.y_mm[idx], dtype=pt.float32).view( - 1, 1 - ) # It would be better to use float64, but the trainer complains + y = self.y_mm[idx] if self.remove_ref_energy: - props["y"] -= self.compute_reference_energy(z) + y -= self.compute_reference_energy(z) + props["y"] = pt.tensor(y, dtype=pt.float32).view(1, 1) if "neg_dy" in self.properties: props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) if "q" in self.properties: