diff --git a/src/flux_1_schnell/flux.py b/src/flux_1_schnell/flux.py index 50dade0f..a9feeb07 100644 --- a/src/flux_1_schnell/flux.py +++ b/src/flux_1_schnell/flux.py @@ -61,10 +61,10 @@ def generate_image(self, seed: int, prompt: str, config: Config = Config()) -> P return ImageUtil.to_image(decoded) @staticmethod - def _unpack_latents(latents: mx.array, width: int, height: int) -> mx.array: - latents = mx.reshape(latents, (1, height//16, width//16, 16, 2, 2)) + def _unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: + latents = mx.reshape(latents, (1, width // 16, height // 16, 16, 2, 2)) latents = mx.transpose(latents, (0, 3, 1, 4, 2, 5)) - latents = mx.reshape(latents, (1, 16, height//16 *2, width//16 * 2)) + latents = mx.reshape(latents, (1, 16, width // 16 * 2, height // 16 * 2)) return latents def encode(self, path: str) -> mx.array: diff --git a/src/flux_1_schnell/models/transformer/transformer.py b/src/flux_1_schnell/models/transformer/transformer.py index 76768857..1d132c9c 100644 --- a/src/flux_1_schnell/models/transformer/transformer.py +++ b/src/flux_1_schnell/models/transformer/transformer.py @@ -67,14 +67,14 @@ def predict( return noise @staticmethod - def _prepare_latent_image_ids(width: int, height: int) -> mx.array: - latent_height = height // 16 + def _prepare_latent_image_ids(height: int, width: int) -> mx.array: latent_width = width // 16 - latent_image_ids = mx.zeros((latent_height, latent_width, 3)) - latent_image_ids = latent_image_ids.at[:, :, 1].add(mx.arange(0, latent_height)[:, None]) - latent_image_ids = latent_image_ids.at[:, :, 2].add(mx.arange(0, latent_width)[None, :]) + latent_height = height // 16 + latent_image_ids = mx.zeros((latent_width, latent_height, 3)) + latent_image_ids = latent_image_ids.at[:, :, 1].add(mx.arange(0, latent_width)[:, None]) + latent_image_ids = latent_image_ids.at[:, :, 2].add(mx.arange(0, latent_height)[None, :]) latent_image_ids = mx.repeat(latent_image_ids[None, :], 1, axis=0) - latent_image_ids = mx.reshape(latent_image_ids, (1, latent_height * latent_width, 3)) + latent_image_ids = mx.reshape(latent_image_ids, (1, latent_width * latent_height, 3)) return latent_image_ids @staticmethod