-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathTwo_Stream_Net.py
44 lines (29 loc) · 1.34 KB
/
Two_Stream_Net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch.nn as nn
import torchvision.models as models
import LoadUCF101Data
class OpticalFlowStreamNet(nn.Module):
def __init__(self):
super(OpticalFlowStreamNet, self).__init__()
self.OpticalFlow_stream = models.resnet50()
self.OpticalFlow_stream.conv1 = nn.Conv2d(LoadUCF101Data.SAMPLE_FRAME_NUM * 2, 64, kernel_size=7, stride=2, padding=3,bias=False)
self.OpticalFlow_stream.fc = nn.Linear(in_features=2048, out_features=101)
def forward(self, x):
streamOpticalFlow_out = self.OpticalFlow_stream(x)
return streamOpticalFlow_out
class RGBStreamNet(nn.Module):
def __init__(self):
super(RGBStreamNet, self).__init__()
self.RGB_stream = models.resnet50(pretrained=True)
self.RGB_stream.fc = nn.Linear(in_features=2048, out_features=101)
def forward(self, x):
streamRGB_out = self.RGB_stream(x)
return streamRGB_out
class TwoStreamNet(nn.Module):
def __init__(self):
super(TwoStreamNet, self).__init__()
self.rgb_branch = RGBStreamNet()
self.opticalFlow_branch = OpticalFlowStreamNet()
def forward(self, x_rgb, x_opticalFlow):
rgb_out = self.rgb_branch(x_rgb)
opticalFlow_out = self.opticalFlow_branch(x_opticalFlow)
return rgb_out + opticalFlow_out