diff --git a/include/common/oclapi.hpp b/include/common/oclapi.hpp index 83e494dfc..763fa0d13 100644 --- a/include/common/oclapi.hpp +++ b/include/common/oclapi.hpp @@ -56,6 +56,7 @@ enum OCLAPI { OCL_API_Z_SINGLE, OCL_API_Z_SINGLE_WIDE, OCL_API_PHASE_PARITY, + OCL_API_PHASE_MASK, OCL_API_ROL, OCL_API_APPROXCOMPARE, OCL_API_NORMALIZE, diff --git a/include/qengine_cpu.hpp b/include/qengine_cpu.hpp index a122628af..2cf422ac6 100644 --- a/include/qengine_cpu.hpp +++ b/include/qengine_cpu.hpp @@ -143,6 +143,7 @@ class QEngineCPU : public QEngine { void XMask(bitCapInt mask); void PhaseParity(real1_f radians, bitCapInt mask); + void PhaseRootNMask(bitLenInt n, bitCapInt mask); /** * \defgroup ArithGate Arithmetic and other opcode-like gate implemenations. diff --git a/include/qengine_opencl.hpp b/include/qengine_opencl.hpp index 0cdef5104..9b1009fe5 100644 --- a/include/qengine_opencl.hpp +++ b/include/qengine_opencl.hpp @@ -379,6 +379,7 @@ class QEngineOCL : public QEngine { void Phase(complex topLeft, complex bottomRight, bitLenInt qubitIndex); void XMask(bitCapInt mask); void PhaseParity(real1_f radians, bitCapInt mask); + void PhaseRootNMask(bitLenInt n, bitCapInt mask); using QEngine::Compose; bitLenInt Compose(QEngineOCLPtr toCopy); diff --git a/include/qinterface.hpp b/include/qinterface.hpp index fd2463e4f..53e71ad81 100644 --- a/include/qinterface.hpp +++ b/include/qinterface.hpp @@ -1037,6 +1037,13 @@ class QInterface : public ParallelFor { Phase(ONE_CMPLX, pow(-ONE_CMPLX, (real1)(ONE_R1 / pow2Ocl(n - 1U))), qubit); } + /** + * Masked PhaseRootN gate + * + * Applies a 1/(2^N) phase rotation to all qubits in the mask. + */ + virtual void PhaseRootNMask(bitLenInt n, bitCapInt mask); + /** * Inverse "PhaseRootN" gate * diff --git a/src/common/oclengine.cpp b/src/common/oclengine.cpp index d99686e53..214ee5d38 100644 --- a/src/common/oclengine.cpp +++ b/src/common/oclengine.cpp @@ -86,6 +86,7 @@ const std::vector OCLEngine::kernelHandles{ OCLKernelHandle(OCL_API_Z_SINGLE, "zsingle"), OCLKernelHandle(OCL_API_Z_SINGLE_WIDE, "zsinglewide"), OCLKernelHandle(OCL_API_PHASE_PARITY, "phaseparity"), + OCLKernelHandle(OCL_API_PHASE_MASK, "phasemask"), OCLKernelHandle(OCL_API_COMPOSE, "compose"), OCLKernelHandle(OCL_API_COMPOSE_WIDE, "compose"), OCLKernelHandle(OCL_API_COMPOSE_MID, "composemid"), diff --git a/src/common/qengine.cl b/src/common/qengine.cl index ad28e8ded..0aec3d29b 100644 --- a/src/common/qengine.cl +++ b/src/common/qengine.cl @@ -36,6 +36,10 @@ inline cmplx conj(const cmplx cmp) return (cmplx)(cmp.x, -cmp.y); } +inline cmplx polar_unit(real1 theta) { + return (cmplx)(cos(theta), sin(theta)); +} + #define OFFSET2_ARG bitCapIntOclPtr[0] #define OFFSET1_ARG bitCapIntOclPtr[1] #define MAXI_ARG bitCapIntOclPtr[2] @@ -310,6 +314,30 @@ void kernel phaseparity(global cmplx* stateVec, constant bitCapIntOcl* bitCapInt } } +void kernel phasemask(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr, constant real1* angle) +{ + const bitCapIntOcl Nthreads = get_global_size(0); + const bitCapIntOcl maxI = bitCapIntOclPtr[0]; + const bitCapIntOcl mask = bitCapIntOclPtr[1]; + const bitCapIntOcl nPhases = bitCapIntOclPtr[2]; + const real1 phaseAngle = angle[0]; + + for (bitCapIntOcl lcv = ID; lcv < maxI; lcv += Nthreads) { + bitCapIntOcl popCount = 0; + bitCapIntOcl v = lcv & mask; + while (v) { + popCount += v & 1; + v >>= 1; + } + + const bitCapIntOcl nPhaseSteps = popCount % nPhases; + if (nPhaseSteps != 0) { + const cmplx phaseFactor = polar_unit(nPhaseSteps * phaseAngle); + stateVec[lcv] = zmul(phaseFactor, stateVec[lcv]); + } + } +} + void kernel zsingle(global cmplx* stateVec, constant bitCapIntOcl* bitCapIntOclPtr) { const bitCapIntOcl Nthreads = get_global_size(0); diff --git a/src/qengine/opencl.cpp b/src/qengine/opencl.cpp index 8fbaef8c9..ccffca060 100644 --- a/src/qengine/opencl.cpp +++ b/src/qengine/opencl.cpp @@ -759,6 +759,50 @@ void QEngineOCL::PhaseParity(real1_f radians, bitCapInt mask) BitMask((bitCapIntOcl)mask, OCL_API_PHASE_PARITY, radians); } +void QEngineOCL::PhaseRootNMask(bitLenInt n, bitCapInt mask) +{ + if (bi_compare_0(mask) == 0) { + return; + } + + const bitCapIntOcl oclMask = (bitCapIntOcl)mask; + if (oclMask >= maxQPowerOcl) { + throw std::invalid_argument("QEngineOCL::BitMask mask out-of-bounds!"); + } + + CHECK_ZERO_SKIP(); + + const bitCapIntOcl nPhases = pow2Ocl(n); + const real1_f radians[1] = { -PI_R1 / pow2Ocl(n - 1U) }; + + if (isPowerOfTwo(mask)) { + const complex phaseFac = std::polar(ONE_R1, radians[0]); + Phase(ONE_CMPLX, phaseFac, log2(mask)); + return; + } + + const bitCapIntOcl bciArgs[BCI_ARG_LEN]{ maxQPowerOcl, oclMask, nPhases, 0U, 0U, 0U, 0U, 0U, 0U, 0U }; + PoolItemPtr poolItem = GetFreePoolItem(); + + { + EventVecPtr waitVec = ResetWaitEvents(); + + cl::Event writeIntArgsEvent; + DISPATCH_TEMP_WRITE(waitVec, *(poolItem->ulongBuffer), sizeof(bitCapIntOcl) * 3, bciArgs, writeIntArgsEvent); + + cl::Event writeRealArgsEvent; + DISPATCH_LOC_WRITE(*(poolItem->realBuffer), sizeof(real1), radians, writeRealArgsEvent); + + writeIntArgsEvent.wait(); + writeRealArgsEvent.wait(); + wait_refs.clear(); + } + + const size_t ngc = FixWorkItemCount(bciArgs[0], nrmGroupCount); + const size_t ngs = FixGroupSize(ngc, nrmGroupSize); + QueueCall(OCL_API_PHASE_MASK, ngc, ngs, { stateBuffer, poolItem->ulongBuffer, poolItem->realBuffer }); +} + void QEngineOCL::Apply2x2(bitCapIntOcl offset1, bitCapIntOcl offset2, const complex* mtrx, bitLenInt bitCount, const bitCapIntOcl* qPowersSorted, bool doCalcNorm, SPECIAL_2X2 special, real1_f norm_thresh) { diff --git a/src/qengine/state.cpp b/src/qengine/state.cpp index a057e3472..6490e0b87 100644 --- a/src/qengine/state.cpp +++ b/src/qengine/state.cpp @@ -773,6 +773,66 @@ void QEngineCPU::PhaseParity(real1_f radians, bitCapInt mask) }); } +void QEngineCPU::PhaseRootNMask(bitLenInt n, bitCapInt mask) +{ + if (bi_compare(mask, maxQPower) >= 0) { + throw std::invalid_argument("QEngineCPU::PhaseRootNMask mask out-of-bounds!"); + } + if (n > sizeof(bitCapIntOcl)) { + throw std::invalid_argument("QEngineCPU::PhaseRootNMask: power of 2 out-of-bounds"); + } + + CHECK_ZERO_SKIP(); + + if (bi_compare_0(mask) == 0) { + return; + } + + if (n == 0) { + return; + } + if (n == 1) { + ZMask(mask); + return; + } + + const real1_f radians = -PI_R1 / pow2Ocl(n - 1U); + + if (isPowerOfTwo(mask)) { + const complex phaseFac = std::polar(ONE_R1, radians); + Phase(ONE_CMPLX, phaseFac, log2(mask)); + return; + } + + if (stateVec->is_sparse()) { + QInterface::PhaseRootNMask(n, mask); + return; + } + + Dispatch(maxQPowerOcl, [this, n, mask, radians] { + const bitCapIntOcl maskOcl = (bitCapIntOcl)mask; + const bitCapIntOcl nPhases = pow2Ocl(n); + ParallelFunc fn = [&](const bitCapIntOcl& lcv, const unsigned& cpu) { + bitCapIntOcl popCount = 0; + { + bitCapIntOcl v = lcv & maskOcl; + while (v) { + popCount += v & 1; + v >>= 1; + } + } + + const bitCapIntOcl nPhaseSteps = popCount % nPhases; + if (nPhaseSteps != 0) { + const complex phaseFac = std::polar(ONE_R1, radians * nPhaseSteps); + stateVec->write(lcv, phaseFac * stateVec->read(lcv)); + } + }; + + par_for(0U, maxQPowerOcl, fn); + }); +} + void QEngineCPU::UniformlyControlledSingleBit(const std::vector& controls, bitLenInt qubitIndex, const complex* mtrxs, const std::vector& mtrxSkipPowers, bitCapInt mtrxSkipValueMask) { diff --git a/src/qinterface/gates.cpp b/src/qinterface/gates.cpp index 0cd7e9f69..f5ff8e0dc 100644 --- a/src/qinterface/gates.cpp +++ b/src/qinterface/gates.cpp @@ -150,6 +150,16 @@ void QInterface::ZMask(bitCapInt mask) } } +void QInterface::PhaseRootNMask(bitLenInt n, bitCapInt mask) +{ + bitCapInt v = mask; + while (bi_compare_0(mask) != 0) { + v = v & (v - ONE_BCI); + PhaseRootN(n, log2(mask ^ v)); + mask = v; + } +} + void QInterface::Swap(bitLenInt q1, bitLenInt q2) { if (q1 == q2) { diff --git a/test/tests.cpp b/test/tests.cpp index 5aa147281..52fbe3c38 100644 --- a/test/tests.cpp +++ b/test/tests.cpp @@ -1229,6 +1229,30 @@ TEST_CASE_METHOD(QInterfaceTestFixture, "test_zmask") REQUIRE_THAT(qftReg, HasProbability(0, 20, 0x80001)); } +TEST_CASE_METHOD(QInterfaceTestFixture, "test_phaserootnmask") +{ + constexpr BIG_INTEGER_WORD ket = 14062; + constexpr BIG_INTEGER_WORD masks[6] = { 8, 3097, 22225, 16051, 62894, 49134 }; + constexpr uint16_t n = 3; + const uint16_t modulus = pow2Ocl(n); + // phaseCounts[ii] = popcount(ket & masks[ii]) + constexpr uint16_t phaseCounts[6] = { 1, 2, 5, 7, 8, 10 }; + + qftReg->SetPermutation(ket); + REQUIRE_THAT(qftReg, HasProbability(0, 20, ket)); + + for (int ii = 0; ii < 6; ii++) { + const real1_f angle = -PI_R1 * (phaseCounts[ii] % modulus) / pow2Ocl(n - 1U); + const complex expectedPhaseFactor = std::polar(ONE_R1, angle); + const complex amp_before = qftReg->GetAmplitude(ket); + + qftReg->PhaseRootNMask(n, masks[ii]); + const complex amp_after = qftReg->GetAmplitude(ket); + + REQUIRE_CMPLX(amp_after / amp_before, expectedPhaseFactor); + } +} + TEST_CASE_METHOD(QInterfaceTestFixture, "test_approxcompare") { qftReg =