Skip to content

Commit

Permalink
vs: k7sfunc 更新 0.4.6
Browse files Browse the repository at this point in the history
检查部分模块所需的模型是否存在。

模块:
- 修复 MVT_STD 的内部参数错误 #397
- 限制 RIFE_STD 的参数 stat_th 超过 60.0
- UAI_DML 和 UAI_NV_TRT 现在支持调用外部目录的模型文件
  • Loading branch information
hooke007 committed Dec 28, 2023
1 parent 50bc4d3 commit f0f2192
Showing 1 changed file with 72 additions and 8 deletions.
80 changes: 72 additions & 8 deletions k7sfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
### 文档: https://github.com/hooke007/MPV_lazy/wiki/3_K7sfunc
##################################################

__version__ = "0.4.3"
__version__ = "0.4.6"

__all__ = [
"FMT_CHANGE", "FMT_CTRL", "FPS_CHANGE", "FPS_CTRL",
Expand Down Expand Up @@ -537,6 +537,12 @@ def CUGAN_NV(
if not hasattr(core, "trt") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 trt")

plg_dir = os.path.dirname(core.trt.Version()["path"]).decode()
mdl_fname = ["pro-no-denoise3x-up2x", "pro-conservative-up2x", "pro-denoise3x-up2x"][[-1, 0, 3].index(nr_lv)]
mdl_pth = plg_dir + "/models/cugan/" + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -669,6 +675,12 @@ def ESRGAN_DML(
if not hasattr(core, "ort") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 ort")

plg_dir = os.path.dirname(core.ort.Version()["path"]).decode()
mdl_fname = ["RealESRGANv2-animevideo-xsx2", "realesr-animevideov3", "animejanaiV2L1", "animejanaiV2L2", "animejanaiV2L3"][[0, 2, 5005, 5006, 5007].index(model)]
mdl_pth = plg_dir + "/models/RealESRGANv2/" + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -730,6 +742,12 @@ def ESRGAN_NV(
if not hasattr(core, "trt") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 trt")

plg_dir = os.path.dirname(core.trt.Version()["path"]).decode()
mdl_fname = ["RealESRGANv2-animevideo-xsx2", "realesr-animevideov3", "animejanaiV2L1", "animejanaiV2L2", "animejanaiV2L3"][[0, 2, 5005, 5006, 5007].index(model)]
mdl_pth = plg_dir + "/models/RealESRGANv2/" + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -830,6 +848,13 @@ def WAIFU_DML(
if not hasattr(core, "ort") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 ort")

plg_dir = os.path.dirname(core.ort.Version()["path"]).decode()
mdl_pname = {3:"upconv_7_anime_style_art_rgb/", 5:"upresnet10/", 6:"cunet/"}.get(model)
mdl_fname = ["scale2.0x_model", "noise0_scale2.0x_model", "noise1_scale2.0x_model", "noise2_scale2.0x_model", "noise3_scale2.0x_model"][[-1, 0, 1, 2, 3].index(nr_lv)]
mdl_pth = plg_dir + "/models/waifu2x/" + mdl_pname + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -894,6 +919,13 @@ def WAIFU_NV(
if not hasattr(core, "trt") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 trt")

plg_dir = os.path.dirname(core.trt.Version()["path"]).decode()
mdl_pname = {3:"upconv_7_anime_style_art_rgb/", 5:"upresnet10/", 6:"cunet/"}.get(model)
mdl_fname = ["scale2.0x_model", "noise0_scale2.0x_model", "noise1_scale2.0x_model", "noise2_scale2.0x_model", "noise3_scale2.0x_model"][[-1, 0, 1, 2, 3].index(nr_lv)]
mdl_pth = plg_dir + "/models/waifu2x/" + mdl_pname + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -1032,8 +1064,8 @@ def _ffps(fps) :

cut1 = core.std.AssumeFPS(clip=cut0, fpsnum=int(vfps), fpsden=vden)
cut_s = core.mv.Super(clip=cut1, sharp=1, rfilter=4)
cut_b = core.mv.Analyse(super=cut_s, blksize=64, searchparam=0, pelsearch=3, isb=True, _lambda=0, lsad=10000, overlapv=16, badrange=0, search_coarse=4)
cut_f = core.mv.Analyse(super=cut_s, blksize=64, searchparam=0, pelsearch=3, _lambda=0, lsad=10000, overlapv=16, badrange=0, search_coarse=4)
cut_b = core.mv.Analyse(super=cut_s, blksize=64, searchparam=0, pelsearch=3, isb=True, lambda_=0, lsad=10000, overlapv=16, badrange=0, search_coarse=4)
cut_f = core.mv.Analyse(super=cut_s, blksize=64, searchparam=0, pelsearch=3, lambda_=0, lsad=10000, overlapv=16, badrange=0, search_coarse=4)

output = core.mv.BlockFPS(clip=cut1, super=cut_s, mvbw=cut_b, mvfw=cut_f, num=fps_out * 1000, den=vden, mode=2, thscd1=970, thscd2=255, blend=False)
if w_tmp + h_tmp > 0 :
Expand Down Expand Up @@ -1185,7 +1217,7 @@ def RIFE_STD(
raise vs.Error(f"模块 {func_name} 的子参数 sc_mode 的值无效")
if not isinstance(skip, bool) :
raise vs.Error(f"模块 {func_name} 的子参数 skip 的值无效")
if not isinstance(stat_th, (int, float)) or stat_th <= 0.0 :
if not isinstance(stat_th, (int, float)) or stat_th <= 0.0 or stat_th > 60.0 :
raise vs.Error(f"模块 {func_name} 的子参数 stat_th 的值无效")
if gpu not in [0, 1, 2] :
raise vs.Error(f"模块 {func_name} 的子参数 gpu 的值无效")
Expand Down Expand Up @@ -1288,6 +1320,16 @@ def RIFE_NV(
if not hasattr(core, "akarin") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 akarin")

plg_dir = os.path.dirname(core.trt.Version()["path"]).decode()
mdl_pname = "rife/" if ext_proc else "rife_v2/"
if t_tta :
mdl_fname = ["rife_v4.6_ensemble", "rife_v4.13_ensemble", "rife_v4.13_lite_ensemble"][[46, 413, 4131].index(model)]
else :
mdl_fname = ["rife_v4.6", "rife_v4.13", "rife_v4.13_lite"][[46, 413, 4131].index(model)]
mdl_pth = plg_dir + "/models/" + mdl_pname + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -1911,6 +1953,12 @@ def DPIR_NR_NV(
if not hasattr(core, "trt") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 trt")

plg_dir = os.path.dirname(core.trt.Version()["path"]).decode()
mdl_fname = ["drunet_gray", "drunet_color"][[0, 1].index(model)]
mdl_pth = plg_dir + "/models/dpir/" + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -2508,6 +2556,12 @@ def DPIR_DBLK_NV(
if not hasattr(core, "trt") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 trt")

plg_dir = os.path.dirname(core.trt.Version()["path"]).decode()
mdl_fname = ["drunet_deblocking_grayscale", "drunet_deblocking_color"][[2, 3].index(model)]
mdl_pth = plg_dir + "/models/dpir/" + mdl_fname + ".onnx"
if not os.path.exists(mdl_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")

global vsmlrt
if vsmlrt is None :
try :
Expand Down Expand Up @@ -2801,7 +2855,7 @@ def UAI_DML(
raise vs.Error(f"模块 {func_name} 的子参数 input 的值无效")
if not isinstance(clamp, bool) :
raise vs.Error(f"模块 {func_name} 的子参数 clamp 的值无效")
if len(model_pth) == 0 :
if len(model_pth) <= 5 :
raise vs.Error(f"模块 {func_name} 的子参数 model_pth 的值无效")
if gpu not in [0, 1, 2] :
raise vs.Error(f"模块 {func_name} 的子参数 gpu 的值无效")
Expand All @@ -2816,6 +2870,11 @@ def UAI_DML(
if not hasattr(core, "akarin") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 akarin")

mdl_pth_rel = os.path.join(vsmlrt.models_path, model_pth)
if not os.path.exists(mdl_pth_rel) and not os.path.exists(model_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")
mdl_pth = mdl_pth_rel if os.path.exists(mdl_pth_rel) else model_pth

global vsmlrt
if vsmlrt is None :
try :
Expand All @@ -2833,7 +2892,7 @@ def UAI_DML(
if clamp :
clip = core.akarin.Expr(clips=clip, expr="x 0 1 clamp")
be_param = vsmlrt.BackendV2.ORT_DML(device_id=gpu, num_streams=gpu_t, fp16=True)
infer = vsmlrt.inference(clips=clip, network_path=os.path.join(vsmlrt.models_path, model_pth), backend=be_param)
infer = vsmlrt.inference(clips=clip, network_path=mdl_pth, backend=be_param)
output = core.resize.Bilinear(clip=infer, format=fmt_in, matrix_s="709", range=1 if colorlv==0 else None)

return output
Expand Down Expand Up @@ -2863,7 +2922,7 @@ def UAI_NV_TRT(
raise vs.Error(f"模块 {func_name} 的子参数 input 的值无效")
if not isinstance(clamp, bool) :
raise vs.Error(f"模块 {func_name} 的子参数 clamp 的值无效")
if len(model_pth) == 0 :
if len(model_pth) <= 5 :
raise vs.Error(f"模块 {func_name} 的子参数 model_pth 的值无效")
if opt_lv not in [0, 1, 2, 3, 4, 5] :
raise vs.Error(f"模块 {func_name} 的子参数 opt_lv 的值无效")
Expand Down Expand Up @@ -2896,6 +2955,11 @@ def UAI_NV_TRT(
if not hasattr(core, "akarin") :
raise ModuleNotFoundError(f"模块 {func_name} 依赖错误:缺失插件,检查项目 akarin")

mdl_pth_rel = os.path.join(vsmlrt.models_path, model_pth)
if not os.path.exists(mdl_pth_rel) and not os.path.exists(model_pth) :
raise vs.Error(f"模块 {func_name} 所请求的模型缺失")
mdl_pth = mdl_pth_rel if os.path.exists(mdl_pth_rel) else model_pth

global vsmlrt
if vsmlrt is None :
try :
Expand All @@ -2918,7 +2982,7 @@ def UAI_NV_TRT(
num_streams=gpu_t, use_cuda_graph=nv1, use_cublas=nv2, use_cudnn=nv3,
fp16=fp16, force_fp16=False, tf32=True, output_format=1 if fp16 else 0, workspace=None if ws_size < 128 else (ws_size if st_eng else ws_size * 2),
static_shape=st_eng, min_shapes=[0, 0] if st_eng else [64, 64], opt_shapes=None if st_eng else res_opt, max_shapes=None if st_eng else res_max)
infer = vsmlrt.inference(clips=clip, network_path=os.path.join(vsmlrt.models_path, model_pth), backend=be_param)
infer = vsmlrt.inference(clips=clip, network_path=mdl_pth, backend=be_param)
output = core.resize.Bilinear(clip=infer, format=fmt_in, matrix_s="709", range=1 if colorlv==0 else None)

return output
Expand Down

0 comments on commit f0f2192

Please sign in to comment.