Skip to content

Commit

Permalink
Make some imports late to allow a more minimal set of inference requi…
Browse files Browse the repository at this point in the history
…rements
  • Loading branch information
akx committed Jul 19, 2023
1 parent e5dc966 commit d843656
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions sgm/modules/autoencoding/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss

from ....util import default, instantiate_from_config

Expand All @@ -26,6 +23,7 @@ def __init__(
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
Expand Down Expand Up @@ -101,6 +99,9 @@ def __init__(
learn_logvar: bool = False,
regularization_weights: Union[None, dict] = None,
):
from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init # late import to avoid extra dependency
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss # late import to avoid extra dependency
super().__init__()
self.dims = dims
if self.dims > 2:
Expand Down
3 changes: 2 additions & 1 deletion sgm/modules/diffusionmodules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn as nn
from omegaconf import ListConfig
from taming.modules.losses.lpips import LPIPS

from ...util import append_dims, instantiate_from_config

Expand All @@ -26,6 +25,8 @@ def __init__(
self.offset_noise_level = offset_noise_level

if type == "lpips":
from taming.modules.losses.lpips import LPIPS # late import to avoid extra dependency

self.lpips = LPIPS().eval()

if not batch2model_keys:
Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/diffusionmodules/sampling_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from scipy import integrate

from ...util import append_dims

Expand All @@ -10,6 +9,7 @@ def __call__(self, uncond, cond, scale):


def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
from scipy import integrate # late import to avoid extra dependency
if order - 1 > i:
raise ValueError(f"Order {order} too high for step {i}")

Expand Down

0 comments on commit d843656

Please sign in to comment.