Skip to content

Commit 2345481

Browse files
[Flax] Fix unet and ddim scheduler (open-mmlab#594)
* [Flax] Fix unet and ddim scheduler * correct * finish
1 parent d934d3d commit 2345481

7 files changed

+22
-14
lines changed

src/diffusers/models/embeddings_flax.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
2121
# less general (only handles the case we currently need).
22-
def get_sinusoidal_embeddings(timesteps, embedding_dim):
22+
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
2323
"""
2424
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
2525
@@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
2929
embeddings. :return: an [N x dim] tensor of positional embeddings.
3030
"""
3131
half_dim = embedding_dim // 2
32-
emb = math.log(10000) / (half_dim - 1)
32+
emb = math.log(10000) / (half_dim - freq_shift)
3333
emb = jnp.exp(jnp.arange(half_dim) * -emb)
3434
emb = timesteps[:, None] * emb[None, :]
3535
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
@@ -50,7 +50,8 @@ def __call__(self, temb):
5050

5151
class FlaxTimesteps(nn.Module):
5252
dim: int = 32
53+
freq_shift: float = 1
5354

5455
@nn.compact
5556
def __call__(self, timesteps):
56-
return get_sinusoidal_embeddings(timesteps, self.dim)
57+
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)

src/diffusers/models/unet_2d_condition_flax.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
7373
cross_attention_dim: int = 1280
7474
dropout: float = 0.0
7575
dtype: jnp.dtype = jnp.float32
76+
freq_shift: int = 0
7677

7778
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
7879
# init input tensors
@@ -100,7 +101,7 @@ def setup(self):
100101
)
101102

102103
# time
103-
self.time_proj = FlaxTimesteps(block_out_channels[0])
104+
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
104105
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
105106

106107
# down

src/diffusers/pipeline_flax_utils.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
354354
# TODO(Patrick, Suraj) - delete later
355355
if class_name == "DummyChecker":
356356
library_name = "stable_diffusion"
357-
class_name = "StableDiffusionSafetyChecker"
357+
class_name = "FlaxStableDiffusionSafetyChecker"
358358

359359
is_pipeline_module = hasattr(pipelines, library_name)
360360
loaded_sub_model = None
@@ -421,16 +421,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
421421
loaded_sub_model = cached_folder
422422

423423
if issubclass(class_obj, FlaxModelMixin):
424-
# TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
424+
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
425+
params[name] = loaded_params
426+
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
427+
# make sure we don't initialize the weights to save time
425428
if name == "safety_checker":
426429
loaded_sub_model = DummyChecker()
427430
loaded_params = DummyChecker()
428-
else:
429-
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
430-
params[name] = loaded_params
431-
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
432-
# make sure we don't initialize the weights to save time
433-
if from_pt:
431+
elif from_pt:
434432
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
435433
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
436434
loaded_params = loaded_sub_model.params

src/diffusers/pipeline_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
341341

342342
# 3. Load each module in the pipeline
343343
for name, (library_name, class_name) in init_dict.items():
344+
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
345+
if class_name.startswith("Flax"):
346+
class_name = class_name[4:]
347+
344348
is_pipeline_module = hasattr(pipelines, library_name)
345349
loaded_sub_model = None
346350

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ def loop_body(step, args):
178178
jnp.array(latents_input),
179179
jnp.array(timestep, dtype=jnp.int32),
180180
encoder_hidden_states=context,
181-
rngs={},
182181
).sample
183182
# perform guidance
184183
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)

src/diffusers/schedulers/scheduling_ddim.py

+1
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def step(
222222
# 2. compute alphas, betas
223223
alpha_prod_t = self.alphas_cumprod[timestep]
224224
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
225+
225226
beta_prod_t = 1 - alpha_prod_t
226227

227228
# 3. compute predicted original sample from predicted noise also called

src/diffusers/schedulers/scheduling_ddim_flax.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def step(
216216
# - pred_sample_direction -> "direction pointing to x_t"
217217
# - pred_prev_sample -> "x_t-1"
218218

219+
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
220+
eta = 0.0
221+
219222
# 1. get previous step value (=t-1)
220223
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
221224

@@ -224,6 +227,7 @@ def step(
224227
# 2. compute alphas, betas
225228
alpha_prod_t = alphas_cumprod[timestep]
226229
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
230+
227231
beta_prod_t = 1 - alpha_prod_t
228232

229233
# 3. compute predicted original sample from predicted noise also called
@@ -233,7 +237,7 @@ def step(
233237
# 4. compute variance: "sigma_t(η)" -> see formula (16)
234238
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
235239
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
236-
std_dev_t = variance ** (0.5)
240+
std_dev_t = eta * variance ** (0.5)
237241

238242
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
239243
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output

0 commit comments

Comments
 (0)