Skip to content

PYTHON-4542 Improved sessions API #2335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
12 changes: 12 additions & 0 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -222,6 +224,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to bind/unbind the session in ClientSession.__enter__/__exit__. That way the stack of sessions is managed correctly (ie we call _SESSION.reset(token)). Think about how nested cases will work:

session1 = client.start_session(bind=True)
with session1:
    session2 = client.start_session(bind=True)
    with session2:
        coll.find_one() # uses session2
    coll.find_one() # uses session1
coll.find_one() # uses implicit session


@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -514,6 +517,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._token = None

async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand Down Expand Up @@ -545,9 +549,14 @@ def _check_ended(self) -> None:
raise InvalidOperation("Cannot use ended session")

async def __aenter__(self) -> AsyncClientSession:
if self._options._bind:
self._token = _SESSION.set(self)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._token:
_SESSION.reset(self._token)
self._token = None
await self._end_session(lock=True)

@property
Expand Down Expand Up @@ -1065,6 +1074,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")


_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 7 additions & 0 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def __init__(
self._killed = False
self._session: Optional[AsyncClientSession]

from .client_session import _SESSION

bound_session = _SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif bound_session:
self._session = bound_session
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
2 changes: 2 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,7 @@ def start_session(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.AsyncClientSession:
"""Start a logical session.

Expand All @@ -1384,6 +1385,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(
Expand Down
12 changes: 12 additions & 0 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -221,6 +223,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind

@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -513,6 +516,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._token = None

def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand Down Expand Up @@ -544,9 +548,14 @@ def _check_ended(self) -> None:
raise InvalidOperation("Cannot use ended session")

def __enter__(self) -> ClientSession:
if self._options._bind:
self._token = _SESSION.set(self)
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._token:
_SESSION.reset(self._token)
self._token = None
self._end_session(lock=True)

@property
Expand Down Expand Up @@ -1060,6 +1069,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A ClientSession cannot be copied, create a new session instead")


_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 7 additions & 0 deletions pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def __init__(
self._killed = False
self._session: Optional[ClientSession]

from .client_session import _SESSION

bound_session = _SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif bound_session:
self._session = bound_session
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
2 changes: 2 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,7 @@ def start_session(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.ClientSession:
"""Start a logical session.

Expand All @@ -1382,6 +1383,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
Expand Down
17 changes: 17 additions & 0 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ async def test_cursor_clone(self):
await cursor.close()
await clone.close()

async def test_bind_session(self):
coll = self.client.pymongo_test.collection

# Explicit session via context variable.
async with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)

# Nested sessions.
session1 = self.client.start_session(bind=True)
async with session1:
session2 = self.client.start_session(bind=True)
async with session2:
await coll.find_one() # uses session2
await coll.find_one() # uses session1
await coll.find_one() # uses implicit session

async def test_cursor(self):
listener = self.listener
client = self.client
Expand Down
17 changes: 17 additions & 0 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ def test_cursor_clone(self):
cursor.close()
clone.close()

def test_bind_session(self):
coll = self.client.pymongo_test.collection

# Explicit session via context variable.
with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)

# Nested sessions.
session1 = self.client.start_session(bind=True)
with session1:
session2 = self.client.start_session(bind=True)
with session2:
coll.find_one() # uses session2
coll.find_one() # uses session1
coll.find_one() # uses implicit session

def test_cursor(self):
listener = self.listener
client = self.client
Expand Down
Loading