Skip to content

Commit

Permalink
feat: Jinja2 macro for querying datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed May 17, 2022
1 parent 6244728 commit 570e7c0
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 2 deletions.
2 changes: 1 addition & 1 deletion superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def external_metadata(self) -> List[Dict[str, str]]:
def get_query_str(self, query_obj: QueryObjectDict) -> str:
"""Returns a query as a string
This is used to be displayed to the user so that she/he can
This is used to be displayed to the user so that they can
understand what is taking place behind the scene"""
raise NotImplementedError()

Expand Down
34 changes: 34 additions & 0 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sqlalchemy.types import String
from typing_extensions import TypedDict

from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetTemplateException
from superset.extensions import feature_flag_manager
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters
Expand Down Expand Up @@ -490,6 +491,7 @@ def set_context(self, **kwargs: Any) -> None:
"cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper),
"filter_values": partial(safe_proxy, extra_cache.filter_values),
"get_filters": partial(safe_proxy, extra_cache.get_filters),
"dataset": partial(safe_proxy, dataset_macro),
}
)

Expand Down Expand Up @@ -602,3 +604,35 @@ def get_template_processor(
else:
template_processor = NoOpTemplateProcessor
return template_processor(database=database, table=table, query=query, **kwargs)


def dataset_macro(
dataset_id: int,
include_metrics: bool = False,
groupby: Optional[List[str]] = None,
) -> str:
"""
Given a dataset ID, return the SQL that represents it.
The generated SQL includes all columns (including computed) by default. Optionally
the user can also request metrics to be included, and columns to group by.
"""
# pylint: disable=import-outside-toplevel
from superset.datasets.dao import DatasetDAO

dataset = DatasetDAO.find_by_id(dataset_id)
if not dataset:
raise DatasetNotFoundError(f"Dataset {dataset_id} not found!")

columns = [column.column_name for column in dataset.columns]
metrics = [metric.metric_name for metric in dataset.metrics]
query_obj = {
"is_timeseries": False,
"filter": [],
"metrics": metrics if include_metrics else None,
"columns": columns,
"groupby": groupby,
}
sqla_query = dataset.get_query_str_extended(query_obj)
sql = sqla_query.sql
return f"({sql}) AS dataset_{dataset_id}"
101 changes: 100 additions & 1 deletion tests/unit_tests/jinja_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument

from superset.jinja_context import where_in
import json

import pytest
from pytest_mock import MockFixture

from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.jinja_context import dataset_macro, where_in


def test_where_in() -> None:
Expand All @@ -25,3 +32,95 @@ def test_where_in() -> None:
assert where_in([1, "b", 3]) == "(1, 'b', 3)"
assert where_in([1, "b", 3], '"') == '(1, "b", 3)'
assert where_in(["O'Malley's"]) == "('O''Malley''s')"


def test_dataset_macro(mocker: MockFixture, app_context: None) -> None:
"""
Test the ``dataset_macro`` macro.
"""
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.core import Database

columns = [
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
TableColumn(column_name="num_boys", type="INTEGER"),
TableColumn(column_name="revenue", type="INTEGER"),
TableColumn(column_name="expenses", type="INTEGER"),
TableColumn(
column_name="profit", type="INTEGER", expression="revenue-expenses"
),
]
metrics = [
SqlMetric(metric_name="cnt", expression="COUNT(*)"),
]

dataset = SqlaTable(
table_name="old_dataset",
columns=columns,
metrics=metrics,
main_dttm_col="ds",
default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
offset=-8,
description="This is the description",
is_featured=1,
cache_timeout=3600,
schema="my_schema",
sql=None,
params=json.dumps(
{
"remote_id": 64,
"database_name": "examples",
"import_time": 1606677834,
}
),
perm=None,
filter_select_enabled=1,
fetch_values_predicate="foo IN (1, 2)",
is_sqllab_view=0, # no longer used?
template_params=json.dumps({"answer": "42"}),
schema_perm=None,
extra=json.dumps({"warning_markdown": "*WARNING*"}),
)
DatasetDAO = mocker.patch("superset.datasets.dao.DatasetDAO")
DatasetDAO.find_by_id.return_value = dataset

assert (
dataset_macro(1)
== """(SELECT ds AS ds,
num_boys AS num_boys,
revenue AS revenue,
expenses AS expenses,
revenue-expenses AS profit
FROM my_schema.old_dataset) AS dataset_1"""
)

assert (
dataset_macro(1, include_metrics=True)
== """(SELECT ds AS ds,
num_boys AS num_boys,
revenue AS revenue,
expenses AS expenses,
revenue-expenses AS profit,
COUNT(*) AS cnt
FROM my_schema.old_dataset
GROUP BY ds,
num_boys,
revenue,
expenses,
revenue-expenses) AS dataset_1"""
)

assert (
dataset_macro(1, include_metrics=True, groupby=["ds"])
== """(SELECT ds AS ds,
COUNT(*) AS cnt
FROM my_schema.old_dataset
GROUP BY ds) AS dataset_1"""
)

DatasetDAO.find_by_id.return_value = None
with pytest.raises(DatasetNotFoundError) as excinfo:
dataset_macro(1)
assert str(excinfo.value) == "Dataset 1 not found!"

0 comments on commit 570e7c0

Please sign in to comment.