Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better graphs #25

Merged
merged 3 commits into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions playground/categoryPartitions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
{
"blog": {
"overall": 3,
"test": 1,
"validation": 0,
"train": 2
},
"bug": {
"overall": 20,
"test": 3,
"validation": 1,
"train": 16
},
"business": {
"overall": 3,
"test": 1,
"validation": 0,
"train": 2
},
"code": {
"overall": 57,
"test": 6,
"validation": 8,
"train": 43
},
"content": {
"overall": 2,
"test": 1,
"validation": 0,
"train": 1
},
"design": {
"overall": 11,
"test": 0,
"validation": 2,
"train": 9
},
"doc": {
"overall": 23,
"test": 4,
"validation": 4,
"train": 15
},
"eventOrganizing": {
"overall": 4,
"test": 0,
"validation": 1,
"train": 3
},
"example": {
"overall": 2,
"test": 1,
"validation": 1,
"train": 0
},
"financial": {
"overall": 4,
"test": 1,
"validation": 1,
"train": 2
},
"fundingFinding": {
"overall": 2,
"test": 0,
"validation": 0,
"train": 2
},
"ideas": {
"overall": 26,
"test": 1,
"validation": 4,
"train": 21
},
"infra": {
"overall": 14,
"test": 6,
"validation": 2,
"train": 6
},
"maintenance": {
"overall": 30,
"test": 4,
"validation": 2,
"train": 24
},
"null": {
"overall": 197,
"test": 27,
"validation": 34,
"train": 136
},
"platform": {
"overall": 23,
"test": 7,
"validation": 5,
"train": 11
},
"plugin": {
"overall": 4,
"test": 1,
"validation": 0,
"train": 3
},
"projectManagement": {
"overall": 6,
"test": 0,
"validation": 1,
"train": 5
},
"question": {
"overall": 5,
"test": 0,
"validation": 1,
"train": 4
},
"review": {
"overall": 3,
"test": 0,
"validation": 1,
"train": 2
},
"security": {
"overall": 10,
"test": 2,
"validation": 1,
"train": 7
},
"talk": {
"overall": 2,
"test": 0,
"validation": 0,
"train": 2
},
"test": {
"overall": 12,
"test": 3,
"validation": 0,
"train": 9
},
"tool": {
"overall": 13,
"test": 2,
"validation": 3,
"train": 8
},
"translation": {
"overall": 1,
"test": 1,
"validation": 0,
"train": 0
},
"tutorial": {
"overall": 3,
"test": 1,
"validation": 1,
"train": 1
},
"userTesting": {
"overall": 3,
"test": 0,
"validation": 0,
"train": 3
},
"video": {
"overall": 3,
"test": 0,
"validation": 0,
"train": 3
}
}
2 changes: 2 additions & 0 deletions playground/playground.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ writeFileSync(
JSON.stringify(longStats, null, 2),
) && console.log('Saved learner to "playground-fullStats.json"')

console.log('More Stats:', learner.getStats(true))

process.exit(0)
48 changes: 38 additions & 10 deletions public/chart.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ const loadData = async () => {
try {
const data = await Promise.all([
fetch('../src/categories.json').then(res => res.json(), console.error),
fetch('../src/labels.json').then(res => res.json(), console.error),
// fetch('../src/labels.json').then(res => res.json(), console.error),
fetch('../playground/categoryPartitions.json').then(
res => res.json(),
console.error,
),
])

return data //[categories, dataset]
return data //[categories, dataset, categoryPartitions]
} catch (error) {
console.log('Error downloading one or more files:', error)
}
Expand All @@ -49,26 +53,44 @@ const loadData = async () => {
/**
* Organise a dataset for ChartJS.
* @param {[string[], object[]]} data Dataset of categories and instances
* @param {string} caption Caption of the chart
* @returns {Object} configuration for ChartJS
*/
const buildConfig = (data, caption = 'Categories') => {
const buildConfig = data => {
const res = {
labels: data[0], //categories
datasets: [
{
label: caption,
label: 'Training',
data: new Array(data[0].length).fill(0),
backgroundColor: '#00f',
},
{
label: 'Validation',
data: new Array(data[0].length).fill(0),
backgroundColor: '#0f0',
},
{
label: 'Test',
data: new Array(data[0].length).fill(0),
backgroundColor: '#f00',
},
],
}

data[1].forEach(instance => {
const idx = data[0].indexOf(instance.category)
res.datasets[0].data[idx]++
})
// console.log('dataset=', res.dataset.backgroundColor);
// data[1].forEach(instance => { //For labels.json
// const idx = data[0].indexOf(instance.category)
// res.datasets[0].data[idx]++
// })
for (const cat in data[1]) {
if (data[1].hasOwnProperty(cat)) {
const inf = data[1][cat]
const idx = data[0].indexOf(cat)
// res.datasets[0].data[idx] = inf.overall
res.datasets[0].data[idx] = inf.train
res.datasets[1].data[idx] = inf.validation
res.datasets[2].data[idx] = inf.test
}
}
return res
}

Expand All @@ -87,6 +109,12 @@ loadData().then(
ticks: {
beginAtZero: true,
},
stacked: true,
},
],
xAxes: [
{
stacked: true,
},
],
},
Expand Down
41 changes: 27 additions & 14 deletions src/index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const {writeFile, readFile} = require('fs')
const {writeFile, readFile, writeFileSync} = require('fs')
const serialize = require('serialization')
const tvts = require('tvt-split')
const {Spinner} = require('clui')
Expand Down Expand Up @@ -37,7 +37,7 @@ class Learner {
* @example <caption>Using a custom dataset</caption>
* const learner = new Learner({
* dataset: [{input: 'something bad', output: 'bad'}, {input: 'a good thing', output: 'good'}]
* })
* })
* @example <caption>Using a specified classifier function</caption>
* const learner = new Learner({
* classifier: myClassifierBuilderFn //see {@link module:./classifier} for an example (or checkout `limdu`'s examples)
Expand Down Expand Up @@ -72,11 +72,11 @@ class Learner {
*/
train(trainSet = this.trainSet) {
//@todo Move this so it could be used for any potentially lengthy ops
spinner.message('Training...')
spinner.start()
// spinner.start()
// spinner.message('Training...')
this.classifier.trainBatch(trainSet)
// spinner.message('Training complete')
spinner.stop()
// spinner.stop()
}

/**
Expand Down Expand Up @@ -189,7 +189,7 @@ class Learner {
/**
* @param {number} [numOfFolds=5] Cross-validation folds
* @param {number} [verboseLevel=0] Verbosity level on limdu's explainations
* @param {boolean} [log=false] Pre-training logging
* @param {boolean} [log=false] Cross-validation logging
* @returns {{microAvg: Object, macroAvg: Object}} Averages
* @memberof Learner
* @public
Expand Down Expand Up @@ -232,7 +232,7 @@ class Learner {
this.macroAvg,
)
})
spinner.message('Calculating stats...')
if (!log) spinner.message('Calculating stats...')
this.macroAvg.calculateMacroAverageStats(numOfFolds)
this.microAvg.calculateStats()
const completeMsg = 'Cross-validation complete'
Expand Down Expand Up @@ -308,10 +308,14 @@ class Learner {

/**
* @memberof Learner
* @param {boolean} [log=false] Log events
* @param {string} [outputFile='categoryPartitions.json'] Filename for the output (to be used by chart.html)
* @returns {Object<string, {overall: number, test: number, validation: number, train: number}>} Partitions
* @public
*/
getCategoryPartition() {
getCategoryPartition(log = false, outputFile = 'categoryPartitions.json') {
const hasInput = (set, input) => set.find(o => o.input === input)

spinner.message('Generating category partitions...')
spinner.start()
const res = {}
Expand All @@ -326,21 +330,30 @@ class Learner {
this.dataset.forEach(data => {
spinner.message(`Adding ${data.output} data`)
++res[data.output].overall
if (this.trainSet.includes(data)) ++res[data.output].train
if (this.validationSet.includes(data)) ++res[data.output].validation
if (this.testSet.includes(data)) ++res[data.output].test
if (hasInput(this.trainSet, data.input)) ++res[data.output].train
if (hasInput(this.validationSet, data.input))
++res[data.output].validation
if (hasInput(this.testSet, data.input)) ++res[data.output].test
})
// spinner.message('Category partitions complete')

const completeMsg = 'Category partitions complete'
//eslint-disable-next-line babel/no-unused-expressions
log ? succ(completeMsg) : spinner.message(completeMsg)
spinner.stop()
if (outputFile.length) {
writeFileSync(outputFile, JSON.stringify(res, null, 2))
if (log) succ(`Saved the partitions to "${outputFile}"`)
}
return res
}

/**
* @memberof Learner
* @param {boolean} [log=false] Log events
* @returns {Object} Statistics
* @public
*/
getStats() {
getStats(log = false) {
const {
TP,
TN,
Expand Down Expand Up @@ -368,7 +381,7 @@ class Learner {
trainCount: this.trainSet.length,
validationCount: this.validationSet.length,
testCount: this.testSet.length,
categoryPartition: this.getCategoryPartition(),
categoryPartition: this.getCategoryPartition(log),
//ROC, AUC
}
}
Expand Down