diff --git a/providers/openai/pyproject.toml b/providers/openai/pyproject.toml index e8e757224d519..c57fb1a1647b7 100644 --- a/providers/openai/pyproject.toml +++ b/providers/openai/pyproject.toml @@ -58,6 +58,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.10.0", + "apache-airflow-providers-common-compat>=1.7.4", # + TODO: bump to next version "openai[datalib]>=1.66.0", ] @@ -66,6 +67,7 @@ dev = [ "apache-airflow", "apache-airflow-task-sdk", "apache-airflow-devel-common", + "apache-airflow-providers-common-compat", # Additional devel dependencies (do not remove this line and add extra development dependencies) ] diff --git a/providers/openai/src/airflow/providers/openai/hooks/openai.py b/providers/openai/src/airflow/providers/openai/hooks/openai.py index 5b0566acb0b07..c245bd10ca06a 100644 --- a/providers/openai/src/airflow/providers/openai/hooks/openai.py +++ b/providers/openai/src/airflow/providers/openai/hooks/openai.py @@ -43,8 +43,8 @@ ChatCompletionUserMessageParam, ) from openai.types.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted +from airflow.providers.common.compat.sdk import BaseHook from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout -from airflow.providers.openai.version_compat import BaseHook class BatchStatus(str, Enum): diff --git a/providers/openai/src/airflow/providers/openai/operators/openai.py b/providers/openai/src/airflow/providers/openai/operators/openai.py index 0b085ee205351..9479335b7d7c0 100644 --- a/providers/openai/src/airflow/providers/openai/operators/openai.py +++ b/providers/openai/src/airflow/providers/openai/operators/openai.py @@ -23,17 +23,13 @@ from typing import TYPE_CHECKING, Any, Literal from airflow.configuration import conf +from airflow.providers.common.compat.sdk import BaseOperator from airflow.providers.openai.exceptions import OpenAIBatchJobException from airflow.providers.openai.hooks.openai import OpenAIHook from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger -from airflow.providers.openai.version_compat import BaseOperator if TYPE_CHECKING: - try: - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context + from airflow.providers.common.compat.sdk import Context class OpenAIEmbeddingOperator(BaseOperator): diff --git a/providers/openai/src/airflow/providers/openai/version_compat.py b/providers/openai/src/airflow/providers/openai/version_compat.py index f885c7b73931c..f5bb3ae555c1f 100644 --- a/providers/openai/src/airflow/providers/openai/version_compat.py +++ b/providers/openai/src/airflow/providers/openai/version_compat.py @@ -35,20 +35,4 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) - -if AIRFLOW_V_3_1_PLUS: - from airflow.sdk import BaseHook -else: - from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef] - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator -else: - from airflow.models import BaseOperator - -__all__ = [ - "AIRFLOW_V_3_0_PLUS", - "AIRFLOW_V_3_1_PLUS", - "BaseHook", - "BaseOperator", -] +__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS"] diff --git a/providers/openai/tests/system/openai/example_openai.py b/providers/openai/tests/system/openai/example_openai.py index dd35125bd254b..159f5a587c2da 100644 --- a/providers/openai/tests/system/openai/example_openai.py +++ b/providers/openai/tests/system/openai/example_openai.py @@ -18,11 +18,9 @@ import pendulum -try: - from airflow.sdk import dag, task -except ImportError: - # Airflow 2 path - from airflow.decorators import dag, task # type: ignore[attr-defined,no-redef] +# This example uses common.compat for Airflow 2.x/3.x compatibility. +# If you only need Airflow 3+, you can use: from airflow.sdk import dag, task +from airflow.providers.common.compat.sdk import dag, task from airflow.providers.openai.hooks.openai import OpenAIHook from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator diff --git a/providers/openai/tests/unit/openai/operators/test_openai.py b/providers/openai/tests/unit/openai/operators/test_openai.py index 6b582b6d9feec..8542c0fc2bbab 100644 --- a/providers/openai/tests/unit/openai/operators/test_openai.py +++ b/providers/openai/tests/unit/openai/operators/test_openai.py @@ -22,15 +22,10 @@ from openai.types.batch import Batch from airflow.exceptions import TaskDeferred +from airflow.providers.common.compat.sdk import Context from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator, OpenAITriggerBatchOperator from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger -try: - from airflow.sdk.definitions.context import Context -except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context - openai = pytest.importorskip("openai") TASK_ID = "TaskId" CONN_ID = "test_conn_id"