Skip to content

Commit

Permalink
Update ClassicalSimulator to confirm to simulation abstraction (#6432)
Browse files Browse the repository at this point in the history
  • Loading branch information
shef4 authored and jselig-rigetti committed May 28, 2024
1 parent 2f92166 commit 0b93e27
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 87 deletions.
298 changes: 211 additions & 87 deletions cirq-core/cirq/sim/classical_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,96 +12,220 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from collections import defaultdict
from cirq.sim.simulator import SimulatesSamples
from cirq import ops, protocols
from cirq.study.resolver import ParamResolver
from cirq.circuits.circuit import AbstractCircuit
from cirq.ops.raw_types import Qid

from typing import Dict, Generic, Any, Sequence, List, Optional, Union, TYPE_CHECKING
from copy import deepcopy, copy
from cirq import ops, qis
from cirq.value import big_endian_int_to_bits
from cirq import sim
from cirq.sim.simulation_state import TSimulationState, SimulationState
import numpy as np

if TYPE_CHECKING:
import cirq


def _is_identity(op: ops.Operation) -> bool:
if isinstance(op.gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)):
return op.gate.exponent % 2 == 0
def _is_identity(action) -> bool:
"""Check if the given action is equivalent to an identity."""
gate = action.gate if isinstance(action, ops.Operation) else action
if isinstance(gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)):
return gate.exponent % 2 == 0
return False


class ClassicalStateSimulator(SimulatesSamples):
"""A simulator that accepts only gates with classical counterparts.
This simulator evolves a single state, using only gates that output a single state for each
input state. The simulator runs in linear time, at the cost of not supporting superposition.
It can be used to estimate costs and simulate circuits for simple non-quantum algorithms using
many more qubits than fully capable quantum simulators.
The supported gates are:
- cirq.X
- cirq.CNOT
- cirq.SWAP
- cirq.TOFFOLI
- cirq.measure
Args:
circuit: The circuit to simulate.
param_resolver: Parameters to run with the program.
repetitions: Number of times to repeat the run. It is expected that
this is validated greater than zero before calling this method.
Returns:
A dictionary mapping measurement keys to measurement results.
Raises:
ValueError: If
- one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement.
- A measurement key is used for measurements on different numbers of qubits.
"""

def _run(
self, circuit: AbstractCircuit, param_resolver: ParamResolver, repetitions: int
) -> Dict[str, np.ndarray]:
results_dict: Dict[str, np.ndarray] = {}
values_dict: Dict[Qid, int] = defaultdict(int)
param_resolver = param_resolver or ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)

for moment in resolved_circuit:
for op in moment:
if _is_identity(op):
continue
if op.gate == ops.X:
(q,) = op.qubits
values_dict[q] ^= 1
elif op.gate == ops.CNOT:
c, q = op.qubits
values_dict[q] ^= values_dict[c]
elif op.gate == ops.SWAP:
a, b = op.qubits
values_dict[a], values_dict[b] = values_dict[b], values_dict[a]
elif op.gate == ops.TOFFOLI:
c1, c2, q = op.qubits
values_dict[q] ^= values_dict[c1] & values_dict[c2]
elif protocols.is_measurement(op):
measurement_values = np.array(
[[[values_dict[q] for q in op.qubits]]] * repetitions, dtype=np.uint8
)
key = op.gate.key # type: ignore
if key in results_dict:
if op._num_qubits_() != results_dict[key].shape[-1]:
raise ValueError(
f'Measurement shape {len(measurement_values)} does not match '
f'{results_dict[key].shape[-1]} in {key}.'
)
results_dict[key] = np.concatenate(
(results_dict[key], measurement_values), axis=1
)
else:
results_dict[key] = measurement_values
else:
raise ValueError(
f'{op} is not one of cirq.X, cirq.CNOT, cirq.SWAP, '
'cirq.CCNOT, or a measurement'
)

return results_dict
class ClassicalBasisState(qis.QuantumStateRepresentation):
"""Represents a classical basis state for efficient state evolution."""

def __init__(self, initial_state: Union[List[int], np.ndarray]):
"""Initializes the ClassicalBasisState object.
Args:
initial_state: The initial state in the computational basis.
"""
self.basis = initial_state

def copy(self, deep_copy_buffers: bool = True) -> 'ClassicalBasisState':
"""Creates a copy of the ClassicalBasisState object.
Args:
deep_copy_buffers: Whether to deep copy the internal buffers.
Returns:
A copy of the ClassicalBasisState object.
"""
return ClassicalBasisState(
initial_state=deepcopy(self.basis) if deep_copy_buffers else copy(self.basis)
)

def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
"""Measures the density matrix.
Args:
axes: The axes to measure.
seed: The random number seed to use.
Returns:
The measurements in order.
"""
return [self.basis[i] for i in axes]


class ClassicalBasisSimState(SimulationState[ClassicalBasisState]):
"""Represents the state of a quantum simulation using classical basis states."""

def __init__(
self,
initial_state: Union[int, List[int]] = 0,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
"""Initializes the ClassicalBasisSimState object.
Args:
qubits: The qubits to simulate.
initial_state: The initial state for the simulation.
classical_data: The classical data container for the simulation.
Raises:
ValueError: If qubits not provided and initial_state is int.
If initial_state is not an int, List[int], or np.ndarray.
An initial_state value of type integer is parsed in big endian order.
"""
if isinstance(initial_state, int):
if qubits is None:
raise ValueError('qubits must be provided if initial_state is not List[int]')
state = ClassicalBasisState(
big_endian_int_to_bits(initial_state, bit_count=len(qubits))
)
elif isinstance(initial_state, (list, np.ndarray)):
state = ClassicalBasisState(initial_state)
else:
raise ValueError('initial_state must be an int or List[int] or np.ndarray')
super().__init__(state=state, qubits=qubits, classical_data=classical_data)

def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True):
"""Acts on the state with a given operation.
Args:
action: The operation to apply.
qubits: The qubits to apply the operation to.
allow_decompose: Whether to allow decomposition of the operation.
Returns:
True if the operation was applied successfully.
Raises:
ValueError: If initial_state shape for type np.ndarray is not equal to 1.
If gate is not one of X, CNOT, SWAP, CCNOT, or a measurement.
"""
if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1:
raise ValueError('initial_state shape for type np.ndarray is not equal to 1')
gate = action.gate if isinstance(action, ops.Operation) else action
mapped_qubits = [self.qubit_map[i] for i in qubits]
if _is_identity(gate):
pass
elif gate == ops.X:
(q,) = mapped_qubits
self._state.basis[q] ^= 1
elif gate == ops.CNOT:
c, q = mapped_qubits
self._state.basis[q] ^= self._state.basis[c]
elif gate == ops.SWAP:
a, b = mapped_qubits
self._state.basis[a], self._state.basis[b] = self._state.basis[b], self._state.basis[a]
elif gate == ops.TOFFOLI:
c1, c2, q = mapped_qubits
self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2]
else:
raise ValueError(f'{gate} is not one of X, CNOT, SWAP, CCNOT, or a measurement')
return True


class ClassicalStateStepResult(
sim.StepResultBase['ClassicalBasisSimState'], Generic[TSimulationState]
):
"""The step result provided by `ClassicalStateSimulator.simulate_moment_steps`."""


class ClassicalStateTrialResult(
sim.SimulationTrialResultBase['ClassicalBasisSimState'], Generic[TSimulationState]
):
"""The trial result provided by `ClassicalStateSimulator.simulate`."""


class ClassicalStateSimulator(
sim.SimulatorBase[
ClassicalStateStepResult['ClassicalBasisSimState'],
ClassicalStateTrialResult['ClassicalBasisSimState'],
'ClassicalBasisSimState',
],
Generic[TSimulationState],
):
"""A simulator that accepts only gates with classical counterparts."""

def __init__(
self, *, noise: 'cirq.NOISE_MODEL_LIKE' = None, split_untangled_states: bool = False
):
"""Initializes a ClassicalStateSimulator.
Args:
noise: The noise model used by the simulator.
split_untangled_states: Whether to run the simulation as a product state.
Raises:
ValueError: If noise_model is not None.
"""
if noise is not None:
raise ValueError(f'{noise=} is not supported')
super().__init__(noise=noise, split_untangled_states=split_untangled_states)

def _create_simulator_trial_result(
self,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_simulator_state: 'cirq.SimulationStateBase[ClassicalBasisSimState]',
) -> 'ClassicalStateTrialResult[ClassicalBasisSimState]':
"""Creates a trial result for the simulator.
Args:
params: The parameter resolver for the simulation.
measurements: The measurement results.
final_simulator_state: The final state of the simulator.
Returns:
A trial result for the simulator.
"""
return ClassicalStateTrialResult(
params, measurements, final_simulator_state=final_simulator_state
)

def _create_step_result(
self, sim_state: 'cirq.SimulationStateBase[ClassicalBasisSimState]'
) -> 'ClassicalStateStepResult[ClassicalBasisSimState]':
"""Creates a step result for the simulator.
Args:
sim_state: The current state of the simulator.
Returns:
A step result for the simulator.
"""
return ClassicalStateStepResult(sim_state)

def _create_partial_simulation_state(
self,
initial_state: Any,
qubits: Sequence['cirq.Qid'],
classical_data: 'cirq.ClassicalDataStore',
) -> 'ClassicalBasisSimState':
"""Creates a partial simulation state for the simulator.
Args:
initial_state: The initial state for the simulation.
qubits: The qubits associated with the state.
classical_data: The shared classical data container for this simulation.
Returns:
A partial simulation state.
"""
return ClassicalBasisSimState(
initial_state=initial_state, qubits=qubits, classical_data=classical_data
)
Loading

0 comments on commit 0b93e27

Please sign in to comment.