-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ARROW-134 Cannot encode pandas NA objects #118
Changes from 1 commit
46fed4b
1116dc4
774653a
3e73b0c
a5e35d1
f9450f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import numpy as np | ||
import pymongo.errors | ||
from bson import encode | ||
from bson.codec_options import TypeCodec, TypeRegistry | ||
from bson.raw_bson import RawBSONDocument | ||
from pyarrow import Schema as ArrowSchema | ||
from pyarrow import Table | ||
|
@@ -26,9 +27,10 @@ | |
ndarray = None | ||
|
||
try: | ||
from pandas import DataFrame | ||
from pandas import NA, DataFrame | ||
except ImportError: | ||
DataFrame = None | ||
NA = None | ||
|
||
from pymongo.bulk import BulkWriteError | ||
from pymongo.common import MAX_WRITE_BATCH_SIZE | ||
|
@@ -316,6 +318,21 @@ def _tabular_generator(tabular): | |
return | ||
|
||
|
||
class _PandasNACodec(TypeCodec): | ||
"""A custom type codec for Pandas NA objects.""" | ||
|
||
python_type = NA.__class__ # type:ignore[assignment] | ||
bson_type = None # type:ignore[assignment] | ||
|
||
def transform_python(self, _): | ||
"""Transform an NA object into 'None'""" | ||
return None | ||
|
||
def transform_bson(self, _): | ||
"""Transform a 'None' object into NA""" | ||
return NA | ||
ShaneHarvey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def write(collection, tabular): | ||
"""Write data from `tabular` into the given MongoDB `collection`. | ||
|
||
|
@@ -352,6 +369,13 @@ def write(collection, tabular): | |
) | ||
|
||
tabular_gen = _tabular_generator(tabular) | ||
|
||
# Handle Pandas NA objects. | ||
codec_options = collection.codec_options | ||
if DataFrame is not None: | ||
type_registry = TypeRegistry([_PandasNACodec()]) | ||
codec_options = codec_options.with_options(type_registry=type_registry) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure it's ideal to completely replace the collection's type_registry. What if the app already configured a type_registry for other types they want to encode? Would it be possible to keep the existing registry but add the _PandasNACodec? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could, but we'd have to use private APIs from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah my mistake for assuming TypeRegistry would have the ability to add/edit a type. You know, like a registry. Could you open an ARROW ticket for this and backlog it? I'd say it's low priority. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
while cur_offset < tab_size: | ||
cur_size = 0 | ||
cur_batch = [] | ||
|
@@ -361,9 +385,7 @@ def write(collection, tabular): | |
and len(cur_batch) <= _MAX_WRITE_BATCH_SIZE | ||
and cur_offset + i < tab_size | ||
): | ||
enc_tab = RawBSONDocument( | ||
encode(next(tabular_gen), codec_options=collection.codec_options) | ||
) | ||
enc_tab = RawBSONDocument(encode(next(tabular_gen), codec_options=codec_options)) | ||
cur_batch.append(enc_tab) | ||
cur_size += len(enc_tab.raw) | ||
i += 1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,13 +98,21 @@ def test_aggregate_simple(self): | |
self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection) | ||
self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True}) | ||
|
||
def _assert_frames_equal(self, incoming, outgoing): | ||
for name in incoming.columns: | ||
col = incoming[name] | ||
val = outgoing[name] | ||
if str(val.dtype) in ["object", "float64"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comment here or a method docstring would be helpful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
val = val.astype(col.dtype) | ||
pd.testing.assert_series_equal(col, val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is using "val" as the name of a column idomatic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
|
||
def round_trip(self, data, schema, coll=None): | ||
if coll is None: | ||
coll = self.coll | ||
coll.drop() | ||
res = write(self.coll, data) | ||
self.assertEqual(len(data), res.raw_result["insertedCount"]) | ||
pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema)) | ||
self._assert_frames_equal(data, find_pandas_all(coll, {}, schema=schema)) | ||
return res | ||
|
||
def test_write_error(self): | ||
|
@@ -129,23 +137,34 @@ def _create_data(self): | |
if k.__name__ not in ("ObjectId", "Decimal128") | ||
} | ||
schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()} | ||
schema["Int64"] = pd.Int64Dtype() | ||
schema["int"] = pd.Int32Dtype() | ||
schema["str"] = "U8" | ||
schema["datetime"] = "datetime64[ns]" | ||
|
||
data = pd.DataFrame( | ||
data={ | ||
"Int64": [i for i in range(2)], | ||
"float": [i for i in range(2)], | ||
"int": [i for i in range(2)], | ||
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)], | ||
"str": [f"a{i}" for i in range(2)], | ||
"bool": [True, False], | ||
"Int64": [i for i in range(2)] + [None], | ||
"float": [i for i in range(2)] + [None], | ||
"int": [i for i in range(2)] + [None], | ||
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)] + [None], | ||
"str": [f"a{i}" for i in range(2)] + [None], | ||
"bool": [True, False, None], | ||
} | ||
).astype(schema) | ||
return arrow_schema, data | ||
|
||
def test_write_schema_validation(self): | ||
arrow_schema, data = self._create_data() | ||
|
||
# Work around https://github.com/pandas-dev/pandas/issues/11453. | ||
def new_replace(k): | ||
if k.value < 1: | ||
return datetime.datetime(1970, 1, 1) | ||
return k.replace(tzinfo=None) | ||
|
||
data["datetime"] = data.apply(lambda row: new_replace(row["datetime"]), axis=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That pandas ticket seems to be about .time() not working on NA types but we don't use .time() anywhere. Can you explain? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, meant to link to pandas-dev/pandas#16248. Also added context. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still confused since this code is changing the data before passing it to any pymongoarrow methods. What happens if the user actually calls write() with a NaT datetime? And why are we clamping the datetime to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They would see the same error as https://jira.mongodb.org/browse/FREE-165786, for which this is a workaround. I just chose a random valid date. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, 2 follow up questions.
if isinstance(k, Nat):
return None
return k
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we return |
||
|
||
self.round_trip( | ||
data, | ||
Schema(arrow_schema), | ||
|
@@ -282,12 +301,7 @@ def test_csv(self): | |
f.close() | ||
data.to_csv(f.name, index=False) | ||
out = pd.read_csv(f.name) | ||
for name in data.columns: | ||
col = data[name] | ||
val = out[name] | ||
if str(val.dtype) == "object": | ||
val = val.astype(col.dtype) | ||
pd.testing.assert_series_equal(col, val) | ||
self._assert_frames_equal(data, out) | ||
|
||
|
||
class TestBSONTypes(PandasTestBase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, why the type ignores?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My editor had flagged them because the base class uses read-only properties.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see mypy does not raise this issue. This does look like a bug in pyright though:
Looks like microsoft/pyright#2601 which they closed as won't fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually that pyright issue has
@classmethod
whereas we just use@property
so I think it would be worth opening a new pyright issue for it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
microsoft/pyright#4364