Skip to content

Commit

Permalink
Fallback to initialising electrons about the origin if a per-atom ini…
Browse files Browse the repository at this point in the history
…tialisation is not found

PiperOrigin-RevId: 713722201
Change-Id: Ib6b2e491218ca4d5e9c65ce7518377f669607031
  • Loading branch information
jsspencer committed Jan 9, 2025
1 parent 57f784d commit 266c025
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def init_electrons( # pylint: disable=dangerous-default-value
batch_size: int,
init_width: float,
core_electrons: Mapping[str, int] = {},
max_iter: int = 10_000,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Initializes electron positions around each atom.
Expand All @@ -75,6 +76,9 @@ def init_electrons( # pylint: disable=dangerous-default-value
electron configurations.
core_electrons: mapping of element symbol to number of core electrons
included in the pseudopotential.
max_iter: maximum number of iterations to try to find a valid initial
electron configuration for each atom. If reached, all electrons are
initialised from a Gaussian distribution centred on the origin.
Returns:
array of (batch_size, (nalpha+nbeta)*ndim) of initial (random) electron
Expand All @@ -83,6 +87,7 @@ def init_electrons( # pylint: disable=dangerous-default-value
of spin configurations, where 1 and -1 indicate alpha and beta electrons
respectively.
"""
niter = 0
total_electrons = sum(atom.charge - core_electrons.get(atom.symbol, 0)
for atom in molecule)
if total_electrons != sum(electrons):
Expand All @@ -98,19 +103,35 @@ def init_electrons( # pylint: disable=dangerous-default-value
for atom in molecule
]
assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons)
while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons:
while (
tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons
and niter < max_iter
):
i = np.random.randint(len(atomic_spin_configs))
nalpha, nbeta = atomic_spin_configs[i]
atomic_spin_configs[i] = nbeta, nalpha
niter += 1

if tuple(sum(x) for x in zip(*atomic_spin_configs)) == electrons:
# Assign each electron to an atom initially.
electron_positions = []
for i in range(2):
for j in range(len(molecule)):
atom_position = jnp.asarray(molecule[j].coords)
electron_positions.append(
jnp.tile(atom_position, atomic_spin_configs[j][i]))
electron_positions = jnp.concatenate(electron_positions)
else:
logging.warning(
'Failed to find a valid initial electron configuration after %i'
' iterations. Initializing all electrons from a Gaussian distribution'
' centred on the origin. This might require increasing the number of'
' iterations used for pretraining and MCMC burn-in. Consider'
' implementing a custom initialisation.',
niter,
)
electron_positions = jnp.zeros(shape=(3*sum(electrons),))

# Assign each electron to an atom initially.
electron_positions = []
for i in range(2):
for j in range(len(molecule)):
atom_position = jnp.asarray(molecule[j].coords)
electron_positions.append(
jnp.tile(atom_position, atomic_spin_configs[j][i]))
electron_positions = jnp.concatenate(electron_positions)
# Create a batch of configurations with a Gaussian distribution about each
# atom.
key, subkey = jax.random.split(key)
Expand Down

0 comments on commit 266c025

Please sign in to comment.