-
Notifications
You must be signed in to change notification settings - Fork 7
/
mlm.py
69 lines (49 loc) · 2.12 KB
/
mlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torch.utils.data import DataLoader
from transformers import BertForMaskedLM, BertTokenizer
import data_prep
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
def train(dataset, model):
# stage data for training
loader = DataLoader(dataset, batch_size=32, shuffle=True)
# set model to train
model.to(device)
model.train()
# initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
epochs = 2
for epoch in range(epochs):
for batch in loader:
optimizer.zero_grad()
# prep data for predict step
masked_inputs = data_prep.masking_step(batch["input_ids"]).to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["mlm_labels"].to(device)
outputs = model(masked_inputs, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Epoch: {epoch} Loss: {loss.item()}")
model.save_pretrained(f"models/URLTran-BERT-{epoch}")
def predict_mask(url, tokenizer, model):
inputs = data_prep.preprocess(url, tokenizer)
masked_inputs = data_prep.masking_step(inputs["input_ids"]).to(device)
with torch.no_grad():
predictions = model(masked_inputs)
output_ids = torch.argmax(
torch.nn.functional.softmax(predictions.logits[0], -1), dim=1
).tolist()
return masked_inputs, output_ids
if __name__ == "__main__":
data_path = "data/final_data.csv"
dataset = data_prep.URLTranDataset(data_path, tokenizer)
train(dataset, model)
# Example Inference
url = "huggingface.co/docs/transformers/task_summary"
input_ids, output_ids = predict_mask(url, tokenizer, model)
masked_input = tokenizer.decode(input_ids[0].tolist()).replace(" ", "")
prediction = tokenizer.decode(output_ids).replace(" ", "")
print(f"Masked Input: {masked_input}")
print(f"Predicted Output: {prediction}")