diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 07d36e595a..099fbc42a4 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -464,8 +464,7 @@ def forward( protection=self.env_protection, ) nlist_mask = nlist != -1 - nlist_copy = nlist.detach().clone() - nlist_copy[nlist == -1] = 0 + nlist = torch.where(nlist == -1, 0, nlist) sw = torch.squeeze(sw, -1) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0)