-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfresh_support_model.py
91 lines (73 loc) · 3.61 KB
/
fresh_support_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# obtain and save part(see the code) of attentions and logits for later use
# $ python fresh_support_model.py [name] [artifact_path] [tuning_type]
# example: $ python fresh_support_model.py FT3 artifacts/finetuned_BERT:v0/misty-dust-2.pth FT
import argparse
import os
import torch
import transformers
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (BertForSequenceClassification, BertTokenizer,
DataCollatorWithPadding)
# 0. argparse
parser = argparse.ArgumentParser()
parser.add_argument("name", help='path to save attentions and logits')
parser.add_argument("artifact_path", help="provide path to the artifact to load")
parser.add_argument("tuning_type", help="provide which tuning method was used for the artifact. Should be one of 'FT', 'Adapter', 'LoRA'")
args = parser.parse_args()
# 1. load model and load weights from artifact
## load base model
print('Cuda is available:', torch.cuda.is_available())
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=2).to(device)
## function to check env (transformer library)
def is_adapter_transformer():
"""function to check whether current env is 'adapter-transformer' or 'transformer'"""
if 'AdapterConfig' in dir(transformers):
return True
return False
## load weights from the provided artifact
if args.tuning_type == 'FT':
assert not is_adapter_transformer()
model.load_state_dict(torch.load(args.artifact_path, map_location=device))
elif args.tuning_type == 'Adapter':
assert is_adapter_transformer()
model.load_adapter(args.artifact_path, set_active=True)
elif args.tuning_type == 'LoRA':
assert not is_adapter_transformer()
model.load_state_dict(torch.load(args.artifact_path, map_location=device), strict=False)
else:
raise argparse.ArgumentError
# 2. model inference
## load dataset and build dataloader with dynamic padding
kr3_tokenized = Dataset.load_from_disk('kr3_tokenized')
kr3_tokenized.set_format(type='torch')
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
dloader = DataLoader(kr3_tokenized, batch_size = 32, collate_fn=data_collator) # Important not to shuffle.
## inference
model.eval()
with torch.no_grad():
for i, batch in enumerate(tqdm(dloader)):
# forward
batch = {k: v.to(device) for k, v in batch.items()}
model = model.to(device)
outputs = model(**batch, output_attentions=True)
# full attn weights
# attentions.size() = [batch_size, num_layers, num_heads, seq_len, seq_len]
attentions = torch.stack([attn.detach() for attn in outputs.attentions], dim=1).to(device='cpu')
# process attn: mean over head, attn w.r.t. the first token([CLS]). I'd love to store all, but time and memory...
# attentions.size() = [batch_size, num_layers, seq_len]; seq_len is the max_seq_len within the batch.
attentions = attentions.mean(dim=2)[:,:,0,:]
# logits
# logits.size() = [batch_size, 2]
logits = outputs.logits.detach().to(device='cpu')
# save attentions and logits (one file per one batch)
dir_name = f'outputs/{args.name}'
if not os.path.isdir(dir_name):
os.makedirs(dir_name)
os.makedirs(f'{dir_name}/attentions')
os.makedirs(f'{dir_name}/logits')
torch.save(attentions, f'{dir_name}/attentions/batch_{i}.pt')
torch.save(logits, f'{dir_name}/logits/batch_{i}.pt')