-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[Data] - Improve performance for unify_schemas
#55880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] - Improve performance for unify_schemas
#55880
Conversation
Signed-off-by: Goutam V <goutam@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request significantly improves the performance of unify_schemas by refactoring it to use a single pass for gathering column statistics. The new implementation is not only faster but also more readable and maintainable. The use of a ColAgg dataclass to hold column statistics is a clean approach. I've found one potential issue with override precedence that could lead to incorrect type unification in some cases. Otherwise, this is an excellent improvement.
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
srinathk10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be good to add in a unify_schema test case on lots of schema (10) and wide schemas (10k) with CI assuming it all get's done < 1sec.
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
| schemas[0].remove_metadata() | ||
| schemas_to_unify = [schemas[0]] | ||
| for schema in schemas[1:]: | ||
| schema.remove_metadata() | ||
| if not schema.equals(schemas[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Let's actually do a set and (later we can raise a PR in Pyarrow to start caching the hashes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll use dict.fromkeys() instead to preserve ordering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually spark schemas are dicts and they're unhashable. Fails this test: test_raydp: df = ds.to_spark(spark)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Input has to be PA schema, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you look at this stack trace:
[2025-09-10T22:10:49Z] _____________________________ test_raydp_roundtrip _____________________________
--
| [2025-09-10T22:10:49Z]
| [2025-09-10T22:10:49Z] spark = <pyspark.sql.session.SparkSession object at 0x7f086c7c2190>
| [2025-09-10T22:10:49Z]
| [2025-09-10T22:10:49Z] def test_raydp_roundtrip(spark):
| [2025-09-10T22:10:49Z] spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["one", "two"])
| [2025-09-10T22:10:49Z] rows = [(r.one, r.two) for r in spark_df.take(3)]
| [2025-09-10T22:10:49Z] ds = ray.data.from_spark(spark_df)
| [2025-09-10T22:10:49Z] values = [(r["one"], r["two"]) for r in ds.take(6)]
| [2025-09-10T22:10:49Z] assert values == rows
| [2025-09-10T22:10:49Z] > df = ds.to_spark(spark)
| [2025-09-10T22:10:49Z]
| [2025-09-10T22:10:49Z] python/ray/data/tests/test_raydp.py:30:
| [2025-09-10T22:10:49Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
| [2025-09-10T22:10:49Z] /rayci/python/ray/data/dataset.py:5594: in to_spark
| [2025-09-10T22:10:49Z] schema = self.schema()
| [2025-09-10T22:10:49Z] /rayci/python/ray/data/dataset.py:3459: in schema
| [2025-09-10T22:10:49Z] base_schema = self._plan.schema(fetch_if_missing=False)
| [2025-09-10T22:10:49Z] /rayci/python/ray/data/_internal/plan.py:395: in schema
| [2025-09-10T22:10:49Z] schema = self._logical_plan.dag.infer_schema()
| [2025-09-10T22:10:49Z] /rayci/python/ray/data/_internal/logical/operators/from_operators.py:77: in infer_schema
| [2025-09-10T22:10:49Z] return unify_ref_bundles_schema(self._input_data)
| [2025-09-10T22:10:49Z] /rayci/python/ray/data/_internal/util.py:791: in unify_ref_bundles_schema
| [2025-09-10T22:10:49Z] return unify_schemas_with_validation(schemas_to_unify)
| [2025-09-10T22:10:49Z] /rayci/python/ray/data/_internal/util.py:775: in unify_schemas_with_validation
| [2025-09-10T22:10:49Z] return unify_schemas(schemas_to_unify, promote_types=True)
| [2025-09-10T22:10:49Z] /rayci/python/ray/data/_internal/arrow_ops/transform_pyarrow.py:325: in unify_schemas
| [2025-09-10T22:10:49Z] schemas_to_unify = list(dict.fromkeys(schemas))
| [2025-09-10T22:10:49Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
| [2025-09-10T22:10:49Z]
| [2025-09-10T22:10:49Z] > ???
| [2025-09-10T22:10:49Z] E TypeError: unhashable type: 'dict'
It seems that the schema becomes a dict.infer_schema() seems to be the one that converts it.
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
| # If we raise only on non tensor errors, it fails to unify PythonObjectType and pyarrow primitives. | ||
| # Look at test_pyarrow_conversion_error_handling for an example. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexeykudinkin just fyi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack. What do exceptions look like in this cases?
I want to limit the scope of it as much as possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pyarrow.lib.ArrowTypeError: Unable to merge: Field my_data has incompatible types: string vs extension<ray.data.arrow_pickled_object<ArrowPythonObjectType>>
| schemas[0].remove_metadata() | ||
| schemas_to_unify = [schemas[0]] | ||
| for schema in schemas[1:]: | ||
| schema.remove_metadata() | ||
| if not schema.equals(schemas[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Input has to be PA schema, right?
| # If we raise only on non tensor errors, it fails to unify PythonObjectType and pyarrow primitives. | ||
| # Look at test_pyarrow_conversion_error_handling for an example. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack. What do exceptions look like in this cases?
I want to limit the scope of it as much as possible
alexeykudinkin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, minor comments
| if not (pyarrow.types.is_list(t) and pyarrow.types.is_null(t.value_type)): | ||
| return t | ||
| # Let PyArrow handle other cases | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this phase, it will error out because Arrow can't handle the case and we can't reconcile either. I'll clarify the comment.
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Signed-off-by: Goutam V. <goutam@anyscale.com>
Why are these changes needed?
Find all diverging schemas, coalesce them if possible, and do so recursively in the presence of structs.
Perform a single pass to gather stats for all columns across all schemas.
Related issue number
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.