-
Notifications
You must be signed in to change notification settings - Fork 0
/
pipeline.py
60 lines (47 loc) · 1.72 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import argparse
import yaml
from Dataset import Dataset
from Greedy import Greedy
from Optimal import Optimal
from Reinforce import Reinforce
def create_and_print_dataset(config):
"""Create and print the dataset."""
dataset = Dataset(config)
print(dataset)
return dataset
def main():
"""Run the recommender system based on the provided model and parameters."""
parser = argparse.ArgumentParser(description="Run recommender models.")
parser.add_argument("--config", help="Path to the configuration file")
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
model_classes = {
"greedy": Greedy,
"optimal": Optimal,
"reinforce": Reinforce,
}
for run in range(config["nb_runs"]):
dataset = create_and_print_dataset(config)
# If the model is greedy or optimal, we use the corresponding class defined in Greedy.py and Optimal.py
if config["model"] in ["greedy", "optimal"]:
recommender = model_classes[config["model"]](dataset, config["threshold"])
recommendation_method = getattr(
recommender, f'{config["model"]}_recommendation'
)
recommendation_method(config["k"], run)
# Otherwise, we use the Reinforce class, described in Reinforce.py
else:
recommender = Reinforce(
dataset,
config["model"],
config["k"],
config["threshold"],
run,
config["total_steps"],
config["eval_freq"],
)
recommender.reinforce_recommendation()
if __name__ == "__main__":
main()