Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LTX 0.9.5 #10968

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

LTX 0.9.5 #10968

wants to merge 6 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu March 5, 2025 00:37
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w changed the title Fix documentation LTX 0.9.5 Mar 5, 2025
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Mar 5, 2025

Code for matching VAE:

import sys
sys.path.append("/raid/aryan/ltx-code")

import json
from typing import Any, Dict

import torch
from safetensors.torch import load_file
from safetensors import safe_open

from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder


def remove_keys_(key: str, state_dict: Dict[str, Any]):
    state_dict.pop(key)

VAE_KEYS_RENAME_DICT = {
    # decoder
    "up_blocks.0": "mid_block",
    "up_blocks.1": "up_blocks.0",
    "up_blocks.2": "up_blocks.1.upsamplers.0",
    "up_blocks.3": "up_blocks.1",
    "up_blocks.4": "up_blocks.2.conv_in",
    "up_blocks.5": "up_blocks.2.upsamplers.0",
    "up_blocks.6": "up_blocks.2",
    "up_blocks.7": "up_blocks.3.conv_in",
    "up_blocks.8": "up_blocks.3.upsamplers.0",
    "up_blocks.9": "up_blocks.3",
    # encoder
    "down_blocks.0": "down_blocks.0",
    "down_blocks.1": "down_blocks.0.downsamplers.0",
    "down_blocks.2": "down_blocks.0.conv_out",
    "down_blocks.3": "down_blocks.1",
    "down_blocks.4": "down_blocks.1.downsamplers.0",
    "down_blocks.5": "down_blocks.1.conv_out",
    "down_blocks.6": "down_blocks.2",
    "down_blocks.7": "down_blocks.2.downsamplers.0",
    "down_blocks.8": "down_blocks.3",
    "down_blocks.9": "mid_block",
    # common
    "conv_shortcut": "conv_shortcut.conv",
    "res_blocks": "resnets",
    "norm3.norm": "norm3",
    "per_channel_statistics.mean-of-means": "latents_mean",
    "per_channel_statistics.std-of-means": "latents_std",
}

VAE_091_RENAME_DICT = {
    # decoder
    "up_blocks.0": "mid_block",
    "up_blocks.1": "up_blocks.0.upsamplers.0",
    "up_blocks.2": "up_blocks.0",
    "up_blocks.3": "up_blocks.1.upsamplers.0",
    "up_blocks.4": "up_blocks.1",
    "up_blocks.5": "up_blocks.2.upsamplers.0",
    "up_blocks.6": "up_blocks.2",
    "up_blocks.7": "up_blocks.3.upsamplers.0",
    "up_blocks.8": "up_blocks.3",
    # common
    "last_time_embedder": "time_embedder",
    "last_scale_shift_table": "scale_shift_table",
}

VAE_095_RENAME_DICT = {
    # decoder
    "up_blocks.0": "mid_block",
    "up_blocks.1": "up_blocks.0.upsamplers.0",
    "up_blocks.2": "up_blocks.0",
    "up_blocks.3": "up_blocks.1.upsamplers.0",
    "up_blocks.4": "up_blocks.1",
    "up_blocks.5": "up_blocks.2.upsamplers.0",
    "up_blocks.6": "up_blocks.2",
    "up_blocks.7": "up_blocks.3.upsamplers.0",
    "up_blocks.8": "up_blocks.3",
    # encoder
    "down_blocks.0": "down_blocks.0",
    "down_blocks.1": "down_blocks.0.downsamplers.0",
    "down_blocks.2": "down_blocks.1",
    "down_blocks.3": "down_blocks.1.downsamplers.0",
    "down_blocks.4": "down_blocks.2",
    "down_blocks.5": "down_blocks.2.downsamplers.0",
    "down_blocks.6": "down_blocks.3",
    "down_blocks.7": "down_blocks.3.downsamplers.0",
    "down_blocks.8": "mid_block",
    # common
    "last_time_embedder": "time_embedder",
    "last_scale_shift_table": "scale_shift_table",
}

VAE_SPECIAL_KEYS_REMAP = {
    "per_channel_statistics.channel": remove_keys_,
    "per_channel_statistics.mean-of-means": remove_keys_,
    "per_channel_statistics.mean-of-stds": remove_keys_,
    "model.diffusion_model": remove_keys_,
}

VAE_091_SPECIAL_KEYS_REMAP = {
    "timestep_scale_multiplier": remove_keys_,
}

VAE_095_SPECIAL_KEYS_REMAP = {
    
}


def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
    state_dict[new_key] = state_dict.pop(old_key)


def convert_vae(original_state_dict):
    PREFIX_KEY = "vae."

    for key in list(original_state_dict.keys()):
        new_key = key[:]
        if new_key.startswith(PREFIX_KEY):
            new_key = key[len(PREFIX_KEY) :]
        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_inplace(original_state_dict, key, new_key)

    for key in list(original_state_dict.keys()):
        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, original_state_dict)

    return original_state_dict


@torch.no_grad()
def match_vae():
    from diffusers import AutoencoderKLLTXVideo

    original_model_path = "/raid/aryan/ltx-new/ltx-video-2b-v0.9.5rc1.safetensors"
    theirs_config = json.loads(safe_open(original_model_path, "pt").metadata()["config"])
    theirs_model = CausalVideoAutoencoder.from_config(theirs_config["vae"])
    theirs_state_dict = load_file(original_model_path)
    theirs_model.load_state_dict(theirs_state_dict)

    ours_config = {
        "in_channels": 3,
        "out_channels": 3,
        "latent_channels": 128,
        "block_out_channels": (128, 256, 512, 1024, 2048),
        "down_block_types": (
            "LTXVideo095DownBlock3D",
            "LTXVideo095DownBlock3D",
            "LTXVideo095DownBlock3D",
            "LTXVideo095DownBlock3D",
        ),
        "decoder_block_out_channels": (256, 512, 1024),
        "layers_per_block": (4, 6, 6, 2, 2),
        "decoder_layers_per_block": (5, 5, 5, 5),
        "spatio_temporal_scaling": (True, True, True, True),
        "decoder_spatio_temporal_scaling": (True, True, True),
        "decoder_inject_noise": (False, False, False, False),
        "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
        "upsample_residual": (True, True, True),
        "upsample_factor": (2, 2, 2),
        "timestep_conditioning": True,
        "patch_size": 4,
        "patch_size_t": 1,
        "resnet_norm_eps": 1e-6,
        "scaling_factor": 1.0,
        "encoder_causal": True,
        "decoder_causal": False,
    }
    ours_model = AutoencoderKLLTXVideo.from_config(ours_config)

    VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
    VAE_SPECIAL_KEYS_REMAP.update(VAE_095_SPECIAL_KEYS_REMAP)
    ours_state_dict = convert_vae(theirs_state_dict)
    ours_model.load_state_dict(ours_state_dict)

    state_dict_params = sum(p.numel() for p in ours_state_dict.values())
    print(f"State dict params: {state_dict_params}")

    device = torch.device("cuda")
    dtype = torch.float32

    theirs_model.to(device=device, dtype=dtype)
    ours_model.to(device=device, dtype=dtype)

    theirs_model.disable_z_tiling()
    theirs_model.disable_hw_tiling()

    print(sum(p.numel() for p in theirs_model.parameters()))
    print(sum(p.numel() for p in ours_model.parameters()))

    batch_size = 1
    num_channels = 3
    num_frames = 49
    height = 128
    width = 128

    torch.manual_seed(0)
    input = torch.randn(batch_size, num_channels, num_frames, height, width, device=device, dtype=dtype)
    decode_timestep = 0.025

    print("theirs_encoding")
    theirs_encoder_output = theirs_model.encode(input).latent_dist.mode()
    print("theirs_decoding")
    theirs_decoder_output = theirs_model.decode(theirs_encoder_output, timestep=decode_timestep, target_shape=(batch_size, num_channels, num_frames, height, width)).sample
    print("theirs:", theirs_encoder_output.shape, theirs_decoder_output.shape)

    print("ours_encoding")
    ours_encoder_output = ours_model.encode(input).latent_dist.mode()
    print("ours_decoding")
    ours_decoder_output = ours_model.decode(ours_encoder_output, temb=decode_timestep).sample
    print("ours:", ours_encoder_output.shape, ours_decoder_output.shape)

    diff_encoder = theirs_encoder_output - ours_encoder_output
    diff_decoder = theirs_decoder_output - ours_decoder_output

    absmax_encoder, absmean_encoder = torch.max(diff_encoder.abs()), torch.mean(diff_encoder.abs())
    absmax_decoder, absmean_decoder = torch.max(diff_decoder.abs()), torch.mean(diff_decoder.abs())

    print(f"Encoder: absmax={absmax_encoder}, absmean={absmean_encoder}")
    print(f"Decoder: absmax={absmax_decoder}, absmean={absmean_decoder}")


match_vae()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants