forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sdxl_example.py
35 lines (27 loc) · 1 KB
/
sdxl_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from legacy.pipefuser.pipelines import DistriSDXLPipeline
from legacy.pipefuser.utils import DistriConfig
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
default="stabilityai/stable-diffusion-xl-base-1.0",
type=str,
help="Path or Id to the pretrained model.",
)
args = parser.parse_args()
distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4)
pipeline = DistriSDXLPipeline.from_pretrained(
distri_config=distri_config,
pretrained_model_name_or_path=args.model_id,
variant="fp16",
use_safetensors=True,
)
pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
image = pipeline(
prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
generator=torch.Generator(device="cuda").manual_seed(233),
).images[0]
if distri_config.rank == 0:
image.save("astronaut.png")