diff --git a/setup.py b/setup.py index e2ee0db7f2090..4318ea8860cf0 100644 --- a/setup.py +++ b/setup.py @@ -154,7 +154,7 @@ def get_git_sha() -> str: ], "db2": ["ibm-db-sa>0.3.8, <=0.4.0"], "dremio": ["sqlalchemy-dremio>=1.1.5, <1.3"], - "drill": ["sqlalchemy-drill==0.1.dev"], + "drill": ["sqlalchemy-drill>=1.1.4, <2"], "druid": ["pydruid>=0.6.5,<0.7"], "duckdb": ["duckdb-engine>=0.9.5, <0.10"], "dynamodb": ["pydynamodb>=0.4.2"], diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index fb42409b4e952..276ff5b185448 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -14,8 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + from datetime import datetime -from typing import Any, Optional +from typing import Any from urllib import parse from sqlalchemy import types @@ -60,8 +63,8 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -76,8 +79,8 @@ def adjust_engine_params( cls, uri: URL, connect_args: dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, + catalog: str | None = None, + schema: str | None = None, ) -> tuple[URL, dict[str, Any]]: if schema: uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe="")) @@ -89,7 +92,7 @@ def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, connect_args: dict[str, Any], - ) -> Optional[str]: + ) -> str | None: """ Return the configured schema. """ @@ -97,7 +100,7 @@ def get_schema_from_engine_params( @classmethod def get_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str] + cls, url: URL, impersonate_user: bool, username: str | None ) -> URL: """ Return a modified URL with the username set. @@ -117,3 +120,23 @@ def get_url_for_impersonation( ) return url + + @classmethod + def fetch_data( + cls, + cursor: Any, + limit: int | None = None, + ) -> list[tuple[Any, ...]]: + """ + Custom `fetch_data` for Drill. + + When no rows are returned, Drill raises a `RuntimeError` with the message + "generator raised StopIteration". This method catches the exception and + returns an empty list instead. + """ + try: + return super().fetch_data(cursor, limit) + except RuntimeError as ex: + if str(ex) == "generator raised StopIteration": + return [] + raise