diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py index 722be8e4684e3..ec2292f6ef32b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py @@ -19,6 +19,7 @@ import json from collections.abc import Iterable +from uuid import UUID from pydantic import Field, JsonValue, model_validator @@ -35,6 +36,7 @@ class VariableResponse(BaseModel): val: str = Field(alias="value") description: str | None is_encrypted: bool + team_id: UUID | None @model_validator(mode="after") def redact_val(self) -> Self: @@ -57,6 +59,7 @@ class VariableBody(StrictBaseModel): key: str = Field(max_length=ID_LEN) value: JsonValue = Field(serialization_alias="val") description: str | None = Field(default=None) + team_id: UUID | None = Field(default=None) class VariableCollectionResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 2e064fa35a6a9..0310e745712e0 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -13068,6 +13068,12 @@ components: - type: string - type: 'null' title: Description + team_id: + anyOf: + - type: string + format: uuid + - type: 'null' + title: Team Id additionalProperties: false type: object required: @@ -13107,12 +13113,19 @@ components: is_encrypted: type: boolean title: Is Encrypted + team_id: + anyOf: + - type: string + format: uuid + - type: 'null' + title: Team Id type: object required: - key - value - description - is_encrypted + - team_id title: VariableResponse description: Variable serializer for responses. VersionInfo: diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 684da9ac5b84c..87d1445a06597 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -30,7 +30,7 @@ from sqlalchemy_utils import UUIDType from airflow._shared.secrets_masker import mask_secret -from airflow.configuration import ensure_secrets_loaded +from airflow.configuration import conf, ensure_secrets_loaded from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet from airflow.models.team import Team @@ -149,7 +149,7 @@ def get( # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big # back-compat layer - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) + # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): warnings.warn( @@ -185,6 +185,7 @@ def set( value: Any, description: str | None = None, serialize_json: bool = False, + team_id: str | None = None, session: Session | None = None, ) -> None: """ @@ -196,13 +197,14 @@ def set( :param value: Value to set for the Variable :param description: Description of the Variable :param serialize_json: Serialize the value to a JSON string + :param team_id: ID of the team associated to the variable (if any) :param session: optional session, use if provided or create a new one """ # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big # back-compat layer - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) + # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): warnings.warn( @@ -221,6 +223,11 @@ def set( ) return + if team_id and not conf.getboolean("core", "multi_team"): + raise ValueError( + "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled." + ) + # check if the secret exists in the custom secrets' backend. Variable.check_for_write_conflict(key=key) if serialize_json: @@ -235,7 +242,7 @@ def set( ctx = create_session() with ctx as session: - new_variable = Variable(key=key, val=stored_value, description=description) + new_variable = Variable(key=key, val=stored_value, description=description, team_id=team_id) val = new_variable._val is_encrypted = new_variable.is_encrypted @@ -252,6 +259,7 @@ def set( val=val, description=description, is_encrypted=is_encrypted, + team_id=team_id, ) stmt = pg_stmt.on_conflict_do_update( index_elements=["key"], @@ -259,6 +267,7 @@ def set( val=val, description=description, is_encrypted=is_encrypted, + team_id=team_id, ), ) elif dialect_name == "mysql": @@ -269,11 +278,13 @@ def set( val=val, description=description, is_encrypted=is_encrypted, + team_id=team_id, ) stmt = mysql_stmt.on_duplicate_key_update( val=val, description=description, is_encrypted=is_encrypted, + team_id=team_id, ) else: from sqlalchemy.dialects.sqlite import insert as sqlite_insert @@ -283,6 +294,7 @@ def set( val=val, description=description, is_encrypted=is_encrypted, + team_id=team_id, ) stmt = sqlite_stmt.on_conflict_do_update( index_elements=["key"], @@ -290,6 +302,7 @@ def set( val=val, description=description, is_encrypted=is_encrypted, + team_id=team_id, ), ) diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index b2502da763811..5142b744a77e3 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -6559,6 +6559,18 @@ export const $VariableBody = { } ], title: 'Description' + }, + team_id: { + anyOf: [ + { + type: 'string', + format: 'uuid' + }, + { + type: 'null' + } + ], + title: 'Team Id' } }, additionalProperties: false, @@ -6612,10 +6624,22 @@ export const $VariableResponse = { is_encrypted: { type: 'boolean', title: 'Is Encrypted' + }, + team_id: { + anyOf: [ + { + type: 'string', + format: 'uuid' + }, + { + type: 'null' + } + ], + title: 'Team Id' } }, type: 'object', - required: ['key', 'value', 'description', 'is_encrypted'], + required: ['key', 'value', 'description', 'is_encrypted', 'team_id'], title: 'VariableResponse', description: 'Variable serializer for responses.' } as const; diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 57c5d50b3db5c..5eca441aa8ef3 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1594,6 +1594,7 @@ export type VariableBody = { key: string; value: JsonValue; description?: string | null; + team_id?: string | null; }; /** @@ -1612,6 +1613,7 @@ export type VariableResponse = { value: string; description: string | null; is_encrypted: boolean; + team_id: string | null; }; /** diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py index 6e6282508a9c4..b235db785b39d 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py @@ -19,14 +19,17 @@ import json from io import BytesIO from unittest import mock +from unittest.mock import ANY import pytest +from airflow.models.team import Team from airflow.models.variable import Variable from airflow.utils.session import provide_session from tests_common.test_utils.asserts import assert_queries_count -from tests_common.test_utils.db import clear_db_variables +from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import clear_db_teams, clear_db_variables from tests_common.test_utils.logs import check_last_log pytestmark = pytest.mark.db_test @@ -62,6 +65,8 @@ def create_file_upload(content: dict) -> BytesIO: @provide_session def _create_variables(session) -> None: + team = session.query(Team).where(Team.name == "test").one() + Variable.set( key=TEST_VARIABLE_KEY, value=TEST_VARIABLE_VALUE, @@ -87,6 +92,7 @@ def _create_variables(session) -> None: key=TEST_VARIABLE_KEY4, value=TEST_VARIABLE_VALUE4, description=TEST_VARIABLE_DESCRIPTION4, + team_id=team.id, session=session, ) @@ -98,15 +104,31 @@ def _create_variables(session) -> None: ) +@provide_session +def _create_team(session) -> None: + session.add(Team(name="test")) + session.commit() + + +@pytest.fixture(scope="session") +def team_id(session): + return str(session.query(Team.id).filter_by(name="test").one()[0]) + + class TestVariableEndpoint: @pytest.fixture(autouse=True) - def setup(self) -> None: + def setup(self): clear_db_variables() + clear_db_teams() + with conf_vars({("core", "multi_team"): "True"}): + yield - def teardown_method(self) -> None: + def teardown_method(self): clear_db_variables() + clear_db_teams() def create_variables(self): + _create_team() _create_variables() @@ -150,6 +172,7 @@ class TestGetVariable(TestVariableEndpoint): "value": TEST_VARIABLE_VALUE, "description": TEST_VARIABLE_DESCRIPTION, "is_encrypted": True, + "team_id": None, }, ), ( @@ -159,6 +182,7 @@ class TestGetVariable(TestVariableEndpoint): "value": "***", "description": TEST_VARIABLE_DESCRIPTION2, "is_encrypted": True, + "team_id": None, }, ), ( @@ -168,6 +192,7 @@ class TestGetVariable(TestVariableEndpoint): "value": '{"password": "***"}', "description": TEST_VARIABLE_DESCRIPTION3, "is_encrypted": True, + "team_id": None, }, ), ( @@ -177,6 +202,7 @@ class TestGetVariable(TestVariableEndpoint): "value": TEST_VARIABLE_VALUE4, "description": TEST_VARIABLE_DESCRIPTION4, "is_encrypted": True, + "team_id": ANY, }, ), ( @@ -186,6 +212,7 @@ class TestGetVariable(TestVariableEndpoint): "value": TEST_VARIABLE_SEARCH_VALUE, "description": TEST_VARIABLE_SEARCH_DESCRIPTION, "is_encrypted": True, + "team_id": None, }, ), ], @@ -343,6 +370,7 @@ class TestPatchVariable(TestVariableEndpoint): "value": "The new value", "description": "The new description", "is_encrypted": True, + "team_id": None, }, ), ( @@ -351,6 +379,7 @@ class TestPatchVariable(TestVariableEndpoint): "key": TEST_VARIABLE_KEY, "value": "The new value", "description": "The new description", + "team_id": None, }, {"update_mask": ["value"]}, { @@ -358,6 +387,7 @@ class TestPatchVariable(TestVariableEndpoint): "value": "The new value", "description": TEST_VARIABLE_DESCRIPTION, "is_encrypted": True, + "team_id": None, }, ), ( @@ -373,6 +403,7 @@ class TestPatchVariable(TestVariableEndpoint): "value": "The new value", "description": TEST_VARIABLE_DESCRIPTION4, "is_encrypted": True, + "team_id": ANY, }, ), ( @@ -388,6 +419,7 @@ class TestPatchVariable(TestVariableEndpoint): "value": "***", "description": TEST_VARIABLE_DESCRIPTION2, "is_encrypted": True, + "team_id": None, }, ), ( @@ -403,6 +435,7 @@ class TestPatchVariable(TestVariableEndpoint): "value": '{"password": "***"}', "description": "new description", "is_encrypted": True, + "team_id": None, }, ), ], @@ -414,6 +447,25 @@ def test_patch_should_respond_200(self, test_client, session, key, body, params, assert response.json() == expected_response check_last_log(session, dag_id=None, event="patch_variable", logical_date=None) + def test_patch_with_team_should_respond_200(self, test_client, session, testing_team): + self.create_variables() + body = { + "key": TEST_VARIABLE_KEY, + "value": "The new value", + "description": "The new description", + "team_id": str(testing_team.id), + } + response = test_client.patch(f"/variables/{TEST_VARIABLE_KEY}", json=body) + assert response.status_code == 200 + assert response.json() == { + "key": TEST_VARIABLE_KEY, + "value": "The new value", + "description": "The new description", + "is_encrypted": True, + "team_id": str(testing_team.id), + } + check_last_log(session, dag_id=None, event="patch_variable", logical_date=None) + def test_patch_should_respond_400(self, test_client): response = test_client.patch( f"/variables/{TEST_VARIABLE_KEY}", @@ -463,6 +515,7 @@ class TestPostVariable(TestVariableEndpoint): "value": "new variable value", "description": "new variable description", "is_encrypted": True, + "team_id": None, }, ), ( @@ -476,6 +529,7 @@ class TestPostVariable(TestVariableEndpoint): "value": "***", "description": "another password", "is_encrypted": True, + "team_id": None, }, ), ( @@ -489,6 +543,7 @@ class TestPostVariable(TestVariableEndpoint): "value": '{"password": "***"}', "description": "some description", "is_encrypted": True, + "team_id": None, }, ), ( @@ -502,6 +557,7 @@ class TestPostVariable(TestVariableEndpoint): "value": "", "description": "some description", "is_encrypted": True, + "team_id": None, }, ), ], @@ -513,6 +569,25 @@ def test_post_should_respond_201(self, test_client, session, body, expected_resp assert response.json() == expected_response check_last_log(session, dag_id=None, event="post_variable", logical_date=None) + def test_post_with_team_should_respond_201(self, test_client, testing_team, session): + self.create_variables() + body = { + "key": "new variable key", + "value": "new variable value", + "description": "new variable description", + "team_id": str(testing_team.id), + } + response = test_client.post("/variables", json=body) + assert response.status_code == 201 + assert response.json() == { + "key": "new variable key", + "value": "new variable value", + "description": "new variable description", + "is_encrypted": True, + "team_id": str(testing_team.id), + } + check_last_log(session, dag_id=None, event="post_variable", logical_date=None) + def test_post_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.post( "/variables", diff --git a/airflow-core/tests/unit/models/test_variable.py b/airflow-core/tests/unit/models/test_variable.py index b02a760a665dd..722b6a155a346 100644 --- a/airflow-core/tests/unit/models/test_variable.py +++ b/airflow-core/tests/unit/models/test_variable.py @@ -198,6 +198,17 @@ def test_set_variable_sets_description(self, session): assert test_var.description == "a test variable" assert test_var.val == "value" + @conf_vars({("core", "multi_team"): "True"}) + def test_set_variable_sets_team(self, testing_team, session): + Variable.set(key="key", value="value", team_id=testing_team.id, session=session) + test_var = session.query(Variable).filter(Variable.key == "key").one() + assert test_var.team_id == testing_team.id + assert test_var.val == "value" + + def test_set_variable_sets_team_multi_team_off(self, testing_team, session): + with pytest.raises(ValueError, match=r"Multi-team mode is not configured in the Airflow environment"): + Variable.set(key="key", value="value", team_id=testing_team.id, session=session) + def test_variable_set_existing_value_to_blank(self, session): test_value = "Some value" test_key = "test_key" diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index 5621a9950cf3e..1cb5ac77dde1f 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -925,6 +925,7 @@ class VariableBody(BaseModel): key: Annotated[str, Field(max_length=250, title="Key")] value: JsonValue description: Annotated[str | None, Field(title="Description")] = None + team_id: Annotated[UUID | None, Field(title="Team Id")] = None class VariableResponse(BaseModel): @@ -936,6 +937,7 @@ class VariableResponse(BaseModel): value: Annotated[str, Field(title="Value")] description: Annotated[str | None, Field(title="Description")] = None is_encrypted: Annotated[bool, Field(title="Is Encrypted")] + team_id: Annotated[UUID | None, Field(title="Team Id")] = None class VersionInfo(BaseModel): diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 0355c3abc876b..be75d6ee1cd4f 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -264,7 +264,7 @@ def test_server_response_error_invalid_status_codes(self, status_code, descripti class TestTaskInstanceOperations: """ - Test that the TestVariableOperations class works as expected. While the operations are simple, it + Test that the TestTaskInstanceOperations class works as expected. While the operations are simple, it still catches the basic functionality of the client for task instances including endpoint and response parsing. """