Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ async def has_variable_access(


@router.get(
"/{variable_key}",
"/{variable_key:path}",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def get_variable(variable_key: str) -> VariableResponse:
"""Get an Airflow Variable."""
if not variable_key:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Not Found")

try:
variable_value = Variable.get(variable_key)
except KeyError:
Expand All @@ -78,7 +81,7 @@ def get_variable(variable_key: str) -> VariableResponse:


@router.put(
"/{variable_key}",
"/{variable_key:path}",
status_code=status.HTTP_201_CREATED,
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
Expand All @@ -87,12 +90,15 @@ def get_variable(variable_key: str) -> VariableResponse:
)
def put_variable(variable_key: str, body: VariablePostBody):
"""Set an Airflow Variable."""
if not variable_key:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Not Found")

Variable.set(key=variable_key, value=body.value, description=body.description)
return {"message": "Variable successfully set"}


@router.delete(
"/{variable_key}",
"/{variable_key:path}",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
Expand All @@ -101,4 +107,7 @@ def put_variable(variable_key: str, body: VariablePostBody):
)
def delete_variable(variable_key: str):
"""Delete an Airflow Variable."""
if not variable_key:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Not Found")

Variable.delete(key=variable_key)
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,24 @@ async def _(request: Request, variable_key: str, token=JWTBearerDep):


class TestGetVariable:
def test_variable_get_from_db(self, client, session):
Variable.set(key="var1", value="value", session=session)
@pytest.mark.parametrize(
"key, value",
[
("var1", "value"),
("var2/with_slash", "slash_value"),
],
)
def test_variable_get_from_db(self, client, session, key, value):
Variable.set(key=key, value=value, session=session)
session.commit()

response = client.get("/execution/variables/var1")
response = client.get(f"/execution/variables/{key}")

assert response.status_code == 200
assert response.json() == {"key": "var1", "value": "value"}
assert response.json() == {"key": key, "value": value}

# Remove connection
Variable.delete(key="var1", session=session)
Variable.delete(key=key, session=session)
session.commit()

@mock.patch.dict(
Expand All @@ -88,13 +95,20 @@ def test_variable_get_from_env_var(self, client, session):
assert response.status_code == 200
assert response.json() == {"key": "key1", "value": "VALUE"}

def test_variable_get_not_found(self, client):
response = client.get("/execution/variables/non_existent_var")
@pytest.mark.parametrize(
"key",
[
"non_existent_var",
"non/existent/slash/var",
],
)
def test_variable_get_not_found(self, client, key):
response = client.get(f"/execution/variables/{key}")

assert response.status_code == 404
assert response.json() == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"message": f"Variable with key '{key}' not found",
"reason": "not_found",
}
}
Expand All @@ -117,22 +131,26 @@ def test_variable_get_access_denied(self, client, caplog):

class TestPutVariable:
@pytest.mark.parametrize(
"payload",
"key, payload",
[
pytest.param({"value": "{}", "description": "description"}, id="valid-payload"),
pytest.param({"value": "{}"}, id="missing-description"),
pytest.param("var_create", {"value": "{}", "description": "description"}, id="valid-payload"),
pytest.param("var_create", {"value": "{}"}, id="missing-description"),
pytest.param(
"var_create/with_slash",
{"value": "slash_value", "description": "Variable with slash"},
id="slash-key",
),
],
)
def test_should_create_variable(self, client, payload, session):
key = "var_create"
def test_should_create_variable(self, client, key, payload, session):
response = client.put(
f"/execution/variables/{key}",
json=payload,
)
assert response.status_code == 201, response.json()
assert response.json()["message"] == "Variable successfully set"

var_from_db = session.query(Variable).where(Variable.key == "var_create").first()
var_from_db = session.query(Variable).where(Variable.key == key).first()
assert var_from_db is not None
assert var_from_db.key == key
assert var_from_db.val == payload["value"]
Expand Down Expand Up @@ -179,8 +197,14 @@ def test_variable_adding_extra_fields(self, client, key, payload, session):
assert response.json()["detail"][0]["type"] == "extra_forbidden"
assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted"

def test_overwriting_existing_variable(self, client, session):
key = "var_create"
@pytest.mark.parametrize(
"key",
[
"var_create",
"var_create/with_slash",
],
)
def test_overwriting_existing_variable(self, client, session, key):
Variable.set(key=key, value="value", session=session)
session.commit()

Expand Down Expand Up @@ -218,19 +242,26 @@ def test_post_variable_access_denied(self, client, caplog):


class TestDeleteVariable:
def test_should_delete_variable(self, client, session):
for i in range(1, 3):
Variable.set(key=f"key{i}", value=i)
@pytest.mark.parametrize(
"keys_to_create, key_to_delete",
[
(["key1", "key2"], "key1"),
(["key3/with_slash", "key4"], "key3/with_slash"),
],
)
def test_should_delete_variable(self, client, session, keys_to_create, key_to_delete):
for i, key in enumerate(keys_to_create, 1):
Variable.set(key=key, value=str(i))

vars = session.query(Variable).all()
assert len(vars) == 2
assert len(vars) == len(keys_to_create)

response = client.delete("/execution/variables/key1")
response = client.delete(f"/execution/variables/{key_to_delete}")

assert response.status_code == 204

vars = session.query(Variable).all()
assert len(vars) == 1
assert len(vars) == len(keys_to_create) - 1

def test_should_not_delete_variable(self, client, session):
Variable.set(key="key", value="value")
Expand Down