Skip to content

Commit

Permalink
Merge pull request #41 from XanaduAI/40-make-linear-mixing-scf-code-j…
Browse files Browse the repository at this point in the history
…ittable

Jitted version of linear mixing implemented
  • Loading branch information
jackbaker1001 authored Sep 10, 2023
2 parents d398a53 + 5e9ac16 commit a26a36b
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions grad_dft/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,95 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca

return scf_iterator

def make_jitted_simple_scf_loop(functional: Functional, cycles: int = 25, mixing_factor: float = 0.4, **kwargs) -> Callable:
r"""
Creates an scf_iterator object that can be called to implement a self-consistent loop using linear mixing.
intented to be jax.jit compatible (fully self-differentiable).
If you are looking for a more flexible but not differentiable scf loop, see evaluate.py make_scf_loop.
Main parameters
---------------
functional: Functional
max_cycles: int, default to 25
Returns
---------
float
"""

predict_molecule = molecule_predictor(functional, chunk_size=None, **kwargs)

@jit
def scf_jitted_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Scalar]:
r"""
Implements a scf loop intented for use in a jax.jit compiled function (training loop).
If you are looking for a more flexible but not differentiable scf loop, see evaluate.py make_scf_loop.
It asks for a Molecule and a functional implicitly defined predict_molecule with
parameters params
Parameters
----------
params: PyTree
molecule: Molecule
*args: Arguments to be passed to predict_molecule function
"""

if molecule.omegas:
raise NotImplementedError(
"SCF training loop not implemented for (range-separated) exact-exchange functionals. \
Doing so would require a differentiable way of recomputing the chi tensor."
)

old_e = jnp.inf
norm_gorb = jnp.inf

predicted_e, fock = predict_molecule(params, molecule, *args)

old_e = jnp.inf
norm_gorb = jnp.inf

predicted_e, fock = predict_molecule(params, molecule, *args)

state = (molecule, fock, predicted_e, old_e, norm_gorb)

def loop_body(cycle, state):
old_state = state
molecule, fock, predicted_e, old_e, norm_gorb = old_state
old_e = predicted_e
old_rdm1 = molecule.rdm1

# Diagonalize Fock matrix
mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

# Update the molecular occupation
mo_occ = molecule.get_occ()
molecule = molecule.replace(mo_occ=mo_occ)

# Update the density matrix with linear mixing
unmixed_new_rdm1 = molecule.make_rdm1()
rdm1 = (1 - mixing_factor)*old_rdm1 + mixing_factor*unmixed_new_rdm1
molecule = molecule.replace(rdm1=rdm1)

# Compute the new energy and Fock matrix
predicted_e, fock = predict_molecule(params, molecule, *args)

# Compute the norm of the gradient
norm_gorb = jnp.linalg.norm(orbital_grad(mo_coeff, mo_occ, fock))

state = (molecule, fock, predicted_e, old_e, norm_gorb)

return state

# Compute the scf loop
final_state = fori_loop(0, cycles, body_fun=loop_body, init_val=state)
molecule, fock, predicted_e, old_e, norm_gorb = final_state

return predicted_e, fock, molecule.rdm1

return scf_jitted_iterator


def make_scf_loop(
functional: Functional,
Expand Down

0 comments on commit a26a36b

Please sign in to comment.