Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client-side validations #1730

Merged
merged 13 commits into from
Jun 18, 2024
3 changes: 2 additions & 1 deletion qiskit_ibm_runtime/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .utils.deprecation import deprecate_arguments, issue_deprecation_msg
from .utils.qctrl import validate as qctrl_validate
from .utils.qctrl import validate_v2 as qctrl_validate_v2

from .utils.validations import validate_estimator_pubs

# pylint: disable=unused-import,cyclic-import
from .session import Session
Expand Down Expand Up @@ -184,6 +184,7 @@ def run(

"""
coerced_pubs = [EstimatorPub.coerce(pub, precision) for pub in pubs]
validate_estimator_pubs(coerced_pubs)
return self._run(coerced_pubs) # type: ignore[arg-type]

def _validate_options(self, options: dict) -> None:
Expand Down
9 changes: 2 additions & 7 deletions qiskit_ibm_runtime/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
from typing import Dict, Optional, Sequence, Any, Union, Iterable
import logging
import warnings

from qiskit.circuit import QuantumCircuit
from qiskit.primitives import BaseSampler
Expand All @@ -36,6 +35,7 @@
from .utils.deprecation import deprecate_arguments, issue_deprecation_msg
from .utils.qctrl import validate as qctrl_validate
from .utils.qctrl import validate_v2 as qctrl_validate_v2
from .utils.validations import validate_classical_registers
from .options import SamplerOptions

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -146,12 +146,7 @@ def run(self, pubs: Iterable[SamplerPubLike], *, shots: int | None = None) -> Ru
"""
coerced_pubs = [SamplerPub.coerce(pub, shots) for pub in pubs]

if any(len(pub.circuit.cregs) == 0 for pub in coerced_pubs):
warnings.warn(
"One of your circuits has no output classical registers and so the result "
"will be empty. Did you mean to add measurement instructions?",
UserWarning,
)
validate_classical_registers(coerced_pubs)

return self._run(coerced_pubs) # type: ignore[arg-type]

Expand Down
72 changes: 72 additions & 0 deletions qiskit_ibm_runtime/utils/validations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Utilities for data validation."""
from typing import List
import warnings
import keyword
from qiskit.primitives.containers.sampler_pub import SamplerPub
from qiskit.primitives.containers.estimator_pub import EstimatorPub


def validate_classical_registers(pubs: List[SamplerPub]) -> None:
"""Validates the classical registers in the pub won't cause problems that can be caught client-side.

Args:
pubs: The list of pubs to validate

Raises:
ValueError: If any circuit has a size-0 creg.
ValueError: If any circuit has a creg whose name is not a valid identifier.
ValueError: If any circuit has a creg whose name is a Python keyword.
"""

for index, pub in enumerate(pubs):
if len(pub.circuit.cregs) == 0:
warnings.warn(
f"The {index}-th circuit has no output classical registers so the result "
"will be empty. Did you mean to add measurement instructions?",
UserWarning,
)

for reg in pub.circuit.cregs:
# size 0 classical register will crash the server-side sampler
if reg.size == 0:
raise ValueError(
f"Classical register {reg.name} is of size 0, which is not allowed"
)
if not reg.name.isidentifier():
raise ValueError(
f"Classical register names must be valid identifiers, but {reg.name} "
f"is not. Valid identifiers contain only alphanumeric letters "
f"(a-z and A-Z), decimal digits (0-9), or underscores (_)"
)
if keyword.iskeyword(reg.name):
raise ValueError(
f"Classical register names cannot be Python keywords, but {reg.name} "
f"is such a keyword. You can see the Python keyword list here: "
f"https://docs.python.org/3/reference/lexical_analysis.html#keywords"
)


def validate_estimator_pubs(pubs: List[EstimatorPub]) -> None:
"""Validates the estimator pubs won't cause problems that can be caught client-side.

Args:
pubs: The list of pubs to validate

Raises:
ValueError: If any observable array is of size 0
"""
for pub in pubs:
if pub.observables.shape == (0,):
raise ValueError("Empty observables array is not allowed")
kt474 marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 9 additions & 0 deletions test/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,12 @@ def test_unsupported_dynamical_decoupling_with_dynamic_circuits(self):
"Dynamical decoupling currently cannot be used with dynamic circuits",
):
inst.run(in_pubs)

def test_estimator_validations(self):
"""Test exceptions when failing client-side validations."""
backend = get_mocked_backend()
inst = EstimatorV2(backend=backend)
circ = QuantumCircuit(2)
obs = []
with self.assertRaisesRegex(ValueError, "Empty observables array is not allowed"):
inst.run(pubs=[(circ, obs)])
27 changes: 26 additions & 1 deletion test/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ddt import data, ddt, named_data
import numpy as np

from qiskit import QuantumCircuit, transpile
from qiskit import QuantumCircuit, transpile, QuantumRegister, ClassicalRegister
from qiskit.primitives.containers.sampler_pub import SamplerPub
from qiskit.circuit.library import RealAmplitudes
from qiskit_ibm_runtime import Sampler, Session, SamplerV2, SamplerOptions, IBMInputValueError
Expand Down Expand Up @@ -152,6 +152,31 @@ def test_run_default_options(self):
f"{inputs} and {expected} not partially equal.",
)

def test_sampler_validations(self):
"""Test exceptions when failing client-side validations."""
with Session(
service=FakeRuntimeService(channel="ibm_quantum", token="abc"),
backend="common_backend",
) as session:
inst = SamplerV2(session=session)
circ = QuantumCircuit(QuantumRegister(2), ClassicalRegister(0))
with self.assertRaisesRegex(ValueError, "Classical register .* is of size 0"):
inst.run([(circ,)])

creg = ClassicalRegister(2, "not-an-identifier")
circ = QuantumCircuit(QuantumRegister(2), creg)
with self.assertRaisesRegex(
ValueError, "Classical register names must be valid identifiers"
):
inst.run([(circ,)])

creg = ClassicalRegister(2, "lambda")
circ = QuantumCircuit(QuantumRegister(2), creg)
with self.assertRaisesRegex(
ValueError, "Classical register names cannot be Python keywords"
):
inst.run([(circ,)])

def test_run_dynamic_circuit_with_fractional_opted(self):
"""Fractional opted backend cannot run dynamic circuits."""
model_backend = FakeFractionalBackend()
Expand Down