Skip to content

Commit

Permalink
Fix discrete inference, fixes #1
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Hehn committed May 19, 2022
1 parent be4cf69 commit 5fa0255
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 14 deletions.
4 changes: 2 additions & 2 deletions e2edt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def forward(self, X, discrete=False):
gating = self.gating(X)
if discrete:
gating = gating.round()
return gating * self.right_child(X)\
+ (1 - gating) * self.left_child(X)
return gating * self.right_child(X, discrete=discrete)\
+ (1 - gating) * self.left_child(X, discrete=discrete)
return self.leaf_predictions

def split(self, initial_steepness):
Expand Down
4 changes: 4 additions & 0 deletions e2edt/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def refine(self, train_set, epochs=100, algo='EM', weight_decay=0.0,
self.new_trees = []
yield total_f_loss

def flush_trees(self):
self.trees += self.new_trees
self.new_trees = []

def refine_noopt(self, train_set, epochs=100, algo='EM', weight_decay=0.0):
refiners = [tree.refine(train_set, epochs, algo, weight_decay)\
for tree in self.new_trees]
Expand Down
3 changes: 3 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@
non_linear_module = None
if args.NN == 'TinyCNN':
non_linear_module = TinyCNN
if args.data != 'MNIST':
raise ValueError("The layer dimensions are hardcoded for MNIST.")
n_features = 50

regularizer = None
if args.reg > 0:
Expand Down
5 changes: 5 additions & 0 deletions run_forest_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@
for n_tree in n_trees:
# --- setup forest for greedy training
model.add_trees(n_tree, args)
print("Number of trees: {}".format(len(model.trees)))
model.fit_greedy(train_set, args)
print("Done fitting.")
print("Number of trees: {}".format(len(model.trees)))

#~# Nasty hack to count leaves without implementing a function.
#~count = [0]
Expand Down Expand Up @@ -200,9 +202,12 @@
model_filename = "{}_{}.pth".format(TIMESTAMP, argstr)
print("Saving model to {}".format(model_filename))
model.save(model_filename)
else:
model.flush_trees()

total_nodes = model.count_nodes()
leaf_nodes = model.count_leaf_nodes()
print("Number of trees: {}".format(len(model.trees)))
print("Total nodes: {}, {} per tree".
format(total_nodes, float(total_nodes)/len(model.trees)))
print("Leaf nodes: {}, {} per tree".
Expand Down
27 changes: 15 additions & 12 deletions utils/TinyCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@
import torch.nn.functional as F

class TinyCNN(nn.Module):
"""
A small neural network architecture to extract features from MNIST samples.
"""
def __init__(self):
super(TinyCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 3, padding=2, kernel_size=5)
self.conv2 = nn.Conv2d(3, 3, padding=2, kernel_size=5)
self.conv1 = nn.Conv2d(1, 3, kernel_size=5)
self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
#self.conv2_drop = nn.Dropout2d()
#self.fc1 = nn.Linear(6*16, 50) # input size = 50
#self.output_size = 6*16 #50
self.output_size = 784*3 #50
self.fc1 = nn.Linear(6*16, 50) # input size = 50
self.output_size = 50
#self.output_size = 784*3 #50

def forward(self, X):
X = F.relu(self.conv1(X))
X = F.relu(self.conv2(X))
#X = F.relu(F.max_pool2d(self.conv1(X), 2))
#X = F.relu(F.max_pool2d(self.conv2(X), 2))
#X = F.relu(self.conv1(X))
#X = F.relu(self.conv2(X))
X = F.relu(F.max_pool2d(self.conv1(X), 2))
X = F.relu(F.max_pool2d(self.conv2(X), 2))
#X = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(X)), 2))
#X = X.view(-1, 6*16)
X = X.view(-1, 784*3)
#X = F.relu(self.fc1(X))
X = X.view(X.shape[0], -1)
#X = X.view(-1, 784*3)
X = F.relu(self.fc1(X))
return X


0 comments on commit 5fa0255

Please sign in to comment.