Skip to content

Commit

Permalink
feat(embedded+async queries): support async queries to work with embe…
Browse files Browse the repository at this point in the history
…dded guest user (apache#26332)
  • Loading branch information
Zef Lin authored Jan 9, 2024
1 parent 4c2e818 commit efdeb9d
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 19 deletions.
26 changes: 23 additions & 3 deletions superset/async_events/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,40 @@ def submit_explore_json_job(
force: Optional[bool] = False,
user_id: Optional[int] = None,
) -> dict[str, Any]:
# pylint: disable=import-outside-toplevel
from superset import security_manager

job_metadata = self.init_job(channel_id, user_id)
self._load_explore_json_into_cache_job.delay(
job_metadata,
{**job_metadata, "guest_token": guest_user.guest_token}
if (guest_user := security_manager.get_current_guest_user_if_guest())
else job_metadata,
form_data,
response_type,
force,
)
return job_metadata

def submit_chart_data_job(
self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
self,
channel_id: str,
form_data: dict[str, Any],
user_id: Optional[int] = None,
) -> dict[str, Any]:
# pylint: disable=import-outside-toplevel
from superset import security_manager

# if it's guest user, we want to pass the guest token to the celery task
# chart data cache key is calculated based on the current user
# this way we can keep the cache key consistent between sync and async command
# so that it can be looked up consistently
job_metadata = self.init_job(channel_id, user_id)
self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
self._load_chart_data_into_cache_job.delay(
{**job_metadata, "guest_token": guest_user.guest_token}
if (guest_user := security_manager.get_current_guest_user_if_guest())
else job_metadata,
form_data,
)
return job_metadata

def read_events(
Expand Down
10 changes: 9 additions & 1 deletion superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,15 @@ def get_payload(
set_and_log_cache(
cache_manager.cache,
cache_key,
{"data": self._query_context.cache_values},
{
"data": {
# setting form_data into query context cache value as well
# so that it can be used to reconstruct form_data field
# for query context object when reading from cache
"form_data": self._query_context.form_data,
**self._query_context.cache_values,
},
},
self.get_cache_timeout(),
)
return_value["cache_key"] = cache_key # type: ignore
Expand Down
37 changes: 24 additions & 13 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError

from superset.charts.schemas import ChartDataQueryContextSchema
Expand Down Expand Up @@ -58,6 +59,20 @@ def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext:
raise error


def _load_user_from_job_metadata(job_metadata: dict[str, Any]) -> User:
if user_id := job_metadata.get("user_id"):
# logged in user
user = security_manager.get_user_by_id(user_id)
elif guest_token := job_metadata.get("guest_token"):
# embedded guest user
user = security_manager.get_guest_user_from_token(guest_token)
del job_metadata["guest_token"]
else:
# default to anonymous user if no user is found
user = security_manager.get_anonymous_user()
return user


@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: dict[str, Any],
Expand All @@ -66,12 +81,7 @@ def load_chart_data_into_cache(
# pylint: disable=import-outside-toplevel
from superset.commands.chart.data.get_data_command import ChartDataCommand

user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)

with override_user(user, force=False):
with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
Expand Down Expand Up @@ -106,12 +116,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
) -> None:
cache_key_prefix = "ejr-" # ejr: explore_json request

user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)

with override_user(user, force=False):
with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
Expand Down Expand Up @@ -140,7 +145,13 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
cache_instance = cache_manager.cache
cache_timeout = (
cache_instance.cache.default_timeout if cache_instance.cache else None
)
set_and_log_cache(
cache_instance, cache_key, cache_value, cache_timeout=cache_timeout
)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_cache(self):

cached = cache_manager.cache.get(cache_key)
assert cached is not None
assert "form_data" in cached["data"]

rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
rehydrated_qo = rehydrated_qc.queries[0]
Expand Down
79 changes: 77 additions & 2 deletions tests/unit_tests/async_events/async_query_manager_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from unittest.mock import ANY, Mock

from unittest.mock import Mock

from flask import g
from jwt import encode
from pytest import fixture, raises

from superset import security_manager
from superset.async_events.async_query_manager import (
AsyncQueryManager,
AsyncQueryTokenException,
Expand All @@ -38,6 +40,12 @@ def async_query_manager():
return query_manager


def set_current_as_guest_user():
g.user = security_manager.get_guest_user_from_token(
{"user": {}, "resources": [{"type": "dashboard", "id": "some-uuid"}]}
)


def test_parse_channel_id_from_request(async_query_manager):
encoded_token = encode(
{"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256"
Expand Down Expand Up @@ -65,3 +73,70 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):

with raises(AsyncQueryTokenException):
async_query_manager.parse_channel_id_from_request(request)


@mock.patch("superset.is_feature_enabled")
def test_submit_chart_data_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
job_mock = Mock()
async_query_manager._load_chart_data_into_cache_job = job_mock
job_meta = async_query_manager.submit_chart_data_job(
channel_id="test_channel_id",
form_data={},
)

job_mock.delay.assert_called_once_with(
{
"channel_id": "test_channel_id",
"errors": [],
"guest_token": {
"resources": [{"id": "some-uuid", "type": "dashboard"}],
"user": {},
},
"job_id": ANY,
"result_url": None,
"status": "pending",
"user_id": None,
},
{},
)

assert "guest_token" not in job_meta


@mock.patch("superset.is_feature_enabled")
def test_submit_explore_json_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
job_mock = Mock()
async_query_manager._load_explore_json_into_cache_job = job_mock
job_meta = async_query_manager.submit_explore_json_job(
channel_id="test_channel_id",
form_data={},
response_type="json",
)

job_mock.delay.assert_called_once_with(
{
"channel_id": "test_channel_id",
"errors": [],
"guest_token": {
"resources": [{"id": "some-uuid", "type": "dashboard"}],
"user": {},
},
"job_id": ANY,
"result_url": None,
"status": "pending",
"user_id": None,
},
{},
"json",
False,
)

assert "guest_token" not in job_meta

0 comments on commit efdeb9d

Please sign in to comment.