Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cache-warmup fails #31173

Merged
merged 12 commits into from
Dec 7, 2024
25 changes: 24 additions & 1 deletion superset/tasks/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Optional, Union
from urllib import request
from urllib.error import URLError
from urllib.parse import urlparse

from celery.beat import SchedulingError
from celery.utils.log import get_task_logger
Expand Down Expand Up @@ -47,6 +48,23 @@
return payload


def is_secure_url(url: str) -> bool:
nsivarajan marked this conversation as resolved.
Show resolved Hide resolved
"""
Validates if a URL is secure (uses HTTPS).

:param url: The URL to validate.
:return: True if the URL uses HTTPS (secure), False if it uses HTTP (non-secure).
"""
try:
parsed_url = urlparse(url)

Check warning on line 59 in superset/tasks/cache.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/cache.py#L58-L59

Added lines #L58 - L59 were not covered by tests
# Return True for HTTPS, False for HTTP
return parsed_url.scheme == "https"
except ValueError as exception:

Check warning on line 62 in superset/tasks/cache.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/cache.py#L61-L62

Added lines #L61 - L62 were not covered by tests
# Log a warning with detailed context
logger.warning("Failed to parse URL '%s': %s", url, str(exception))
return False

Check warning on line 65 in superset/tasks/cache.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/cache.py#L64-L65

Added lines #L64 - L65 were not covered by tests


class Strategy: # pylint: disable=too-few-public-methods
"""
A cache warm up strategy.
Expand Down Expand Up @@ -220,10 +238,15 @@
"""
result = {}
try:
url = get_url_path("ChartRestApi.warm_up_cache")

if is_secure_url(url):
logger.info("URL '%s' is secure. Adding Referer header.", url)
headers.update({"Referer": url})

# Fetch CSRF token for API request
headers.update(fetch_csrf_token(headers))

url = get_url_path("ChartRestApi.warm_up_cache")
logger.info("Fetching %s with payload %s", url, data)
req = request.Request(
url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
Expand Down
2 changes: 1 addition & 1 deletion superset/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def fetch_csrf_token(
data = json.loads(body)
res = {"X-CSRF-Token": data["result"]}
if session_cookie is not None:
res["Cookie"] = session_cookie
res["Cookie"] = f"{session_cookie_name}={session_cookie}"
mistercrunch marked this conversation as resolved.
Show resolved Hide resolved
return res

logger.error("Error fetching CSRF token, status code: %s", response.status)
Expand Down
46 changes: 39 additions & 7 deletions tests/integration_tests/tasks/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,32 @@


@pytest.mark.parametrize(
"base_url",
"base_url, expected_referer",
[
"http://base-url",
"http://base-url/",
("http://base-url", None),
("http://base-url/", None),
("https://base-url", "https://base-url/api/v1/chart/warm_up_cache"),
("https://base-url/", "https://base-url/api/v1/chart/warm_up_cache"),
],
ids=[
"Without trailing slash (HTTP)",
"With trailing slash (HTTP)",
"Without trailing slash (HTTPS)",
"With trailing slash (HTTPS)",
],
ids=["Without trailing slash", "With trailing slash"],
)
@mock.patch("superset.tasks.cache.fetch_csrf_token")
@mock.patch("superset.tasks.cache.request.Request")
@mock.patch("superset.tasks.cache.request.urlopen")
def test_fetch_url(mock_urlopen, mock_request_cls, mock_fetch_csrf_token, base_url):
@mock.patch("superset.tasks.cache.is_secure_url")
def test_fetch_url(
mock_is_secure_url,
mock_urlopen,
mock_request_cls,
mock_fetch_csrf_token,
base_url,
expected_referer,
):
from superset.tasks.cache import fetch_url

mock_request = mock.MagicMock()
Expand All @@ -41,8 +56,17 @@ def test_fetch_url(mock_urlopen, mock_request_cls, mock_fetch_csrf_token, base_u
mock_urlopen.return_value = mock.MagicMock()
mock_urlopen.return_value.code = 200

# Mock the URL validation to return True for HTTPS and False for HTTP
mock_is_secure_url.return_value = base_url.startswith("https")

initial_headers = {"Cookie": "cookie", "key": "value"}
csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"}

# Conditionally add the Referer header and assert its presence
if expected_referer:
csrf_headers = csrf_headers | {"Referer": expected_referer}
assert csrf_headers["Referer"] == expected_referer

mock_fetch_csrf_token.return_value = csrf_headers

app.config["WEBDRIVER_BASEURL"] = base_url
Expand All @@ -51,13 +75,21 @@ def test_fetch_url(mock_urlopen, mock_request_cls, mock_fetch_csrf_token, base_u

result = fetch_url(data, initial_headers)

assert data == result["success"]
expected_url = (
f"{base_url}/api/v1/chart/warm_up_cache"
if not base_url.endswith("/")
else f"{base_url}api/v1/chart/warm_up_cache"
)

mock_fetch_csrf_token.assert_called_once_with(initial_headers)

mock_request_cls.assert_called_once_with(
"http://base-url/api/v1/chart/warm_up_cache",
expected_url, # Use the dynamic URL based on base_url
data=data_encoded,
headers=csrf_headers,
method="PUT",
)
# assert the same Request object is used
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)

assert data == result["success"]
19 changes: 16 additions & 3 deletions tests/integration_tests/tasks/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@
[
"http://base-url",
"http://base-url/",
"https://base-url",
"https://base-url/",
],
ids=[
"Without trailing slash (HTTP)",
"With trailing slash (HTTP)",
"Without trailing slash (HTTPS)",
"With trailing slash (HTTPS)",
],
ids=["Without trailing slash", "With trailing slash"],
)
@mock.patch("superset.tasks.cache.request.Request")
@mock.patch("superset.tasks.cache.request.urlopen")
Expand All @@ -52,13 +59,19 @@ def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context)

result_headers = fetch_csrf_token(headers)

expected_url = (
f"{base_url}/api/v1/security/csrf_token/"
if not base_url.endswith("/")
else f"{base_url}api/v1/security/csrf_token/"
)

mock_request_cls.assert_called_with(
"http://base-url/api/v1/security/csrf_token/",
expected_url,
headers=headers,
method="GET",
)

assert result_headers["X-CSRF-Token"] == "csrf_token"
assert result_headers["Cookie"] == "new_session_cookie"
assert result_headers["Cookie"] == "session=new_session_cookie" # Updated assertion
# assert the same Request object is used
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
Loading