Skip to content

Commit

Permalink
add onnx support
Browse files Browse the repository at this point in the history
  • Loading branch information
warmshao committed Jul 16, 2024
1 parent d103fe7 commit a78cccb
Showing 1 changed file with 50 additions and 46 deletions.
96 changes: 50 additions & 46 deletions export_onnx.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
# -*- coding: utf-8 -*-
# @Time : 2024/7/13 12:03
# @Author : shaoguowen
# @Email : wenshaoguo1026@gmail.com
# @Project : LivePortrait
# @FileName: export_onnx.py

"""
reference speed.py
"""
import os
import pdb
import onnx
from onnxsim import simplify
import yaml
import torch
from torch import nn
import time
import numpy as np
from src.utils.helper import load_model, concat_feat
Expand Down Expand Up @@ -48,6 +43,26 @@ def initialize_inputs(batch_size=1):
return inputs


def simplify_onnx_model(onnx_model_path):
# load your predefined ONNX model
model = onnx.load(onnx_model_path)

# convert model
model_simp, check = simplify(model)
onnx.save(model_simp, onnx_model_path)

class WarpingSpadeModel(nn.Module):
def __init__(self, warping_module, spade_generator):
super(WarpingSpadeModel, self).__init__()
self.warping_module = warping_module
self.spade_generator = spade_generator

def forward(self, feature_3d, kp_driving, kp_source):
occlusion_map, deformation, out = self.warping_module.forward_onnx(feature_3d, kp_driving, kp_source)
out = self.spade_generator(out)
return out


def load_torch_models(cfg, model_config):
"""
Load and compile models for inference
Expand All @@ -61,11 +76,13 @@ def load_torch_models(cfg, model_config):
spade_generator = load_model(cfg.checkpoint_G, model_config, "cpu", 'spade_generator').eval()
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, "cpu",
'stitching_retargeting_module')

return appearance_feature_extractor, motion_extractor, warping_module, spade_generator, stitching_retargeting_module
warping_spade_model = WarpingSpadeModel(warping_module, spade_generator).eval().cpu()
return appearance_feature_extractor, motion_extractor, warping_spade_model, spade_generator, stitching_retargeting_module


if __name__ == '__main__':
onnx_save_dir = "pretrained_weights/liveportrait_onnx"
os.makedirs(onnx_save_dir, exist_ok=True)
inputs = initialize_inputs()
"""
input feature_3d shape: torch.Size([1, 32, 16, 64, 64])
Expand Down Expand Up @@ -100,14 +117,15 @@ def load_torch_models(cfg, model_config):
torch.onnx.export(
appearance_feature_extractor,
(inputs['source_image'],),
os.path.join("pretrained_weights/liveportrait_onnx", "appearance_feature_extractor.onnx"),
os.path.join(onnx_save_dir, "appearance_feature_extractor.onnx"),
export_params=True,
opset_version=20,
opset_version=16,
do_constant_folding=True,
input_names=['img'],
output_names=['output'],
dynamic_axes=None
)
simplify_onnx_model(os.path.join(onnx_save_dir, "appearance_feature_extractor.onnx"))

# export appearance_feature_extractor
print("export motion_extractor >>> ")
Expand All @@ -126,14 +144,15 @@ def load_torch_models(cfg, model_config):
torch.onnx.export(
motion_extractor,
(inputs['source_image'],),
os.path.join("pretrained_weights/liveportrait_onnx", "motion_extractor.onnx"),
os.path.join(onnx_save_dir, "motion_extractor.onnx"),
export_params=True,
opset_version=20,
opset_version=16,
do_constant_folding=True,
input_names=['img'],
output_names=['pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'],
dynamic_axes=None
)
simplify_onnx_model(os.path.join(onnx_save_dir, "motion_extractor.onnx"))

# export appearance_feature_extractor
print("export warping_module >>> ")
Expand All @@ -143,21 +162,22 @@ def load_torch_models(cfg, model_config):
warping_module->deformation shape: torch.Size([1, 16, 64, 64, 3])
warping_module->out shape: torch.Size([1, 256, 64, 64])
"""
for i, key in enumerate(['occlusion_map', 'deformation', 'out']):
print(f"warping_module->{key} shape:", warping_outputs[i].shape)
# for i, key in enumerate(['occlusion_map', 'deformation', 'out']):
# print(f"warping_module->{key} shape:", warping_outputs[i].shape)
print(f"warping_module output shape:", warping_outputs.shape)
# use pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
torch.onnx.export(
warping_module,
(inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']),
os.path.join("pretrained_weights/liveportrait_onnx", "warping.onnx"),
os.path.join(onnx_save_dir, "warping_spade.onnx"),
export_params=True,
opset_version=20,
opset_version=16,
do_constant_folding=True,
input_names=['feature_3d', 'kp_driving', 'kp_source'],
output_names=['occlusion_map', 'deformation', 'out'],
output_names=['out'],
dynamic_axes=None
)

simplify_onnx_model(os.path.join(onnx_save_dir, "warping_spade.onnx"))

def modify_onnx_model(onnx_model_path, onnx_save_path, custom_op_name="GridSample3D"):
model = onnx.load(onnx_model_path)
Expand All @@ -167,27 +187,8 @@ def modify_onnx_model(onnx_model_path, onnx_save_path, custom_op_name="GridSampl
onnx.save(model, onnx_save_path)


modify_onnx_model(os.path.join("pretrained_weights/liveportrait_onnx", "warping.onnx"),
os.path.join("pretrained_weights/liveportrait_onnx", "warping-fix.onnx"))

# spade_generator export
print("export spade_generator >>> ")
spade_outputs = spade_generator(inputs['generator_input'])
"""
spade_generator output shape: torch.Size([1, 3, 512, 512])
"""
print(f"spade_generator output shape:", spade_outputs.shape)
torch.onnx.export(
spade_generator,
(inputs['generator_input'],),
os.path.join("pretrained_weights/liveportrait_onnx", "spade_generator.onnx"),
export_params=True,
opset_version=20,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes=None
)
modify_onnx_model(os.path.join("pretrained_weights/liveportrait_onnx", "warping_spade.onnx"),
os.path.join("pretrained_weights/liveportrait_onnx", "warping_spade-fix.onnx"))

# stitching export
print("export stitching >>> ")
Expand All @@ -200,14 +201,15 @@ def modify_onnx_model(onnx_model_path, onnx_save_path, custom_op_name="GridSampl
torch.onnx.export(
stitching_model,
(inputs['feat_stitching'],),
os.path.join("pretrained_weights/liveportrait_onnx", "stitching.onnx"),
os.path.join(onnx_save_dir, "stitching.onnx"),
export_params=True,
opset_version=20,
opset_version=16,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes=None
)
simplify_onnx_model(os.path.join(onnx_save_dir, "stitching.onnx"))

# eye_stitching_model export
print("export eye_stitching_model >>> ")
Expand All @@ -220,14 +222,15 @@ def modify_onnx_model(onnx_model_path, onnx_save_path, custom_op_name="GridSampl
torch.onnx.export(
eye_stitching_model,
(inputs['feat_eye'],),
os.path.join("pretrained_weights/liveportrait_onnx", "stitching_eye.onnx"),
os.path.join(onnx_save_dir, "stitching_eye.onnx"),
export_params=True,
opset_version=20,
opset_version=16,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes=None
)
simplify_onnx_model(os.path.join(onnx_save_dir, "stitching_eye.onnx"))

# eye_stitching_model export
print("export lip_stitching_model >>> ")
Expand All @@ -240,11 +243,12 @@ def modify_onnx_model(onnx_model_path, onnx_save_path, custom_op_name="GridSampl
torch.onnx.export(
lip_stitching_model,
(inputs['feat_lip'],),
os.path.join("pretrained_weights/liveportrait_onnx", "stitching_lip.onnx"),
os.path.join(onnx_save_dir, "stitching_lip.onnx"),
export_params=True,
opset_version=20,
opset_version=16,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes=None
)
simplify_onnx_model(os.path.join(onnx_save_dir, "stitching_lip.onnx"))

0 comments on commit a78cccb

Please sign in to comment.