JAX Implementation flax.nnx
of Black Forest Labs' FLUX Family of Models. JFLUX is a port from the official PyTorch implementation into JAX, using the flax.nnx
framework.
- Full JAX Implementation: The FLUX.1 family of models, originally built in PyTorch, have been fully re-implemented in JAX.
- Support for Flux Variants: Includes support for FLUX.1 variants namely schnell and dev. (tested on schnell on an A100)
- Open-Source and Community-Driven: We encourage contributions and collaboration to further improve the performance and capabilities of this repository. (Refer to the Roadmap Section for more details)
To get started, follow these simple steps:
- Clone the repository:
git clone https://github.com/ml-gde/jflux.git cd jflux
- With python version
>=3.10
, run the model with a prompt (installuv
beforehand):uv run jflux
Note
This might help 😃
uv python install 3.10
uv python pin 3.10
If you don't have uv
you can run the following command to get running.
For GPU:
cd jflux && pip install -U pip && pip install -e .[gpu]
For TPU:
cd jflux && pip install -U pip && pip install -e .[tpu]
Once installed, JFLUX can be used to generate high-quality images from text prompts via a cli. For example:
uv run jflux --prompt "A cute dog"
Additional options are available to adjust resolution, sampling steps, and more.
CLI Overview
NAME
jflux - Sample the flux model. Either interactively (set `--loop`) or run for a single image.
SYNOPSIS
jflux <flags>
DESCRIPTION
Sample the flux model. Either interactively (set `--loop`) or run for a single image.
FLAGS
--name=NAME
Type: str
Default: 'flux-schnell'
Name of the model to load
-w, --width=WIDTH
Type: int
Default: 1360
width of the sample in pixels (should be a multiple of 16)
-h, --height=HEIGHT
Type: int
Default: 768
height of the sample in pixels (should be a multiple of 16)
-s, --seed=SEED
Type: Optional[int | None]
Default: None
Set a seed for sampling
-p, --prompt=PROMPT
Type: str
Default: 'a photo of a forest with mist swirling around the tree trun...'
Prompt used for sampling
--num_steps=NUM_STEPS
Type: Optional[int | None]
Default: None
number of sampling steps (default 4 for schnell, 50 for guidance distilled)
-l, --loop=LOOP
Type: bool
Default: False
start an interactive session and sample multiple times
-g, --guidance=GUIDANCE
Type: float
Default: 3.5
guidance value used for guidance distillation
--offload=OFFLOAD
Type: bool
Default: True
--output_dir=OUTPUT_DIR
Type: str
Default: 'output'
-a, --add_sampling_metadata=ADD_SAMPLING_METADATA
Type: bool
Default: True
Add the prompt to the image Exif metadata
There are a few challenges we’re actively working to resolve. We encourage contributions from the community to help improve the project:
- No TPU Support: As of now, the code does not support TPUs.
- High VRAM Requirements: Running
flux-schnell
currently requires 40 GB of VRAM (A100 or higher). - Image Generation Time: Image generation takes longer than expected, especially for high-resolution outputs (takes ~51s on an A100).
- bfloat16 Upcasting: Weights are upcasted from bfloat16 to fp32 because NumPy does not handle bf16 natively, leading to some inefficiencies.
- porting weights: Currently, we load pytorch weights and convert them to JAX tensors during inference; however, it would be ideal if, after the first run, we serialise the next weights to avoid redundant conversions in each run.
- pure JAX inference: As of now, we load the
t5
andclip
models as torchModule
and wrap their outputs as JAX arrays. This prevents OOM while inference; having a pure JAX inference pipeline would be ideal.
- Add TPU support (#19) (fixed in #25)
- Optimize VRAM usage with gradient checkpointing (#20)
- Explore further optimizations for image generation time
- Improve the handling of bfloat16 tensors with JAX
- Serialise
nnx
weights during cli run (#22) - Port
t5
andclip
models tonnx.Module
(#21)
We’d like to extend our gratitude to those who have supported and guided this project:
- Christian Garcia (google/flax#4273 (comment))
- Jake VanderPlas (jax-ml/ml_dtypes#81 (comment))
- Sayak Paul for major brainstorming sessions
- Aakash Kumar Nain for helping us with JAX issues
- Thanks to the ML Developer Programs' team at Google for providing GCP credits.
Special thanks to Black Forest Labs for the original FLUX implementation.
This project is licensed under the same license as the black-forest-labs/flux repository
- Original Implementation: black-forest-labs/flux