Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing.Result and typing.ResultBatch type variables #4108

Merged
merged 15 commits into from
May 16, 2023
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
[(#4079)](https://github.com/PennyLaneAI/pennylane/pull/4079)
[(#4095)](https://github.com/PennyLaneAI/pennylane/pull/4095)

* Adds the Type variables `pennylane.typing.Result` and `pennylane.typing.ResultBatch` for type hinting the result of
an execution.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

<h3>Breaking changes 💔</h3>

* The experimental Device interface `qml.devices.experimental.Device` now requires that the `preprocess` method
Expand Down
8 changes: 5 additions & 3 deletions pennylane/devices/experimental/default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from typing import Union, Callable, Tuple, Optional, Sequence

from pennylane.tape import QuantumTape, QuantumScript
from pennylane.typing import Result, ResultBatch

from . import Device
from .execution_config import ExecutionConfig, DefaultExecutionConfig
from ..qubit.simulate import simulate
from ..qubit.preprocess import preprocess, validate_and_expand_adjoint
from ..qubit.adjoint_jacobian import adjoint_jacobian

Result_or_Batch = Union[Result, ResultBatch]
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
QuantumTapeBatch = Sequence[QuantumTape]
QuantumTape_or_Batch = Union[QuantumTape, QuantumTapeBatch]

Expand Down Expand Up @@ -126,7 +128,7 @@ def preprocess(
self,
circuits: QuantumTape_or_Batch,
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Tuple[QuantumTapeBatch, Callable]:
) -> Tuple[QuantumTapeBatch, Callable[[ResultBatch], Result_or_Batch]]:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Converts an arbitrary circuit or batch of circuits into a batch natively executable by the :meth:`~.execute` method.

Args:
Expand Down Expand Up @@ -155,7 +157,7 @@ def preprocess(

if is_single_circuit:

def convert_batch_to_single_output(results):
def convert_batch_to_single_output(results: ResultBatch) -> Result:
"""Unwraps a dimension so that executing the batch of circuits looks like executing a single circuit."""
return post_processing_fn(results)[0]

Expand All @@ -167,7 +169,7 @@ def execute(
self,
circuits: QuantumTape_or_Batch,
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
) -> Result_or_Batch:
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
Expand Down
7 changes: 4 additions & 3 deletions pennylane/devices/experimental/device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Callable, Union, Sequence, Tuple, Optional

from pennylane.tape import QuantumTape
from pennylane.typing import Result, ResultBatch
from pennylane import Tracker

from .execution_config import ExecutionConfig, DefaultExecutionConfig
Expand Down Expand Up @@ -139,7 +140,7 @@ def preprocess(
self,
circuits: QuantumTape_or_Batch,
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Tuple[QuantumTapeBatch, Callable, ExecutionConfig]:
) -> Tuple[QuantumTapeBatch, Callable[[ResultBatch], ResultBatch], ExecutionConfig]:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Device preprocessing function.

.. warning::
Expand Down Expand Up @@ -173,7 +174,7 @@ def preprocess(

"""

def blank_postprocessing_fn(res):
def blank_postprocessing_fn(res: ResultBatch) -> ResultBatch:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Identity postprocessing function created in Device preprocessing.

Args:
Expand All @@ -193,7 +194,7 @@ def execute(
self,
circuits: QuantumTape_or_Batch,
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
) -> Union[Result, ResultBatch]:
"""Execute a circuit or a batch of circuits and turn it into results.

Args:
Expand Down
7 changes: 4 additions & 3 deletions pennylane/devices/qubit/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pennylane.operation import Tensor
from pennylane.measurements import MidMeasureMP, StateMeasurement, ExpectationMP
from pennylane.typing import ResultBatch, Result
from pennylane import DeviceError

from ..experimental import ExecutionConfig, DefaultExecutionConfig
Expand Down Expand Up @@ -194,7 +195,7 @@ def expand_fn(circuit: qml.tape.QuantumScript) -> qml.tape.QuantumScript:

def batch_transform(
circuit: qml.tape.QuantumScript,
) -> Tuple[Tuple[qml.tape.QuantumScript], Callable]:
) -> Tuple[Tuple[qml.tape.QuantumScript], Callable[[ResultBatch], Union[Result, ResultBatch]]]:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Apply a differentiable batch transform for preprocessing a circuit
prior to execution.

Expand All @@ -221,7 +222,7 @@ def batch_transform(
# If the circuit wasn't broadcasted, no action required
circuits = [circuit]

def batch_fn(res):
def batch_fn(res: ResultBatch) -> Result:
"""A post-processing function to convert the results of a batch of
executions into the result of a single executiion."""
return res[0]
Expand All @@ -237,7 +238,7 @@ def batch_fn(res):
def preprocess(
circuits: Tuple[qml.tape.QuantumScript],
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Tuple[Tuple[qml.tape.QuantumScript], Callable]:
) -> Tuple[Tuple[qml.tape.QuantumScript], Callable[[ResultBatch], ResultBatch]]:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Preprocess a batch of :class:`~.QuantumTape` objects to make them ready for execution.

This function validates a batch of :class:`~.QuantumTape` objects by transforming and expanding
Expand Down
6 changes: 2 additions & 4 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simulate a quantum script."""
from typing import Union

# pylint: disable=protected-access
import pennylane as qml
from pennylane.typing import TensorLike
from pennylane.typing import Result

from .initialize_state import create_initial_state
from .apply_operation import apply_operation
from .measure import measure


def simulate(circuit: qml.tape.QuantumScript) -> Union[tuple, TensorLike]:
def simulate(circuit: qml.tape.QuantumScript) -> Result:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Simulate a single quantum script.

This is an internal function that will be called by the successor to ``default.qubit``.
Expand Down
2 changes: 1 addition & 1 deletion pennylane/interfaces/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def execute(
expand_fn="device",
max_expansion=10,
device_batch_transform=True,
):
) -> qml.typing.ResultBatch:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""New function to execute a batch of tapes on a device in an autodifferentiable-compatible manner. More cases will be added,
during the project. The current version is supporting forward execution for Numpy and does not support shot vectors.

Expand Down
6 changes: 4 additions & 2 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
This module contains the QNode class and qnode decorator.
"""
# pylint: disable=too-many-instance-attributes,too-many-arguments,protected-access,unnecessary-lambda-assignment
# pylint: disable=too-many-instance-attributes,too-many-arguments,protected-access,unnecessary-lambda-assignment, too-many-branches, too-many-statements
import functools
import inspect
import warnings
Expand Down Expand Up @@ -824,7 +824,9 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches
if old_interface == "auto":
self.interface = "auto"

def __call__(self, *args, **kwargs): # pylint: disable=too-many-branches, too-many-statements
def __call__(
self, *args, **kwargs
) -> qml.typing.Result: # pylint: disable=too-many-branches, too-many-statements
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
override_shots = False
old_interface = self.interface

Expand Down
10 changes: 8 additions & 2 deletions pennylane/transforms/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import types
import warnings

from typing import Callable, Tuple

import pennylane as qml


Expand Down Expand Up @@ -449,7 +451,11 @@ def _tape_wrapper(self, *targs, **tkwargs):
return lambda tape: self.construct(tape, *targs, **tkwargs)


def map_batch_transform(transform, tapes):
def map_batch_transform(
transform, tapes: Tuple[qml.tape.QuantumScript]
) -> Tuple[
Tuple[qml.tape.QuantumScript], Callable[[qml.typing.ResultBatch], qml.typing.ResultBatch]
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
]:
"""Map a batch transform over multiple tapes.

Args:
Expand Down Expand Up @@ -501,7 +507,7 @@ def map_batch_transform(transform, tapes):
batch_fns.append(fn)
tape_counts.append(len(new_tapes))

def processing_fn(res):
def processing_fn(res: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
count = 0
final_results = []

Expand Down
2 changes: 1 addition & 1 deletion pennylane/transforms/broadcast_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def broadcast_expand(tape):
new_tape.set_parameters(p, trainable_only=False)
output_tapes.append(new_tape)

def processing_fn(results):
def processing_fn(results: qml.typing.ResultBatch) -> qml.typing.Result:
if len(tape.measurements) > 1 and qml.active_return():
processed_results = [
qml.math.squeeze(qml.math.stack([results[b][i] for b in range(tape.batch_size)]))
Expand Down
11 changes: 9 additions & 2 deletions pennylane/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# pylint: disable=import-outside-toplevel, too-few-public-methods
import sys
from typing import Union
from typing import Union, TypeVar, Tuple

import numpy as np
from autograd.numpy.numpy_boxes import ArrayBox
Expand Down Expand Up @@ -83,7 +83,9 @@ def _is_jax(other, subclass=False):
ndarray,
jax.Array # TODO: keep this after jax>=0.4 is required
if hasattr(jax, "Array")
else Union[jaxlib.xla_extension.DeviceArray, jax.core.Tracer],
else Union[
jaxlib.xla_extension.DeviceArray, jax.core.Tracer
], # pylint: disable=c-extension-no-member
]
check = issubclass if subclass else isinstance

Expand Down Expand Up @@ -114,3 +116,8 @@ def _is_torch(other, subclass=False):

return check(other, torchTensor)
return False


Result = TypeVar("Result", Tuple, TensorLike)

ResultBatch = Tuple[Result]