Skip to content

Commit c1d5d15

Browse files
authored
feat(index): added status update (#21)
* feat(index): added status update ... For a better DX/UX * chore(playground): update
1 parent 95c8e76 commit c1d5d15

File tree

2 files changed

+50
-31
lines changed

2 files changed

+50
-31
lines changed

playground/playground.js

+2
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,5 @@ writeFileSync(
4747
'playground-fullStats.json',
4848
JSON.stringify(longStats, null, 2),
4949
) && console.log('Saved learner to "playground-fullStats.json"')
50+
51+
process.exit(0)

src/index.js

+48-31
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@ const classifierBuilder = require('./classifier')
88
const categories = require('./categories')
99
const ConfusionMatrix = require('./confusionMatrix')
1010

11+
const spinner = new Spinner('Loading...', [
12+
'⣾',
13+
'⣽',
14+
'⣻',
15+
'⢿',
16+
'⡿',
17+
'⣟',
18+
'⣯',
19+
'⣷',
20+
])
21+
1122
/**
1223
* NodeJS Classification-based learner.
1324
* @class Learner
@@ -60,20 +71,11 @@ class Learner {
6071
*/
6172
train(trainSet = this.trainSet) {
6273
//@todo Move this so it could be used for any potentially lengthy ops
63-
const training = new Spinner('Training...', [
64-
'⣾',
65-
'⣽',
66-
'⣻',
67-
'⢿',
68-
'⡿',
69-
'⣟',
70-
'⣯',
71-
'⣷',
72-
])
73-
training.start()
74+
spinner.message('Training...')
75+
spinner.start()
7476
this.classifier.trainBatch(trainSet)
75-
training.message('Training complete')
76-
training.stop()
77+
// spinner.message('Training complete')
78+
spinner.stop()
7779
}
7880

7981
/**
@@ -82,18 +84,27 @@ class Learner {
8284
* @public
8385
*/
8486
eval() {
87+
spinner.message('Evaluating...')
88+
spinner.start()
8589
const actual = []
8690
const predicted = []
91+
const len = this.testSet.length
92+
let idx = 0
8793
for (const data of this.testSet) {
8894
const predictions = this.classify(data.input)
8995
actual.push(data.output)
9096
predicted.push(predictions.length ? predictions[0] : 'null') //Ignores the rest (as it only wants one guess)
97+
spinner.message(
98+
`Evaluating instances (${Math.round((idx++ / len) * 10000) / 100}%)`,
99+
)
91100
}
92101
this.confusionMatrix = ConfusionMatrix.fromData(
93102
actual,
94103
predicted,
95104
categories,
96105
)
106+
// spinner.message('Evaluation complete')
107+
spinner.stop()
97108
return this.confusionMatrix.getStats()
98109
}
99110

@@ -182,29 +193,30 @@ class Learner {
182193
F_1 (or effectiveness) = 2 * (Pr * R) / (Pr + R)
183194
...
184195
*/
196+
spinner.message('Cross-validating...')
197+
spinner.start()
185198
this.macroAvg = new PrecisionRecall()
186199
this.microAvg = new PrecisionRecall()
200+
const set = [...this.trainSet, ...this.validationSet]
187201

188-
partitions.partitions(
189-
[...this.trainSet, ...this.validationSet],
190-
numOfFolds,
191-
(trainSet, validationSet) => {
192-
if (log)
193-
process.stdout.write(
194-
`Training on ${trainSet.length} samples, testing ${validationSet.length} samples`,
195-
)
196-
this.train(trainSet)
197-
test(
198-
this.classifier,
199-
validationSet,
200-
verboseLevel,
201-
this.microAvg,
202-
this.macroAvg,
203-
)
204-
},
205-
)
202+
partitions.partitions(set, numOfFolds, (trainSet, validationSet) => {
203+
const status = `Training on ${trainSet.length} samples, testing ${validationSet.length} samples`
204+
//eslint-disable-next-line babel/no-unused-expressions
205+
log ? process.stdout.write(status) : spinner.message(status)
206+
this.train(trainSet)
207+
test(
208+
this.classifier,
209+
validationSet,
210+
verboseLevel,
211+
this.microAvg,
212+
this.macroAvg,
213+
)
214+
})
215+
spinner.message('Calculating stats')
206216
this.macroAvg.calculateMacroAverageStats(numOfFolds)
207217
this.microAvg.calculateStats()
218+
// spinner.message('Cross-validation complete')
219+
spinner.stop()
208220
return {
209221
macroAvg: this.macroAvg.fullStats(), //preferable in 2-class settings or in balanced multi-class settings
210222
microAvg: this.microAvg.fullStats(), //preferable in multi-class settings (in case of class imbalance)
@@ -278,6 +290,8 @@ class Learner {
278290
* @public
279291
*/
280292
getCategoryPartition() {
293+
spinner.message('Generating category partitions...')
294+
spinner.start()
281295
const res = {}
282296
categories.forEach(cat => {
283297
res[cat] = {
@@ -288,11 +302,14 @@ class Learner {
288302
}
289303
})
290304
this.dataset.forEach(data => {
305+
spinner.message(`Adding ${data.output} data`)
291306
++res[data.output].overall
292307
if (this.trainSet.includes(data)) ++res[data.output].train
293308
if (this.validationSet.includes(data)) ++res[data.output].validation
294309
if (this.testSet.includes(data)) ++res[data.output].test
295310
})
311+
// spinner.message('Category partitions complete')
312+
spinner.stop()
296313
return res
297314
}
298315

0 commit comments

Comments
 (0)