From 46fed4bf1b344b2a04d6659dd5832bddaab64034 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 20 Dec 2022 13:36:33 -0600 Subject: [PATCH 1/6] ARROW-134 Cannot encode pandas NA objects --- bindings/python/pymongoarrow/api.py | 30 +++++++++++++++++++--- bindings/python/test/test_pandas.py | 40 +++++++++++++++++++---------- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index 633823ac..a5b8d983 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -16,6 +16,7 @@ import numpy as np import pymongo.errors from bson import encode +from bson.codec_options import TypeCodec, TypeRegistry from bson.raw_bson import RawBSONDocument from pyarrow import Schema as ArrowSchema from pyarrow import Table @@ -26,9 +27,10 @@ ndarray = None try: - from pandas import DataFrame + from pandas import NA, DataFrame except ImportError: DataFrame = None + NA = None from pymongo.bulk import BulkWriteError from pymongo.common import MAX_WRITE_BATCH_SIZE @@ -316,6 +318,21 @@ def _tabular_generator(tabular): return +class _PandasNACodec(TypeCodec): + """A custom type codec for Pandas NA objects.""" + + python_type = NA.__class__ # type:ignore[assignment] + bson_type = None # type:ignore[assignment] + + def transform_python(self, _): + """Transform an NA object into 'None'""" + return None + + def transform_bson(self, _): + """Transform a 'None' object into NA""" + return NA + + def write(collection, tabular): """Write data from `tabular` into the given MongoDB `collection`. @@ -352,6 +369,13 @@ def write(collection, tabular): ) tabular_gen = _tabular_generator(tabular) + + # Handle Pandas NA objects. + codec_options = collection.codec_options + if DataFrame is not None: + type_registry = TypeRegistry([_PandasNACodec()]) + codec_options = codec_options.with_options(type_registry=type_registry) + while cur_offset < tab_size: cur_size = 0 cur_batch = [] @@ -361,9 +385,7 @@ def write(collection, tabular): and len(cur_batch) <= _MAX_WRITE_BATCH_SIZE and cur_offset + i < tab_size ): - enc_tab = RawBSONDocument( - encode(next(tabular_gen), codec_options=collection.codec_options) - ) + enc_tab = RawBSONDocument(encode(next(tabular_gen), codec_options=codec_options)) cur_batch.append(enc_tab) cur_size += len(enc_tab.raw) i += 1 diff --git a/bindings/python/test/test_pandas.py b/bindings/python/test/test_pandas.py index 22314d7d..b0d44201 100644 --- a/bindings/python/test/test_pandas.py +++ b/bindings/python/test/test_pandas.py @@ -98,13 +98,21 @@ def test_aggregate_simple(self): self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection) self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True}) + def _assert_frames_equal(self, incoming, outgoing): + for name in incoming.columns: + col = incoming[name] + val = outgoing[name] + if str(val.dtype) in ["object", "float64"]: + val = val.astype(col.dtype) + pd.testing.assert_series_equal(col, val) + def round_trip(self, data, schema, coll=None): if coll is None: coll = self.coll coll.drop() res = write(self.coll, data) self.assertEqual(len(data), res.raw_result["insertedCount"]) - pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema)) + self._assert_frames_equal(data, find_pandas_all(coll, {}, schema=schema)) return res def test_write_error(self): @@ -129,23 +137,34 @@ def _create_data(self): if k.__name__ not in ("ObjectId", "Decimal128") } schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()} + schema["Int64"] = pd.Int64Dtype() + schema["int"] = pd.Int32Dtype() schema["str"] = "U8" schema["datetime"] = "datetime64[ns]" data = pd.DataFrame( data={ - "Int64": [i for i in range(2)], - "float": [i for i in range(2)], - "int": [i for i in range(2)], - "datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)], - "str": [f"a{i}" for i in range(2)], - "bool": [True, False], + "Int64": [i for i in range(2)] + [None], + "float": [i for i in range(2)] + [None], + "int": [i for i in range(2)] + [None], + "datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)] + [None], + "str": [f"a{i}" for i in range(2)] + [None], + "bool": [True, False, None], } ).astype(schema) return arrow_schema, data def test_write_schema_validation(self): arrow_schema, data = self._create_data() + + # Work around https://github.com/pandas-dev/pandas/issues/11453. + def new_replace(k): + if k.value < 1: + return datetime.datetime(1970, 1, 1) + return k.replace(tzinfo=None) + + data["datetime"] = data.apply(lambda row: new_replace(row["datetime"]), axis=1) + self.round_trip( data, Schema(arrow_schema), @@ -282,12 +301,7 @@ def test_csv(self): f.close() data.to_csv(f.name, index=False) out = pd.read_csv(f.name) - for name in data.columns: - col = data[name] - val = out[name] - if str(val.dtype) == "object": - val = val.astype(col.dtype) - pd.testing.assert_series_equal(col, val) + self._assert_frames_equal(data, out) class TestBSONTypes(PandasTestBase): From 1116dc4c89d01f2e0406ddc7bbb7412cd258197e Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 20 Dec 2022 15:25:19 -0600 Subject: [PATCH 2/6] address review --- bindings/python/pymongoarrow/api.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index a5b8d983..39ddc1a9 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -16,7 +16,7 @@ import numpy as np import pymongo.errors from bson import encode -from bson.codec_options import TypeCodec, TypeRegistry +from bson.codec_options import TypeEncoder, TypeRegistry from bson.raw_bson import RawBSONDocument from pyarrow import Schema as ArrowSchema from pyarrow import Table @@ -318,20 +318,15 @@ def _tabular_generator(tabular): return -class _PandasNACodec(TypeCodec): +class _PandasNACodec(TypeEncoder): """A custom type codec for Pandas NA objects.""" python_type = NA.__class__ # type:ignore[assignment] - bson_type = None # type:ignore[assignment] def transform_python(self, _): """Transform an NA object into 'None'""" return None - def transform_bson(self, _): - """Transform a 'None' object into NA""" - return NA - def write(collection, tabular): """Write data from `tabular` into the given MongoDB `collection`. From 774653a732d6cca778956c0b39cca2b7a5517dfe Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 20 Dec 2022 15:27:49 -0600 Subject: [PATCH 3/6] try to fix csv test --- bindings/python/test/test_pandas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/test/test_pandas.py b/bindings/python/test/test_pandas.py index b0d44201..e30339fd 100644 --- a/bindings/python/test/test_pandas.py +++ b/bindings/python/test/test_pandas.py @@ -299,7 +299,7 @@ def test_csv(self): _, data = self._create_data() with tempfile.NamedTemporaryFile(suffix=".csv") as f: f.close() - data.to_csv(f.name, index=False) + data.to_csv(f.name, index=False, na_rep="") out = pd.read_csv(f.name) self._assert_frames_equal(data, out) From 3e73b0c1d3e47e5c267c946b83663cef2cf17b34 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 20 Dec 2022 15:37:13 -0600 Subject: [PATCH 4/6] handle runtimewarning --- bindings/python/test/test_pandas.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bindings/python/test/test_pandas.py b/bindings/python/test/test_pandas.py index e30339fd..9578c7f5 100644 --- a/bindings/python/test/test_pandas.py +++ b/bindings/python/test/test_pandas.py @@ -16,6 +16,7 @@ import tempfile import unittest import unittest.mock as mock +import warnings from test import client_context from test.utils import AllowListEventListener, TestNullsBase @@ -299,7 +300,10 @@ def test_csv(self): _, data = self._create_data() with tempfile.NamedTemporaryFile(suffix=".csv") as f: f.close() - data.to_csv(f.name, index=False, na_rep="") + # May give RuntimeWarning due to the nulls. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + data.to_csv(f.name, index=False, na_rep="") out = pd.read_csv(f.name) self._assert_frames_equal(data, out) From a5e35d158a916052c1a944cdb8bc07c0bef79fb2 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 20 Dec 2022 17:14:25 -0600 Subject: [PATCH 5/6] address review --- bindings/python/test/test_pandas.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/bindings/python/test/test_pandas.py b/bindings/python/test/test_pandas.py index 9578c7f5..21624cff 100644 --- a/bindings/python/test/test_pandas.py +++ b/bindings/python/test/test_pandas.py @@ -101,11 +101,14 @@ def test_aggregate_simple(self): def _assert_frames_equal(self, incoming, outgoing): for name in incoming.columns: - col = incoming[name] - val = outgoing[name] - if str(val.dtype) in ["object", "float64"]: - val = val.astype(col.dtype) - pd.testing.assert_series_equal(col, val) + in_col = incoming[name] + out_col = outgoing[name] + # Object types may lose type information in a round trip. + # Integer types with missing values are converted to floating + # point in a round trip. + if str(out_col.dtype) in ["object", "float64"]: + out_col = out_col.astype(in_col.dtype) + pd.testing.assert_series_equal(in_col, out_col) def round_trip(self, data, schema, coll=None): if coll is None: @@ -158,7 +161,8 @@ def _create_data(self): def test_write_schema_validation(self): arrow_schema, data = self._create_data() - # Work around https://github.com/pandas-dev/pandas/issues/11453. + # Work around https://github.com/pandas-dev/pandas/issues/16248, + # Where pandas does not implement utcoffset for null timestamps. def new_replace(k): if k.value < 1: return datetime.datetime(1970, 1, 1) From f9450f46eca0f50e990eb556be60bd8d88b2952f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 4 Jan 2023 16:00:30 -0600 Subject: [PATCH 6/6] cleanup --- bindings/python/pymongoarrow/api.py | 4 +++- bindings/python/test/test_pandas.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index 39ddc1a9..d480b85d 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -321,7 +321,9 @@ def _tabular_generator(tabular): class _PandasNACodec(TypeEncoder): """A custom type codec for Pandas NA objects.""" - python_type = NA.__class__ # type:ignore[assignment] + @property + def python_type(self): + return NA.__class__ def transform_python(self, _): """Transform an NA object into 'None'""" diff --git a/bindings/python/test/test_pandas.py b/bindings/python/test/test_pandas.py index 21624cff..33326751 100644 --- a/bindings/python/test/test_pandas.py +++ b/bindings/python/test/test_pandas.py @@ -164,7 +164,7 @@ def test_write_schema_validation(self): # Work around https://github.com/pandas-dev/pandas/issues/16248, # Where pandas does not implement utcoffset for null timestamps. def new_replace(k): - if k.value < 1: + if isinstance(k, pd.NaT.__class__): return datetime.datetime(1970, 1, 1) return k.replace(tzinfo=None)