@@ -8,6 +8,17 @@ const classifierBuilder = require('./classifier')
8
8
const categories = require ( './categories' )
9
9
const ConfusionMatrix = require ( './confusionMatrix' )
10
10
11
+ const spinner = new Spinner ( 'Loading...' , [
12
+ '⣾' ,
13
+ '⣽' ,
14
+ '⣻' ,
15
+ '⢿' ,
16
+ '⡿' ,
17
+ '⣟' ,
18
+ '⣯' ,
19
+ '⣷' ,
20
+ ] )
21
+
11
22
/**
12
23
* NodeJS Classification-based learner.
13
24
* @class Learner
@@ -60,20 +71,11 @@ class Learner {
60
71
*/
61
72
train ( trainSet = this . trainSet ) {
62
73
//@todo Move this so it could be used for any potentially lengthy ops
63
- const training = new Spinner ( 'Training...' , [
64
- '⣾' ,
65
- '⣽' ,
66
- '⣻' ,
67
- '⢿' ,
68
- '⡿' ,
69
- '⣟' ,
70
- '⣯' ,
71
- '⣷' ,
72
- ] )
73
- training . start ( )
74
+ spinner . message ( 'Training...' )
75
+ spinner . start ( )
74
76
this . classifier . trainBatch ( trainSet )
75
- training . message ( 'Training complete' )
76
- training . stop ( )
77
+ // spinner .message('Training complete')
78
+ spinner . stop ( )
77
79
}
78
80
79
81
/**
@@ -82,18 +84,27 @@ class Learner {
82
84
* @public
83
85
*/
84
86
eval ( ) {
87
+ spinner . message ( 'Evaluating...' )
88
+ spinner . start ( )
85
89
const actual = [ ]
86
90
const predicted = [ ]
91
+ const len = this . testSet . length
92
+ let idx = 0
87
93
for ( const data of this . testSet ) {
88
94
const predictions = this . classify ( data . input )
89
95
actual . push ( data . output )
90
96
predicted . push ( predictions . length ? predictions [ 0 ] : 'null' ) //Ignores the rest (as it only wants one guess)
97
+ spinner . message (
98
+ `Evaluating instances (${ Math . round ( ( idx ++ / len ) * 10000 ) / 100 } %)` ,
99
+ )
91
100
}
92
101
this . confusionMatrix = ConfusionMatrix . fromData (
93
102
actual ,
94
103
predicted ,
95
104
categories ,
96
105
)
106
+ // spinner.message('Evaluation complete')
107
+ spinner . stop ( )
97
108
return this . confusionMatrix . getStats ( )
98
109
}
99
110
@@ -182,29 +193,30 @@ class Learner {
182
193
F_1 (or effectiveness) = 2 * (Pr * R) / (Pr + R)
183
194
...
184
195
*/
196
+ spinner . message ( 'Cross-validating...' )
197
+ spinner . start ( )
185
198
this . macroAvg = new PrecisionRecall ( )
186
199
this . microAvg = new PrecisionRecall ( )
200
+ const set = [ ...this . trainSet , ...this . validationSet ]
187
201
188
- partitions . partitions (
189
- [ ...this . trainSet , ...this . validationSet ] ,
190
- numOfFolds ,
191
- ( trainSet , validationSet ) => {
192
- if ( log )
193
- process . stdout . write (
194
- `Training on ${ trainSet . length } samples, testing ${ validationSet . length } samples` ,
195
- )
196
- this . train ( trainSet )
197
- test (
198
- this . classifier ,
199
- validationSet ,
200
- verboseLevel ,
201
- this . microAvg ,
202
- this . macroAvg ,
203
- )
204
- } ,
205
- )
202
+ partitions . partitions ( set , numOfFolds , ( trainSet , validationSet ) => {
203
+ const status = `Training on ${ trainSet . length } samples, testing ${ validationSet . length } samples`
204
+ //eslint-disable-next-line babel/no-unused-expressions
205
+ log ? process . stdout . write ( status ) : spinner . message ( status )
206
+ this . train ( trainSet )
207
+ test (
208
+ this . classifier ,
209
+ validationSet ,
210
+ verboseLevel ,
211
+ this . microAvg ,
212
+ this . macroAvg ,
213
+ )
214
+ } )
215
+ spinner . message ( 'Calculating stats' )
206
216
this . macroAvg . calculateMacroAverageStats ( numOfFolds )
207
217
this . microAvg . calculateStats ( )
218
+ // spinner.message('Cross-validation complete')
219
+ spinner . stop ( )
208
220
return {
209
221
macroAvg : this . macroAvg . fullStats ( ) , //preferable in 2-class settings or in balanced multi-class settings
210
222
microAvg : this . microAvg . fullStats ( ) , //preferable in multi-class settings (in case of class imbalance)
@@ -278,6 +290,8 @@ class Learner {
278
290
* @public
279
291
*/
280
292
getCategoryPartition ( ) {
293
+ spinner . message ( 'Generating category partitions...' )
294
+ spinner . start ( )
281
295
const res = { }
282
296
categories . forEach ( cat => {
283
297
res [ cat ] = {
@@ -288,11 +302,14 @@ class Learner {
288
302
}
289
303
} )
290
304
this . dataset . forEach ( data => {
305
+ spinner . message ( `Adding ${ data . output } data` )
291
306
++ res [ data . output ] . overall
292
307
if ( this . trainSet . includes ( data ) ) ++ res [ data . output ] . train
293
308
if ( this . validationSet . includes ( data ) ) ++ res [ data . output ] . validation
294
309
if ( this . testSet . includes ( data ) ) ++ res [ data . output ] . test
295
310
} )
311
+ // spinner.message('Category partitions complete')
312
+ spinner . stop ( )
296
313
return res
297
314
}
298
315
0 commit comments