From f9e9b442c2666f7d4b4713b59547a2c365e58a16 Mon Sep 17 00:00:00 2001 From: Matthew Owen Date: Thu, 25 Jul 2024 12:21:54 -0700 Subject: [PATCH] add tests, debug and tweak code Signed-off-by: Matthew Owen --- python/ray/data/_internal/numpy_support.py | 24 +++++++--------- python/ray/data/tests/test_numpy_support.py | 32 +++++++++++++++++++++ 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/python/ray/data/_internal/numpy_support.py b/python/ray/data/_internal/numpy_support.py index 4e8f2abc78ae..e61f77577dac 100644 --- a/python/ray/data/_internal/numpy_support.py +++ b/python/ray/data/_internal/numpy_support.py @@ -40,27 +40,25 @@ def validate_numpy_batch(batch: Union[Dict[str, np.ndarray], Dict[str, list]]) - f"({_truncated_repr(batch)})" ) -def _detect_highest_datetime_precision_dtype(datetime_list: List[datetime]) -> str: - highest_precision = 'datetime64[D]' # Start with day precision +def _detect_highest_datetime_precision(datetime_list: List[datetime]) -> str: + highest_precision = 'D' for dt in datetime_list: - if dt.microsecond != 0: - highest_precision = 'datetime64[ns]' + if dt.microsecond != 0 and dt.microsecond % 1000 != 0: + highest_precision = 'us' break - elif dt.second != 0: - highest_precision = 'datetime64[s]' - elif dt.minute != 0: - highest_precision = 'datetime64[m]' - elif dt.hour != 0: - highest_precision = 'datetime64[h]' + elif dt.microsecond != 0 and dt.microsecond % 1000 == 0: + highest_precision = 'ms' + elif dt.hour != 0 or dt.minute != 0 or dt.second != 0: + # pyarrow does not support h or m, use s for those cases too + highest_precision = 's' return highest_precision def _convert_datetime_list_to_array(datetime_list: List[datetime]) -> np.ndarray: - # Detect highest precision - dtype_with_precision = _detect_highest_datetime_precision_dtype(datetime_list) + precision = _detect_highest_datetime_precision(datetime_list) - return np.array([np.datetime64(dt) for dt in datetime_list], dtype=dtype_with_precision) + return np.array([np.datetime64(dt, precision) for dt in datetime_list], dtype=f"datetime64[{precision}]") def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any: diff --git a/python/ray/data/tests/test_numpy_support.py b/python/ray/data/tests/test_numpy_support.py index 3bf29f60e05b..7f7cec997642 100644 --- a/python/ray/data/tests/test_numpy_support.py +++ b/python/ray/data/tests/test_numpy_support.py @@ -6,6 +6,7 @@ from ray.air.util.tensor_extensions.utils import create_ragged_ndarray from ray.data.tests.conftest import * # noqa from ray.tests.conftest import * # noqa +from datetime import datetime class UserObj: @@ -45,6 +46,37 @@ def test_list_of_objects(ray_start_regular_shared): output = do_map_batches(data) assert_structure_equals(output, np.array([1, 2, 3, UserObj()])) +DATETIME_DAY_PRECISION = datetime(year=2024, month=1, day=1) +DATETIME_HOUR_PRECISION = datetime(year=2024, month=1, day=1, hour=1) +DATETIME_MIN_PRECISION = datetime(year=2024, month=1, day=1, minute=1) +DATETIME_SEC_PRECISION = datetime(year=2024, month=1, day=1, second=1) +DATETIME_MILLISEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1000) +DATETIME_MICROSEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1) + +DATETIME64_DAY_PRECISION = np.datetime64("2024-01-01") +DATETIME64_HOUR_PRECISION = np.datetime64("2024-01-01T01:00", 's') +DATETIME64_MIN_PRECISION = np.datetime64("2024-01-01T00:01", 's') +DATETIME64_SEC_PRECISION = np.datetime64("2024-01-01T00:00:01") +DATETIME64_MILLISEC_PRECISION = np.datetime64("2024-01-01T00:00:00.001") +DATETIME64_MICROSEC_PRECISION = np.datetime64("2024-01-01T00:00:00.000001") + +@pytest.mark.parametrize( + "data,expected_output", + [ + ([DATETIME_DAY_PRECISION], np.array([DATETIME64_DAY_PRECISION])), + ([DATETIME_HOUR_PRECISION], np.array([DATETIME64_HOUR_PRECISION])), + ([DATETIME_MIN_PRECISION], np.array([DATETIME64_MIN_PRECISION])), + ([DATETIME_SEC_PRECISION], np.array([DATETIME64_SEC_PRECISION])), + ([DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_MILLISEC_PRECISION])), + ([DATETIME_MICROSEC_PRECISION], np.array([DATETIME64_MICROSEC_PRECISION])), + ([DATETIME_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION], dtype="datetime64[us]")), + ([DATETIME_SEC_PRECISION, DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_SEC_PRECISION, DATETIME_MILLISEC_PRECISION], dtype="datetime64[ms]")), + ] +) +def test_list_of_datetimes(data, expected_output, ray_start_regular_shared): + output = do_map_batches(data) + assert_structure_equals(output, expected_output) + def test_array_like(ray_start_regular_shared): data = torch.Tensor([1, 2, 3])