forked from ant-research/VCSL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_video_sim.py
113 lines (91 loc) · 5.23 KB
/
run_video_sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python3
# Copyright (c) Ant Group, Inc.
"""
Codes for [CVPR2022] VCSL paper [https://github.com/alipay/VCSL].
This is the script for obtaining frame-to-frame similarity map between input video feature pairs.
The similarity map is the common input for different temporal alignment methods.
"input-root" in this script is the frame feature extracted from videos in data/videos_url_uuid.csv.
We also provide the extracted features link in data/vcsl_features.txt
Please cite the following publications if you plan to use our codes or the results for your research:
{
1. He S, Yang X, Jiang C, et al. A Large-scale Comprehensive Dataset and Copy-overlap Aware Evaluation
Protocol for Segment-level Video Copy Detection[C]//Proceedings of the IEEE/CVF Conference on Computer
Vision and Pattern Recognition. 2022: 21086-21095.
2. Jiang C, Huang K, He S, et al. Learning segment similarity and alignment in large-scale content based
video retrieval[C]//Proceedings of the 29th ACM International Conference on Multimedia. 2021: 1618-1626.
}
@author: Sifeng He and Xudong Yang
@email [sifeng.hsf@antgroup.com, jiegang.yxd@antgroup.com]
"""
import argparse
import pandas as pd
from vcsl import *
from torch.utils.data import DataLoader
from loguru import logger
from itertools import islice
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--query-file", "-Q", type=str, help="data file")
parser.add_argument("--reference-file", "-G", type=str, help="data file")
parser.add_argument("--pair-file", type=str, help="data file")
parser.add_argument("--data-file", type=str, help="data file")
parser.add_argument("--input-store", type=str, help="store of input feature data: oss|local", default="oss")
parser.add_argument("--input-root", type=str, help="root path of input feature data", default="")
parser.add_argument("--oss-config", type=str, default='~/ossutilconfig-copyright', help="url path")
parser.add_argument("--batch-size", "-b", type=int, default=32, help="batch size")
parser.add_argument("--data-workers", type=int, default=16, help="data workers")
parser.add_argument("--request-workers", type=int, default=4, help="data workers")
parser.add_argument("--output-workers", type=int, default=4, help="oss upload workers")
parser.add_argument("--output-store", type=str, help="store of output data: oss|local")
parser.add_argument("--output-root", type=str, help="output root")
parser.add_argument("--similarity-type", default='cos', type=str, help="cos or chamfer")
parser.add_argument('--consume', action='store_false', help="find exist npy file and consume")
parser.add_argument('--device', type=int, default=0, help="cuda device, available for gpu")
args = parser.parse_args()
pairs, files_dict, query, reference = None, None, None, None
bucket = create_oss_bucket(args.oss_config)
if args.pair_file and args.data_file:
df = pd.read_csv(args.pair_file)
pairs = df[['query_id', 'reference_id']].values.tolist()
files_dict = pd.read_csv(args.data_file, usecols=['uuid', 'path'], index_col='uuid')
files_dict = {idx: r['path'] for idx, r in files_dict.iterrows()}
else:
query = pd.read_csv(args.query_file)
query = query[['uuid', 'path']].values.tolist()
reference = pd.read_csv(args.reference_file)
reference = reference[['uuid', 'path']].values.tolist()
config = dict()
if args.input_store == 'oss':
config['oss_config'] = args.oss_config
dataset = PairDataset(query_list=query,
gallery_list=reference,
pair_list=pairs,
file_dict=files_dict,
root=args.input_root,
store_type=args.input_store,
trans_key_func=lambda x: x + ".npy",
data_type="numpy",
**config)
logger.info(f"Data to run {len(dataset)}")
loader = DataLoader(dataset, collate_fn=lambda x: x,
batch_size=args.batch_size,
num_workers=args.data_workers)
model = VideoSimMapModel(concurrency=args.request_workers,
)
output_store = args.input_store if args.output_store is None else args.output_store
output_config = dict(oss_config=args.oss_config) if output_store == 'oss' else dict()
writer_pool = AsyncWriter(pool_size=args.output_workers,
store_type=output_store,
data_type=DataType.NUMPY.type_name,
**output_config)
if output_store == 'local' and not os.path.exists(args.output_root):
os.makedirs(args.output_root, exist_ok=True)
for batch_data in islice(loader, 0, None):
logger.info("data cnt: {}", len(batch_data))
batch_result = model.forward(batch_data, normalize_input=False, similarity_type=args.similarity_type, device=args.device)
logger.info("result cnt: {}", len(batch_result))
#
for r_id, q_id, result in batch_result:
key = os.path.join(args.output_root, f"{r_id}-{q_id}.npy")
writer_pool.consume((key, result))
writer_pool.stop()