-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_demo.py
39 lines (31 loc) · 1012 Bytes
/
run_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# @File: run_demo.py
# @Project: SceneTracker
# @Author : wangbo
# @Time : 2024.07.12
import numpy as np
import cv2
import torch
from model.model_scenetracker import SceneTracker
import run_test
def read_mp4(name_path):
vidcap = cv2.VideoCapture(name_path)
frames = []
while (vidcap.isOpened()):
ret, frame = vidcap.read()
if ret == False:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
vidcap.release()
return frames
print('SceneTracker demo start...')
model = SceneTracker()
pre_replace_list = [['module.', '']]
checkpoint = torch.load('exp/0-pretrain/scenetracker-odyssey-200k.pth')
for l in pre_replace_list:
checkpoint = {k.replace(l[0], l[1]): v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, strict=True)
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
model.eval().cuda()
run_test.validate_odyssey(model, split='demo')
print('Success!!!')