diff --git a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py index 4341e82bc10..17c3c620a08 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py +++ b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py @@ -71,6 +71,7 @@ class ConstantGauge(Gauge): post_q1: Tuple[ops.Gate, ...] = field( default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g) ) + swap_qubits: bool = False def sample(self, gate: ops.Gate, prng: np.random.Generator) -> "ConstantGauge": return self @@ -85,6 +86,41 @@ def post(self) -> Tuple[Tuple[ops.Gate, ...], Tuple[ops.Gate, ...]]: """A tuple (ops to apply to q0, ops to apply to q1).""" return self.post_q0, self.post_q1 + def on(self, q0: ops.Qid, q1: ops.Qid) -> ops.Operation: + """Returns the operation that replaces the two qubit gate.""" + if self.swap_qubits: + return self.two_qubit_gate(q1, q0) + return self.two_qubit_gate(q0, q1) + + +@frozen +class SameGateGauge(Gauge): + """Same as ConstantGauge but the new two-qubit gate equals the old gate.""" + + pre_q0: Tuple[ops.Gate, ...] = field( + default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g) + ) + pre_q1: Tuple[ops.Gate, ...] = field( + default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g) + ) + post_q0: Tuple[ops.Gate, ...] = field( + default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g) + ) + post_q1: Tuple[ops.Gate, ...] = field( + default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g) + ) + swap_qubits: bool = False + + def sample(self, gate: ops.Gate, prng: np.random.Generator) -> ConstantGauge: + return ConstantGauge( + two_qubit_gate=gate, + pre_q0=self.pre_q0, + pre_q1=self.pre_q1, + post_q0=self.post_q0, + post_q1=self.post_q1, + swap_qubits=self.swap_qubits, + ) + def _select(choices: Sequence[Gauge], probabilites: np.ndarray, prng: np.random.Generator) -> Gauge: return choices[prng.choice(len(choices), p=probabilites)] @@ -154,7 +190,7 @@ def __call__( gauge = self.gauge_selector(rng).sample(op.gate, rng) q0, q1 = op.qubits left.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.pre)) - center.append(gauge.two_qubit_gate(q0, q1)) + center.append(gauge.on(q0, q1)) right.extend([g(q) for g in gs] for q, gs in zip(op.qubits, gauge.post)) else: center.append(op) diff --git a/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge.py b/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge.py index 96643eb3a90..88713ebba80 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge.py +++ b/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge.py @@ -17,17 +17,17 @@ from cirq.transformers.gauge_compiling.gauge_compiling import ( GaugeTransformer, GaugeSelector, - ConstantGauge, + SameGateGauge, ) from cirq import ops SpinInversionGaugeSelector = GaugeSelector( gauges=[ - ConstantGauge(two_qubit_gate=ops.ZZ, pre_q0=ops.X, post_q0=ops.X), - ConstantGauge(two_qubit_gate=ops.ZZ, pre_q1=ops.X, post_q1=ops.X), + SameGateGauge(pre_q0=ops.X, post_q0=ops.X, pre_q1=ops.X, post_q1=ops.X), + SameGateGauge(), ] ) SpinInversionGaugeTransformer = GaugeTransformer( - target=ops.ZZ, gauge_selector=SpinInversionGaugeSelector + target=ops.GateFamily(ops.ZZPowGate), gauge_selector=SpinInversionGaugeSelector ) diff --git a/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge_test.py b/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge_test.py index 6e38ff451f1..c599b328316 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge_test.py +++ b/cirq-core/cirq/transformers/gauge_compiling/spin_inversion_gauge_test.py @@ -12,12 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. - import cirq from cirq.transformers.gauge_compiling import SpinInversionGaugeTransformer from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester -class TestSpinInversionGauge(GaugeTester): +class TestSpinInversionGauge_0(GaugeTester): two_qubit_gate = cirq.ZZ gauge_transformer = SpinInversionGaugeTransformer + + +class TestSpinInversionGauge_1(GaugeTester): + two_qubit_gate = cirq.ZZ**0.1 + gauge_transformer = SpinInversionGaugeTransformer + + +class TestSpinInversionGauge_2(GaugeTester): + two_qubit_gate = cirq.ZZ**-1 + gauge_transformer = SpinInversionGaugeTransformer + + +class TestSpinInversionGauge_3(GaugeTester): + two_qubit_gate = cirq.ZZ**0.3 + gauge_transformer = SpinInversionGaugeTransformer diff --git a/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge.py b/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge.py index 2e6e96f9220..60252d6ad56 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge.py +++ b/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge.py @@ -12,18 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A Gauge transformer for CZ**0.5 gate.""" +"""A Gauge transformer for CZ**0.5 and CZ**-0.5 gates.""" + + +from typing import TYPE_CHECKING +import numpy as np from cirq.transformers.gauge_compiling.gauge_compiling import ( GaugeTransformer, GaugeSelector, ConstantGauge, + Gauge, ) -from cirq.ops.common_gates import CZ -from cirq import ops +from cirq.ops import CZ, S, X, Gateset + +if TYPE_CHECKING: + import cirq + +_SQRT_CZ = CZ**0.5 +_ADJ_S = S**-1 -SqrtCZGaugeSelector = GaugeSelector( - gauges=[ConstantGauge(pre_q0=ops.X, post_q0=ops.X, post_q1=ops.Z**0.5, two_qubit_gate=CZ**-0.5)] -) -SqrtCZGaugeTransformer = GaugeTransformer(target=CZ**0.5, gauge_selector=SqrtCZGaugeSelector) +class SqrtCZGauge(Gauge): + + def weight(self) -> float: + return 3.0 + + def sample(self, gate: 'cirq.Gate', prng: np.random.Generator) -> ConstantGauge: + if prng.choice([True, False]): + return ConstantGauge(two_qubit_gate=gate) + swap_qubits = prng.choice([True, False]) + if swap_qubits: + return ConstantGauge( + pre_q1=X, + post_q1=X, + post_q0=S if gate == _SQRT_CZ else _ADJ_S, + two_qubit_gate=gate**-1, + swap_qubits=True, + ) + else: + return ConstantGauge( + pre_q0=X, + post_q0=X, + post_q1=S if gate == _SQRT_CZ else _ADJ_S, + two_qubit_gate=gate**-1, + ) + + +SqrtCZGaugeTransformer = GaugeTransformer( + target=Gateset(_SQRT_CZ, _SQRT_CZ**-1), gauge_selector=GaugeSelector(gauges=[SqrtCZGauge()]) +) diff --git a/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge_test.py b/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge_test.py index 8d18c208e1d..e6e871cf938 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge_test.py +++ b/cirq-core/cirq/transformers/gauge_compiling/sqrt_cz_gauge_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import cirq from cirq.transformers.gauge_compiling import SqrtCZGaugeTransformer from cirq.transformers.gauge_compiling.gauge_compiling_test_utils import GaugeTester @@ -21,3 +20,8 @@ class TestSqrtCZGauge(GaugeTester): two_qubit_gate = cirq.CZ**0.5 gauge_transformer = SqrtCZGaugeTransformer + + +class TestAdjointSqrtCZGauge(GaugeTester): + two_qubit_gate = cirq.CZ**-0.5 + gauge_transformer = SqrtCZGaugeTransformer