diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 6656e80..ea33633 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -104,14 +104,15 @@ class BrownianTreeNoiseSampler: internal timestep. """ - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, transform_last=lambda x: x): self.transform = transform + self.transform_last = transform_last t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed) def __call__(self, sigma, sigma_next): t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) - return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + return self.transform_last(self.tree(t0, t1) / (t1 - t0).abs().sqrt()) @torch.no_grad()