Skip to content

Commit

Permalink
add tests, debug and tweak code
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Owen <mowen@anyscale.com>
  • Loading branch information
omatthew98 committed Jul 25, 2024
1 parent 6c59540 commit f9e9b44
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
24 changes: 11 additions & 13 deletions python/ray/data/_internal/numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,25 @@ def validate_numpy_batch(batch: Union[Dict[str, np.ndarray], Dict[str, list]]) -
f"({_truncated_repr(batch)})"
)

def _detect_highest_datetime_precision_dtype(datetime_list: List[datetime]) -> str:
highest_precision = 'datetime64[D]' # Start with day precision
def _detect_highest_datetime_precision(datetime_list: List[datetime]) -> str:
highest_precision = 'D'

for dt in datetime_list:
if dt.microsecond != 0:
highest_precision = 'datetime64[ns]'
if dt.microsecond != 0 and dt.microsecond % 1000 != 0:
highest_precision = 'us'
break
elif dt.second != 0:
highest_precision = 'datetime64[s]'
elif dt.minute != 0:
highest_precision = 'datetime64[m]'
elif dt.hour != 0:
highest_precision = 'datetime64[h]'
elif dt.microsecond != 0 and dt.microsecond % 1000 == 0:
highest_precision = 'ms'
elif dt.hour != 0 or dt.minute != 0 or dt.second != 0:
# pyarrow does not support h or m, use s for those cases too
highest_precision = 's'

return highest_precision

def _convert_datetime_list_to_array(datetime_list: List[datetime]) -> np.ndarray:
# Detect highest precision
dtype_with_precision = _detect_highest_datetime_precision_dtype(datetime_list)
precision = _detect_highest_datetime_precision(datetime_list)

return np.array([np.datetime64(dt) for dt in datetime_list], dtype=dtype_with_precision)
return np.array([np.datetime64(dt, precision) for dt in datetime_list], dtype=f"datetime64[{precision}]")


def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any:
Expand Down
32 changes: 32 additions & 0 deletions python/ray/data/tests/test_numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
from ray.data.tests.conftest import * # noqa
from ray.tests.conftest import * # noqa
from datetime import datetime


class UserObj:
Expand Down Expand Up @@ -45,6 +46,37 @@ def test_list_of_objects(ray_start_regular_shared):
output = do_map_batches(data)
assert_structure_equals(output, np.array([1, 2, 3, UserObj()]))

DATETIME_DAY_PRECISION = datetime(year=2024, month=1, day=1)
DATETIME_HOUR_PRECISION = datetime(year=2024, month=1, day=1, hour=1)
DATETIME_MIN_PRECISION = datetime(year=2024, month=1, day=1, minute=1)
DATETIME_SEC_PRECISION = datetime(year=2024, month=1, day=1, second=1)
DATETIME_MILLISEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1000)
DATETIME_MICROSEC_PRECISION = datetime(year=2024, month=1, day=1, microsecond=1)

DATETIME64_DAY_PRECISION = np.datetime64("2024-01-01")
DATETIME64_HOUR_PRECISION = np.datetime64("2024-01-01T01:00", 's')
DATETIME64_MIN_PRECISION = np.datetime64("2024-01-01T00:01", 's')
DATETIME64_SEC_PRECISION = np.datetime64("2024-01-01T00:00:01")
DATETIME64_MILLISEC_PRECISION = np.datetime64("2024-01-01T00:00:00.001")
DATETIME64_MICROSEC_PRECISION = np.datetime64("2024-01-01T00:00:00.000001")

@pytest.mark.parametrize(
"data,expected_output",
[
([DATETIME_DAY_PRECISION], np.array([DATETIME64_DAY_PRECISION])),
([DATETIME_HOUR_PRECISION], np.array([DATETIME64_HOUR_PRECISION])),
([DATETIME_MIN_PRECISION], np.array([DATETIME64_MIN_PRECISION])),
([DATETIME_SEC_PRECISION], np.array([DATETIME64_SEC_PRECISION])),
([DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_MILLISEC_PRECISION])),
([DATETIME_MICROSEC_PRECISION], np.array([DATETIME64_MICROSEC_PRECISION])),
([DATETIME_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_MICROSEC_PRECISION, DATETIME_MILLISEC_PRECISION], dtype="datetime64[us]")),
([DATETIME_SEC_PRECISION, DATETIME_MILLISEC_PRECISION], np.array([DATETIME64_SEC_PRECISION, DATETIME_MILLISEC_PRECISION], dtype="datetime64[ms]")),
]
)
def test_list_of_datetimes(data, expected_output, ray_start_regular_shared):
output = do_map_batches(data)
assert_structure_equals(output, expected_output)


def test_array_like(ray_start_regular_shared):
data = torch.Tensor([1, 2, 3])
Expand Down

0 comments on commit f9e9b44

Please sign in to comment.