diff --git a/python/ray/data/block.py b/python/ray/data/block.py index 56a3c6dbc40d2..fcab3feb67eb5 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -72,6 +72,12 @@ class BlockType(Enum): # returned from batch UDFs. DataBatch = Union["pyarrow.Table", "pandas.DataFrame", Dict[str, np.ndarray]] +# User-facing data column type. This is the data type for data that is supplied to and +# returned from column UDFs. +DataBatchColumn = Union[ + "pyarrow.ChunkedArray", "pyarrow.Array", "pandas.Series", np.ndarray +] + # A class type that implements __call__. CallableClass = type diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index d576b8eb2ea76..779a5bd3295e7 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -87,6 +87,7 @@ Block, BlockAccessor, DataBatch, + DataBatchColumn, T, U, UserDefinedFunction, @@ -529,7 +530,8 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: compute: This argument is deprecated. Use ``concurrency`` argument. batch_format: If ``"default"`` or ``"numpy"``, batches are ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are - ``pandas.DataFrame``. + ``pandas.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. zero_copy_batch: Whether ``fn`` should be provided zero-copy, read-only batches. If this is ``True`` and no copy is required for the ``batch_format`` conversion, the batch is a zero-copy, read-only @@ -700,16 +702,21 @@ def _map_batches_without_batch_size_validation( def add_column( self, col: str, - fn: Callable[["pandas.DataFrame"], "pandas.Series"], + fn: Callable[ + [DataBatch], + DataBatchColumn, + ], *, + batch_format: Optional[str] = "pandas", compute: Optional[str] = None, concurrency: Optional[Union[int, Tuple[int, int]]] = None, **ray_remote_args, ) -> "Dataset": """Add the given column to the dataset. - A function generating the new column values given the batch in pandas - format must be specified. + A function generating the new column values given the batch in pyarrow or pandas + format must be specified. This function must operate on batches of + `batch_format`. Examples: @@ -729,11 +736,6 @@ def add_column( id int64 new_id int64 - Overwrite the existing values with zeros. - - >>> ds.add_column("id", lambda df: 0).take(3) - [{'id': 0}, {'id': 0}, {'id': 0}] - Time complexity: O(dataset size / parallelism) Args: @@ -741,6 +743,11 @@ def add_column( column is overwritten. fn: Map function generating the column values given a batch of records in pandas format. + batch_format: If ``"default"`` or ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are + ``pandas.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. If ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. compute: This argument is deprecated. Use ``concurrency`` argument. concurrency: The number of Ray workers to use concurrently. For a fixed-sized worker pool of size ``n``, specify ``concurrency=n``. For @@ -749,17 +756,72 @@ def add_column( ray_remote_args: Additional resource requirements to request from ray (e.g., num_gpus=1 to request GPUs for the map tasks). """ + # Check that batch_format + accepted_batch_formats = ["pandas", "pyarrow", "numpy"] + if batch_format not in accepted_batch_formats: + raise ValueError( + f"batch_format argument must be on of {accepted_batch_formats}, " + f"got: {batch_format}" + ) - def add_column(batch: "pandas.DataFrame") -> "pandas.DataFrame": - batch.loc[:, col] = fn(batch) - return batch + def add_column(batch: DataBatch) -> DataBatch: + column = fn(batch) + if batch_format == "pandas": + import pandas as pd + + assert isinstance(column, pd.Series), ( + f"For pandas batch format, the function must return a pandas " + f"Series, got: {type(column)}" + ) + if col in batch: + raise ValueError( + f"Trying to add an existing column with name" f" {col}" + ) + batch.loc[:, col] = column + return batch + elif batch_format == "pyarrow": + import pyarrow as pa + + assert isinstance(column, (pa.Array, pa.ChunkedArray)), ( + f"For pyarrow batch format, the function must return a pyarrow " + f"Array, got: {type(column)}" + ) + # Historically, this method was written for pandas batch format. + # To resolve https://github.com/ray-project/ray/issues/48090, + # we also allow pyarrow batch format which is preferred but would be + # a breaking change to enforce. + + # For pyarrow, the index of the column will be -1 if it is missing in + # which case we'll want to append it + column_idx = batch.schema.get_field_index(col) + if column_idx == -1: + # Append the column to the table + return batch.append_column(col, column) + else: + raise ValueError( + f"Trying to add an existing column with name {col}" + ) + + else: + # batch format is assumed to be numpy since we checked at the + # beginning of the add_column function + assert isinstance(column, np.ndarray), ( + f"For numpy batch format, the function must return a " + f"numpy.ndarray, got: {type(column)}" + ) + if col in batch: + raise ValueError( + f"Trying to add an existing column with name" f" {col}" + ) + batch[col] = column + return batch if not callable(fn): raise ValueError("`fn` must be callable, got {}".format(fn)) return self.map_batches( add_column, - batch_format="pandas", # TODO(ekl) we should make this configurable. + batch_format=batch_format, compute=compute, concurrency=concurrency, zero_copy_batch=False, @@ -801,7 +863,7 @@ def drop_columns( Args: cols: Names of the columns to drop. If any name does not exist, - an exception is raised. + an exception is raised. Column names must be unique. compute: This argument is deprecated. Use ``concurrency`` argument. concurrency: The number of Ray workers to use concurrently. For a fixed-sized worker pool of size ``n``, specify ``concurrency=n``. For an autoscaling @@ -810,12 +872,15 @@ def drop_columns( ray (e.g., num_gpus=1 to request GPUs for the map tasks). """ # noqa: E501 + if len(cols) != len(set(cols)): + raise ValueError(f"drop_columns expects unique column names, got: {cols}") + def drop_columns(batch): - return batch.drop(columns=cols) + return batch.drop(cols) return self.map_batches( drop_columns, - batch_format="pandas", + batch_format="pyarrow", zero_copy_batch=True, compute=compute, concurrency=concurrency, @@ -4316,7 +4381,8 @@ def to_tf( If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns. A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``. - >>> ds = ds.add_column("sample weights", lambda df: 1) + >>> import pandas as pd + >>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df))) >>> ds.to_tf(feature_columns="features", label_columns="target", additional_columns="sample weights") <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))> diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 58e9a1b7355eb..2f19111af80f7 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -734,7 +734,8 @@ def to_tf( If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns. A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``. - >>> ds = ds.add_column("sample weights", lambda df: 1) + >>> import pandas as pd + >>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df))) >>> it = ds.iterator() >>> it.to_tf(feature_columns="sepal length (cm)", label_columns="target", additional_columns="sample weights") <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))> diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 9b1a4f8d4575c..d4e7e2c374de9 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pyarrow.parquet as pq import pytest @@ -330,18 +331,99 @@ def map_generator(item: dict) -> Iterator[int]: def test_add_column(ray_start_regular_shared): - ds = ray.data.range(5).add_column("foo", lambda x: 1) + """Tests the add column API.""" + + # Test with pyarrow batch format + ds = ray.data.range(5).add_column( + "foo", lambda x: pa.array([1] * x.num_rows), batch_format="pyarrow" + ) + assert ds.take(1) == [{"id": 0, "foo": 1}] + + # Test with chunked array batch format + ds = ray.data.range(5).add_column( + "foo", lambda x: pa.chunked_array([[1] * x.num_rows]), batch_format="pyarrow" + ) + assert ds.take(1) == [{"id": 0, "foo": 1}] + + ds = ray.data.range(5).add_column( + "foo", lambda x: pc.add(x["id"], 1), batch_format="pyarrow" + ) + assert ds.take(1) == [{"id": 0, "foo": 1}] + + # Adding a column that is already there should result in an error + with pytest.raises( + ray.exceptions.UserCodeException, + match="Trying to add an existing column with name id", + ): + ds = ray.data.range(5).add_column( + "id", lambda x: pc.add(x["id"], 1), batch_format="pyarrow" + ) + assert ds.take(2) == [{"id": 1}, {"id": 2}] + + # Adding a column in the wrong format should result in an error + with pytest.raises( + ray.exceptions.UserCodeException, match="For pyarrow batch " "format" + ): + ds = ray.data.range(5).add_column("id", lambda x: [1], batch_format="pyarrow") + assert ds.take(2) == [{"id": 1}, {"id": 2}] + + # Test with numpy batch format + ds = ray.data.range(5).add_column( + "foo", lambda x: np.array([1] * len(list(x.keys())[0])), batch_format="numpy" + ) + assert ds.take(1) == [{"id": 0, "foo": 1}] + + ds = ray.data.range(5).add_column( + "foo", lambda x: np.add(x["id"], 1), batch_format="numpy" + ) + assert ds.take(1) == [{"id": 0, "foo": 1}] + + # Adding a column that is already there should result in an error + with pytest.raises( + ray.exceptions.UserCodeException, + match="Trying to add an existing column with name id", + ): + ds = ray.data.range(5).add_column( + "id", lambda x: np.add(x["id"], 1), batch_format="numpy" + ) + assert ds.take(2) == [{"id": 1}, {"id": 2}] + + # Adding a column in the wrong format should result in an error + with pytest.raises( + ray.exceptions.UserCodeException, match="For numpy batch " "format" + ): + ds = ray.data.range(5).add_column("id", lambda x: [1], batch_format="numpy") + assert ds.take(2) == [{"id": 1}, {"id": 2}] + + # Test with pandas batch format + ds = ray.data.range(5).add_column("foo", lambda x: pd.Series([1] * x.shape[0])) assert ds.take(1) == [{"id": 0, "foo": 1}] ds = ray.data.range(5).add_column("foo", lambda x: x["id"] + 1) assert ds.take(1) == [{"id": 0, "foo": 1}] - ds = ray.data.range(5).add_column("id", lambda x: x["id"] + 1) - assert ds.take(2) == [{"id": 1}, {"id": 2}] + # Adding a column that is already there should result in an error + with pytest.raises( + ray.exceptions.UserCodeException, + match="Trying to add an existing column with name id", + ): + ds = ray.data.range(5).add_column("id", lambda x: x["id"] + 1) + assert ds.take(2) == [{"id": 1}, {"id": 2}] + + # Adding a column in the wrong format should result in an error + with pytest.raises( + ray.exceptions.UserCodeException, match="For pandas batch " "format" + ): + ds = ray.data.range(5).add_column("id", lambda x: [1], batch_format="pandas") + assert ds.take(2) == [{"id": 1}, {"id": 2}] with pytest.raises(ValueError): ds = ray.data.range(5).add_column("id", 0) + # Test that an invalid batch_format raises an error + with pytest.raises(ValueError): + ray.data.range(5).add_column("foo", lambda x: x["id"] + 1, batch_format="foo") + @pytest.mark.parametrize("names", (["foo", "bar"], {"spam": "foo", "ham": "bar"})) def test_rename_columns(ray_start_regular_shared, names): @@ -362,14 +444,15 @@ def test_drop_columns(ray_start_regular_shared, tmp_path): assert ds.drop_columns(["col2"]).take(1) == [{"col1": 1, "col3": 3}] assert ds.drop_columns(["col1", "col3"]).take(1) == [{"col2": 2}] assert ds.drop_columns([]).take(1) == [{"col1": 1, "col2": 2, "col3": 3}] - assert ds.drop_columns(["col1", "col2", "col3"]).take(1) == [{}] - assert ds.drop_columns(["col1", "col1", "col2", "col1"]).take(1) == [ - {"col3": 3} - ] + assert ds.drop_columns(["col1", "col2", "col3"]).take(1) == [] + assert ds.drop_columns(["col1", "col2"]).take(1) == [{"col3": 3}] # Test dropping non-existent column with pytest.raises((UserCodeException, KeyError)): ds.drop_columns(["dummy_col", "col1", "col2"]).materialize() + with pytest.raises(ValueError, match="drop_columns expects unique column names"): + ds1.drop_columns(["col1", "col2", "col2"]) + def test_select_columns(ray_start_regular_shared): # Test pandas and arrow diff --git a/python/ray/data/tests/test_mongo.py b/python/ray/data/tests/test_mongo.py index 97828aae6bea6..eb03aab39f806 100644 --- a/python/ray/data/tests/test_mongo.py +++ b/python/ray/data/tests/test_mongo.py @@ -93,13 +93,13 @@ def test_read_write_mongo(ray_start_regular_shared, start_mongo): override_num_blocks=2, ) assert ds._block_num_rows() == [3, 2] - assert str(ds) == ( - "Dataset(\n" - " num_rows=5,\n" - " schema={_id: fixed_size_binary[12], float_field: double, " - "int_field: int32}\n" - ")" - ) + assert ds.count() == 5 + assert ds.schema().names == ["_id", "float_field", "int_field"] + # We are not testing the datatype of _id here, because it varies per platform + assert ds.schema().types[1:] == [ + pa.float64(), + pa.int32(), + ] assert df.equals(ds.drop_columns(["_id"]).to_pandas()) # Read a subset of the collection. @@ -111,13 +111,8 @@ def test_read_write_mongo(ray_start_regular_shared, start_mongo): override_num_blocks=2, ) assert ds._block_num_rows() == [2, 1] - assert str(ds) == ( - "Dataset(\n" - " num_rows=3,\n" - " schema={_id: fixed_size_binary[12], float_field: double, " - "int_field: int32}\n" - ")" - ) + assert ds.count() == 3 + assert ds.schema().names == ["_id", "float_field", "int_field"] df[df["int_field"] < 3].equals(ds.drop_columns(["_id"]).to_pandas()) # Read with auto-tuned parallelism. @@ -126,13 +121,14 @@ def test_read_write_mongo(ray_start_regular_shared, start_mongo): database=foo_db, collection=foo_collection, ) - assert str(ds) == ( - "Dataset(\n" - " num_rows=5,\n" - " schema={_id: fixed_size_binary[12], float_field: double, " - "int_field: int32}\n" - ")" - ) + + assert ds.count() == 5 + assert ds.schema().names == ["_id", "float_field", "int_field"] + # We are not testing the datatype of _id here, because it varies per platform + assert ds.schema().types[1:] == [ + pa.float64(), + pa.int32(), + ] assert df.equals(ds.drop_columns(["_id"]).to_pandas()) # Read with a parallelism larger than number of rows. @@ -142,13 +138,14 @@ def test_read_write_mongo(ray_start_regular_shared, start_mongo): collection=foo_collection, override_num_blocks=1000, ) - assert str(ds) == ( - "Dataset(\n" - " num_rows=5,\n" - " schema={_id: fixed_size_binary[12], float_field: double, " - "int_field: int32}\n" - ")" - ) + + assert ds.count() == 5 + assert ds.schema().names == ["_id", "float_field", "int_field"] + # We are not testing the datatype of _id here, because it varies per platform + assert ds.schema().types[1:] == [ + pa.float64(), + pa.int32(), + ] assert df.equals(ds.drop_columns(["_id"]).to_pandas()) # Add a column and then write back to MongoDB. diff --git a/python/ray/data/tests/test_object_gc.py b/python/ray/data/tests/test_object_gc.py index b56c4542618d0..2b1947e0498d6 100644 --- a/python/ray/data/tests/test_object_gc.py +++ b/python/ray/data/tests/test_object_gc.py @@ -1,6 +1,7 @@ import sys import threading +import pandas as pd import pytest import ray @@ -107,7 +108,7 @@ def test_tf_iteration(shutdown_only): # The size of dataset is 500*(80*80*4)*8B, about 100MB. ds = ray.data.range_tensor( 500, shape=(80, 80, 4), override_num_blocks=100 - ).add_column("label", lambda x: 1) + ).add_column("label", lambda df: pd.Series([1] * len(df))) # to_tf check_to_tf_no_spill(ctx, ds.map(lambda x: x))