Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(presto): Handle ROW data stored as string #10456

Merged
merged 5 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from urllib import parse

import pandas as pd
Expand All @@ -40,6 +40,7 @@
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils

Expand Down Expand Up @@ -568,7 +569,7 @@ def get_all_datasource_names(
return datasource_names

@classmethod
def expand_data( # pylint: disable=too-many-locals
def expand_data( # pylint: disable=too-many-locals,too-many-branches
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
"""
Expand Down Expand Up @@ -616,6 +617,7 @@ def expand_data( # pylint: disable=too-many-locals
current_array_level = level

name = column["name"]
values: Optional[Union[str, List[Any]]]

if column["type"].startswith("ARRAY("):
# keep processing array children; we append to the right so that
Expand All @@ -627,6 +629,8 @@ def expand_data( # pylint: disable=too-many-locals
while i < len(data):
row = data[i]
values = row.get(name)
if isinstance(values, str):
row[name] = values = destringify(values)
if values:
# how many extra rows we need to unnest the data?
extra_rows = len(values) - 1
Expand All @@ -653,12 +657,15 @@ def expand_data( # pylint: disable=too-many-locals
# expand columns; we append them to the left so they are added
# immediately after the parent
expanded = get_children(column)
to_process.extendleft((column, level) for column in expanded)
to_process.extendleft((column, level) for column in expanded[::-1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reversed expanded so that sub-columns are added in the right order, otherwise we get:

a ROW(b, c) => a, a.c, a.b

expanded_columns.extend(expanded)

# expand row objects into new columns
for row in data:
for value, col in zip(row.get(name) or [], expanded):
values = row.get(name) or []
if isinstance(values, str):
row[name] = values = cast(List[Any], destringify(values))
for value, col in zip(values, expanded):
row[col["name"]] = value

data = [
Expand Down
4 changes: 4 additions & 0 deletions superset/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def stringify_values(array: np.ndarray) -> np.ndarray:
return vstringify(array)


def destringify(obj: str) -> Any:
return json.loads(obj)


class SupersetResultSet:
def __init__( # pylint: disable=too-many-locals,too-many-branches
self,
Expand Down
59 changes: 58 additions & 1 deletion tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def test_presto_expand_data_with_complex_row_columns(self):
"name": "row_column",
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
},
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
{"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
{"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
]
expected_data = [
{
Expand Down Expand Up @@ -433,3 +433,60 @@ def test_query_cost_formatter(self):
}
]
self.assertEqual(formatted_cost, expected)

@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_expand_data_array(self):
cols = [
{"name": "event_id", "type": "VARCHAR", "is_date": False},
{"name": "timestamp", "type": "BIGINT", "is_date": False},
{
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_date": False,
},
]
data = [
{
"event_id": "abcdef01-2345-6789-abcd-ef0123456789",
"timestamp": "1595895506219",
"user": '[1, "JOHN", "DOE"]',
}
]
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
cols, data
)
expected_cols = [
{"name": "event_id", "type": "VARCHAR", "is_date": False},
{"name": "timestamp", "type": "BIGINT", "is_date": False},
{
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_date": False,
},
{"name": "user.id", "type": "BIGINT"},
{"name": "user.first_name", "type": "VARCHAR"},
{"name": "user.last_name", "type": "VARCHAR"},
]
expected_data = [
{
"event_id": "abcdef01-2345-6789-abcd-ef0123456789",
"timestamp": "1595895506219",
"user": [1, "JOHN", "DOE"],
"user.id": 1,
"user.first_name": "JOHN",
"user.last_name": "DOE",
}
]
expected_expanded_cols = [
{"name": "user.id", "type": "BIGINT"},
{"name": "user.first_name", "type": "VARCHAR"},
{"name": "user.last_name", "type": "VARCHAR"},
]

self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)