forked from feichtenhofer/st-resnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
STMuNet_test.m
80 lines (65 loc) · 2.62 KB
/
STMuNet_test.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
function STMuNet_test(varargin)
if ~isempty(gcp('nocreate'))
delete(gcp)
end
opts.train.gpus = [ 1 ] ;
opts = cnn_setup_environment(opts);
opts.nSplit = 1 ;
% opts.dataSet = 'hmdb51';
opts.dataSet = 'ucf101';
model = ['ST-MulNet-img50-flow50-final-split=' num2str(opts.nSplit)];
model = ['ST-MulNet-img50-flow152-split=' num2str(opts.nSplit)];
opts.train.memoryMapFile = fullfile(tempdir, 'ramdisk', ['matconvnet' num2str(opts.nSplit ) '.bin']) ;
opts.dataDir = fullfile(opts.dataPath, opts.dataSet) ;
opts.splitDir = 'ucf101_splits'; nClasses = 101;
opts.imdbPath = fullfile(opts.dataDir, [opts.dataSet '_split' num2str(opts.nSplit) 'imdb.mat']);
opts.model = fullfile(opts.modelPath, [opts.dataSet model '.mat']) ;
opts.expDir = fullfile(opts.dataDir, [opts.dataSet '-' model]) ;
[opts, varargin] = vl_argparse(opts, varargin) ;
opts.train.saveAllPredScores = 1;
opts.train.denseEval = 1;
opts.train.plotDiagnostics = 0 ;
opts.train.continue = 1 ;
opts.train.prefetch = 1 ;
opts.train.expDir = opts.expDir ;
opts.train.numAugments = 1;
opts.train.frameSample = 'random';
opts = vl_argparse(opts, varargin) ;
% -------------------------------------------------------------------------
% Database initialization
% -------------------------------------------------------------------------
if exist(opts.imdbPath)
imdb = load(opts.imdbPath) ;
imdb.flowDir = opts.flowDir;
else
imdb = cnn_setup_data(opts) ;
save(opts.imdbPath, '-struct', 'imdb', '-v6') ;
end
% -------------------------------------------------------------------------
% Network initialization
% -------------------------------------------------------------------------
if ~exist(opts.model,'file')
[~, baseModel] = fileparts(opts.model);
fprintf('Downloading base model file: %s ...\n', baseModel);
mkdir(fileparts(opts.model)) ;
urlwrite(...
['http://ftp.tugraz.at/pub/feichtenhofer/st-mul/final/' baseModel '.mat'], ...
opts.model) ;
end
net = load(opts.model) ;
if isfield(net, 'net'), net=net.net;end
net = dagnn.DagNN.loadobj(net);
opts.train.augmentation = 'f25noCtr';
opts.train.frameSample = 'temporalStrideRandom';
opts.train.nFramesPerVid = 1;
opts.train.temporalStride = 1:15;
opts.train.valmode = 'temporalStrideRandom';
opts.train.numValFrames = 25 ;
opts.train.saveAllPredScores = 1 ;
opts.train.denseEval = 1;
opts.train.temporalFullConvTest = 1;
opts.train.train = NaN;
net.rebuild() ;
net.conserveMemory = 1 ;
fn = getBatchWrapper_rgbflow(net.meta.normalization, opts.numFetchThreads, opts.train) ;
[info] = cnn_train_dag(net, imdb, fn, opts.train) ;