Here we provide an efficient MindSpore implementation of OpenSora, an open-source project that aims to foster innovation, creativity, and inclusivity within the field of content creation.
This repository is built on the models and code released by HPC-AI Tech. We are grateful for their exceptional work and generous contribution to open source.
Official News from HPC-AI Tech | MindSpore Support |
---|---|
[2024.06.17] 🔥 HPC-AI released Open-Sora 1.2, which includes 3D-VAE, rectified flow, and score condition. The video quality is greatly improved. [checkpoints] [report] [blog] | Text-to-Video |
[2024.04.25] 🤗 HPC-AI Tech released the Gradio demo for Open-Sora on Hugging Face Spaces. | N.A. |
[2024.04.25] 🔥 HPC-AI Tech released Open-Sora 1.1, which supports 2s~15s, 144p to 720p, any aspect ratio text-to-image, text-to-video, image-to-video, video-to-video, infinite time generation. In addition, a full video processing pipeline is released. [checkpoints] [report] | Image/Video-to-Video; Infinite time generation; Variable resolutions, aspect ratios, durations |
[2024.03.18] HPC-AI Tech released Open-Sora 1.0, a fully open-source project for video generation. | ✅ VAE + STDiT training and inference |
[2024.03.04] HPC-AI Tech Open-Sora provides training with 46% cost reduction [blog] | ✅ Parallel training on Ascend devices |
mindspore | ascend driver | firmware | cann tookit/kernel |
---|---|---|---|
2.3.1 | 23.0.3 | 7.1.0.9.220 | 8.0.RC2.beta1 |
The following videos are generated based on MindSpore and Ascend 910*.
4s 720×1280 | 4s 720×1280 | 4s 720×1280 |
---|---|---|
000-A-Japanese-tram-glides-through-the-snowy-streets-of-a.mp4 |
006-a-close-up-shot-of-a-woman-standing-in-a-dimly.mp4 |
015-a-cozy-living-room-scene-with-a-christmas-tree-in.mp4 |
Tip
To generate better looking videos, you can try generating in two stages: Text-to-Image and then Image-to-Video.
Demo
Input | Output |
---|---|
![]() ![]() |
|
|
![]() ![]() |
Start Frame | End Frame | Caption | Output |
---|---|---|---|
![]() |
![]() |
A breathtaking sunrise scene. | ![]() ![]() |
Input | Output |
---|---|
![]() ![]() |
Caption | Output |
---|---|
Bright scene, aerial view,ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens. | ![]() ![]() |
A small cactus with a happy face in the Sahara desert. |
Demo
Videos are downsampled to .gif
for display. Click for original videos. Prompts are trimmed for display, see here for full prompts.
-
📍 Open-Sora 1.2 released. Model weights are available here. See report 1.2 for more details.
- ✅ Support rectified flow scheduling.
- ✅ Support more conditioning including fps, aesthetic score, motion strength and camera motion.
- ✅ Trained our 3D-VAE for temporal dimension compression.
-
📍 Open-Sora 1.1 with the following features
- ✅ Improved ST-DiT architecture includes Rotary Position Embedding (RoPE), QK Normalization, longer text length, etc.
- ✅ Support image and video conditioning and video editing, and thus support animating images, connecting videos, etc.
- ✅ Support training with any resolution, aspect ratio, and duration.
-
📍 Open-Sora 1.0 with the following features
- ✅ Text-to-video generation in 256x256 or 512x512 resolution and up to 64 frames.
- ✅ Three-stage training: i) 16x256x256 video pretraining, ii) 16x512x512 video fine-tuning, and iii) 64x512x512 videos
- ✅ Optimized training recipes for MindSpore+Ascend framework (see
configs/opensora/train/xxx_ms.yaml
) - ✅ Acceleration methods: flash attention, recompute (gradient checkpointing), data sink, mixed precision, and graph compilation.
- ✅ Data parallelism + Optimizer parallelism, allow training on 300x512x512 videos
View more
- ✅ Following the findings in OpenSora, we also adopt the VAE from Stable Diffusion for video latent encoding.
- ✅ We pick the STDiT model as our video diffusion transformer following the best practice in OpenSora.
- ✅ Support T5 text conditioning.
View more
- Evaluation pipeline.
- Complete the data processing pipeline (including dense optical flow, aesthetics scores, text-image similarity, etc.).
- Installation
- Model Weights
- Inference
- Data Processing
- Training
- Evaluation
- VAE Training & Evaluation
- Contribution
- Acknowledgement
Other useful documents and links are listed below.
- Repo structure: structure.md
-
Please install MindSpore 2.3.1 according to the MindSpore official website and install CANN 8.0.RC2.beta1 as recommended by the official installation website.
-
Install requirements
pip install -r requirements.txt
In case decord
package is not available, try pip install eva-decord
.
For EulerOS, instructions on ffmpeg and decord installation are as follows.
1. install ffmpeg 4, referring to https://ffmpeg.org/releases
wget https://ffmpeg.org/releases/ffmpeg-4.0.1.tar.bz2 --no-check-certificate
tar -xvf ffmpeg-4.0.1.tar.bz2
mv ffmpeg-4.0.1 ffmpeg
cd ffmpeg
./configure --enable-shared # --enable-shared is needed for sharing libavcodec with decord
make -j 64
make install
2. install decord, referring to https://github.com/dmlc/decord?tab=readme-ov-file#install-from-source
git clone --recursive https://github.com/dmlc/decord
cd decord
rm build && mkdir build && cd build
cmake .. -DUSE_CUDA=0 -DCMAKE_BUILD_TYPE=Release
make -j 64
make install
cd ../python
python3 setup.py install --user
Model | Model size | Data | URL |
---|---|---|---|
STDiT3 (Diffusion) | 1.1B | 30M | Download |
VAE | 384M | 3M | Download |
- Convert STDiT3 to MS checkpoint:
python tools/convert_pt2ms.py --src /path/to/OpenSora-STDiT-v3/model.safetensors --target models/opensora_stdit_v3.ckpt
- Convert VAE to MS checkpoint:
python convert_vae_3d.py --src /path/to/OpenSora-VAE-v1.2/model.safetensors --target models/OpenSora-VAE-v1.2/model.ckpt
- The T5 model is identical to OpenSora 1.0 and can be downloaded from the links below.
Instructions
- STDit:
Stage | Resolution | Model Size | Data | #iterations | URL |
---|---|---|---|---|---|
2 | mainly 144p & 240p | 700M | 10M videos + 2M images | 100k | Download |
3 | 144p to 720p | 700M | 500K HQ videos + 1M images | 4k | Download |
Convert to MS checkpoint:
python tools/convert_pt2ms.py --src /path/to/OpenSora-STDiT-v2-stage3/model.safetensors --target models/opensora_v1.1_stage3.ckpt
- T5 and VAE models are identical to OpenSora 1.0 and can be downloaded from the links below.
Instructions
Please prepare the model checkpoints of T5, VAE, and STDiT and put them under models/
folder as follows.
-
T5: Download the DeepFloyd/t5-v1_1-xxl folder and put it under
models/
Convert to ms checkpoint:
python tools/convert_t5.py --src models/t5-v1_1-xxl/pytorch_model-00001-of-00002.bin models/t5-v1_1-xxl/pytorch_model-00002-of-00002.bin --target models/t5-v1_1-xxl/model.ckpt
-
VAE: Download the safetensor checkpoint from here
Convert to ms checkpoint:
python tools/convert_vae.py --src /path/to/sd-vae-ft-ema/diffusion_pytorch_model.safetensors --target models/sd-vae-ft-ema.ckpt
-
STDiT: Download
OpenSora-v1-16x256x256.pth
/OpenSora-v1-HQ-16x256x256.pth
/OpenSora-v1-HQ-16x512x512.pth
from hereConvert to ms checkpoint:
python tools/convert_pt2ms.py --src /path/to/OpenSora-v1-16x256x256.pth --target models/OpenSora-v1-16x256x256.ckpt
Training orders: 16x256x256
16x256x256 HQ 16x512x512 HQ. These model weights are partially initialized from PixArt-α. The number of parameters is 724M. More information about training can be found in HPC-AI Tech's report. More about the dataset can be found in datasets.md from HPC-AI Tech. HQ means high quality.
-
PixArt-α: Download the pth checkpoint from here (for training only)
Convert to ms checkpoint:
python tools/convert_pt2ms.py --src /path/to/PixArt-XL-2-512x512.pth --target models/PixArt-XL-2-512x512.ckpt
# OSv1.2
python scripts/inference.py --config configs/opensora-v1-2/inference/sample_iv2v.yaml --ckpt_path /path/to/your/opensora-v1-1.ckpt
# OSv1.1
python scripts/inference.py --config configs/opensora-v1-1/inference/sample_iv2v.yaml --ckpt_path /path/to/your/opensora-v1-1.ckpt
For parallel inference, please use
mpirun
ormsrun
, and append--use_parallel=True
to the inference script referring toscripts/run/run_infer_os_v1.1_t2v_parallel.sh
In the sample_iv2v.yaml
, provide such information as loop
, condition_frame_length
, captions
, mask_strategy
,
and reference_path
.
See here for more details.
For inference with sequence parallelism using multiple NPUs in Open-Sora 1.2, please use
msrun
and append--use_parallel True
and--enable_sequence_parallelism True
to the inference script, referring toscripts/run/run_infer_sequence_parallel.sh
.
To generate a video from text, you can use sample_t2v.yaml
or set --reference_path
to an empty string ''
when using sample_iv2v.yaml
.
python scripts/inference.py --config configs/opensora-v1-1/inference/sample_t2v.yaml --ckpt_path /path/to/your/opensora-v1-1.ckpt
We evaluate the inference performance of text-to-video generation by measuring the average sampling time per step and the total sampling time of a video.
All experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
model name | cards | batch size | resolution | jit level | precision | scheduler | step | graph compile | s/step | s/video | recipe |
---|---|---|---|---|---|---|---|---|---|---|---|
STDiT2-XL/2 | 1 | 1 | 16x640x360 | O0 | bf16 | DDPM | 100 | 1~2 mins | 1.56 | 156.00 | yaml |
STDiT3-XL/2 | 1 | 1 | 51x720x1280 | O0 | bf16 | RFlow | 30 | 1~2 mins | 5.88 | 176.40 | yaml |
STDiT3-XL/2 | 1 | 1 | 102x720x1280 | O0 | bf16 | RFlow | 30 | 1~2 min | 13.71 | 411.30 | yaml |
Instructions
You can run text-to-video inference via the script scripts/inference.py
as follows.
# Sample 16x256x256 videos
python scripts/inference.py --config configs/opensora/inference/stdit_256x256x16.yaml --ckpt_path models/OpenSora-v1-HQ-16x256x256.ckpt --prompt_path /path/to/prompt.txt
# Sample 16x512x512 videos
python scripts/inference.py --config configs/opensora/inference/stdit_512x512x16.yaml --ckpt_path models/OpenSora-v1-HQ-16x512x512.ckpt --prompt_path /path/to/prompt.txt
# Sample 64x512x512 videos
python scripts/inference.py --config configs/opensora/inference/stdit_512x512x64.yaml --ckpt_path /path/to/your/opensora-v1.ckpt --prompt_path /path/to/prompt.txt
For parallel inference, please use
mpirun
ormsrun
, and append--use_parallel=True
to the inference script referring toscripts/run/run_infer_t2v_parallel.sh
We also provide a three-stage sampling script run_sole_3stages.sh
to reduce memory limitation, which decomposes the whole pipeline into text embedding, text-to-video latent sampling, and vae decoding.
For more usage on the inference script, please run python scripts/inference.py -h
We evaluate the inference performance of text-to-video generation by measuring the average sampling time per step and the total sampling time of a video.
All experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
model name | cards | batch size | resolution | jit level | precision | scheduler | step | graph compile | s/step | s/video | recipe |
---|---|---|---|---|---|---|---|---|---|---|---|
STDiT-XL/2 | 1 | 4 | 16x256x256 | O0 | fp32 | DDPM | 100 | 2~3 mins | 0.39 | 39.22 | yaml |
STDiT-XL/2 | 1 | 1 | 16x512x512 | O0 | fp32 | DDPM | 100 | 2~3 mins | 1.85 | 185.00 | yaml |
STDiT-XL/2 | 1 | 1 | 64x512x512 | O0 | bf16 | DDPM | 100 | 2~3 mins | 2.78 | 278.45 | yaml |
⚠️ Note: When running parallel inference scripts underscripts/run/
on ModelArts, pleaseunset RANK_TABLE_FILE
before the inference starts.
Currently, we are developing the complete pipeline for data processing from raw videos to high-quality text-video pairs. We provide the data processing tools as follows.
View more
The text-video pair data should be organized as follows, for example.
.
├── video_caption.csv
├── video_folder
│ ├── part01
│ │ ├── vid001.mp4
│ │ ├── vid002.mp4
│ │ └── ...
│ └── part02
│ ├── vid001.mp4
│ ├── vid002.mp4
│ └── ...
The video_folder
contains all the video files. The csv file video_caption.csv
records the relative video path and its text caption in each line, as follows.
video,caption
video_folder/part01/vid001.mp4,a cartoon character is walking through
video_folder/part01/vid002.mp4,a red and white ball with an angry look on its face
For acceleration, we pre-compute the t5 embedding before training stdit.
python scripts/infer_t5.py \
--csv_path /path/to/video_caption.csv \
--output_path /path/to/text_embed_folder \
--model_max_length 300 # 300 for OpenSora v1.2, 200 for OpenSora v1.1, 120 for OpenSora 1.0
OpenSora v1 uses text embedding sequence length of 120 (by default).
If you want to generate text embeddings for OpenSora v1.1, please change model_max_length
to 200.
After running, the text embeddings saved as npz file for each caption will be in output_path
. Please change csv_path
to your video-caption annotation file accordingly.
If the storage budget is sufficient, you may also cache the video embedding by
python scripts/infer_vae.py \
--csv_path /path/to/video_caption.csv \
--video_folder /path/to/video_folder \
--output_path /path/to/video_embed_folder \
--vae_checkpoint models/sd-vae-ft-ema.ckpt \
--image_size 512 \
for parallel running, please refer to
scripts/run/run_infer_vae_parallel.sh
For more usage, please check python scripts/infer_vae.py -h
After running, the vae latents saved as npz file for each video will be in output_path
.
Finally, the training data should be like follows.
.
├── video_caption.csv
├── video_folder
│ ├── part01
│ │ ├── vid001.mp4
│ │ ├── vid002.mp4
│ │ └── ...
│ └── part02
│ ├── vid001.mp4
│ ├── vid002.mp4
│ └── ...
├── text_embed_folder
│ ├── part01
│ │ ├── vid001.npz
│ │ ├── vid002.npz
│ │ └── ...
│ └── part02
│ ├── vid001.npz
│ ├── vid002.npz
│ └── ...
├── video_embed_folder # optional
│ ├── part01
│ │ ├── vid001.npz
│ │ ├── vid002.npz
│ │ └── ...
│ └── part02
│ ├── vid001.npz
│ ├── vid002.npz
│ └── ...
Each npz file contains data for the following keys:
latent_mean
mean of vae latent distributionlatent_std
: std of vae latent distributionfps
: video fpsori_size
: original size (h, w) of the video
After caching VAE, you can use them for STDiT training by parsing --vae_latent_folder=/path/to/video_embed_folder
to the training script python train.py
.
If there are multiple folders named in latent_{h}x{w}
format under the --vae_latent_folder
folder (which is parsed to train.py), one of resolutions will selected randomly during training. For example:
video_embed_folder
├── latent_576x1024
│ ├── vid001.npz
│ ├── vid002.npz
│ └── ...
└── latent_1024x576
├── vid001.npz
├── vid002.npz
└── ...
Once you prepare the data in a csv file, you may run the following commands to launch training on a single card.
# standalone training for stage 2
export MS_DEV_ENABLE_KERNEL_PACKET=on
python scripts/train.py --config configs/opensora-v1-2 /train/train_stage2.yaml \
--csv_path /path/to/video_caption.csv \
--video_folder /path/to/video_folder \
--text_embed_folder /path/to/text_embed_folder \
text_embed_folder
is required and used to speed up the training. You can find the instructions on how to generate T5 embeddings here.
For parallel training, use msrun
and along with --use_parallel=True
:
# distributed training for stage 2
export MS_DEV_ENABLE_KERNEL_PACKET=on
msrun --worker_num=8 --local_worker_num=8 --log_dir=$output_dir \
python scripts/train.py --config configs/opensora-v1-2/train/train_stage2.yaml \
--csv_path /path/to/video_caption.csv \
--video_folder /path/to/video_folder \
--text_embed_folder /path/to/text_embed_folder \
--use_parallel True
You can modify the training configuration, including hyper-parameters and data settings, in the yaml file specified by the --config
argument.
OpenSora v1.2 supports training with multiple resolutions, aspect ratios, and frames based on the bucket method.
To enable dynamic training for STDiT3, please set the bucket_config
to fit your datasets and tasks at first. An example (from configs/opensora-v1-2/train/train_stage2.yaml
) is
bucket_config:
# Structure: "resolution": { num_frames: [ keep_prob, batch_size ] }
"144p": { 1: [ 1.0, 475 ], 51: [ 1.0, 51 ], 102: [ [ 1.0, 0.33 ], 27 ], 204: [ [ 1.0, 0.1 ], 13 ], 408: [ [ 1.0, 0.1 ], 6 ] }
"256": { 1: [ 0.4, 297 ], 51: [ 0.5, 20 ], 102: [ [ 0.5, 0.33 ], 10 ], 204: [ [ 0.5, 1.0 ], 5 ], 408: [ [ 0.5, 1.0 ], 2 ] }
"240p": { 1: [ 0.3, 297 ], 51: [ 0.4, 20 ], 102: [ [ 0.4, 0.33 ], 10 ], 204: [ [ 0.4, 1.0 ], 5 ], 408: [ [ 0.4, 1.0 ], 2 ] }
"360p": { 1: [ 0.5, 141 ], 51: [ 0.15, 8 ], 102: [ [ 0.3, 0.5 ], 4 ], 204: [ [ 0.3, 1.0 ], 2 ], 408: [ [ 0.5, 0.5 ], 1 ] }
"512": { 1: [ 0.4, 141 ], 51: [ 0.15, 8 ], 102: [ [ 0.2, 0.4 ], 4 ], 204: [ [ 0.2, 1.0 ], 2 ], 408: [ [ 0.4, 0.5 ], 1 ] }
"480p": { 1: [ 0.5, 89 ], 51: [ 0.2, 5 ], 102: [ 0.2, 2 ], 204: [ 0.1, 1 ] }
"720p": { 1: [ 0.1, 36 ], 51: [ 0.03, 1 ] }
"1024": { 1: [ 0.1, 36 ], 51: [ 0.02, 1 ] }
"1080p": { 1: [ 0.01, 5 ] }
"2048": { 1: [ 0.01, 5 ] }
Knowing that the optimal bucket config can varies from device to device, we have tuned and provided bucket config that are more balanced on Ascend + MindSpore in configs/opensora-v1-2/train/{stage}_ms.yaml
. You may use them for better training performance.
More details on the bucket configuration can be found in Multi-resolution Training with Buckets.
The instruction for launching the dynamic training task is smilar to the previous section. An example running script is scripts/run/run_train_os1.2_stage2.sh
.
Instructions
Once you prepare the data in a csv file, you may run the following commands to launch training on a single card.
# standalone training for stage 1
python scripts/train.py --config configs/opensora-v1-1/train/train_stage1.yaml \
--csv_path /path/to/video_caption.csv \
--video_folder /path/to/video_folder \
--text_embed_folder /path/to/text_embed_folder \
--vae_latent_folder /path/to/video_embed_folder
text_embed_folder
and vae_latent_folder
are optional and used to speed up the training.
You can find more in T5 text embeddings and VAE Video Embeddings
For parallel training, use msrun
and along with --use_parallel=True
:
# distributed training for stage 1
msrun --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir=$output_dir \
python scripts/train.py --config configs/opensora-v1-1/train/train_stage1.yaml \
--csv_path /path/to/video_caption.csv \
--video_folder /path/to/video_folder \
--text_embed_folder /path/to/text_embed_folder \
--vae_latent_folder /path/to/video_embed_folder \
--use_parallel True
OpenSora v1.1 supports training with multiple resolutions, aspect ratios, and a variable number of frames. This can be enabled in one of two ways:
- Provide variable sized VAE embeddings with the
--vae_latent_folder
option. - Use
bucket_config
for training with videos in their original format. More on the bucket configuration can be found in Multi-resolution Training with Buckets.
Detailed running command can be referred in scripts/run/run_train_os_v1.1_stage2.sh
Instructions
Once the training data including the T5 text embeddings is prepared, you can run the following commands to launch training.
# standalone training, 16x256x256
python scripts/train.py --config configs/opensora/train/stdit_256x256x16_ms.yaml \
--csv_path /path/to/video_caption.csv \
--video_folder /path/to/video_folder \
--text_embed_folder /path/to/text_embed_folder \
To use the cached video embedding, please replace
--video_folder
with--video_embed_folder
and pass the path to the video embedding folder.
For parallel training, please use msrun
and pass --use_parallel=True
# 8 NPUs, 64x512x512
msrun --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir=$output_dir \
python scripts/train.py --config configs/opensora/train/stdit_512x512x64_ms.yaml \
--csv_path /path/to/video_caption.csv \
--video_folder /path/to/video_folder \
--text_embed_folder /path/to/text_embed_folder \
--use_parallel True \
To train in bfloat16 precision, please parse --global_bf16=True
For more usage, please check python scripts/train.py -h
.
You may also see the example shell scripts in scripts/run
for quick reference.
Open-Sora 1.2 based on MindSpore and Ascend 910* supports 0s~16s, 144p to 720p, various aspect ratios video generation. The supported configurations are listed below.
image | 2s | 4s | 8s | 16s | |
---|---|---|---|---|---|
240p | ✅ | ✅ | ✅ | ✅ | ✅ |
360p | ✅ | ✅ | ✅ | ✅ | ✅ |
480p | ✅ | ✅ | ✅ | ✅ | 🆗 |
720p | ✅ | ✅ | ✅ | 🆗 | 🆗 |
Here ✅ means that the data is seen during training, and 🆗 means although not trained, the model can inference at that config. Inference for 🆗 requires sequence parallelism.
We evaluate the training performance of Open-Sora v1.2 on the MixKit dataset with high-resolution videos (1080P, duration 12s to 100s).
All experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
model name | cards | batch size | resolution | precision | sink | jit level | graph compile | s/step | recipe |
---|---|---|---|---|---|---|---|---|---|
STDiT3-XL/2 | 8 | 1 | 51x720x1280 | bf16 | ON | O1 | 12 mins | 14.23 | yaml |
STDiT3-XL/2 | 8 | dynamic | stage 1 | bf16 | OFF | O1 | 22 mins | 13.17 | yaml |
STDiT3-XL/2 | 8 | dynamic | stage 2 | bf16 | OFF | O1 | 22 mins | 31.04 | yaml |
STDiT3-XL/2 | 8 | dynamic | stage 3 | bf16 | OFF | O1 | 22 mins | 31.17 | yaml |
Note that the step time of dynamic training can be influenced by the resolution and duration distribution of the source videos.
To reproduce the above performance, you may refer to scripts/run/run_train_os1.2_720x1280x51.sh
and scripts/run/run_train_os1.2_stage2.sh
.
Below are some generation results after fine-tuning STDiT3 with Stage 2 bucket config on a mixkit subset, which contains 100 text-video pairs. The training set contains 80 1080P videos consisting of natural scenes, flowers, and pets. Here we show the text-to-video generation results on the test set.
480x854x204 | 480x854x204 |
019-The-video-begins-with-a-completely-black-screen.-which-quickly.mp4 |
009-The-video-features-a-person-in-a-white-lace-wedding.mp4 |
480x854x204 | 480x854x204 |
005-The-video-showcases-a-small-dog-with-a-light-brown.mp4 |
001-The-video-showcases-a-black-and-white-dog-engaging-in.mp4 |
View more
We evaluate the training performance of Open-Sora v1.1 on a subset of the MixKit dataset.
All experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
model name | cards | batch size | resolution | vae cache | precision | sink | jit level | graph compile | s/step | recipe |
---|---|---|---|---|---|---|---|---|---|---|
STDiT3-XL/2 | 8 | 1 | 16x512x512 | OFF | bf16 | OFF | O1 | 13 mins | 2.28 | yaml |
STDiT3-XL/2 | 8 | 1 | 64x512x512 | OFF | bf16 | OFF | O1 | 13 mins | 8.57 | yaml |
STDiT3-XL/2 | 8 | 1 | 24x576x1024 | OFF | bf16 | OFF | O1 | 13 mins | 8.55 | yaml |
STDiT3-XL/2 | 8 | 1 | 64x576x1024 | ON | bf16 | OFF | O1 | 13 mins | 18.94 | yaml |
vae cache: whether vae embedding is pre-computed and cached before training.
Note that T5 text embedding is pre-computed before training.
Here are some generation results after fine-tuning STDiT2 on a mixkit subset.
576x1024x48 | 576x1024x48 |
000-a-breathtaking-aerial-view-of-a-vast-landscape.-The-foreground.mp4 |
001-a-close-up-view-of-a-tree-branch-adorned-with-vibrant.mp4 |
576x1024x48 | 576x1024x48 |
005-a-serene-landscape.-bathed-in-the-soft-glow-of-daylight.mp4 |
003-a-vibrant-scene-dominated-by-a-cluster-of-pink-bougainvillea.mp4 |
View more
All experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
model name | cards | batch size | resolution | stage | precision | sink | jit level | graph compile | s/step | recipe |
---|---|---|---|---|---|---|---|---|---|---|
STDiT-XL/2 | 8 | 3 | 16x256x256 | 1 | fp16 | ON | O1 | 5~6 mins | 1.53 | yaml |
STDiT-XL/2 | 8 | 1 | 16x512x512 | 2 | fp16 | ON | O1 | 5~6 mins | 2.47 | yaml |
STDiT-XL/2 | 8 | 1 | 64x512x512 | 3 | bf16 | ON | O1 | 5~6 mins | 8.52 | yaml |
Here are some generation results after fine-tuning STDiT on a subset of WebVid dataset.
512x512x64 | 512x512x64 | 512x512x64 |
001-Cloudy-moscow-kremlin-time-lapse.mp4 |
003-The-girl-received-flowers-as-a-gift.-a-gift-for.mp4 |
004-A-baker-turns-freshly-baked-loaves-of-sourdough-bread.mp4 |
For quality evaluation, please refer to the original HPC-AI Tech evaluation doc for video generation quality evaluation.
A 3D-VAE pipeline consisting of a spatial VAE followed by a temporal VAE is trained in OpenSora v1.1. For more details, refer to VAE Documentation.
-
Download pretained VAE-2D checkpoint from PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers if you aim to train VAE-3D from spatial VAE initialization.
Convert to ms checkpoint:
python tools/convert_vae1.2.py --src /path/to/pixart_sigma_sdxlvae_T5_diffusers/vae/diffusion_pytorch_model.safetensors --target models/sdxl_vae.ckpt --from_vae2d
-
Downalod pretrained VAE-3D checkpoint from hpcai-tech/OpenSora-VAE-v1.2 if you aim to train VAEA-3D from the VAE-3D model pre-trained with 3 stages.
Convert to ms checkpoint:
python tools/convert_vae1.2.py --src /path/OpenSora-VAE-v1.2/models.safetensors --target models/OpenSora-VAE-v1.2/sdxl_vae.ckpt
-
Download lpips mindspore checkpoint from here and put it under 'models/'
Before VAE-3D training, we need to prepare a csv annotation file for the training videos. The csv file list the path to each video related to the root video_folder
. An example is
video
dance/vid001.mp4
dance/vid002.mp4
...
Taking UCF-101 for example, please download the UCF-101 dataset and extract it to datasets/UCF-101
folder. You can generate the csv annotation by running python tools/annotate_vae_ucf101.py
. It will result in two csv files, datasets/ucf101_train.csv
and datasets/ucf101_test.csv
, for training and testing respectively.
# stage 1 training, 8 NPUs
msrun --worker_num=8 --local_work_num=8 \
python scripts/train_vae.py --config configs/vae/train/stage1.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101
# stage 2 training, 8 NPUs
msrun --worker_num=8 --local_work_num=8 \
python scripts/train_vae.py --config configs/vae/train/stage2.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101
# stage 3 training, 8 NPUs
msrun --worker_num=8 --local_work_num=8 \
python scripts/train_vae.py --config configs/vae/train/stage3.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101
You can change the csv_path
and video_folder
to train on your own data.
To evaluate the VAE performance, you need to run VAE inference first to generate the videos, then calculate scores on the generated videos:
# video generation and evaluation
python scripts/inference_vae.py --ckpt_path /path/to/you_vae_ckpt --image_size 256 --num_frames=17 --csv_path datasets/ucf101_test.csv --video_folder datasets/UCF-101
You can change the csv_path
and video_folder
to evaluate on your own data.
Here, we report the training performance and evaluation results on the UCF-101 dataset.
All experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
model name | cards | batch size | resolution | precision | jit level | graph compile | s/step | PSNR | SSIM | recipe |
---|---|---|---|---|---|---|---|---|---|---|
VAE-3D | 8 | 1 | 17x256x256 | bf16 | O1 | 5 mins | 1.09 | 29.02 | 0.87 | yaml |
Note that we train with mixed video ang image strategy i.e. --mixed_strategy=mixed_video_image
for stage 3 instead of random number of frames (mixed_video_random
). Random frame training will be supported in the future.
Thanks go to the support from MindSpore team and the open-source contributions from the OpenSora project.
If you wish to contribute to this project, you can refer to the Contribution Guideline.
- ColossalAI: A powerful large model parallel acceleration and optimization system.
- DiT: Scalable Diffusion Models with Transformers.
- OpenDiT: An acceleration for DiT training. We adopt valuable acceleration strategies for training progress from OpenDiT.
- PixArt: An open-source DiT-based text-to-image model.
- Latte: An attempt to efficiently train DiT for video.
- StabilityAI VAE: A powerful image VAE model.
- CLIP: A powerful text-image embedding model.
- T5: A powerful text encoder.
- LLaVA: A powerful image captioning model based on Mistral-7B and Yi-34B.
We are grateful for their exceptional work and generous contribution to open source.