Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check DAG read permission before accessing DAG code #36257

Merged
merged 1 commit into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,38 @@
from __future__ import annotations

from http import HTTPStatus
from typing import TYPE_CHECKING

from flask import Response, current_app, request
from itsdangerous import BadSignature, URLSafeSerializer

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema
from airflow.api_connexion.security import get_readable_dags
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models.dag import DagModel
from airflow.models.dagcode import DagCode
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session


@security.requires_access_dag("GET", DagAccessEntity.CODE)
def get_dag_source(*, file_token: str) -> Response:
@provide_session
def get_dag_source(*, file_token: str, session: Session = NEW_SESSION) -> Response:
"""Get source code using file token."""
secret_key = current_app.config["SECRET_KEY"]
auth_s = URLSafeSerializer(secret_key)
try:
path = auth_s.loads(file_token)
dag_source = DagCode.code(path)
dag_ids = session.query(DagModel.dag_id).filter(DagModel.fileloc == path).all()
readable_dags = get_readable_dags()
# Check if user has read access to all the DAGs defined in the file
if any(dag_id[0] not in readable_dags for dag_id in dag_ids):
raise PermissionDenied()
dag_source = DagCode.code(path, session=session)
except (BadSignature, FileNotFoundError):
raise NotFound("Dag source not found")

Expand Down
5 changes: 3 additions & 2 deletions airflow/models/dagcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ def get_code_by_fileloc(cls, fileloc: str) -> str:
return cls.code(fileloc)

@classmethod
def code(cls, fileloc) -> str:
@provide_session
def code(cls, fileloc, session: Session = NEW_SESSION) -> str:
"""Return source code for this DagCode object.

:return: source code as string
"""
return cls._get_code_from_db(fileloc)
return cls._get_code_from_db(fileloc, session)
Comment on lines -180 to +186
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to avoid creating two different sessions because we use the method code and not the protected one.


@staticmethod
def _get_code_from_file(fileloc):
Expand Down
67 changes: 59 additions & 8 deletions tests/api_connexion/endpoints/test_dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@

ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py")
EXAMPLE_DAG_ID = "example_bash_operator"
TEST_DAG_ID = "latest_only"
NOT_READABLE_DAG_ID = "latest_only_with_trigger"
TEST_MULTIPLE_DAGS_ID = "dataset_produces_1"


@pytest.fixture(scope="module")
Expand All @@ -45,6 +49,18 @@ def configured_app(minimal_app_for_api):
role_name="Test",
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], # type: ignore
)
app.appbuilder.sm.sync_perm_for_dag( # type: ignore
TEST_DAG_ID,
access_control={"Test": [permissions.ACTION_CAN_READ]},
)
app.appbuilder.sm.sync_perm_for_dag( # type: ignore
EXAMPLE_DAG_ID,
access_control={"Test": [permissions.ACTION_CAN_READ]},
)
app.appbuilder.sm.sync_perm_for_dag( # type: ignore
TEST_MULTIPLE_DAGS_ID,
access_control={"Test": [permissions.ACTION_CAN_READ]},
)
create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore

yield app
Expand Down Expand Up @@ -80,10 +96,10 @@ def _get_dag_file_docstring(fileloc: str) -> str | None:
def test_should_respond_200_text(self, url_safe_serializer):
dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
dagbag.sync_to_db()
first_dag: DAG = next(iter(dagbag.dags.values()))
dag_docstring = self._get_dag_file_docstring(first_dag.fileloc)
test_dag: DAG = dagbag.dags[TEST_DAG_ID]
dag_docstring = self._get_dag_file_docstring(test_dag.fileloc)

url = f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}"
url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}"
response = self.client.get(
url, headers={"Accept": "text/plain"}, environ_overrides={"REMOTE_USER": "test"}
)
Expand All @@ -95,10 +111,10 @@ def test_should_respond_200_text(self, url_safe_serializer):
def test_should_respond_200_json(self, url_safe_serializer):
dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
dagbag.sync_to_db()
first_dag: DAG = next(iter(dagbag.dags.values()))
dag_docstring = self._get_dag_file_docstring(first_dag.fileloc)
test_dag: DAG = dagbag.dags[TEST_DAG_ID]
dag_docstring = self._get_dag_file_docstring(test_dag.fileloc)

url = f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}"
url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}"
response = self.client.get(
url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"}
)
Expand All @@ -110,9 +126,9 @@ def test_should_respond_200_json(self, url_safe_serializer):
def test_should_respond_406(self, url_safe_serializer):
dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
dagbag.sync_to_db()
first_dag: DAG = next(iter(dagbag.dags.values()))
test_dag: DAG = dagbag.dags[TEST_DAG_ID]

url = f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}"
url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}"
response = self.client.get(
url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"}
)
Expand Down Expand Up @@ -151,3 +167,38 @@ def test_should_raise_403_forbidden(self, url_safe_serializer):
environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403

def test_should_respond_403_not_readable(self, url_safe_serializer):
dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
dagbag.sync_to_db()
dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID]

response = self.client.get(
f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}",
headers={"Accept": "text/plain"},
environ_overrides={"REMOTE_USER": "test"},
)
read_dag = self.client.get(
f"/api/v1/dags/{NOT_READABLE_DAG_ID}",
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 403
assert read_dag.status_code == 403

def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer):
dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
dagbag.sync_to_db()
dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID]

response = self.client.get(
f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}",
headers={"Accept": "text/plain"},
environ_overrides={"REMOTE_USER": "test"},
)

read_dag = self.client.get(
f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}",
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 403
assert read_dag.status_code == 200