From feb27eef87d2be220a2f843f53258cdb1a9328bd Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sat, 12 Jul 2025 21:54:06 +0100 Subject: [PATCH] Cleanup mypy ignores in databricks provider where possible --- .../airflow/providers/databricks/hooks/databricks_sql.py | 6 +++--- .../airflow/providers/databricks/operators/databricks.py | 6 ++---- .../providers/databricks/plugins/databricks_workflow.py | 2 +- .../src/airflow/providers/databricks/utils/openlineage.py | 4 ++-- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py index 0ace06abc998b..ce60f0e4bdeb1 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -30,7 +30,7 @@ overload, ) -from databricks import sql # type: ignore[attr-defined] +from databricks import sql from databricks.sql.types import Row from airflow.exceptions import AirflowException @@ -173,7 +173,7 @@ def get_conn(self) -> AirflowConnection: raise AirflowException("SQL connection is not initialized") return cast("AirflowConnection", self._sql_conn) - @overload # type: ignore[override] + @overload def run( self, sql: str | Iterable[str], @@ -258,7 +258,7 @@ def run( # TODO: adjust this to make testing easier try: - self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined] + self._run_command(cur, sql_statement, parameters) except Exception as e: if t is None or t.is_alive(): raise DatabricksSqlExecutionError( diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index e341daaa8a096..c9cf3c0045ccf 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -60,7 +60,7 @@ from airflow.sdk import BaseOperatorLink from airflow.sdk.execution_time.xcom import XCom else: - from airflow.models import XCom # type: ignore[no-redef] + from airflow.models import XCom from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] DEFER_METHOD_NAME = "execute_complete" @@ -1428,9 +1428,7 @@ def execute(self, context: Context) -> None: if not self.workflow_run_metadata: launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch")) self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id) - workflow_run_metadata = WorkflowRunMetadata( # type: ignore[arg-type] - **self.workflow_run_metadata - ) + workflow_run_metadata = WorkflowRunMetadata(**self.workflow_run_metadata) self.databricks_run_id = workflow_run_metadata.run_id self.databricks_conn_id = workflow_run_metadata.conn_id diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index ab927fb0c1722..d504099a6d97f 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -54,7 +54,7 @@ from airflow.sdk import BaseOperatorLink from airflow.sdk.execution_time.xcom import XCom else: - from airflow.models import XCom # type: ignore[no-redef] + from airflow.models import XCom from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] diff --git a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py index 50df7efd0e889..971e59f291549 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py @@ -342,8 +342,8 @@ def emit_openlineage_events_for_databricks_queries( event_batch = _create_ol_event_pair( job_namespace=namespace(), job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}", - start_time=query_metadata.get("start_time") or default_event_time, # type: ignore[arg-type] - end_time=query_metadata.get("end_time") or default_event_time, # type: ignore[arg-type] + start_time=query_metadata.get("start_time") or default_event_time, + end_time=query_metadata.get("end_time") or default_event_time, # Only finished status means it completed without failures is_successful=(query_metadata.get("status") or default_state).lower() == "finished", run_facets={**query_specific_run_facets, **common_run_facets, **additional_run_facets},