Skip to content

Commit

Permalink
apacheGH-37254: [Python] Parametrize all pickling tests to use both t…
Browse files Browse the repository at this point in the history
…he pickle and cloudpickle modules (apache#37255)

### Rationale for this change

Cloudpickle was not tested in most parts of the pyarrow test suite. Improving this coverage will make the Cython 3.0.0 upgrade cleaner as cloudpickle was failing in a few places where the default pickle module was not. This has been verified using Cython 0.29.36.

### What changes are included in this PR?

* `__reduce__` methods that need to pass kwargs have been changed from classmethod to staticmethod
* All pytests that pickle objects are parameterized to use both `pickle` and `cloudpickle`

### Are these changes tested?

Yes, pytests run successfully with Cython 0.29.36

### Are there any user-facing changes?

No.
* Closes: apache#37254

Authored-by: Dane Pitkin <dane@voltrondata.com>
Signed-off-by: Sutou Kouhei <kou@clear-code.com>
  • Loading branch information
danepitkin authored Aug 24, 2023
1 parent 0be17e6 commit 175b2a2
Show file tree
Hide file tree
Showing 21 changed files with 372 additions and 308 deletions.
12 changes: 8 additions & 4 deletions python/pyarrow/_dataset_parquet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

"""Dataset support for Parquest file format."""

from cython cimport binding
from cython.operator cimport dereference as deref

import os
Expand Down Expand Up @@ -770,9 +771,12 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions):
other.thrift_container_size_limit)
return attrs == other_attrs

@classmethod
def _reconstruct(cls, kwargs):
return cls(**kwargs)
@staticmethod
@binding(True) # Required for Cython < 3
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return ParquetFragmentScanOptions(**kwargs)

def __reduce__(self):
kwargs = dict(
Expand All @@ -782,7 +786,7 @@ cdef class ParquetFragmentScanOptions(FragmentScanOptions):
thrift_string_size_limit=self.thrift_string_size_limit,
thrift_container_size_limit=self.thrift_container_size_limit,
)
return type(self)._reconstruct, (kwargs,)
return ParquetFragmentScanOptions._reconstruct, (kwargs,)


cdef class ParquetFactoryOptions(_Weakrefable):
Expand Down
8 changes: 5 additions & 3 deletions python/pyarrow/_fs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# cython: language_level = 3

from cpython.datetime cimport datetime, PyDateTime_DateTime
from cython cimport binding

from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow_python cimport PyDateTime_to_TimePoint
Expand Down Expand Up @@ -1106,11 +1107,12 @@ cdef class LocalFileSystem(FileSystem):
FileSystem.init(self, c_fs)
self.localfs = <CLocalFileSystem*> c_fs.get()

@classmethod
def _reconstruct(cls, kwargs):
@staticmethod
@binding(True) # Required for cython < 3
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return cls(**kwargs)
return LocalFileSystem(**kwargs)

def __reduce__(self):
cdef CLocalFileSystemOptions opts = self.localfs.options()
Expand Down
13 changes: 9 additions & 4 deletions python/pyarrow/_gcsfs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# cython: language_level = 3

from cython cimport binding

from pyarrow.lib cimport (pyarrow_wrap_metadata,
pyarrow_unwrap_metadata)
from pyarrow.lib import frombytes, tobytes, ensure_metadata
Expand Down Expand Up @@ -154,17 +156,20 @@ cdef class GcsFileSystem(FileSystem):
FileSystem.init(self, wrapped)
self.gcsfs = <CGcsFileSystem*> wrapped.get()

@classmethod
def _reconstruct(cls, kwargs):
return cls(**kwargs)

def _expiration_datetime_from_options(self):
expiration_ns = TimePoint_to_ns(
self.gcsfs.options().credentials.expiration())
if expiration_ns == 0:
return None
return datetime.fromtimestamp(expiration_ns / 1.0e9, timezone.utc)

@staticmethod
@binding(True) # Required for cython < 3
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return GcsFileSystem(**kwargs)

def __reduce__(self):
cdef CGcsOptions opts = self.gcsfs.options()
service_account = frombytes(opts.credentials.target_service_account())
Expand Down
11 changes: 8 additions & 3 deletions python/pyarrow/_hdfs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# cython: language_level = 3

from cython cimport binding

from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_fs cimport *
Expand Down Expand Up @@ -134,9 +136,12 @@ replication=1)``
self.init(<shared_ptr[CFileSystem]> wrapped)
return self

@classmethod
def _reconstruct(cls, kwargs):
return cls(**kwargs)
@staticmethod
@binding(True) # Required for cython < 3
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return HadoopFileSystem(**kwargs)

def __reduce__(self):
cdef CHdfsOptions opts = self.hdfs.options()
Expand Down
11 changes: 8 additions & 3 deletions python/pyarrow/_s3fs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# cython: language_level = 3

from cython cimport binding

from pyarrow.lib cimport (check_status, pyarrow_wrap_metadata,
pyarrow_unwrap_metadata)
from pyarrow.lib import frombytes, tobytes, KeyValueMetadata
Expand Down Expand Up @@ -388,9 +390,12 @@ cdef class S3FileSystem(FileSystem):
FileSystem.init(self, wrapped)
self.s3fs = <CS3FileSystem*> wrapped.get()

@classmethod
def _reconstruct(cls, kwargs):
return cls(**kwargs)
@staticmethod
@binding(True) # Required for cython < 3
def _reconstruct(kwargs):
# __reduce__ doesn't allow passing named arguments directly to the
# reconstructor, hence this wrapper.
return S3FileSystem(**kwargs)

def __reduce__(self):
cdef CS3Options opts = self.s3fs.options()
Expand Down
6 changes: 4 additions & 2 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import collections
from cython cimport binding


cdef class Scalar(_Weakrefable):
Expand Down Expand Up @@ -836,8 +837,9 @@ cdef class DictionaryScalar(Scalar):
Concrete class for dictionary-encoded scalars.
"""

@classmethod
def _reconstruct(cls, type, is_valid, index, dictionary):
@staticmethod
@binding(True) # Required for cython < 3
def _reconstruct(type, is_valid, index, dictionary):
cdef:
CDictionaryScalarIndexAndDictionary value
shared_ptr[CDictionaryScalar] wrapped
Expand Down
111 changes: 91 additions & 20 deletions python/pyarrow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
# specific language governing permissions and limitations
# under the License.

import functools
import os
import pathlib
import subprocess
import sys
from tempfile import TemporaryDirectory
import time
import urllib.request

import pytest
from pytest_lazyfixture import lazy_fixture
import hypothesis as h
from ..conftest import groups, defaults

Expand Down Expand Up @@ -146,8 +149,49 @@ def s3_connection():
return host, port, access_key, secret_key


def retry(attempts=3, delay=1.0, max_delay=None, backoff=1):
"""
Retry decorator
Parameters
----------
attempts : int, default 3
The number of attempts.
delay : float, default 1
Initial delay in seconds.
max_delay : float, optional
The max delay between attempts.
backoff : float, default 1
The multiplier to delay after each attempt.
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remaining_attempts = attempts
curr_delay = delay
while remaining_attempts > 0:
try:
return func(*args, **kwargs)
except Exception as err:
remaining_attempts -= 1
last_exception = err
curr_delay *= backoff
if max_delay:
curr_delay = min(curr_delay, max_delay)
time.sleep(curr_delay)
raise last_exception
return wrapper
return decorate


@pytest.fixture(scope='session')
def s3_server(s3_connection):
def s3_server(s3_connection, tmpdir_factory):
@retry(attempts=5, delay=0.1, backoff=2)
def minio_server_health_check(address):
resp = urllib.request.urlopen(f"http://{address}/minio/health/cluster")
assert resp.getcode() == 200

tmpdir = tmpdir_factory.getbasetemp()
host, port, access_key, secret_key = s3_connection

address = '{}:{}'.format(host, port)
Expand All @@ -157,24 +201,26 @@ def s3_server(s3_connection):
'MINIO_SECRET_KEY': secret_key
})

with TemporaryDirectory() as tempdir:
args = ['minio', '--compat', 'server', '--quiet', '--address',
address, tempdir]
proc = None
try:
proc = subprocess.Popen(args, env=env)
except OSError:
pytest.skip('`minio` command cannot be located')
else:
yield {
'connection': s3_connection,
'process': proc,
'tempdir': tempdir
}
finally:
if proc is not None:
proc.kill()
proc.wait()
args = ['minio', '--compat', 'server', '--quiet', '--address',
address, tmpdir]
proc = None
try:
proc = subprocess.Popen(args, env=env)
except OSError:
pytest.skip('`minio` command cannot be located')
else:
# Wait for the server to startup before yielding
minio_server_health_check(address)

yield {
'connection': s3_connection,
'process': proc,
'tempdir': tmpdir
}
finally:
if proc is not None:
proc.kill()
proc.wait()


@pytest.fixture(scope='session')
Expand Down Expand Up @@ -202,3 +248,28 @@ def gcs_server():
if proc is not None:
proc.kill()
proc.wait()


@pytest.fixture(
params=[
lazy_fixture('builtin_pickle'),
lazy_fixture('cloudpickle')
],
scope='session'
)
def pickle_module(request):
return request.param


@pytest.fixture(scope='session')
def builtin_pickle():
import pickle
return pickle


@pytest.fixture(scope='session')
def cloudpickle():
cp = pytest.importorskip('cloudpickle')
if 'HIGHEST_PROTOCOL' not in cp.__dict__:
cp.HIGHEST_PROTOCOL = cp.DEFAULT_PROTOCOL
return cp
25 changes: 5 additions & 20 deletions python/pyarrow/tests/parquet/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,10 +1533,13 @@ def _make_dataset_for_pickling(tempdir, use_legacy_dataset=False, N=100):
return dataset


def _assert_dataset_is_picklable(dataset, pickler, use_legacy_dataset=False):
@pytest.mark.pandas
@parametrize_legacy_dataset
def test_pickle_dataset(tempdir, datadir, use_legacy_dataset, pickle_module):
def is_pickleable(obj):
return obj == pickler.loads(pickler.dumps(obj))
return obj == pickle_module.loads(pickle_module.dumps(obj))

dataset = _make_dataset_for_pickling(tempdir, use_legacy_dataset)
assert is_pickleable(dataset)
if use_legacy_dataset:
with pytest.warns(FutureWarning):
Expand All @@ -1555,24 +1558,6 @@ def is_pickleable(obj):
assert is_pickleable(metadata.row_group(i))


@pytest.mark.pandas
@parametrize_legacy_dataset
def test_builtin_pickle_dataset(tempdir, datadir, use_legacy_dataset):
import pickle
dataset = _make_dataset_for_pickling(tempdir, use_legacy_dataset)
_assert_dataset_is_picklable(
dataset, pickler=pickle, use_legacy_dataset=use_legacy_dataset)


@pytest.mark.pandas
@parametrize_legacy_dataset
def test_cloudpickle_dataset(tempdir, datadir, use_legacy_dataset):
cp = pytest.importorskip('cloudpickle')
dataset = _make_dataset_for_pickling(tempdir, use_legacy_dataset)
_assert_dataset_is_picklable(
dataset, pickler=cp, use_legacy_dataset=use_legacy_dataset)


@pytest.mark.pandas
@parametrize_legacy_dataset
def test_partitioned_dataset(tempdir, use_legacy_dataset):
Expand Down
Loading

0 comments on commit 175b2a2

Please sign in to comment.