From f47a2573c02ef0dad6eefefb90c76ad953547422 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Wed, 31 May 2023 11:49:43 +0300 Subject: [PATCH] Add transform_last for BrownianTreeNoiseSampler --- k_diffusion/sampling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 6656e80..ea33633 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -104,14 +104,15 @@ class BrownianTreeNoiseSampler: internal timestep. """ - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, transform_last=lambda x: x): self.transform = transform + self.transform_last = transform_last t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed) def __call__(self, sigma, sigma_next): t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) - return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + return self.transform_last(self.tree(t0, t1) / (t1 - t0).abs().sqrt()) @torch.no_grad()