diff --git a/cirq-core/cirq/transformers/transformer_api.py b/cirq-core/cirq/transformers/transformer_api.py index f1897a1936e..484e731ef42 100644 --- a/cirq-core/cirq/transformers/transformer_api.py +++ b/cirq-core/cirq/transformers/transformer_api.py @@ -14,21 +14,21 @@ """Defines the API for circuit transformers in Cirq.""" -import textwrap +import dataclasses +import enum import functools +import textwrap from typing import ( Any, - Callable, Tuple, Hashable, List, - Type, overload, + Type, TYPE_CHECKING, + TypeVar, ) -import dataclasses -import enum -from cirq.circuits.circuit import CIRCUIT_TYPE +from typing_extensions import Protocol if TYPE_CHECKING: import cirq @@ -218,96 +218,95 @@ class TransformerContext: ignore_tags: Tuple[Hashable, ...] = () -TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit'] -_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE] - - -def _transform_and_log( - func: _TRANSFORMER_TYPE[CIRCUIT_TYPE], - transformer_name: str, - circuit: 'cirq.AbstractCircuit', - context: TransformerContext, -) -> CIRCUIT_TYPE: - """Helper to log initial and final circuits before and after calling the transformer.""" - - context.logger.register_initial(circuit, transformer_name) - transformed_circuit = func(circuit, context) - context.logger.register_final(transformed_circuit, transformer_name) - return transformed_circuit - +class TRANSFORMER(Protocol): + def __call__( + self, circuit: 'cirq.AbstractCircuit', context: TransformerContext + ) -> 'cirq.AbstractCircuit': + ... -def _transformer_class( - cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], -) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]: - old_func = cls.__call__ - def transformer_with_logging_cls( - self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], - circuit: 'cirq.AbstractCircuit', - context: TransformerContext, - ) -> CIRCUIT_TYPE: - def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE: - return old_func(self, c, ct) - - return _transform_and_log(call_old_func, cls.__name__, circuit, context) - - setattr(cls, '__call__', transformer_with_logging_cls) - return cls - - -def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]: - @functools.wraps(func) - def transformer_with_logging_func( - circuit: 'cirq.AbstractCircuit', - context: TransformerContext, - ) -> CIRCUIT_TYPE: - return _transform_and_log(func, func.__name__, circuit, context) - - return transformer_with_logging_func +_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER) +_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER]) @overload -def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]: +def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T: pass @overload -def transformer( - cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], -) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]: +def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T: pass def transformer(cls_or_func: Any) -> Any: """Decorator to verify API and append logging functionality to transformer functions & classes. - The decorated function or class must satisfy - `Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example: + A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and + `cirq.TransformerContext`, and returns another `cirq.AbstractCircuit` without + modifying the input circuit. A transformer could be a function, for example: >>> @cirq.transformer - >>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit: + >>> def convert_to_cz( + >>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext + >>> ) -> cirq.Circuit: >>> ... - The decorated class must implement the `__call__` method to satisfy the above API. + Or it could be a class that implements `__call__` with the same API, for example: >>> @cirq.transformer >>> class ConvertToSqrtISwaps: >>> def __init__(self): >>> ... >>> def __call__( - >>> self, circuit: cirq.Circuit, context: cirq.TransformerContext + >>> self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext >>> ) -> cirq.Circuit: >>> ... Args: - cls_or_func: The callable class or method to be decorated. + cls_or_func: The callable class or function to be decorated. Returns: - Decorated class / method which includes additional logging boilerplate. The decorated - callable always receives a copy of the input circuit so that the input is never mutated. + Decorated class / function which includes additional logging boilerplate. """ if isinstance(cls_or_func, type): - return _transformer_class(cls_or_func) + cls = cls_or_func + method = cls.__call__ + + @functools.wraps(method) + def method_with_logging( + self, circuit: 'cirq.AbstractCircuit', context: TransformerContext + ) -> 'cirq.AbstractCircuit': + return _transform_and_log( + lambda circuit, context: method(self, circuit, context), + cls.__name__, + circuit, + context, + ) + + setattr(cls, '__call__', method_with_logging) + return cls else: assert callable(cls_or_func) - return _transformer_func(cls_or_func) + func = cls_or_func + + @functools.wraps(func) + def func_with_logging( + circuit: 'cirq.AbstractCircuit', context: TransformerContext + ) -> 'cirq.AbstractCircuit': + return _transform_and_log(func, func.__name__, circuit, context) + + return func_with_logging + + +def _transform_and_log( + func: TRANSFORMER, + transformer_name: str, + circuit: 'cirq.AbstractCircuit', + context: TransformerContext, +) -> 'cirq.AbstractCircuit': + """Helper to log initial and final circuits before and after calling the transformer.""" + context.logger.register_initial(circuit, transformer_name) + transformed_circuit = func(circuit, context) + context.logger.register_final(transformed_circuit, transformer_name) + return transformed_circuit