Skip to content

Commit

Permalink
Serialize data classes based on their fields only (#2697)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2697

# Context:

**I think the unit test is the easiest way to understand this!**

There are several cases in Ax where an instance has an attribute that
1) can't or shouldn't be serialized; this is usually subclasses of `torch.nn.Module`, including BoTorch models, but can be other large and complex objects,
2) isn't typically passed at initialization, but rather constructed afterwards, and
3) should be an attribute rather than a property because it's too expensive to construct more than once

These classes have thus required custom serialization logic so that such attributes are not serialized. Classes that have this issue include many benchmarking classes that work with surrogates or neural nets, such as `SurrogateRunner`, `PyTorchCNNTorchvisionRunner`, and `PyTorchCNNTorchvisionBenchmarkProblem`, as well as MBM classes.

A simpler solution is to use dataclasses, by
- specifying features that satisfy (1)-(3) with `InitVar`, and, if they are needed immediately, constructing them in the `__post_init__`
- only serializing fields; `InitVar`s are not fields

This gives more flexibility in what we serialize without taking any away: If an attribute is constructed in the post-init and *should* be serialized, that is still supported by marking it as a `field` and not an `InitVar`. Attributes that are constructed in the init will be serialized, even if they are modified elsewhere.

## Downside
It is not quite the normal usage to use an `InitVar` to define a persistent non-field attribute; instead `InitVar` is intended for something that is needed only for initializing fields. So the usage pattern outlined in the unit test causes Pyre errors, and someone who uses an `InitVar` that way might be surprised that it can't be recovered. However, it might not need to be, since the other fields would be recovered. Also, `InitVar` is a rarely used feature.

# This PR:

* Changes Ax's JSON serialization for dataclasses to exclude non-fields
* Adds a unit test

Reviewed By: danielcohenlive

Differential Revision: D61665461
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 23, 2024
1 parent 9c2fa96 commit 2345564
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ax/storage/json_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def object_to_json( # noqa C901
},
}
elif dataclasses.is_dataclass(obj):
field_names = [f.name for f in dataclasses.fields(obj)]
return {
"__type": _type.__name__,
**{
Expand All @@ -144,6 +145,7 @@ def object_to_json( # noqa C901
class_encoder_registry=class_encoder_registry,
)
for k, v in obj.__dict__.items()
if k in field_names
},
}

Expand Down
19 changes: 19 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,25 @@ def test_EncodeDecode(self) -> None:
else:
raise e

def test_EncodeDecode_dataclass_with_initvar(self) -> None:
@dataclasses.dataclass
class TestDataclass:
a_field: int
not_a_field: dataclasses.InitVar[int | None] = None

def __post_init__(self, doesnt_serialize: None) -> None:
self.not_a_field = 1

obj = TestDataclass(a_field=-1)
as_json = object_to_json(obj=obj)
self.assertEqual(as_json, {"__type": "TestDataclass", "a_field": -1})
recovered = object_from_json(
object_json=as_json, decoder_registry={"TestDataclass": TestDataclass}
)
self.assertEqual(recovered.a_field, -1)
self.assertEqual(recovered.not_a_field, 1)
self.assertEqual(obj, recovered)

def test_EncodeDecodeTorchTensor(self) -> None:
x = torch.tensor(
[[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64, device=torch.device("cpu")
Expand Down

0 comments on commit 2345564

Please sign in to comment.