-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsentence_model.py
36 lines (24 loc) · 1.04 KB
/
sentence_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# tips from https://towardsdatascience.com/optimize-pytorch-performance-for-speed-and-memory-efficiency-2022-84f453916ea6
import torch
import torch.nn as nn
import torch.nn.functional as F
class SentenceModel(nn.Module):
def __init__(self, embedding_size, num_classes, hidden_size=32, drop_rate=0.1):
super(SentenceModel, self).__init__()
self.linear1 = nn.Linear(embedding_size, hidden_size)
self.dropout = nn.Dropout(drop_rate)
self.linear2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(self.dropout(F.gelu(x)))
return F.log_softmax(x, dim=-1)
class SentenceEmbeddedDataset(torch.utils.data.Dataset):
def __init__(self, embeddings, classes, device):
super().__init__()
self.embeddings = embeddings
self.classes = classes
self.device = device
def __len__(self):
return len(self.embeddings)
def __getitem__(self, idx):
return self.embeddings[idx], self.classes[idx]