-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathz_inference.py
30 lines (24 loc) · 847 Bytes
/
z_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from z_modelops import NameToLanguages, load_labels
from z_dataops import transform
import json
from torch import nn
def load_model(location="model/rnn.pth"):
'''loads the model, together with arch'''
model = torch.load(location, weights_only=False)
return model
def infer_lang(name:str, model, label:dict, k=3)-> str:
name_tensor = transform(name)
with torch.no_grad():
logits = model(name_tensor.unsqueeze(0))
y_pred = nn.Softmax(dim=1)(logits)
top_k_idx = y_pred.sort(descending=True, dim=1).indices.numpy()[0][:k]
return [label[str(idx)] for idx in top_k_idx]
def setup_inference():
# load model
model = load_model()
# call the model with inputs
labels = load_labels()
return model, labels
if __name__=="__main__":
model, labels = setup_inference()