Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@
)
from openai.types.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted
from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
from airflow.providers.openai.version_compat import BaseHook


class BatchStatus(str, Enum):
Expand Down
14 changes: 13 additions & 1 deletion providers/openai/src/airflow/providers/openai/version_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,22 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:


AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)


if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import BaseHook
else:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
else:
from airflow.models import BaseOperator # type: ignore[no-redef]

__all__ = ["AIRFLOW_V_3_0_PLUS", "BaseOperator"]
__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
"BaseHook",
"BaseOperator",
]