Skip to content

Commit

Permalink
Optimize OCL kernels (note: make clean, cmake .., qrack_cl_precompile)
Browse files Browse the repository at this point in the history
  • Loading branch information
WrathfulSpatula committed Aug 27, 2024
1 parent d642a58 commit 0aef1ce
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
21 changes: 9 additions & 12 deletions src/common/qengine.cl
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,15 @@ void kernel xmask(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr
}
}

void kernel phaseparity(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr, constant cmplx* cmplxPtr)
void kernel phaseparity(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr, constant cmplx2* cmplxPtr)
{
const bitCapIntOcl parityStartSize = 4U * sizeof(bitCapIntOcl);
const bitCapIntOcl Nthreads = get_global_size(0);
const bitCapIntOcl2 args = vload2(0, bitCapIntOclPtr);
const bitCapIntOcl maxI = args.x;
const bitCapIntOcl mask = args.y;
const bitCapIntOcl otherMask = bitCapIntOclPtr[2];
const cmplx phaseFac = cmplxPtr[0];
const cmplx iPhaseFac = cmplxPtr[1];
const cmplx2 phaseFac = cmplxPtr[0];

for (bitCapIntOcl lcv = ID; lcv < maxI; lcv += Nthreads) {
bitCapIntOcl setInt = lcv & mask;
Expand All @@ -312,7 +311,7 @@ void kernel phaseparity(global cmplx* stateVec, constant bitCapIntOcl* bitCapInt

setInt |= lcv & otherMask;

stateVec[setInt] = zmul(v ? phaseFac : iPhaseFac, stateVec[setInt]);
stateVec[setInt] = zmul(v ? phaseFac.lo : phaseFac.hi, stateVec[setInt]);
}
}

Expand Down Expand Up @@ -450,22 +449,21 @@ void kernel uniformlycontrolled(global cmplx* stateVec, constant bitCapIntOcl* b
}
}

void kernel uniformparityrz(global cmplx* stateVec, constant bitCapIntOcl2* bitCapIntOclPtr, constant cmplx* cmplx_ptr)
void kernel uniformparityrz(global cmplx* stateVec, constant bitCapIntOcl2* bitCapIntOclPtr, constant cmplx2* cmplx_ptr)
{
const bitCapIntOcl Nthreads = get_global_size(0);
const bitCapIntOcl2 args = bitCapIntOclPtr[0];
const bitCapIntOcl maxI = args.x;
const bitCapIntOcl qMask = args.y;
const cmplx phaseFac = cmplx_ptr[0];
const cmplx phaseFacAdj = cmplx_ptr[1];
const cmplx2 phaseFac = cmplx_ptr[0];
for (bitCapIntOcl lcv = ID; lcv < maxI; lcv += Nthreads) {
bitCapIntOcl perm = lcv & qMask;
bitLenInt c;
for (c = 0; perm; c++) {
// clear the least significant bit set
perm &= perm - ONE_BCI;
}
stateVec[lcv] = zmul(stateVec[lcv], ((c & 1U) ? phaseFac : phaseFacAdj));
stateVec[lcv] = zmul(stateVec[lcv], ((c & 1U) ? phaseFac.lo : phaseFac.hi));
}
}

Expand All @@ -490,16 +488,15 @@ void kernel uniformparityrznorm(global cmplx* stateVec, constant bitCapIntOcl2*
}
}

void kernel cuniformparityrz(global cmplx* stateVec, constant bitCapIntOcl4* bitCapIntOclPtr, constant cmplx* cmplx_ptr, constant bitCapIntOcl* qPowers)
void kernel cuniformparityrz(global cmplx* stateVec, constant bitCapIntOcl4* bitCapIntOclPtr, constant cmplx2* cmplx_ptr, constant bitCapIntOcl* qPowers)
{
const bitCapIntOcl Nthreads = get_global_size(0);
const bitCapIntOcl4 args = bitCapIntOclPtr[0];
const bitCapIntOcl maxI = args.x;
const bitCapIntOcl qMask = args.y;
const bitCapIntOcl cMask = args.z;
const bitLenInt cLen = (bitLenInt)args.w;
const cmplx phaseFac = cmplx_ptr[0];
const cmplx phaseFacAdj = cmplx_ptr[1];
const cmplx2 phaseFac = cmplx_ptr[0];

for (bitCapIntOcl lcv = ID; lcv < maxI; lcv += Nthreads) {
bitCapIntOcl iHigh = lcv;
Expand All @@ -517,7 +514,7 @@ void kernel cuniformparityrz(global cmplx* stateVec, constant bitCapIntOcl4* bit
// clear the least significant bit set
perm &= perm - ONE_BCI;
}
stateVec[i] = zmul(stateVec[i], ((c & 1U) ? phaseFac : phaseFacAdj));
stateVec[i] = zmul(stateVec[i], ((c & 1U) ? phaseFac.lo : phaseFac.hi));
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/qengine/opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,8 @@ bitLenInt QEngineOCL::Compose(QEngineOCLPtr toCopy, bitLenInt start)
const bitCapIntOcl startMask = pow2Ocl(start) - 1U;
const bitCapIntOcl midMask = bitRegMaskOcl(start, oQubitCount);
const bitCapIntOcl endMask = pow2MaskOcl(qubitCount + oQubitCount) & ~(startMask | midMask);
const bitCapIntOcl bciArgs[BCI_ARG_LEN]{ nMaxQPower, oQubitCount, startMask, midMask, endMask, start, 0U, 0U, 0U, 0U };
const bitCapIntOcl bciArgs[BCI_ARG_LEN]{ nMaxQPower, oQubitCount, startMask, midMask, endMask, start, 0U, 0U, 0U,
0U };

Compose(OCL_API_COMPOSE_MID, bciArgs, toCopy);

Expand Down

0 comments on commit 0aef1ce

Please sign in to comment.