Skip to content

Commit

Permalink
Use a single scratchpad for all tensornet operations (#1865)
Browse files Browse the repository at this point in the history
  • Loading branch information
1tnguyen authored Jul 1, 2024
1 parent 215e229 commit cee02b1
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 73 deletions.
40 changes: 22 additions & 18 deletions runtime/nvqir/cutensornet/mps_simulation_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ std::size_t MPSSimulationState::getNumQubits() const {

MPSSimulationState::MPSSimulationState(std::unique_ptr<TensorNetState> inState,
const std::vector<MPSTensor> &mpsTensors,
ScratchDeviceMem &inScratchPad,
cutensornetHandle_t cutnHandle)
: m_cutnHandle(cutnHandle), state(std::move(inState)),
m_mpsTensors(mpsTensors) {}
m_mpsTensors(mpsTensors), scratchPad(inScratchPad) {}

MPSSimulationState::~MPSSimulationState() { deallocate(); }

Expand Down Expand Up @@ -141,6 +142,7 @@ std::complex<double> MPSSimulationState::computeOverlap(
computeType, &m_tnDescr));

cutensornetContractionOptimizerConfig_t m_tnConfig;

// Determine the tensor network contraction path and create the contraction
// plan
HANDLE_CUTN_ERROR(
Expand All @@ -149,9 +151,9 @@ std::complex<double> MPSSimulationState::computeOverlap(
cutensornetContractionOptimizerInfo_t m_tnPath;
HANDLE_CUTN_ERROR(cutensornetCreateContractionOptimizerInfo(
cutnHandle, m_tnDescr, &m_tnPath));
assert(m_scratchPad.scratchSize > 0);
assert(scratchPad.scratchSize > 0);
HANDLE_CUTN_ERROR(cutensornetContractionOptimize(
cutnHandle, m_tnDescr, m_tnConfig, m_scratchPad.scratchSize, m_tnPath));
cutnHandle, m_tnDescr, m_tnConfig, scratchPad.scratchSize, m_tnPath));
cutensornetWorkspaceDescriptor_t workDesc;
HANDLE_CUTN_ERROR(
cutensornetCreateWorkspaceDescriptor(cutnHandle, &workDesc));
Expand All @@ -164,10 +166,10 @@ std::complex<double> MPSSimulationState::computeOverlap(
&requiredWorkspaceSize));
assert(requiredWorkspaceSize > 0);
assert(static_cast<std::size_t>(requiredWorkspaceSize) <=
m_scratchPad.scratchSize);
scratchPad.scratchSize);
HANDLE_CUTN_ERROR(cutensornetWorkspaceSetMemory(
cutnHandle, workDesc, CUTENSORNET_MEMSPACE_DEVICE,
CUTENSORNET_WORKSPACE_SCRATCH, m_scratchPad.d_scratch,
CUTENSORNET_WORKSPACE_SCRATCH, scratchPad.d_scratch,
requiredWorkspaceSize));
cutensornetContractionPlan_t m_tnPlan;
HANDLE_CUTN_ERROR(cutensornetCreateContractionPlan(
Expand Down Expand Up @@ -251,7 +253,8 @@ MPSSimulationState::getAmplitude(const std::vector<int> &basisState) {
}

if (getNumQubits() > 1) {
TensorNetState basisTensorNetState(basisState, state->getInternalContext());
TensorNetState basisTensorNetState(basisState, scratchPad,
state->getInternalContext());
// Note: this is a basis state, hence bond dim == 1
std::vector<MPSTensor> basisStateTensors =
basisTensorNetState.factorizeMPS(1, std::numeric_limits<double>::min(),
Expand Down Expand Up @@ -351,10 +354,9 @@ static Eigen::MatrixXcd reshapeStateVec(const Eigen::VectorXcd &stateVec) {
return reshapeMatrix(A);
}

MPSSimulationState::MpsStateData
MPSSimulationState::createFromStateVec(cutensornetHandle_t cutnHandle,
std::size_t size,
std::complex<double> *ptr, int bondDim) {
MPSSimulationState::MpsStateData MPSSimulationState::createFromStateVec(
cutensornetHandle_t cutnHandle, ScratchDeviceMem &inScratchPad,
std::size_t size, std::complex<double> *ptr, int bondDim) {
const std::size_t numQubits = std::log2(size);
// Reverse the qubit order to match cutensornet convention
auto newStateVec = TensorNetState::reverseQubitOrder(
Expand All @@ -371,8 +373,8 @@ MPSSimulationState::createFromStateVec(cutensornetHandle_t cutnHandle,
MPSTensor stateTensor;
stateTensor.deviceData = d_tensor;
stateTensor.extents = std::vector<int64_t>{2};
auto state =
TensorNetState::createFromMpsTensors({stateTensor}, cutnHandle);
auto state = TensorNetState::createFromMpsTensors({stateTensor},
inScratchPad, cutnHandle);
return {std::move(state), std::vector<MPSTensor>{stateTensor}};
}

Expand Down Expand Up @@ -444,7 +446,8 @@ MPSSimulationState::createFromStateVec(cutensornetHandle_t cutnHandle,
stateTensor.extents = std::vector<int64_t>{numSingularValues.back(), 2};
mpsTensors.emplace_back(stateTensor);
assert(mpsTensors.size() == numQubits);
auto state = TensorNetState::createFromMpsTensors(mpsTensors, cutnHandle);
auto state = TensorNetState::createFromMpsTensors(mpsTensors, inScratchPad,
cutnHandle);
return {std::move(state), mpsTensors};
}

Expand All @@ -470,15 +473,16 @@ MPSSimulationState::createFromSizeAndPtr(std::size_t size, void *ptr,
MPSTensor stateTensor{d_tensor, mpsExtents};
mpsTensors.emplace_back(stateTensor);
}
auto state = TensorNetState::createFromMpsTensors(mpsTensors, m_cutnHandle);
auto state = TensorNetState::createFromMpsTensors(mpsTensors, scratchPad,
m_cutnHandle);
return std::make_unique<MPSSimulationState>(std::move(state), mpsTensors,
m_cutnHandle);
scratchPad, m_cutnHandle);
}
auto [state, mpsTensors] = createFromStateVec(
m_cutnHandle, size, reinterpret_cast<std::complex<double> *>(ptr),
MPSSettings().maxBond);
m_cutnHandle, scratchPad, size,
reinterpret_cast<std::complex<double> *>(ptr), MPSSettings().maxBond);
return std::make_unique<MPSSimulationState>(std::move(state), mpsTensors,
m_cutnHandle);
scratchPad, m_cutnHandle);
}

MPSSettings::MPSSettings() {
Expand Down
4 changes: 3 additions & 1 deletion runtime/nvqir/cutensornet/mps_simulation_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MPSSimulationState : public cudaq::SimulationState {
public:
MPSSimulationState(std::unique_ptr<TensorNetState> inState,
const std::vector<MPSTensor> &mpsTensors,
ScratchDeviceMem &inScratchPad,
cutensornetHandle_t cutnHandle);

MPSSimulationState(const MPSSimulationState &) = delete;
Expand Down Expand Up @@ -78,6 +79,7 @@ class MPSSimulationState : public cudaq::SimulationState {
/// Util method to create an MPS state from an input state vector.
// For example, state vector from the user's input.
static MpsStateData createFromStateVec(cutensornetHandle_t cutnHandle,
ScratchDeviceMem &inScratchPad,
std::size_t size,
std::complex<double> *data,
int bondDim);
Expand All @@ -95,7 +97,7 @@ class MPSSimulationState : public cudaq::SimulationState {
cutensornetHandle_t m_cutnHandle;
std::unique_ptr<TensorNetState> state;
std::vector<MPSTensor> m_mpsTensors;
ScratchDeviceMem m_scratchPad;
ScratchDeviceMem &scratchPad;
// Max number of qubits whereby the tensor network state should be contracted
// and cached into a state vector.
// This speeds up sequential state amplitude accessors for small states.
Expand Down
3 changes: 2 additions & 1 deletion runtime/nvqir/cutensornet/simulator_cutensornet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ void SimulatorTensorNetBase::setToZeroState() {
const auto numQubits = m_state->getNumQubits();
m_state.reset();
// Re-create a zero state of the same size
m_state = std::make_unique<TensorNetState>(numQubits, m_cutnHandle);
m_state =
std::make_unique<TensorNetState>(numQubits, scratchPad, m_cutnHandle);
}

void SimulatorTensorNetBase::swap(const std::vector<std::size_t> &ctrlBits,
Expand Down
1 change: 1 addition & 0 deletions runtime/nvqir/cutensornet/simulator_cutensornet.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class SimulatorTensorNetBase : public nvqir::CircuitSimulatorBase<double> {
cutensornetHandle_t m_cutnHandle;
std::unique_ptr<TensorNetState> m_state;
std::unordered_map<std::string, void *> m_gateDeviceMemCache;
ScratchDeviceMem scratchPad;
};

} // end namespace nvqir
28 changes: 17 additions & 11 deletions runtime/nvqir/cutensornet/simulator_mps_register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SimulatorMPS : public SimulatorTensorNetBase {
"[SimulatorMPS simulator] Incompatible state input");
if (!m_state) {
m_state = TensorNetState::createFromMpsTensors(casted->getMpsTensors(),
m_cutnHandle);
scratchPad, m_cutnHandle);
} else {
// Expand an existing state: Append MPS tensors
// Factor the existing state
Expand All @@ -72,7 +72,8 @@ class SimulatorMPS : public SimulatorTensorNetBase {
tensorSizeBytes, cudaMemcpyDefault));
tensors.emplace_back(MPSTensor(mpsTensor, extents));
}
m_state = TensorNetState::createFromMpsTensors(tensors, m_cutnHandle);
m_state = TensorNetState::createFromMpsTensors(tensors, scratchPad,
m_cutnHandle);
}
}

Expand Down Expand Up @@ -159,10 +160,11 @@ class SimulatorMPS : public SimulatorTensorNetBase {
LOG_API_TIME();
if (!m_state) {
if (!ptr) {
m_state = std::make_unique<TensorNetState>(numQubits, m_cutnHandle);
m_state = std::make_unique<TensorNetState>(numQubits, scratchPad,
m_cutnHandle);
} else {
auto [state, mpsTensors] = MPSSimulationState::createFromStateVec(
m_cutnHandle, 1ULL << numQubits,
m_cutnHandle, scratchPad, 1ULL << numQubits,
reinterpret_cast<std::complex<double> *>(const_cast<void *>(ptr)),
m_settings.maxBond);
m_state = std::move(state);
Expand All @@ -188,11 +190,12 @@ class SimulatorMPS : public SimulatorTensorNetBase {
cudaMemcpyHostToDevice));
tensors.emplace_back(MPSTensor(mpsTensor, extents));
}
m_state = TensorNetState::createFromMpsTensors(tensors, m_cutnHandle);
m_state = TensorNetState::createFromMpsTensors(tensors, scratchPad,
m_cutnHandle);
} else {
// Non-zero state needs to be factorized and appended.
auto [state, mpsTensors] = MPSSimulationState::createFromStateVec(
m_cutnHandle, 1ULL << numQubits,
m_cutnHandle, scratchPad, 1ULL << numQubits,
reinterpret_cast<std::complex<double> *>(const_cast<void *>(ptr)),
m_settings.maxBond);
auto tensors = m_state->factorizeMPS(
Expand All @@ -206,7 +209,8 @@ class SimulatorMPS : public SimulatorTensorNetBase {
mpsTensors.front().extents = extents;
// Combine the list
tensors.insert(tensors.end(), mpsTensors.begin(), mpsTensors.end());
m_state = TensorNetState::createFromMpsTensors(tensors, m_cutnHandle);
m_state = TensorNetState::createFromMpsTensors(tensors, scratchPad,
m_cutnHandle);
}
}
}
Expand All @@ -215,14 +219,15 @@ class SimulatorMPS : public SimulatorTensorNetBase {
LOG_API_TIME();

if (!m_state || m_state->getNumQubits() == 0)
return std::make_unique<MPSSimulationState>(
std::move(m_state), std::vector<MPSTensor>{}, m_cutnHandle);
return std::make_unique<MPSSimulationState>(std::move(m_state),
std::vector<MPSTensor>{},
scratchPad, m_cutnHandle);

if (m_state->getNumQubits() > 1) {
std::vector<MPSTensor> tensors = m_state->factorizeMPS(
m_settings.maxBond, m_settings.absCutoff, m_settings.relCutoff);
return std::make_unique<MPSSimulationState>(std::move(m_state), tensors,
m_cutnHandle);
scratchPad, m_cutnHandle);
}

auto [d_tensor, numElements] = m_state->contractStateVectorInternal({});
Expand All @@ -232,7 +237,8 @@ class SimulatorMPS : public SimulatorTensorNetBase {
stateTensor.extents = {static_cast<int64_t>(numElements)};

return std::make_unique<MPSSimulationState>(
std::move(m_state), std::vector<MPSTensor>{stateTensor}, m_cutnHandle);
std::move(m_state), std::vector<MPSTensor>{stateTensor}, scratchPad,
m_cutnHandle);
}

virtual ~SimulatorMPS() noexcept {
Expand Down
13 changes: 8 additions & 5 deletions runtime/nvqir/cutensornet/simulator_tensornet_register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,21 @@ class SimulatorTensorNet : public SimulatorTensorNetBase {
std::unique_ptr<cudaq::SimulationState> getSimulationState() override {
LOG_API_TIME();
return std::make_unique<TensorNetSimulationState>(std::move(m_state),
m_cutnHandle);
scratchPad, m_cutnHandle);
}

void addQubitsToState(std::size_t numQubits, const void *ptr) override {
LOG_API_TIME();
if (!m_state) {
if (!ptr) {
m_state = std::make_unique<TensorNetState>(numQubits, m_cutnHandle);
m_state = std::make_unique<TensorNetState>(numQubits, scratchPad,
m_cutnHandle);
} else {
auto *casted =
reinterpret_cast<std::complex<double> *>(const_cast<void *>(ptr));
std::span<std::complex<double>> stateVec(casted, 1ULL << numQubits);
m_state = TensorNetState::createFromStateVector(stateVec, m_cutnHandle);
m_state = TensorNetState::createFromStateVector(stateVec, scratchPad,
m_cutnHandle);
}
} else {
if (!ptr) {
Expand All @@ -83,8 +85,9 @@ class SimulatorTensorNet : public SimulatorTensorNetBase {
throw std::invalid_argument(
"[Tensornet simulator] Incompatible state input");
if (!m_state) {
m_state = TensorNetState::createFromOpTensors(
in_state.getNumQubits(), casted->getAppliedTensors(), m_cutnHandle);
m_state = TensorNetState::createFromOpTensors(in_state.getNumQubits(),
casted->getAppliedTensors(),
scratchPad, m_cutnHandle);
} else {
// Expand an existing state:
// (1) Create a blank tensor network with combined number of qubits
Expand Down
Loading

0 comments on commit cee02b1

Please sign in to comment.