Skip to content

Commit

Permalink
Fix height and width argument names
Browse files Browse the repository at this point in the history
  • Loading branch information
filipstrand committed Aug 18, 2024
1 parent 826b58d commit c50a1fb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/flux_1_schnell/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def generate_image(self, seed: int, prompt: str, config: Config = Config()) -> P
return ImageUtil.to_image(decoded)

@staticmethod
def _unpack_latents(latents: mx.array, width: int, height: int) -> mx.array:
latents = mx.reshape(latents, (1, height//16, width//16, 16, 2, 2))
def _unpack_latents(latents: mx.array, height: int, width: int) -> mx.array:
latents = mx.reshape(latents, (1, width // 16, height // 16, 16, 2, 2))
latents = mx.transpose(latents, (0, 3, 1, 4, 2, 5))
latents = mx.reshape(latents, (1, 16, height//16 *2, width//16 * 2))
latents = mx.reshape(latents, (1, 16, width // 16 * 2, height // 16 * 2))
return latents

def encode(self, path: str) -> mx.array:
Expand Down
12 changes: 6 additions & 6 deletions src/flux_1_schnell/models/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def predict(
return noise

@staticmethod
def _prepare_latent_image_ids(width: int, height: int) -> mx.array:
latent_height = height // 16
def _prepare_latent_image_ids(height: int, width: int) -> mx.array:
latent_width = width // 16
latent_image_ids = mx.zeros((latent_height, latent_width, 3))
latent_image_ids = latent_image_ids.at[:, :, 1].add(mx.arange(0, latent_height)[:, None])
latent_image_ids = latent_image_ids.at[:, :, 2].add(mx.arange(0, latent_width)[None, :])
latent_height = height // 16
latent_image_ids = mx.zeros((latent_width, latent_height, 3))
latent_image_ids = latent_image_ids.at[:, :, 1].add(mx.arange(0, latent_width)[:, None])
latent_image_ids = latent_image_ids.at[:, :, 2].add(mx.arange(0, latent_height)[None, :])
latent_image_ids = mx.repeat(latent_image_ids[None, :], 1, axis=0)
latent_image_ids = mx.reshape(latent_image_ids, (1, latent_height * latent_width, 3))
latent_image_ids = mx.reshape(latent_image_ids, (1, latent_width * latent_height, 3))
return latent_image_ids

@staticmethod
Expand Down

0 comments on commit c50a1fb

Please sign in to comment.