-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
155 lines (130 loc) · 4.51 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""TBNet inference."""
import os
import argparse
import math
from mindspore import load_checkpoint, load_param_into_net, context
import mindspore.common.dtype as mstype
from src.config import TBNetConfig
from src.tbnet import TBNet
from src.aggregator import InferenceAggregator
from src import dataset
from src import steam
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Infer TBNet.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--csv',
type=str,
required=False,
default='infer.csv',
help="the csv datafile inside the dataset folder (e.g. infer.csv)"
)
parser.add_argument(
'--checkpoint_id',
type=int,
required=True,
help="use which checkpoint(.ckpt) file to infer"
)
parser.add_argument(
'--user',
type=int,
required=True,
help="id of the user to be recommended to"
)
parser.add_argument(
'--items',
type=int,
required=False,
default=1,
help="no. of items to be recommended"
)
parser.add_argument(
'--explanations',
type=int,
required=False,
default=3,
help="no. of recommendation explanations to be shown"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='Ascend',
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
return parser.parse_args()
def infer_tbnet():
"""Inference process."""
args = get_args()
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
home = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
translate_path = os.path.join(home, 'data', args.dataset, 'translate.json')
data_path = os.path.join(home, 'data', args.dataset, args.csv)
ckpt_path = os.path.join(home, 'checkpoints')
print(f"creating TBNet from checkpoint {args.checkpoint_id}...")
config = TBNetConfig(config_path)
if args.device_target == 'Ascend':
config.per_item_paths = math.ceil(config.per_item_paths / 16) * 16
config.embedding_dim = math.ceil(config.embedding_dim / 16) * 16
network = TBNet(config)
if args.device_target == 'Ascend':
network.to_float(mstype.float16)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
print(f"creating dataset from {data_path}...")
infer_ds = dataset.create(data_path, config.per_item_paths, train=False, users=args.user)
infer_ds = infer_ds.batch(config.batch_size)
print("inferring...")
# infer and aggregate results
aggregator = InferenceAggregator(top_k=args.items)
for user, item, relation1, entity, relation2, hist_item, rating in infer_ds:
del rating
result = network(item, relation1, entity, relation2, hist_item)
item_score = result[0]
path_importance = result[1]
aggregator.aggregate(user, item, relation1, entity, relation2, hist_item, item_score, path_importance)
# show recommendations with explanations
explainer = steam.TextExplainer(translate_path)
recomms = aggregator.recommend()
for user, recomm in recomms.items():
for item_rec in recomm.item_records:
item_name = explainer.translate_item(item_rec.item)
print(f"Recommend <{item_name}> to user:{user}, because:")
# show explanations
explanation = 0
for path in item_rec.paths:
print(" - " + explainer.explain(path))
explanation += 1
if explanation >= args.explanations:
break
print("")
if __name__ == '__main__':
infer_tbnet()