Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ndarray type #623

Merged
merged 14 commits into from
May 1, 2024
10 changes: 7 additions & 3 deletions streaming/base/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

"""Utility function for converting spark dataframe to MDS dataset."""

from streaming.base.converters.dataframe_to_mds import (MAPPING_SPARK_TO_MDS, dataframe_to_mds,
dataframeToMDS)
from streaming.base.converters.dataframe_to_mds import (SPARK_TO_MDS, dataframe_to_mds,
dataframeToMDS, infer_dataframe_schema,
is_json_compatible)

__all__ = ['dataframeToMDS', 'dataframe_to_mds', 'MAPPING_SPARK_TO_MDS']
__all__ = [
'dataframeToMDS', 'dataframe_to_mds', 'SPARK_TO_MDS', 'infer_dataframe_schema',
'is_json_compatible'
]
101 changes: 69 additions & 32 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import (ArrayType, BinaryType, BooleanType, ByteType, DateType,
DayTimeIntervalType, DecimalType, DoubleType, FloatType,
IntegerType, LongType, MapType, ShortType, StringType,
IntegerType, LongType, NullType, ShortType, StringType,
StructField, StructType, TimestampNTZType, TimestampType)
except ImportError as e:
e.msg = get_import_exception_message(e.name, extra_deps='spark') # pyright: ignore
Expand All @@ -32,42 +32,64 @@

logger = logging.getLogger(__name__)

MAPPING_SPARK_TO_MDS = {
ByteType: 'uint8',
ShortType: 'uint16',
IntegerType: 'int',
LongType: 'int64',
FloatType: 'float32',
DoubleType: 'float64',
DecimalType: 'str_decimal',
StringType: 'str',
BinaryType: 'bytes',
BooleanType: None,
TimestampType: None,
TimestampNTZType: None,
DateType: None,
DayTimeIntervalType: None,
ArrayType: None,
MapType: None,
StructType: None,
StructField: None
SPARK_TO_MDS = {
ByteType(): 'uint8',
ShortType(): 'uint16',
IntegerType(): 'int32',
LongType(): 'int64',
FloatType(): 'float32',
DoubleType(): 'float64',
DecimalType(): 'str_decimal',
StringType(): 'str',
BinaryType(): 'bytes',
BooleanType(): None,
TimestampType(): None,
TimestampNTZType(): None,
DateType(): None,
DayTimeIntervalType(): None,
ArrayType(IntegerType()): 'ndarray:int32',
ArrayType(ShortType()): 'ndarray:int16',
ArrayType(LongType()): 'ndarray:int64',
ArrayType(FloatType()): 'ndarray:float32',
ArrayType(DoubleType()): 'ndarray:float64',
}


def is_json_compatible(data_type: Any):
"""Recursively check if a given PySpark DataType is JSON compatible.

JSON = Union[Dict[str, 'JSON'], List['JSON'], str, float, int, bool, None]

Args:
data_type (Any): A pyspark schema for a column of the input spark dataframe.

Returns:
(bool): True if data_type is JSON compatible.
"""
if isinstance(data_type, StructType):
return all(is_json_compatible(field.dataType) for field in data_type.fields)
elif isinstance(data_type, ArrayType):
return is_json_compatible(data_type.elementType)
elif isinstance(data_type, (StringType, IntegerType, FloatType, BooleanType, NullType)):
return True
else:
return False


def infer_dataframe_schema(dataframe: DataFrame,
user_defined_cols: Optional[Dict[str, Any]] = None) -> Optional[Dict]:
"""Retrieve schema to construct a dictionary or do sanity check for MDSWriter.
"""Retrieve schema to construct a dictionary or do sanity check for dataframe_to_mds.

Args:
dataframe (spark dataframe): dataframe to inspect schema
user_defined_cols (Optional[Dict[str, Any]]): user specified schema for MDSWriter
user_defined_cols (Optional[Dict[str, Any]]): user specified schema for dataframe_to_mds

Returns:
If user_defined_cols is None, return schema_dict (dict): column name and dtypes that are
supported by MDSWriter, else None

Raises:
ValueError if any of the datatypes are unsupported by MDSWriter.
ValueError if any of the datatypes are unsupported by dataframe_to_mds.
"""

def map_spark_dtype(spark_data_type: Any) -> str:
Expand All @@ -82,24 +104,39 @@ def map_spark_dtype(spark_data_type: Any) -> str:
Raises:
raise ValueError if no mds datatype is found for input type
"""
mds_type = MAPPING_SPARK_TO_MDS.get(type(spark_data_type), None)
if issubclass(type(spark_data_type), DecimalType):
mds_type = SPARK_TO_MDS.get(DecimalType(), None)
else:
mds_type = SPARK_TO_MDS.get(spark_data_type, None)

if mds_type is None:
raise ValueError(f'{spark_data_type} is not supported by MDSWriter')
raise ValueError(f'{spark_data_type} is not supported by dataframe_to_mds')
return mds_type

# user has provided schema, we just check if mds supports the dtype
if user_defined_cols is not None:
mds_supported_dtypes = {
mds_type for mds_type in MAPPING_SPARK_TO_MDS.values() if mds_type is not None
}
mds_supported_dtypes = set(filter(bool, SPARK_TO_MDS.values()))

for col_name, user_dtype in user_defined_cols.items():
if col_name not in dataframe.columns:
raise ValueError(
f'{col_name} is not a column of input dataframe: {dataframe.columns}')
if user_dtype not in mds_supported_dtypes:
raise ValueError(f'{user_dtype} is not supported by MDSWriter')

if user_dtype.startswith('ndarray:'):
parts = user_dtype.split(':')
if len(parts) == 3:
user_dtype = ':'.join(parts[:-1])

actual_spark_dtype = dataframe.schema[col_name].dataType

if user_dtype not in mds_supported_dtypes:
if user_dtype == 'json':
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
if is_json_compatible(actual_spark_dtype):
continue
else:
raise ValueError(f'{col_name} can not be encoded by MDS JSON.')
raise ValueError(f'{user_dtype} is not supported by dataframe_to_mds')

mapped_mds_dtype = map_spark_dtype(actual_spark_dtype)
if user_dtype != mapped_mds_dtype:
raise ValueError(
Expand All @@ -112,10 +149,10 @@ def map_spark_dtype(spark_data_type: Any) -> str:

for field in schema:
dtype = map_spark_dtype(field.dataType)
if dtype in _encodings:
if dtype.split(':')[0] in _encodings:
schema_dict[field.name] = dtype
else:
raise ValueError(f'{dtype} is not supported by MDSWriter')
raise ValueError(f'{dtype} is not supported by dataframe_to_mds')
return schema_dict


Expand Down
Loading
Loading