diff --git a/cirq/circuits/circuit.py b/cirq/circuits/circuit.py index d94d0c0b17f..f3753efb4a2 100644 --- a/cirq/circuits/circuit.py +++ b/cirq/circuits/circuit.py @@ -1797,7 +1797,7 @@ def _formatted_exponent(info: 'cirq.CircuitDiagramInfo', - info.exponent) < 10**-args.precision: return '({})'.format(approx_frac) - return '{{:.{}}}'.format(args.precision).format(info.exponent) + return args.format_real(info.exponent) return repr(info.exponent) # If the exponent is any other object, use its string representation. diff --git a/cirq/circuits/qasm_output.py b/cirq/circuits/qasm_output.py index 430ea14315b..d3b4e91c7ce 100644 --- a/cirq/circuits/qasm_output.py +++ b/cirq/circuits/qasm_output.py @@ -54,24 +54,26 @@ def from_matrix(mat: np.array) -> 'QasmUGate': pre_phase / np.pi, ) - def _qasm_(self, qubits: Tuple[ops.Qid, ...], args: 'cirq.QasmArgs') -> str: + def _qasm_(self, qubits: Tuple['cirq.Qid', ...], + args: 'cirq.QasmArgs') -> str: args.validate_version('2.0') return args.format( 'u3({0:half_turns},{1:half_turns},{2:half_turns}) {3};\n', self.theta, self.phi, self.lmda, qubits[0]) def __repr__(self) -> str: - return 'cirq.QasmUGate({}, {}, {})'.format(self.theta, self.phi, - self.lmda) - - def _unitary_(self) -> np.ndarray: - # Source: https://arxiv.org/abs/1707.03429 (equation 2) - operations = [ - ops.Rz(self.phi * np.pi), - ops.Ry(self.theta * np.pi), - ops.Rz(self.lmda * np.pi), + return (f'cirq.circuits.qasm_output.QasmUGate(' + f'theta={self.theta!r}, ' + f'phi={self.phi!r}, ' + f'lmda={self.lmda})') + + def _decompose_(self, qubits): + q = qubits[0] + return [ + ops.Rz(self.lmda * np.pi).on(q), + ops.Ry(self.theta * np.pi).on(q), + ops.Rz(self.phi * np.pi).on(q), ] - return linalg.dot(*map(protocols.unitary, operations)) def _value_equality_values_(self): return self.lmda, self.theta, self.phi diff --git a/cirq/circuits/qasm_output_test.py b/cirq/circuits/qasm_output_test.py index 8f7fc63af9a..bb794af588a 100644 --- a/cirq/circuits/qasm_output_test.py +++ b/cirq/circuits/qasm_output_test.py @@ -27,7 +27,8 @@ def _make_qubits(n): def test_u_gate_repr(): gate = QasmUGate(0.1, 0.2, 0.3) - assert repr(gate) == 'cirq.QasmUGate(0.1, 0.2, 0.3)' + assert repr(gate) == ('cirq.circuits.qasm_output.QasmUGate(' + 'theta=0.1, phi=0.2, lmda=0.3)') def test_u_gate_eq(): @@ -51,6 +52,8 @@ def test_qasm_u_qubit_gate_unitary(): u, atol=1e-7) + cirq.testing.assert_implements_consistent_protocols(g) + def test_qasm_two_qubit_gate_unitary(): u = cirq.testing.random_unitary(4) diff --git a/cirq/google/optimizers/convert_to_xmon_gates_test.py b/cirq/google/optimizers/convert_to_xmon_gates_test.py index 733935902a0..76d1cc5d829 100644 --- a/cirq/google/optimizers/convert_to_xmon_gates_test.py +++ b/cirq/google/optimizers/convert_to_xmon_gates_test.py @@ -42,7 +42,7 @@ def test_avoids_infinite_cycle_when_matrix_available(): q = cirq.GridQubit(0, 0) c = cirq.Circuit(OtherX().on(q), OtherOtherX().on(q)) cirq.google.ConvertToXmonGates().optimize_circuit(c) - cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1.0)───PhX(1.0)───') + cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1)───PhX(1)───') cirq.protocols.decompose(c) diff --git a/cirq/neutral_atoms/convert_to_neutral_atom_gates_test.py b/cirq/neutral_atoms/convert_to_neutral_atom_gates_test.py index 0e67b66dc56..37b02890478 100644 --- a/cirq/neutral_atoms/convert_to_neutral_atom_gates_test.py +++ b/cirq/neutral_atoms/convert_to_neutral_atom_gates_test.py @@ -59,7 +59,7 @@ def _decompose_(self, qubits): q = cirq.GridQubit(0, 0) c = cirq.Circuit(OtherX().on(q), OtherOtherX().on(q)) cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c) - cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1.0)───PhX(1.0)───') + cirq.testing.assert_has_diagram(c, '(0, 0): ───PhX(1)───PhX(1)───') def test_avoids_decompose_fallback_when_matrix_available_two_qubit(): @@ -78,5 +78,9 @@ def _decompose_(self, qubits): q01 = cirq.GridQubit(0, 1) c = cirq.Circuit(OtherCZ().on(q00, q01), OtherOtherCZ().on(q00, q01)) cirq.neutral_atoms.ConvertToNeutralAtomGates().optimize_circuit(c) - cirq.testing.assert_has_diagram(c, "(0, 0): ───@───@───\n" - " │ │\n(0, 1): ───@───@───") + cirq.testing.assert_has_diagram( + c, """ +(0, 0): ───@───@─── + │ │ +(0, 1): ───@───@─── +""") diff --git a/cirq/ops/fsim_gate.py b/cirq/ops/fsim_gate.py index e2e9e1aecc9..84d1f74e827 100644 --- a/cirq/ops/fsim_gate.py +++ b/cirq/ops/fsim_gate.py @@ -169,6 +169,6 @@ def _format_rads(args: 'cirq.CircuitDiagramInfoArgs', radians: float) -> str: if radians == -np.pi: return '-' + unit if args.precision is not None: - quantity = '{{:.{}}}'.format(args.precision).format(radians / np.pi) + quantity = args.format_real(radians / np.pi) return quantity + unit return repr(radians) diff --git a/cirq/ops/phased_iswap_gate.py b/cirq/ops/phased_iswap_gate.py index 9ba7fe60512..2a49d6436eb 100644 --- a/cirq/ops/phased_iswap_gate.py +++ b/cirq/ops/phased_iswap_gate.py @@ -156,14 +156,9 @@ def _pauli_expansion_(self) -> value.LinearDict[str]: 'ZZ': expansion['ZZ'], }) - def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs' - ) -> 'protocols.CircuitDiagramInfo': - if (isinstance(self._phase_exponent, (sympy.Basic, int)) or - args.precision is None): - s = 'PhISwap({})'.format(self._phase_exponent) - else: - s = 'PhISwap({{:.{}}})'.format(args.precision).format( - self._phase_exponent) + def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs' + ) -> 'cirq.CircuitDiagramInfo': + s = f'PhISwap({args.format_real(self._phase_exponent)})' return protocols.CircuitDiagramInfo( wire_symbols=(s, s), exponent=self._diagram_exponent(args)) diff --git a/cirq/ops/phased_x_gate.py b/cirq/ops/phased_x_gate.py index 143a7756ed5..72d77b4ceba 100644 --- a/cirq/ops/phased_x_gate.py +++ b/cirq/ops/phased_x_gate.py @@ -146,18 +146,12 @@ def _phase_by_(self, phase_turns, qubit_index): phase_exponent=self._phase_exponent + phase_turns * 2, global_shift=self._global_shift) - def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs' - ) -> 'protocols.CircuitDiagramInfo': + def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs' + ) -> 'cirq.CircuitDiagramInfo': """See `cirq.SupportsCircuitDiagramInfo`.""" - if (isinstance(self.phase_exponent, (sympy.Basic, int)) or - args.precision is None): - s = 'PhX({})'.format(self.phase_exponent) - else: - s = 'PhX({{:.{}}})'.format(args.precision).format( - self.phase_exponent) return protocols.CircuitDiagramInfo( - wire_symbols=(s,), + wire_symbols=(f'PhX({args.format_real(self.phase_exponent)})',), exponent=value.canonicalize_half_turns(self._exponent)) def __str__(self): diff --git a/cirq/protocols/circuit_diagram_info_protocol.py b/cirq/protocols/circuit_diagram_info_protocol.py index 3326dded566..c981c7f80b2 100644 --- a/cirq/protocols/circuit_diagram_info_protocol.py +++ b/cirq/protocols/circuit_diagram_info_protocol.py @@ -15,6 +15,7 @@ from typing import (Any, TYPE_CHECKING, Optional, Union, TypeVar, Dict, overload, Iterable) +import sympy from typing_extensions import Protocol from cirq import value @@ -133,6 +134,15 @@ def __repr__(self): self.use_unicode_characters, self.precision, self.qubit_map)) + def format_real(self, val: Union[sympy.Basic, int, float]) -> str: + if isinstance(val, sympy.Basic): + return str(val) + if val == int(val): + return str(int(val)) + if self.precision is None: + return str(val) + return f'{float(val):.{self.precision}}' + def copy(self): return self.__class__( known_qubits=self.known_qubits, diff --git a/cirq/protocols/circuit_diagram_info_protocol_test.py b/cirq/protocols/circuit_diagram_info_protocol_test.py index f4bf36b4e78..bdec2962b2d 100644 --- a/cirq/protocols/circuit_diagram_info_protocol_test.py +++ b/cirq/protocols/circuit_diagram_info_protocol_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +import sympy import cirq @@ -170,3 +171,21 @@ def test_circuit_diagram_info_args_repr(): cirq.LineQubit(0): 5, cirq.LineQubit(1): 7 })) + + +def test_formal_real(): + args = cirq.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT + assert args.format_real(1) == '1' + assert args.format_real(1.1) == '1.1' + assert args.format_real(1.234567) == '1.23' + assert args.format_real(1 / 7) == '0.143' + assert args.format_real(sympy.Symbol('t')) == 't' + assert args.format_real(sympy.Symbol('t') * 2 + 1) == '2*t + 1' + + args.precision = None + assert args.format_real(1) == '1' + assert args.format_real(1.1) == '1.1' + assert args.format_real(1.234567) == '1.234567' + assert args.format_real(1 / 7) == repr(1 / 7) + assert args.format_real(sympy.Symbol('t')) == 't' + assert args.format_real(sympy.Symbol('t') * 2 + 1) == '2*t + 1'