Skip to content

Commit e2da620

Browse files
committed
feat(python/adbc_driver_manager): support more sans PyArrow
Fixes #2827.
1 parent 0000e8d commit e2da620

File tree

5 files changed

+293
-36
lines changed

5 files changed

+293
-36
lines changed

ci/scripts/python_util.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ import $component.dbapi
189189
fi
190190

191191
# --import-mode required, else tries to import from the source dir instead of installed package
192-
python -m pytest -vvx --import-mode append "${test_files[@]}"
192+
# set env var so that we don't skip tests if we somehow accidentally installed pyarrow
193+
env ADBC_NO_SKIP_TESTS=1 python -m pytest -vvx --import-mode append "${test_files[@]}"
193194
done
194195
}

docs/source/conf.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,18 @@
5656

5757

5858
def on_missing_reference(app, env, node, contnode):
59-
if str(contnode) == "polars.DataFrame":
59+
if str(contnode) in {
6060
# Polars does something odd with Sphinx such that polars.DataFrame
6161
# isn't xrefable; suppress the warning.
62-
return contnode
63-
elif str(contnode) == "CapsuleType":
62+
"polars.DataFrame",
6463
# CapsuleType is only in 3.13+
64+
"CapsuleType",
65+
# Internal API
66+
"DbapiBackend",
67+
}:
6568
return contnode
66-
else:
67-
return None
69+
70+
return None
6871

6972

7073
def setup(app):
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
Backend-specific operations for the DB-API layer.
20+
21+
These are mostly functions that convert Python types to/from Arrow types.
22+
They are abstracted so that we can support multiple backends like PyArrow,
23+
polars, and nanoarrow.
24+
"""
25+
26+
import abc
27+
import typing
28+
29+
from . import _lib
30+
31+
32+
class DbapiBackend(abc.ABC):
33+
"""
34+
Python/Arrow type conversions that the DB-API layer needs.
35+
36+
The return types can and should vary based on the backend.
37+
"""
38+
39+
@abc.abstractmethod
40+
def convert_bind_parameters(self, parameters: typing.Any) -> typing.Any:
41+
"""Convert an arbitrary Python object into bind parameters.
42+
43+
Parameters
44+
----------
45+
parameters
46+
A sequence of bind parameters. For instance: a tuple, where each
47+
item is a bind parameter in sequence.
48+
49+
Returns
50+
-------
51+
parameters : CapsuleType
52+
This should be an Arrow stream capsule or an object implementing
53+
the Arrow PyCapsule interface.
54+
55+
See Also
56+
--------
57+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
58+
59+
"""
60+
...
61+
62+
@abc.abstractmethod
63+
def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
64+
"""Convert an arbitrary Python sequence into bind parameters.
65+
66+
Parameters
67+
----------
68+
parameters
69+
A sequence of bind parameters. For instance: an iterable of
70+
tuples, where each tuple is a row of bind parameters.
71+
72+
Returns
73+
-------
74+
parameters : CapsuleType
75+
This should be an Arrow stream capsule or an object implementing
76+
the Arrow PyCapsule interface.
77+
78+
See Also
79+
--------
80+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
81+
82+
"""
83+
...
84+
85+
@abc.abstractmethod
86+
def import_array_stream(self, handle: _lib.ArrowArrayStreamHandle) -> typing.Any:
87+
"""Import an Arrow stream."""
88+
...
89+
90+
@abc.abstractmethod
91+
def import_schema(self, handle: _lib.ArrowSchemaHandle) -> typing.Any:
92+
"""Import an Arrow schema."""
93+
...
94+
95+
96+
_ALL_BACKENDS: list[DbapiBackend] = []
97+
98+
99+
def default_backend() -> DbapiBackend:
100+
return _ALL_BACKENDS[-1]
101+
102+
103+
class _NoOpBackend(DbapiBackend):
104+
def convert_bind_parameters(self, parameters: typing.Any) -> typing.Any:
105+
raise _lib.ProgrammingError(
106+
"This API requires PyArrow or another suitable backend to be installed",
107+
status_code=_lib.AdbcStatusCode.INVALID_STATE,
108+
)
109+
110+
def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
111+
raise _lib.ProgrammingError(
112+
"This API requires PyArrow or another suitable backend to be installed",
113+
status_code=_lib.AdbcStatusCode.INVALID_STATE,
114+
)
115+
116+
def import_array_stream(
117+
self, handle: _lib.ArrowArrayStreamHandle
118+
) -> _lib.ArrowArrayStreamHandle:
119+
return handle
120+
121+
def import_schema(self, handle: _lib.ArrowSchemaHandle) -> _lib.ArrowSchemaHandle:
122+
return handle
123+
124+
125+
_ALL_BACKENDS.append(_NoOpBackend())
126+
127+
try:
128+
import polars
129+
130+
class _PolarsBackend(DbapiBackend):
131+
def convert_bind_parameters(self, parameters: typing.Any) -> polars.DataFrame:
132+
return polars.DataFrame(
133+
{str(col_idx): x for col_idx, x in enumerate(parameters)},
134+
)
135+
136+
def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
137+
return polars.DataFrame(
138+
{
139+
str(col_idx): x
140+
for col_idx, x in enumerate(map(list, zip(*parameters)))
141+
},
142+
)
143+
144+
def import_array_stream(
145+
self, handle: _lib.ArrowArrayStreamHandle
146+
) -> typing.Any:
147+
return polars.from_arrow(handle)
148+
149+
def import_schema(self, handle: _lib.ArrowSchemaHandle) -> typing.Any:
150+
raise _lib.NotSupportedError("Polars does not support __arrow_c_schema__")
151+
152+
_ALL_BACKENDS.append(_PolarsBackend())
153+
except ImportError:
154+
pass
155+
156+
# Keep PyArrow at the end so it stays default
157+
try:
158+
import pyarrow
159+
160+
class _PyArrowBackend(DbapiBackend):
161+
def convert_bind_parameters(self, parameters: typing.Any) -> typing.Any:
162+
return pyarrow.record_batch(
163+
[[param_value] for param_value in parameters],
164+
names=[str(i) for i in range(len(parameters))],
165+
)
166+
167+
def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
168+
return pyarrow.RecordBatch.from_pydict(
169+
{
170+
str(col_idx): pyarrow.array(x)
171+
for col_idx, x in enumerate(map(list, zip(*parameters)))
172+
},
173+
)
174+
175+
def import_array_stream(
176+
self, handle: _lib.ArrowArrayStreamHandle
177+
) -> pyarrow.RecordBatchReader:
178+
return pyarrow.RecordBatchReader._import_from_c(handle.address)
179+
180+
def import_schema(self, handle: _lib.ArrowSchemaHandle) -> pyarrow.Schema:
181+
return pyarrow.schema(handle)
182+
183+
_ALL_BACKENDS.append(_PyArrowBackend())
184+
185+
except ImportError:
186+
pass

python/adbc_driver_manager/adbc_driver_manager/dbapi.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858

5959
import adbc_driver_manager
6060

61-
from . import _lib
61+
from . import _dbapi_backend, _lib
6262
from ._lib import _blocking_call
6363

6464
if typing.TYPE_CHECKING:
@@ -303,7 +303,12 @@ def __init__(
303303
conn_kwargs: Optional[Dict[str, str]] = None,
304304
*,
305305
autocommit=False,
306+
backend: Optional[_dbapi_backend.DbapiBackend] = None,
306307
) -> None:
308+
if backend is None:
309+
backend = _dbapi_backend.default_backend()
310+
311+
self._backend = backend
307312
self._closed = False
308313
if isinstance(db, _SharedDatabase):
309314
self._db = db.clone()
@@ -455,8 +460,6 @@ def adbc_get_objects(
455460
-----
456461
This is an extension and not part of the DBAPI standard.
457462
"""
458-
_requires_pyarrow()
459-
460463
if depth in ("all", "columns"):
461464
c_depth = _lib.GetObjectsDepth.ALL
462465
elif depth == "catalogs":
@@ -479,7 +482,7 @@ def adbc_get_objects(
479482
),
480483
self._conn.cancel,
481484
)
482-
return pyarrow.RecordBatchReader._import_from_c(handle.address)
485+
return self._backend.import_array_stream(handle)
483486

484487
def adbc_get_table_schema(
485488
self,
@@ -504,8 +507,6 @@ def adbc_get_table_schema(
504507
-----
505508
This is an extension and not part of the DBAPI standard.
506509
"""
507-
_requires_pyarrow()
508-
509510
handle = _blocking_call(
510511
self._conn.get_table_schema,
511512
(
@@ -516,7 +517,7 @@ def adbc_get_table_schema(
516517
{},
517518
self._conn.cancel,
518519
)
519-
return pyarrow.Schema._import_from_c(handle.address)
520+
return self._backend.import_schema(handle)
520521

521522
def adbc_get_table_types(self) -> List[str]:
522523
"""
@@ -706,11 +707,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
706707
if _is_arrow_data(parameters):
707708
self._bind(parameters)
708709
elif parameters:
709-
_requires_pyarrow()
710-
rb = pyarrow.record_batch(
711-
[[param_value] for param_value in parameters],
712-
names=[str(i) for i in range(len(parameters))],
713-
)
710+
rb = self._conn._backend.convert_bind_parameters(parameters)
714711
self._bind(rb)
715712

716713
def execute(self, operation: Union[bytes, str], parameters=None) -> None:
@@ -762,18 +759,14 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
762759
if _is_arrow_data(seq_of_parameters):
763760
arrow_parameters = seq_of_parameters
764761
elif seq_of_parameters:
765-
_requires_pyarrow()
766-
arrow_parameters = pyarrow.RecordBatch.from_pydict(
767-
{
768-
str(col_idx): pyarrow.array(x)
769-
for col_idx, x in enumerate(map(list, zip(*seq_of_parameters)))
770-
},
762+
arrow_parameters = self._conn._backend.convert_executemany_parameters(
763+
seq_of_parameters
771764
)
772765
else:
773-
_requires_pyarrow()
774-
arrow_parameters = pyarrow.record_batch([])
766+
arrow_parameters = None
775767

776-
self._bind(arrow_parameters)
768+
if arrow_parameters is not None:
769+
self._bind(arrow_parameters)
777770
self._rowcount = _blocking_call(
778771
self._stmt.execute_update, (), {}, self._stmt.cancel
779772
)
@@ -958,8 +951,7 @@ def adbc_ingest(
958951
self._stmt.bind_stream(data)
959952
elif _lib.is_pycapsule(data, b"arrow_array_stream"):
960953
self._stmt.bind_stream(data)
961-
else:
962-
_requires_pyarrow()
954+
elif _has_pyarrow:
963955
if isinstance(data, pyarrow.dataset.Dataset):
964956
data = typing.cast(pyarrow.dataset.Dataset, data).scanner().to_reader()
965957
elif isinstance(data, pyarrow.dataset.Scanner):
@@ -974,6 +966,8 @@ def adbc_ingest(
974966
else:
975967
# Should be impossible from above but let's be explicit
976968
raise TypeError(f"Cannot bind {type(data)}")
969+
else:
970+
raise TypeError(f"Cannot bind {type(data)}")
977971

978972
self._last_query = None
979973
return _blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)
@@ -999,14 +993,13 @@ def adbc_execute_partitions(
999993
-----
1000994
This is an extension and not part of the DBAPI standard.
1001995
"""
1002-
_requires_pyarrow()
1003996
self._clear()
1004997
self._prepare_execute(operation, parameters)
1005998
partitions, schema_handle, self._rowcount = _blocking_call(
1006999
self._stmt.execute_partitions, (), {}, self._stmt.cancel
10071000
)
10081001
if schema_handle and schema_handle.address:
1009-
schema = pyarrow.Schema._import_from_c(schema_handle.address)
1002+
schema = self._conn._backend.import_schema(schema_handle)
10101003
else:
10111004
schema = None
10121005
return partitions, schema
@@ -1024,11 +1017,10 @@ def adbc_execute_schema(self, operation, parameters=None) -> "pyarrow.Schema":
10241017
-----
10251018
This is an extension and not part of the DBAPI standard.
10261019
"""
1027-
_requires_pyarrow()
10281020
self._clear()
10291021
self._prepare_execute(operation, parameters)
10301022
schema = _blocking_call(self._stmt.execute_schema, (), {}, self._stmt.cancel)
1031-
return pyarrow.Schema._import_from_c(schema.address)
1023+
return self._conn._backend.import_schema(schema)
10321024

10331025
def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema"]:
10341026
"""
@@ -1048,7 +1040,6 @@ def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema
10481040
-----
10491041
This is an extension and not part of the DBAPI standard.
10501042
"""
1051-
_requires_pyarrow()
10521043
self._clear()
10531044
self._prepare_execute(operation)
10541045

@@ -1058,7 +1049,7 @@ def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema
10581049
)
10591050
except NotSupportedError:
10601051
return None
1061-
return pyarrow.Schema._import_from_c(handle.address)
1052+
return self._conn._backend.import_schema(handle)
10621053

10631054
def adbc_read_partition(self, partition: bytes) -> None:
10641055
"""
@@ -1218,7 +1209,9 @@ def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
12181209
class _RowIterator(_Closeable):
12191210
"""Track state needed to iterate over the result set."""
12201211

1221-
def __init__(self, stmt, handle: _lib.ArrowArrayStreamHandle) -> None:
1212+
def __init__(
1213+
self, stmt: _lib.AdbcStatement, handle: _lib.ArrowArrayStreamHandle
1214+
) -> None:
12221215
self._stmt = stmt
12231216
self._handle: Optional[_lib.ArrowArrayStreamHandle] = handle
12241217
self._reader: Optional["_reader.AdbcRecordBatchReader"] = None

0 commit comments

Comments
 (0)