diff --git a/bindings/python/benchmarks/benchmarks.py b/bindings/python/benchmarks/benchmarks.py index 0a6ef033..892c78ae 100644 --- a/bindings/python/benchmarks/benchmarks.py +++ b/bindings/python/benchmarks/benchmarks.py @@ -18,6 +18,7 @@ import numpy as np import pandas as pd +import polars as pl import pyarrow as pa import pymongo from bson import BSON, Binary, Decimal128 @@ -27,6 +28,7 @@ find_arrow_all, find_numpy_all, find_pandas_all, + find_polars_all, write, ) from pymongoarrow.types import BinaryType, Decimal128Type @@ -74,6 +76,9 @@ def time_insert_pandas(self): def time_insert_numpy(self): write(db.benchmark, self.numpy_arrays) + def time_insert_polars(self): + write(db.benchmark, self.polars_table) + def peakmem_insert_arrow(self): self.time_insert_arrow() @@ -86,6 +91,9 @@ def peakmem_insert_pandas(self): def peakmem_insert_numpy(self): self.time_insert_numpy() + def peakmem_insert_polars(self): + self.time_insert_polars() + class Read(ABC): """ @@ -136,16 +144,25 @@ def time_to_pandas(self): c = db.benchmark find_pandas_all(c, {}, schema=self.schema, projection={"_id": 0}) + def time_conventional_arrow(self): + c = db.benchmark + f = list(c.find({}, projection={"_id": 0})) + table = pa.Table.from_pylist(f) + self.exercise_table(table) + def time_to_arrow(self): c = db.benchmark table = find_arrow_all(c, {}, schema=self.schema, projection={"_id": 0}) self.exercise_table(table) - def time_conventional_arrow(self): + def time_conventional_polars(self): + collection = db.benchmark + cursor = collection.find(projection={"_id": 0}) + _ = pl.DataFrame(list(cursor)) + + def time_to_polars(self): c = db.benchmark - f = list(c.find({}, projection={"_id": 0})) - table = pa.Table.from_pylist(f) - self.exercise_table(table) + find_polars_all(c, {}, schema=self.schema, projection={"_id": 0}) def peakmem_to_numpy(self): self.time_to_numpy() @@ -162,6 +179,12 @@ def peakmem_to_arrow(self): def peakmem_conventional_arrow(self): self.time_conventional_arrow() + def peakmem_to_polars(self): + self.time_to_polars() + + def peakmem_conventional_polars(self): + self.time_conventional_polars() + class ProfileReadArray(Read): schema = Schema( @@ -364,6 +387,7 @@ def setup(self): self.arrow_table = find_arrow_all(db.benchmark, {}, schema=self.schema) self.pandas_table = find_pandas_all(db.benchmark, {}, schema=self.schema) self.numpy_arrays = find_numpy_all(db.benchmark, {}, schema=self.schema) + self.polars_table = find_polars_all(db.benchmark, {}, schema=self.schema) class ProfileInsertLarge(Insert): @@ -383,3 +407,4 @@ def setup(self): self.arrow_table = find_arrow_all(db.benchmark, {}, schema=self.schema) self.pandas_table = find_pandas_all(db.benchmark, {}, schema=self.schema) self.numpy_arrays = find_numpy_all(db.benchmark, {}, schema=self.schema) + self.polars_table = find_polars_all(db.benchmark, {}, schema=self.schema) diff --git a/bindings/python/docs/source/changelog.rst b/bindings/python/docs/source/changelog.rst index 7e030045..8f3a12f5 100644 --- a/bindings/python/docs/source/changelog.rst +++ b/bindings/python/docs/source/changelog.rst @@ -1,6 +1,11 @@ Changelog ========= +Changes in Version 1.3.0 +------------------------ +- Support for Polars +- Support for PyArrow.DataTypes: large_list, large_string + Changes in Version 1.2.0 ------------------------ - Support for PyArrow 14.0. diff --git a/bindings/python/docs/source/comparison.rst b/bindings/python/docs/source/comparison.rst index d4390cac..e827ecd6 100644 --- a/bindings/python/docs/source/comparison.rst +++ b/bindings/python/docs/source/comparison.rst @@ -1,8 +1,8 @@ -Quick Start -=========== +Comparing to PyMongo +==================== -This tutorial is intended as a comparison between using just PyMongo, versus -with **PyMongoArrow**. The reader is assumed to be familiar with basic +This tutorial is intended as a comparison between using **PyMongoArrow**, +versus just PyMongo. The reader is assumed to be familiar with basic `PyMongo `_ and `MongoDB `_ concepts. diff --git a/bindings/python/docs/source/data_types.rst b/bindings/python/docs/source/data_types.rst index 16ff2dde..2a666f9e 100644 --- a/bindings/python/docs/source/data_types.rst +++ b/bindings/python/docs/source/data_types.rst @@ -4,8 +4,12 @@ Data Types ========== PyMongoArrow supports a majority of the BSON types. +As Arrow and Polars provide first-class support for Lists and Structs, +this includes Embedded arrays and documents. + Support for additional types will be added in subsequent releases. + .. note:: For more information about BSON types, see the `BSON specification `_. @@ -131,11 +135,12 @@ dataframe will be the appropriate ``bson`` type. >>> df["_id"][0] ObjectId('64408bf65ac9e208af220144') +As of this writing, Polars does not support Extension Types. Null Values and Conversion to Pandas DataFrames ----------------------------------------------- -In Arrow, all Arrays are always nullable. +In Arrow (and Polars), all Arrays are nullable. Pandas has experimental nullable data types as, e.g., "Int64" (note the capital "I"). You can instruct Arrow to create a pandas DataFrame using nullable dtypes with the code below (taken from `here `_) diff --git a/bindings/python/docs/source/faq.rst b/bindings/python/docs/source/faq.rst index ec6d15a7..0373ccda 100644 --- a/bindings/python/docs/source/faq.rst +++ b/bindings/python/docs/source/faq.rst @@ -3,13 +3,13 @@ Frequently Asked Questions .. contents:: -Why do I get ``ModuleNotFoundError: No module named 'pandas'`` when using PyMongoArrow --------------------------------------------------------------------------------------- +Why do I get ``ModuleNotFoundError: No module named 'polars'`` when using PyMongoArrow? +--------------------------------------------------------------------------------------- This error is raised when an application attempts to use a PyMongoArrow API -that returns query result sets as a :class:`pandas.DataFrame` instance without -having ``pandas`` installed in the Python environment. Since ``pandas`` is not +that returns query result sets as a :class:`polars.DataFrame` instance without +having ``polars`` installed in the Python environment. Since ``polars`` is not a direct dependency of PyMongoArrow, it is not automatically installed when you install ``pymongoarrow`` and must be installed separately:: - $ python -m pip install pandas + $ python -m pip install polars diff --git a/bindings/python/docs/source/index.rst b/bindings/python/docs/source/index.rst index 2f6f5475..ebd64a57 100644 --- a/bindings/python/docs/source/index.rst +++ b/bindings/python/docs/source/index.rst @@ -6,7 +6,8 @@ Overview **PyMongoArrow** is a `PyMongo `_ extension containing tools for loading `MongoDB `_ query result sets as `Apache Arrow `_ tables, -`Pandas `_ and `NumPy `_ arrays. +`NumPy `_ arrays, and `Pandas `_ +or `Polars `_ DataFrames. PyMongoArrow is the recommended way to materialize MongoDB query result sets as contiguous-in-memory, typed arrays suited for in-memory analytical processing applications. This documentation attempts to explain everything you need to diff --git a/bindings/python/docs/source/quickstart.rst b/bindings/python/docs/source/quickstart.rst index 8549b8d8..fb2cf94f 100644 --- a/bindings/python/docs/source/quickstart.rst +++ b/bindings/python/docs/source/quickstart.rst @@ -68,10 +68,16 @@ to type-specifiers, e.g.:: schema = Schema({'_id': int, 'amount': float, 'last_updated': datetime}) -Nested data (embedded documents) are also supported:: +PyMongoArrow offers first-class support for Nested data (embedded documents):: schema = Schema({'_id': int, 'amount': float, 'account': { 'name': str, 'account_number': int}}) +Lists (and nested lists) are also supported:: + + from pyarrow import list_, string + schema = Schema({'txns': list_(string())}) + polars_df = client.db.data.find_polars_all({'amount': {'$gt': 0}}, schema=schema) + There are multiple permissible type-identifiers for each supported BSON type. For a full-list of data types and associated type-identifiers see :doc:`data_types`. @@ -89,18 +95,16 @@ We can also load the same result set as a :class:`pyarrow.Table` instance:: arrow_table = client.db.data.find_arrow_all({'amount': {'$gt': 0}}, schema=schema) -In the NumPy case, the return value is a dictionary where the keys are field -names and values are corresponding :class:`numpy.ndarray` instances:: +a :class:`polars.DataFrame`:: - ndarrays = client.db.data.find_numpy_all({'amount': {'$gt': 0}}, schema=schema) + df = client.db.data.find_polars_all({'amount': {'$gt': 0}}, schema=schema) +or as **Numpy arrays**:: -Arrays (and nested arrays) are also supported:: - - from pyarrow import list_, string - schema = Schema({'_id': int, 'amount': float, 'txns': list_(string())}) - arrow_table = client.db.data.find_arrow_all({'amount': {'$gt': 0}}, schema=schema) + ndarrays = client.db.data.find_numpy_all({'amount': {'$gt': 0}}, schema=schema) +In the NumPy case, the return value is a dictionary where the keys are field +names and values are corresponding :class:`numpy.ndarray` instances. .. note:: For all of the examples above, the schema can be omitted like so:: @@ -130,16 +134,18 @@ More information on aggregation pipelines can be found `here pa.Array: + """Return an Array where ExtensionTypes have been cast to their base pyarrow types""" + if isinstance(array.type, pa.ExtensionType): + return array.cast(array.type.storage_type) + # elif pa.types.is_struct(field.type): + # ... + # elif pa.types.is_list(field.type): + # ... + return array + + +def _cast_away_extension_types_on_table(table: pa.Table) -> pa.Table: + """Given arrow_table that may ExtensionTypes, cast these to the base pyarrow types""" + # Convert all fields in the Arrow table + converted_fields = [ + _cast_away_extension_types_on_array(table.column(i)) for i in range(table.num_columns) + ] + # Reconstruct the Arrow table + return pa.Table.from_arrays(converted_fields, names=table.column_names) + + +def _arrow_to_polars(arrow_table): + """Helper function that converts an Arrow Table to a Polars DataFrame. + + Note: Polars lacks ExtensionTypes. We cast them to their base arrow classes. + """ + if pl is None: + msg = "polars is not installed. Try pip install polars." + raise ValueError(msg) + arrow_table_without_extensions = _cast_away_extension_types_on_table(arrow_table) + return pl.from_arrow(arrow_table_without_extensions) + + +def find_polars_all(collection, query, *, schema=None, **kwargs): + """Method that returns the results of a find query as a + :class:`polars.DataFrame` instance. + + :Parameters: + - `collection`: Instance of :class:`~pymongo.collection.Collection`. + against which to run the ``find`` operation. + - `query`: A mapping containing the query to use for the find operation. + - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. + If the schema is not given, it will be inferred using the first + document in the result set. + + Additional keyword-arguments passed to this method will be passed + directly to the underlying ``find`` operation. + + :Returns: + An instance of class:`polars.DataFrame`. + + .. versionadded:: 1.3 + """ + return _arrow_to_polars(find_arrow_all(collection, query, schema=schema, **kwargs)) + + +def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs): + """Method that returns the results of an aggregation pipeline as a + :class:`polars.DataFrame` instance. + + :Parameters: + - `collection`: Instance of :class:`~pymongo.collection.Collection`. + against which to run the ``find`` operation. + - `pipeline`: A list of aggregation pipeline stages. + - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. + If the schema is not given, it will be inferred using the first + document in the result set. + + Additional keyword-arguments passed to this method will be passed + directly to the underlying ``aggregate`` operation. + + :Returns: + An instance of class:`polars.DataFrame`. + """ + return _arrow_to_polars(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs)) + + def _transform_bwe(bwe, offset): bwe["nInserted"] += offset for i in bwe["writeErrors"]: @@ -299,9 +387,11 @@ def _tabular_generator(tabular): for i in tabular.to_batches(): for row in i.to_pylist(): yield row - elif DataFrame is not None and isinstance(tabular, DataFrame): + elif isinstance(tabular, pd.DataFrame): for row in tabular.to_dict("records"): yield row + elif pl is not None and isinstance(tabular, pl.DataFrame): + yield from _tabular_generator(tabular.to_arrow()) elif isinstance(tabular, dict): iter_dict = {k: np.nditer(v) for k, v in tabular.items()} try: @@ -316,7 +406,7 @@ class _PandasNACodec(TypeEncoder): @property def python_type(self): - return NA.__class__ + return pd.NA.__class__ def transform_python(self, _): """Transform an NA object into 'None'""" @@ -341,8 +431,11 @@ def write(collection, tabular): tab_size = len(tabular) if isinstance(tabular, Table): _validate_schema(tabular.schema.types) - elif isinstance(tabular, DataFrame): + elif isinstance(tabular, pd.DataFrame): _validate_schema(ArrowSchema.from_pandas(tabular).types) + elif pl is not None and isinstance(tabular, pl.DataFrame): + tabular = tabular.to_arrow() # zero-copy in most cases and done in tabular_gen anyway + _validate_schema(tabular.schema.types) elif ( isinstance(tabular, dict) and len(tabular.values()) >= 1 diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 5fc253a7..76c30989 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -9,11 +9,11 @@ requires = [ [project] name = "pymongoarrow" -description = '"Tools for using NumPy, Pandas and PyArrow with MongoDB"' +description = '"Tools for using NumPy, Pandas, Polars, and PyArrow with MongoDB"' license = {text = "Apache License, Version 2.0"} authors = [{name = "Prashant Mital"}] maintainers = [{name = "MongoDB"}, {name = "Inc."}] -keywords = ["mongo", "mongodb", "pymongo", "arrow", "bson", "numpy", "pandas"] +keywords = ["mongo", "mongodb", "pymongo", "arrow", "bson", "numpy", "pandas", "polars"] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -38,7 +38,7 @@ dependencies = [ "pyarrow >=15.0,<15.1", "pymongo >=4.4,<5", "pandas >=1.3.5,<3", - "packaging >=23.2,<24" + "packaging >=23.2,<24", ] dynamic = ["version"] @@ -49,7 +49,7 @@ Source = "https://github.com/mongodb-labs/mongo-arrow/tree/main/bindings/python" Tracker = "https://jira.mongodb.org/projects/ARROW/issues" [project.optional-dependencies] -test = ["pytz", "pytest"] +test = ["pytz", "pytest", "polars"] [tool.setuptools] zip-safe = false @@ -83,6 +83,7 @@ LIBBSON_INSTALL_DIR = "./libbson" [tool.cibuildwheel.linux] archs = "x86_64 aarch64" manylinux-x86_64-image = "manylinux_2_28" +manylinux-aarch64-image = "manylinux_2_28" repair-wheel-command = [ "pip install \"auditwheel>=5,<6\"", "python addtags.py {wheel} {dest_dir}" diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index 628c93af..88973041 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -749,6 +749,23 @@ def test_large_list_type(self): schema, data = self._create_nested_data((large_list(int32()), list(range(3)))) self.round_trip(data, Schema(schema)) + def test_binary_types(self): + """Demonstrates that binary data is not yet supported. TODO [ARROW-214] + + Will demonstrate roundtrip behavior of Arrow DataType binary and large_binary. + """ + for btype in [pa.binary(), pa.large_binary()]: + with self.assertRaises(ValueError): + self.coll.drop() + aschema = pa.schema([("binary", btype)]) + table_in = pa.Table.from_pydict({"binary": [b"1", b"one"]}, schema=aschema) + write(self.coll, table_in) + table_out_none = find_arrow_all(self.coll, {}, schema=None) + mschema = Schema.from_arrow(aschema) + table_out_schema = find_arrow_all(self.coll, {}, schema=mschema) + self.assertTrue(table_out_schema.schema == table_in.schema) + self.assertTrue(table_out_none.equals(table_out_schema)) + class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase): def run_find(self, *args, **kwargs): diff --git a/bindings/python/test/test_polars.py b/bindings/python/test/test_polars.py new file mode 100644 index 00000000..6526869f --- /dev/null +++ b/bindings/python/test/test_polars.py @@ -0,0 +1,395 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import unittest.mock as mock +import uuid +from datetime import datetime +from test import client_context +from test.utils import AllowListEventListener + +import bson +import polars as pl +import pyarrow as pa +from polars.testing import assert_frame_equal +from pyarrow import int32, int64 +from pymongo import DESCENDING, WriteConcern +from pymongo.collection import Collection + +from pymongoarrow import api +from pymongoarrow.api import Schema, aggregate_polars_all, find_arrow_all, find_polars_all, write +from pymongoarrow.errors import ArrowWriteError +from pymongoarrow.types import ( + _TYPE_NORMALIZER_FACTORY, + BinaryType, + CodeType, + Decimal128Type, + ObjectIdType, +) + + +class PolarsTestBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not client_context.connected: + raise unittest.SkipTest("cannot connect to MongoDB") + cls.cmd_listener = AllowListEventListener("find", "aggregate") + cls.getmore_listener = AllowListEventListener("getMore") + cls.client = client_context.get_client( + event_listeners=[cls.getmore_listener, cls.cmd_listener], + uuidRepresentation="standard", + ) + + +class TestExplicitPolarsApi(PolarsTestBase): + @classmethod + def setUpClass(cls): + PolarsTestBase.setUpClass() + cls.schema = Schema({"_id": int32(), "data": int64()}) + cls.coll = cls.client.pymongoarrow_test.get_collection( + "test", write_concern=WriteConcern(w="majority") + ) + + def setUp(self): + """Insert simple use case data.""" + self.coll.drop() + self.coll.insert_many( + [ + {"_id": 1, "data": 10}, + {"_id": 2, "data": 20}, + {"_id": 3, "data": 30}, + {"_id": 4}, + ] + ) + self.cmd_listener.reset() + self.getmore_listener.reset() + + def round_trip(self, df_in, schema=None, **kwargs): + """Helper tests pl.DataFrame written matches that found.""" + self.coll.drop() + res = write(self.coll, df_in) + self.assertEqual(len(df_in), res.raw_result["insertedCount"]) + df_out = find_polars_all(self.coll, {}, schema=schema, **kwargs) + pl.testing.assert_frame_equal(df_in, df_out) + return res + + def test_find_simple(self): + expected = pl.DataFrame( + data={ + "_id": pl.Series(values=[1, 2, 3, 4], dtype=pl.Int32), + "data": pl.Series(values=[10, 20, 30, None], dtype=pl.Int64), + } + ) + table = find_polars_all(self.coll, {}, schema=self.schema) + self.assertEqual(expected.dtypes, table.dtypes) + self.assertTrue(table.equals(expected)) + + expected = pl.DataFrame( + data={ + "_id": pl.Series(values=[4, 3], dtype=pl.Int32), + "data": pl.Series(values=[None, 30], dtype=pl.Int64), + } + ) + table = find_polars_all( + self.coll, + {"_id": {"$gt": 2}}, + schema=self.schema, + sort=[("_id", DESCENDING)], + ) + self.assertEqual(expected.dtypes, table.dtypes) + self.assertTrue(table.equals(expected)) + + find_cmd = self.cmd_listener.results["started"][-1] + self.assertEqual(find_cmd.command_name, "find") + self.assertEqual(find_cmd.command["projection"], {"_id": True, "data": True}) + + def test_aggregate_simple(self): + expected = pl.DataFrame( + data={ + "_id": pl.Series(values=[1, 2, 3, 4], dtype=pl.Int32), + "data": pl.Series(values=[20, 40, 60, None], dtype=pl.Int64), + } + ) + projection = {"_id": True, "data": {"$multiply": [2, "$data"]}} + table = aggregate_polars_all(self.coll, [{"$project": projection}], schema=self.schema) + self.assertTrue(table.equals(expected)) + + agg_cmd = self.cmd_listener.results["started"][-1] + self.assertEqual(agg_cmd.command_name, "aggregate") + assert len(agg_cmd.command["pipeline"]) == 2 + self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection) + self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True}) + + @mock.patch.object(Collection, "insert_many", side_effect=Collection.insert_many, autospec=True) + def test_write_batching(self, mock): + data = pl.DataFrame(data={"_id": pl.Series(values=range(100040), dtype=pl.Int64)}) + self.round_trip(data, Schema(dict(_id=int64()))) + self.assertEqual(mock.call_count, 2) + + def test_duplicate_key_error(self): + """Confirm expected error is raised, simple duplicate key case.""" + n = 3 + data = pl.DataFrame( + data={ + "_id": pl.Series(values=list(range(n)) * 2, dtype=pl.Int32), + "data": pl.Series(values=range(n * 2), dtype=pl.Int64), + } + ) + with self.assertRaises(ArrowWriteError): + try: + self.round_trip(data, Schema({"_id": int32(), "data": int64()})) + except ArrowWriteError as awe: + self.assertEqual(n, awe.details["writeErrors"][0]["index"]) + self.assertEqual(n, awe.details["nInserted"]) + raise awe + + def test_polars_types(self): + """Test round-trip of DataFrame consisting of Polar.DataTypes. + + This does NOT include ExtensionTypes as Polars doesn't support them (yet). + """ + pl_typenames = ["Int64", "Int32", "Float64", "Datetime", "String", "Boolean"] + pl_types = [pl.Int64, pl.Int32, pl.Float64, pl.Datetime("ms"), pl.String, pl.Boolean] + pa_types = [ + pa.int64(), + pa.int32(), + pa.float64(), + pa.timestamp("ms"), + pa.string(), + pa.bool_(), + ] + + pl_schema = dict(zip(pl_typenames, pl_types)) + pa_schema = dict(zip(pl_typenames, pa_types)) + + data = { + "Int64": pl.Series([i for i in range(2)] + [None]), + "Int32": pl.Series([i for i in range(2)] + [None]), + "Float64": pl.Series([i for i in range(2)] + [None]), + "Datetime": pl.Series([datetime(1970 + i, 1, 1) for i in range(2)] + [None]), + "String": pl.Series([f"a{i}" for i in range(2)] + [None]), + "Boolean": pl.Series([True, False, None]), + } + + df_in = pl.DataFrame._from_dict(data=data, schema=pl_schema) + self.coll.drop() + write(self.coll, df_in) + df_out = find_polars_all(self.coll, {}, schema=Schema(pa_schema)) + pl.testing.assert_frame_equal(df_in, df_out.drop("_id")) + + def test_extension_types_fail(self): + """Confirm failure on ExtensionTypes for Polars.DataFrame.from_arrow""" + + for ext_type, data in ( + (ObjectIdType(), [bson.ObjectId().binary, bson.ObjectId().binary]), + (Decimal128Type(), [bson.Decimal128(str(i)).bid for i in range(2)]), + (CodeType(), [str(i) for i in range(2)]), + ): + table = pa.Table.from_pydict({"foo": data}, pa.schema({"foo": ext_type})) + with self.assertRaises(pl.exceptions.ComputeError): + pl.from_arrow(table) + + def test_auto_schema_succeeds_on_find(self): + """Confirms Polars can read ObjectID Extension type. + + This is inserted automatically by Collection.insert_many + Note that the output dtype is int32 + """ + vals = [1, "2", True, 4] + data = [{"a": v} for v in vals] + + self.coll.drop() + self.coll.insert_many(data) # ObjectID autogenerated here + + df_out = find_polars_all(self.coll, {}) + self.assertEqual(df_out.columns, ["_id", "a"]) + self.assertEqual(df_out.shape, (4, 2)) + self.assertEqual(df_out.dtypes, [pl.Binary, pl.Int32]) + + def test_arrow_to_polars(self): + """Test reading Polars data from written Arrow Data.""" + arrow_schema = {k.__name__: v(True) for k, v in _TYPE_NORMALIZER_FACTORY.items()} + arrow_table_in = pa.Table.from_pydict( + { + "Int64": [i for i in range(2)], + "float": [i for i in range(2)], + "datetime": [i for i in range(2)], + "str": [str(i) for i in range(2)], + "int": [i for i in range(2)], + "bool": [True, False], + "Binary": [b"1", b"23"], + "ObjectId": [bson.ObjectId().binary, bson.ObjectId().binary], + "Decimal128": [bson.Decimal128(str(i)).bid for i in range(2)], + "Code": [str(i) for i in range(2)], + }, + pa.schema(arrow_schema), + ) + + self.coll.drop() + res = write(self.coll, arrow_table_in) + self.assertEqual(len(arrow_table_in), res.raw_result["insertedCount"]) + df_out = find_polars_all(self.coll, query={}, schema=Schema(arrow_schema)) + + # Sanity check: compare with cast_away_extension_types_on_table + arrow_cast = api._cast_away_extension_types_on_table(arrow_table_in) + assert_frame_equal(df_out, pl.from_arrow(arrow_cast)) + + def test_exceptions_for_unsupported_polar_types(self): + """Confirm exceptions thrown are expected. + + Currently, pl.Series, and pl.Object + Tracks future changes in any packages. + """ + + # Series: PyMongoError does not support + with self.assertRaises(ValueError) as exc: + pls = pl.Series(values=range(2)) + write(self.coll, pls) + self.assertTrue("Invalid tabular data object" in exc.exception.args[0]) + + # Polars has an Object Type, similar in concept to Pandas + class MyObject: + pass + + with self.assertRaises(pl.PolarsPanicError) as exc: + df_in = pl.DataFrame(data=[MyObject()] * 2) + write(self.coll, df_in) + self.assertTrue("not implemented" in exc.exception.args[0]) + + def test_polars_binary_type(self): + """Demonstrates that binary data is not yet supported. TODO [ARROW-214] + + Will demonstrate roundtrip behavior of Polar Binary Type. + """ + # 1. _id added by MongoDB + self.coll.drop() + with self.assertRaises(ValueError): + df_in = pl.DataFrame({"Binary": [b"1", b"one"]}, schema={"Binary": pl.Binary}) + write(self.coll, df_in) + df_out = find_polars_all(self.coll, {}) + self.assertTrue(df_out.columns == ["_id", "Binary"]) + self.assertTrue(all([isinstance(c, pl.Binary) for c in df_out.dtypes])) + self.assertIsNone(assert_frame_equal(df_in, df_out.select("Binary"))) + # 2. Explicit Binary _id + self.coll.drop() + df_in = pl.DataFrame( + data=dict(_id=[b"0", b"1"], Binary=[b"1", b"one"]), + schema=dict(_id=pl.Binary, Binary=pl.Binary), + ) + write(self.coll, df_in) + df_out = find_polars_all(self.coll, {}) + self.assertEqual(df_out.columns, ["_id", "Binary"]) + self.assertTrue(all([isinstance(c, pl.Binary) for c in df_out.dtypes])) + self.assertIsNone(assert_frame_equal(df_in, df_out)) + # 3. Explicit Int32 _id + self.coll.drop() + df_in = pl.DataFrame( + data={"_id": [0, 1], "Binary": [b"1", b"one"]}, + schema={"_id": pl.Int32, "Binary": pl.Binary}, + ) + write(self.coll, df_in) + df_out = find_polars_all(self.coll, {}) + self.assertEqual(df_out.columns, ["_id", "Binary"]) + out_types = df_out.dtypes + self.assertTrue(isinstance(out_types[0], pl.Int32)) + self.assertTrue(isinstance(out_types[1], pl.Binary)) + self.assertTrue(assert_frame_equal(df_in, df_out) is None) + + def test_bson_types(self): + """Test reading Polars and Arrow data from written BSON Data. + + This is meant to capture the use case of reading data in Python + that has been written to a DB from another language. + + Note that this tests only types currently supported by Arrow. + bson.Regex is not included, for example. + """ + + # 1. Use pymongo / bson packages to build create and write tabular data + self.coll.drop() + collection = self.coll + + data_type_map = [ + {"type": "int", "value": 42, "atype": pa.int32(), "ptype": pl.Int32}, + {"type": "long", "value": 1234567890123456789, "atype": pa.int64(), "ptype": pl.Int64}, + {"type": "double", "value": 10.5, "atype": pa.float64(), "ptype": pl.Float64}, + {"type": "string", "value": "hello world", "atype": pa.string(), "ptype": pl.String}, + {"type": "boolean", "value": True, "atype": pa.bool_(), "ptype": pl.Boolean}, + { + "type": "date", + "value": datetime(2025, 1, 21), + "atype": pa.timestamp("ms"), + "ptype": pl.Datetime, + }, + { + "type": "object", + "value": {"a": 1, "b": 2}, + "atype": pa.struct({"a": pa.int32(), "b": pa.int32()}), + "ptype": pl.Struct({"a": pl.Int32, "b": pl.Int32}), + }, + { + "type": "array", + "value": [1, 2, 3], + "atype": pa.list_(pa.int32()), + "ptype": pl.List(pl.Int32), + }, + { + "type": "bytes", + "value": b"\x00\x01\x02\x03\x04", + "atype": BinaryType(pa.binary()), + "ptype": pl.Binary, + }, + { + "type": "binary data", + "value": bson.Binary(b"\x00\x01\x02\x03\x04"), + "atype": BinaryType(pa.binary()), + "ptype": pl.Binary, + }, + { + "type": "object id", + "value": bson.ObjectId(), + "atype": ObjectIdType(), + "ptype": pl.Object, + }, + { + "type": "javascript", + "value": bson.Code("function() { return x; }"), + "atype": CodeType(), + "ptype": pl.String, + }, + { + "type": "decimal128", + "value": bson.Decimal128("10.99"), + "atype": Decimal128Type(), + "ptype": pl.Decimal, + }, + { + "type": "uuid", + "value": uuid.uuid4(), + "atype": BinaryType(pa.binary()), + "ptype": pl.Binary, + }, + ] + + # Iterate over types + for data_type in data_type_map: + collection.insert_one({"data_type": data_type["type"], "value": data_type["value"]}) + table = find_arrow_all(collection=collection, query={"data_type": data_type["type"]}) + assert table.shape == (1, 3) + assert table["value"].type == data_type["atype"] + try: + dfpl = pl.from_arrow(table.drop("_id")) + assert dfpl["value"].dtype == data_type["ptype"] + except pl.ComputeError: + assert isinstance(table["value"].type, pa.ExtensionType) diff --git a/bindings/python/tox.ini b/bindings/python/tox.ini index 6d1482eb..ea0e52ce 100644 --- a/bindings/python/tox.ini +++ b/bindings/python/tox.ini @@ -71,6 +71,7 @@ commands = python -c "from pymongoarrow.lib import libbson_version" [testenv:benchmark] +extras = test deps = asv commands =