Skip to content

Commit

Permalink
fix(presto): Handle ROW data stored as string (apache#10456)
Browse files Browse the repository at this point in the history
* Handle ROW data stored as string

* Use destringify

* Fix mypy

* Fix mypy with cast

* Bypass pylint
  • Loading branch information
betodealmeida authored and auxten committed Nov 20, 2020
1 parent 4119ae4 commit 37ce1ad
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
15 changes: 11 additions & 4 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from urllib import parse

import pandas as pd
Expand All @@ -40,6 +40,7 @@
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils

Expand Down Expand Up @@ -568,7 +569,7 @@ def get_all_datasource_names(
return datasource_names

@classmethod
def expand_data( # pylint: disable=too-many-locals
def expand_data( # pylint: disable=too-many-locals,too-many-branches
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
"""
Expand Down Expand Up @@ -616,6 +617,7 @@ def expand_data( # pylint: disable=too-many-locals
current_array_level = level

name = column["name"]
values: Optional[Union[str, List[Any]]]

if column["type"].startswith("ARRAY("):
# keep processing array children; we append to the right so that
Expand All @@ -627,6 +629,8 @@ def expand_data( # pylint: disable=too-many-locals
while i < len(data):
row = data[i]
values = row.get(name)
if isinstance(values, str):
row[name] = values = destringify(values)
if values:
# how many extra rows we need to unnest the data?
extra_rows = len(values) - 1
Expand All @@ -653,12 +657,15 @@ def expand_data( # pylint: disable=too-many-locals
# expand columns; we append them to the left so they are added
# immediately after the parent
expanded = get_children(column)
to_process.extendleft((column, level) for column in expanded)
to_process.extendleft((column, level) for column in expanded[::-1])
expanded_columns.extend(expanded)

# expand row objects into new columns
for row in data:
for value, col in zip(row.get(name) or [], expanded):
values = row.get(name) or []
if isinstance(values, str):
row[name] = values = cast(List[Any], destringify(values))
for value, col in zip(values, expanded):
row[col["name"]] = value

data = [
Expand Down
4 changes: 4 additions & 0 deletions superset/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def stringify_values(array: np.ndarray) -> np.ndarray:
return vstringify(array)


def destringify(obj: str) -> Any:
return json.loads(obj)


class SupersetResultSet:
def __init__( # pylint: disable=too-many-locals,too-many-branches
self,
Expand Down
59 changes: 58 additions & 1 deletion tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def test_presto_expand_data_with_complex_row_columns(self):
"name": "row_column",
"type": "ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR))",
},
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
{"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ2 VARCHAR)"},
{"name": "row_column.nested_row.nested_obj2", "type": "VARCHAR"},
{"name": "row_column.nested_obj1", "type": "VARCHAR"},
]
expected_data = [
{
Expand Down Expand Up @@ -433,3 +433,60 @@ def test_query_cost_formatter(self):
}
]
self.assertEqual(formatted_cost, expected)

@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_expand_data_array(self):
cols = [
{"name": "event_id", "type": "VARCHAR", "is_date": False},
{"name": "timestamp", "type": "BIGINT", "is_date": False},
{
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_date": False,
},
]
data = [
{
"event_id": "abcdef01-2345-6789-abcd-ef0123456789",
"timestamp": "1595895506219",
"user": '[1, "JOHN", "DOE"]',
}
]
actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
cols, data
)
expected_cols = [
{"name": "event_id", "type": "VARCHAR", "is_date": False},
{"name": "timestamp", "type": "BIGINT", "is_date": False},
{
"name": "user",
"type": "ROW(ID BIGINT, FIRST_NAME VARCHAR, LAST_NAME VARCHAR)",
"is_date": False,
},
{"name": "user.id", "type": "BIGINT"},
{"name": "user.first_name", "type": "VARCHAR"},
{"name": "user.last_name", "type": "VARCHAR"},
]
expected_data = [
{
"event_id": "abcdef01-2345-6789-abcd-ef0123456789",
"timestamp": "1595895506219",
"user": [1, "JOHN", "DOE"],
"user.id": 1,
"user.first_name": "JOHN",
"user.last_name": "DOE",
}
]
expected_expanded_cols = [
{"name": "user.id", "type": "BIGINT"},
{"name": "user.first_name", "type": "VARCHAR"},
{"name": "user.last_name", "type": "VARCHAR"},
]

self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)

0 comments on commit 37ce1ad

Please sign in to comment.