Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ repos:
files: >
(?x)
^airflow-ctl.*\.py$|
^providers/fab.*\.py$|
^task_sdk.*\.py$
pass_filenames: true
- id: update-supported-versions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from flask_login import LoginManager
from itsdangerous import want_bytes
from markupsafe import Markup, escape
from sqlalchemy import func, inspect, or_, select
from sqlalchemy import delete, func, inspect, or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from werkzeug.security import check_password_hash, generate_password_hash
Expand Down Expand Up @@ -565,7 +565,7 @@ def reset_user_sessions(self, user: User) -> None:
interface = self.appbuilder.get_app.session_interface
session = interface.db.session
user_session_model = interface.sql_session_model
num_sessions = session.query(user_session_model).count()
num_sessions = session.scalars(select(func.count()).select_from(user_session_model)).one()
if num_sessions > MAX_NUM_DATABASE_USER_SESSIONS:
safe_username = escape(user.username)
self._cli_safe_flash(
Expand All @@ -580,7 +580,7 @@ def reset_user_sessions(self, user: User) -> None:
"warning",
)
else:
for s in session.query(user_session_model):
for s in session.scalars(select(user_session_model)).all():
session_details = interface.serializer.loads(want_bytes(s.data))
if session_details.get("_user_id") == user.id:
session.delete(s)
Expand Down Expand Up @@ -1209,12 +1209,14 @@ def clean_perms(self) -> None:
"""FAB leaves faulty permissions that need to be cleaned up."""
self.log.debug("Cleaning faulty perms")
sesh = self.appbuilder.get_session
perms = sesh.query(Permission).filter(
or_(
Permission.action == None, # noqa: E711
Permission.resource == None, # noqa: E711
perms = sesh.scalars(
select(Permission).where(
or_(
Permission.action == None, # noqa: E711
Permission.resource == None, # noqa: E711
)
)
)
).all()
# Since FAB doesn't define ON DELETE CASCADE on these tables, we need
# to delete the _object_ so that SQLA knows to delete the many-to-many
# relationship object too. :(
Expand Down Expand Up @@ -1292,10 +1294,10 @@ def find_role(self, name):

:param name: the role name
"""
return self.get_session.query(self.role_model).filter_by(name=name).one_or_none()
return self.get_session.scalars(select(self.role_model).filter_by(name=name)).unique().one_or_none()

def get_all_roles(self):
return self.get_session.query(self.role_model).all()
return self.get_session.scalars(select(self.role_model)).unique().all()

def delete_role(self, role_name: str) -> None:
"""
Expand All @@ -1304,10 +1306,10 @@ def delete_role(self, role_name: str) -> None:
:param role_name: the name of a role in the ab_role table
"""
session = self.get_session
role = session.query(Role).filter(Role.name == role_name).first()
role = session.scalars(select(Role).where(Role.name == role_name)).first()
if role:
log.info("Deleting role '%s'", role_name)
session.delete(role)
session.execute(delete(Role).where(Role.name == role_name))
session.commit()
else:
raise AirflowException(f"Role named '{role_name}' does not exist")
Expand Down Expand Up @@ -1338,7 +1340,11 @@ def get_roles_from_keys(self, role_keys: list[str]) -> set[Role]:
return _roles

def get_public_role(self):
return self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none()
return (
self.get_session.scalars(select(self.role_model).filter_by(name=self.auth_role_public))
.unique()
.one_or_none()
)

"""
-----------
Expand Down Expand Up @@ -1395,7 +1401,7 @@ def get_user_by_id(self, pk):

def count_users(self):
"""Return the number of users in the database."""
return self.get_session.query(func.count(self.user_model.id)).scalar()
return self.get_session.scalar(select(func.count(self.user_model.id)))

def add_register_user(self, username, first_name, last_name, email, password="", hashed_password=""):
"""
Expand Down Expand Up @@ -1427,22 +1433,22 @@ def find_user(self, username=None, email=None):
if username:
try:
if self.auth_username_ci:
return (
self.get_session.query(self.user_model)
.filter(func.lower(self.user_model.username) == func.lower(username))
.one_or_none()
return self.get_session.scalars(
select(self.user_model).where(
func.lower(self.user_model.username) == func.lower(username)
)
).one_or_none()
return self.get_session.scalars(
select(self.user_model).where(
func.lower(self.user_model.username) == func.lower(username)
)
return (
self.get_session.query(self.user_model)
.filter(func.lower(self.user_model.username) == func.lower(username))
.one_or_none()
)
).one_or_none()
except MultipleResultsFound:
log.error("Multiple results found for user %s", username)
return None
elif email:
try:
return self.get_session.query(self.user_model).filter_by(email=email).one_or_none()
return self.get_session.scalars(select(self.user_model).filter_by(email=email)).one_or_none()
except MultipleResultsFound:
log.error("Multiple results found for user with email %s", email)
return None
Expand Down Expand Up @@ -1474,7 +1480,7 @@ def del_register_user(self, register_user):
return False

def get_all_users(self):
return self.get_session.query(self.user_model).all()
return self.get_session.scalars(select(self.user_model)).all()

def update_user_auth_stat(self, user, success=True):
"""
Expand Down Expand Up @@ -1514,7 +1520,7 @@ def get_action(self, name: str) -> Action:

:param name: name
"""
return self.get_session.query(self.action_model).filter_by(name=name).one_or_none()
return self.get_session.scalars(select(self.action_model).filter_by(name=name)).one_or_none()

def create_action(self, name):
"""
Expand Down Expand Up @@ -1547,11 +1553,9 @@ def delete_action(self, name: str) -> bool:
log.warning(const.LOGMSG_WAR_SEC_DEL_PERMISSION, name)
return False
try:
perms = (
self.get_session.query(self.permission_model)
.filter(self.permission_model.action == action)
.all()
)
perms = self.get_session.scalars(
select(self.permission_model).where(self.permission_model.action == action)
).all()
if perms:
log.warning(const.LOGMSG_WAR_SEC_DEL_PERM_PVM, action, perms)
return False
Expand All @@ -1575,7 +1579,7 @@ def get_resource(self, name: str) -> Resource | None:

:param name: Name of resource
"""
return self.get_session.query(self.resource_model).filter_by(name=name).one_or_none()
return self.get_session.scalars(select(self.resource_model).filter_by(name=name)).one_or_none()

def create_resource(self, name) -> Resource | None:
"""
Expand Down Expand Up @@ -1617,10 +1621,13 @@ def get_permission(
resource = self.get_resource(resource_name)
if action and resource:
return (
self.get_session.query(self.permission_model)
.filter_by(action=action, resource=resource)
self.get_session.scalars(
select(self.permission_model).filter_by(action=action, resource=resource)
)
.unique()
.one_or_none()
)

return None

def get_resource_permissions(self, resource: Resource) -> Permission:
Expand All @@ -1629,7 +1636,9 @@ def get_resource_permissions(self, resource: Resource) -> Permission:

:param resource: Object representing a single resource.
"""
return self.get_session.query(self.permission_model).filter_by(resource_id=resource.id).all()
return self.get_session.scalars(
select(self.permission_model).filter_by(resource_id=resource.id)
).all()

def create_permission(self, action_name, resource_name) -> Permission | None:
"""
Expand Down Expand Up @@ -1676,9 +1685,9 @@ def delete_permission(self, action_name: str, resource_name: str) -> None:
perm = self.get_permission(action_name, resource_name)
if not perm:
return
roles = (
self.get_session.query(self.role_model).filter(self.role_model.permissions.contains(perm)).first()
)
roles = self.get_session.scalars(
select(self.role_model).where(self.role_model.permissions.contains(perm))
).first()
if roles:
log.warning(const.LOGMSG_WAR_SEC_DEL_PERMVIEW, resource_name, action_name, roles)
return
Expand All @@ -1687,7 +1696,9 @@ def delete_permission(self, action_name: str, resource_name: str) -> None:
self.get_session.delete(perm)
self.get_session.commit()
# if no more permission on permission view, delete permission
if not self.get_session.query(self.permission_model).filter_by(action=perm.action).all():
if not self.get_session.scalars(
select(self.permission_model).filter_by(action=perm.action)
).all():
self.delete_action(perm.action.name)
log.info(const.LOGMSG_INF_SEC_DEL_PERMVIEW, action_name, resource_name)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import pytest
from sqlalchemy import select

from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.providers.fab.www.security import permissions
Expand Down Expand Up @@ -73,7 +74,7 @@ def teardown_method(self):
session = self.app.appbuilder.get_session
existing_roles = set(EXISTING_ROLES)
existing_roles.update(["Test", "TestNoPermissions"])
roles = session.query(Role).filter(~Role.name.in_(existing_roles)).all()
roles = session.scalars(select(Role).where(~Role.name.in_(existing_roles))).unique().all()
for role in roles:
delete_role(self.app, role.name)

Expand Down Expand Up @@ -353,7 +354,7 @@ def test_delete_should_respond_204(self, session):
role = create_role(self.app, "mytestrole")
response = self.client.delete(f"/fab/v1/roles/{role.name}", environ_overrides={"REMOTE_USER": "test"})
assert response.status_code == 204
role_obj = session.query(Role).filter(Role.name == role.name).all()
role_obj = session.scalars(select(Role).where(Role.name == role.name)).all()
assert len(role_obj) == 0

def test_delete_should_respond_404(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
import unittest.mock

import pytest
from sqlalchemy.sql.functions import count
from sqlalchemy import delete, func, select

from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.providers.fab.www.security import permissions
from airflow.utils import timezone
from airflow.utils.session import create_session

from tests_common.test_utils.compat import ignore_provider_compatibility_error
Expand All @@ -35,6 +34,12 @@
delete_user,
)

try:
from airflow.utils import timezone # type: ignore[attr-defined]
except AttributeError:
from airflow.sdk import timezone


with ignore_provider_compatibility_error("2.9.0+", __file__):
from airflow.providers.fab.auth_manager.models import User

Expand Down Expand Up @@ -85,8 +90,7 @@ def setup_attrs(self, configured_app) -> None:

def teardown_method(self) -> None:
# Delete users that have our custom default time
users = self.session.query(User).filter(User.changed_on == timezone.parse(DEFAULT_TIME))
users.delete(synchronize_session=False)
self.session.execute(delete(User).where(User.changed_on == timezone.parse(DEFAULT_TIME)))
self.session.commit()

def _create_users(self, count, roles=None):
Expand Down Expand Up @@ -354,11 +358,11 @@ def test_should_return_conf_max_if_req_max_above_conf(self):

def _delete_user(**filters):
with create_session() as session:
user = session.query(User).filter_by(**filters).first()
user = session.scalars(select(User).filter_by(**filters)).first()
if user is None:
return
user.roles = []
session.delete(user)
session.execute(delete(User).filter_by(**filters))


@pytest.fixture
Expand Down Expand Up @@ -676,10 +680,7 @@ def test_password_hashed(
assert "password" not in response.json

mock_generate_password_hash.assert_called_once_with("new-pass")

password_in_db = (
self.session.query(User.password).filter(User.username == autoclean_username).scalar()
)
password_in_db = self.session.scalar(select(User.password).where(User.username == autoclean_username))
assert password_in_db == "fake-hashed-pass"

@pytest.mark.usefixtures("autoclean_admin_user")
Expand Down Expand Up @@ -788,15 +789,19 @@ def test_delete(self, autoclean_username):
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 204, response.json # NO CONTENT.
assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 0
assert (
self.session.scalar(select(func.count(User.id)).where(User.username == autoclean_username)) == 0
)

@pytest.mark.usefixtures("autoclean_admin_user")
def test_unauthenticated(self, autoclean_username):
response = self.client.delete(
f"/fab/v1/users/{autoclean_username}",
)
assert response.status_code == 401, response.json
assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 1
assert (
self.session.scalar(select(func.count(User.id)).where(User.username == autoclean_username)) == 1
)

@pytest.mark.usefixtures("autoclean_admin_user")
def test_forbidden(self, autoclean_username):
Expand All @@ -805,7 +810,9 @@ def test_forbidden(self, autoclean_username):
environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403, response.json
assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 1
assert (
self.session.scalar(select(func.count(User.id)).where(User.username == autoclean_username)) == 1
)

def test_not_found(self, autoclean_username):
# This test does not populate autoclean_admin_user into the database.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import pytest
from sqlalchemy import select

from airflow.utils import timezone

Expand Down Expand Up @@ -60,7 +61,7 @@ def setup_attrs(self, configured_app) -> None:
self.session = self.app.appbuilder.get_session

def teardown_method(self):
user = self.session.query(User).filter(User.email == TEST_EMAIL).first()
user = self.session.scalars(select(User).where(User.email == TEST_EMAIL)).first()
if user:
self.session.delete(user)
self.session.commit()
Expand All @@ -80,7 +81,7 @@ def test_serialize(self):
self.session.add(user_model)
user_model.roles = [self.role]
self.session.commit()
user = self.session.query(User).filter(User.email == TEST_EMAIL).first()
user = self.session.scalars(select(User).where(User.email == TEST_EMAIL)).first()
deserialized_user = user_collection_item_schema.dump(user)
# No user_id and password in dump
assert deserialized_user == {
Expand Down Expand Up @@ -111,7 +112,7 @@ def test_serialize(self):
)
self.session.add(user_model)
self.session.commit()
user = self.session.query(User).filter(User.email == TEST_EMAIL).first()
user = self.session.scalars(select(User).where(User.email == TEST_EMAIL)).first()
deserialized_user = user_schema.dump(user)
# No user_id and password in dump
assert deserialized_user == {
Expand Down
Loading