Skip to content

Commit

Permalink
Add CircuitDiagramInfoArgs.format_real (#2569)
Browse files Browse the repository at this point in the history
- Also fix the qasm UGate not having a decomposition or a correct repr
  • Loading branch information
Strilanc authored and CirqBot committed Nov 19, 2019
1 parent ad98557 commit 907f990
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 35 deletions.
2 changes: 1 addition & 1 deletion cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ def _formatted_exponent(info: 'cirq.CircuitDiagramInfo',
- info.exponent) < 10**-args.precision:
return '({})'.format(approx_frac)

return '{{:.{}}}'.format(args.precision).format(info.exponent)
return args.format_real(info.exponent)
return repr(info.exponent)

# If the exponent is any other object, use its string representation.
Expand Down
24 changes: 13 additions & 11 deletions cirq/circuits/qasm_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,26 @@ def from_matrix(mat: np.array) -> 'QasmUGate':
pre_phase / np.pi,
)

def _qasm_(self, qubits: Tuple[ops.Qid, ...], args: 'cirq.QasmArgs') -> str:
def _qasm_(self, qubits: Tuple['cirq.Qid', ...],
args: 'cirq.QasmArgs') -> str:
args.validate_version('2.0')
return args.format(
'u3({0:half_turns},{1:half_turns},{2:half_turns}) {3};\n',
self.theta, self.phi, self.lmda, qubits[0])

def __repr__(self) -> str:
return 'cirq.QasmUGate({}, {}, {})'.format(self.theta, self.phi,
self.lmda)

def _unitary_(self) -> np.ndarray:
# Source: https://arxiv.org/abs/1707.03429 (equation 2)
operations = [
ops.Rz(self.phi * np.pi),
ops.Ry(self.theta * np.pi),
ops.Rz(self.lmda * np.pi),
return (f'cirq.circuits.qasm_output.QasmUGate('
f'theta={self.theta!r}, '
f'phi={self.phi!r}, '
f'lmda={self.lmda})')

def _decompose_(self, qubits):
q = qubits[0]
return [
ops.Rz(self.lmda * np.pi).on(q),
ops.Ry(self.theta * np.pi).on(q),
ops.Rz(self.phi * np.pi).on(q),
]
return linalg.dot(*map(protocols.unitary, operations))

def _value_equality_values_(self):
return self.lmda, self.theta, self.phi
Expand Down
5 changes: 4 additions & 1 deletion cirq/circuits/qasm_output_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def _make_qubits(n):

def test_u_gate_repr():
gate = QasmUGate(0.1, 0.2, 0.3)
assert repr(gate) == 'cirq.QasmUGate(0.1, 0.2, 0.3)'
assert repr(gate) == ('cirq.circuits.qasm_output.QasmUGate('
'theta=0.1, phi=0.2, lmda=0.3)')


def test_u_gate_eq():
Expand All @@ -51,6 +52,8 @@ def test_qasm_u_qubit_gate_unitary():
u,
atol=1e-7)

cirq.testing.assert_implements_consistent_protocols(g)


def test_qasm_two_qubit_gate_unitary():
u = cirq.testing.random_unitary(4)
Expand Down
2 changes: 1 addition & 1 deletion cirq/google/optimizers/convert_to_xmon_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_avoids_infinite_cycle_when_matrix_available():
q = cirq.GridQubit(0, 0)
c = cirq.Circuit(OtherX().on(q), OtherOtherX().on(q))
cirq.google.ConvertToXmonGates().optimize_circuit(c)
cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1.0)───PhX(1.0)───')
cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1)───PhX(1)───')
cirq.protocols.decompose(c)


Expand Down
10 changes: 7 additions & 3 deletions cirq/neutral_atoms/convert_to_neutral_atom_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _decompose_(self, qubits):
q = cirq.GridQubit(0, 0)
c = cirq.Circuit(OtherX().on(q), OtherOtherX().on(q))
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1.0)───PhX(1.0)───')
cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1)───PhX(1)───')


def test_avoids_decompose_fallback_when_matrix_available_two_qubit():
Expand All @@ -78,5 +78,9 @@ def _decompose_(self, qubits):
q01 = cirq.GridQubit(0, 1)
c = cirq.Circuit(OtherCZ().on(q00, q01), OtherOtherCZ().on(q00, q01))
cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c)
cirq.testing.assert_has_diagram(c, "(0, 0): ───@───@───\n"
" │ │\n(0, 1): ───@───@───")
cirq.testing.assert_has_diagram(
c, """
(0, 0): ───@───@───
│ │
(0, 1): ───@───@───
""")
2 changes: 1 addition & 1 deletion cirq/ops/fsim_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,6 @@ def _format_rads(args: 'cirq.CircuitDiagramInfoArgs', radians: float) -> str:
if radians == -np.pi:
return '-' + unit
if args.precision is not None:
quantity = '{{:.{}}}'.format(args.precision).format(radians / np.pi)
quantity = args.format_real(radians / np.pi)
return quantity + unit
return repr(radians)
11 changes: 3 additions & 8 deletions cirq/ops/phased_iswap_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,9 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
'ZZ': expansion['ZZ'],
})

def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs'
) -> 'protocols.CircuitDiagramInfo':
if (isinstance(self._phase_exponent, (sympy.Basic, int)) or
args.precision is None):
s = 'PhISwap({})'.format(self._phase_exponent)
else:
s = 'PhISwap({{:.{}}})'.format(args.precision).format(
self._phase_exponent)
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
s = f'PhISwap({args.format_real(self._phase_exponent)})'
return protocols.CircuitDiagramInfo(
wire_symbols=(s, s), exponent=self._diagram_exponent(args))

Expand Down
12 changes: 3 additions & 9 deletions cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,12 @@ def _phase_by_(self, phase_turns, qubit_index):
phase_exponent=self._phase_exponent + phase_turns * 2,
global_shift=self._global_shift)

def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs'
) -> 'protocols.CircuitDiagramInfo':
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
"""See `cirq.SupportsCircuitDiagramInfo`."""

if (isinstance(self.phase_exponent, (sympy.Basic, int)) or
args.precision is None):
s = 'PhX({})'.format(self.phase_exponent)
else:
s = 'PhX({{:.{}}})'.format(args.precision).format(
self.phase_exponent)
return protocols.CircuitDiagramInfo(
wire_symbols=(s,),
wire_symbols=(f'PhX({args.format_real(self.phase_exponent)})',),
exponent=value.canonicalize_half_turns(self._exponent))

def __str__(self):
Expand Down
10 changes: 10 additions & 0 deletions cirq/protocols/circuit_diagram_info_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import (Any, TYPE_CHECKING, Optional, Union, TypeVar, Dict,
overload, Iterable)

import sympy
from typing_extensions import Protocol

from cirq import value
Expand Down Expand Up @@ -133,6 +134,15 @@ def __repr__(self):
self.use_unicode_characters,
self.precision, self.qubit_map))

def format_real(self, val: Union[sympy.Basic, int, float]) -> str:
if isinstance(val, sympy.Basic):
return str(val)
if val == int(val):
return str(int(val))
if self.precision is None:
return str(val)
return f'{float(val):.{self.precision}}'

def copy(self):
return self.__class__(
known_qubits=self.known_qubits,
Expand Down
19 changes: 19 additions & 0 deletions cirq/protocols/circuit_diagram_info_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import sympy

import cirq

Expand Down Expand Up @@ -170,3 +171,21 @@ def test_circuit_diagram_info_args_repr():
cirq.LineQubit(0): 5,
cirq.LineQubit(1): 7
}))


def test_formal_real():
args = cirq.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT
assert args.format_real(1) == '1'
assert args.format_real(1.1) == '1.1'
assert args.format_real(1.234567) == '1.23'
assert args.format_real(1 / 7) == '0.143'
assert args.format_real(sympy.Symbol('t')) == 't'
assert args.format_real(sympy.Symbol('t') * 2 + 1) == '2*t + 1'

args.precision = None
assert args.format_real(1) == '1'
assert args.format_real(1.1) == '1.1'
assert args.format_real(1.234567) == '1.234567'
assert args.format_real(1 / 7) == repr(1 / 7)
assert args.format_real(sympy.Symbol('t')) == 't'
assert args.format_real(sympy.Symbol('t') * 2 + 1) == '2*t + 1'

0 comments on commit 907f990

Please sign in to comment.