Skip to content

Commit

Permalink
Fix version check for CLI Imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bugraoz93 committed Dec 27, 2024
1 parent 9be5971 commit b89f1f4
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 98 deletions.
2 changes: 1 addition & 1 deletion dev/breeze/doc/images/output_setup.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ba00ab3fb2ed5a777684878c28b3ce65
08c78d9dddd037a2ade6b751c5a22ff9
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2e42a9de8b8ed2ce83b5a1fcbdaa0158
3c1ddd562c1325ab655b84dfba3ac805
7 changes: 5 additions & 2 deletions providers/src/airflow/providers/celery/cli/celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from airflow import settings
from airflow.configuration import conf
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations
from airflow.utils.serve_logs import serve_logs
Expand All @@ -42,8 +43,10 @@

def _run_command_with_daemon_option(*args, **kwargs):
try:
from airflow.cli.commands.local_commands.daemon_utils import run_command_with_daemon_option

if AIRFLOW_V_3_0_PLUS:
from airflow.cli.commands.local_commands.daemon_utils import run_command_with_daemon_option
else:
from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
run_command_with_daemon_option(*args, **kwargs)
except ImportError:
from airflow.exceptions import AirflowOptionalProviderFeatureException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@

from celery import states as celery_states
from deprecated import deprecated
from packaging.version import Version

from airflow import __version__ as airflow_version
from airflow.cli.cli_config import (
ARG_DAEMON,
ARG_LOG_FILE,
Expand All @@ -56,6 +54,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.celery.version_compat import AIRFLOW_V_2_8_PLUS, AIRFLOW_V_3_0_PLUS
from airflow.stats import Stats
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -163,8 +162,12 @@ def __getattr__(name):

CELERY_CLI_COMMAND_PATH = (
"airflow.providers.celery.cli.celery_command"
if Version(airflow_version) >= Version("2.8.0")
else "airflow.cli.commands.local_commands.celery_command"
if AIRFLOW_V_2_8_PLUS
else (
"airflow.cli.commands.local_commands.celery_command"
if AIRFLOW_V_3_0_PLUS
else "airflow.cli.commands.celery_command"
)
)

CELERY_COMMANDS = (
Expand Down
30 changes: 30 additions & 0 deletions providers/src/airflow/providers/celery/version_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations


def get_base_airflow_version_tuple() -> tuple[int, int, int]:
from packaging.version import Version

from airflow import __version__

airflow_version = Version(__version__)
return airflow_version.major, airflow_version.minor, airflow_version.micro


AIRFLOW_V_2_8_PLUS = get_base_airflow_version_tuple() >= (2, 8, 0)
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@
from __future__ import annotations

from airflow import settings
from airflow.cli.commands.local_commands.db_command import run_db_downgrade_command, run_db_migrate_command
from airflow.providers.fab.auth_manager.models.db import _REVISION_HEADS_MAP, FABDBManager
from airflow.providers.fab.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import cli as cli_utils
from airflow.utils.providers_configuration_loader import providers_configuration_loaded


def get_db_command():
try:
if AIRFLOW_V_3_0_PLUS:
import airflow.cli.commands.local_commands.db_command as db_command
else:
import airflow.cli.commands.db_command as db_command
except ImportError:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException("Failed to import db_command from Airflow CLI.")

return db_command


@providers_configuration_loaded
def resetdb(args):
"""Reset the metadata database."""
Expand All @@ -38,7 +52,7 @@ def migratedb(args):
"""Migrates the metadata database."""
session = settings.Session()
upgrade_command = FABDBManager(session).upgradedb
run_db_migrate_command(
get_db_command().run_db_migrate_command(
args, upgrade_command, revision_heads_map=_REVISION_HEADS_MAP, reserialize_dags=False
)

Expand All @@ -49,4 +63,4 @@ def downgrade(args):
"""Downgrades the metadata database."""
session = settings.Session()
dwongrade_command = FABDBManager(session).downgrade
run_db_downgrade_command(args, dwongrade_command, revision_heads_map=_REVISION_HEADS_MAP)
get_db_command().run_db_downgrade_command(args, dwongrade_command, revision_heads_map=_REVISION_HEADS_MAP)
29 changes: 29 additions & 0 deletions providers/src/airflow/providers/fab/version_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations


def get_base_airflow_version_tuple() -> tuple[int, int, int]:
from packaging.version import Version

from airflow import __version__

airflow_version = Version(__version__)
return airflow_version.major, airflow_version.minor, airflow_version.micro


AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
Loading

0 comments on commit b89f1f4

Please sign in to comment.