diff --git a/tools/model_converters/pytorch2onnx.py b/tools/model_converters/pytorch2onnx.py index 09919c2b6f..f0dabb00ad 100644 --- a/tools/model_converters/pytorch2onnx.py +++ b/tools/model_converters/pytorch2onnx.py @@ -1,18 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import glob import os.path as osp -import re import warnings -from functools import reduce import cv2 -import mmcv import numpy as np import onnx import onnxruntime as rt import torch from mmcv.onnx import register_extra_symbolics +from mmengine import Config from mmengine.dataset import Compose from mmengine.runner import load_checkpoint @@ -63,7 +60,7 @@ def pytorch2onnx(model, img = input['inputs'].unsqueeze(0) data = torch.cat((img, masks), dim=1) elif model_type == 'video_restorer': - data = input['inputs'].unsqueeze(0) + data = input['inputs'].unsqueeze(0).float() data = data.to(device) # pytorch has some bug in pytorch1.3, we have to fix it @@ -158,6 +155,8 @@ def parse_args(): '--mask-path', default=None, help='path to input mask file, used in inpainting model') + parser.add_argument('--num-frames', type=int, default=None) + parser.add_argument('--sequence-length', type=int, default=None) parser.add_argument('--device', type=int, default=0, help='CUDA device id') parser.add_argument('--show', action='store_true', help='show onnx graph') parser.add_argument('--output-file', type=str, default='tmp.onnx') @@ -188,7 +187,7 @@ def parse_args(): else: device = torch.device('cuda', args.device) - config = mmcv.Config.fromfile(args.config) + config = Config.fromfile(args.config) delete_cfg(config, key='init_cfg') # ONNX does not support spectral norm @@ -217,6 +216,8 @@ def parse_args(): keys_to_remove = ['alpha', 'ori_alpha'] elif model_type == 'image_restorer': keys_to_remove = ['gt', 'gt_path'] + elif model_type == 'video_restorer': + keys_to_remove = ['gt', 'gt_path'] else: keys_to_remove = [] for key in keys_to_remove: @@ -244,17 +245,15 @@ def parse_args(): f'"GenerateSegmentIndices", but got ' f'"{test_pipeline[0]["type"]}".') # prepare data - sequence_length = len(glob.glob(osp.join(args.img_path, '*'))) - img_dir_split = re.split(r'[\\/]', args.img_path) - if img_dir_split[0] == '': - img_dir_split[0] = '/' - key = img_dir_split[-1] - lq_folder = reduce(osp.join, img_dir_split[:-1]) + # sequence_length = len(glob.glob(osp.join(args.img_path, '*'))) + lq_folder = osp.dirname(args.img_path) + key = osp.basename(args.img_path) data = dict( img_path=lq_folder, gt_path='', key=key, - sequence_length=sequence_length) + num_frames=args.num_frames, + sequence_length=args.sequence_length) # build the data pipeline test_pipeline = Compose(test_pipeline)