diff --git a/providers/github/src/airflow/providers/github/operators/github.py b/providers/github/src/airflow/providers/github/operators/github.py index fb368b9280e32..c7420e0680882 100644 --- a/providers/github/src/airflow/providers/github/operators/github.py +++ b/providers/github/src/airflow/providers/github/operators/github.py @@ -23,15 +23,11 @@ from github import GithubException from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.github.hooks.github import GithubHook +from airflow.providers.github.version_compat import BaseOperator if TYPE_CHECKING: - try: - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context + from airflow.providers.github.version_compat import Context class GithubOperator(BaseOperator): diff --git a/providers/github/src/airflow/providers/github/sensors/github.py b/providers/github/src/airflow/providers/github/sensors/github.py index b234168494836..446b66a792b95 100644 --- a/providers/github/src/airflow/providers/github/sensors/github.py +++ b/providers/github/src/airflow/providers/github/sensors/github.py @@ -24,19 +24,10 @@ from airflow.exceptions import AirflowException from airflow.providers.github.hooks.github import GithubHook -from airflow.providers.github.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseSensorOperator -else: - from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] +from airflow.providers.github.version_compat import BaseSensorOperator if TYPE_CHECKING: - try: - from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context + from airflow.providers.github.version_compat import Context class GithubSensor(BaseSensorOperator): diff --git a/providers/github/src/airflow/providers/github/version_compat.py b/providers/github/src/airflow/providers/github/version_compat.py index 48d122b669696..682d19a5f49be 100644 --- a/providers/github/src/airflow/providers/github/version_compat.py +++ b/providers/github/src/airflow/providers/github/version_compat.py @@ -33,3 +33,19 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator, BaseSensorOperator + from airflow.sdk.definitions.context import Context +else: + from airflow.models import BaseOperator + from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef] + from airflow.utils.context import Context + + +__all__ = [ + "AIRFLOW_V_3_0_PLUS", + "BaseOperator", + "BaseSensorOperator", + "Context", +] diff --git a/providers/github/tests/unit/github/operators/test_github.py b/providers/github/tests/unit/github/operators/test_github.py index cdfc72177e5a2..b6c6cf681fbb4 100644 --- a/providers/github/tests/unit/github/operators/test_github.py +++ b/providers/github/tests/unit/github/operators/test_github.py @@ -61,7 +61,7 @@ def test_operator_init_with_optional_args(self): @patch( "airflow.providers.github.hooks.github.GithubClient", autospec=True, return_value=github_client_mock ) - def test_find_repos(self, github_mock): + def test_find_repos(self, github_mock, dag_maker): class MockRepository: pass @@ -69,16 +69,15 @@ class MockRepository: repo.full_name = "apache/airflow" github_mock.return_value.get_repo.return_value = repo - - github_operator = GithubOperator( - task_id="github-test", - github_method="get_repo", - github_method_args={"full_name_or_id": "apache/airflow"}, - result_processor=lambda r: r.full_name, - dag=self.dag, - ) - - github_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + with dag_maker(): + GithubOperator( + task_id="github-test", + github_method="get_repo", + github_method_args={"full_name_or_id": "apache/airflow"}, + result_processor=lambda r: r.full_name, + ) + dr = dag_maker.create_dagrun() + dag_maker.run_ti("github-test", dr) assert github_mock.called assert github_mock.return_value.get_repo.called