Skip to content

Commit

Permalink
feat(python): react to SIGINT in more places (#1579)
Browse files Browse the repository at this point in the history
Fixes #1573.
  • Loading branch information
lidavidm authored Mar 1, 2024
1 parent d9d66b2 commit 2353563
Showing 1 changed file with 47 additions and 22 deletions.
69 changes: 47 additions & 22 deletions python/adbc_driver_manager/adbc_driver_manager/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,10 @@ def adbc_get_info(self) -> Dict[Union[str, int], Any]:
-----
This is an extension and not part of the DBAPI standard.
"""
handle = self._conn.get_info()
handle = _blocking_call(self._conn.get_info, (), {}, self._conn.cancel)
reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
info = reader.read_all().to_pylist()
table = _blocking_call(reader.read_all, (), {}, self._conn.cancel)
info = table.to_pylist()
return dict(
{
_KNOWN_INFO_VALUES.get(row["info_name"], row["info_name"]): row[
Expand Down Expand Up @@ -440,13 +441,17 @@ def adbc_get_objects(
c_depth = _lib.GetObjectsDepth.TABLES
else:
raise ValueError(f"Invalid value for 'depth': {depth}")
handle = self._conn.get_objects(
c_depth,
catalog=catalog_filter,
db_schema=db_schema_filter,
table_name=table_name_filter,
table_types=table_types_filter,
column_name=column_name_filter,
handle = _blocking_call(
self._conn.get_objects,
(c_depth,),
dict(
catalog=catalog_filter,
db_schema=db_schema_filter,
table_name=table_name_filter,
table_types=table_types_filter,
column_name=column_name_filter,
),
self._conn.cancel,
)
return pyarrow.RecordBatchReader._import_from_c(handle.address)

Expand All @@ -473,8 +478,15 @@ def adbc_get_table_schema(
-----
This is an extension and not part of the DBAPI standard.
"""
handle = self._conn.get_table_schema(
catalog_filter, db_schema_filter, table_name
handle = _blocking_call(
self._conn.get_table_schema,
(
catalog_filter,
db_schema_filter,
table_name,
),
{},
self._conn.cancel,
)
return pyarrow.Schema._import_from_c(handle.address)

Expand All @@ -486,7 +498,12 @@ def adbc_get_table_types(self) -> List[str]:
-----
This is an extension and not part of the DBAPI standard.
"""
handle = self._conn.get_table_types()
handle = _blocking_call(
self._conn.get_table_types,
(),
{},
self._conn.cancel,
)
reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
table = reader.read_all()
return table[0].to_pylist()
Expand Down Expand Up @@ -648,7 +665,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
else:
self._stmt.set_sql_query(operation)
try:
self._stmt.prepare()
_blocking_call(self._stmt.prepare, (), {}, self._stmt.cancel)
except NotSupportedError:
# Not all drivers support it
pass
Expand Down Expand Up @@ -722,7 +739,9 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
arrow_parameters = pyarrow.record_batch([])

self._bind(arrow_parameters)
self._rowcount = self._stmt.execute_update()
self._rowcount = _blocking_call(
self._stmt.execute_update, (), {}, self._stmt.cancel
)

def fetchone(self) -> Optional[tuple]:
"""Fetch one row of the result."""
Expand Down Expand Up @@ -916,7 +935,7 @@ def adbc_ingest(
self._stmt.bind_stream(handle)

self._last_query = None
return self._stmt.execute_update()
return _blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)

def adbc_execute_partitions(
self,
Expand All @@ -940,7 +959,9 @@ def adbc_execute_partitions(
This is an extension and not part of the DBAPI standard.
"""
self._prepare_execute(operation, parameters)
partitions, schema_handle, self._rowcount = self._stmt.execute_partitions()
partitions, schema_handle, self._rowcount = _blocking_call(
self._stmt.execute_partitions, (), {}, self._stmt.cancel
)
if schema_handle and schema_handle.address:
schema = pyarrow.Schema._import_from_c(schema_handle.address)
else:
Expand All @@ -961,7 +982,7 @@ def adbc_execute_schema(self, operation, parameters=None) -> pyarrow.Schema:
This is an extension and not part of the DBAPI standard.
"""
self._prepare_execute(operation, parameters)
schema = self._stmt.execute_schema()
schema = _blocking_call(self._stmt.execute_schema, (), {}, self._stmt.cancel)
return pyarrow.Schema._import_from_c(schema.address)

def adbc_prepare(self, operation: Union[bytes, str]) -> Optional[pyarrow.Schema]:
Expand All @@ -985,7 +1006,9 @@ def adbc_prepare(self, operation: Union[bytes, str]) -> Optional[pyarrow.Schema]
self._prepare_execute(operation)

try:
handle = self._stmt.get_parameter_schema()
handle = _blocking_call(
self._stmt.get_parameter_schema, (), {}, self._stmt.cancel
)
except NotSupportedError:
return None
return pyarrow.Schema._import_from_c(handle.address)
Expand All @@ -999,7 +1022,9 @@ def adbc_read_partition(self, partition: bytes) -> None:
This is an extension and not part of the DBAPI standard.
"""
self._results = None
handle = self._conn._conn.read_partition(partition)
handle = _blocking_call(
self._conn._conn.read_partition, (partition,), {}, self._stmt.cancel
)
self._rowcount = -1
self._results = _RowIterator(
self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address)
Expand Down Expand Up @@ -1032,7 +1057,7 @@ def executescript(self, operation: str) -> None:
self._last_query = None
self._results = None
self._stmt.set_sql_query(operation)
self._stmt.execute_update()
_blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)

def fetchallarrow(self) -> pyarrow.Table:
"""
Expand Down Expand Up @@ -1170,10 +1195,10 @@ def fetchall(self) -> List[tuple]:
return rows

def fetch_arrow_table(self) -> pyarrow.Table:
return self._reader.read_all()
return _blocking_call(self._reader.read_all, (), {}, self._stmt.cancel)

def fetch_df(self) -> "pandas.DataFrame":
return self._reader.read_pandas()
return _blocking_call(self._reader.read_pandas, (), {}, self._stmt.cancel)


_PYTEST_ENV_VAR = "PYTEST_CURRENT_TEST"
Expand Down

0 comments on commit 2353563

Please sign in to comment.