Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure backend.run and Sampler.run can run in a single session #1203

Merged
merged 14 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion qiskit_ibm_runtime/base_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import copy
import logging
from dataclasses import asdict
import warnings

from qiskit.providers.options import Options as TerraOptions

from qiskit_ibm_provider.session import get_cm_session as get_cm_provider_session

from .options import Options
from .options.utils import set_default_error_levels
from .runtime_job import RuntimeJob
from .ibm_backend import IBMBackend
from .session import get_cm_session
from .utils.default_session import get_cm_session
from .constants import DEFAULT_DECODERS
from .qiskit_runtime_service import QiskitRuntimeService

Expand Down Expand Up @@ -118,6 +121,11 @@ def __init__(
raise ValueError(
"A backend or session must be specified when not using ibm_cloud channel."
)
# Check if initialized within a IBMBackend session. If so, issue a warning.
if get_cm_provider_session():
warnings.warn(
"A Backend.run() session is open but Primitives will not be run within this session"
)

def _run_primitive(self, primitive_inputs: Dict, user_kwargs: Dict) -> RuntimeJob:
"""Run the primitive.
Expand Down
6 changes: 6 additions & 0 deletions qiskit_ibm_runtime/ibm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .utils.backend_converter import (
convert_to_target,
)
from .utils.default_session import get_cm_session as get_cm_primitive_session

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -751,6 +752,11 @@ def _runtime_run(
if self._service._channel == "ibm_quantum":
hgp_name = self._instance or self._service._get_hgp().name

# Check if initialized within a Primitive session. If so, issue a warning.
if get_cm_primitive_session():
warnings.warn(
"A Primitive session is open but Backend.run() jobs will not be run within this session"
)
if self._session:
if not self._session.active:
raise RuntimeError(f"The session {self._session.session_id} is closed.")
Expand Down
18 changes: 1 addition & 17 deletions qiskit_ibm_runtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Dict, Optional, Type, Union, Callable, Any
from types import TracebackType
from functools import wraps
from contextvars import ContextVar

from qiskit_ibm_provider.utils.converters import hms_to_seconds

Expand All @@ -24,6 +23,7 @@
from .runtime_program import ParameterNamespace
from .program.result_decoder import ResultDecoder
from .ibm_backend import IBMBackend
from .utils.default_session import set_cm_session


def _active_session(func): # type: ignore
Expand Down Expand Up @@ -314,19 +314,3 @@ def __exit__(
) -> None:
set_cm_session(None)
self.close()


# Default session
_DEFAULT_SESSION: ContextVar[Optional[Session]] = ContextVar("_DEFAULT_SESSION", default=None)
_IN_SESSION_CM: ContextVar[bool] = ContextVar("_IN_SESSION_CM", default=False)


def set_cm_session(session: Optional[Session]) -> None:
"""Set the context manager session."""
_DEFAULT_SESSION.set(session)
_IN_SESSION_CM.set(session is not None)


def get_cm_session() -> Session:
"""Return the context managed session."""
return _DEFAULT_SESSION.get()
34 changes: 34 additions & 0 deletions qiskit_ibm_runtime/utils/default_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Methods for checking if we are inside a Session context manager"""

from contextvars import ContextVar
from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from .session import Session

# Default session
_DEFAULT_SESSION: ContextVar[Optional["Session"]] = ContextVar("_DEFAULT_SESSION", default=None)
_IN_SESSION_CM: ContextVar[bool] = ContextVar("_IN_SESSION_CM", default=False)


def set_cm_session(session: Optional["Session"]) -> None:
"""Set the context manager session."""
_DEFAULT_SESSION.set(session)
_IN_SESSION_CM.set(session is not None)


def get_cm_session() -> "Session":
"""Return the context managed session."""
return _DEFAULT_SESSION.get()
21 changes: 21 additions & 0 deletions test/integration/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

"""Integration tests for Session."""

import warnings

from qiskit.circuit.library import RealAmplitudes
from qiskit.quantum_info import SparsePauliOp
from qiskit.test.reference_circuits import ReferenceCircuits
Expand Down Expand Up @@ -127,6 +129,25 @@ def test_backend_run_with_session(self):
result.get_counts()["00"], result.get_counts()["11"], delta=shots / 10
)

def test_backend_and_primitive_in_session(self):
"""Test Sampler.run and backend.run in the same session."""
backend = self.service.get_backend("ibmq_qasm_simulator")
with Session(backend=backend) as session:
sampler = Sampler(session=session)
job1 = sampler.run(circuits=ReferenceCircuits.bell())
with warnings.catch_warnings(record=True):
job2 = backend.run(circuits=ReferenceCircuits.bell())
self.assertEqual(job1.session_id, job1.job_id())
self.assertIsNone(job2.session_id)
with backend.open_session() as session:
with warnings.catch_warnings(record=True):
sampler = Sampler(backend=backend)
job1 = backend.run(ReferenceCircuits.bell())
job2 = sampler.run(circuits=ReferenceCircuits.bell())
session_id = session.session_id
self.assertEqual(session_id, job1.job_id())
self.assertIsNone(job2.session_id)

def test_session_cancel(self):
"""Test closing a session"""
backend = self.service.backend("ibmq_qasm_simulator")
Expand Down
5 changes: 3 additions & 2 deletions test/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from unittest.mock import patch

from qiskit_ibm_runtime import Batch
import qiskit_ibm_runtime.session as session_pkg
from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION

from ..ibm_test_case import IBMTestCase


Expand All @@ -24,7 +25,7 @@ class TestBatch(IBMTestCase):

def tearDown(self) -> None:
super().tearDown()
session_pkg._DEFAULT_SESSION.set(None)
_DEFAULT_SESSION.set(None)

@patch("qiskit_ibm_runtime.session.QiskitRuntimeService", autospec=True)
def test_default_batch(self, mock_service):
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_ibm_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Session,
)
from qiskit_ibm_runtime.ibm_backend import IBMBackend
import qiskit_ibm_runtime.session as session_pkg
from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION

from ..ibm_test_case import IBMTestCase
from ..utils import (
Expand Down Expand Up @@ -61,7 +61,7 @@ def setUpClass(cls):

def tearDown(self) -> None:
super().tearDown()
session_pkg._DEFAULT_SESSION.set(None)
_DEFAULT_SESSION.set(None)

def test_dict_options(self):
"""Test passing a dictionary as options."""
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from qiskit_ibm_runtime import Session
from qiskit_ibm_runtime.ibm_backend import IBMBackend
import qiskit_ibm_runtime.session as session_pkg
from qiskit_ibm_runtime.utils.default_session import _DEFAULT_SESSION
from .mock.fake_runtime_service import FakeRuntimeService
from ..ibm_test_case import IBMTestCase

Expand All @@ -26,7 +26,7 @@ class TestSession(IBMTestCase):

def tearDown(self) -> None:
super().tearDown()
session_pkg._DEFAULT_SESSION.set(None)
_DEFAULT_SESSION.set(None)

@patch("qiskit_ibm_runtime.session.QiskitRuntimeService", autospec=True)
def test_default_service(self, mock_service):
Expand Down
Loading