This repository contains the implementation of the paper
MagicDriveDiT: High-Resolution Long Video Generation for Autonomous Driving with Adaptive Control
Ruiyuan Gao1, Kai Chen2, Bo Xiao3, Lanqing Hong4, Zhenguo Li4, Qiang Xu1
1CUHK 2HKUST 3Huawei Cloud 4Huawei Noah's Ark Lab
MagicDriveDiT-2c.mp4
TL; DR MagicDriveDiT generates high-resolution and long videos for street-view with diverse 3D geometry control and multiview consistency.
The rapid advancement of diffusion models has greatly improved video synthesis, especially in controllable video generation, which is essential for applications like autonomous driving. However, existing methods are limited by scalability and how control conditions are integrated, failing to meet the needs for high-resolution and long videos for autonomous driving applications. In this paper, we introduce MagicDriveDiT, a novel approach based on the DiT architecture, and tackle these challenges. Our method enhances scalability through flow matching and employs a progressive training strategy to manage complex scenarios. By incorporating spatial-temporal conditional encoding, MagicDriveDiT achieves precise control over spatial-temporal latents. Comprehensive experiments show its superior performance in generating realistic street scene videos with higher resolution and more frames. MagicDriveDiT significantly improves video generation quality and spatial-temporal controls, expanding its potential applications across various tasks in autonomous driving.
- [2024/12/07] Stage-3 checkpoint and nuScenes metadata for training & inference release!
- [2024/12/03] Train & inference code release! We will update links in readme later.
- [2024/11/22] Paper and project page released! Check https://gaoruiyuan.com/magicdrivedit/
- train & inference code
- pretrained weight for stage 3 & metadata for nuScenes
- pretrained weight for stage 1 & 2 (will be released later)
Clone this repo
git clone https://github.com/flymin/MagicDriveDiT.git
The code is tested on A800/H20/Ascend 910b servers. To setup the python environment, follow:
Note
Please use pip
to set up your environment. We DO NOT recommend using conda
+yaml
directly for environment configuration.
NVIDIA Servers step-by-step guide:
- Make sure you have an environment with the following packages:
torch==2.4.0 torchvision==0.19.0 # may need to build from source apex (https://github.com/NVIDIA/apex) # choose the correct wheel packages or build from the source xformers>=0.0.27 flash-attn>=2.6.3
- Install Colossalai
git clone https://github.com/flymin/ColossalAI.git git checkout pt2.4 && git pull cd ColossalAI BUILD_EXT=1 pip install .
- Install other dependencies
pip install -r requirements/requirements.txt
Please refer to the following yaml files for further details:
- A800:
requirements/a800_cu118.yaml
- H20:
requirements/h20_cu124.yaml
Ascend Servers step-by-step guide:
- Make sure you have an environment with the following packages (please refer to this page to setup pytorch env):
# based on CANN 8.0RC2 torch==2.3.1 torchvision==0.18.1 torch-npu==2.3.1 apex (https://gitee.com/ascend/apex) # choose the correct wheel packages or build from the source xformers==0.0.27
- Install Colossalai
# We remove dependency on `bitsandbytes`. git clone https://github.com/flymin/ColossalAI.git git checkout ascend && git pull cd ColossalAI BUILD_EXT=1 pip install .
- Install other dependencies
pip install -r requirements/requirements.txt
Please refer to requirements/910b_cann8.0.RC2_aarch64.yaml
for further details.
VAE: We use the 3DVAE from THUDM/CogVideoX-2b. It is OK if you only download the vae
sub-folder.
Text Encoder: We use T5 Encoder from google/t5-v1_1-xxl.
You should organize them as follows:
${CODE_ROOT}/pretrained/
├── CogVideoX-2b
│ └── vae
└── t5-v1_1-xxl
Please download the stage-3 checkpoint from flymin/MagicDriveDiT-stage3-40k-ft and put it in ${CODE_ROOT}/ckpts/
as:
${CODE_ROOT}/ckpts/
└── MagicDriveDiT-stage3-40k-ft
We prepare the nuScenes dataset similar to bevfusion's instructions. Specifically,
-
Download the nuScenes dataset from the website and put them in
./data/
. You should have these files:${CODE_ROOT}/data/nuscenes ├── can_bus ├── maps ├── mini ├── samples ├── sweeps ├── v1.0-mini └── v1.0-trainval
-
Download the metadata for
mmdet
from flymin/MagicDriveDiT-nuScenes-metadata.Otherwise
Please interpolate the annotations to 12Hz as MagicDrive-t, and generate the meta data by yourself with the command in
tools/prepare_data/prepare_dataset.sh
.If you have the meta data files from MagicDrive-t, you can use
tools/prepare_data/add_box_id.py
to add the keys for instance id. See commands intools/prepare_data/prepare_dataset.sh
.Your data folder should look like:
${CODE_ROOT}/data ├── nuscenes │ ├── ... │ └── interp_12Hz_trainval └── nuscenes_mmdet3d-12Hz ├── nuscenes_interp_12Hz_infos_train_with_bid.pkl └── nuscenes_interp_12Hz_infos_val_with_bid.pkl
-
(Optional) To accelerate data loading, we prepared cache files in h5 format for BEV maps.
Instructions
They can be generated through
tools/prepare_data/prepare_map_aux.py
with different configs inconfigs/cache_gen
For example:python tools/prepare_data/prepare_map_aux.py +cache_gen=map_cache_gen_interp \ +process=val +subfix=8x200x200_12Hz
Please find the full commands in
tools/prepare_data/prepare_dataset.sh
.Please make sure you move the generated cache file to the right path. Our defaults are:
${CODE_ROOT}/data/nuscenes_map_aux_12Hz ├── train_8x200x200_12Hz.h5 (25G) ├── train_8x400x400_12Hz.h5 (99G) ├── val_8x200x200_12Hz.h5 (5.3G) └── val_8x400x400_12Hz.h5 (22G)
In most cases, you can use the same commands on both GPU servers and Ascend servers.
# ${GPUS} can be 1/2/4/8 for sequence parallel.
# ${CFG} can be any file located in `configs/magicdrive/inference/`.
# ${PATH_TO_MODEL} can be path to `ema.pt` or path to `model` from the checkpoint.
# ${FRAME} can be 1/9/17/33/65/129/full...(8n+1). 1 for image; full for the full-length of nuScenes.
# `cpu_offload=true` and `scheduler.type=rflow-slice` can be omitted if you have enough GPU memory.
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
torchrun --standalone --nproc_per_node ${GPUS} scripts/inference_magicdrive.py ${CFG} \
--cfg-options model.from_pretrained=${PATH_TO_MODEL} num_frames=${FRAME} \
cpu_offload=true scheduler.type=rflow-slice
Please check FAQ for more information about GPU memory requirements.
For example, to generate the full-length video (20s@12fps) as the highest resolution (848x1600), with 8*H20/A800:
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --standalone --nproc_per_node 8 \
scripts/inference_magicdrive.py \
configs/magicdrive/inference/fullx848x1600_stdit3_CogVAE_boxTDS_wCT_xCE_wSST.py \
--cfg-options model.from_pretrained=./ckpts/MagicDriveDiT-stage3-40k-ft/ema.pt \
num_frames=full cpu_offload=true scheduler.type=rflow-slice
Other options for generation:
- `force_daytime`: (bool) force to generate daytime scenes. - `force_rainy`: (bool) force to generate rainy scenes. - `force_night`: (bool) force to generate night scenes. - `allow_class`: (list) limit the classes for generation. - `del_box_ratio`: (float) randomly drop boxes for generation. - `drop_nearest_car`: (int) drop N-nearest vehicles during generation.We generate the videos in the format of W-CODA2024 Track2 and test with the established benchmark. Before generation, please make sure the meta data for evaluation is prepared as follows:
${CODE_ROOT}/data/nuscenes_mmdet3d-12Hz
├── nuscenes_interp_12Hz_infos_track2_eval.pkl # this can be downloaded from the page for track2
└── nuscenes_interp_12Hz_infos_track2_eval_with_bid.pkl # this can be generated or downloaded from this project.
To generate the videos (with 8 GPUs/NPUs):
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # for GPU
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True # for NPU
torchrun --standalone --nproc_per_node 8 scripts/test_magicdrive.py \
configs/magicdrive/test/17-16x848x1600_stdit3_CogVAE_boxTDS_wCT_xCE_wSST_map0_fsp8_cfg2.0.py \
--cfg-options model.from_pretrained=${PATH_TO_MODEL} tag=${TAG}
Launch training with (with 32xA800/H20):
# please change "xx" to real rank and ip
# ${config} can be any file in `configs/magicdrive/train`.
# For example: configs/magicdrive/train/stage3_higher-b-v3.1-12Hz_stdit3_CogVAE_boxTDS_wCT_xCE_wSST_bs4_lr1e-5_sp4simu8.py
torchrun --nproc-per-node=8 --nnode=4 --node_rank=xx --master_addr xx --master_port 18836 \
scripts/train_magicdrive.py ${config} --cfg-options num_workers=2 prefetch_factor=2
We also use 64 Ascend 910b to train stage 2, please see the config in configs/magicdrive/npu_64g
.
Besides, we provide debug config to test your environment and data loading process:
# for example (with 4xA800)
# ${config} can be any file in `configs/magicdrive/train`.
# For example: configs/magicdrive/train/stage3_higher-b-v3.1-12Hz_stdit3_CogVAE_boxTDS_wCT_xCE_wSST_bs4_lr1e-5_sp4simu8.py
bash scripts/launch_1node.sh 4 ${config} --cfg-options debug=true
# by setting `vsdebug=true` with 1 process, you can use the 'attach mode' from vscode to debug.
Note: sp=4
(stage 3) needs at least 4 GPUs to run.
@misc{gao2024magicdrivedit,
title={{MagicDriveDiT}: High-Resolution Long Video Generation for Autonomous Driving with Adaptive Control},
author={Gao, Ruiyuan and Chen, Kai and Xiao, Bo and Hong, Lanqing and Li, Zhenguo and Xu, Qiang},
year={2024},
eprint={2411.13807},
archivePrefix={arXiv},
}
We adopt the following open-sourced projects:
- BEVFusion: dataloader to handle 3d bounding boxes and BEV map
- Open-Sora: STDiT3 and framework to train
- ColossalAI: framework for parallel and zero2
- CogVideoX: we use their CogVAE