Skip to content

gh-76785: Minor Improvements to "interpreters" Module #116328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Lib/test/support/interpreters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ def __hash__(self):
def __del__(self):
self._decref()

# for pickling:
def __getnewargs__(self):
return (self._id,)

# for pickling:
def __getstate__(self):
return None

def _decref(self):
if not self._ownsref:
return
Expand Down
12 changes: 11 additions & 1 deletion Lib/test/support/interpreters/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ class _ChannelEnd:

_end = None

def __init__(self, cid):
def __new__(cls, cid):
self = super().__new__(cls)
if self._end == 'send':
cid = _channels._channel_id(cid, send=True, force=True)
elif self._end == 'recv':
cid = _channels._channel_id(cid, recv=True, force=True)
else:
raise NotImplementedError(self._end)
self._id = cid
return self

def __repr__(self):
return f'{type(self).__name__}(id={int(self._id)})'
Expand All @@ -61,6 +63,14 @@ def __eq__(self, other):
return NotImplemented
return other._id == self._id

# for pickling:
def __getnewargs__(self):
return (int(self._id),)

# for pickling:
def __getstate__(self):
return None

@property
def id(self):
return self._id
Expand Down
31 changes: 16 additions & 15 deletions Lib/test/support/interpreters/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
]


class QueueEmpty(_queues.QueueEmpty, queue.Empty):
class QueueEmpty(QueueError, queue.Empty):
"""Raised from get_nowait() when the queue is empty.

It is also raised from get() if it times out.
"""


class QueueFull(_queues.QueueFull, queue.Full):
class QueueFull(QueueError, queue.Full):
"""Raised from put_nowait() when the queue is full.

It is also raised from put() if it times out.
Expand Down Expand Up @@ -66,7 +66,7 @@ def __new__(cls, id, /, *, _fmt=None):
else:
raise TypeError(f'id must be an int, got {id!r}')
if _fmt is None:
_fmt = _queues.get_default_fmt(id)
_fmt, = _queues.get_queue_defaults(id)
try:
self = _known_queues[id]
except KeyError:
Expand All @@ -93,6 +93,14 @@ def __repr__(self):
def __hash__(self):
return hash(self._id)

# for pickling:
def __getnewargs__(self):
return (self._id,)

# for pickling:
def __getstate__(self):
return None

@property
def id(self):
return self._id
Expand Down Expand Up @@ -159,9 +167,8 @@ def put(self, obj, timeout=None, *,
while True:
try:
_queues.put(self._id, obj, fmt)
except _queues.QueueFull as exc:
except QueueFull as exc:
if timeout is not None and time.time() >= end:
exc.__class__ = QueueFull
raise # re-raise
time.sleep(_delay)
else:
Expand All @@ -174,11 +181,7 @@ def put_nowait(self, obj, *, syncobj=None):
fmt = _SHARED_ONLY if syncobj else _PICKLED
if fmt is _PICKLED:
obj = pickle.dumps(obj)
try:
_queues.put(self._id, obj, fmt)
except _queues.QueueFull as exc:
exc.__class__ = QueueFull
raise # re-raise
_queues.put(self._id, obj, fmt)

def get(self, timeout=None, *,
_delay=10 / 1000, # 10 milliseconds
Expand All @@ -195,9 +198,8 @@ def get(self, timeout=None, *,
while True:
try:
obj, fmt = _queues.get(self._id)
except _queues.QueueEmpty as exc:
except QueueEmpty as exc:
if timeout is not None and time.time() >= end:
exc.__class__ = QueueEmpty
raise # re-raise
time.sleep(_delay)
else:
Expand All @@ -216,8 +218,7 @@ def get_nowait(self):
"""
try:
obj, fmt = _queues.get(self._id)
except _queues.QueueEmpty as exc:
exc.__class__ = QueueEmpty
except QueueEmpty as exc:
raise # re-raise
if fmt == _PICKLED:
obj = pickle.loads(obj)
Expand All @@ -226,4 +227,4 @@ def get_nowait(self):
return obj


_queues._register_queue_type(Queue)
_queues._register_heap_types(Queue, QueueEmpty, QueueFull)
7 changes: 7 additions & 0 deletions Lib/test/test_interpreters/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
import threading
from textwrap import dedent
import unittest
Expand Down Expand Up @@ -261,6 +262,12 @@ def test_equality(self):
self.assertEqual(interp1, interp1)
self.assertNotEqual(interp1, interp2)

def test_pickle(self):
interp = interpreters.create()
data = pickle.dumps(interp)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, interp)


class TestInterpreterIsRunning(TestBase):

Expand Down
13 changes: 13 additions & 0 deletions Lib/test/test_interpreters/test_channels.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import pickle
import threading
from textwrap import dedent
import unittest
Expand Down Expand Up @@ -100,6 +101,12 @@ def test_equality(self):
self.assertEqual(ch1, ch1)
self.assertNotEqual(ch1, ch2)

def test_pickle(self):
ch, _ = channels.create()
data = pickle.dumps(ch)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, ch)


class TestSendChannelAttrs(TestBase):

Expand All @@ -125,6 +132,12 @@ def test_equality(self):
self.assertEqual(ch1, ch1)
self.assertNotEqual(ch1, ch2)

def test_pickle(self):
_, ch = channels.create()
data = pickle.dumps(ch)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, ch)


class TestSendRecv(TestBase):

Expand Down
71 changes: 67 additions & 4 deletions Lib/test/test_interpreters/test_queues.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import importlib
import pickle
import threading
from textwrap import dedent
import unittest
import time

from test.support import import_helper
from test.support import import_helper, Py_DEBUG
# Raise SkipTest if subinterpreters not supported.
_queues = import_helper.import_module('_xxinterpqueues')
from test.support import interpreters
from test.support.interpreters import queues
from .utils import _run_output, TestBase
from .utils import _run_output, TestBase as _TestBase


class TestBase(TestBase):
def get_num_queues():
return len(_queues.list_all())


class TestBase(_TestBase):
def tearDown(self):
for qid in _queues.list_all():
for qid, _ in _queues.list_all():
try:
_queues.destroy(qid)
except Exception:
Expand All @@ -34,6 +39,58 @@ def test_highlevel_reloaded(self):
# See gh-115490 (https://github.com/python/cpython/issues/115490).
importlib.reload(queues)

def test_create_destroy(self):
qid = _queues.create(2, 0)
_queues.destroy(qid)
self.assertEqual(get_num_queues(), 0)
with self.assertRaises(queues.QueueNotFoundError):
_queues.get(qid)
with self.assertRaises(queues.QueueNotFoundError):
_queues.destroy(qid)

def test_not_destroyed(self):
# It should have cleaned up any remaining queues.
stdout, stderr = self.assert_python_ok(
'-c',
dedent(f"""
import {_queues.__name__} as _queues
_queues.create(2, 0)
"""),
)
self.assertEqual(stdout, '')
if Py_DEBUG:
self.assertNotEqual(stderr, '')
else:
self.assertEqual(stderr, '')

def test_bind_release(self):
with self.subTest('typical'):
qid = _queues.create(2, 0)
_queues.bind(qid)
_queues.release(qid)
self.assertEqual(get_num_queues(), 0)

with self.subTest('bind too much'):
qid = _queues.create(2, 0)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
_queues.destroy(qid)
self.assertEqual(get_num_queues(), 0)

with self.subTest('nested'):
qid = _queues.create(2, 0)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
_queues.release(qid)
self.assertEqual(get_num_queues(), 0)

with self.subTest('release without binding'):
qid = _queues.create(2, 0)
with self.assertRaises(queues.QueueError):
_queues.release(qid)


class QueueTests(TestBase):

Expand Down Expand Up @@ -127,6 +184,12 @@ def test_equality(self):
self.assertEqual(queue1, queue1)
self.assertNotEqual(queue1, queue2)

def test_pickle(self):
queue = queues.create()
data = pickle.dumps(queue)
unpickled = pickle.loads(data)
self.assertEqual(unpickled, queue)


class TestQueueOps(TestBase):

Expand Down
8 changes: 8 additions & 0 deletions Modules/_interpreters_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,11 @@ ensure_xid_class(PyTypeObject *cls, crossinterpdatafunc getdata)
//assert(cls->tp_flags & Py_TPFLAGS_HEAPTYPE);
return _PyCrossInterpreterData_RegisterClass(cls, getdata);
}

#ifdef REGISTERS_HEAP_TYPES
static int
clear_xid_class(PyTypeObject *cls)
{
return _PyCrossInterpreterData_UnregisterClass(cls);
}
#endif
14 changes: 8 additions & 6 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include <sched.h> // sched_yield()
#endif

#define REGISTERS_HEAP_TYPES
#include "_interpreters_common.h"
#undef REGISTERS_HEAP_TYPES


/*
Expand Down Expand Up @@ -281,17 +283,17 @@ clear_xid_types(module_state *state)
{
/* external types */
if (state->send_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
(void)clear_xid_class(state->send_channel_type);
Py_CLEAR(state->send_channel_type);
}
if (state->recv_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type);
(void)clear_xid_class(state->recv_channel_type);
Py_CLEAR(state->recv_channel_type);
}

/* heap types */
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
(void)clear_xid_class(state->ChannelIDType);
Py_CLEAR(state->ChannelIDType);
}
}
Expand Down Expand Up @@ -2677,11 +2679,11 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)

// Clear the old values if the .py module was reloaded.
if (state->send_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
(void)clear_xid_class(state->send_channel_type);
Py_CLEAR(state->send_channel_type);
}
if (state->recv_channel_type != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type);
(void)clear_xid_class(state->recv_channel_type);
Py_CLEAR(state->recv_channel_type);
}

Expand All @@ -2694,7 +2696,7 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
return -1;
}
if (ensure_xid_class(recv, _channelend_shared) < 0) {
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
(void)clear_xid_class(state->send_channel_type);
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);
return -1;
Expand Down
Loading