Skip to content

Commit

Permalink
extend args accpeted by Embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Aug 9, 2023
1 parent 2c52c3b commit f05f883
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion libmultilabel/nn/networks/kim_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
activation="relu",
):
super(KimCNN, self).__init__()
self.embedding = Embedding(embed_vecs, embed_dropout)
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
self.encoder = CNNEncoder(
embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, post_encoder_dropout, num_pool=1
)
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/networks/labelwise_attention_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LabelwiseAttentionNetwork(ABC, nn.Module):

def __init__(self, embed_vecs, num_classes, embed_dropout, encoder_dropout, post_encoder_dropout, hidden_dim):
super(LabelwiseAttentionNetwork, self).__init__()
self.embedding = Embedding(embed_vecs, embed_dropout)
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
self.encoder = self._get_encoder(embed_vecs.shape[1], encoder_dropout, post_encoder_dropout)
self.attention = self._get_attention()
self.output = LabelwiseLinearOutput(hidden_dim, num_classes)
Expand Down
6 changes: 4 additions & 2 deletions libmultilabel/nn/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class Embedding(nn.Module):
Args:
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
freeze (bool): If True, the tensor does not get updated in the learning process.
Equivalent to embedding.weight.requires_grad = False. Default: False.
dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
"""

def __init__(self, embed_vecs, dropout=0.2):
def __init__(self, embed_vecs, freeze=False, dropout=0.2):
super(Embedding, self).__init__()
self.embedding = nn.Embedding.from_pretrained(embed_vecs, freeze=False, padding_idx=0)
self.embedding = nn.Embedding.from_pretrained(embed_vecs, freeze=freeze, padding_idx=0)
self.dropout = nn.Dropout(dropout)

def forward(self, input):
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/networks/xml_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
activation="relu",
):
super(XMLCNN, self).__init__()
self.embedding = Embedding(embed_vecs, embed_dropout)
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, num_pool=num_pool)
total_output_size = len(filter_sizes) * num_filter_per_size * num_pool
self.linear1 = nn.Linear(total_output_size, hidden_dim)
Expand Down

0 comments on commit f05f883

Please sign in to comment.