diff --git a/qiskit_ibm_runtime/base_primitive.py b/qiskit_ibm_runtime/base_primitive.py index 4dd596e2a..13c759584 100644 --- a/qiskit_ibm_runtime/base_primitive.py +++ b/qiskit_ibm_runtime/base_primitive.py @@ -18,14 +18,17 @@ import copy import logging from dataclasses import asdict +import warnings from qiskit.providers.options import Options as TerraOptions +from qiskit_ibm_provider.session import get_cm_session as get_cm_provider_session + from .options import Options from .options.utils import set_default_error_levels from .runtime_job import RuntimeJob from .ibm_backend import IBMBackend -from .session import get_cm_session +from .utils.default_session import get_cm_session from .constants import DEFAULT_DECODERS from .qiskit_runtime_service import QiskitRuntimeService @@ -118,6 +121,11 @@ def __init__( raise ValueError( "A backend or session must be specified when not using ibm_cloud channel." ) + # Check if initialized within a IBMBackend session. If so, issue a warning. + if get_cm_provider_session(): + warnings.warn( + "A Backend.run() session is open but Primitives will not be run within this session" + ) def _run_primitive(self, primitive_inputs: Dict, user_kwargs: Dict) -> RuntimeJob: """Run the primitive. diff --git a/qiskit_ibm_runtime/ibm_backend.py b/qiskit_ibm_runtime/ibm_backend.py index b90f5789b..1fb60dbd6 100644 --- a/qiskit_ibm_runtime/ibm_backend.py +++ b/qiskit_ibm_runtime/ibm_backend.py @@ -65,6 +65,7 @@ from .utils.backend_converter import ( convert_to_target, ) +from .utils.default_session import get_cm_session as get_cm_primitive_session logger = logging.getLogger(__name__) @@ -751,6 +752,11 @@ def _runtime_run( if self._service._channel == "ibm_quantum": hgp_name = self._instance or self._service._get_hgp().name + # Check if initialized within a Primitive session. If so, issue a warning. + if get_cm_primitive_session(): + warnings.warn( + "A Primitive session is open but Backend.run() jobs will not be run within this session" + ) if self._session: if not self._session.active: raise RuntimeError(f"The session {self._session.session_id} is closed.") diff --git a/qiskit_ibm_runtime/session.py b/qiskit_ibm_runtime/session.py index 36f7a0f62..b5f322639 100644 --- a/qiskit_ibm_runtime/session.py +++ b/qiskit_ibm_runtime/session.py @@ -15,7 +15,6 @@ from typing import Dict, Optional, Type, Union, Callable, Any from types import TracebackType from functools import wraps -from contextvars import ContextVar from qiskit_ibm_provider.utils.converters import hms_to_seconds @@ -24,6 +23,7 @@ from .runtime_program import ParameterNamespace from .program.result_decoder import ResultDecoder from .ibm_backend import IBMBackend +from .utils.default_session import set_cm_session def _active_session(func): # type: ignore @@ -314,19 +314,3 @@ def __exit__( ) -> None: set_cm_session(None) self.close() - - -# Default session -_DEFAULT_SESSION: ContextVar[Optional[Session]] = ContextVar("_DEFAULT_SESSION", default=None) -_IN_SESSION_CM: ContextVar[bool] = ContextVar("_IN_SESSION_CM", default=False) - - -def set_cm_session(session: Optional[Session]) -> None: - """Set the context manager session.""" - _DEFAULT_SESSION.set(session) - _IN_SESSION_CM.set(session is not None) - - -def get_cm_session() -> Session: - """Return the context managed session.""" - return _DEFAULT_SESSION.get() diff --git a/qiskit_ibm_runtime/utils/default_session.py b/qiskit_ibm_runtime/utils/default_session.py new file mode 100644 index 000000000..5049a65fb --- /dev/null +++ b/qiskit_ibm_runtime/utils/default_session.py @@ -0,0 +1,34 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Methods for checking if we are inside a Session context manager""" + +from contextvars import ContextVar +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .session import Session + +# Default session +_DEFAULT_SESSION: ContextVar[Optional["Session"]] = ContextVar("_DEFAULT_SESSION", default=None) +_IN_SESSION_CM: ContextVar[bool] = ContextVar("_IN_SESSION_CM", default=False) + + +def set_cm_session(session: Optional["Session"]) -> None: + """Set the context manager session.""" + _DEFAULT_SESSION.set(session) + _IN_SESSION_CM.set(session is not None) + + +def get_cm_session() -> "Session": + """Return the context managed session.""" + return _DEFAULT_SESSION.get() diff --git a/test/integration/test_session.py b/test/integration/test_session.py index 28238a6a6..b8c086a59 100644 --- a/test/integration/test_session.py +++ b/test/integration/test_session.py @@ -12,6 +12,8 @@ """Integration tests for Session.""" +import warnings + from qiskit.circuit.library import RealAmplitudes from qiskit.quantum_info import SparsePauliOp from qiskit.test.reference_circuits import ReferenceCircuits @@ -127,6 +129,25 @@ def test_backend_run_with_session(self): result.get_counts()["00"], result.get_counts()["11"], delta=shots / 10 ) + def test_backend_and_primitive_in_session(self): + """Test Sampler.run and backend.run in the same session.""" + backend = self.service.get_backend("ibmq_qasm_simulator") + with Session(backend=backend) as session: + sampler = Sampler(session=session) + job1 = sampler.run(circuits=ReferenceCircuits.bell()) + with warnings.catch_warnings(record=True): + job2 = backend.run(circuits=ReferenceCircuits.bell()) + self.assertEqual(job1.session_id, job1.job_id()) + self.assertIsNone(job2.session_id) + with backend.open_session() as session: + with warnings.catch_warnings(record=True): + sampler = Sampler(backend=backend) + job1 = backend.run(ReferenceCircuits.bell()) + job2 = sampler.run(circuits=ReferenceCircuits.bell()) + session_id = session.session_id + self.assertEqual(session_id, job1.job_id()) + self.assertIsNone(job2.session_id) + def test_session_cancel(self): """Test closing a session""" backend = self.service.backend("ibmq_qasm_simulator") diff --git a/test/unit/test_batch.py b/test/unit/test_batch.py index a27eca479..3b54b1094 100644 --- a/test/unit/test_batch.py +++ b/test/unit/test_batch.py @@ -15,7 +15,8 @@ from unittest.mock import patch from qiskit_ibm_runtime import Batch -import qiskit_ibm_runtime.session as session_pkg +from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION + from ..ibm_test_case import IBMTestCase @@ -24,7 +25,7 @@ class TestBatch(IBMTestCase): def tearDown(self) -> None: super().tearDown() - session_pkg._DEFAULT_SESSION.set(None) + _DEFAULT_SESSION.set(None) @patch("qiskit_ibm_runtime.session.QiskitRuntimeService", autospec=True) def test_default_batch(self, mock_service): diff --git a/test/unit/test_ibm_primitives.py b/test/unit/test_ibm_primitives.py index 9314e2a15..5e1e643c3 100644 --- a/test/unit/test_ibm_primitives.py +++ b/test/unit/test_ibm_primitives.py @@ -32,7 +32,7 @@ Session, ) from qiskit_ibm_runtime.ibm_backend import IBMBackend -import qiskit_ibm_runtime.session as session_pkg +from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION from ..ibm_test_case import IBMTestCase from ..utils import ( @@ -61,7 +61,7 @@ def setUpClass(cls): def tearDown(self) -> None: super().tearDown() - session_pkg._DEFAULT_SESSION.set(None) + _DEFAULT_SESSION.set(None) def test_dict_options(self): """Test passing a dictionary as options.""" diff --git a/test/unit/test_session.py b/test/unit/test_session.py index 9ac349c3f..f796f4b40 100644 --- a/test/unit/test_session.py +++ b/test/unit/test_session.py @@ -16,7 +16,7 @@ from qiskit_ibm_runtime import Session from qiskit_ibm_runtime.ibm_backend import IBMBackend -import qiskit_ibm_runtime.session as session_pkg +from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION from .mock.fake_runtime_service import FakeRuntimeService from ..ibm_test_case import IBMTestCase @@ -26,7 +26,7 @@ class TestSession(IBMTestCase): def tearDown(self) -> None: super().tearDown() - session_pkg._DEFAULT_SESSION.set(None) + _DEFAULT_SESSION.set(None) @patch("qiskit_ibm_runtime.session.QiskitRuntimeService", autospec=True) def test_default_service(self, mock_service):