Skip to content

Commit

Permalink
naming and description
Browse files Browse the repository at this point in the history
  • Loading branch information
mgrankin authored and Optimox committed Feb 14, 2020
1 parent 014f27a commit 939f01c
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def forward(self, x):
return res


class DryTabNet(torch.nn.Module):
class TabNetNoEmbeddings(torch.nn.Module):
def __init__(self, input_dim, output_dim,
n_d=8, n_a=8,
n_steps=3, gamma=1.3,
n_independent=2, n_shared=2, epsilon=1e-15,
virtual_batch_size=128, momentum=0.02):
"""
Defines the essence of TabNet network
Defines main part of the TabNet network without the embedding layers.
Parameters
----------
Expand All @@ -75,7 +75,7 @@ def __init__(self, input_dim, output_dim,
- epsilon: float
Avoid log(0), this should be kept very low
"""
super(DryTabNet, self).__init__()
super(TabNetNoEmbeddings, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.n_d = n_d
Expand Down Expand Up @@ -228,8 +228,9 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
else:
self.post_embed_dim = self.input_dim + np.sum(cat_emb_dim) - len(cat_emb_dim)
self.post_embed_dim = np.int(self.post_embed_dim)
self.tabnet = DryTabNet(self.post_embed_dim, output_dim, n_d, n_a, n_steps, gamma,
n_independent, n_shared, epsilon, virtual_batch_size, momentum)
self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps,
gamma, n_independent, n_shared, epsilon,
virtual_batch_size, momentum)
self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01)

# Defining device
Expand Down

0 comments on commit 939f01c

Please sign in to comment.