-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_ALS.py
37 lines (27 loc) · 1005 Bytes
/
run_ALS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#This file generates the submission ALS.csv from the files item_features_ALS and user_features_ALS
from helpers_MF import *
from train_ALS import train_ALS
import os
def run_ALS():
"""Create subimssion with predicitons of the ALS model"""
print("Running ALS...")
#load the positions of the predictions to generate
path_dataset = "data/data_test.csv"
positions= load_data(path_dataset)
#if features do not exist, traint the model
if not os.path.isfile("data/item_features_ALS.obj") or not os.path.isfile("data/user_features_ALS.obj"):
train_ALS()
#load the item features
file=open("data/item_features_ALS.obj",'rb')
item_features = pickle.load(file)
file.close()
#load the user features
file=open("data/user_features_ALS.obj",'rb')
user_features = pickle.load(file)
file.close()
#get the predictions based on the features
predictions=np.dot(item_features.T,user_features)
#create submission
create_submission(predictions,positions,"ALS")
if __name__ == "__main__":
run_ALS()