-
Notifications
You must be signed in to change notification settings - Fork 132
/
example_inference.py
56 lines (40 loc) · 1.16 KB
/
example_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Usage: python -m scripts.example_inference
Computes logits for a single sequence or with a batch of sequences.
"""
import torch
from evo import Evo
def main():
# Load model.
device = 'cuda:0'
evo_model = Evo('evo-1-131k-base')
model, tokenizer = evo_model.model, evo_model.tokenizer
model.to(device)
model.eval()
# Example single-sequence inference.
sequence = 'ACGT'
input_ids = torch.tensor(
tokenizer.tokenize(sequence),
dtype=torch.int,
).to(device).unsqueeze(0)
logits, _ = model(input_ids) # (batch, length, vocab)
print('Logits: ', logits)
print('Shape (batch, length, vocab): ', logits.shape)
# Example batched inference.
sequences = [
'ACGT',
'A',
'AAAAACCCCCGGGGGTTTTT',
]
from evo.scoring import prepare_batch
input_ids, seq_lengths = prepare_batch(
sequences,
tokenizer,
prepend_bos=False,
device=device,
)
logits, _ = model(input_ids) # (batch, length, vocab)
print('Batch logits: ', logits)
print('Batch shape (batch, length, vocab): ', logits.shape)
if __name__ == '__main__':
main()