Skip to content

Commit

Permalink
Add oogs (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
stgeke authored Aug 8, 2020
1 parent 59ca31c commit 3e27736
Show file tree
Hide file tree
Showing 44 changed files with 622 additions and 528 deletions.
6 changes: 2 additions & 4 deletions 3rd_party/gslib/ogs/include/ogsKernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,11 @@ namespace ogs {


extern occa::kernel gatherKernel_floatAdd;
extern occa::kernel gatherKernel_floatAddSelf;
extern occa::kernel gatherKernel_floatMul;
extern occa::kernel gatherKernel_floatMin;
extern occa::kernel gatherKernel_floatMax;

extern occa::kernel gatherKernel_doubleAdd;
extern occa::kernel gatherKernel_doubleAddSelf;
extern occa::kernel gatherKernel_doubleMul;
extern occa::kernel gatherKernel_doubleMin;
extern occa::kernel gatherKernel_doubleMax;
Expand All @@ -129,7 +127,6 @@ namespace ogs {
extern occa::kernel gatherKernel_longMax;



extern occa::kernel gatherVecKernel_floatAdd;
extern occa::kernel gatherVecKernel_floatMul;
extern occa::kernel gatherVecKernel_floatMin;
Expand All @@ -151,7 +148,6 @@ namespace ogs {
extern occa::kernel gatherVecKernel_longMax;



extern occa::kernel gatherManyKernel_floatAdd;
extern occa::kernel gatherManyKernel_floatMul;
extern occa::kernel gatherManyKernel_floatMin;
Expand Down Expand Up @@ -193,6 +189,8 @@ namespace ogs {

void initKernels(MPI_Comm comm, occa::device device);

extern occa::properties kernelInfo;

void freeKernels();
}

Expand Down
20 changes: 13 additions & 7 deletions 3rd_party/gslib/ogs/ogs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ SOFTWARE.
#ifndef OGS_HPP
#define OGS_HPP 1

#include <functional>
#include <math.h>
#include <stdlib.h>
#include <occa.hpp>

#include "mpi.h"
#include "types.h"

#define OGS_ENABLE_TIMER
//#define OGS_ENABLE_TIMER
#ifdef OGS_ENABLE_TIMER
#include "timer.hpp"
#endif
Expand Down Expand Up @@ -246,18 +247,23 @@ typedef struct {
occa::memory o_scatterOffsets, o_gatherOffsets;
occa::memory o_scatterIds, o_gatherIds;

occa::kernel packBufDoubleKernel, unpackBufDoubleKernel;
occa::kernel packBufFloatKernel, unpackBufFloatKernel;

oogs_mode mode;

} oogs_t;

namespace oogs{

void gatherScatter(void *v, const char *type, const char *op, oogs_t *h);
void gatherScatter(occa::memory o_v, const char *type, const char *op, oogs_t *h);
void start(occa::memory o_v, const char *type, const char *op, oogs_t *h);
void finish(occa::memory o_v, const char *type, const char *op, oogs_t *h);
oogs_t *setup(dlong N, hlong *ids, const char *type, MPI_Comm &comm,
int verbose, occa::device device, oogs_mode mode);
void start(occa::memory o_v, const int k, const dlong stride, const char *type, const char *op, oogs_t *h);
void finish(occa::memory o_v, const int k, const dlong stride, const char *type, const char *op, oogs_t *h);
void startFinish(void *v, const int k, const dlong stride, const char *type, const char *op, oogs_t *h);
void startFinish(occa::memory o_v, const int k, const dlong stride, const char *type, const char *op, oogs_t *h);
oogs_t *setup(ogs_t *ogs, int nVec, dlong stride, const char *type, std::function<void()> callback, oogs_mode gsMode);
oogs_t *setup(dlong N, hlong *ids, const int k, const dlong stride, const char *type, MPI_Comm &comm,
int verbose, occa::device device, std::function<void()> callback, oogs_mode mode);
void destroy(oogs_t *h);

}

Expand Down
44 changes: 0 additions & 44 deletions 3rd_party/gslib/ogs/okl/gather.okl
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,6 @@ SOFTWARE.
}
}

@kernel void gather_floatAddSelf(const dlong Ngather,
@restrict const dlong * gatherStarts,
@restrict const dlong * gatherIds,
@restrict const float * q,
@restrict float * gatherq){

for(dlong g=0;g<Ngather;++g;@tile(256,@outer,@inner)){

const dlong start = gatherStarts[g];
const dlong end = gatherStarts[g+1];

float gq = 0.f;
for(dlong n=start;n<end;++n){
const dlong id = gatherIds[n];
gq += q[id];
}

//contiguously packed
gatherq[g] += gq;
}
}

@kernel void gather_doubleAdd(const dlong Ngather,
@restrict const dlong * gatherStarts,
@restrict const dlong * gatherIds,
Expand All @@ -91,28 +69,6 @@ SOFTWARE.
}
}

@kernel void gather_doubleAddSelf(const dlong Ngather,
@restrict const dlong * gatherStarts,
@restrict const dlong * gatherIds,
@restrict const double * q,
@restrict double * gatherq){

for(dlong g=0;g<Ngather;++g;@tile(256,@outer,@inner)){

const dlong start = gatherStarts[g];
const dlong end = gatherStarts[g+1];

double gq = 0.f;
for(dlong n=start;n<end;++n){
const dlong id = gatherIds[n];
gq += q[id];
}

//contiguously packed
gatherq[g] += gq;
}
}

@kernel void gather_intAdd(const dlong Ngather,
@restrict const dlong * gatherStarts,
@restrict const dlong * gatherIds,
Expand Down
97 changes: 97 additions & 0 deletions 3rd_party/gslib/ogs/okl/oogs.okl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
@kernel void packBuf_float(const dlong Nscatter,
const int Nentries,
@restrict const dlong * scatterStarts,
@restrict const dlong * scatterIds,
@restrict const float * q,
@restrict float * scatterq){

for(dlong s=0;s<Nscatter*Nentries;++s;@tile(256,@outer,@inner)){

const float qs = q[s];

const dlong sid = s%Nscatter;
const int k = s/Nscatter;
const dlong start = scatterStarts[sid];
const dlong end = scatterStarts[sid+1];

for(dlong n=start;n<end;++n){
const dlong id = scatterIds[n];
scatterq[id*Nentries+k] = qs;
}
}
}

@kernel void packBuf_double(const dlong Nscatter,
const int Nentries,
@restrict const dlong * scatterStarts,
@restrict const dlong * scatterIds,
@restrict const double * q,
@restrict double * scatterq){

for(dlong s=0;s<Nscatter*Nentries;++s;@tile(256,@outer,@inner)){

const double qs = q[s];

const dlong sid = s%Nscatter;
const int k = s/Nscatter;
const dlong start = scatterStarts[sid];
const dlong end = scatterStarts[sid+1];

for(dlong n=start;n<end;++n){
const dlong id = scatterIds[n];
scatterq[id*Nentries+k] = qs;
}
}
}

@kernel void unpackBuf_float(const dlong Ngather,
const int Nentries,
@restrict const dlong * gatherStarts,
@restrict const dlong * gatherIds,
@restrict const float * q,
@restrict float * gatherq){

for(dlong g=0;g<Ngather*Nentries;++g;@tile(256,@outer,@inner)){

const dlong gid = g%Ngather;
const int k = g/Ngather;
const dlong start = gatherStarts[gid];
const dlong end = gatherStarts[gid+1];

float gq = 0.f;
for(dlong n=start;n<end;++n){
const dlong id = gatherIds[n];
gq += q[id*Nentries+k];
}

//contiguously packed
gatherq[g] += gq;
}
}

@kernel void unpackBuf_double(const dlong Ngather,
const int Nentries,
@restrict const dlong * gatherStarts,
@restrict const dlong * gatherIds,
@restrict const double * q,
@restrict double * gatherq){

for(dlong g=0;g<Ngather*Nentries;++g;@tile(256,@outer,@inner)){

const dlong gid = g%Ngather;
const int k = g/Ngather;
const dlong start = gatherStarts[gid];
const dlong end = gatherStarts[gid+1];

double gq = 0.f;
for(dlong n=start;n<end;++n){
const dlong id = gatherIds[n];
gq += q[id*Nentries+k];
}

//contiguously packed
gatherq[g] += gq;
}
}


2 changes: 1 addition & 1 deletion 3rd_party/gslib/ogs/okl/scatterMany.okl
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ SOFTWARE.
scatterq[id+k*sstride] = qs;
}
}
}
}
4 changes: 0 additions & 4 deletions 3rd_party/gslib/ogs/src/ogsGather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,6 @@ void occaGather(const dlong Ngather,

if ((!strcmp(type, "float"))&&(!strcmp(op, "add")))
ogs::gatherKernel_floatAdd(Ngather, o_gatherStarts, o_gatherIds, o_v, o_gv);
else if ((!strcmp(type, "float"))&&(!strcmp(op, "add+self")))
ogs::gatherKernel_floatAddSelf(Ngather, o_gatherStarts, o_gatherIds, o_v, o_gv);
else if ((!strcmp(type, "float"))&&(!strcmp(op, "mul")))
ogs::gatherKernel_floatMul(Ngather, o_gatherStarts, o_gatherIds, o_v, o_gv);
else if ((!strcmp(type, "float"))&&(!strcmp(op, "min")))
Expand All @@ -323,8 +321,6 @@ void occaGather(const dlong Ngather,
ogs::gatherKernel_floatMax(Ngather, o_gatherStarts, o_gatherIds, o_v, o_gv);
else if ((!strcmp(type, "double"))&&(!strcmp(op, "add")))
ogs::gatherKernel_doubleAdd(Ngather, o_gatherStarts, o_gatherIds, o_v, o_gv);
else if ((!strcmp(type, "double"))&&(!strcmp(op, "add+self")))
ogs::gatherKernel_doubleAddSelf(Ngather, o_gatherStarts, o_gatherIds, o_v, o_gv);
else if ((!strcmp(type, "double"))&&(!strcmp(op, "mul")))
ogs::gatherKernel_doubleMul(Ngather, o_gatherStarts, o_gatherIds, o_v, o_gv);
else if ((!strcmp(type, "double"))&&(!strcmp(op, "min")))
Expand Down
7 changes: 0 additions & 7 deletions 3rd_party/gslib/ogs/src/ogsGatherMany.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ void ogsGatherManyStart(occa::memory o_gv,
const char *type,
const char *op,
ogs_t *ogs){

size_t Nbytes;
if (!strcmp(type, "float"))
Nbytes = sizeof(float);
Expand Down Expand Up @@ -116,14 +115,8 @@ void ogsGatherManyFinish(occa::memory o_gv,
void* H[k];
for (int i=0;i<k;i++) H[i] = (char*)ogs::haloBuf + i*ogs->NhaloGather*Nbytes;

#ifdef OGS_ENABLE_TIMER
timer::tic("gsMPI",1);
#endif
// MPI based gather using libgs
ogsHostGatherMany(H, k, type, op, ogs->haloGshNonSym);
#ifdef OGS_ENABLE_TIMER
timer::toc("gsMPI");
#endif

// copy totally gather halo data back from HOST to DEVICE
if (ogs->NownedHalo)
Expand Down
7 changes: 0 additions & 7 deletions 3rd_party/gslib/ogs/src/ogsGatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ void ogsGatherScatterStart(occa::memory o_v,
const char *type,
const char *op,
ogs_t *ogs){

size_t Nbytes;

if (!strcmp(type, "double"))
Expand Down Expand Up @@ -104,13 +103,7 @@ void ogsGatherScatterFinish(occa::memory o_v,
ogs->device.finish();

// MPI based gather scatter using libgs
#ifdef OGS_ENABLE_TIMER
timer::tic("gsMPI",1);
#endif
ogsHostGatherScatter(ogs::haloBuf, type, op, ogs->haloGshSym);
#ifdef OGS_ENABLE_TIMER
timer::toc("gsMPI");
#endif

// copy totally gather halo data back from HOST to DEVICE
ogs::o_haloBuf.copyFrom(ogs::haloBuf, ogs->NhaloGather*Nbytes, 0, "async: true");
Expand Down
7 changes: 0 additions & 7 deletions 3rd_party/gslib/ogs/src/ogsGatherScatterMany.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ void ogsGatherScatterManyStart(occa::memory o_v,
const char *type,
const char *op,
ogs_t *ogs){

size_t Nbytes;
if (!strcmp(type, "float"))
Nbytes = sizeof(float);
Expand Down Expand Up @@ -127,14 +126,8 @@ void ogsGatherScatterManyFinish(occa::memory o_v,
void* H[k];
for (int i=0;i<k;i++) H[i] = (char*)ogs::haloBuf + i*ogs->NhaloGather*Nbytes;

#ifdef OGS_ENABLE_TIMER
timer::tic("gsMPI",1);
#endif
// MPI based gather scatter using libgs
ogsHostGatherScatterMany(H, k, type, op, ogs->haloGshSym);
#ifdef OGS_ENABLE_TIMER
timer::toc("gsMPI");
#endif

// copy totally gather halo data back from HOST to DEVICE
ogs::o_haloBuf.copyFrom(ogs::haloBuf, ogs->NhaloGather*Nbytes*k, 0, "async: true");
Expand Down
7 changes: 0 additions & 7 deletions 3rd_party/gslib/ogs/src/ogsGatherScatterVec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ void ogsGatherScatterVecStart(occa::memory o_v,
const char *type,
const char *op,
ogs_t *ogs){

size_t Nbytes;
if (!strcmp(type, "float"))
Nbytes = sizeof(float);
Expand Down Expand Up @@ -106,14 +105,8 @@ void ogsGatherScatterVecFinish(occa::memory o_v,
ogs->device.setStream(ogs::dataStream);
ogs->device.finish();

#ifdef OGS_ENABLE_TIMER
timer::tic("gsMPI",1);
#endif
// MPI based gather scatter using libgs
ogsHostGatherScatterVec(ogs::haloBuf, k, type, op, ogs->haloGshSym);
#ifdef OGS_ENABLE_TIMER
timer::toc("gsMPI");
#endif

// copy totally gather halo data back from HOST to DEVICE
ogs::o_haloBuf.copyFrom(ogs::haloBuf, ogs->NhaloGather*Nbytes*k, 0, "async: true");
Expand Down
8 changes: 0 additions & 8 deletions 3rd_party/gslib/ogs/src/ogsGatherVec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ void ogsGatherVecStart(occa::memory o_gv,
const char *type,
const char *op,
ogs_t *ogs){

size_t Nbytes;
if (!strcmp(type, "float"))
Nbytes = sizeof(float);
Expand Down Expand Up @@ -80,7 +79,6 @@ void ogsGatherVecStart(occa::memory o_gv,
ogs::o_haloBuf.copyTo(ogs::haloBuf, ogs->NhaloGather*Nbytes*k, 0, "async: true");
ogs->device.setStream(ogs::defaultStream);
}

}


Expand Down Expand Up @@ -108,14 +106,8 @@ void ogsGatherVecFinish(occa::memory o_gv,
ogs->device.setStream(ogs::dataStream);
ogs->device.finish();

#ifdef OGS_ENABLE_TIMER
timer::tic("gsMPI",1);
#endif
// MPI based gather using libgs
ogsHostGatherVec(ogs::haloBuf, k, type, op, ogs->haloGshNonSym);
#ifdef OGS_ENABLE_TIMER
timer::toc("gsMPI");
#endif

// copy totally gather halo data back from HOST to DEVICE
if (ogs->NownedHalo)
Expand Down
2 changes: 1 addition & 1 deletion 3rd_party/gslib/ogs/src/ogsHostSetup.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void *ogsHostSetup(MPI_Comm meshComm,
id[n] = (slong) gatherGlobalNodes[n];
}

struct gs_data *gsh = gs_setup(id, NuniqueBases, &com, nonsymm, gs_pairwise, verbose); // gs_auto, gs_crystal_router, gs_pw
struct gs_data *gsh = gs_setup(id, NuniqueBases, &com, nonsymm, gs_pairwise, 0); // gs_auto, gs_crystal_router, gs_pw

free(id);

Expand Down
Loading

0 comments on commit 3e27736

Please sign in to comment.