Skip to content

Commit

Permalink
Delayed import of torch and tensorflow. This should massivly increase…
Browse files Browse the repository at this point in the history
… import speed of tpcp, when torch and tensorflow are installed, but not used in the context of tpcp
  • Loading branch information
AKuederle committed Aug 23, 2024
1 parent 3e320be commit 41ecc3b
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tpcp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@
ValidationError,
)

try:
import tensorflow as tf
except ImportError:
tf = None

try:
import torch
except ImportError:
torch = None

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -757,6 +748,19 @@ def clone(algorithm: T, *, safe: bool = False) -> T: # noqa: C901, PLR0911
"it does not seem to be a compatible algorithm/pipline class or general `tpcp` object as it does not "
"inherit from `BaseTpcpObject` or `Algorithm` or `Pipeline`."
)
# We delay an potential tensorflow and torch input until here, because the import is expensive.
# We only import, when the modules have been used before.
# Let's hope this does not explode in a multi-processing context, but let's see.
if "tensorflow" in sys.modules:
import tensorflow as tf
else:
tf = None

if "torch" in sys.modules:
import torch
else:
torch = None

# We have one special case for torch here, as apparently torch objects can not be deepcopied.
# https://github.com/pytorch/tutorials/issues/2177
if torch is not None and isinstance(algorithm, torch.nn.Module):
Expand Down

0 comments on commit 41ecc3b

Please sign in to comment.