diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index 633823ac..d480b85d 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -16,6 +16,7 @@ import numpy as np import pymongo.errors from bson import encode +from bson.codec_options import TypeEncoder, TypeRegistry from bson.raw_bson import RawBSONDocument from pyarrow import Schema as ArrowSchema from pyarrow import Table @@ -26,9 +27,10 @@ ndarray = None try: - from pandas import DataFrame + from pandas import NA, DataFrame except ImportError: DataFrame = None + NA = None from pymongo.bulk import BulkWriteError from pymongo.common import MAX_WRITE_BATCH_SIZE @@ -316,6 +318,18 @@ def _tabular_generator(tabular): return +class _PandasNACodec(TypeEncoder): + """A custom type codec for Pandas NA objects.""" + + @property + def python_type(self): + return NA.__class__ + + def transform_python(self, _): + """Transform an NA object into 'None'""" + return None + + def write(collection, tabular): """Write data from `tabular` into the given MongoDB `collection`. @@ -352,6 +366,13 @@ def write(collection, tabular): ) tabular_gen = _tabular_generator(tabular) + + # Handle Pandas NA objects. + codec_options = collection.codec_options + if DataFrame is not None: + type_registry = TypeRegistry([_PandasNACodec()]) + codec_options = codec_options.with_options(type_registry=type_registry) + while cur_offset < tab_size: cur_size = 0 cur_batch = [] @@ -361,9 +382,7 @@ def write(collection, tabular): and len(cur_batch) <= _MAX_WRITE_BATCH_SIZE and cur_offset + i < tab_size ): - enc_tab = RawBSONDocument( - encode(next(tabular_gen), codec_options=collection.codec_options) - ) + enc_tab = RawBSONDocument(encode(next(tabular_gen), codec_options=codec_options)) cur_batch.append(enc_tab) cur_size += len(enc_tab.raw) i += 1 diff --git a/bindings/python/test/test_pandas.py b/bindings/python/test/test_pandas.py index 22314d7d..33326751 100644 --- a/bindings/python/test/test_pandas.py +++ b/bindings/python/test/test_pandas.py @@ -16,6 +16,7 @@ import tempfile import unittest import unittest.mock as mock +import warnings from test import client_context from test.utils import AllowListEventListener, TestNullsBase @@ -98,13 +99,24 @@ def test_aggregate_simple(self): self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection) self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True}) + def _assert_frames_equal(self, incoming, outgoing): + for name in incoming.columns: + in_col = incoming[name] + out_col = outgoing[name] + # Object types may lose type information in a round trip. + # Integer types with missing values are converted to floating + # point in a round trip. + if str(out_col.dtype) in ["object", "float64"]: + out_col = out_col.astype(in_col.dtype) + pd.testing.assert_series_equal(in_col, out_col) + def round_trip(self, data, schema, coll=None): if coll is None: coll = self.coll coll.drop() res = write(self.coll, data) self.assertEqual(len(data), res.raw_result["insertedCount"]) - pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema)) + self._assert_frames_equal(data, find_pandas_all(coll, {}, schema=schema)) return res def test_write_error(self): @@ -129,23 +141,35 @@ def _create_data(self): if k.__name__ not in ("ObjectId", "Decimal128") } schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()} + schema["Int64"] = pd.Int64Dtype() + schema["int"] = pd.Int32Dtype() schema["str"] = "U8" schema["datetime"] = "datetime64[ns]" data = pd.DataFrame( data={ - "Int64": [i for i in range(2)], - "float": [i for i in range(2)], - "int": [i for i in range(2)], - "datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)], - "str": [f"a{i}" for i in range(2)], - "bool": [True, False], + "Int64": [i for i in range(2)] + [None], + "float": [i for i in range(2)] + [None], + "int": [i for i in range(2)] + [None], + "datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)] + [None], + "str": [f"a{i}" for i in range(2)] + [None], + "bool": [True, False, None], } ).astype(schema) return arrow_schema, data def test_write_schema_validation(self): arrow_schema, data = self._create_data() + + # Work around https://github.com/pandas-dev/pandas/issues/16248, + # Where pandas does not implement utcoffset for null timestamps. + def new_replace(k): + if isinstance(k, pd.NaT.__class__): + return datetime.datetime(1970, 1, 1) + return k.replace(tzinfo=None) + + data["datetime"] = data.apply(lambda row: new_replace(row["datetime"]), axis=1) + self.round_trip( data, Schema(arrow_schema), @@ -280,14 +304,12 @@ def test_csv(self): _, data = self._create_data() with tempfile.NamedTemporaryFile(suffix=".csv") as f: f.close() - data.to_csv(f.name, index=False) + # May give RuntimeWarning due to the nulls. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + data.to_csv(f.name, index=False, na_rep="") out = pd.read_csv(f.name) - for name in data.columns: - col = data[name] - val = out[name] - if str(val.dtype) == "object": - val = val.astype(col.dtype) - pd.testing.assert_series_equal(col, val) + self._assert_frames_equal(data, out) class TestBSONTypes(PandasTestBase):