Skip to content

Commit

Permalink
Lazy load DTypeFloatTorch (#173)
Browse files Browse the repository at this point in the history
Lazy load `DTypeFloatTorch` from a newly created `baybe.utils.torch` module, to avoid eager loading of `torch`.
  • Loading branch information
AdrianSosic authored Apr 3, 2024
2 parents 1b6c576 + abdfe38 commit 82bed24
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Changed
- `torch` numeric types are now loaded lazily

## [0.8.2] - 2024-03-27
### Added
- Simulation user guide
Expand Down
3 changes: 2 additions & 1 deletion baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
get_base_structure_hook,
unstructure_base,
)
from baybe.utils.numerical import DTypeFloatTorch

if TYPE_CHECKING:
from torch import Tensor
Expand Down Expand Up @@ -164,6 +163,8 @@ def to_botorch(
"""
import torch

from baybe.utils.torch import DTypeFloatTorch

param_names = [p.name for p in parameters]
param_indices = [
param_names.index(p) + idx_offset
Expand Down
6 changes: 4 additions & 2 deletions baybe/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from functools import partial
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import pandas as pd
Expand All @@ -12,9 +12,11 @@

from baybe.serialization import SerialMixin
from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.utils.numerical import geom_mean

if TYPE_CHECKING:
from baybe.targets.numerical import NumericalTarget


def _normalize_weights(weights: list[float]) -> list[float]:
"""Normalize a collection of weights such that they sum to 100.
Expand Down
3 changes: 2 additions & 1 deletion baybe/surrogates/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from baybe.surrogates.base import Surrogate
from baybe.surrogates.utils import batchify, catch_constant_targets
from baybe.surrogates.validation import validate_custom_architecture_cls
from baybe.utils.numerical import DTypeFloatONNX, DTypeFloatTorch
from baybe.utils.numerical import DTypeFloatONNX
from baybe.utils.torch import DTypeFloatTorch

try:
import onnxruntime as ort
Expand Down
4 changes: 3 additions & 1 deletion baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from baybe.parameters.base import ContinuousParameter, DiscreteParameter
from baybe.targets.enum import TargetMode
from baybe.utils.numerical import DTypeFloatNumpy, DTypeFloatTorch
from baybe.utils.numerical import DTypeFloatNumpy

if TYPE_CHECKING:
from torch import Tensor
Expand Down Expand Up @@ -41,6 +41,8 @@ def to_tensor(*dfs: pd.DataFrame) -> Union[Tensor, Iterable[Tensor]]:
# even though this seems like double casting here.
import torch

from baybe.utils.torch import DTypeFloatTorch

out = (
torch.from_numpy(df.values.astype(DTypeFloatNumpy)).to(DTypeFloatTorch)
for df in dfs
Expand Down
4 changes: 3 additions & 1 deletion baybe/utils/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from packaging import version

from baybe.serialization import SerialMixin, converter
from baybe.utils.numerical import DTypeFloatNumpy, DTypeFloatTorch
from baybe.utils.numerical import DTypeFloatNumpy

if TYPE_CHECKING:
from torch import Tensor
Expand Down Expand Up @@ -132,6 +132,8 @@ def to_tensor(self) -> "Tensor":
"""Transform the interval to a :class:`torch.Tensor`."""
import torch

from baybe.utils.torch import DTypeFloatTorch

return torch.tensor([self.lower, self.upper], dtype=DTypeFloatTorch)

def contains(self, number: float) -> bool:
Expand Down
4 changes: 0 additions & 4 deletions baybe/utils/numerical.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
"""Utilities for numeric operations."""

import numpy as np
import torch

DTypeFloatNumpy = np.float64
"""Floating point data type used for numpy arrays."""

DTypeFloatTorch = torch.float64
"""Floating point data type used for torch tensors."""

DTypeFloatONNX = np.float32
"""Floating point data type used for ONNX models.
Expand Down
6 changes: 6 additions & 0 deletions baybe/utils/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Torch utilities shipped as separate module for lazy-loading."""

import torch

DTypeFloatTorch = torch.float64
"""Floating point data type used for torch tensors."""

0 comments on commit 82bed24

Please sign in to comment.