Skip to content

Commit

Permalink
Add hessian to C-api
Browse files Browse the repository at this point in the history
Signed-off-by: Ty Balduf <ty.balduf@schrodinger.com>
  • Loading branch information
TyBalduf committed Sep 9, 2024
1 parent bdc305f commit f372e79
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/xtb.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ xtb_singlepoint(xtb_TEnvironment /* env */,
xtb_TCalculator /* calc */,
xtb_TResults /* res */) XTB_API_SUFFIX__VERSION_2_0_0;

/// Perform hessian calculation
extern XTB_API_ENTRY void XTB_API_CALL
xtb_hessian(xtb_TEnvironment /* env */,
xtb_TMolecule /* mol */,
xtb_TCalculator /* calc */,
xtb_TResults /* res */,
double* /* hessian */,
int* /* atom_index_list */,
int* /* step_size */,
double* /* dipole_gradient */,
double* /* polarizability_gradient */) XTB_API_SUFFIX__VERSION_1_0_0;

/*
* Calculation results
**/
Expand Down
146 changes: 146 additions & 0 deletions src/api/interface.f90
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,150 @@ subroutine cpcmx_calc_api(venv, vmol, vcalc, vres) &
end subroutine cpcmx_calc_api


subroutine hessian_api(venv, vmol, vcalc, vres, c_hess, &
& c_step, c_list, c_dipgrad, c_polgrad) &
& bind(C, name="xtb_hessian")
!DEC$ ATTRIBUTES DLLEXPORT :: hessian_api
character(len=*), parameter :: source = 'xtb_api_hessian'
type(c_ptr), value :: venv
type(VEnvironment), pointer :: env
type(c_ptr), value :: vmol
type(VMolecule), pointer :: mol
type(c_ptr), value :: vcalc
type(VCalculator), pointer :: calc
type(c_ptr), value :: vres
type(VResults), pointer :: res

!> Array to add Hessian to
real(c_double), intent(inout) :: c_hess(*)
real(wp), allocatable :: hess(:, :)
!> List of atoms to displace
integer(c_int), intent(in), optional :: c_list(:)
integer, allocatable :: list(:)
!> Step size for numerical differentiation
real(c_double), intent(in), optional :: c_step
real(wp) :: step
!> Array to add dipole gradient to
real(c_double), intent(inout), optional :: c_dipgrad(*)
real(wp), allocatable :: dipgrad(:, :)
!> Array to add polarizability gradient to
real(c_double), intent(inout), optional :: c_polgrad(*)
real(wp), allocatable :: polgrad(:, :)

integer :: natom, natsq, i, j
logical :: has_polgrad, has_dipgrad

if (c_associated(venv)) then
call c_f_pointer(venv, env)
call checkGlobalEnv

if (.not.c_associated(vmol)) then
call env%ptr%error("Molecular structure data is not allocated", source)
return
end if
call c_f_pointer(vmol, mol)
natom = mol%ptr%n
natsq = natom * natom

if (.not.c_associated(vcalc)) then
call env%ptr%error("Singlepoint calculator is not allocated", source)
return
end if
call c_f_pointer(vcalc, calc)

if (.not.allocated(calc%ptr)) then
call env%ptr%error("No calculator loaded for single point", &
& source)
return
end if

if (.not.c_associated(vres)) then
call env%ptr%error("Calculation results are not allocated", source)
return
end if
call c_f_pointer(vres, res)

! check cache, automatically invalidate missmatched data
if (allocated(res%chk)) then
select type(xtb => calc%ptr)
type is(TxTBCalculator)
if (res%chk%wfn%n /= mol%ptr%n .or. res%chk%wfn%n /= xtb%basis%n .or. &
& res%chk%wfn%nao /= xtb%basis%nao .or. &
& res%chk%wfn%nshell /= xtb%basis%nshell) then
deallocate(res%chk)
end if
end select
end if

if (.not.allocated(res%chk)) then
allocate(res%chk)
! in case of a new wavefunction cache we have to perform an initial guess
select type(xtb => calc%ptr)
type is(TxTBCalculator)
call newWavefunction(env%ptr, mol%ptr, xtb, res%chk)
end select
end if

hess = reshape(c_hess(:9*natsq), &
&(/3*natom, 3*natom/))
! Need to initialize, as the subroutine increments the values
hess = 0.0_wp

if (.not.present(c_step)) then
step = 0.005_wp
else
step = c_step
end if

if (.not.present(c_list)) then
list = [(i, i=1, natom)]
else
list = c_list
end if

! Dipole gradient is required by the hessian method,
! so we have to allocate it
has_dipgrad = present(c_dipgrad)
if (.not. has_dipgrad) then
allocate(dipgrad(3, 3*natom))
else
dipgrad = reshape(c_dipgrad(:9*natom), &
&(/3, 3*natom/))
end if

has_polgrad = present(c_polgrad)
if (has_polgrad) then
polgrad = reshape(c_polgrad(:18*natom), &
&(/6, 3*natom/))
end if

! hessian calculation
if (has_polgrad) then
call calc%ptr%hessian(env%ptr, mol%ptr, res%chk, list, step, &
& hess, dipgrad, polgrad)
else
call calc%ptr%hessian(env%ptr, mol%ptr, res%chk, list, step, &
& hess, dipgrad)
end if

! Symmetrize the hessian
do i = 1, 3*natom
do j = i+1, 3*natom
hess(i, j) = 0.5_wp * (hess(i, j) + hess(j, i))
hess(j, i) = hess(i, j)
end do
end do

! copy back the results
c_hess(:9*natsq) = reshape(hess, (/9*natsq/))
if (has_dipgrad) then
c_dipgrad(:9*natom) = reshape(dipgrad, (/9*natom/))
end if
if (has_polgrad) then
c_polgrad(:18*natom) = reshape(polgrad, (/18*natom/))
end if
end if

end subroutine hessian_api

end module xtb_api_interface
18 changes: 18 additions & 0 deletions test/api/c_api_example.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ int testFirst() {
double* q;
char* buffer;
double* wbo;
double* hess;

int buffersize = 512;
int tester = 0;
Expand All @@ -56,6 +57,7 @@ int testFirst() {
q = (double*) malloc(natoms * sizeof(double));
wbo = (double*) malloc(natsq * sizeof(double));
buffer = (char*) malloc(buffersize *sizeof(char));
hess = (double*) malloc(9*natsq * sizeof(double));
char solvent[] = "h2o";
char gbsa[] = "gbsa";
char alpb[] = "alpb";
Expand Down Expand Up @@ -109,6 +111,21 @@ int testFirst() {
if (!check(wbo[9], 2.89823984265213, 1.0e-8, "Bond order does not match"))
goto error;

// Compute Hessian
xtb_hessian(env, mol, calc, res, hess, NULL, NULL, NULL, NULL);
if (xtb_checkEnvironment(env))
goto error;

if (!check(hess[0], 0.4790088649, 1.0e-9, "Hessian[0,0] does not match"))
goto error;
if (!check(hess[3], -0.0463290233, 1.0e-9, "Hessian[0,3] does not match"))
goto error;
if (!check(hess[3], hess[63], 1.0e-9, "Hessian[0,3] != Hessian[3,0]"))
goto error;
if (!check(hess[(9*natsq)-1], 0.3636571159, 1.0e-9, "Hessian[21,21] does not match"))
goto error;

// GBSA
xtb_setSolvent(env, calc, solvent, NULL, NULL, NULL, gbsa);
if (xtb_checkEnvironment(env))
goto error;
Expand Down Expand Up @@ -245,6 +262,7 @@ int testFirst() {
free(q);
free(wbo);
free(buffer);
free(hess);

tester = !res;
if (!check(tester, 1, "Results not deleted"))
Expand Down

0 comments on commit f372e79

Please sign in to comment.