Skip to content

Commit

Permalink
Fix non-alchemical learnable basis bug
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 15, 2024
1 parent 187d6d0 commit 12d5907
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_spex/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def forward(self, r, samples_metadata: Labels):
split_l_aj = l_aj.split("_")
l = int(split_l_aj[0])
aj = int(split_l_aj[1])
where_aj = torch.nonzero(neighbor_species == aj)[0]
where_aj = torch.nonzero(neighbor_species == aj)[:, 0]
radial_basis_after_mlp[l][where_aj, :] = radial_mlp_l_aj(torch.index_select(radial_basis[l], 0, where_aj))
return radial_basis_after_mlp
else:
Expand Down

0 comments on commit 12d5907

Please sign in to comment.