Skip to content

Commit

Permalink
Make sure backend.run and Sampler.run can run in a single session (#1203
Browse files Browse the repository at this point in the history
)

* Added test for backend.run and Sampler.run in a single session

* Issue warning when Primitive is run within a backend session

* Fixed comment

* black

* Added warning for when a backend is run in a primitive session and vice versa.

* Fixed imports from default_session

* minor wording change

---------

Co-authored-by: Kevin Tian <kevin.tian@ibm.com>
  • Loading branch information
merav-aharoni and kt474 authored Nov 14, 2023
1 parent 030c54e commit 6cc8197
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 24 deletions.
10 changes: 9 additions & 1 deletion qiskit_ibm_runtime/base_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import copy
import logging
from dataclasses import asdict
import warnings

from qiskit.providers.options import Options as TerraOptions

from qiskit_ibm_provider.session import get_cm_session as get_cm_provider_session

from .options import Options
from .options.utils import set_default_error_levels
from .runtime_job import RuntimeJob
from .ibm_backend import IBMBackend
from .session import get_cm_session
from .utils.default_session import get_cm_session
from .constants import DEFAULT_DECODERS
from .qiskit_runtime_service import QiskitRuntimeService

Expand Down Expand Up @@ -118,6 +121,11 @@ def __init__(
raise ValueError(
"A backend or session must be specified when not using ibm_cloud channel."
)
# Check if initialized within a IBMBackend session. If so, issue a warning.
if get_cm_provider_session():
warnings.warn(
"A Backend.run() session is open but Primitives will not be run within this session"
)

def _run_primitive(self, primitive_inputs: Dict, user_kwargs: Dict) -> RuntimeJob:
"""Run the primitive.
Expand Down
6 changes: 6 additions & 0 deletions qiskit_ibm_runtime/ibm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .utils.backend_converter import (
convert_to_target,
)
from .utils.default_session import get_cm_session as get_cm_primitive_session

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -751,6 +752,11 @@ def _runtime_run(
if self._service._channel == "ibm_quantum":
hgp_name = self._instance or self._service._get_hgp().name

# Check if initialized within a Primitive session. If so, issue a warning.
if get_cm_primitive_session():
warnings.warn(
"A Primitive session is open but Backend.run() jobs will not be run within this session"
)
if self._session:
if not self._session.active:
raise RuntimeError(f"The session {self._session.session_id} is closed.")
Expand Down
18 changes: 1 addition & 17 deletions qiskit_ibm_runtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Dict, Optional, Type, Union, Callable, Any
from types import TracebackType
from functools import wraps
from contextvars import ContextVar

from qiskit_ibm_provider.utils.converters import hms_to_seconds

Expand All @@ -24,6 +23,7 @@
from .runtime_program import ParameterNamespace
from .program.result_decoder import ResultDecoder
from .ibm_backend import IBMBackend
from .utils.default_session import set_cm_session


def _active_session(func): # type: ignore
Expand Down Expand Up @@ -314,19 +314,3 @@ def __exit__(
) -> None:
set_cm_session(None)
self.close()


# Default session
_DEFAULT_SESSION: ContextVar[Optional[Session]] = ContextVar("_DEFAULT_SESSION", default=None)
_IN_SESSION_CM: ContextVar[bool] = ContextVar("_IN_SESSION_CM", default=False)


def set_cm_session(session: Optional[Session]) -> None:
"""Set the context manager session."""
_DEFAULT_SESSION.set(session)
_IN_SESSION_CM.set(session is not None)


def get_cm_session() -> Session:
"""Return the context managed session."""
return _DEFAULT_SESSION.get()
34 changes: 34 additions & 0 deletions qiskit_ibm_runtime/utils/default_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Methods for checking if we are inside a Session context manager"""

from contextvars import ContextVar
from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from .session import Session

# Default session
_DEFAULT_SESSION: ContextVar[Optional["Session"]] = ContextVar("_DEFAULT_SESSION", default=None)
_IN_SESSION_CM: ContextVar[bool] = ContextVar("_IN_SESSION_CM", default=False)


def set_cm_session(session: Optional["Session"]) -> None:
"""Set the context manager session."""
_DEFAULT_SESSION.set(session)
_IN_SESSION_CM.set(session is not None)


def get_cm_session() -> "Session":
"""Return the context managed session."""
return _DEFAULT_SESSION.get()
21 changes: 21 additions & 0 deletions test/integration/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

"""Integration tests for Session."""

import warnings

from qiskit.circuit.library import RealAmplitudes
from qiskit.quantum_info import SparsePauliOp
from qiskit.test.reference_circuits import ReferenceCircuits
Expand Down Expand Up @@ -127,6 +129,25 @@ def test_backend_run_with_session(self):
result.get_counts()["00"], result.get_counts()["11"], delta=shots / 10
)

def test_backend_and_primitive_in_session(self):
"""Test Sampler.run and backend.run in the same session."""
backend = self.service.get_backend("ibmq_qasm_simulator")
with Session(backend=backend) as session:
sampler = Sampler(session=session)
job1 = sampler.run(circuits=ReferenceCircuits.bell())
with warnings.catch_warnings(record=True):
job2 = backend.run(circuits=ReferenceCircuits.bell())
self.assertEqual(job1.session_id, job1.job_id())
self.assertIsNone(job2.session_id)
with backend.open_session() as session:
with warnings.catch_warnings(record=True):
sampler = Sampler(backend=backend)
job1 = backend.run(ReferenceCircuits.bell())
job2 = sampler.run(circuits=ReferenceCircuits.bell())
session_id = session.session_id
self.assertEqual(session_id, job1.job_id())
self.assertIsNone(job2.session_id)

def test_session_cancel(self):
"""Test closing a session"""
backend = self.service.backend("ibmq_qasm_simulator")
Expand Down
5 changes: 3 additions & 2 deletions test/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from unittest.mock import patch

from qiskit_ibm_runtime import Batch
import qiskit_ibm_runtime.session as session_pkg
from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION

from ..ibm_test_case import IBMTestCase


Expand All @@ -24,7 +25,7 @@ class TestBatch(IBMTestCase):

def tearDown(self) -> None:
super().tearDown()
session_pkg._DEFAULT_SESSION.set(None)
_DEFAULT_SESSION.set(None)

@patch("qiskit_ibm_runtime.session.QiskitRuntimeService", autospec=True)
def test_default_batch(self, mock_service):
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_ibm_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Session,
)
from qiskit_ibm_runtime.ibm_backend import IBMBackend
import qiskit_ibm_runtime.session as session_pkg
from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION

from ..ibm_test_case import IBMTestCase
from ..utils import (
Expand Down Expand Up @@ -61,7 +61,7 @@ def setUpClass(cls):

def tearDown(self) -> None:
super().tearDown()
session_pkg._DEFAULT_SESSION.set(None)
_DEFAULT_SESSION.set(None)

def test_dict_options(self):
"""Test passing a dictionary as options."""
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from qiskit_ibm_runtime import Session
from qiskit_ibm_runtime.ibm_backend import IBMBackend
import qiskit_ibm_runtime.session as session_pkg
from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION
from .mock.fake_runtime_service import FakeRuntimeService
from ..ibm_test_case import IBMTestCase

Expand All @@ -26,7 +26,7 @@ class TestSession(IBMTestCase):

def tearDown(self) -> None:
super().tearDown()
session_pkg._DEFAULT_SESSION.set(None)
_DEFAULT_SESSION.set(None)

@patch("qiskit_ibm_runtime.session.QiskitRuntimeService", autospec=True)
def test_default_service(self, mock_service):
Expand Down

0 comments on commit 6cc8197

Please sign in to comment.