Skip to content

Commit

Permalink
Add EigenGate._equal_up_to_global_phase_ (#1840)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryano authored and CirqBot committed Nov 8, 2019
1 parent 5207900 commit 8f1c30a
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
26 changes: 26 additions & 0 deletions cirq/ops/eigen_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,32 @@ def _resolve_parameters_(self: TSelf, param_resolver) -> TSelf:
return self._with_exponent(
exponent=param_resolver.value_of(self._exponent))

def _equal_up_to_global_phase_(self, other, atol):
if not isinstance(other, EigenGate):
return NotImplemented

exponents = (self.exponent, other.exponent)
exponents_is_parameterized = tuple(
protocols.is_parameterized(e) for e in exponents)
if (all(exponents_is_parameterized) and exponents[0] != exponents[1]):
return False
if any(exponents_is_parameterized):
return False
self_without_phase = self._with_exponent(self.exponent)
self_without_phase._global_shift = 0
self_without_exp_or_phase = self_without_phase._with_exponent(0)
other_without_phase = other._with_exponent(other.exponent)
other_without_phase._global_shift = 0
other_without_exp_or_phase = other_without_phase._with_exponent(0)
if not protocols.approx_eq(self_without_exp_or_phase,
other_without_exp_or_phase,
atol=atol):
return False

period = self_without_phase._period()
canonical_diff = (exponents[0] - exponents[1]) % period
return np.isclose(canonical_diff, 0, atol=atol)

def _json_dict_(self):
return protocols.obj_to_dict_helper(self, ['exponent', 'global_shift'])

Expand Down
41 changes: 41 additions & 0 deletions cirq/ops/eigen_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pytest
import sympy

import cirq
Expand Down Expand Up @@ -370,3 +371,43 @@ def _eigen_shifts(self):

# Unknown period.
assert ShiftyGate(505.2, 0, np.pi, np.e)._diagram_exponent(args) == 505.2


class WeightedZPowGate(cirq.EigenGate, cirq.SingleQubitGate):

def __init__(self, weight, **kwargs):
self.weight = weight
super().__init__(**kwargs)

def _value_equality_values_(self):
return self.weight, self._canonical_exponent, self._global_shift

_value_equality_approximate_values_ = _value_equality_values_

def _eigen_components(self):
return [
(0, np.diag([1, 0])),
(self.weight, np.diag([0, 1])),
]

def _with_exponent(self, exponent):
return type(self)(self.weight,
exponent=exponent,
global_shift=self._global_shift)


@pytest.mark.parametrize('gate1,gate2,eq_up_to_global_phase', [
(cirq.Rz(0.3 * np.pi), cirq.Z**0.3, True),
(cirq.Z, cirq.Gate, False),
(cirq.Rz(0.3), cirq.Z**0.3, False),
(cirq.ZZPowGate(global_shift=0.5), cirq.ZZ, True),
(cirq.ZPowGate(global_shift=0.5)**sympy.Symbol('e'), cirq.Z, False),
(cirq.Z**sympy.Symbol('e'), cirq.Z**sympy.Symbol('f'), False),
(cirq.ZZ**1.9, cirq.ZZ**-0.1, True),
(WeightedZPowGate(0), WeightedZPowGate(0.1), False),
(WeightedZPowGate(0.3), WeightedZPowGate(0.3, global_shift=0.1), True),
(cirq.X, cirq.Z, False),
(cirq.X**0.3, cirq.Z**0.3, False),
])
def test_equal_up_to_global_phase(gate1, gate2, eq_up_to_global_phase):
assert cirq.equal_up_to_global_phase(gate1, gate2) == eq_up_to_global_phase
12 changes: 12 additions & 0 deletions cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,18 @@ def _qasm_(self, args: 'protocols.QasmArgs') -> Optional[str]:
qubits=self.qubits,
default=None)

def _equal_up_to_global_phase_(self,
other: Any,
atol: Union[int, float] = 1e-8
) -> Union[NotImplementedType, bool]:
if not isinstance(other, type(self)):
return NotImplemented
if self.qubits != other.qubits:
return False
return protocols.equal_up_to_global_phase(self.gate,
other.gate,
atol=atol)


TV = TypeVar('TV', bound=raw_types.Gate)

Expand Down
21 changes: 21 additions & 0 deletions cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,27 @@ def with_qubits(self, *new_qubits):
assert not cirq.op_gate_isinstance(NonGateOperation(), NonGateOperation)


@pytest.mark.parametrize('gate1,gate2,eq_up_to_global_phase', [
(cirq.Rz(0.3 * np.pi), cirq.Z**0.3, True),
(cirq.Rz(0.3), cirq.Z**0.3, False),
(cirq.ZZPowGate(global_shift=0.5), cirq.ZZ, True),
(cirq.ZPowGate(global_shift=0.5)**sympy.Symbol('e'), cirq.Z, False),
(cirq.Z**sympy.Symbol('e'), cirq.Z**sympy.Symbol('f'), False),
])
def test_equal_up_to_global_phase_on_gates(gate1, gate2, eq_up_to_global_phase):
num_qubits1, num_qubits2 = (cirq.num_qubits(g) for g in (gate1, gate2))
qubits = cirq.LineQubit.range(max(num_qubits1, num_qubits2) + 1)
op1, op2 = gate1(*qubits[:num_qubits1]), gate2(*qubits[:num_qubits2])
assert cirq.equal_up_to_global_phase(op1, op2) == eq_up_to_global_phase
op2_on_diff_qubits = gate2(*qubits[1:num_qubits2 + 1])
assert not cirq.equal_up_to_global_phase(op1, op2_on_diff_qubits)


def test_equal_up_to_global_phase_on_diff_types():
op = cirq.X(cirq.LineQubit(0))
assert not cirq.equal_up_to_global_phase(op, 3)


def test_gate_on_operation_besides_gate_operation():
a, b = cirq.LineQubit.range(2)

Expand Down

0 comments on commit 8f1c30a

Please sign in to comment.