-
Notifications
You must be signed in to change notification settings - Fork 14
/
e2e_thumos_videomaev2_g_768x2_224_adapter.py
98 lines (93 loc) · 3.78 KB
/
e2e_thumos_videomaev2_g_768x2_224_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
_base_ = ["e2e_thumos_videomae_s_768x1_160_adapter.py"]
window_size = 768
scale_factor = 2
chunk_num = window_size * scale_factor // 16
dataset = dict(
train=dict(
pipeline=[
dict(type="PrepareVideoInfo", format="mp4"),
dict(type="mmaction.DecordInit", num_threads=4),
dict(
type="LoadFrames",
num_clips=1,
method="random_trunc",
trunc_len=window_size,
trunc_thresh=0.75,
crop_ratio=[0.9, 1.0],
scale_factor=scale_factor,
),
dict(type="mmaction.DecordDecode"),
dict(type="mmaction.Resize", scale=(-1, 256)),
dict(type="mmaction.RandomResizedCrop"),
dict(type="mmaction.Resize", scale=(224, 224), keep_ratio=False),
dict(type="mmaction.Flip", flip_ratio=0.5),
dict(type="mmaction.ImgAug", transforms="default"),
dict(type="mmaction.ColorJitter"),
dict(type="mmaction.FormatShape", input_format="NCTHW"),
dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
],
),
val=dict(
window_size=window_size,
pipeline=[
dict(type="PrepareVideoInfo", format="mp4"),
dict(type="mmaction.DecordInit", num_threads=4),
dict(type="LoadFrames", num_clips=1, method="sliding_window", scale_factor=scale_factor),
dict(type="mmaction.DecordDecode"),
dict(type="mmaction.Resize", scale=(-1, 224)),
dict(type="mmaction.CenterCrop", crop_size=224),
dict(type="mmaction.FormatShape", input_format="NCTHW"),
dict(type="ConvertToTensor", keys=["imgs", "gt_segments", "gt_labels"]),
dict(type="Collect", inputs="imgs", keys=["masks", "gt_segments", "gt_labels"]),
],
),
test=dict(
window_size=window_size,
pipeline=[
dict(type="PrepareVideoInfo", format="mp4"),
dict(type="mmaction.DecordInit", num_threads=4),
dict(type="LoadFrames", num_clips=1, method="sliding_window", scale_factor=scale_factor),
dict(type="mmaction.DecordDecode"),
dict(type="mmaction.Resize", scale=(-1, 224)),
dict(type="mmaction.CenterCrop", crop_size=224),
dict(type="mmaction.FormatShape", input_format="NCTHW"),
dict(type="ConvertToTensor", keys=["imgs"]),
dict(type="Collect", inputs="imgs", keys=["masks"]),
],
),
)
model = dict(
backbone=dict(
backbone=dict(
patch_size=14,
embed_dims=1408,
depth=40,
num_heads=16,
mlp_ratio=48 / 11,
total_frames=window_size * scale_factor,
adapter_index=list(range(20, 40)),
),
custom=dict(
pretrain="pretrained/vit-giant-p14_videomaev2-hybrid_pt_1200e_k710_ft_my.pth",
pre_processing_pipeline=[
dict(type="Rearrange", keys=["frames"], ops="b n c (t1 t) h w -> (b t1) n c t h w", t1=chunk_num),
],
post_processing_pipeline=[
dict(type="Reduce", keys=["feats"], ops="b n c t h w -> b c t", reduction="mean"),
dict(type="Rearrange", keys=["feats"], ops="(b t1) c t -> b c (t1 t)", t1=chunk_num),
dict(type="Interpolate", keys=["feats"], size=window_size),
],
),
),
projection=dict(in_channels=1408),
)
workflow = dict(
logging_interval=50,
checkpoint_interval=2,
val_loss_interval=-1,
val_eval_interval=2,
val_start_epoch=37,
end_epoch=50,
)
work_dir = "exps/thumos/adatad/e2e_actionformer_videomaev2_g_768x2_224_adapter"