diff --git a/data-serving/reusable-data-service/reusable_data_service/model/case.py b/data-serving/reusable-data-service/reusable_data_service/model/case.py index 133cec4fb..b562681db 100644 --- a/data-serving/reusable-data-service/reusable_data_service/model/case.py +++ b/data-serving/reusable-data-service/reusable_data_service/model/case.py @@ -30,15 +30,26 @@ def from_json(cls, obj: str) -> type: def from_dict(cls, dictionary: dict[str, Any]) -> type: case = cls() for key in dictionary: - if key in ["confirmation_date"]: - # parse as an ISO 8601 date - date_dt = datetime.datetime.strptime( - dictionary[key], "%Y-%m-%dT%H:%M:%S.%fZ" - ) - date = date_dt.date() - setattr(case, key, date) + if key in cls.date_fields(): + # handle a few different ways dates get represented in dictionaries + maybe_date = dictionary[key] + if isinstance(maybe_date, datetime.datetime): + value = maybe_date.date() + elif isinstance(maybe_date, datetime.date): + value = maybe_date + elif isinstance(maybe_date, str): + value = datetime.datetime.strptime( + maybe_date, "%Y-%m-%dT%H:%M:%S.%fZ" + ) + elif isinstance(maybe_date, dict) and "$date" in maybe_date: + value = datetime.datetime.strptime( + maybe_date["$date"], "%Y-%m-%dT%H:%M:%SZ" + ).date() + else: + raise ValueError(f"Cannot interpret date {maybe_date}") else: - setattr(case, key, dictionary[key]) + value = dictionary[key] + setattr(case, key, value) case.validate() return case @@ -49,6 +60,12 @@ def validate(self): elif self.confirmation_date is None: raise ValueError("Confirmation Date must have a value") + @classmethod + def date_fields(cls) -> list[str]: + """Record where dates are kept because they sometimes need special treatment. + A subclass could override this method to indicate it stores additional date fields.""" + return ["confirmation_date"] + # Actually we want to capture extra fields which can be specified dynamically: # so Case is the class that you should use. diff --git a/data-serving/reusable-data-service/tests/test_case_end_to_end.py b/data-serving/reusable-data-service/tests/test_case_end_to_end.py index aed5f041a..7bd340bdd 100644 --- a/data-serving/reusable-data-service/tests/test_case_end_to_end.py +++ b/data-serving/reusable-data-service/tests/test_case_end_to_end.py @@ -2,6 +2,8 @@ import mongomock import pymongo +from datetime import datetime + from reusable_data_service import app, set_up_controllers @@ -31,7 +33,7 @@ def test_get_case_with_known_id(client_with_patched_mongo): db = pymongo.MongoClient("mongodb://localhost:27017/outbreak") case_id = ( db["outbreak"]["cases"] - .insert_one({"confirmation_date": "2021-12-31T01:23:45.678Z"}) + .insert_one({"confirmation_date": datetime(2021, 12, 31, 1, 23, 45, 678)}) .inserted_id ) response = client_with_patched_mongo.get(f"/api/cases/{str(case_id)}")