Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import warnings
from collections import defaultdict
from collections.abc import Sequence
from functools import cached_property
Expand All @@ -27,6 +28,7 @@
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.cli.cli_config import CLICommand
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
AwsAuthManagerAmazonVerifiedPermissionsFacade,
Expand Down Expand Up @@ -158,6 +160,13 @@ def is_authorized_dag(
def is_authorized_backfill(
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: BackfillDetails | None = None
) -> bool:
# Method can be removed once the min Airflow version is >= 3.2.0.
warnings.warn(
"Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

backfill_id = details.id if details else None
return self.avp_facade.is_authorized(
method=method, entity_type=AvpEntities.BACKFILL, user=user, entity_id=backfill_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
# under the License.
from __future__ import annotations

from contextlib import ExitStack
from typing import TYPE_CHECKING
from unittest.mock import ANY, Mock, patch

import pytest

from airflow.exceptions import AirflowProviderDeprecationWarning

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if not AIRFLOW_V_3_0_PLUS:
Expand Down Expand Up @@ -226,8 +229,16 @@ def test_is_authorized_backfill(
is_authorized = Mock(return_value=True)
mock_avp_facade.is_authorized = is_authorized

method: ResourceMethod = "GET"
result = auth_manager.is_authorized_backfill(method=method, details=details, user=user)
with ExitStack() as stack:
stack.enter_context(
pytest.warns(
AirflowProviderDeprecationWarning,
match="Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
)
)

method: ResourceMethod = "GET"
result = auth_manager.is_authorized_backfill(method=method, details=details, user=user)

is_authorized.assert_called_once_with(
method=method, entity_type=AvpEntities.BACKFILL, user=expected_user, entity_id=expected_entity_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import warnings
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -51,7 +52,7 @@
)
from airflow.api_fastapi.common.types import ExtraMenuItem, MenuItem
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
from airflow.exceptions import AirflowConfigException, AirflowProviderDeprecationWarning
from airflow.models import Connection, DagModel, Pool, Variable
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.fab.auth_manager.models import Permission, Role, User
Expand Down Expand Up @@ -389,6 +390,12 @@ def is_authorized_backfill(
user: User,
details: BackfillDetails | None = None,
) -> bool:
# Method can be removed once the min Airflow version is >= 3.2.0.
warnings.warn(
"Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self._is_authorized(method=method, resource_type=RESOURCE_BACKFILL, user=user)

def is_authorized_asset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import time
from contextlib import contextmanager, suppress
from contextlib import ExitStack, contextmanager, suppress
from itertools import chain
from typing import TYPE_CHECKING
from unittest import mock
Expand All @@ -29,7 +29,7 @@

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.common.types import MenuItem
from airflow.exceptions import AirflowConfigException
from airflow.exceptions import AirflowConfigException, AirflowProviderDeprecationWarning
from airflow.providers.fab.www.app import create_app
from airflow.providers.fab.www.utils import get_fab_auth_manager
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -328,10 +328,19 @@ def test_create_token_wrong_values(self, username, password, auth_manager_with_a
def test_is_authorized(self, api_name, method, user_permissions, expected_result, auth_manager):
user = Mock()
user.perms = user_permissions
result = getattr(auth_manager, api_name)(
method=method,
user=user,
)

with ExitStack() as stack:
if api_name == "is_authorized_backfill":
stack.enter_context(
pytest.warns(
AirflowProviderDeprecationWarning,
match="Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
)
)
result = getattr(auth_manager, api_name)(
method=method,
user=user,
)
assert result == expected_result

@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import logging
import time
import warnings
from base64 import urlsafe_b64decode
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
Expand All @@ -32,6 +33,7 @@

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.exceptions import AirflowProviderDeprecationWarning

try:
from airflow.api_fastapi.auth.managers.base_auth_manager import ExtendedResourceMethod
Expand Down Expand Up @@ -220,6 +222,13 @@ def is_authorized_dag(
def is_authorized_backfill(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: BackfillDetails | None = None
) -> bool:
# Method can be removed once the min Airflow version is >= 3.2.0.
warnings.warn(
"Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

backfill_id = str(details.id) if details else None
return self._is_authorized(
method=method, resource_type=KeycloakResource.BACKFILL, user=user, resource_id=backfill_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import json
from contextlib import ExitStack
from unittest.mock import Mock, patch

import pytest
Expand All @@ -36,6 +37,7 @@
VariableDetails,
)
from airflow.api_fastapi.common.types import MenuItem
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
Expand Down Expand Up @@ -263,7 +265,16 @@ def test_is_authorized(
mock_response.status_code = status_code
auth_manager.http_session.post = Mock(return_value=mock_response)

result = getattr(auth_manager, function)(method=method, user=user, details=details)
with ExitStack() as stack:
if function == "is_authorized_backfill":
stack.enter_context(
pytest.warns(
AirflowProviderDeprecationWarning,
match="Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
)
)

result = getattr(auth_manager, function)(method=method, user=user, details=details)

token_url = auth_manager._get_token_url("server_url", "realm")
payload = auth_manager._get_payload("client_id", permission, attributes)
Expand Down Expand Up @@ -291,10 +302,19 @@ def test_is_authorized_failure(self, function, auth_manager, user):
resp.status_code = 500
auth_manager.http_session.post = Mock(return_value=resp)

with pytest.raises(AirflowException) as e:
getattr(auth_manager, function)(method="GET", user=user)
with ExitStack() as stack:
if function == "is_authorized_backfill":
stack.enter_context(
pytest.warns(
AirflowProviderDeprecationWarning,
match="Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
)
)

with pytest.raises(AirflowException) as e:
getattr(auth_manager, function)(method="GET", user=user)

assert "Unexpected error" in str(e.value)
assert "Unexpected error" in str(e.value)

@pytest.mark.parametrize(
"function",
Expand All @@ -315,10 +335,19 @@ def test_is_authorized_invalid_request(self, function, auth_manager, user):
resp.text = json.dumps({"error": "invalid_scope", "error_description": "Invalid scopes: GET"})
auth_manager.http_session.post = Mock(return_value=resp)

with pytest.raises(AirflowException) as e:
getattr(auth_manager, function)(method="GET", user=user)
with ExitStack() as stack:
if function == "is_authorized_backfill":
stack.enter_context(
pytest.warns(
AirflowProviderDeprecationWarning,
match="Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.",
)
)

with pytest.raises(AirflowException) as e:
getattr(auth_manager, function)(method="GET", user=user)

assert "Request not recognized by Keycloak. invalid_scope. Invalid scopes: GET" in str(e.value)
assert "Request not recognized by Keycloak. invalid_scope. Invalid scopes: GET" in str(e.value)

@pytest.mark.parametrize(
("method", "access_entity", "details", "permission", "attributes"),
Expand Down