Skip to content

Commit

Permalink
Adding 4-index transform and minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hjjvandam committed Oct 30, 2024
1 parent 58630de commit 295e990
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 17 deletions.
10 changes: 5 additions & 5 deletions src/noft/noft_2index.fi
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@
call ga_inquire(g_vec,itp_vec,n_in,n_out)
!
if (itp_opi.ne.itp_vec)
& call errquit(pname//" mismatch data types")
& call errquit(pname//" mismatch data types",10,UERR)
itp_opo = itp_opi
if (n_opi.ne.n_in)
& call errquit(pname//" mismatch dimensions")
& call errquit(pname//" mismatch dimensions",20,UERR)
n_opo = n_out
!
aa = 1.0_dp
bb = 0.0_dp
chunk = 10
if (.not.ga_create(itp_opo,1,n_opi,1,n_opo,"g_tmp",
& chunk,chunk,g_tmp))
& call errquit(pname//" failed to create g_tmp")
& call errquit(pname//" failed to create g_tmp",30,GA_ERR)
if (.not.ga_create(itp_opo,1,n_opo,1,n_opo,"g_opo",
& chunk,chunk,g_opo))
& call errquit(pname//" failed to create g_opo")
& call errquit(pname//" failed to create g_opo",40,GA_ERR)
call ga_dgemm(tn,tn,n_opi,n_opo,n_opi,aa,g_opi,g_vec,bb,g_tmp)
call ga_dgemm(ty,tn,n_opo,n_opo,n_opi,aa,g_vec,g_tmp,bb,g_opo)
if (.not.ga_destroy(g_tmp))
& call errquit(pname//" failed to destroy g_tmp")
& call errquit(pname//" failed to destroy g_tmp",50,GA_ERR)
!
end subroutine noft_2index
!
Expand Down
257 changes: 257 additions & 0 deletions src/noft/noft_4index.fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
!-----------------------------------------------------------------------
!>
!> \brief Transform a 4D tensor
!>
!> We assume a 4D tensor where the first 2 dimensions correspond to
!> electron 1, and the second 2 dimensions correspond to electron 2.
!> The transformation to apply to the indeces of electron 1 might be
!> different from the transformation for electron 2. Therefore we'll
!> have two sets of transformations. The dimensions of the
!> transformations will be the same, i.e. if g_vec1 is an NxM
!> transformation then g_vec2 has to be an NxM transformation as well.
!>
!> A Global Array limitation is that the toolkit only implements
!> standard matrix-matrix multiplications. Therefore, to perform the
!> transformation of two indeces in a 4-index transformation we need
!> to do this through N**2 matrix transformations. Obviously this
!> is a disaster in terms of scalability and raw performance. Therefore
!> we'll have to reimplement this downstream to get good performance,
!> but now we are more concerned with correctness than performance. So
!> the simpler the algorithm the better.
!>
subroutine noft_4index(g_opi,g_vec1,g_vec2,g_opo)
implicit none
#include "errquit.fh"
#include "global.fh"
#include "mafdecls.fh"
!> Input operator
integer, intent(in) :: g_opi
!> Transformation vectors for electron 1
integer, intent(in) :: g_vec1
!> Transformation vectors for electron 2
integer, intent(in) :: g_vec2
!> Output operator
integer, intent(out) :: g_opo
!
! Local
!
character(len=11), parameter :: pname = "noft_4index"
character(len=1), parameter :: tn = "N"
character(len=1), parameter :: ty = "Y"
integer :: g_i1i1i2i2
integer :: g_o1o1i2i2
integer :: g_o1o1o2o2
integer :: g_oi
integer :: n_in1, n_in2
integer :: n_out1, n_out2
integer :: itp_opi
integer :: itp_vec1
integer :: itp_vec2
integer :: ii, jj, kk, ll
real(kind=dp) :: aa, bb
integer, parameter :: ndim = 4
integer :: n_opi(ndim)
integer :: nn(ndim)
integer :: chnk(ndim)
integer :: alo(ndim), ahi(ndim)
integer :: blo(ndim), bhi(ndim)
integer :: clo(ndim), chi(ndim)
!
call nga_inquire(g_opi,itp_opi,n_opi,nn)
call ga_inquire(g_vec1,itp_vec1,n_in1,n_out1)
call ga_inquire(g_vec2,itp_vec2,n_in2,n_out2)
!
! Check types
!
if (itp_opi.ne.itp_vec1)
& call errquit(pname//" mismatch type opi and vec1",10,UERR)
if (itp_opi.ne.itp_vec2)
& call errquit(pname//" mismatch type opi and vec2",20,UERR)
!
! Check dimensions
!
if (n_in1.ne.n_in2)
& call errquit(pname//" mismatch n_in1 and n_in2",30,UERR)
if (n_out1.ne.n_out2)
& call errquit(pname//" mismatch n_out1 and n_out2",40,UERR)
do ii = 1, ndim
if (n_opi(ii).ne.n_in1)
& call errquit(pname//" mismatch g_opi and n_in1",50,UERR)
enddo
!
aa = 1.0_dp
bb = 0.0_dp
chnk = 10
g_i1i1i2i2 = g_opi
nn(3) = n_out1
nn(4) = n_out1
if (.not.nga_create(itp_opi,ndim,nn,"g_ooii",chnk,g_o1o1i2i2))
& call errquit(pname//" nga_create failed for g_iioo",60,GA_ERR)
nn(1) = n_out1
if (.not.nga_create(itp_opi,2,nn,"g_oi",chnk,g_oi))
& call errquit(pname//" nga_create failed for g_oi",70,GA_ERR)
nn(2) = n_out1
if (.not.nga_create(itp_opi,ndim,nn,"g_oooo",chnk,g_o1o1o2o2))
& call errquit(pname//" nga_create failed for g_oooo",80,GA_ERR)
!
! Transform the indeces of electron 1
!
do ll = 1, n_in1
do kk = 1, n_in1
!
alo(1) = 1
alo(2) = 1
alo(3) = -1
alo(4) = -1
ahi(1) = n_out1
ahi(2) = n_in1
ahi(3) = -2
ahi(4) = -2
!
blo(1) = 1
blo(2) = 1
blo(3) = kk
blo(4) = ll
bhi(1) = n_in1
bhi(2) = n_in1
bhi(3) = kk
bhi(4) = ll
!
clo(1) = 1
clo(2) = 1
clo(3) = -1
clo(4) = -1
chi(1) = n_out1
chi(2) = n_in1
chi(3) = -2
chi(4) = -2
!
! Transform index 1
!
call nga_matmul_patch('t','n',aa,bb,
& g_vec1, alo,ahi,
& g_i1i1i2i2,blo,bhi,
& g_oi, clo,chi)
!
alo(1) = 1
alo(2) = 1
alo(3) = -1
alo(4) = -1
ahi(1) = n_out1
ahi(2) = n_in1
ahi(3) = -2
ahi(4) = -2
!
blo(1) = 1
blo(2) = 1
blo(3) = -1
blo(4) = -1
bhi(1) = n_in1
bhi(2) = n_out1
bhi(3) = -2
bhi(4) = -2
!
clo(1) = 1
clo(2) = 1
clo(3) = kk
clo(4) = ll
chi(1) = n_out1
chi(2) = n_out1
chi(3) = kk
chi(4) = ll
!
! Transform index 2
!
call nga_matmul_patch('n','n',aa,bb,
& g_oi, alo,ahi,
& g_vec1, blo,bhi,
& g_o1o1i2i2,clo,chi)
!
enddo ! kk
enddo ! ll
!
! Transform the indeces of electron 2
!
do jj = 1, n_out1
do ii = 1, n_out1
!
alo(1) = 1
alo(2) = 1
alo(3) = -1
alo(4) = -1
ahi(1) = n_out1
ahi(2) = n_in1
ahi(3) = -2
ahi(4) = -2
!
blo(1) = ii
blo(2) = jj
blo(3) = 1
blo(4) = 1
bhi(1) = ii
bhi(2) = jj
bhi(3) = n_in1
bhi(4) = n_in1
!
clo(1) = 1
clo(2) = 1
clo(3) = -1
clo(4) = -1
chi(1) = n_out1
chi(2) = n_in1
chi(3) = -2
chi(4) = -2
!
! Transform index 3
!
call nga_matmul_patch('t','n',aa,bb,
& g_vec2, alo,ahi,
& g_o1o1i2i2,blo,bhi,
& g_oi, clo,chi)
!
alo(1) = 1
alo(2) = 1
alo(3) = -1
alo(4) = -1
ahi(1) = n_out1
ahi(2) = n_in1
ahi(3) = -2
ahi(4) = -2
!
blo(1) = 1
blo(2) = 1
blo(3) = -1
blo(4) = -1
bhi(1) = n_in1
bhi(2) = n_out1
bhi(3) = -2
bhi(4) = -2
!
clo(1) = ii
clo(2) = jj
clo(3) = 1
clo(4) = 1
chi(1) = ii
chi(2) = jj
chi(3) = n_out1
chi(4) = n_out1
!
! Transform index 4
!
call nga_matmul_patch('n','n',aa,bb,
& g_oi, alo,ahi,
& g_vec2, blo,bhi,
& g_o1o1o2o2,clo,chi)
!
enddo ! ii
enddo ! jj
!
g_opo = g_o1o1o2o2
if (.not.ga_destroy(g_oi))
& call errquit(pname//" failed to destroy g_oi",90,GA_ERR)
if (.not.ga_destroy(g_o1o1i2i2))
& call errquit(pname//" failed to destroy g_ooii",100,GA_ERR)
!
end subroutine noft_4index
!
!-----------------------------------------------------------------------
6 changes: 3 additions & 3 deletions src/noft/noft_lindep.fi
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@
call ga_copy(g_vec,g_s12)
else
if (.not.ga_destroy(g_s12))
& call errquit(pname//"destroy failed")
& call errquit(pname//"destroy failed",10,GA_ERR)
idims(1) = nao
idims(2) = nindep
if (.not.nga_create(MT_DBL,2,idims,"svecs",chunk,g_s12))
& call errquit(pname//"failed to allocate s12")
& call errquit(pname//"failed to allocate s12",20,GA_ERR)
vlo(1) = 1
vlo(2) = nao-nindep+1
vhi(1) = nao
Expand All @@ -124,7 +124,7 @@
call nga_copy_patch("N",g_vec,vlo,vhi,g_s12,slo,shi)
endif
if (.not.ga_destroy(g_vec))
& call errquit(pname//"failed to destroy g_vec",0,GA_ERR)
& call errquit(pname//"failed to destroy g_vec",30,GA_ERR)
end subroutine noft_lindep
!
!-----------------------------------------------------------------------
16 changes: 8 additions & 8 deletions src/noft/noft_load_vectors.fi
Original file line number Diff line number Diff line change
Expand Up @@ -132,35 +132,35 @@
!
tag = "noft:input vectors"
if (.not. rtdb_cput(rtdb, tag, 1, movecs_in)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",10,RTDB_ERR)
endif
tag = "noft:output vectors"
if (.not. rtdb_cput(rtdb, tag, 1, movecs_out)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",20,RTDB_ERR)
endif
tag = "noft:input s-vectors"
if (.not. rtdb_cput(rtdb, tag, 1, svecs_in)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",30,RTDB_ERR)
endif
tag = "noft:output s-vectors"
if (.not. rtdb_cput(rtdb, tag, 1, svecs_out)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",40,RTDB_ERR)
endif
tag = "noft:input m-vectors"
if (.not. rtdb_cput(rtdb, tag, 1, mvecs_in)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",50,RTDB_ERR)
endif
tag = "noft:output m-vectors"
if (.not. rtdb_cput(rtdb, tag, 1, mvecs_out)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",60,RTDB_ERR)
endif
tag = "noft:input t-vectors"
if (.not. rtdb_cput(rtdb, tag, 1, tvecs_in)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",70,RTDB_ERR)
endif
tag = "noft:output t-vectors"
if (.not. rtdb_cput(rtdb, tag, 1, tvecs_out)) then
call errquit("writing movecs to RTDB failed")
call errquit("writing movecs to RTDB failed",80,RTDB_ERR)
endif
!
end subroutine noft_load_vectors
Expand Down
1 change: 1 addition & 0 deletions src/noft/noft_module.F
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module noft
#include "noft_wavefunction_tp.fh"
contains
#include "noft_2index.fi"
#include "noft_4index.fi"
#include "noft_adjust_wavefunction.fi"
#include "noft_compute_1electron.fi"
#include "noft_compute_2electron.fi"
Expand Down
2 changes: 1 addition & 1 deletion src/noft/noft_ortho_operators.fi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
g_vec = noft_wavefunction%so
call noft_2index(g_el_1,g_vec,g_out)
if (.not.ga_destroy(g_el_1)) then
call errquit(pname//" failed to destroy g_el_1")
call errquit(pname//" failed to destroy g_el_1",10,GA_ERR)
else
noft_oper%ao_el_1 = 0
endif
Expand Down

0 comments on commit 295e990

Please sign in to comment.