Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAMC committed Sep 9, 2023
2 parents 4117b10 + e13f0c9 commit c9f642b
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 142 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/install_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ jobs:
- name: Install extra example dependencies
run: |
pip install -e ".[examples]"
- name: Run unit tests
run: |
pytest -v tests/unit/test_eigenproblem.py
- name: Run integration tests
run: |
pytest -v tests/integration/test_non_xc_energy.py
Expand Down
12 changes: 6 additions & 6 deletions grad_dft/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from grad_dft.utils import PyTree, Array, Scalar, Optimizer
from grad_dft.functional import Functional

from grad_dft.molecule import Molecule, abs_clip, make_rdm1, orbital_grad, general_eigh
from grad_dft.molecule import Molecule, abs_clip, make_rdm1, orbital_grad
from grad_dft.train import molecule_predictor
from grad_dft.utils import PyTree, Array, Scalar
from grad_dft.utils import PyTree, Array, Scalar, safe_fock_solver
from grad_dft.interface.pyscf import (
generate_chi_tensor,
mol_from_Molecule,
Expand Down Expand Up @@ -117,7 +117,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
else:
# Diagonalize Fock matrix
overlap = abs_clip(molecule.s1e, 1e-20)
mo_energy, mo_coeff = general_eigh(fock, overlap)
mo_energy, mo_coeff = safe_fock_solver(fock, overlap)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -275,7 +275,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
)

# Diagonalize Fock matrix
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -373,7 +373,7 @@ def nelec_cost_fn(m, mo_es, sigma, _nelectron):
if abs(predicted_e - old_e) * Hartree2kcalmol < e_conv and norm_gorb < g_conv:
# We perform an extra diagonalization to remove the level shift
# Solve eigenvalue problem
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -663,7 +663,7 @@ def loop_body(cycle, state):
fock, diis_data = diis.run(new_data, diis_data, cycle)

# Diagonalize Fock matrix
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down
1 change: 0 additions & 1 deletion grad_dft/external/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
from grad_dft.external.density_functional_approximation_dm21.density_functional_approximation_dm21.neural_numint import (
_SystemState,
)
from grad_dft.external.eigh_impl import eigh2d
116 changes: 0 additions & 116 deletions grad_dft/external/eigh_impl.py

This file was deleted.

19 changes: 0 additions & 19 deletions grad_dft/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from dataclasses import fields
from grad_dft.utils import vmap_chunked
from functools import partial
from grad_dft.external.eigh_impl import eigh2d

from jax import numpy as jnp
from jax import scipy as jsp
Expand Down Expand Up @@ -615,27 +614,9 @@ def chunked_jvp(chi_tensor, gr_tensor, ao_tensor):

return (jax.jit(chunked_jvp)(chi.transpose(3, 0, 1, 2), gr, ao)).transpose(1, 2, 3, 0)


def eig(h, x):
e0, c0 = eigh2d(h[0], x)
e1, c1 = eigh2d(h[1], x)
return jnp.stack((e0, e1), axis=0), jnp.stack((c0, c1), axis=0)

def abs_clip(arr, threshold):
return jnp.where(jnp.abs(arr) > threshold, arr, 0)

def general_eigh(A, B):
L = jnp.linalg.cholesky(B)
L_inv = jnp.linalg.inv(L)
C = L_inv @ A @ L_inv.T
C = abs_clip(C, 1e-20)
eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C)
# eigenvalues, eigenvectors_transformed = jsp.linalg.eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
eigenvectors_original = abs_clip(eigenvectors_original, 1e-20)
eigenvalues = abs_clip(eigenvalues, 1e-20)
return eigenvalues, eigenvectors_original


######################################################################

Expand Down
1 change: 1 addition & 0 deletions grad_dft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
from .tree import tree_size, tree_isfinite, tree_randn_like, tree_func, tree_shape
from .utils import to_device_arrays, Utils
from .chunk import vmap_chunked
from .eigenproblem import safe_fock_solver
144 changes: 144 additions & 0 deletions grad_dft/utils/eigenproblem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp
from jax import custom_vjp
from .types import Array, Scalar

# Probably don't alter these unless you know what you're doing
DEGEN_TOL = 1e-6
BROADENING = 1e-7


@custom_vjp
def safe_eigh(A: Array) -> tuple[Array, Array]:
r"""Get the eigenvalues and eigenvectors for an input real symmetric matrix.
A safe reverse mode gradient is implemented in safe_eigh_rev below.
Args:
A (Array): a 2D Jax array representing a real symmetric matrix.
Returns:
tuple[Array, Array]: the eigenvalues and eigenvectors of the input real symmetric matrix.
"""
evecs, evals = jnp.linalg.eigh(A)
return evecs, evals


def safe_eigh_fwd(A: Array) -> tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]:
r"""Forward mode operation of safe_eigh. Saves evecs and evals for the reverse pass.
Args:
A (Array): a 2D Jax array representing a real symmetric matrix.
Returns:
tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]: eigenvectors, eigenvalues and the input real symmetric matrix A.
"""
evecs, evals = safe_eigh(A)
return (evecs, evals), ((evecs, evals), A)


def safe_eigh_rev(res: tuple[tuple[Array, Array], Array], g: Array) -> tuple[Array]:
r"""Use the Lorentzian broading approach suggested in https://doi.org/10.1038/s42005-021-00568-6
to calculate stable backward mode gradients for degenerate eigenvectors. We only apply this
technique if eigenvalues are detected to be degenerate according to the constant DEGEN_TOL
in this module. When degeneracies are detected, the are broadened according to the constant
BROADENING also defined in this module.
Args:
res (tuple[tuple[Array, Array]): eigenvectors, eigenvales and the input real symmetric matrix A saved from the forward pass
g (Array): the gradients d[eigenvalues]/dA and d[eigenvectors]/dA
Returns:
tuple[Array]: the matrix of reverse mode gradients.
"""
(evals, evecs), A = res
grad_evals, grad_evecs = g
grad_evals_diag = jnp.diag(grad_evals)
evecs_trans = evecs.T

# Generate eigenvalue difference matrix
eval_diff = evals.reshape((1, -1)) - evals.reshape((-1, 1))
# Find elements where degen_tol condition was or wasn't was met
mask_degen = (jnp.abs(eval_diff) < DEGEN_TOL).astype(jnp.int32)
mask_non_degen = (jnp.abs(eval_diff) >= DEGEN_TOL).astype(jnp.int32)

# Regular gap for non_degen terms => 1/(e_j - e_i)
# Will get +infs turning to large numbers here if degeneracies are present.
# This doesn't matter as they multiply by 0 in the forthcoming mask when calculating
# the F-matrix
regular_gap = jnp.nan_to_num(jnp.divide(1, eval_diff))

# Lorentzian broadened gap for degen terms => (e_j - e_i)/((e_j - e_i)^2 + eps)
broadened_gap = eval_diff / (eval_diff * eval_diff + BROADENING)

# Calculate full F matrix. large numbers generated by NaNs from regular_gap are deleted here
F = 0.5 * (jnp.multiply(mask_non_degen, regular_gap) + jnp.multiply(mask_degen, broadened_gap))

# Set diagonals to 0
F = F.at[jnp.diag_indices_from(F)].set(0)

# Calculate the gradient
grad = (
jnp.linalg.inv(evecs_trans)
@ (0.5 * grad_evals_diag + jnp.multiply(F, evecs_trans @ grad_evecs))
@ evecs_trans
)
# Symmetrize
grad_sym = grad + grad.T
return (grad_sym,)


safe_eigh.defvjp(safe_eigh_fwd, safe_eigh_rev)


def safe_general_eigh(A: Array, B: Array) -> tuple[Array, Array]:
r"""Solve the general eigenproblem for the eigenvalues and eigenvectors. I.e,
. math::
AC = ECB
for matrix of eigenvectors C and diagonal matrix of eigenvalues E. This function requires all input
matrices to real and symmetric and the matrix B to be invertible.
Args:
A (Array): a real symmetric matrix
B (Array): another real symmetric matrix
Returns:
tuple[Array, Array]: the eigenvalues and matrix of eigenvectors
"""
L = jnp.linalg.cholesky(B)
L_inv = jnp.linalg.inv(L)
C = L_inv @ A @ L_inv.T
eigenvalues, eigenvectors_transformed = safe_eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
return eigenvalues, eigenvectors_original


def safe_fock_solver(fock: tuple[Array, Array], overlap: Array) -> tuple[Array, Array]:
"""Get the eigenenergies and molecular orbital coefficients for the
up and down fock spin matrices.
Args:
fock (tuple[Array, Array]): the up and down fock spin matrices
overlap (Array): the overlap matrix
Returns:
tuple[Array, Array]: the eigenenergies and matrix of molecular orbital coefficients.
"""
mo_energies_up, mo_coeffs_up = safe_general_eigh(fock[0], overlap)
mo_energies_dn, mo_coeffs_dn = safe_general_eigh(fock[1], overlap)
return jnp.stack((mo_energies_up, mo_energies_dn), axis=0), jnp.stack(
(mo_coeffs_up, mo_coeffs_dn), axis=0
)
Loading

0 comments on commit c9f642b

Please sign in to comment.