diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index a924186a36be9..7820b7339dcce 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -217,7 +217,7 @@ def _sync_dag_perms(dag: MaybeSerializedDAG, session: Session): dag_id = dag.dag_id log.debug("Syncing DAG permissions: %s to the DB", dag_id) - from airflow.www.security_appless import ApplessAirflowSecurityManager + from airflow.providers.fab.www.security_appless import ApplessAirflowSecurityManager security_manager = ApplessAirflowSecurityManager(session=session) security_manager.sync_perm_for_dag(dag_id, dag.access_control) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 5e370aeb1b7a7..314095fea2ad1 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -735,7 +735,7 @@ def _get_flask_db(sql_database_uri): from flask import Flask from flask_sqlalchemy import SQLAlchemy - from airflow.www.session import AirflowDatabaseSessionInterface + from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface flask_app = Flask(__name__) flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri diff --git a/airflow/www/__init__.py b/airflow/www/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/airflow/www/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/www/session.py b/airflow/www/session.py deleted file mode 100644 index 763b909ae0d94..0000000000000 --- a/airflow/www/session.py +++ /dev/null @@ -1,41 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from flask import request -from flask.sessions import SecureCookieSessionInterface -from flask_session.sessions import SqlAlchemySessionInterface - - -class SessionExemptMixin: - """Exempt certain blueprints/paths from autogenerated sessions.""" - - def save_session(self, *args, **kwargs): - """Prevent creating session from REST API and health requests.""" - if request.blueprint == "/api/v1": - return None - if request.path == "/health": - return None - return super().save_session(*args, **kwargs) - - -class AirflowDatabaseSessionInterface(SessionExemptMixin, SqlAlchemySessionInterface): - """Session interface that exempts some routes and stores session data in the database.""" - - -class AirflowSecureCookieSessionInterface(SessionExemptMixin, SecureCookieSessionInterface): - """Session interface that exempts some routes and stores session data in a signed cookie.""" diff --git a/airflow/www/validators.py b/airflow/www/validators.py deleted file mode 100644 index 21d156a8dd178..0000000000000 --- a/airflow/www/validators.py +++ /dev/null @@ -1,139 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import json -from json import JSONDecodeError - -from wtforms.validators import EqualTo, ValidationError - -from airflow.models.connection import CONN_ID_MAX_LEN, sanitize_conn_id -from airflow.utils import helpers - - -class GreaterEqualThan(EqualTo): - """ - Compares the values of two fields. - - :param fieldname: - The name of the other field to compare to. - :param message: - Error message to raise in case of a validation error. Can be - interpolated with `%(other_label)s` and `%(other_name)s` to provide a - more helpful error. - """ - - def __call__(self, form, field): - try: - other = form[self.fieldname] - except KeyError: - raise ValidationError(field.gettext(f"Invalid field name '{self.fieldname}'.")) - - if field.data is None or other.data is None: - return - - if field.data < other.data: - message_args = { - "other_label": hasattr(other, "label") and other.label.text or self.fieldname, - "other_name": self.fieldname, - } - message = self.message - if message is None: - message = field.gettext( - f"Field must be greater than or equal to {message_args['other_label']}." - ) - else: - message = message % message_args - - raise ValidationError(message) - - -class ValidJson: - """ - Validates data is valid JSON. - - :param message: - Error message to raise in case of a validation error. - """ - - def __init__(self, message=None): - self.message = message - - def __call__(self, form, field): - if field.data: - try: - json.loads(field.data) - except JSONDecodeError as ex: - message = self.message or f"JSON Validation Error: {ex}" - raise ValidationError(message=field.gettext(message.format(field.data))) - - -class ValidKey: - """ - Validates values that will be used as keys. - - :param max_length: - The maximum allowed length of the given key - """ - - def __init__(self, max_length=200): - self.max_length = max_length - - def __call__(self, form, field): - if field.data: - try: - helpers.validate_key(field.data, self.max_length) - except Exception as e: - raise ValidationError(str(e)) - - -class ReadOnly: - """ - Adds readonly flag to a field. - - When using this you normally will need to override the form's populate_obj method, - so field.populate_obj is not called for read-only fields. - """ - - def __call__(self, form, field): - field.flags.readonly = True - - -class ValidConnID: - """ - Validates the connection ID adheres to the desired format. - - :param max_length: - The maximum allowed length of the given Connection ID. - """ - - message = ( - "Connection ID must be alphanumeric characters plus dashes, dots, hashes, colons, semicolons, " - "underscores, exclamation marks, and parentheses" - ) - - def __init__( - self, - max_length: int = CONN_ID_MAX_LEN, - ): - self.max_length = max_length - - def __call__(self, form, field): - if field.data: - if sanitize_conn_id(field.data, self.max_length) is None: - raise ValidationError(f"{self.message} for 1 and up to {self.max_length} matches") diff --git a/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py b/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py index 3382249e4bad9..f896a0506b87a 100644 --- a/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py +++ b/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py @@ -24,6 +24,7 @@ from flask import Blueprint, redirect, request, url_for from flask_appbuilder import BaseView, expose +from markupsafe import Markup from sqlalchemy import select from airflow.auth.managers.models.resource_details import AccessView @@ -32,6 +33,7 @@ from airflow.models.taskinstance import TaskInstanceState from airflow.plugins_manager import AirflowPlugin from airflow.providers.edge.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils.state import State if AIRFLOW_V_3_0_PLUS: from airflow.providers.fab.www.auth import has_access_view @@ -39,7 +41,6 @@ from airflow.www.auth import has_access_view # type: ignore from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.yaml import safe_load -from airflow.www import utils as wwwutils if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -80,6 +81,18 @@ def _get_api_endpoint() -> dict[str, Any]: } +def _state_token(state): + """Return a formatted string with HTML for a given State.""" + color = State.color(state) + fg_color = State.color_fg(state) + return Markup( + """ + {state} + """ + ).format(color=color, state=state, fg_color=fg_color) + + def modify_maintenance_comment_on_update(maintenance_comment: str | None, username: str) -> str: if maintenance_comment: if re.search( @@ -121,7 +134,7 @@ def jobs(self, session: Session = NEW_SESSION): jobs = session.scalars(select(EdgeJobModel).order_by(EdgeJobModel.queued_dttm)).all() html_states = { - str(state): wwwutils.state_token(str(state)) for state in TaskInstanceState.__members__.values() + str(state): _state_token(str(state)) for state in TaskInstanceState.__members__.values() } return self.render_template("edge_worker_jobs.html", jobs=jobs, html_states=html_states) diff --git a/airflow/www/security_appless.py b/providers/fab/src/airflow/providers/fab/www/security_appless.py similarity index 100% rename from airflow/www/security_appless.py rename to providers/fab/src/airflow/providers/fab/www/security_appless.py diff --git a/providers/fab/tests/unit/fab/utils.py b/providers/fab/tests/unit/fab/utils.py index 01b43bb7513a8..ca78a5a4ab4b4 100644 --- a/providers/fab/tests/unit/fab/utils.py +++ b/providers/fab/tests/unit/fab/utils.py @@ -211,7 +211,6 @@ def local_context(self): # flask_appbuilder.baseviews.BaseView.render_template "appbuilder", "base_template", - # airflow.www.app.py.create_app (inner method - jinja_globals) "server_timezone", "default_ui_timezone", "hostname", @@ -227,11 +226,8 @@ def local_context(self): "airflow_version", "git_version", "k8s_or_k8scelery_executor", - # airflow.www.static_config.configure_manifest_files "url_for_asset", - # airflow.www.views.AirflowBaseView.render_template "scheduler_job", - # airflow.www.views.AirflowBaseView.extra_args "macros", "auth_manager", "triggerer_job", diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index ff2ee2eb1d8f2..f4beb819a2c4f 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -120,7 +120,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: from wtforms import validators from wtforms.fields.simple import BooleanField, StringField - from airflow.www.validators import ValidJson + from airflow.providers.google.cloud.utils.validators import ValidJson connection_form_widgets = super().get_connection_form_widgets() connection_form_widgets["use_legacy_sql"] = BooleanField(lazy_gettext("Use Legacy SQL"), default=True) diff --git a/airflow/www/utils.py b/providers/google/src/airflow/providers/google/cloud/utils/validators.py similarity index 56% rename from airflow/www/utils.py rename to providers/google/src/airflow/providers/google/cloud/utils/validators.py index ca4e98645145f..d0a8157edfb49 100644 --- a/airflow/www/utils.py +++ b/providers/google/src/airflow/providers/google/cloud/utils/validators.py @@ -17,18 +17,27 @@ # under the License. from __future__ import annotations -from markupsafe import Markup +import json +from json import JSONDecodeError -from airflow.utils.state import State +from wtforms.validators import ValidationError -def state_token(state): - """Return a formatted string with HTML for a given State.""" - color = State.color(state) - fg_color = State.color_fg(state) - return Markup( - """ - {state} - """ - ).format(color=color, state=state, fg_color=fg_color) +class ValidJson: + """ + Validates data is valid JSON. + + :param message: + Error message to raise in case of a validation error. + """ + + def __init__(self, message=None): + self.message = message + + def __call__(self, form, field): + if field.data: + try: + json.loads(field.data) + except JSONDecodeError as ex: + message = self.message or f"JSON Validation Error: {ex}" + raise ValidationError(message=field.gettext(message.format(field.data))) diff --git a/providers/google/tests/unit/google/cloud/utils/test_validators.py b/providers/google/tests/unit/google/cloud/utils/test_validators.py new file mode 100644 index 0000000000000..1b191f9c21f46 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/utils/test_validators.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from wtforms.validators import ValidationError + +from airflow.providers.google.cloud.utils.validators import ValidJson + + +class TestValidJson: + def setup_method(self): + self.form_field_mock = mock.MagicMock(data='{"valid":"True"}') + self.form_field_mock.gettext.side_effect = lambda msg: msg + self.form_mock = mock.MagicMock(spec_set=dict) + + def _validate(self, message=None): + validator = ValidJson(message=message) + + return validator(self.form_mock, self.form_field_mock) + + def test_form_field_is_none(self): + self.form_field_mock.data = None + + assert self._validate() is None + + def test_validation_pass(self): + assert self._validate() is None + + def test_validation_raises_default_message(self): + self.form_field_mock.data = "2017-05-04" + + with pytest.raises(ValidationError, match="JSON Validation Error:.*"): + self._validate() + + def test_validation_raises_custom_message(self): + self.form_field_mock.data = "2017-05-04" + + with pytest.raises(ValidationError, match="Invalid JSON"): + self._validate( + message="Invalid JSON: {}", + ) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index ed888981ecde4..7b006e36ffd57 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -133,6 +133,7 @@ def test_providers_modules_should_have_tests(self): "providers/fab/tests/unit/fab/www/test_airflow_flask_app.py", "providers/fab/tests/unit/fab/www/test_app.py", "providers/fab/tests/unit/fab/www/test_constants.py", + "providers/fab/tests/unit/fab/www/test_security_appless.py", "providers/fab/tests/unit/fab/www/test_security_manager.py", "providers/fab/tests/unit/fab/www/test_session.py", "providers/fab/tests/unit/fab/www/test_utils.py", diff --git a/tests/www/__init__.py b/tests/www/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/tests/www/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py deleted file mode 100644 index 02923f9e2afd8..0000000000000 --- a/tests/www/test_validators.py +++ /dev/null @@ -1,167 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -import pytest - -from airflow.www import validators - - -class TestGreaterEqualThan: - def setup_method(self): - self.form_field_mock = mock.MagicMock(data="2017-05-06") - self.form_field_mock.gettext.side_effect = lambda msg: msg - self.other_field_mock = mock.MagicMock(data="2017-05-05") - self.other_field_mock.gettext.side_effect = lambda msg: msg - self.other_field_mock.label.text = "other field" - self.form_stub = {"other_field": self.other_field_mock} - self.form_mock = mock.MagicMock(spec_set=dict) - self.form_mock.__getitem__.side_effect = self.form_stub.__getitem__ - - def _validate(self, fieldname=None, message=None): - if fieldname is None: - fieldname = "other_field" - - validator = validators.GreaterEqualThan(fieldname=fieldname, message=message) - - return validator(self.form_mock, self.form_field_mock) - - def test_field_not_found(self): - with pytest.raises(validators.ValidationError, match="^Invalid field name 'some'.$"): - self._validate( - fieldname="some", - ) - - def test_form_field_is_none(self): - self.form_field_mock.data = None - - assert self._validate() is None - - def test_other_field_is_none(self): - self.other_field_mock.data = None - - assert self._validate() is None - - def test_both_fields_are_none(self): - self.form_field_mock.data = None - self.other_field_mock.data = None - - assert self._validate() is None - - def test_validation_pass(self): - assert self._validate() is None - - def test_validation_raises(self): - self.form_field_mock.data = "2017-05-04" - - with pytest.raises( - validators.ValidationError, match="^Field must be greater than or equal to other field.$" - ): - self._validate() - - def test_validation_raises_custom_message(self): - self.form_field_mock.data = "2017-05-04" - - with pytest.raises( - validators.ValidationError, match="^This field must be greater than or equal to MyField.$" - ): - self._validate( - message="This field must be greater than or equal to MyField.", - ) - - -class TestValidJson: - def setup_method(self): - self.form_field_mock = mock.MagicMock(data='{"valid":"True"}') - self.form_field_mock.gettext.side_effect = lambda msg: msg - self.form_mock = mock.MagicMock(spec_set=dict) - - def _validate(self, message=None): - validator = validators.ValidJson(message=message) - - return validator(self.form_mock, self.form_field_mock) - - def test_form_field_is_none(self): - self.form_field_mock.data = None - - assert self._validate() is None - - def test_validation_pass(self): - assert self._validate() is None - - def test_validation_raises_default_message(self): - self.form_field_mock.data = "2017-05-04" - - with pytest.raises(validators.ValidationError, match="JSON Validation Error:.*"): - self._validate() - - def test_validation_raises_custom_message(self): - self.form_field_mock.data = "2017-05-04" - - with pytest.raises(validators.ValidationError, match="Invalid JSON"): - self._validate( - message="Invalid JSON: {}", - ) - - -class TestValidKey: - def setup_method(self): - self.form_field_mock = mock.MagicMock(data="valid_key") - self.form_field_mock.gettext.side_effect = lambda msg: msg - self.form_mock = mock.MagicMock(spec_set=dict) - - def _validate(self): - validator = validators.ValidKey() - - return validator(self.form_mock, self.form_field_mock) - - def test_form_field_is_none(self): - self.form_field_mock.data = None - - assert self._validate() is None - - def test_validation_pass(self): - assert self._validate() is None - - def test_validation_fails_with_trailing_whitespace(self): - self.form_field_mock.data = "invalid key " - - with pytest.raises(validators.ValidationError): - self._validate() - - def test_validation_fails_with_too_many_characters(self): - self.form_field_mock.data = "".join("x" for _ in range(1000)) - - with pytest.raises( - validators.ValidationError, - match=r"The key: [x]+ has to be less than [0-9]+ characters", - ): - self._validate() - - -class TestReadOnly: - def setup_method(self): - self.form_read_only_field_mock = mock.MagicMock(data="readOnlyField") - self.form_mock = mock.MagicMock(spec_set=dict) - - def test_read_only_validator(self): - validator = validators.ReadOnly() - assert validator(self.form_mock, self.form_read_only_field_mock) is None - assert self.form_read_only_field_mock.flags.readonly is True diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index f4bb4da213ad0..13e9b3138af30 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -848,7 +848,10 @@ def __exit__(self, type, value, traceback): dag.sync_to_db(session=self.session) if dag.access_control: - from airflow.www.security_appless import ApplessAirflowSecurityManager + if AIRFLOW_V_3_0_PLUS: + from airflow.providers.fab.www.security_appless import ApplessAirflowSecurityManager + else: + from airflow.www.security_appless import ApplessAirflowSecurityManager security_manager = ApplessAirflowSecurityManager(session=self.session) security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control)