diff --git a/python/ray/data/_internal/numpy_support.py b/python/ray/data/_internal/numpy_support.py index fbb54de7555bd..9e6a7c305dfb7 100644 --- a/python/ray/data/_internal/numpy_support.py +++ b/python/ray/data/_internal/numpy_support.py @@ -1,4 +1,5 @@ import collections +from datetime import datetime from typing import Any, Dict, List, Union import numpy as np @@ -40,6 +41,31 @@ def validate_numpy_batch(batch: Union[Dict[str, np.ndarray], Dict[str, list]]) - ) +def _detect_highest_datetime_precision(datetime_list: List[datetime]) -> str: + highest_precision = "D" + + for dt in datetime_list: + if dt.microsecond != 0 and dt.microsecond % 1000 != 0: + highest_precision = "us" + break + elif dt.microsecond != 0 and dt.microsecond % 1000 == 0: + highest_precision = "ms" + elif dt.hour != 0 or dt.minute != 0 or dt.second != 0: + # pyarrow does not support h or m, use s for those cases too + highest_precision = "s" + + return highest_precision + + +def _convert_datetime_list_to_array(datetime_list: List[datetime]) -> np.ndarray: + precision = _detect_highest_datetime_precision(datetime_list) + + return np.array( + [np.datetime64(dt, precision) for dt in datetime_list], + dtype=f"datetime64[{precision}]", + ) + + def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any: """Convert UDF columns (output of map_batches) to numpy, if possible. @@ -64,6 +90,9 @@ def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any: udf_return_col = np.expand_dims(udf_return_col[0], axis=0) return udf_return_col + if all(isinstance(elem, datetime) for elem in udf_return_col): + return _convert_datetime_list_to_array(udf_return_col) + # Try to convert list values into an numpy array via # np.array(), so users don't need to manually cast. # NOTE: we don't cast generic iterables, since types like diff --git a/python/ray/data/tests/test_numpy_support.py b/python/ray/data/tests/test_numpy_support.py index 3bf29f60e05b9..c14038918c0af 100644 --- a/python/ray/data/tests/test_numpy_support.py +++ b/python/ray/data/tests/test_numpy_support.py @@ -1,3 +1,5 @@ +from datetime import datetime + import numpy as np import pytest import torch @@ -46,6 +48,51 @@ def test_list_of_objects(ray_start_regular_shared): assert_structure_equals(output, np.array([1, 2, 3, UserObj()])) +DATETIME_DAY_PRECISION = datetime(year=2024, month=1, day=1) +DATETIME_HOUR_PRECISION = datetime(year=2024, month=1, day=1, hour=1) +DATETIME_MIN_PRECISION = datetime(year=2024, month=1, day=1, minute=1) +DATETIME_SEC_PRECISION = datetime(year=2024, month=1, day=1, second=1) +DATETIME_MILLISEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1000) +DATETIME_MICROSEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1) + +DATETIME64_DAY_PRECISION = np.datetime64("2024-01-01") +DATETIME64_HOUR_PRECISION = np.datetime64("2024-01-01T01:00", "s") +DATETIME64_MIN_PRECISION = np.datetime64("2024-01-01T00:01", "s") +DATETIME64_SEC_PRECISION = np.datetime64("2024-01-01T00:00:01") +DATETIME64_MILLISEC_PRECISION = np.datetime64("2024-01-01T00:00:00.001") +DATETIME64_MICROSEC_PRECISION = np.datetime64("2024-01-01T00:00:00.000001") + + +@pytest.mark.parametrize( + "data,expected_output", + [ + ([DATETIME_DAY_PRECISION], np.array([DATETIME64_DAY_PRECISION])), + ([DATETIME_HOUR_PRECISION], np.array([DATETIME64_HOUR_PRECISION])), + ([DATETIME_MIN_PRECISION], np.array([DATETIME64_MIN_PRECISION])), + ([DATETIME_SEC_PRECISION], np.array([DATETIME64_SEC_PRECISION])), + ([DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_MILLISEC_PRECISION])), + ([DATETIME_MICROSEC_PRECISION], np.array([DATETIME64_MICROSEC_PRECISION])), + ( + [DATETIME_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION], + np.array( + [DATETIME64_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION], + dtype="datetime64[us]", + ), + ), + ( + [DATETIME_SEC_PRECISION, DATETIME_MILLISEC_PRECISION], + np.array( + [DATETIME64_SEC_PRECISION, DATETIME_MILLISEC_PRECISION], + dtype="datetime64[ms]", + ), + ), + ], +) +def test_list_of_datetimes(data, expected_output, ray_start_regular_shared): + output = do_map_batches(data) + assert_structure_equals(output, expected_output) + + def test_array_like(ray_start_regular_shared): data = torch.Tensor([1, 2, 3]) output = do_map_batches(data)