Skip to content

Commit

Permalink
Use same dispersion parameters for reference energies (#32)
Browse files Browse the repository at this point in the history
Minor cleanup for and re-calculation of reference energies with
TPSS0-ATM parameters for consistency.

`disp3` now also only contains the ATM energy (before it was `disp2` +
ATM).
  • Loading branch information
marvinfriede authored Oct 20, 2023
1 parent cefa86b commit e22cc0a
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 204 deletions.
15 changes: 6 additions & 9 deletions src/tad_dftd3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,14 @@ def dftd3(
if ref is None:
ref = reference.Reference(**dd)
if rcov is None:
rcov = data.covalent_rad_d3[numbers].type(positions.dtype).to(positions.device)
rcov = data.covalent_rad_d3[numbers].to(**dd)
if rvdw is None:
rvdw = (
data.vdw_rad_d3[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
.type(positions.dtype)
.to(positions.device)
)
rvdw = data.vdw_rad_d3[
numbers.unsqueeze(-1),
numbers.unsqueeze(-2),
].to(**dd)
if r4r2 is None:
r4r2 = (
data.sqrt_z_r4_over_r2[numbers].type(positions.dtype).to(positions.device)
)
r4r2 = data.sqrt_z_r4_over_r2[numbers].to(**dd)

cn = ncoord.coordination_number(numbers, positions, rcov, counting_function)
weights = model.weight_references(numbers, cn, ref, weighting_function)
Expand Down
13 changes: 5 additions & 8 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def dispersion(
if cutoff is None:
cutoff = torch.tensor(50.0, **dd)
if r4r2 is None:
r4r2 = (
data.sqrt_z_r4_over_r2[numbers].type(positions.dtype).to(positions.device)
)
r4r2 = data.sqrt_z_r4_over_r2[numbers].to(**dd)
if numbers.shape != positions.shape[:-1]:
raise ValueError(
"Shape of positions is not consistent with atomic numbers.",
Expand All @@ -122,11 +120,10 @@ def dispersion(
# three-body dispersion
if "s9" in param and param["s9"] != 0.0:
if rvdw is None:
rvdw = (
data.vdw_rad_d3[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
.type(positions.dtype)
.to(positions.device)
)
rvdw = data.vdw_rad_d3[
numbers.unsqueeze(-1),
numbers.unsqueeze(-2),
].to(**dd)

energy += dispersion3(numbers, positions, param, c6, rvdw, cutoff)

Expand Down
231 changes: 148 additions & 83 deletions tests/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@


class Refs(TypedDict):
"""Format of reference values."""
"""
Format of reference values. Note that energies and gradients are calculated
with different parameters.
"""

cn: Tensor
"""Coordination number."""
Expand Down Expand Up @@ -57,8 +60,20 @@ class Record(Molecule, Refs):
"c6": torch.tensor([], dtype=torch.double),
"cn": torch.tensor([], dtype=torch.double),
"weights": torch.tensor([], dtype=torch.double),
"disp2": torch.tensor([], dtype=torch.double),
"disp3": torch.tensor([], dtype=torch.double),
"disp2": torch.tensor(
[
-1.5918418587455960e-04,
-1.5918418587455960e-04,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[
+0.0000000000000000e00,
+0.0000000000000000e00,
],
dtype=torch.double,
),
"grad": torch.tensor(
[
[
Expand Down Expand Up @@ -219,21 +234,21 @@ class Record(Molecule, Refs):
),
"disp2": torch.tensor(
[
-9.2481575005393872e-004,
-3.6494949521315417e-004,
-3.6494949521315417e-004,
-3.6494949521315417e-004,
-3.6494949521315417e-004,
-9.2481575005393872e-04,
-3.6494949521315417e-04,
-3.6494949521315417e-04,
-3.6494949521315417e-04,
-3.6494949521315417e-04,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[
-9.2481446570860746e-004,
-3.6487414688948653e-004,
-3.6487414688948653e-004,
-3.6487414688948653e-004,
-3.6487414688948653e-004,
+1.2843453312590819e-09,
+7.5348323667640688e-08,
+7.5348323667640688e-08,
+7.5348323667640688e-08,
+7.5348323667640688e-08,
],
dtype=torch.double,
),
Expand Down Expand Up @@ -964,43 +979,43 @@ class Record(Molecule, Refs):
),
"disp2": torch.tensor(
[
-2.8788632548321321e-003,
-6.3435979775151754e-004,
-9.6167619562274962e-004,
-7.9723260613915258e-004,
-7.9238263177385578e-004,
-7.4485995467369389e-004,
-1.0311812354479540e-003,
-1.0804678845482093e-003,
-2.1424517331896948e-003,
-5.3905710617330410e-004,
-7.3549132878459982e-004,
-2.9718856310496566e-003,
-1.9053629060228276e-003,
-1.8362475794413465e-003,
-1.7182276597931356e-003,
-4.2417715940356341e-003,
-2.8788632548321321e-03,
-6.3435979775151754e-04,
-9.6167619562274962e-04,
-7.9723260613915258e-04,
-7.9238263177385578e-04,
-7.4485995467369389e-04,
-1.0311812354479540e-03,
-1.0804678845482093e-03,
-2.1424517331896948e-03,
-5.3905710617330410e-04,
-7.3549132878459982e-04,
-2.9718856310496566e-03,
-1.9053629060228276e-03,
-1.8362475794413465e-03,
-1.7182276597931356e-03,
-4.2417715940356341e-03,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[
-2.8718125389999259e-003,
-6.3328090446635918e-004,
-9.5663711211542641e-004,
-7.9370460692262154e-004,
-7.9033697856002835e-004,
-7.4037167668508294e-004,
-1.0277787758263043e-003,
-1.0733979636313967e-003,
-2.1410728848939844e-003,
-5.3484648487498051e-004,
-7.3184554571681479e-004,
-2.9622995709883419e-003,
-1.9025657858451914e-003,
-1.8324762672052280e-003,
-1.7135582283322110e-003,
-4.2406598201600847e-003,
+7.0507158322062093e-06,
+1.0788932851583596e-06,
+5.0390835073232118e-06,
+3.5279992165310452e-06,
+2.0456532138274277e-06,
+4.4882779886109463e-06,
+3.4024596216497734e-06,
+7.0699209168125984e-06,
+1.3788482957103818e-06,
+4.2106212983235953e-06,
+3.6457830677850229e-06,
+9.5860600613146586e-06,
+2.7971201776362010e-06,
+3.7713122361185403e-06,
+4.6694314609246109e-06,
+1.1117738755494003e-06,
],
dtype=torch.double,
),
Expand Down Expand Up @@ -3606,20 +3621,30 @@ class Record(Molecule, Refs):
),
"disp2": torch.tensor(
[
-3.5479912602e-04,
-8.9124281989e-05,
-8.9124287363e-05,
-8.9124287363e-05,
-1.3686794039e-04,
-3.8805575850e-04,
-8.7387460069e-05,
-8.7387464149e-05,
-8.7387460069e-05,
-1.7789829941052290e-03,
-4.2874420697641040e-04,
-4.2874425413182740e-04,
-4.2874425413182740e-04,
-6.4605081235581219e-04,
-1.8277741525957012e-03,
-4.4890931954739776e-04,
-4.4890934300120941e-04,
-4.4890931954739776e-04,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[],
[
+1.5164457178775542e-07,
+3.1871289285333041e-07,
+3.1871279049093017e-07,
+3.1871279049093017e-07,
-5.9772721699589883e-07,
-3.5376082968855901e-07,
+1.4591177238904105e-07,
+1.4591163155676249e-07,
+1.4591177238904105e-07,
],
dtype=torch.double,
),
"grad": torch.tensor(
Expand Down Expand Up @@ -4965,29 +4990,48 @@ class Record(Molecule, Refs):
),
"disp2": torch.tensor(
[
-4.1551151549e-04,
-3.9770287009e-04,
-4.1552470565e-04,
-4.4246829733e-04,
-4.7527776799e-04,
-4.4258484762e-04,
-1.0637547378e-03,
-1.5452322970e-04,
-1.9695663808e-04,
-1.6184434935e-04,
-1.9703176496e-04,
-1.6183339573e-04,
-4.6648977616e-04,
-1.3764556692e-04,
-2.4555353368e-04,
-1.3535967638e-04,
-1.5719227870e-04,
-1.1675684940e-04,
-1.9420461943405458e-03,
-1.8659072210258116e-03,
-1.9421688758887014e-03,
-2.2256063318899419e-03,
-2.3963299472900094e-03,
-2.2258129538456762e-03,
-4.5810403655531691e-03,
-6.0279450821464173e-04,
-7.9994791096430059e-04,
-6.1485615934089312e-04,
-7.9989323817241818e-04,
-6.1484107713457887e-04,
-2.2996378209045958e-03,
-5.6155104045316131e-04,
-1.1544788441618554e-03,
-5.5259186314968840e-04,
-6.8597888322421800e-04,
-5.0103989808744046e-04,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[],
[
-1.2978866706459067e-06,
-6.8327757407160399e-07,
-1.2942593535913288e-06,
-5.7304824129487952e-07,
-8.9195765730180898e-07,
-4.8897672215875848e-07,
-5.9620837808702434e-06,
-5.1712490636531602e-07,
+2.1379354562450553e-06,
+7.7699432620597416e-07,
+2.1956704534880581e-06,
+7.6716763665232290e-07,
-9.5275400116253198e-07,
+6.0068639199219523e-07,
-2.6385604432973588e-07,
+1.1560414358817309e-06,
-2.6528734005501400e-07,
-1.3951746669187961e-06,
],
dtype=torch.double,
),
"grad": torch.tensor([], dtype=torch.double),
Expand Down Expand Up @@ -5075,24 +5119,45 @@ class Record(Molecule, Refs):
),
"disp2": torch.tensor(
[
-1.5994749264791608e-04,
-8.4440915088634938e-05,
-8.4437982877716422e-05,
-8.4441242506727576e-05,
-1.048180025875288e-03,
-4.430683267237130e-04,
-4.430435696703567e-04,
-4.430709410870264e-04,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[
-1.5993273700587451e-04,
-8.4211184002924711e-05,
-8.4208259067963809e-05,
-8.4211511421017349e-05,
1.475402588166237e-08,
2.297333064597274e-07,
2.297265476250950e-07,
2.297346486316179e-07,
],
dtype=torch.double,
),
"grad": torch.tensor(
[],
[
[
-3.091609121445480e-10,
2.958185285646392e-12,
3.762196005417977e-07,
],
[
-1.982582438864074e-05,
-5.360338422731795e-06,
-1.431825059800939e-07,
],
[
5.276022219053056e-06,
1.985418497606805e-05,
-1.170419972562382e-07,
],
[
1.455011133049982e-05,
-1.449384951152156e-05,
-1.159950973054663e-07,
],
],
dtype=torch.double,
),
"hessian": torch.tensor(
Expand Down
Loading

0 comments on commit e22cc0a

Please sign in to comment.