Skip to content

Commit

Permalink
Subclass garbage collection (#397)
Browse files Browse the repository at this point in the history
Fixes the issue discussed
[here](python-attrs/cattrs#589).
  • Loading branch information
AdrianSosic authored Oct 9, 2024
2 parents 0d20f62 + a1b4767 commit c1373c0
Show file tree
Hide file tree
Showing 47 changed files with 222 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Fixed
- Leftover attrs-decorated classes are garbage collected before the subclass tree is
traversed, avoiding sporadic serialization problems

## [0.11.1] - 2024-10-01
### Added
- Continuous linear constraints have been consolidated in the new
Expand Down
5 changes: 5 additions & 0 deletions baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Available acquisition functions."""

import gc
import math
from typing import ClassVar

Expand Down Expand Up @@ -302,3 +303,7 @@ class qThompsonSampling(qSimpleRegret):
def _non_botorch_attrs(cls) -> tuple[str, ...]:
flds = fields(qThompsonSampling)
return (flds.n_mc_samples.name,)


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
import warnings
from abc import ABC
from inspect import signature
Expand Down Expand Up @@ -195,3 +196,6 @@ def added_deprecation_hook(val: dict | str, cls: type):
_add_deprecation_hook(get_base_structure_hook(AcquisitionFunction)),
)
converter.register_unstructure_hook(AcquisitionFunction, unstructure_base)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
6 changes: 6 additions & 0 deletions baybe/acquisition/partial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Wrapper functionality for hybrid spaces."""

import gc

import torch
from attr import define
from botorch.acquisition import AcquisitionFunction as BotorchAcquisitionFunction
Expand Down Expand Up @@ -89,3 +91,7 @@ def set_X_pending(self, X_pending: Tensor | None):
X_pending = torch.squeeze(X_pending, -2)
# Now use the original set_X_pending function
self.botorch_acqf.set_X_pending(X_pending)


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
import json
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -380,3 +381,6 @@ def _drop_version(dict_: dict) -> dict:
_validation_converter.register_structure_hook(
SearchSpace, validate_searchspace_from_config
)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar

Expand Down Expand Up @@ -201,3 +202,6 @@ class ContinuousNonlinearConstraint(ContinuousConstraint, ABC):
# Currently affected by a deprecation
# converter.register_structure_hook(Constraint, get_base_structure_hook(Constraint))
converter.register_structure_hook(Constraint, structure_constraints)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/constraints/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
import operator as ops
from abc import ABC, abstractmethod
from collections.abc import Callable
Expand Down Expand Up @@ -231,3 +232,6 @@ def to_polars(self, expr: pl.Expr, /) -> pl.Expr: # noqa: D102
converter.register_unstructure_hook(
Condition, partial(unstructure_base, overrides=_overrides)
)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/constraints/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
import math
from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -174,3 +175,7 @@ def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]:
]

return inactive_params


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
from collections.abc import Callable
from functools import reduce
from typing import TYPE_CHECKING, Any, ClassVar, cast
Expand Down Expand Up @@ -387,3 +388,6 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: # noqa: D102
# Prevent (de-)serialization of custom constraints
converter.register_unstructure_hook(DiscreteCustomConstraint, block_serialization_hook)
converter.register_structure_hook(DiscreteCustomConstraint, block_deserialization_hook)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
from abc import ABC
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -127,3 +128,6 @@ class CompositeKernel(Kernel, ABC):
# Register de-/serialization hooks
converter.register_structure_hook(Kernel, get_base_structure_hook(Kernel))
converter.register_unstructure_hook(Kernel, unstructure_base)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
6 changes: 6 additions & 0 deletions baybe/kernels/basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Collection of basic kernels."""

import gc

from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import ge, gt, in_, instance_of
Expand Down Expand Up @@ -213,3 +215,7 @@ class RQKernel(BasicKernel):
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel lengthscale."""


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/kernels/composite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Composite kernels (that is, kernels composed of other kernels)."""

import gc
from functools import reduce
from operator import add, mul

Expand Down Expand Up @@ -80,3 +81,7 @@ def to_gpytorch(self, *args, **kwargs): # noqa: D102
# See base class.

return reduce(mul, (k.to_gpytorch(*args, **kwargs) for k in self.base_kernels))


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/objectives/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for all objectives."""

import gc
from abc import ABC, abstractmethod

import pandas as pd
Expand Down Expand Up @@ -57,3 +58,6 @@ def to_objective(x: Target | Objective, /) -> Objective:
),
),
)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/objectives/desirability.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Functionality for desirability objectives."""

import gc
from collections.abc import Callable
from functools import cached_property, partial
from typing import TypeGuard
Expand Down Expand Up @@ -151,3 +152,7 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102
transformed = pd.DataFrame({"Desirability": vals}, index=transformed.index)

return transformed


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
6 changes: 6 additions & 0 deletions baybe/objectives/single.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Functionality for single-target objectives."""

import gc

import pandas as pd
from attr import define, field
from attr.validators import instance_of
Expand Down Expand Up @@ -37,3 +39,7 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102
# See base class.
target_data = data[[self._target.name]].copy()
return self._target.transform(target_data)


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
from abc import ABC, abstractmethod
from functools import cached_property, partial
from typing import TYPE_CHECKING, Any, ClassVar
Expand Down Expand Up @@ -175,3 +176,6 @@ def to_subspace(self) -> SubspaceContinuous:
converter.register_unstructure_hook(
Parameter, partial(unstructure_base, overrides=_overrides)
)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/parameters/categorical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Categorical parameters."""

import gc
from functools import cached_property
from typing import Any, ClassVar

Expand Down Expand Up @@ -109,3 +110,7 @@ def _validate_active_values( # noqa: DOC101, DOC103
raise ValueError("The active parameter values must be unique.")
if not all(v in self.values for v in values):
raise ValueError("All active values must be valid parameter choices.")


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/parameters/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Custom parameters."""

import gc
from functools import cached_property
from typing import Any, ClassVar

Expand Down Expand Up @@ -111,3 +112,7 @@ def comp_df(self) -> pd.DataFrame: # noqa: D102
comp_df = df_uncorrelated_features(comp_df, threshold=self.decorrelate)

return comp_df


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/parameters/numerical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Numerical parameters."""

import gc
from functools import cached_property
from typing import Any, ClassVar

Expand Down Expand Up @@ -147,3 +148,7 @@ def summary(self) -> dict: # noqa: D102
Upper_Bound=self.bounds.upper,
)
return param_dict


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/parameters/substance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Substance parameters."""

import gc
from functools import cached_property
from typing import Any, ClassVar

Expand Down Expand Up @@ -149,3 +150,7 @@ def comp_df(self) -> pd.DataFrame: # noqa: D102
comp_df = df_uncorrelated_features(comp_df, threshold=self.decorrelate)

return comp_df


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/priors/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base class for all priors."""

import gc
from abc import ABC

from attrs import define
Expand Down Expand Up @@ -41,3 +42,7 @@ def to_gpytorch(self, *args, **kwargs):
# Register de-/serialization hooks
converter.register_structure_hook(Prior, get_base_structure_hook(Prior))
converter.register_unstructure_hook(Prior, unstructure_base)


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/priors/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
from typing import Any

from attrs import define, field
Expand Down Expand Up @@ -103,3 +104,7 @@ def to_gpytorch(self, *args, **kwargs): # noqa: D102
raise NotImplementedError(
f"'{self.__class__.__name__}' does not have a gpytorch analog."
)


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/recommenders/meta/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for all meta recommenders."""

import gc
from abc import ABC, abstractmethod
from typing import Any

Expand Down Expand Up @@ -146,3 +147,6 @@ def recommend(
converter.register_structure_hook(
MetaRecommender, get_base_structure_hook(MetaRecommender)
)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
4 changes: 4 additions & 0 deletions baybe/recommenders/meta/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# this file will resolve type errors
# mypy: disable-error-code="arg-type"

import gc
from collections.abc import Iterable, Iterator
from typing import Literal

Expand Down Expand Up @@ -275,3 +276,6 @@ def __str__(self) -> str:
converter.register_structure_hook(
StreamingSequentialMetaRecommender, block_deserialization_hook
)

# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/recommenders/naive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Naive recommender for hybrid spaces."""

import gc
import warnings
from typing import ClassVar

Expand Down Expand Up @@ -180,3 +181,7 @@ def recommend( # noqa: D102
rec_cont.index = rec_disc_exp.index
rec_exp = pd.concat([rec_disc_exp, rec_cont], axis=1)
return rec_exp


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
5 changes: 5 additions & 0 deletions baybe/recommenders/pure/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for all pure recommenders."""

import gc
from abc import ABC
from typing import ClassVar

Expand Down Expand Up @@ -232,3 +233,7 @@ def _recommend_with_discrete_parts(

# Return recommendations
return rec


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
Loading

0 comments on commit c1373c0

Please sign in to comment.