From fb0a3c63f0021b1a2d48ac7dcfe86a06e98fa2ba Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Tue, 16 May 2023 18:11:03 +0200 Subject: [PATCH] Replace `new_tensor` --- src/tad_dftd3/damping/atm.py | 6 ++++-- src/tad_dftd3/damping/rational.py | 6 ++++-- src/tad_dftd3/disp.py | 22 ++++++++++++++-------- src/tad_dftd3/model.py | 2 +- src/tad_dftd3/ncoord.py | 9 +++++---- 5 files changed, 28 insertions(+), 17 deletions(-) diff --git a/src/tad_dftd3/damping/atm.py b/src/tad_dftd3/damping/atm.py index 0a906b0..de1bb41 100644 --- a/src/tad_dftd3/damping/atm.py +++ b/src/tad_dftd3/damping/atm.py @@ -60,6 +60,8 @@ def dispersion_atm( Tensor Atom-resolved ATM dispersion energy. """ + dd = {"device": positions.device, "dtype": positions.dtype} + s9 = s9.type(positions.dtype).to(positions.device) rs9 = rs9.type(positions.dtype).to(positions.device) alp = alp.type(positions.dtype).to(positions.device) @@ -85,7 +87,7 @@ def dispersion_atm( torch.cdist( positions, positions, p=2, compute_mode="use_mm_for_euclid_dist" ), - positions.new_tensor(torch.finfo(positions.dtype).eps), + torch.tensor(torch.finfo(positions.dtype).eps, **dd), ), 2.0, ) @@ -107,7 +109,7 @@ def dispersion_atm( * (r2jk <= cutoff2) * (r2jk <= cutoff2), 0.375 * s / r5 + 1.0 / r3, - positions.new_tensor(0.0), + torch.tensor(0.0, **dd), ) energy = ang * fdamp * c9 diff --git a/src/tad_dftd3/damping/rational.py b/src/tad_dftd3/damping/rational.py index b9a9d26..36b2ac5 100644 --- a/src/tad_dftd3/damping/rational.py +++ b/src/tad_dftd3/damping/rational.py @@ -43,6 +43,8 @@ def rational_damping( Tensor Values of the damping function. """ - a1 = param.get("a1", distances.new_tensor(defaults.A1)) - a2 = param.get("a2", distances.new_tensor(defaults.A1)) + dd = {"device": distances.device, "dtype": distances.dtype} + + a1 = param.get("a1", torch.tensor(defaults.A1, **dd)) + a2 = param.get("a2", torch.tensor(defaults.A2, **dd)) return 1.0 / (distances.pow(order) + (a1 * torch.sqrt(qq) + a2).pow(order)) diff --git a/src/tad_dftd3/disp.py b/src/tad_dftd3/disp.py index 4f4ad3e..c945b5b 100644 --- a/src/tad_dftd3/disp.py +++ b/src/tad_dftd3/disp.py @@ -92,8 +92,10 @@ def dispersion( Damping function evaluate distance dependent contributions. Additional arguments are passed through to the function. """ + dd = {"device": positions.device, "dtype": positions.dtype} + if cutoff is None: - cutoff = positions.new_tensor(50.0) + cutoff = torch.tensor(50.0, **dd) if r4r2 is None: r4r2 = ( data.sqrt_z_r4_over_r2[numbers].type(positions.dtype).to(positions.device) @@ -155,11 +157,13 @@ def dispersion2( Damping function evaluate distance dependent contributions. Additional arguments are passed through to the function. """ + dd = {"device": positions.device, "dtype": positions.dtype} + mask = real_pairs(numbers, diagonal=False) distances = torch.where( mask, torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"), - positions.new_tensor(torch.finfo(positions.dtype).eps), + torch.tensor(torch.finfo(positions.dtype).eps, **dd), ) qq = 3 * r4r2.unsqueeze(-1) * r4r2.unsqueeze(-2) @@ -168,19 +172,19 @@ def dispersion2( t6 = torch.where( mask * (distances <= cutoff), damping_function(6, distances, qq, param, **kwargs), - positions.new_tensor(0.0), + torch.tensor(0.0, **dd), ) t8 = torch.where( mask * (distances <= cutoff), damping_function(8, distances, qq, param, **kwargs), - positions.new_tensor(0.0), + torch.tensor(0.0, **dd), ) e6 = -0.5 * torch.sum(c6 * t6, dim=-1) e8 = -0.5 * torch.sum(c8 * t8, dim=-1) - s6 = param.get("s6", positions.new_tensor(defaults.S6)) - s8 = param.get("s8", positions.new_tensor(defaults.S8)) + s6 = param.get("s6", torch.tensor(defaults.S6, **dd)) + s8 = param.get("s8", torch.tensor(defaults.S8, **dd)) return s6 * e6 + s8 * e8 @@ -220,8 +224,10 @@ def dispersion3( Tensor Atom-resolved three-body dispersion energy. """ - alp = param.get("alp", positions.new_tensor(14.0)) - s9 = param.get("s9", positions.new_tensor(14.0)) + dd = {"device": positions.device, "dtype": positions.dtype} + + alp = param.get("alp", torch.tensor(14.0, **dd)) + s9 = param.get("s9", torch.tensor(1.0, **dd)) rs9 = rs9.type(positions.dtype).to(positions.device) return dispersion_atm(numbers, positions, c6, rvdw, cutoff, s9, rs9, alp) diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model.py index 453c37d..dccd9ae 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model.py @@ -129,7 +129,7 @@ def weight_references( weights = torch.where( mask, weighting_function(reference.cn[numbers] - cn.unsqueeze(-1), **kwargs), - cn.new_tensor(0.0), + torch.tensor(0.0, device=cn.device, dtype=cn.dtype), ) norms = torch.add(torch.sum(weights, dim=-1), epsilon) diff --git a/src/tad_dftd3/ncoord.py b/src/tad_dftd3/ncoord.py index a84aa79..64f796a 100644 --- a/src/tad_dftd3/ncoord.py +++ b/src/tad_dftd3/ncoord.py @@ -116,8 +116,10 @@ def coordination_number( ------- Tensor: The coordination number of each atom in the system. """ + dd = {"device": positions.device, "dtype": positions.dtype} + if cutoff is None: - cutoff = positions.new_tensor(25.0) + cutoff = torch.tensor(25.0, **dd) if rcov is None: rcov = data.covalent_rad_d3[numbers].type(positions.dtype).to(positions.device) if numbers.shape != rcov.shape: @@ -127,18 +129,17 @@ def coordination_number( if numbers.shape != positions.shape[:-1]: raise ValueError("Shape of positions is not consistent with atomic numbers") - eps = positions.new_tensor(torch.finfo(positions.dtype).eps) mask = real_pairs(numbers, diagonal=False) distances = torch.where( mask, torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"), - eps, + torch.tensor(torch.finfo(positions.dtype).eps, **dd), ) rc = rcov.unsqueeze(-2) + rcov.unsqueeze(-1) cf = torch.where( mask * (distances <= cutoff), counting_function(distances, rc.type(distances.dtype), **kwargs), - positions.new_tensor(0.0), + torch.tensor(0.0, **dd), ) return torch.sum(cf, dim=-1)