diff --git a/providers/teradata/pyproject.toml b/providers/teradata/pyproject.toml index 553acc1f033f6..a6ea026c00adb 100644 --- a/providers/teradata/pyproject.toml +++ b/providers/teradata/pyproject.toml @@ -77,6 +77,9 @@ dependencies = [ "ssh" = [ "apache-airflow-providers-ssh" ] +"sqlalchemy" = [ + "sqlalchemy>=1.4.49", +] [dependency-groups] dev = [ diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py b/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py index 4429a1c093262..6131c4c3c317c 100644 --- a/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py +++ b/providers/teradata/src/airflow/providers/teradata/hooks/teradata.py @@ -23,9 +23,9 @@ from typing import TYPE_CHECKING, Any import teradatasql -from sqlalchemy.engine import URL from teradatasql import TeradataConnection +from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: @@ -38,6 +38,18 @@ PARAM_TYPES = {bool, float, int, str} +def _get_sqlalchemy_url(): + try: + from sqlalchemy.engine import URL + + return URL + except ImportError: + raise AirflowOptionalProviderFeatureException( + "SQLAlchemy is required for this Teradata feature." + "Install it with: pip install 'apache-airflow-providers-teradata[sqlalchemy]'" + ) + + def _map_param(value): if value in PARAM_TYPES: # In this branch, value is a Python type; calling it produces @@ -197,12 +209,13 @@ def _get_conn_config_teradatasql(self) -> dict[str, Any]: return conn_config @property - def sqlalchemy_url(self) -> URL: + def sqlalchemy_url(self): """ Override to return a Sqlalchemy.engine.URL object from the Teradata connection. :return: the extracted sqlalchemy.engine.URL object. """ + URL = _get_sqlalchemy_url() connection = self.get_connection(self.get_conn_id()) # Adding only teradatasqlalchemy supported connection parameters. # https://pypi.org/project/teradatasqlalchemy/#ConnectionParameters