diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index 3b09b114..1e976da4 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -259,14 +259,22 @@ def get_locations(blocks): ] def ray_dataset_to_spark_dataframe(spark: sql.SparkSession, - arrow_schema, + dataset_schema, blocks: List[ObjectRef], locations = None) -> DataFrame: locations = get_locations(blocks) - if not isinstance(arrow_schema, pa.lib.Schema): - if hasattr(arrow_schema, "base_schema") and \ - not isinstance(arrow_schema.base_schema, pa.lib.Schema): - raise RuntimeError(f"Schema is {type(arrow_schema)}, required pyarrow.lib.Schema. \n" \ + arrow_schema = dataset_schema + if not isinstance(dataset_schema, pa.lib.Schema): + if hasattr(dataset_schema, "base_schema"): + if isinstance(dataset_schema.base_schema, pa.lib.Schema): + arrow_schema = dataset_schema.base_schema + else: + raise RuntimeError(f"Schema is {type(dataset_schema.base_schema)}, " \ + f"required pyarrow.lib.Schema. \n" \ + f"to_spark does not support converting non-arrow ray datasets.") + else: + raise RuntimeError(f"Schema is {type(dataset_schema)}, " \ + f"required pyarrow.lib.Schema. \n" \ f"to_spark does not support converting non-arrow ray datasets.") schema = StructType() for field in arrow_schema: