5858
5959import adbc_driver_manager
6060
61- from . import _lib
61+ from . import _dbapi_backend , _lib
6262from ._lib import _blocking_call
6363
6464if typing .TYPE_CHECKING :
@@ -303,7 +303,12 @@ def __init__(
303303 conn_kwargs : Optional [Dict [str , str ]] = None ,
304304 * ,
305305 autocommit = False ,
306+ backend : Optional [_dbapi_backend .DbapiBackend ] = None ,
306307 ) -> None :
308+ if backend is None :
309+ backend = _dbapi_backend .default_backend ()
310+
311+ self ._backend = backend
307312 self ._closed = False
308313 if isinstance (db , _SharedDatabase ):
309314 self ._db = db .clone ()
@@ -455,8 +460,6 @@ def adbc_get_objects(
455460 -----
456461 This is an extension and not part of the DBAPI standard.
457462 """
458- _requires_pyarrow ()
459-
460463 if depth in ("all" , "columns" ):
461464 c_depth = _lib .GetObjectsDepth .ALL
462465 elif depth == "catalogs" :
@@ -479,7 +482,7 @@ def adbc_get_objects(
479482 ),
480483 self ._conn .cancel ,
481484 )
482- return pyarrow . RecordBatchReader . _import_from_c (handle . address )
485+ return self . _backend . import_array_stream (handle )
483486
484487 def adbc_get_table_schema (
485488 self ,
@@ -504,8 +507,6 @@ def adbc_get_table_schema(
504507 -----
505508 This is an extension and not part of the DBAPI standard.
506509 """
507- _requires_pyarrow ()
508-
509510 handle = _blocking_call (
510511 self ._conn .get_table_schema ,
511512 (
@@ -516,7 +517,7 @@ def adbc_get_table_schema(
516517 {},
517518 self ._conn .cancel ,
518519 )
519- return pyarrow . Schema . _import_from_c (handle . address )
520+ return self . _backend . import_schema (handle )
520521
521522 def adbc_get_table_types (self ) -> List [str ]:
522523 """
@@ -706,11 +707,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
706707 if _is_arrow_data (parameters ):
707708 self ._bind (parameters )
708709 elif parameters :
709- _requires_pyarrow ()
710- rb = pyarrow .record_batch (
711- [[param_value ] for param_value in parameters ],
712- names = [str (i ) for i in range (len (parameters ))],
713- )
710+ rb = self ._conn ._backend .convert_bind_parameters (parameters )
714711 self ._bind (rb )
715712
716713 def execute (self , operation : Union [bytes , str ], parameters = None ) -> None :
@@ -762,18 +759,14 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
762759 if _is_arrow_data (seq_of_parameters ):
763760 arrow_parameters = seq_of_parameters
764761 elif seq_of_parameters :
765- _requires_pyarrow ()
766- arrow_parameters = pyarrow .RecordBatch .from_pydict (
767- {
768- str (col_idx ): pyarrow .array (x )
769- for col_idx , x in enumerate (map (list , zip (* seq_of_parameters )))
770- },
762+ arrow_parameters = self ._conn ._backend .convert_executemany_parameters (
763+ seq_of_parameters
771764 )
772765 else :
773- _requires_pyarrow ()
774- arrow_parameters = pyarrow .record_batch ([])
766+ arrow_parameters = None
775767
776- self ._bind (arrow_parameters )
768+ if arrow_parameters is not None :
769+ self ._bind (arrow_parameters )
777770 self ._rowcount = _blocking_call (
778771 self ._stmt .execute_update , (), {}, self ._stmt .cancel
779772 )
@@ -958,8 +951,7 @@ def adbc_ingest(
958951 self ._stmt .bind_stream (data )
959952 elif _lib .is_pycapsule (data , b"arrow_array_stream" ):
960953 self ._stmt .bind_stream (data )
961- else :
962- _requires_pyarrow ()
954+ elif _has_pyarrow :
963955 if isinstance (data , pyarrow .dataset .Dataset ):
964956 data = typing .cast (pyarrow .dataset .Dataset , data ).scanner ().to_reader ()
965957 elif isinstance (data , pyarrow .dataset .Scanner ):
@@ -974,6 +966,8 @@ def adbc_ingest(
974966 else :
975967 # Should be impossible from above but let's be explicit
976968 raise TypeError (f"Cannot bind { type (data )} " )
969+ else :
970+ raise TypeError (f"Cannot bind { type (data )} " )
977971
978972 self ._last_query = None
979973 return _blocking_call (self ._stmt .execute_update , (), {}, self ._stmt .cancel )
@@ -999,14 +993,13 @@ def adbc_execute_partitions(
999993 -----
1000994 This is an extension and not part of the DBAPI standard.
1001995 """
1002- _requires_pyarrow ()
1003996 self ._clear ()
1004997 self ._prepare_execute (operation , parameters )
1005998 partitions , schema_handle , self ._rowcount = _blocking_call (
1006999 self ._stmt .execute_partitions , (), {}, self ._stmt .cancel
10071000 )
10081001 if schema_handle and schema_handle .address :
1009- schema = pyarrow . Schema . _import_from_c (schema_handle . address )
1002+ schema = self . _conn . _backend . import_schema (schema_handle )
10101003 else :
10111004 schema = None
10121005 return partitions , schema
@@ -1024,11 +1017,10 @@ def adbc_execute_schema(self, operation, parameters=None) -> "pyarrow.Schema":
10241017 -----
10251018 This is an extension and not part of the DBAPI standard.
10261019 """
1027- _requires_pyarrow ()
10281020 self ._clear ()
10291021 self ._prepare_execute (operation , parameters )
10301022 schema = _blocking_call (self ._stmt .execute_schema , (), {}, self ._stmt .cancel )
1031- return pyarrow . Schema . _import_from_c (schema . address )
1023+ return self . _conn . _backend . import_schema (schema )
10321024
10331025 def adbc_prepare (self , operation : Union [bytes , str ]) -> Optional ["pyarrow.Schema" ]:
10341026 """
@@ -1048,7 +1040,6 @@ def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema
10481040 -----
10491041 This is an extension and not part of the DBAPI standard.
10501042 """
1051- _requires_pyarrow ()
10521043 self ._clear ()
10531044 self ._prepare_execute (operation )
10541045
@@ -1058,7 +1049,7 @@ def adbc_prepare(self, operation: Union[bytes, str]) -> Optional["pyarrow.Schema
10581049 )
10591050 except NotSupportedError :
10601051 return None
1061- return pyarrow . Schema . _import_from_c (handle . address )
1052+ return self . _conn . _backend . import_schema (handle )
10621053
10631054 def adbc_read_partition (self , partition : bytes ) -> None :
10641055 """
@@ -1218,7 +1209,9 @@ def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
12181209class _RowIterator (_Closeable ):
12191210 """Track state needed to iterate over the result set."""
12201211
1221- def __init__ (self , stmt , handle : _lib .ArrowArrayStreamHandle ) -> None :
1212+ def __init__ (
1213+ self , stmt : _lib .AdbcStatement , handle : _lib .ArrowArrayStreamHandle
1214+ ) -> None :
12221215 self ._stmt = stmt
12231216 self ._handle : Optional [_lib .ArrowArrayStreamHandle ] = handle
12241217 self ._reader : Optional ["_reader.AdbcRecordBatchReader" ] = None
0 commit comments