diff --git a/ludwig/features/date_feature.py b/ludwig/features/date_feature.py index 68cd644c573..8d6957a18dc 100644 --- a/ludwig/features/date_feature.py +++ b/ludwig/features/date_feature.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== import logging -from datetime import datetime +from datetime import date, datetime from typing import Dict, List import numpy as np @@ -66,6 +66,8 @@ def date_to_list(date_value, datetime_format, preprocessing_parameters): try: if isinstance(date_value, datetime): datetime_obj = date_value + elif isinstance(date_value, date): + datetime_obj = datetime.combine(date=date_value, time=datetime.min.time()) elif isinstance(date_value, str) and datetime_format is not None: try: datetime_obj = datetime.strptime(date_value, datetime_format) diff --git a/tests/ludwig/features/test_date_feature.py b/tests/ludwig/features/test_date_feature.py index 2c7515a8f70..f527379b92b 100644 --- a/tests/ludwig/features/test_date_feature.py +++ b/tests/ludwig/features/test_date_feature.py @@ -1,5 +1,5 @@ from copy import deepcopy -from datetime import datetime +from datetime import date, datetime from typing import Any, List import pytest @@ -157,3 +157,26 @@ def test_date_to_list__UsesFillValueOnInvalidDate(): 0, 0, ] + + +@pytest.fixture(scope="module") +def date_obj(): + return date.fromisoformat("2022-06-25") + + +@pytest.fixture(scope="module") +def date_obj_vec(): + return create_vector_from_datetime_obj(datetime.fromisoformat("2022-06-25")) + + +def test_date_object_to_list(date_obj, date_obj_vec, fill_value): + """Test support for datetime.date object conversion. + + Args: + date_obj: Date object to convert into a vector + date_obj_vector: Expected vector version of `date_obj` + """ + computed_date_vec = date_feature.DateInputFeature.date_to_list( + date_obj, None, preprocessing_parameters={MISSING_VALUE_STRATEGY: FILL_WITH_CONST, "fill_value": fill_value} + ) + assert computed_date_vec == date_obj_vec