Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(embedded+async queries): support async queries to work with embedded guest user #26332

Merged
26 changes: 23 additions & 3 deletions superset/async_events/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,40 @@ def submit_explore_json_job(
force: Optional[bool] = False,
user_id: Optional[int] = None,
) -> dict[str, Any]:
# pylint: disable=import-outside-toplevel
from superset import security_manager

job_metadata = self.init_job(channel_id, user_id)
self._load_explore_json_into_cache_job.delay(
job_metadata,
{**job_metadata, "guest_token": guest_user.guest_token}
if (guest_user := security_manager.get_current_guest_user_if_guest())
else job_metadata,
Comment on lines +199 to +201
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we have the same logic in two places, could we DRY this out into a simple local util _get_job_metadata(job_metadata) to avoid divergence down the line?

form_data,
response_type,
force,
)
return job_metadata

def submit_chart_data_job(
self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
self,
channel_id: str,
form_data: dict[str, Any],
user_id: Optional[int] = None,
) -> dict[str, Any]:
# pylint: disable=import-outside-toplevel
from superset import security_manager

# if it's guest user, we want to pass the guest token to the celery task
# chart data cache key is calculated based on the current user
# this way we can keep the cache key consistent between sync and async command
# so that it can be looked up consistently
job_metadata = self.init_job(channel_id, user_id)
self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
self._load_chart_data_into_cache_job.delay(
{**job_metadata, "guest_token": guest_user.guest_token}
if (guest_user := security_manager.get_current_guest_user_if_guest())
else job_metadata,
form_data,
)
return job_metadata

def read_events(
Expand Down
10 changes: 9 additions & 1 deletion superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,15 @@ def get_payload(
set_and_log_cache(
cache_manager.cache,
cache_key,
{"data": self._query_context.cache_values},
{
"data": {
# setting form_data into query context cache value as well
# so that it can be used to reconstruct form_data field
# for query context object when reading from cache
"form_data": self._query_context.form_data,
Comment on lines +605 to +608
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns here.. We've discussed this at length previously, and I feel we should always consider form_data to be chart specific internal state. IIRC, the original reason that form data is being passed along to the celery task is because the legacy charts require it to function. However, the v1 charts shouldn't need it. I know many fields within form_data are sort of standardized, but the whole point of QueryContext/QueryObject was to standardize the chart data request contract, thus decoupling the viz plugins from the backend. Therefore, I feel any field that is required from form_data should ideally be allocated a dedicated property on either the query context or query object.

**self._query_context.cache_values,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@villebro do you have any thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zephyring from where does the raise_for_access get called when it is passing in the query context from cache?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first call to /chart/data endpoint will set the query context into cache and return async job metadata.
Then later once job is done frontend will receive a job ready event with the result_url something like (/api/v1/chart/data/qc-2772209e0a4b1ca513375118855b27af).
Then it uses that to call chart/data/<cache> endpoint which invokes the ChartCommand and the command validation will call the raise_for_access

},
},
self.get_cache_timeout(),
)
return_value["cache_key"] = cache_key # type: ignore
Expand Down
37 changes: 24 additions & 13 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError

from superset.charts.schemas import ChartDataQueryContextSchema
Expand Down Expand Up @@ -58,6 +59,20 @@
raise error


def _load_user_from_job_metadata(job_metadata: dict[str, Any]) -> User:
if user_id := job_metadata.get("user_id"):
# logged in user
user = security_manager.get_user_by_id(user_id)
elif guest_token := job_metadata.get("guest_token"):

Check warning on line 66 in superset/tasks/async_queries.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/async_queries.py#L66

Added line #L66 was not covered by tests
# embedded guest user
user = security_manager.get_guest_user_from_token(guest_token)
del job_metadata["guest_token"]

Check warning on line 69 in superset/tasks/async_queries.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/async_queries.py#L68-L69

Added lines #L68 - L69 were not covered by tests
else:
# default to anonymous user if no user is found
user = security_manager.get_anonymous_user()

Check warning on line 72 in superset/tasks/async_queries.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/async_queries.py#L72

Added line #L72 was not covered by tests
return user


@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: dict[str, Any],
Expand All @@ -66,12 +81,7 @@
# pylint: disable=import-outside-toplevel
from superset.commands.chart.data.get_data_command import ChartDataCommand

user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)

with override_user(user, force=False):
with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
Expand Down Expand Up @@ -106,12 +116,7 @@
) -> None:
cache_key_prefix = "ejr-" # ejr: explore_json request

user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user_id is None, get_user_by_id() will crash the celery task

or security_manager.get_anonymous_user()
)

with override_user(user, force=False):
with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
Expand Down Expand Up @@ -140,7 +145,13 @@
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
cache_instance = cache_manager.cache
cache_timeout = (
cache_instance.cache.default_timeout if cache_instance.cache else None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using cache_instance default_timeout if found. Otherwise it will fallback to config["CACHE_DEFAULT_TIMEOUT"]

)
set_and_log_cache(
cache_instance, cache_key, cache_value, cache_timeout=cache_timeout
)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_cache(self):

cached = cache_manager.cache.get(cache_key)
assert cached is not None
assert "form_data" in cached["data"]

rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
rehydrated_qo = rehydrated_qc.queries[0]
Expand Down
79 changes: 77 additions & 2 deletions tests/unit_tests/async_events/async_query_manager_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from unittest.mock import ANY, Mock

from unittest.mock import Mock

from flask import g
from jwt import encode
from pytest import fixture, raises

from superset import security_manager
from superset.async_events.async_query_manager import (
AsyncQueryManager,
AsyncQueryTokenException,
Expand All @@ -38,6 +40,12 @@ def async_query_manager():
return query_manager


def set_current_as_guest_user():
g.user = security_manager.get_guest_user_from_token(
{"user": {}, "resources": [{"type": "dashboard", "id": "some-uuid"}]}
)


def test_parse_channel_id_from_request(async_query_manager):
encoded_token = encode(
{"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256"
Expand Down Expand Up @@ -65,3 +73,70 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):

with raises(AsyncQueryTokenException):
async_query_manager.parse_channel_id_from_request(request)


@mock.patch("superset.is_feature_enabled")
def test_submit_chart_data_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
job_mock = Mock()
async_query_manager._load_chart_data_into_cache_job = job_mock
job_meta = async_query_manager.submit_chart_data_job(
channel_id="test_channel_id",
form_data={},
)

job_mock.delay.assert_called_once_with(
{
"channel_id": "test_channel_id",
"errors": [],
"guest_token": {
"resources": [{"id": "some-uuid", "type": "dashboard"}],
"user": {},
},
"job_id": ANY,
"result_url": None,
"status": "pending",
"user_id": None,
},
{},
)

assert "guest_token" not in job_meta


@mock.patch("superset.is_feature_enabled")
def test_submit_explore_json_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
job_mock = Mock()
async_query_manager._load_explore_json_into_cache_job = job_mock
job_meta = async_query_manager.submit_explore_json_job(
channel_id="test_channel_id",
form_data={},
response_type="json",
)

job_mock.delay.assert_called_once_with(
{
"channel_id": "test_channel_id",
"errors": [],
"guest_token": {
"resources": [{"id": "some-uuid", "type": "dashboard"}],
"user": {},
},
"job_id": ANY,
"result_url": None,
"status": "pending",
"user_id": None,
},
{},
"json",
False,
)

assert "guest_token" not in job_meta
Loading