diff --git a/ax/storage/json_store/encoder.py b/ax/storage/json_store/encoder.py index f745c3accee..025d46a735c 100644 --- a/ax/storage/json_store/encoder.py +++ b/ax/storage/json_store/encoder.py @@ -135,6 +135,7 @@ def object_to_json( # noqa C901 }, } elif dataclasses.is_dataclass(obj): + field_names = [f.name for f in dataclasses.fields(obj)] return { "__type": _type.__name__, **{ @@ -144,6 +145,7 @@ def object_to_json( # noqa C901 class_encoder_registry=class_encoder_registry, ) for k, v in obj.__dict__.items() + if k in field_names }, } diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 1bc82ca8738..4e116cf5765 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -383,6 +383,25 @@ def test_EncodeDecode(self) -> None: else: raise e + def test_EncodeDecode_dataclass_with_initvar(self) -> None: + @dataclasses.dataclass + class TestDataclass: + a_field: int + not_a_field: dataclasses.InitVar[int | None] = None + + def __post_init__(self, doesnt_serialize: None) -> None: + self.not_a_field = 1 + + obj = TestDataclass(a_field=-1) + as_json = object_to_json(obj=obj) + self.assertEqual(as_json, {"__type": "TestDataclass", "a_field": -1}) + recovered = object_from_json( + object_json=as_json, decoder_registry={"TestDataclass": TestDataclass} + ) + self.assertEqual(recovered.a_field, -1) + self.assertEqual(recovered.not_a_field, 1) + self.assertEqual(obj, recovered) + def test_EncodeDecodeTorchTensor(self) -> None: x = torch.tensor( [[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64, device=torch.device("cpu")