Skip to content

Commit

Permalink
Introduced qmckl_compute_forces_jastrow_single_een_hpc
Browse files Browse the repository at this point in the history
  • Loading branch information
scemama committed Feb 19, 2025
1 parent ea28f4c commit c04ce98
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 37 deletions.
6 changes: 3 additions & 3 deletions org/qmckl_context.org
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ int main() {
#include "qmckl_mo_private_type.h"
#include "qmckl_jastrow_champ_private_type.h"
#include "qmckl_jastrow_champ_single_private_type.h"
#include "qmckl_jastrow_champ_quad_private_type.h"
// #include "qmckl_jastrow_champ_quad_private_type.h"
#include "qmckl_forces_private_type.h"
#include "qmckl_determinant_private_type.h"
#include "qmckl_local_energy_private_type.h"
Expand All @@ -46,7 +46,7 @@ int main() {
#include "qmckl_mo_private_func.h"
#include "qmckl_jastrow_champ_private_func.h"
#include "qmckl_jastrow_champ_single_private_func.h"
#include "qmckl_jastrow_champ_quad_private_func.h"
// #include "qmckl_jastrow_champ_quad_private_func.h"
#include "qmckl_forces_private_func.h"
#include "qmckl_determinant_private_func.h"
#include "qmckl_local_energy_private_func.h"
Expand Down Expand Up @@ -136,7 +136,7 @@ typedef struct qmckl_context_struct {
/* Points */
qmckl_point_struct point;
qmckl_jastrow_champ_single_struct single_point;
qmckl_jastrow_champ_quad_struct quad_point;
// qmckl_jastrow_champ_quad_struct quad_point;

/* -- Molecular system -- */
qmckl_nucleus_struct nucleus;
Expand Down
188 changes: 154 additions & 34 deletions org/qmckl_forces.org
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ qmckl_exit_code qmckl_finite_difference_deriv_n(

rc = qmckl_context_touch(context);
assert(rc == QMCKL_SUCCESS);

rc = qmckl_single_touch(context);
assert(rc == QMCKL_SUCCESS);

Expand Down Expand Up @@ -2557,7 +2557,7 @@ integer(qmckl_exit_code) function qmckl_compute_forces_jastrow_een( &
if(cn == 0.d0) cycle
do j = 1, elec_num
do ii = 1, 3
accu = een_rescaled_n(j,a,m,nw) * forces_tmp_c(j,ii,a,m+l,k,nw) &
accu = een_rescaled_n(j,a,m,nw) * forces_tmp_c(j,ii,a,m+l,k,nw) &
- een_rescaled_n_gl(j,ii,a,m,nw) * tmp_c(j,a,m+l,k,nw)

forces_jastrow_een(ii, a, nw) = forces_jastrow_een(ii, a, nw) + accu * cn
Expand Down Expand Up @@ -2843,8 +2843,8 @@ integer(qmckl_exit_code) function qmckl_compute_forces_een_rescaled_n_gl( &
! temp * ria_inv / kappa_l - temp
! end do
! forces_een_n(i,4,a,n,l,nw) = forces_een_n(i,4,a,n,l,nw) + &
! (2.0d0 * ria_inv * ria_inv &
! - een_rescaled_n_gl(i,4,a,l,nw)/een_rescaled_n(i,a,l,nw)) * &
! (2.0d0 * ria_inv * ria_inv &
! - een_rescaled_n_gl(i,4,a,l,nw)/een_rescaled_n(i,a,l,nw)) * &
! een_rescaled_n_gl(i,n,a,l,nw)
! end do
! end do
Expand Down Expand Up @@ -4640,7 +4640,7 @@ qmckl_get_forces_jastrow_single_een(qmckl_context context,
real(c_double), intent(out) :: forces_jastrow_single_een(size_max)
end function qmckl_get_forces_jastrow_single_een
end interface
#+end_src
#+end_src

** Provide :noexport:

Expand Down Expand Up @@ -4760,7 +4760,7 @@ qmckl_exit_code qmckl_provide_forces_jastrow_single_een(qmckl_context context)

return QMCKL_SUCCESS;
}
#+end_src
#+end_src

** Compute
:PROPERTIES:
Expand All @@ -4770,27 +4770,27 @@ qmckl_exit_code qmckl_provide_forces_jastrow_single_een(qmckl_context context)
:END:

#+NAME: qmckl_forces_single_een_args
| Variable | Type | In/Out | Description |
|---------------------------+-------------------------------------------+--------+---------------------------------------------------------|
| ~context~ | ~qmckl_context~ | in | Global state |
| ~num~ | ~int64_t~ | in | Index of single electron |
| ~walk_num~ | ~int64_t~ | in | Number of walkers |
| ~elec_num~ | ~int64_t~ | in | Number of electrons |
| ~nucl_num~ | ~int64_t~ | in | Number of nuclei |
| ~cord_num~ | ~int64_t~ | in | order of polynomials |
| ~dim_c_vector~ | ~int64_t~ | in | dimension of full coefficient vector |
| ~c_vector_full~ | ~double[dim_c_vector][nucl_num]~ | in | full coefficient vector |
| ~lkpm_combined_index~ | ~int64_t[4][dim_c_vector]~ | in | combined indices |
| ~delta_p~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][elec_num]~ | in | Single electron P matrix |
| ~forces_delta_p~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][3][elec_num]~ | in | Single electron P matrix |
| ~tmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][elec_num]~ | in | Single electron P matrix |
| ~forces_tmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][4][elec_num]~ | in | Single electron P matrix |
| ~een_rescaled_n~ | ~double[walk_num][0:cord_num][nucl_num][elec_num]~ | in | Electron-nucleus rescaled distances |
| ~een_rescaled_single_n~ | ~double[walk_num][0:cord_num][nucl_num]~ | in | Electron-nucleus single rescaled distances |
| ~een_rescaled_n_gl~ | ~double[walk_num][0:cord_num][nucl_num][4][elec_num]~ | in | Electron-nucleus rescaled distances derivatives |
| ~een_rescaled_single_n_gl~ | ~double[walk_num][0:cord_num][nucl_num][4]~ | in | Electron-nucleus single rescaled distances derivatives |
| ~forces_jastrow_single_een~ | ~double[walk_num][nucl_num][3]~ | out | Single electron-nucleus forces |
|---------------------------+-------------------------------------------+--------+---------------------------------------------------------|
| Variable | Type | In/Out | Description |
|-----------------------------+---------------------------------------------------------------------+--------+--------------------------------------------------------|
| ~context~ | ~qmckl_context~ | in | Global state |
| ~num~ | ~int64_t~ | in | Index of single electron |
| ~walk_num~ | ~int64_t~ | in | Number of walkers |
| ~elec_num~ | ~int64_t~ | in | Number of electrons |
| ~nucl_num~ | ~int64_t~ | in | Number of nuclei |
| ~cord_num~ | ~int64_t~ | in | order of polynomials |
| ~dim_c_vector~ | ~int64_t~ | in | dimension of full coefficient vector |
| ~c_vector_full~ | ~double[dim_c_vector][nucl_num]~ | in | full coefficient vector |
| ~lkpm_combined_index~ | ~int64_t[4][dim_c_vector]~ | in | combined indices |
| ~delta_p~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][elec_num]~ | in | Single electron P matrix |
| ~forces_delta_p~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][3][elec_num]~ | in | Single electron P matrix |
| ~tmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][elec_num]~ | in | Single electron P matrix |
| ~forces_tmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][4][elec_num]~ | in | Single electron P matrix |
| ~een_rescaled_n~ | ~double[walk_num][0:cord_num][nucl_num][elec_num]~ | in | Electron-nucleus rescaled distances |
| ~een_rescaled_single_n~ | ~double[walk_num][0:cord_num][nucl_num]~ | in | Electron-nucleus single rescaled distances |
| ~een_rescaled_n_gl~ | ~double[walk_num][0:cord_num][nucl_num][4][elec_num]~ | in | Electron-nucleus rescaled distances derivatives |
| ~een_rescaled_single_n_gl~ | ~double[walk_num][0:cord_num][nucl_num][4]~ | in | Electron-nucleus single rescaled distances derivatives |
| ~forces_jastrow_single_een~ | ~double[walk_num][nucl_num][3]~ | out | Single electron-nucleus forces |
|-----------------------------+---------------------------------------------------------------------+--------+--------------------------------------------------------|

#+begin_src f90 :comments org :tangle (eval f) :noweb yes
function qmckl_compute_forces_jastrow_single_een_doc( &
Expand Down Expand Up @@ -4876,6 +4876,105 @@ function qmckl_compute_forces_jastrow_single_een_doc( &
end function qmckl_compute_forces_jastrow_single_een_doc
#+end_src

#+begin_src f90 :comments org :tangle (eval f) :noweb yes
function qmckl_compute_forces_jastrow_single_een_hpc( &
context, num_in, walk_num, elec_num, nucl_num, cord_num, &
dim_c_vector, c_vector_full, lkpm_combined_index, &
delta_p, forces_delta_p, tmp_c, forces_tmp_c, &
een_rescaled_n, een_rescaled_single_n, een_rescaled_n_gl, een_rescaled_single_n_gl, forces_jastrow_single_een) &
bind(C) result(info)
use qmckl
implicit none

integer (qmckl_context), intent(in), value :: context
integer (c_int64_t) , intent(in) , value :: num_in
integer (c_int64_t) , intent(in) , value :: walk_num
integer (c_int64_t) , intent(in) , value :: elec_num
integer (c_int64_t) , intent(in) , value :: nucl_num
integer (c_int64_t) , intent(in) , value :: dim_c_vector
integer (c_int64_t) , intent(in) , value :: cord_num
integer(c_int64_t) , intent(in) :: lkpm_combined_index(dim_c_vector,4)
real(c_double) , intent(in) :: c_vector_full(nucl_num, dim_c_vector)
real (c_double ) , intent(in) :: delta_p(elec_num, nucl_num,0:cord_num, 0:cord_num-1, walk_num)
real (c_double ) , intent(in) :: forces_delta_p(elec_num, 3, nucl_num,0:cord_num, 0:cord_num-1, walk_num)
real (c_double ) , intent(in) :: tmp_c(elec_num, nucl_num,0:cord_num, 0:cord_num-1, walk_num)
real (c_double ) , intent(in) :: forces_tmp_c(elec_num, 4, nucl_num,0:cord_num, 0:cord_num-1, walk_num)
real (c_double ) , intent(in) :: een_rescaled_n(elec_num, nucl_num, 0:cord_num, walk_num)
real (c_double ) , intent(in) :: een_rescaled_single_n(nucl_num, 0:cord_num, walk_num)
real (c_double ) , intent(in) :: een_rescaled_n_gl(elec_num, 4, nucl_num, 0:cord_num, walk_num)
real (c_double ) , intent(in) :: een_rescaled_single_n_gl(4, nucl_num, 0:cord_num, walk_num)
real (c_double ) , intent(out) :: forces_jastrow_single_een(3,nucl_num,walk_num)
integer(qmckl_exit_code) :: info

double precision, allocatable :: een_rescaled_delta_n(:, :), een_rescaled_delta_n_gl(:,:,:)

integer*8 :: i, a, j, l, k, p, m, n, nw, num, kk
double precision :: accu2, cn
double precision, allocatable :: accu(:,:), tmp(:)
integer*8 :: LDA, LDB, LDC
double precision, external :: ddot

num = num_in + 1

info = QMCKL_SUCCESS

if (context == QMCKL_NULL_CONTEXT) info = QMCKL_INVALID_CONTEXT
if (walk_num <= 0) info = QMCKL_INVALID_ARG_3
if (elec_num <= 0) info = QMCKL_INVALID_ARG_4
if (nucl_num <= 0) info = QMCKL_INVALID_ARG_5
if (cord_num < 0) info = QMCKL_INVALID_ARG_6
if (info /= QMCKL_SUCCESS) return

forces_jastrow_single_een = 0.0d0

if (cord_num == 0) return

allocate(een_rescaled_delta_n(nucl_num, 0:cord_num), een_rescaled_delta_n_gl(3,nucl_num, 0:cord_num), &
accu(3,nucl_num), tmp(nucl_num))

do nw =1, walk_num
een_rescaled_delta_n(:,:) = een_rescaled_single_n(:,:,nw) - een_rescaled_n(num,:,:,nw)
een_rescaled_delta_n_gl(1:3,:,:) = een_rescaled_single_n_gl(1:3,:,:,nw) - een_rescaled_n_gl(num,1:3,:,:,nw)
do n = 1, dim_c_vector
l = lkpm_combined_index(n, 1)
k = lkpm_combined_index(n, 2)
p = lkpm_combined_index(n, 3)
m = lkpm_combined_index(n, 4)

do a = 1, nucl_num
cn = c_vector_full(a, n)
accu(1:3,a) = 0.0d0
if(cn == 0.d0) cycle
tmp(a) = tmp_c(num,a,m+l,k,nw) + delta_p(num,a,m+l,k,nw)
call dgemv('T', elec_num, 3, -1.d0, een_rescaled_n_gl(1,1,a,m,nw), elec_num, &
delta_p(1,a,m+l,k,nw), 1, 0.d0, accu(1,a), 1)
call dgemv('T', elec_num, 3, 1.d0, forces_delta_p(1,1,a,m+l,k,nw), elec_num, &
een_rescaled_n(1,a,m,nw), 1, 1.d0, accu(1,a), 1)

enddo

accu(1,:) = accu(1,:) - een_rescaled_delta_n_gl(1,:,m)*tmp(:)
accu(2,:) = accu(2,:) - een_rescaled_delta_n_gl(2,:,m)*tmp(:)
accu(3,:) = accu(3,:) - een_rescaled_delta_n_gl(3,:,m)*tmp(:)

accu(1,:) = accu(1,:) + een_rescaled_delta_n(:,m)*forces_tmp_c(num,1,:,m+l,k,nw)
accu(2,:) = accu(2,:) + een_rescaled_delta_n(:,m)*forces_tmp_c(num,2,:,m+l,k,nw)
accu(3,:) = accu(3,:) + een_rescaled_delta_n(:,m)*forces_tmp_c(num,3,:,m+l,k,nw)

accu(1,:) = accu(1,:) + een_rescaled_delta_n(:,m)*forces_delta_p(num,1,:,m+l,k,nw)
accu(2,:) = accu(2,:) + een_rescaled_delta_n(:,m)*forces_delta_p(num,2,:,m+l,k,nw)
accu(3,:) = accu(3,:) + een_rescaled_delta_n(:,m)*forces_delta_p(num,3,:,m+l,k,nw)

forces_jastrow_single_een(1,:,nw) = forces_jastrow_single_een(1,:,nw) + accu(1,:) * c_vector_full(:,n)
forces_jastrow_single_een(2,:,nw) = forces_jastrow_single_een(2,:,nw) + accu(2,:) * c_vector_full(:,n)
forces_jastrow_single_een(3,:,nw) = forces_jastrow_single_een(3,:,nw) + accu(3,:) * c_vector_full(:,n)
end do
end do


end function qmckl_compute_forces_jastrow_single_een_hpc
#+end_src

#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
qmckl_exit_code qmckl_compute_forces_jastrow_single_een_doc (
const qmckl_context context,
Expand Down Expand Up @@ -4916,6 +5015,27 @@ qmckl_exit_code qmckl_compute_forces_jastrow_single_een (
const double* een_rescaled_n_gl,
const double* een_rescaled_single_n_gl,
double* const forces_jastrow_single_een );

qmckl_exit_code qmckl_compute_forces_jastrow_single_een_hpc (
const qmckl_context context,
const int64_t num,
const int64_t walk_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t cord_num,
const int64_t dim_c_vector,
const double* c_vector_full,
const int64_t* lkpm_combined_index,
const double* delta_p,
const double* forces_delta_p,
const double* tmp_c,
const double* forces_tmp_c,
const double* een_rescaled_n,
const double* een_rescaled_single_n,
const double* een_rescaled_n_gl,
const double* een_rescaled_single_n_gl,
double* const forces_jastrow_single_een );

#+end_src


Expand All @@ -4941,7 +5061,7 @@ qmckl_compute_forces_jastrow_single_een (const qmckl_context context,
double* const forces_jastrow_single_een )
{
#ifdef HAVE_HPC
return qmckl_compute_forces_jastrow_single_een_doc
return qmckl_compute_forces_jastrow_single_een_hpc
#else
return qmckl_compute_forces_jastrow_single_een_doc
#endif
Expand Down Expand Up @@ -5323,7 +5443,7 @@ qmckl_exit_code qmckl_provide_forces_ao_value(qmckl_context context)
"Unable to free ctx->forces.forces_ao_value");
}
ctx->forces.forces_ao_value = NULL;

}
}
/* Allocate array */
Expand All @@ -5340,7 +5460,7 @@ qmckl_exit_code qmckl_provide_forces_ao_value(qmckl_context context)
}
ctx->forces.forces_ao_value = forces_ao_value;
}

rc = qmckl_provide_ao_basis_ao_vgl(context);
if (rc != QMCKL_SUCCESS) {
return qmckl_failwith( context, rc, "qmckl_provide_ao_basis_ao_vgl", NULL);
Expand Down Expand Up @@ -5884,7 +6004,7 @@ integer(qmckl_exit_code) function qmckl_compute_forces_mo_value_doc(context, &
forces_mo_value(i,j,3,a) = forces_mo_value(i,j,3,a) + coefficient_t(i,k) * c3
end do
k = k + 1
end do
end do
end do
end do
end do
Expand Down Expand Up @@ -6205,7 +6325,7 @@ integer function qmckl_compute_forces_mo_g_doc(context, &
ao_num, mo_num, point_num, nucl_num, &
shell_num, nucleus_index, nucleus_shell_num, shell_ang_mom, &
coefficient_t, ao_hessian, forces_mo_g) &
bind(C) result(info)
bind(C) result(info)
use qmckl
implicit none
integer(qmckl_context), intent(in), value :: context
Expand Down Expand Up @@ -6249,7 +6369,7 @@ integer function qmckl_compute_forces_mo_g_doc(context, &
ishell_end = nucleus_index(a) + nucleus_shell_num(a)
do n = 1, 3
do j=1,point_num
do m = 1, 3
do m = 1, 3
do ishell = ishell_start, ishell_end
k = ao_index(ishell)
l = shell_ang_mom(ishell)
Expand Down Expand Up @@ -6603,7 +6723,7 @@ integer function qmckl_compute_forces_mo_l_doc(context, &
ao_num, mo_num, point_num, nucl_num, &
shell_num, nucleus_index, nucleus_shell_num, shell_ang_mom, &
coefficient_t, ao_hessian, forces_mo_l) &
bind(C) result(info)
bind(C) result(info)
use qmckl
implicit none
integer(qmckl_context), intent(in), value :: context
Expand Down

0 comments on commit c04ce98

Please sign in to comment.