Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions airflow/api_connexion/endpoints/asset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def get_asset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Get queued asset events for an asset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(
user=get_auth_manager().get_user(), methods=["GET"]
)
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
Expand Down Expand Up @@ -313,7 +315,9 @@ def delete_asset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete queued asset events for an asset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(
user=get_auth_manager().get_user(), methods=["GET"]
)
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Resp
raise NotFound("File not found")

# Check if user has read access to all the DAGs defined in the file
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
raise PermissionDenied()

parsing_request = DagPriorityParsingRequest(fileloc=path)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_dag_source(
]

# Check if user has read access to all the DAGs defined in the file
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
raise PermissionDenied()
dag_source = dag_version.dag_code.source_code
version_number = dag_version.version_number
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
)
session.expunge(error)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET", user=get_auth_manager().get_user())
if not can_read_all_dags:
readable_dag_ids = security.get_readable_dags()
file_dag_ids = {
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_import_errors(
query = select(ParseImportError)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET", user=get_auth_manager().get_user())

if not can_read_all_dags:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_import_errors(
}
for dag_id in file_dag_ids
]
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
session.expunge(import_error)
import_error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"

Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
}
for id in dag_ids
]
if not get_auth_manager().batch_is_authorized_dag(requests):
if not get_auth_manager().batch_is_authorized_dag(requests, user=get_auth_manager().get_user()):
raise PermissionDenied(detail=f"User not allowed to access some of these DAGs: {list(dag_ids)}")
else:
dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user)
Expand Down
34 changes: 26 additions & 8 deletions airflow/api_connexion/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def decorated(*args, **kwargs):
section: str | None = kwargs.get("section")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_configuration(
method=method, details=ConfigurationDetails(section=section)
method=method,
details=ConfigurationDetails(section=section),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -97,7 +99,9 @@ def decorated(*args, **kwargs):
connection_id: str | None = kwargs.get("connection_id")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_connection(
method=method, details=ConnectionDetails(conn_id=connection_id)
method=method,
details=ConnectionDetails(conn_id=connection_id),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -120,13 +124,15 @@ def callback() -> bool | DagAccessEntity:
method=method,
access_entity=access_entity,
details=DagDetails(id=dag_id),
user=get_auth_manager().get_user(),
)
else:
# here we know dag_id is not provided.
# check is the user authorized to access all DAGs?
if get_auth_manager().is_authorized_dag(
method=method,
access_entity=access_entity,
user=get_auth_manager().get_user(),
):
return True
elif access_entity:
Expand All @@ -138,7 +144,9 @@ def callback() -> bool | DagAccessEntity:
# but we leave it to the endpoint function to properly restrict access beyond that
if method not in ("GET", "PUT"):
return False
return any(get_auth_manager().get_permitted_dag_ids(methods=[method]))
return any(
get_auth_manager().get_permitted_dag_ids(user=get_auth_manager().get_user(), methods=[method])
)

return callback

Expand All @@ -165,7 +173,9 @@ def decorated(*args, **kwargs):
uri: str | None = kwargs.get("uri")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_asset(
method=method, details=AssetDetails(uri=uri)
method=method,
details=AssetDetails(uri=uri),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -184,7 +194,9 @@ def decorated(*args, **kwargs):
pool_name: str | None = kwargs.get("pool_name")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_pool(
method=method, details=PoolDetails(name=pool_name)
method=method,
details=PoolDetails(name=pool_name),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -203,7 +215,9 @@ def decorated(*args, **kwargs):
variable_key: str | None = kwargs.get("variable_key")
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_variable(
method=method, details=VariableDetails(key=variable_key)
method=method,
details=VariableDetails(key=variable_key),
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand All @@ -220,7 +234,9 @@ def requires_access_decorator(func: T):
@wraps(func)
def decorated(*args, **kwargs):
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_view(access_view=access_view),
is_authorized_callback=lambda: get_auth_manager().is_authorized_view(
access_view=access_view, user=get_auth_manager().get_user()
),
func=func,
args=args,
kwargs=kwargs,
Expand All @@ -240,7 +256,9 @@ def requires_access_decorator(func: T):
def decorated(*args, **kwargs):
return _requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_custom_view(
method=method, resource_name=resource_name
method=method,
resource_name=resource_name,
user=get_auth_manager().get_user(),
),
func=func,
args=args,
Expand Down
Loading