Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/scripts/python_util.sh
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ import $component.dbapi
fi

# --import-mode required, else tries to import from the source dir instead of installed package
python -m pytest -vvx --import-mode append "${test_files[@]}"
# set env var so that we don't skip tests if we somehow accidentally installed pyarrow
env ADBC_NO_SKIP_TESTS=1 python -m pytest -vvx --import-mode append "${test_files[@]}"
done
}
13 changes: 8 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,18 @@


def on_missing_reference(app, env, node, contnode):
if str(contnode) == "polars.DataFrame":
if str(contnode) in {
# Polars does something odd with Sphinx such that polars.DataFrame
# isn't xrefable; suppress the warning.
return contnode
elif str(contnode) == "CapsuleType":
"polars.DataFrame",
# CapsuleType is only in 3.13+
"CapsuleType",
# Internal API
"DbapiBackend",
}:
return contnode
else:
return None

return None


def setup(app):
Expand Down
186 changes: 186 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Backend-specific operations for the DB-API layer.

These are mostly functions that convert Python types to/from Arrow types.
They are abstracted so that we can support multiple backends like PyArrow,
polars, and nanoarrow.
"""

import abc
import typing

from . import _lib


class DbapiBackend(abc.ABC):
"""
Python/Arrow type conversions that the DB-API layer needs.

The return types can and should vary based on the backend.
"""

@abc.abstractmethod
def convert_bind_parameters(self, parameters: typing.Any) -> typing.Any:
"""Convert an arbitrary Python object into bind parameters.

Parameters
----------
parameters
A sequence of bind parameters. For instance: a tuple, where each
item is a bind parameter in sequence.

Returns
-------
parameters : CapsuleType
This should be an Arrow stream capsule or an object implementing
the Arrow PyCapsule interface.

See Also
--------
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html

"""
...

@abc.abstractmethod
def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
"""Convert an arbitrary Python sequence into bind parameters.

Parameters
----------
parameters
A sequence of bind parameters. For instance: an iterable of
tuples, where each tuple is a row of bind parameters.

Returns
-------
parameters : CapsuleType
This should be an Arrow stream capsule or an object implementing
the Arrow PyCapsule interface.

See Also
--------
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html

"""
...

@abc.abstractmethod
def import_array_stream(self, handle: _lib.ArrowArrayStreamHandle) -> typing.Any:
"""Import an Arrow stream."""
...

@abc.abstractmethod
def import_schema(self, handle: _lib.ArrowSchemaHandle) -> typing.Any:
"""Import an Arrow schema."""
...


_ALL_BACKENDS: list[DbapiBackend] = []


def default_backend() -> DbapiBackend:
return _ALL_BACKENDS[-1]


class _NoOpBackend(DbapiBackend):
def convert_bind_parameters(self, parameters: typing.Any) -> typing.Any:
raise _lib.ProgrammingError(
"This API requires PyArrow or another suitable backend to be installed",
status_code=_lib.AdbcStatusCode.INVALID_STATE,
)

def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
raise _lib.ProgrammingError(
"This API requires PyArrow or another suitable backend to be installed",
status_code=_lib.AdbcStatusCode.INVALID_STATE,
)

def import_array_stream(
self, handle: _lib.ArrowArrayStreamHandle
) -> _lib.ArrowArrayStreamHandle:
return handle

def import_schema(self, handle: _lib.ArrowSchemaHandle) -> _lib.ArrowSchemaHandle:
return handle


_ALL_BACKENDS.append(_NoOpBackend())

try:
import polars

class _PolarsBackend(DbapiBackend):
def convert_bind_parameters(self, parameters: typing.Any) -> polars.DataFrame:
return polars.DataFrame(
{str(col_idx): x for col_idx, x in enumerate(parameters)},
)

def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
return polars.DataFrame(
{
str(col_idx): x
for col_idx, x in enumerate(map(list, zip(*parameters)))
},
)

def import_array_stream(
self, handle: _lib.ArrowArrayStreamHandle
) -> typing.Any:
return polars.from_arrow(handle)

def import_schema(self, handle: _lib.ArrowSchemaHandle) -> typing.Any:
raise _lib.NotSupportedError("Polars does not support __arrow_c_schema__")

_ALL_BACKENDS.append(_PolarsBackend())
except ImportError:
pass

# Keep PyArrow at the end so it stays default
try:
import pyarrow

class _PyArrowBackend(DbapiBackend):
def convert_bind_parameters(self, parameters: typing.Any) -> typing.Any:
return pyarrow.record_batch(
[[param_value] for param_value in parameters],
names=[str(i) for i in range(len(parameters))],
)

def convert_executemany_parameters(self, parameters: typing.Any) -> typing.Any:
return pyarrow.RecordBatch.from_pydict(
{
str(col_idx): pyarrow.array(x)
for col_idx, x in enumerate(map(list, zip(*parameters)))
},
)

def import_array_stream(
self, handle: _lib.ArrowArrayStreamHandle
) -> pyarrow.RecordBatchReader:
return pyarrow.RecordBatchReader._import_from_c(handle.address)

def import_schema(self, handle: _lib.ArrowSchemaHandle) -> pyarrow.Schema:
return pyarrow.schema(handle)

_ALL_BACKENDS.append(_PyArrowBackend())

except ImportError:
pass
53 changes: 23 additions & 30 deletions python/adbc_driver_manager/adbc_driver_manager/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

import adbc_driver_manager

from . import _lib
from . import _dbapi_backend, _lib
from ._lib import _blocking_call

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -303,7 +303,12 @@ def __init__(
conn_kwargs: Optional[Dict[str, str]] = None,
*,
autocommit=False,
backend: Optional[_dbapi_backend.DbapiBackend] = None,
) -> None:
if backend is None:
backend = _dbapi_backend.default_backend()

self._backend = backend
self._closed = False
if isinstance(db, _SharedDatabase):
self._db = db.clone()
Expand Down Expand Up @@ -455,8 +460,6 @@ def adbc_get_objects(
-----
This is an extension and not part of the DBAPI standard.
"""
_requires_pyarrow()

if depth in ("all", "columns"):
c_depth = _lib.GetObjectsDepth.ALL
elif depth == "catalogs":
Expand All @@ -479,7 +482,7 @@ def adbc_get_objects(
),
self._conn.cancel,
)
return pyarrow.RecordBatchReader._import_from_c(handle.address)
return self._backend.import_array_stream(handle)

def adbc_get_table_schema(
self,
Expand All @@ -504,8 +507,6 @@ def adbc_get_table_schema(
-----
This is an extension and not part of the DBAPI standard.
"""
_requires_pyarrow()

handle = _blocking_call(
self._conn.get_table_schema,
(
Expand All @@ -516,7 +517,7 @@ def adbc_get_table_schema(
{},
self._conn.cancel,
)
return pyarrow.Schema._import_from_c(handle.address)
return self._backend.import_schema(handle)

def adbc_get_table_types(self) -> List[str]:
"""
Expand Down Expand Up @@ -706,11 +707,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
if _is_arrow_data(parameters):
self._bind(parameters)
elif parameters:
_requires_pyarrow()
rb = pyarrow.record_batch(
[[param_value] for param_value in parameters],
names=[str(i) for i in range(len(parameters))],
)
rb = self._conn._backend.convert_bind_parameters(parameters)
self._bind(rb)

def execute(self, operation: Union[bytes, str], parameters=None) -> None:
Expand Down Expand Up @@ -762,18 +759,14 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
if _is_arrow_data(seq_of_parameters):
arrow_parameters = seq_of_parameters
elif seq_of_parameters:
_requires_pyarrow()
arrow_parameters = pyarrow.RecordBatch.from_pydict(
{
str(col_idx): pyarrow.array(x)
for col_idx, x in enumerate(map(list, zip(*seq_of_parameters)))
},
arrow_parameters = self._conn._backend.convert_executemany_parameters(
seq_of_parameters
)
else:
_requires_pyarrow()
arrow_parameters = pyarrow.record_batch([])
arrow_parameters = None

self._bind(arrow_parameters)
if arrow_parameters is not None:
self._bind(arrow_parameters)
self._rowcount = _blocking_call(
self._stmt.execute_update, (), {}, self._stmt.cancel
)
Expand Down Expand Up @@ -958,8 +951,7 @@ def adbc_ingest(
self._stmt.bind_stream(data)
elif _lib.is_pycapsule(data, b"arrow_array_stream"):
self._stmt.bind_stream(data)
else:
_requires_pyarrow()
elif _has_pyarrow:
if isinstance(data, pyarrow.dataset.Dataset):
data = typing.cast(pyarrow.dataset.Dataset, data).scanner().to_reader()
elif isinstance(data, pyarrow.dataset.Scanner):
Expand All @@ -974,6 +966,8 @@ def adbc_ingest(
else:
# Should be impossible from above but let's be explicit
raise TypeError(f"Cannot bind {type(data)}")
else:
raise TypeError(f"Cannot bind {type(data)}")

self._last_query = None
return _blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)
Expand All @@ -999,14 +993,13 @@ def adbc_execute_partitions(
-----
This is an extension and not part of the DBAPI standard.
"""
_requires_pyarrow()
self._clear()
self._prepare_execute(operation, parameters)
partitions, schema_handle, self._rowcount = _blocking_call(
self._stmt.execute_partitions, (), {}, self._stmt.cancel
)
if schema_handle and schema_handle.address:
schema = pyarrow.Schema._import_from_c(schema_handle.address)
schema = self._conn._backend.import_schema(schema_handle)
else:
schema = None
return partitions, schema
Expand All @@ -1024,11 +1017,10 @@ def adbc_execute_schema(self, operation, parameters=None) -> "pyarrow.Schema":
-----
This is an extension and not part of the DBAPI standard.
"""
_requires_pyarrow()
self._clear()
self._prepare_execute(operation, parameters)
schema = _blocking_call(self._stmt.execute_schema, (), {}, self._stmt.cancel)
return pyarrow.Schema._import_from_c(schema.address)
return self._conn._backend.import_schema(schema)

def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema"]:
"""
Expand All @@ -1048,7 +1040,6 @@ def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema
-----
This is an extension and not part of the DBAPI standard.
"""
_requires_pyarrow()
self._clear()
self._prepare_execute(operation)

Expand All @@ -1058,7 +1049,7 @@ def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema
)
except NotSupportedError:
return None
return pyarrow.Schema._import_from_c(handle.address)
return self._conn._backend.import_schema(handle)

def adbc_read_partition(self, partition: bytes) -> None:
"""
Expand Down Expand Up @@ -1218,7 +1209,9 @@ def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
class _RowIterator(_Closeable):
"""Track state needed to iterate over the result set."""

def __init__(self, stmt, handle: _lib.ArrowArrayStreamHandle) -> None:
def __init__(
self, stmt: _lib.AdbcStatement, handle: _lib.ArrowArrayStreamHandle
) -> None:
self._stmt = stmt
self._handle: Optional[_lib.ArrowArrayStreamHandle] = handle
self._reader: Optional["_reader.AdbcRecordBatchReader"] = None
Expand Down
Loading
Loading