Skip to content

Commit

Permalink
Merge pull request FedML-AI#66 from chenwanqq/patch-1
Browse files Browse the repository at this point in the history
Update cnn.py

Former-commit-id: d64409c
  • Loading branch information
chaoyanghe authored Nov 19, 2020
2 parents dddce53 + 77d7d06 commit 7b663ee
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions fedml_api/model/cv/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, only_digits=True):
self.linear_1 = nn.Linear(3136, 512)
self.linear_2 = nn.Linear(512, 10 if only_digits else 62)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
#self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = torch.unsqueeze(x, 1)
Expand All @@ -64,7 +64,8 @@ def forward(self, x):
x = self.max_pooling(x)
x = self.flatten(x)
x = self.relu(self.linear_1(x))
x = self.softmax(self.linear_2(x))
x = self.linear_2(x)
#x = self.softmax(self.linear_2(x))
return x


Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(self, only_digits=True):
self.dropout_2 = nn.Dropout(0.5)
self.linear_2 = nn.Linear(128, 10 if only_digits else 62)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
#self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = torch.unsqueeze(x, 1)
Expand All @@ -131,5 +132,6 @@ def forward(self, x):
x = self.flatten(x)
x = self.relu(self.linear_1(x))
x = self.dropout_2(x)
x = self.softmax(self.linear_2(x))
x = self.linear_2(x)
#x = self.softmax(self.linear_2(x))
return x

0 comments on commit 7b663ee

Please sign in to comment.