Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions atom/entrypoints/image_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
"""OpenAI Images API compatible server for Flux."""

import argparse
import base64
import io
import time

import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from PIL import Image
from pydantic import BaseModel

from atom.model_engine.arg_utils import EngineArgs
from atom.model_engine.diffusion_runner import DiffusionModelRunner

app = FastAPI()
runner = None


class ImageRequest(BaseModel):
prompt: str
n: int = 1
size: str = "1024x1024"
response_format: str = "b64_json"
num_inference_steps: int = 50
guidance_scale: float = 3.5


def tensor_to_b64(tensor: torch.Tensor) -> str:
img = Image.fromarray((tensor.permute(1, 2, 0).cpu().numpy() * 255).astype("uint8"))
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()


@app.post("/v1/images/generations")
async def create_image(req: ImageRequest):
w, h = map(int, req.size.split("x"))
images = runner.generate(
[req.prompt] * req.n, h, w, req.num_inference_steps, req.guidance_scale
)
return JSONResponse(
{
"created": int(time.time()),
"data": [{"b64_json": tensor_to_b64(img)} for img in images],
}
)


@app.get("/health")
async def health():
return {"status": "healthy"}


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
EngineArgs.add_cli_args(parser)
args = parser.parse_args()

global runner
config = EngineArgs.from_cli_args(args).create_atom_config()
runner = DiffusionModelRunner(config)
uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions atom/examples/simple_image_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
"""Example: Generate images with Flux."""

import argparse
from PIL import Image
from atom.model_engine.arg_utils import EngineArgs
from atom.model_engine.diffusion_runner import DiffusionModelRunner


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--prompt", default="A beautiful sunset over mountains")
parser.add_argument("--output", default="output.png")
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--guidance", type=float, default=3.5)
EngineArgs.add_cli_args(parser)
args = parser.parse_args()

config = EngineArgs.from_cli_args(args).create_atom_config()
runner = DiffusionModelRunner(config)
images = runner.generate(
[args.prompt], args.height, args.width, args.steps, args.guidance
)

img = Image.fromarray(
(images[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
)
img.save(args.output)
print(f"Saved to {args.output}")


if __name__ == "__main__":
main()
43 changes: 43 additions & 0 deletions atom/model_engine/diffusion_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
"""Diffusion model runner for Flux."""

import torch
from atom.model_loader.loader import load_model
from atom.model_ops.diffusion_sampler import FlowMatchingSampler
from atom.models.flux_vae import AutoencoderKL
from atom.models.flux_text_encoder import FluxTextEncoder


class DiffusionModelRunner:
def __init__(self, config, device: str = "cuda"):
self.config = config
self.device = device
self.dtype = torch.bfloat16

self.model = load_model(config).to(self.device, self.dtype)
self.sampler = FlowMatchingSampler(num_steps=50)
self.vae = AutoencoderKL().to(self.device, self.dtype)
self.text_encoder = FluxTextEncoder(device=device)

@torch.no_grad()
def generate(
self,
prompts: list,
height: int = 1024,
width: int = 1024,
num_steps: int = 50,
guidance_scale: float = 3.5,
) -> torch.Tensor:
B = len(prompts)
H, W = height // 8, width // 8
latents = torch.randn(B, 16, H, W, device=self.device, dtype=self.dtype)

text_emb = self.text_encoder.encode(prompts)
self.sampler.set_timesteps(num_steps, self.device)

for i, t in enumerate(self.sampler.timesteps):
timesteps = t.expand(B)
noise_pred = self.model(latents, timesteps, text_emb, guidance_scale)
latents = self.sampler.step(noise_pred, timesteps, latents, i)

return self.vae.decode(latents).clamp(-1, 1) * 0.5 + 0.5
2 changes: 2 additions & 0 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"DeepseekV3ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM",
# Diffusion Transformer models (NOTE: requires diffusion-specific inference pipeline)
"FluxTransformer2DModel": "atom.models.flux.FluxForImageGeneration",
}
# seed = 34567
# np.random.seed(seed)
Expand Down
28 changes: 28 additions & 0 deletions atom/model_ops/diffusion_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
"""Flow matching sampler for diffusion models."""

import torch


class FlowMatchingSampler:
def __init__(self, num_steps: int = 50, shift: float = 1.0):
self.num_steps = num_steps
self.shift = shift
self.timesteps = None

def set_timesteps(self, num_steps: int = None, device: str = "cuda"):
n = num_steps or self.num_steps
sigmas = torch.linspace(1, 0, n + 1, device=device)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.timesteps = sigmas[:-1]
self.sigmas = sigmas

def step(
self,
model_output: torch.Tensor,
timestep: torch.Tensor,
sample: torch.Tensor,
step_idx: int,
) -> torch.Tensor:
dt = self.sigmas[step_idx + 1] - self.sigmas[step_idx]
return sample + model_output * dt
Loading
Loading