-
Notifications
You must be signed in to change notification settings - Fork 121
/
Copy pathtf_validate.js
62 lines (54 loc) · 2.47 KB
/
tf_validate.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/**
* Description.
* Validate the tensorflow.js model performance.
* Update the database with model performance.
**/
const config = require('./config');
const db = require('./InstantiateDB');
const pre_process = require('./pre_process');
const model = require('./tf_model');
const train = config.trainConfig;
const configDB = config.dbConfig;
const timesteps = 7;
const type = config.type;
/**
* Validate the ml model perfromance with 20% data of the database.
* Update database with model performance.
**/
async function val(){
let client = new db.MongoClient(db.uri, { useNewUrlParser: true, useUnifiedTopology: true });
client.connect(err => {
if (err) throw err;
const collection = client.db(configDB.database).collection(configDB.collectionStream);
collection.find({},{ projection: { _id: 0, Date: 1, Open: 1, High: 1, Low: 1, Close:1, Volume:1 }
}).toArray(async function(err, result) {
var X_test=[];
// Using 20% of the total data for training the tfjs model.
var test_data = result.slice(parseInt((train.trainSize/100) * result.length), result.length);
var test_data_inp = result.slice(result.length - test_data.length - timesteps,
result.length);
// Preprocess data with MinMaxScalar.
var test_scaled = pre_process.transform(test_data_inp,type);
for (i=timesteps; i<test_data.length + timesteps; i++){
X_test.push(test_scaled.slice(i-timesteps,i));
}
// Get model prediction.
const prediction = model.processModel(X_test);
prediction.then(function(result) {
// Inverse scale the predicted values to original value.
var predicted_stock_price = pre_process.inverse_transform(result.arraySync());
for(var i=0; i<test_data.length; i++){
data_ = {"date": test_data[i].Date,
"Prediction": predicted_stock_price[i],
"real": test_data[i][type],
"type": type
}
// Update database with model performance.
db.updateMongoDB(data_, configDB.collectionML, false)
}
})
client.close();
});
});
}
val();