Skip to content

Commit

Permalink
parallel code blocks for stencil (for multi rhs and block decomposition)
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Oct 6, 2023
1 parent 05f2ac1 commit 2a3bc02
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 46 deletions.
4 changes: 2 additions & 2 deletions lib/cgpt/lib/lattice/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class cgpt_Lattice_base {
virtual void invert_matrix(std::vector<cgpt_Lattice_base*>& matrix_inv, std::vector<cgpt_Lattice_base*>& matrix, long n_virtual) = 0;
virtual void determinant(cgpt_Lattice_base* det, std::vector<cgpt_Lattice_base*>& matrix, long n_virtual) = 0; // this determines type of matrix[0]
virtual GridBase* get_grid() = 0;
virtual cgpt_stencil_matrix_base* stencil_matrix(GridBase* grid, PyObject* shifts, PyObject* code) = 0;
virtual cgpt_stencil_matrix_vector_base* stencil_matrix_vector(cgpt_Lattice_base* matrix, GridBase* grid, PyObject* shifts, PyObject* code) = 0;
virtual cgpt_stencil_matrix_base* stencil_matrix(GridBase* grid, PyObject* shifts, PyObject* code, long code_parallel_block_size) = 0;
virtual cgpt_stencil_matrix_vector_base* stencil_matrix_vector(cgpt_Lattice_base* matrix, GridBase* grid, PyObject* shifts, PyObject* code, long code_parallel_block_size) = 0;

};

Expand Down
8 changes: 4 additions & 4 deletions lib/cgpt/lib/lattice/implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ class cgpt_Lattice : public cgpt_Lattice_base {
return l.Grid();
}

virtual cgpt_stencil_matrix_base* stencil_matrix(GridBase* grid, PyObject* shifts, PyObject* code) {
return cgpt_stencil_matrix_create<T>(grid, shifts, code);
virtual cgpt_stencil_matrix_base* stencil_matrix(GridBase* grid, PyObject* shifts, PyObject* code, long code_parallel_block_size) {
return cgpt_stencil_matrix_create<T>(grid, shifts, code, code_parallel_block_size);
}

virtual cgpt_stencil_matrix_vector_base* stencil_matrix_vector(cgpt_Lattice_base* matrix, GridBase* grid, PyObject* shifts, PyObject* code) {
return cgpt_stencil_matrix_vector_create<T>(matrix, grid, shifts, code);
virtual cgpt_stencil_matrix_vector_base* stencil_matrix_vector(cgpt_Lattice_base* matrix, GridBase* grid, PyObject* shifts, PyObject* code, long code_parallel_block_size) {
return cgpt_stencil_matrix_vector_create<T>(matrix, grid, shifts, code, code_parallel_block_size);
}

};
14 changes: 10 additions & 4 deletions lib/cgpt/lib/stencil.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ EXPORT(stencil_matrix_create,{
void* _grid;
void* _lattice;
PyObject* _shifts, * _code;
if (!PyArg_ParseTuple(args, "llOO", &_lattice, &_grid, &_shifts, &_code)) {
long _code_parallel_block_size;
if (!PyArg_ParseTuple(args, "llOOl", &_lattice, &_grid, &_shifts, &_code,
&_code_parallel_block_size)) {
return NULL;
}

GridBase* grid = (GridBase*)_grid;
cgpt_Lattice_base* lattice = (cgpt_Lattice_base*)_lattice;

return PyLong_FromVoidPtr(lattice->stencil_matrix(grid, _shifts, _code));
return PyLong_FromVoidPtr(lattice->stencil_matrix(grid, _shifts, _code,
_code_parallel_block_size));
});

EXPORT(stencil_matrix_vector_create,{
Expand All @@ -39,15 +42,18 @@ EXPORT(stencil_matrix_vector_create,{
void* _lattice_matrix;
void* _lattice_vector;
PyObject* _shifts, * _code;
if (!PyArg_ParseTuple(args, "lllOO", &_lattice_matrix, &_lattice_vector, &_grid, &_shifts, &_code)) {
long _code_parallel_block_size;
if (!PyArg_ParseTuple(args, "lllOOl", &_lattice_matrix, &_lattice_vector, &_grid, &_shifts, &_code,
&_code_parallel_block_size)) {
return NULL;
}

GridBase* grid = (GridBase*)_grid;
cgpt_Lattice_base* lattice_matrix = (cgpt_Lattice_base*)_lattice_matrix;
cgpt_Lattice_base* lattice_vector = (cgpt_Lattice_base*)_lattice_vector;

return PyLong_FromVoidPtr(lattice_vector->stencil_matrix_vector(lattice_matrix, grid, _shifts, _code));
return PyLong_FromVoidPtr(lattice_vector->stencil_matrix_vector(lattice_matrix, grid, _shifts, _code,
_code_parallel_block_size));
});

EXPORT(stencil_matrix_execute,{
Expand Down
33 changes: 23 additions & 10 deletions lib/cgpt/lib/stencil/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,18 @@ class cgpt_stencil_matrix : public cgpt_stencil_matrix_base {
GeneralLocalStencil stencil;
Vector<cgpt_stencil_matrix_code_offload_t> code;
Vector<cgpt_stencil_matrix_factor_t> factors;

int n_code_parallel_block_size, n_code_parallel_blocks;

cgpt_stencil_matrix(GridBase* grid,
const std::vector<Coordinate>& shifts,
const std::vector<cgpt_stencil_matrix_code_t>& _code) :
stencil(grid,shifts), code(_code.size()) {
const std::vector<cgpt_stencil_matrix_code_t>& _code,
int _n_code_parallel_block_size) :
stencil(grid,shifts), code(_code.size()),
n_code_parallel_block_size(_n_code_parallel_block_size) {

ASSERT(_code.size() % n_code_parallel_block_size == 0);
n_code_parallel_blocks = (int)_code.size() / n_code_parallel_block_size;

// total number of factors
int nfactors = 0;
for (int i=0;i<_code.size();i++)
Expand All @@ -83,25 +90,31 @@ class cgpt_stencil_matrix : public cgpt_stencil_matrix_base {
VECTOR_VIEW_OPEN(fields,fields_v,AcceleratorWrite);

int n_code = code.size();
cgpt_stencil_matrix_code_offload_t* p_code = &code[0];
const cgpt_stencil_matrix_code_offload_t* p_code = &code[0];

typedef decltype(coalescedRead(fields_v[0][0])) obj_t;

int nd = fields[0].Grid()->Nd();

auto sview = stencil.View();

accelerator_for(ss,fields[0].Grid()->oSites(),T::Nsimd(),{
accelerator_for(ss_block,fields[0].Grid()->oSites() * n_code_parallel_blocks,M::Nsimd(),{

auto ss = ss_block / n_code_parallel_blocks;
auto oblock = ss_block % n_code_parallel_blocks;

for (int iblock=0;iblock<n_code_parallel_block_size;iblock++) {

int i = oblock * n_code_parallel_block_size + iblock;

for (int i=0;i<n_code;i++) {
obj_t t;

auto _f0 = &p_code[i].factor[0];
fetch(t, _f0->point, ss, fields_v[_f0->index], _f0->adj);

for (int j=1;j<p_code[i].size;j++) {
obj_t f;
auto _f = &p_code[i].factor[j];
const auto _f = &p_code[i].factor[j];
fetch(f, _f->point, ss, fields_v[_f->index], _f->adj);
t = t * f;
}
Expand Down Expand Up @@ -145,20 +158,20 @@ static void cgpt_convert(PyObject* in, cgpt_stencil_matrix_code_t& out) {
// not implemented message
template<typename T>
NotEnableIf<isEndomorphism<T>,cgpt_stencil_matrix_base*>
cgpt_stencil_matrix_create(GridBase* grid, PyObject* _shifts, PyObject* _code) {
cgpt_stencil_matrix_create(GridBase* grid, PyObject* _shifts, PyObject* _code, long code_parallel_block_size) {
ERR("cgpt_stencil_matrix not implemented for type %s",typeid(T).name());
}

// implemented for endomorphisms
template<typename T>
EnableIf<isEndomorphism<T>,cgpt_stencil_matrix_base*>
cgpt_stencil_matrix_create(GridBase* grid, PyObject* _shifts, PyObject* _code) {
cgpt_stencil_matrix_create(GridBase* grid, PyObject* _shifts, PyObject* _code, long code_parallel_block_size) {

std::vector<Coordinate> shifts;
cgpt_convert(_shifts,shifts);

std::vector<cgpt_stencil_matrix_code_t> code;
cgpt_convert(_code,code);

return new cgpt_stencil_matrix<T>(grid,shifts,code);
return new cgpt_stencil_matrix<T>(grid,shifts,code,code_parallel_block_size);
}
37 changes: 25 additions & 12 deletions lib/cgpt/lib/stencil/matrix_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,18 @@ class cgpt_stencil_matrix_vector : public cgpt_stencil_matrix_vector_base {
GeneralLocalStencil stencil;
Vector<cgpt_stencil_matrix_vector_code_offload_t> code;
Vector<cgpt_stencil_matrix_vector_factor_t> factors;
int n_code_parallel_block_size, n_code_parallel_blocks;

cgpt_stencil_matrix_vector(GridBase* grid,
const std::vector<Coordinate>& shifts,
const std::vector<cgpt_stencil_matrix_vector_code_t>& _code) :
stencil(grid,shifts), code(_code.size()) {
const std::vector<Coordinate>& shifts,
const std::vector<cgpt_stencil_matrix_vector_code_t>& _code,
int _n_code_parallel_block_size) :
stencil(grid,shifts), code(_code.size()),
n_code_parallel_block_size(_n_code_parallel_block_size) {

ASSERT(_code.size() % n_code_parallel_block_size == 0);
n_code_parallel_blocks = (int)_code.size() / n_code_parallel_block_size;

// total number of factors
int nfactors = 0;
for (int i=0;i<_code.size();i++)
Expand Down Expand Up @@ -89,7 +96,7 @@ class cgpt_stencil_matrix_vector : public cgpt_stencil_matrix_vector_base {
VECTOR_VIEW_OPEN(vector_fields,fields_v_v,AcceleratorWrite);

int n_code = code.size();
cgpt_stencil_matrix_vector_code_offload_t* p_code = &code[0];
const cgpt_stencil_matrix_vector_code_offload_t* p_code = &code[0];

typedef decltype(coalescedRead(fields_v_v[0][0])) obj_v_t;
typedef decltype(coalescedRead(fields_m_v[0][0])) obj_m_t;
Expand All @@ -98,16 +105,21 @@ class cgpt_stencil_matrix_vector : public cgpt_stencil_matrix_vector_base {

auto sview = stencil.View();

accelerator_for(ss,matrix_fields[0].Grid()->oSites(),M::Nsimd(),{
accelerator_for(ss_block,matrix_fields[0].Grid()->oSites() * n_code_parallel_blocks,M::Nsimd(),{

auto ss = ss_block / n_code_parallel_blocks;
auto oblock = ss_block % n_code_parallel_blocks;

for (int iblock=0;iblock<n_code_parallel_block_size;iblock++) {

for (int i=0;i<n_code;i++) {
int i = oblock * n_code_parallel_block_size + iblock;
obj_v_t t;

fetch(t, p_code[i].source_point, ss, fields_v_v[p_code[i].source], 0);

for (int j=p_code[i].size-1;j>=0;j--) {
obj_m_t f;
auto _f = &p_code[i].factor[j];
const auto _f = &p_code[i].factor[j];
fetch(f, _f->point, ss, fields_m_v[_f->index], _f->adj);
t = f * t;
}
Expand Down Expand Up @@ -155,7 +167,8 @@ static void cgpt_convert(PyObject* in, cgpt_stencil_matrix_vector_code_t& out) {

template<typename V>
cgpt_stencil_matrix_vector_base*
cgpt_stencil_matrix_vector_create(cgpt_Lattice_base* __matrix, GridBase* grid, PyObject* _shifts, PyObject* _code) {
cgpt_stencil_matrix_vector_create(cgpt_Lattice_base* __matrix, GridBase* grid, PyObject* _shifts, PyObject* _code,
long code_parallel_block_size) {

std::vector<Coordinate> shifts;
cgpt_convert(_shifts,shifts);
Expand All @@ -166,13 +179,13 @@ cgpt_stencil_matrix_vector_create(cgpt_Lattice_base* __matrix, GridBase* grid, P
// test __matrix type against matrix in spin space,
// color space spin+color space, and singlet space
if (is_compatible<typename matrixFromTypeAtLevel<V,2>::type>(__matrix)) {
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel<V,2>::type,V>(grid,shifts,code);
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel<V,2>::type,V>(grid,shifts,code,code_parallel_block_size);
} else if (is_compatible<typename matrixFromTypeAtLevel<V,1>::type>(__matrix)) {
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel<V,1>::type,V>(grid,shifts,code);
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel<V,1>::type,V>(grid,shifts,code,code_parallel_block_size);
} else if (is_compatible<typename matrixFromTypeAtLevel<V,0>::type>(__matrix)) {
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel<V,0>::type,V>(grid,shifts,code);
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel<V,0>::type,V>(grid,shifts,code,code_parallel_block_size);
} else if (is_compatible<typename matrixFromTypeAtLevel2<V,1,2>::type>(__matrix)) {
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel2<V,1,2>::type,V>(grid,shifts,code);
return new cgpt_stencil_matrix_vector<typename matrixFromTypeAtLevel2<V,1,2>::type,V>(grid,shifts,code,code_parallel_block_size);
} else {
ERR("Unknown matrix type for matrix_vector stencil with vector type %s",typeid(V).name());
}
Expand Down
9 changes: 7 additions & 2 deletions lib/gpt/core/stencil/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ def parse(c):


class matrix:
def __init__(self, lat, points, code):
def __init__(self, lat, points, code, code_parallel_block_size=None):
self.points = points
self.code = [parse(c) for c in code]
self.obj = cgpt.stencil_matrix_create(lat.v_obj[0], lat.grid.obj, points, self.code)
self.code_parallel_block_size = code_parallel_block_size
if code_parallel_block_size is None:
code_parallel_block_size = len(code)
self.obj = cgpt.stencil_matrix_create(
lat.v_obj[0], lat.grid.obj, points, self.code, code_parallel_block_size
)

def __call__(self, *fields):
cgpt.stencil_matrix_execute(self.obj, list(fields))
Expand Down
23 changes: 19 additions & 4 deletions lib/gpt/core/stencil/matrix_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,33 @@

def parse(c):
if isinstance(c, tuple):
assert len(c) == 4
return {"target": c[0], "accumulate": c[1], "weight": c[2], "factor": c[3]}
assert len(c) == 6
return {
"target": c[0],
"source": c[1],
"source_point": c[2],
"accumulate": c[3],
"weight": c[4],
"factor": c[5],
}
return c


class matrix_vector:
def __init__(self, lat_matrix, lat_vector, points, code):
def __init__(self, lat_matrix, lat_vector, points, code, code_parallel_block_size=None):
self.points = points
self.code = [parse(c) for c in code]
self.code_parallel_block_size = code_parallel_block_size
if code_parallel_block_size is None:
code_parallel_block_size = len(code)
assert lat_matrix.grid == lat_vector.grid
self.obj = cgpt.stencil_matrix_vector_create(
lat_matrix.v_obj[0], lat_vector.v_obj[0], lat_matrix.grid.obj, points, self.code
lat_matrix.v_obj[0],
lat_vector.v_obj[0],
lat_matrix.grid.obj,
points,
self.code,
code_parallel_block_size,
)

def __call__(self, matrix_fields, vector_fields):
Expand Down
32 changes: 24 additions & 8 deletions lib/gpt/qcd/fermion/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,40 @@ def register(reg, op):
reg.Mdiag = lambda dst, src: op.apply_unary_operator(2009, dst, src)
reg.Dminus = lambda dst, src: op.apply_unary_operator(2010, dst, src)
reg.DminusDag = lambda dst, src: op.apply_unary_operator(2011, dst, src)
reg.ImportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(2012, dst, src)
reg.ImportUnphysicalFermion = lambda dst, src: op.apply_unary_operator(2013, dst, src)
reg.ExportPhysicalFermionSolution = lambda dst, src: op.apply_unary_operator(2014, dst, src)
reg.ExportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(2015, dst, src)
reg.ImportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(
2012, dst, src
)
reg.ImportUnphysicalFermion = lambda dst, src: op.apply_unary_operator(
2013, dst, src
)
reg.ExportPhysicalFermionSolution = lambda dst, src: op.apply_unary_operator(
2014, dst, src
)
reg.ExportPhysicalFermionSource = lambda dst, src: op.apply_unary_operator(
2015, dst, src
)
reg.Dhop = lambda dst, src: op.apply_unary_operator(3001, dst, src)
reg.DhopDag = lambda dst, src: op.apply_unary_operator(4001, dst, src)
reg.DhopEO = lambda dst, src: op.apply_unary_operator(3002, dst, src)
reg.DhopEODag = lambda dst, src: op.apply_unary_operator(4002, dst, src)
reg.Mdir = lambda dst, src, dir, disp: op.apply_dirdisp_operator(5001, dst, src, dir, disp)
reg.Mdir = lambda dst, src, dir, disp: op.apply_dirdisp_operator(
5001, dst, src, dir, disp
)
reg.MDeriv = lambda mat, dst, src: op.apply_deriv_operator(6001, mat, dst, src)
reg.MDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7001, mat, dst, src)
reg.MoeDeriv = lambda mat, dst, src: op.apply_deriv_operator(6002, mat, dst, src)
reg.MoeDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7002, mat, dst, src)
reg.MeoDeriv = lambda mat, dst, src: op.apply_deriv_operator(6003, mat, dst, src)
reg.MeoDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7003, mat, dst, src)
reg.DhopDeriv = lambda mat, dst, src: op.apply_deriv_operator(6004, mat, dst, src)
reg.DhopDerivDag = lambda mat, dst, src: op.apply_deriv_operator(7004, mat, dst, src)
reg.DhopDerivDag = lambda mat, dst, src: op.apply_deriv_operator(
7004, mat, dst, src
)
reg.DhopDerivEO = lambda mat, dst, src: op.apply_deriv_operator(6005, mat, dst, src)
reg.DhopDerivEODag = lambda mat, dst, src: op.apply_deriv_operator(7005, mat, dst, src)
reg.DhopDerivEODag = lambda mat, dst, src: op.apply_deriv_operator(
7005, mat, dst, src
)
reg.DhopDerivOE = lambda mat, dst, src: op.apply_deriv_operator(6006, mat, dst, src)
reg.DhopDerivOEDag = lambda mat, dst, src: op.apply_deriv_operator(7006, mat, dst, src)
reg.DhopDerivOEDag = lambda mat, dst, src: op.apply_deriv_operator(
7006, mat, dst, src
)

0 comments on commit 2a3bc02

Please sign in to comment.