From 6a4a6f61e8133a8567d30321ad9bd36209304338 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 26 Aug 2024 14:44:54 -0700 Subject: [PATCH 01/47] adding clowder class and tests --- package-lock.json | 22 +++++++-- src/__tests__/clowder.test.ts | 91 ++++++++++++++++++++++++++++++++++ src/clowder.ts | 93 +++++++++++++++++++++++++++++++++++ 3 files changed, 201 insertions(+), 5 deletions(-) create mode 100644 src/__tests__/clowder.test.ts create mode 100644 src/clowder.ts diff --git a/package-lock.json b/package-lock.json index e11a89c..2100f05 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10297,6 +10297,7 @@ }, "node_modules/npm/node_modules/lodash._baseindexof": { "version": "3.1.0", + "extraneous": true, "inBundle": true, "license": "MIT" }, @@ -10321,16 +10322,19 @@ }, "node_modules/npm/node_modules/lodash._bindcallback": { "version": "3.0.1", + "extraneous": true, "inBundle": true, "license": "MIT" }, "node_modules/npm/node_modules/lodash._cacheindexof": { "version": "3.0.2", + "extraneous": true, "inBundle": true, "license": "MIT" }, "node_modules/npm/node_modules/lodash._createcache": { "version": "3.1.2", + "extraneous": true, "inBundle": true, "license": "MIT", "dependencies": { @@ -10339,6 +10343,7 @@ }, "node_modules/npm/node_modules/lodash._getnative": { "version": "3.9.1", + "extraneous": true, "inBundle": true, "license": "MIT" }, @@ -10349,6 +10354,7 @@ }, "node_modules/npm/node_modules/lodash.restparam": { "version": "3.6.1", + "extraneous": true, "inBundle": true, "license": "MIT" }, @@ -24664,7 +24670,8 @@ }, "lodash._baseindexof": { "version": "3.1.0", - "bundled": true + "bundled": true, + "extraneous": true }, "lodash._baseuniq": { "version": "4.6.0", @@ -24686,22 +24693,26 @@ }, "lodash._bindcallback": { "version": "3.0.1", - "bundled": true + "bundled": true, + "extraneous": true }, "lodash._cacheindexof": { "version": "3.0.2", - "bundled": true + "bundled": true, + "extraneous": true }, "lodash._createcache": { "version": "3.1.2", "bundled": true, + "extraneous": true, "requires": { "lodash._getnative": "^3.0.0" } }, "lodash._getnative": { "version": "3.9.1", - "bundled": true + "bundled": true, + "extraneous": true }, "lodash.clonedeep": { "version": "4.5.0", @@ -24709,7 +24720,8 @@ }, "lodash.restparam": { "version": "3.6.1", - "bundled": true + "bundled": true, + "extraneous": true }, "lodash.union": { "version": "4.6.0", diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts new file mode 100644 index 0000000..a75d862 --- /dev/null +++ b/src/__tests__/clowder.test.ts @@ -0,0 +1,91 @@ +import { Clowder } from '../Clowder'; +import { Stimulus } from '../type'; +import { CatInput } from '../index'; + +describe('Clowder', () => { + let clowder: Clowder; + + const cat1Input: CatInput = { + method: 'MLE', + itemSelect: 'MFI', + }; + + const cat2Input: CatInput = { + method: 'EAP', + itemSelect: 'closest', + }; + + const stimuli1: Stimulus[] = [ + { difficulty: 0.5, c: 0.5, word: 'looking' }, + { difficulty: 3.5, c: 0.5, word: 'opaque' }, + { difficulty: 2, c: 0.5, word: 'right' }, + { difficulty: -2.5, c: 0.5, word: 'yes' }, + { difficulty: -1.8, c: 0.5, word: 'mom' }, + ]; + + const stimuli2: Stimulus[] = [ + { difficulty: 1.0, c: 0.5, word: 'cat' }, + { difficulty: -1.0, c: 0.5, word: 'dog' }, + { difficulty: 2.0, c: 0.5, word: 'fish' }, + ]; + + beforeEach(() => { + clowder = new Clowder({ + cats: [cat1Input, cat2Input], + corpora: [stimuli1, stimuli2], + }); + }); + + it('correctly suggests the next stimulus for each Cat', () => { + const nextStimuli = clowder.getNextStimuli(); + expect(nextStimuli.length).toBe(2); + expect(nextStimuli[0]).toEqual(stimuli1[0]); // Expect first stimulus for cat1 + expect(nextStimuli[1]).toEqual(stimuli2[0]); // Expect first stimulus for cat2 + }); + + it('correctly manages remaining stimuli after selection', () => { + clowder.getNextStimulus(0); + const expectedStimulus = { difficulty: -1.8, c: 0.5, word: 'mom' }; + expect(clowder.getNextStimulus(0)).toEqual(expectedStimulus); // Adjusted expectation + }); + + it('throws an error if trying to access an invalid Cat index', () => { + expect(() => clowder.getNextStimulus(2)).toThrow(Error); + expect(() => clowder.getNextStimulus(-1)).toThrow(Error); + }); + + it('allows adding a new Cat and correctly suggests the next stimulus', () => { + const newCatInput: CatInput = { + method: 'MLE', + itemSelect: 'random', + }; + const newStimuli: Stimulus[] = [ + { difficulty: 1.5, c: 0.5, word: 'lion' }, + { difficulty: -0.5, c: 0.5, word: 'tiger' }, + ]; + + clowder.addCat(newCatInput, newStimuli); + + const nextStimulus = clowder.getNextStimulus(2); + expect(nextStimulus).toBeDefined(); + expect(newStimuli).toContainEqual(nextStimulus); // Use toContainEqual + }); + + it('allows removing a Cat and handles stimuli accordingly', () => { + clowder.removeCat(1); + expect(() => clowder.getNextStimulus(1)).toThrow(Error); // Cat2 should be removed + expect(clowder.getNextStimuli().length).toBe(1); // Only one Cat remains + }); + + it('correctly suggests the next item (random method)', () => { + const randomCatInput: CatInput = { + itemSelect: 'random', + randomSeed: 'test-seed', // Ensure seed is correctly set + }; + const randomStimuli: Stimulus[] = stimuli1.slice(); // Copy of stimuli1 for testing + clowder.addCat(randomCatInput, randomStimuli); + + const nextStimulus = clowder.getNextStimulus(2); // New Cat at index 2 + expect(randomStimuli).toContainEqual(nextStimulus); // Check if nextStimulus is one of the randomStimuli + }); +}); diff --git a/src/clowder.ts b/src/clowder.ts new file mode 100644 index 0000000..dfcac44 --- /dev/null +++ b/src/clowder.ts @@ -0,0 +1,93 @@ +import { Cat, CatInput } from './index'; +import { Stimulus } from './type'; + +/** + * Explanation: + * ClowderInput: Defines the input parameters for the Clowder class, including configurations for multiple Cat instances and corresponding corpora. + * Clowder class: + * constructor: Initializes the Clowder with multiple Cat instances and their respective corpora. + * getNextStimulus: Retrieves the next stimulus for a specific Cat instance. + * getNextStimuli: Retrieves the next stimuli for all Cat instances. + * addCat: Adds a new Cat instance to the Clowder. + * removeCat: Removes a Cat instance from the Clowder. + */ + +export interface ClowderInput { + cats: CatInput[]; // Array of Cat configurations + corpora: Stimulus[][]; // Array of stimuli arrays, one for each Cat +} + +export class Clowder { + private cats: Cat[]; + private corpora: Stimulus[][]; + + /** + * Create a Clowder object. + * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. + */ + constructor({ cats, corpora }: ClowderInput) { + if (cats.length !== corpora.length) { + throw new Error('The number of Cat instances must match the number of corpora'); + } + + // Initialize Cats and corresponding corpora + this.cats = cats.map((catInput) => new Cat(catInput)); + this.corpora = corpora; + } + + /** + * Get the next stimulus for a specific Cat instance. + * @param {number} catIndex - The index of the Cat instance to select the next stimulus for. + * @returns {Stimulus | null} The next stimulus, or null if no more stimuli are available. + */ + public getNextStimulus(catIndex: number): Stimulus | null { + if (catIndex < 0 || catIndex >= this.cats.length) { + throw new Error('Invalid Cat index'); + } + + const cat = this.cats[catIndex]; + const stimuli = this.corpora[catIndex]; + + if (stimuli.length === 0) { + return null; // No more stimuli available for this Cat + } + + const { nextStimulus, remainingStimuli } = cat.findNextItem(stimuli); + this.corpora[catIndex] = remainingStimuli; // Update the corpus for this Cat + + return nextStimulus; + } + + /** + * Get the next stimuli for all Cat instances. + * @returns {Stimulus[]} An array of next stimuli for each Cat. + */ + public getNextStimuli(): Stimulus[] { + return this.cats + .map((_, index) => this.getNextStimulus(index)) + .filter((stimulus): stimulus is Stimulus => stimulus !== null); + } + + /** + * Add a new Cat instance to the Clowder. + * @param {CatInput} catInput - Configuration for the new Cat instance. + * @param {Stimulus[]} stimuli - The corpus for the new Cat. + */ + public addCat(catInput: CatInput, stimuli: Stimulus[]) { + this.cats.push(new Cat(catInput)); + this.corpora.push(stimuli); + } + + /** + * Remove a Cat instance from the Clowder. + * @param {number} catIndex - The index of the Cat instance to remove. + */ + public removeCat(catIndex: number) { + if (catIndex < 0 || catIndex >= this.cats.length) { + throw new Error('Invalid Cat index'); + } + + this.cats.splice(catIndex, 1); + this.corpora.splice(catIndex, 1); + } +} From 21b2f9710eb94d38c5e3dcec851eeb48b5e6552f Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 26 Aug 2024 14:48:04 -0700 Subject: [PATCH 02/47] changing Clowder for clowder --- src/__tests__/clowder.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index a75d862..5f11d3e 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -1,4 +1,4 @@ -import { Clowder } from '../Clowder'; +import { Clowder } from '../clowder'; import { Stimulus } from '../type'; import { CatInput } from '../index'; From 759dae5bd5593f8c53731bd829caf9eb8181b060 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 9 Sep 2024 17:02:38 -0700 Subject: [PATCH 03/47] Adding updateCatAndGetNextItem function Co-authored-by: Adam Richie-Halford --- package-lock.json | 34 ++--- package.json | 4 +- src/__tests__/clowder.test.ts | 154 +++++++++++++---------- src/__tests__/index.test.ts | 98 +++++++++------ src/__tests__/utils.test.ts | 50 +++++--- src/clowder.ts | 225 ++++++++++++++++++++++++++-------- src/index.ts | 5 +- src/type.ts | 24 +++- src/utils.ts | 23 +++- 9 files changed, 414 insertions(+), 203 deletions(-) diff --git a/package-lock.json b/package-lock.json index 2100f05..b933c10 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,7 +16,7 @@ "seedrandom": "^3.0.5" }, "devDependencies": { - "@types/jest": "^28.1.6", + "@types/jest": "^28.1.8", "@types/lodash": "^4.14.182", "@types/seedrandom": "^3.0.2", "@typescript-eslint/eslint-plugin": "^5.30.7", @@ -25,7 +25,7 @@ "eslint-config-prettier": "^8.5.0", "jest": "^28.1.3", "prettier": "^2.7.1", - "ts-jest": "^28.0.7", + "ts-jest": "^28.0.8", "tsdoc": "^0.0.4", "typescript": "^4.7.4" } @@ -1863,12 +1863,13 @@ } }, "node_modules/@types/jest": { - "version": "28.1.6", - "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.6.tgz", - "integrity": "sha512-0RbGAFMfcBJKOmqRazM8L98uokwuwD5F8rHrv/ZMbrZBwVOWZUyPG6VFNscjYr/vjM3Vu4fRrCPbOs42AfemaQ==", + "version": "28.1.8", + "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.8.tgz", + "integrity": "sha512-8TJkV++s7B6XqnDrzR1m/TT0A0h948Pnl/097veySPN67VRAgQ4gZ7n2KfJo2rVq6njQjdxU3GCCyDvAeuHoiw==", "dev": true, + "license": "MIT", "dependencies": { - "jest-matcher-utils": "^28.0.0", + "expect": "^28.0.0", "pretty-format": "^28.0.0" } }, @@ -15950,10 +15951,11 @@ } }, "node_modules/ts-jest": { - "version": "28.0.7", - "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.7.tgz", - "integrity": "sha512-wWXCSmTwBVmdvWrOpYhal79bDpioDy4rTT+0vyUnE3ZzM7LOAAGG9NXwzkEL/a516rQEgnMmS/WKP9jBPCVJyA==", + "version": "28.0.8", + "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.8.tgz", + "integrity": "sha512-5FaG0lXmRPzApix8oFG8RKjAz4ehtm8yMKOTy5HX3fY6W8kmvOrmcY0hKDElW52FJov+clhUbrKAqofnj4mXTg==", "dev": true, + "license": "MIT", "dependencies": { "bs-logger": "0.x", "fast-json-stable-stringify": "2.x", @@ -18113,12 +18115,12 @@ } }, "@types/jest": { - "version": "28.1.6", - "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.6.tgz", - "integrity": "sha512-0RbGAFMfcBJKOmqRazM8L98uokwuwD5F8rHrv/ZMbrZBwVOWZUyPG6VFNscjYr/vjM3Vu4fRrCPbOs42AfemaQ==", + "version": "28.1.8", + "resolved": "https://registry.npmjs.org/@types/jest/-/jest-28.1.8.tgz", + "integrity": "sha512-8TJkV++s7B6XqnDrzR1m/TT0A0h948Pnl/097veySPN67VRAgQ4gZ7n2KfJo2rVq6njQjdxU3GCCyDvAeuHoiw==", "dev": true, "requires": { - "jest-matcher-utils": "^28.0.0", + "expect": "^28.0.0", "pretty-format": "^28.0.0" } }, @@ -29294,9 +29296,9 @@ "integrity": "sha512-WZGXGstmCWgeevgTL54hrCuw1dyMQIzWy7ZfqRJfSmJZBwklI15egmQytFP6bPidmw3M8d5yEowl1niq4vmqZw==" }, "ts-jest": { - "version": "28.0.7", - "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.7.tgz", - "integrity": "sha512-wWXCSmTwBVmdvWrOpYhal79bDpioDy4rTT+0vyUnE3ZzM7LOAAGG9NXwzkEL/a516rQEgnMmS/WKP9jBPCVJyA==", + "version": "28.0.8", + "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-28.0.8.tgz", + "integrity": "sha512-5FaG0lXmRPzApix8oFG8RKjAz4ehtm8yMKOTy5HX3fY6W8kmvOrmcY0hKDElW52FJov+clhUbrKAqofnj4mXTg==", "dev": true, "requires": { "bs-logger": "0.x", diff --git a/package.json b/package.json index 8671448..b5f68ce 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,7 @@ }, "homepage": "https://github.com/yeatmanlab/jsCAT#readme", "devDependencies": { - "@types/jest": "^28.1.6", + "@types/jest": "^28.1.8", "@types/lodash": "^4.14.182", "@types/seedrandom": "^3.0.2", "@typescript-eslint/eslint-plugin": "^5.30.7", @@ -42,7 +42,7 @@ "eslint-config-prettier": "^8.5.0", "jest": "^28.1.3", "prettier": "^2.7.1", - "ts-jest": "^28.0.7", + "ts-jest": "^28.0.8", "tsdoc": "^0.0.4", "typescript": "^4.7.4" }, diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 5f11d3e..af1daf9 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -1,91 +1,109 @@ -import { Clowder } from '../clowder'; +import { Clowder, ClowderInput } from '../Clowder'; import { Stimulus } from '../type'; -import { CatInput } from '../index'; -describe('Clowder', () => { +// Mocking Stimulus +const createStimulus = (id: string): Stimulus => ({ + id, + difficulty: 1, + discrimination: 1, + guessing: 0, + slipping: 0, + content: `Stimulus content ${id}`, +}); + +describe('Clowder Class', () => { let clowder: Clowder; - const cat1Input: CatInput = { - method: 'MLE', - itemSelect: 'MFI', - }; + beforeEach(() => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'EAP', theta: -1.0 }, + }, + corpora: { + validated: [createStimulus('1'), createStimulus('2')], + unvalidated: [createStimulus('1')], + }, + }; + clowder = new Clowder(clowderInput); + }); - const cat2Input: CatInput = { - method: 'EAP', - itemSelect: 'closest', - }; + test('should initialize with provided cats and corpora', () => { + expect(Object.keys(clowder['cats'])).toContain('cat1'); + expect(clowder.remainingItems.validated).toHaveLength(2); + expect(clowder.remainingItems.unvalidated).toHaveLength(1); + }); - const stimuli1: Stimulus[] = [ - { difficulty: 0.5, c: 0.5, word: 'looking' }, - { difficulty: 3.5, c: 0.5, word: 'opaque' }, - { difficulty: 2, c: 0.5, word: 'right' }, - { difficulty: -2.5, c: 0.5, word: 'yes' }, - { difficulty: -1.8, c: 0.5, word: 'mom' }, - ]; + test('should validate cat names', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'invalidCat', + previousItems: [], + previousAnswers: [], + }); + }).toThrow('Invalid Cat name'); + }); - const stimuli2: Stimulus[] = [ - { difficulty: 1.0, c: 0.5, word: 'cat' }, - { difficulty: -1.0, c: 0.5, word: 'dog' }, - { difficulty: 2.0, c: 0.5, word: 'fish' }, - ]; + test('should update ability estimates', () => { + clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); + const cat1 = clowder['cats']['cat1']; + expect(cat1.theta).toBeGreaterThanOrEqual(0); // Since we mock, assume the result is logical. + }); - beforeEach(() => { - clowder = new Clowder({ - cats: [cat1Input, cat2Input], - corpora: [stimuli1, stimuli2], + test('should select next stimulus from validated stimuli', () => { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + previousItems: [createStimulus('1')], + previousAnswers: [1], }); + expect(nextItem).toEqual(createStimulus('1')); // Second validated stimulus }); - it('correctly suggests the next stimulus for each Cat', () => { - const nextStimuli = clowder.getNextStimuli(); - expect(nextStimuli.length).toBe(2); - expect(nextStimuli[0]).toEqual(stimuli1[0]); // Expect first stimulus for cat1 - expect(nextStimuli[1]).toEqual(stimuli2[0]); // Expect first stimulus for cat2 - }); + test('should return unvalidated stimulus when no validated stimuli remain', () => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + previousItems: [createStimulus('1'), createStimulus('2')], + previousAnswers: [1, 0], + }); - it('correctly manages remaining stimuli after selection', () => { - clowder.getNextStimulus(0); - const expectedStimulus = { difficulty: -1.8, c: 0.5, word: 'mom' }; - expect(clowder.getNextStimulus(0)).toEqual(expectedStimulus); // Adjusted expectation + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + previousItems: [], + previousAnswers: [], + }); + expect(nextItem).toEqual(createStimulus('1')); // Unvalidated item }); - it('throws an error if trying to access an invalid Cat index', () => { - expect(() => clowder.getNextStimulus(2)).toThrow(Error); - expect(() => clowder.getNextStimulus(-1)).toThrow(Error); + test('should add a new Cat instance', () => { + clowder.addCat('cat3', { method: 'MLE', theta: 0 }); + expect(Object.keys(clowder['cats'])).toContain('cat3'); }); - it('allows adding a new Cat and correctly suggests the next stimulus', () => { - const newCatInput: CatInput = { - method: 'MLE', - itemSelect: 'random', - }; - const newStimuli: Stimulus[] = [ - { difficulty: 1.5, c: 0.5, word: 'lion' }, - { difficulty: -0.5, c: 0.5, word: 'tiger' }, - ]; - - clowder.addCat(newCatInput, newStimuli); - - const nextStimulus = clowder.getNextStimulus(2); - expect(nextStimulus).toBeDefined(); - expect(newStimuli).toContainEqual(nextStimulus); // Use toContainEqual + test('should throw error if adding duplicate Cat instance', () => { + expect(() => { + clowder.addCat('cat1', { method: 'MLE', theta: 0 }); + }).toThrow('Cat with the name "cat1" already exists.'); }); - it('allows removing a Cat and handles stimuli accordingly', () => { - clowder.removeCat(1); - expect(() => clowder.getNextStimulus(1)).toThrow(Error); // Cat2 should be removed - expect(clowder.getNextStimuli().length).toBe(1); // Only one Cat remains + test('should remove a Cat instance', () => { + clowder.removeCat('cat1'); + expect(Object.keys(clowder['cats'])).not.toContain('cat1'); }); - it('correctly suggests the next item (random method)', () => { - const randomCatInput: CatInput = { - itemSelect: 'random', - randomSeed: 'test-seed', // Ensure seed is correctly set - }; - const randomStimuli: Stimulus[] = stimuli1.slice(); // Copy of stimuli1 for testing - clowder.addCat(randomCatInput, randomStimuli); + test('should throw error when trying to remove non-existent Cat instance', () => { + expect(() => { + clowder.removeCat('nonExistentCat'); + }).toThrow('Invalid Cat name'); + }); - const nextStimulus = clowder.getNextStimulus(2); // New Cat at index 2 - expect(randomStimuli).toContainEqual(nextStimulus); // Check if nextStimulus is one of the randomStimuli + test('should throw error if previousItems and previousAnswers have mismatched lengths', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + previousItems: [createStimulus('1')], + previousAnswers: [1, 0], // Mismatched length + }); + }).toThrow('Previous items and answers must have the same length.'); }); }); diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 1194090..37bcaec 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -1,17 +1,33 @@ import { Cat } from '../index'; -import { Stimulus } from '../type'; +import { zetaKeyMap, Stimulus, ZetaImplicit, ZetaExplicit } from '../type'; import seedrandom from 'seedrandom'; +import _mapKeys from 'lodash/mapKeys'; + +// Convert ZetaImplicit to ZetaExplicit +const convertZetaImplicitToExplicit = (zeta: ZetaImplicit): ZetaExplicit => { + const explicitZeta = _mapKeys(zeta, (value, key) => { + return zetaKeyMap[key as keyof typeof zetaKeyMap]; + }) as ZetaExplicit; + + return { + discrimination: explicitZeta.discrimination, + difficulty: explicitZeta.difficulty, + guessing: explicitZeta.guessing, + slipping: explicitZeta.slipping, + }; +}; describe('Cat', () => { let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; let rng = seedrandom(); + beforeEach(() => { cat1 = new Cat(); cat1.updateAbilityEstimate( [ - { a: 2.225, b: -1.885, c: 0.21, d: 1 }, - { a: 1.174, b: -2.411, c: 0.212, d: 1 }, - { a: 2.104, b: -2.439, c: 0.192, d: 1 }, + convertZetaImplicitToExplicit({ a: 2.225, b: -1.885, c: 0.21, d: 1 }), + convertZetaImplicitToExplicit({ a: 1.174, b: -2.411, c: 0.212, d: 1 }), + convertZetaImplicitToExplicit({ a: 2.104, b: -2.439, c: 0.192, d: 1 }), ], [1, 0, 1], ); @@ -19,13 +35,13 @@ describe('Cat', () => { cat2 = new Cat(); cat2.updateAbilityEstimate( [ - { a: 1, b: -0.447, c: 0.5, d: 1 }, - { a: 1, b: 2.869, c: 0.5, d: 1 }, - { a: 1, b: -0.469, c: 0.5, d: 1 }, - { a: 1, b: -0.576, c: 0.5, d: 1 }, - { a: 1, b: -1.43, c: 0.5, d: 1 }, - { a: 1, b: -1.607, c: 0.5, d: 1 }, - { a: 1, b: 0.529, c: 0.5, d: 1 }, + convertZetaImplicitToExplicit({ a: 1, b: -0.447, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: 2.869, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -0.469, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -0.576, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -1.43, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -1.607, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: 0.529, c: 0.5, d: 1 }), ], [0, 1, 0, 1, 1, 1, 1], ); @@ -33,13 +49,13 @@ describe('Cat', () => { const randomSeed = 'test'; rng = seedrandom(randomSeed); cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); - cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); + cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); // ask cat6 = new Cat(); cat6.updateAbilityEstimate( [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, + convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), ], [0, 0], ); @@ -47,8 +63,8 @@ describe('Cat', () => { cat7 = new Cat({ method: 'eap' }); cat7.updateAbilityEstimate( [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, + convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), ], [0, 0], ); @@ -56,11 +72,11 @@ describe('Cat', () => { cat8 = new Cat({ nStartItems: 0, itemSelect: 'FIXED' }); }); - const s1: Stimulus = { difficulty: 0.5, c: 0.5, word: 'looking' }; - const s2: Stimulus = { difficulty: 3.5, c: 0.5, word: 'opaque' }; - const s3: Stimulus = { difficulty: 2, c: 0.5, word: 'right' }; - const s4: Stimulus = { difficulty: -2.5, c: 0.5, word: 'yes' }; - const s5: Stimulus = { difficulty: -1.8, c: 0.5, word: 'mom' }; + const s1: Stimulus = { difficulty: 0.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'looking' }; + const s2: Stimulus = { difficulty: 3.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'opaque' }; + const s3: Stimulus = { difficulty: 2, guessing: 0.5, discrimination: 1, slipping: 1, word: 'right' }; + const s4: Stimulus = { difficulty: -2.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'yes' }; + const s5: Stimulus = { difficulty: -1.8, guessing: 0.5, discrimination: 1, slipping: 1, word: 'mom' }; const stimuli = [s1, s2, s3, s4, s5]; it('constructs an adaptive test', () => { @@ -88,15 +104,15 @@ describe('Cat', () => { expect(cat2.resps).toEqual([0, 1, 0, 1, 1, 1, 1]); }); - it('correctly updates zatas', () => { + it('correctly updates zetas', () => { expect(cat2.zetas).toEqual([ - { a: 1, b: -0.447, c: 0.5, d: 1 }, - { a: 1, b: 2.869, c: 0.5, d: 1 }, - { a: 1, b: -0.469, c: 0.5, d: 1 }, - { a: 1, b: -0.576, c: 0.5, d: 1 }, - { a: 1, b: -1.43, c: 0.5, d: 1 }, - { a: 1, b: -1.607, c: 0.5, d: 1 }, - { a: 1, b: 0.529, c: 0.5, d: 1 }, + convertZetaImplicitToExplicit({ a: 1, b: -0.447, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: 2.869, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -0.469, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -0.576, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -1.43, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -1.607, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: 0.529, c: 0.5, d: 1 }), ]); }); @@ -127,7 +143,7 @@ describe('Cat', () => { it('correctly suggests the next item (random method)', () => { let received; - const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); + const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); // ask let index = Math.floor(rng() * stimuliSorted.length); received = cat4.findNextItem(stimuliSorted); expect(received.nextStimulus).toEqual(stimuliSorted[index]); @@ -148,12 +164,12 @@ describe('Cat', () => { expect(cat7.theta).toBeCloseTo(0.25, 1); }); - it('should throw a error if zeta and answers do not have matching length', () => { + it('should throw an error if zeta and answers do not have matching length', () => { try { cat7.updateAbilityEstimate( [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, + convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), ], [0, 0, 0], ); @@ -162,7 +178,7 @@ describe('Cat', () => { } }); - it('should throw a error if method is invalid', () => { + it('should throw an error if method is invalid', () => { try { new Cat({ method: 'coolMethod' }); } catch (error) { @@ -172,8 +188,8 @@ describe('Cat', () => { try { cat7.updateAbilityEstimate( [ - { a: 1, b: -4.0, c: 0.5, d: 1 }, - { a: 1, b: -3.0, c: 0.5, d: 1 }, + convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), + convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), ], [0, 0], 'coolMethod', @@ -183,7 +199,7 @@ describe('Cat', () => { } }); - it('should throw a error if itemSelect is invalid', () => { + it('should throw an error if itemSelect is invalid', () => { try { new Cat({ itemSelect: 'coolMethod' }); } catch (error) { @@ -197,11 +213,17 @@ describe('Cat', () => { } }); - it('should throw a error if startSelect is invalid', () => { + it('should throw an error if startSelect is invalid', () => { try { new Cat({ startSelect: 'coolMethod' }); } catch (error) { expect(error).toBeInstanceOf(Error); } }); + + it('should return undefined if there are no input items', () => { + const cat10 = new Cat(); + const { nextStimulus } = cat10.findNextItem([]); + expect(nextStimulus).toBeUndefined(); + }); }); diff --git a/src/__tests__/utils.test.ts b/src/__tests__/utils.test.ts index d8392b5..8895e8e 100644 --- a/src/__tests__/utils.test.ts +++ b/src/__tests__/utils.test.ts @@ -2,40 +2,56 @@ import { itemResponseFunction, fisherInformation, findClosest } from '../utils'; describe('itemResponseFunction', () => { it('correctly calculates the probability', () => { - expect(0.7234).toBeCloseTo(itemResponseFunction(0, { a: 1, b: -0.3, c: 0.35, d: 1 }), 2); - - expect(0.5).toBeCloseTo(itemResponseFunction(0, { a: 1, b: 0, c: 0, d: 1 }), 2); - - expect(0.625).toBeCloseTo(itemResponseFunction(0, { a: 0.5, b: 0, c: 0.25, d: 1 }), 2); + expect(itemResponseFunction(0, { a: 1, b: -0.3, c: 0.35, d: 1 })).toBeCloseTo(0.7234, 2); + expect(itemResponseFunction(0, { a: 1, b: 0, c: 0, d: 1 })).toBeCloseTo(0.5, 2); + expect(itemResponseFunction(0, { a: 0.5, b: 0, c: 0.25, d: 1 })).toBeCloseTo(0.625, 2); }); }); describe('fisherInformation', () => { it('correctly calculates the information', () => { - expect(0.206).toBeCloseTo(fisherInformation(0, { a: 1.53, b: -0.5, c: 0.5, d: 1 }), 2); - - expect(0.1401).toBeCloseTo(fisherInformation(2.35, { a: 1, b: 2, c: 0.3, d: 1 }), 2); + expect(fisherInformation(0, { a: 1.53, b: -0.5, c: 0.5, d: 1 })).toBeCloseTo(0.206, 2); + expect(fisherInformation(2.35, { a: 1, b: 2, c: 0.3, d: 1 })).toBeCloseTo(0.1401, 2); }); }); describe('findClosest', () => { + const stimuli = [ + { difficulty: 1, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 4, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 10, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 11, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + ]; + it('correctly selects the first item if appropriate', () => { - expect(0).toBe(findClosest([{ difficulty: 1 }, { difficulty: 4 }, { difficulty: 10 }, { difficulty: 11 }], 0)); + expect(findClosest(stimuli, 0)).toBe(0); }); + it('correctly selects the last item if appropriate', () => { - expect(3).toBe(findClosest([{ difficulty: 1 }, { difficulty: 4 }, { difficulty: 10 }, { difficulty: 11 }], 1000)); + expect(findClosest(stimuli, 1000)).toBe(3); }); + it('correctly selects a middle item if it equals exactly', () => { - expect(2).toBe(findClosest([{ difficulty: 1 }, { difficulty: 4 }, { difficulty: 10 }, { difficulty: 11 }], 10)); + expect(findClosest(stimuli, 10)).toBe(2); }); + it('correctly selects the one item closest to the target if less than', () => { - expect(1).toBe( - findClosest([{ difficulty: 1.1 }, { difficulty: 4.2 }, { difficulty: 10.3 }, { difficulty: 11.4 }], 5.1), - ); + const stimuliWithDecimal = [ + { difficulty: 1.1, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 4.2, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 10.3, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 11.4, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + ]; + expect(findClosest(stimuliWithDecimal, 5.1)).toBe(1); }); + it('correctly selects the one item closest to the target if greater than', () => { - expect(2).toBe( - findClosest([{ difficulty: 1.1 }, { difficulty: 4.2 }, { difficulty: 10.3 }, { difficulty: 11.4 }], 9.1), - ); + const stimuliWithDecimal = [ + { difficulty: 1.1, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 4.2, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 10.3, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + { difficulty: 11.4, discrimination: 1, guessing: 0.25, slipping: 0.75 }, + ]; + expect(findClosest(stimuliWithDecimal, 9.1)).toBe(2); }); }); diff --git a/src/clowder.ts b/src/clowder.ts index dfcac44..d0ca653 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -1,93 +1,210 @@ import { Cat, CatInput } from './index'; -import { Stimulus } from './type'; - -/** - * Explanation: - * ClowderInput: Defines the input parameters for the Clowder class, including configurations for multiple Cat instances and corresponding corpora. - * Clowder class: - * constructor: Initializes the Clowder with multiple Cat instances and their respective corpora. - * getNextStimulus: Retrieves the next stimulus for a specific Cat instance. - * getNextStimuli: Retrieves the next stimuli for all Cat instances. - * addCat: Adds a new Cat instance to the Clowder. - * removeCat: Removes a Cat instance from the Clowder. - */ +import { Stimulus, Zeta } from './type'; +import _cloneDeep from 'lodash/cloneDeep'; +import _mapValues from 'lodash/mapValues'; +import _zip from 'lodash/zip'; + +interface Corpora { + validated: Stimulus[]; + unvalidated: Stimulus[]; +} export interface ClowderInput { - cats: CatInput[]; // Array of Cat configurations - corpora: Stimulus[][]; // Array of stimuli arrays, one for each Cat + // An object containing Cat configurations for each Cat instance. + cats: { + [name: string]: CatInput; + }; + // An object containing arrays of stimuli for each corpus. + corpora: Corpora; } export class Clowder { - private cats: Cat[]; - private corpora: Stimulus[][]; + private cats: { [name: string]: Cat }; + private corpora: Corpora; + public remainingItems: Corpora; + public seenItems: Stimulus[]; /** * Create a Clowder object. * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. */ constructor({ cats, corpora }: ClowderInput) { - if (cats.length !== corpora.length) { - throw new Error('The number of Cat instances must match the number of corpora'); + this.cats = _mapValues(cats, (catInput) => new Cat(catInput)); + this.seenItems = []; + this.corpora = corpora; + this.remainingItems = _cloneDeep(corpora); + } + + private _validateCatName(catName: string): void { + if (!Object.prototype.hasOwnProperty.call(this.cats, catName)) { + throw new Error(`Invalid Cat name. Expected one of ${Object.keys(this.cats).join(', ')}. Received ${catName}.`); } + } - // Initialize Cats and corresponding corpora - this.cats = cats.map((catInput) => new Cat(catInput)); - this.corpora = corpora; + public get theta() { + return _mapValues(this.cats, (cat) => cat.theta); + } + + public get seMeasurement() { + return _mapValues(this.cats, (cat) => cat.seMeasurement); + } + + public get nItems() { + return _mapValues(this.cats, (cat) => cat.nItems); + } + + public get resps() { + return _mapValues(this.cats, (cat) => cat.resps); + } + + public get zetas() { + return _mapValues(this.cats, (cat) => cat.zetas); + } + + public updateAbilityEstimates(catNames: string[], zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method?: string) { + catNames.forEach((catName) => { + this._validateCatName(catName); + }); + for (const catName of catNames) { + this.cats[catName].updateAbilityEstimate(zeta, answer, method); + } } /** - * Get the next stimulus for a specific Cat instance. - * @param {number} catIndex - The index of the Cat instance to select the next stimulus for. - * @returns {Stimulus | null} The next stimulus, or null if no more stimuli are available. + * Updates the ability estimates for the specified `catsToUpdate` and selects the next stimulus for the `catToSelect`. + * This function processes previous items and answers, updates internal state, and selects the next stimulus + * based on the current state of validated and unvalidated stimuli. + * + * @param {Object} params - The parameters for updating the Cat instance and selecting the next stimulus. + * @param {string} params.catToSelect - The Cat instance to use for selecting the next stimulus. + * @param {string | string[]} [params.catsToUpdate=[]] - A single Cat or array of Cats for which to update ability estimates. + * @param {Stimulus[]} [params.previousItems=[]] - An array of previously presented stimuli. + * @param {(0 | 1) | (0 | 1)[]} [params.previousAnswers=[]] - An array of answers (0 or 1) corresponding to `previousItems`. + * @param {string} [params.method] - Optional method for updating ability estimates (if applicable). + * + * @returns {Stimulus | undefined} - The next stimulus to present, or `undefined` if no further validated stimuli are available. + * + * @throws {Error} If `previousItems` and `previousAnswers` lengths do not match. + * @throws {Error} If any `previousItems` are not found in the Clowder's corpora (validated or unvalidated). + * + * The function operates in several steps: + * 1. Validates the `catToSelect` and `catsToUpdate`. + * 2. Ensures `previousItems` and `previousAnswers` arrays are properly formatted. + * 3. Updates the internal list of seen items. + * 4. Updates the ability estimates for the `catsToUpdate`. + * 5. Selects the next stimulus for `catToSelect`, considering validated and unvalidated stimuli. */ - public getNextStimulus(catIndex: number): Stimulus | null { - if (catIndex < 0 || catIndex >= this.cats.length) { - throw new Error('Invalid Cat index'); + public updateCatAndGetNextItem({ + catToSelect, + catsToUpdate = [], + previousItems = [], + previousAnswers = [], + method, + }: { + catToSelect: string; + catsToUpdate?: string | string[]; + previousItems: Stimulus[]; + previousAnswers: (0 | 1) | (0 | 1)[]; + method?: string; + }): Stimulus | undefined { + this._validateCatName(catToSelect); + + catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; + catsToUpdate.forEach((cat) => { + this._validateCatName(cat); + }); + + previousItems = Array.isArray(previousItems) ? previousItems : [previousItems]; + previousAnswers = Array.isArray(previousAnswers) ? previousAnswers : [previousAnswers]; + + if (previousItems.length !== previousAnswers.length) { + throw new Error('Previous items and answers must have the same length.'); } - const cat = this.cats[catIndex]; - const stimuli = this.corpora[catIndex]; + // Update the seenItems with the provided previous items + this.seenItems.push(...previousItems); - if (stimuli.length === 0) { - return null; // No more stimuli available for this Cat + const itemsAndAnswers = _zip(previousItems, previousAnswers) as [Stimulus, 0 | 1][]; + const validatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => this.corpora.validated.includes(item)); + const unvalidatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => + this.corpora.unvalidated.includes(item), + ); + + const invalidItems = itemsAndAnswers.filter(([item, _answer]) => { + return !this.corpora.validated.includes(item) && !this.corpora.unvalidated.includes(item); + }); + + if (!invalidItems) { + throw new Error( + `The following previous items provided are not in this Clowder's corpora:\n${JSON.stringify( + invalidItems, + null, + 2, + )} ${invalidItems}`, + ); } - const { nextStimulus, remainingStimuli } = cat.findNextItem(stimuli); - this.corpora[catIndex] = remainingStimuli; // Update the corpus for this Cat + const validatedStimuli = validatedItemsAndAnswers.map(([stim, _]) => stim); + const unvalidatedStimuli = unvalidatedItemsAndAnswers.map(([stim, _]) => stim); + const validatedAnswers = validatedItemsAndAnswers.map(([_, answer]) => answer); - return nextStimulus; - } + // Remove previous items from the remainingItems + this.remainingItems.validated = this.remainingItems.validated.filter((item) => !validatedStimuli.includes(item)); + this.remainingItems.unvalidated = this.remainingItems.unvalidated.filter( + (item) => !unvalidatedStimuli.includes(item), + ); - /** - * Get the next stimuli for all Cat instances. - * @returns {Stimulus[]} An array of next stimuli for each Cat. - */ - public getNextStimuli(): Stimulus[] { - return this.cats - .map((_, index) => this.getNextStimulus(index)) - .filter((stimulus): stimulus is Stimulus => stimulus !== null); + // Update the ability estimates for the requested Cats + this.updateAbilityEstimates(catsToUpdate, validatedStimuli, validatedAnswers, method); + + // Use the catForSelect to determine the next stimulus + const cat = this.cats[catToSelect]; + const { nextStimulus } = cat.findNextItem(this.remainingItems.validated); + + // Added some logic to mix in the unvalidated stimuli if needed. + if (this.remainingItems.unvalidated.length === 0) { + // If there are no more unvalidated stimuli, we only have validated items left. + // Use the Cat to find the next item. The Cat may return undefined if all validated items have been seen. + return nextStimulus; + } else if (this.remainingItems.validated.length === 0) { + // In this case, there are no more validated items left. Choose an unvalidated item at random. + return this.remainingItems.unvalidated[Math.floor(Math.random() * this.remainingItems.unvalidated.length)]; + } else { + // In this case, there are both validated and unvalidated items left. + // We need to randomly insert unvalidated items + const numRemaining = { + validated: this.remainingItems.validated.length, + unvalidated: this.remainingItems.unvalidated.length, + }; + const random = Math.random(); + + if (random < numRemaining.unvalidated / (numRemaining.validated + numRemaining.unvalidated)) { + return this.remainingItems.unvalidated[Math.floor(Math.random() * this.remainingItems.unvalidated.length)]; + } else { + return nextStimulus; + } + } } /** * Add a new Cat instance to the Clowder. + * @param {string} catName - Name of the new Cat. * @param {CatInput} catInput - Configuration for the new Cat instance. * @param {Stimulus[]} stimuli - The corpus for the new Cat. */ - public addCat(catInput: CatInput, stimuli: Stimulus[]) { - this.cats.push(new Cat(catInput)); - this.corpora.push(stimuli); + public addCat(catName: string, catInput: CatInput) { + if (Object.prototype.hasOwnProperty.call(this.cats, catName)) { + throw new Error(`Cat with the name "${catName}" already exists.`); + } + this.cats[catName] = new Cat(catInput); } /** * Remove a Cat instance from the Clowder. - * @param {number} catIndex - The index of the Cat instance to remove. + * @param {string} catName - The name of the Cat instance to remove. */ - public removeCat(catIndex: number) { - if (catIndex < 0 || catIndex >= this.cats.length) { - throw new Error('Invalid Cat index'); - } - - this.cats.splice(catIndex, 1); - this.corpora.splice(catIndex, 1); + public removeCat(catName: string) { + this._validateCatName(catName); + delete this.cats[catName]; } } diff --git a/src/index.ts b/src/index.ts index 83424c1..5c79286 100644 --- a/src/index.ts +++ b/src/index.ts @@ -199,8 +199,7 @@ export class Cat { * @param stimuli - an array of stimulus * @param itemSelect - the item selection method * @param deepCopy - default deepCopy = true - * @returns {nextStimulus: Stimulus, - remainingStimuli: Array} + * @returns {nextStimulus: Stimulus, remainingStimuli: Array} */ public findNextItem(stimuli: Stimulus[], itemSelect: string = this.itemSelect, deepCopy = true) { let arr: Array; @@ -302,7 +301,7 @@ export class Cat { * @returns {Stimulus[]} return.remainingStimuli - The list of what's left after picking the item. */ private selectorFixed(arr: Stimulus[]) { - const nextItem = arr.shift() ?? null; + const nextItem = arr.shift(); return { nextStimulus: nextItem, remainingStimuli: arr, diff --git a/src/type.ts b/src/type.ts index 36b4793..46e4813 100644 --- a/src/type.ts +++ b/src/type.ts @@ -1,7 +1,27 @@ -export type Zeta = { a: number; b: number; c: number; d: number }; +export const zetaKeyMap = { + a: 'discrimination', + b: 'difficulty', + c: 'guessing', + d: 'slipping', +}; -export interface Stimulus { +export type ZetaImplicit = { + a: number; // Discrimination (slope of the curve) + b: number; // Difficulty (location of the curve) + c: number; // Guessing (lower asymptote) + d: number; // Slipping (upper asymptote) +}; + +export type ZetaExplicit = { + discrimination: number; difficulty: number; + guessing: number; + slipping: number; +}; + +export type Zeta = ZetaImplicit | ZetaExplicit; + +export interface Stimulus extends ZetaExplicit { // eslint-disable-next-line @typescript-eslint/no-explicit-any [key: string]: any; } diff --git a/src/utils.ts b/src/utils.ts index f7a8d8c..6982fab 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,5 +1,5 @@ import bs from 'binary-search'; -import { Stimulus, Zeta } from './type'; +import { Stimulus, Zeta, ZetaExplicit, ZetaImplicit } from './type'; /** * calculates the probability that someone with a given ability level theta will answer correctly an item. Uses the 4 parameters logistic model @@ -8,7 +8,16 @@ import { Stimulus, Zeta } from './type'; * @returns {number} the probability */ export const itemResponseFunction = (theta: number, zeta: Zeta) => { - return zeta.c + (zeta.d - zeta.c) / (1 + Math.exp(-zeta.a * (theta - zeta.b))); + if ((zeta as ZetaImplicit).a) { + const _zeta = zeta as ZetaImplicit; + return _zeta.c + (_zeta.d - _zeta.c) / (1 + Math.exp(-_zeta.a * (theta - _zeta.b))); + } else { + const _zeta = zeta as ZetaExplicit; + return ( + _zeta.guessing + + (_zeta.slipping - _zeta.guessing) / (1 + Math.exp(-_zeta.discrimination * (theta - _zeta.difficulty))) + ); + } }; /** @@ -20,7 +29,15 @@ export const itemResponseFunction = (theta: number, zeta: Zeta) => { export const fisherInformation = (theta: number, zeta: Zeta) => { const p = itemResponseFunction(theta, zeta); const q = 1 - p; - return Math.pow(zeta.a, 2) * (q / p) * (Math.pow(p - zeta.c, 2) / Math.pow(1 - zeta.c, 2)); + if ((zeta as ZetaImplicit).a) { + const _zeta = zeta as ZetaImplicit; + return Math.pow(_zeta.a, 2) * (q / p) * (Math.pow(p - _zeta.c, 2) / Math.pow(1 - _zeta.c, 2)); + } else { + const _zeta = zeta as ZetaExplicit; + return ( + Math.pow(_zeta.discrimination, 2) * (q / p) * (Math.pow(p - _zeta.guessing, 2) / Math.pow(1 - _zeta.guessing, 2)) + ); + } }; /** From 3a0cdcb2e47fb9fae6d1136a34acf2de2e677dd7 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 9 Sep 2024 17:06:36 -0700 Subject: [PATCH 04/47] eslint for unused -- used variables --- src/clowder.ts | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/clowder.ts b/src/clowder.ts index d0ca653..97ee537 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -125,11 +125,13 @@ export class Clowder { this.seenItems.push(...previousItems); const itemsAndAnswers = _zip(previousItems, previousAnswers) as [Stimulus, 0 | 1][]; + // eslint-disable-next-line @typescript-eslint/no-unused-vars const validatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => this.corpora.validated.includes(item)); + // eslint-disable-next-line @typescript-eslint/no-unused-vars const unvalidatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => this.corpora.unvalidated.includes(item), ); - + // eslint-disable-next-line @typescript-eslint/no-unused-vars const invalidItems = itemsAndAnswers.filter(([item, _answer]) => { return !this.corpora.validated.includes(item) && !this.corpora.unvalidated.includes(item); }); @@ -143,9 +145,11 @@ export class Clowder { )} ${invalidItems}`, ); } - + // eslint-disable-next-line @typescript-eslint/no-unused-vars const validatedStimuli = validatedItemsAndAnswers.map(([stim, _]) => stim); + // eslint-disable-next-line @typescript-eslint/no-unused-vars const unvalidatedStimuli = unvalidatedItemsAndAnswers.map(([stim, _]) => stim); + // eslint-disable-next-line @typescript-eslint/no-unused-vars const validatedAnswers = validatedItemsAndAnswers.map(([_, answer]) => answer); // Remove previous items from the remainingItems From 2b570a148cc66d6ddb3c59a9e8a21b616735becd Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 9 Sep 2024 17:09:29 -0700 Subject: [PATCH 05/47] clowder import --- src/__tests__/clowder.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index af1daf9..ea35133 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -1,4 +1,4 @@ -import { Clowder, ClowderInput } from '../Clowder'; +import { Clowder, ClowderInput } from '../clowder'; import { Stimulus } from '../type'; // Mocking Stimulus From a1f608abaccf5f8c0e7fb9b2bc79868bade60f94 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Tue, 17 Sep 2024 17:33:58 -0700 Subject: [PATCH 06/47] Add zetas for multiple cats to the corpus --- src/__tests__/index.test.ts | 407 +++++++++++++++++------------------- src/clowder.ts | 140 +++++-------- src/index.ts | 36 ++-- src/type.ts | 41 ++-- src/utils.ts | 146 ++++++++++--- 5 files changed, 409 insertions(+), 361 deletions(-) diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 37bcaec..bcbbd7c 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -1,229 +1,206 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ import { Cat } from '../index'; -import { zetaKeyMap, Stimulus, ZetaImplicit, ZetaExplicit } from '../type'; +import { Stimulus } from '../type'; import seedrandom from 'seedrandom'; -import _mapKeys from 'lodash/mapKeys'; - -// Convert ZetaImplicit to ZetaExplicit -const convertZetaImplicitToExplicit = (zeta: ZetaImplicit): ZetaExplicit => { - const explicitZeta = _mapKeys(zeta, (value, key) => { - return zetaKeyMap[key as keyof typeof zetaKeyMap]; - }) as ZetaExplicit; - - return { - discrimination: explicitZeta.discrimination, - difficulty: explicitZeta.difficulty, - guessing: explicitZeta.guessing, - slipping: explicitZeta.slipping, - }; -}; - -describe('Cat', () => { - let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; - let rng = seedrandom(); - - beforeEach(() => { - cat1 = new Cat(); - cat1.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 2.225, b: -1.885, c: 0.21, d: 1 }), - convertZetaImplicitToExplicit({ a: 1.174, b: -2.411, c: 0.212, d: 1 }), - convertZetaImplicitToExplicit({ a: 2.104, b: -2.439, c: 0.192, d: 1 }), - ], - [1, 0, 1], - ); - - cat2 = new Cat(); - cat2.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 1, b: -0.447, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 2.869, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.469, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.576, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.43, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.607, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 0.529, c: 0.5, d: 1 }), - ], - [0, 1, 0, 1, 1, 1, 1], - ); - cat3 = new Cat({ nStartItems: 0 }); - const randomSeed = 'test'; - rng = seedrandom(randomSeed); - cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); - cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); // ask - - cat6 = new Cat(); - cat6.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), - ], - [0, 0], - ); - - cat7 = new Cat({ method: 'eap' }); - cat7.updateAbilityEstimate( - [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), - ], - [0, 0], - ); - - cat8 = new Cat({ nStartItems: 0, itemSelect: 'FIXED' }); - }); - - const s1: Stimulus = { difficulty: 0.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'looking' }; - const s2: Stimulus = { difficulty: 3.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'opaque' }; - const s3: Stimulus = { difficulty: 2, guessing: 0.5, discrimination: 1, slipping: 1, word: 'right' }; - const s4: Stimulus = { difficulty: -2.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'yes' }; - const s5: Stimulus = { difficulty: -1.8, guessing: 0.5, discrimination: 1, slipping: 1, word: 'mom' }; - const stimuli = [s1, s2, s3, s4, s5]; - - it('constructs an adaptive test', () => { - expect(cat1.method).toBe('mle'); - expect(cat1.itemSelect).toBe('mfi'); - }); - - it('correctly updates ability estimate', () => { - expect(cat1.theta).toBeCloseTo(-1.642307, 1); - }); - - it('correctly updates ability estimate', () => { - expect(cat2.theta).toBeCloseTo(-1.272, 1); - }); - - it('correctly updates standard error of mean of ability estimate', () => { - expect(cat2.seMeasurement).toBeCloseTo(1.71, 1); - }); - - it('correctly counts number of items', () => { - expect(cat2.nItems).toEqual(7); - }); - - it('correctly updates answers', () => { - expect(cat2.resps).toEqual([0, 1, 0, 1, 1, 1, 1]); - }); - - it('correctly updates zetas', () => { - expect(cat2.zetas).toEqual([ - convertZetaImplicitToExplicit({ a: 1, b: -0.447, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 2.869, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.469, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -0.576, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.43, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -1.607, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: 0.529, c: 0.5, d: 1 }), - ]); - }); - - it('correctly suggests the next item (closest method)', () => { - const expected = { nextStimulus: s5, remainingStimuli: [s4, s1, s3, s2] }; - const received = cat1.findNextItem(stimuli, 'closest'); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (mfi method)', () => { - const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat3.findNextItem(stimuli, 'MFI'); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (middle method)', () => { - const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat5.findNextItem(stimuli); - expect(received).toEqual(expected); - }); - - it('correctly suggests the next item (fixed method)', () => { - expect(cat8.itemSelect).toBe('fixed'); - const expected = { nextStimulus: s1, remainingStimuli: [s2, s3, s4, s5] }; - const received = cat8.findNextItem(stimuli); - expect(received).toEqual(expected); - }); +import { convertZeta } from '../utils'; - it('correctly suggests the next item (random method)', () => { - let received; - const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); // ask - let index = Math.floor(rng() * stimuliSorted.length); - received = cat4.findNextItem(stimuliSorted); - expect(received.nextStimulus).toEqual(stimuliSorted[index]); - - for (let i = 0; i < 3; i++) { - const remainingStimuli = received.remainingStimuli; - index = Math.floor(rng() * remainingStimuli.length); - received = cat4.findNextItem(remainingStimuli); - expect(received.nextStimulus).toEqual(remainingStimuli[index]); - } - }); - - it('correctly updates ability estimate through MLE', () => { - expect(cat6.theta).toBeCloseTo(-6.0, 1); - }); +for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) { + describe('Cat with explicit zeta', () => { + let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; + let rng = seedrandom(); - it('correctly updates ability estimate through EAP', () => { - expect(cat7.theta).toBeCloseTo(0.25, 1); - }); - - it('should throw an error if zeta and answers do not have matching length', () => { - try { - cat7.updateAbilityEstimate( + beforeEach(() => { + cat1 = new Cat(); + cat1.updateAbilityEstimate( [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), + convertZeta({ a: 2.225, b: -1.885, c: 0.21, d: 1 }, format), + convertZeta({ a: 1.174, b: -2.411, c: 0.212, d: 1 }, format), + convertZeta({ a: 2.104, b: -2.439, c: 0.192, d: 1 }, format), ], - [0, 0, 0], + [1, 0, 1], ); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - - it('should throw an error if method is invalid', () => { - try { - new Cat({ method: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - try { - cat7.updateAbilityEstimate( + cat2 = new Cat(); + cat2.updateAbilityEstimate( [ - convertZetaImplicitToExplicit({ a: 1, b: -4.0, c: 0.5, d: 1 }), - convertZetaImplicitToExplicit({ a: 1, b: -3.0, c: 0.5, d: 1 }), + convertZeta({ a: 1, b: -0.447, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 2.869, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.469, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.576, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.43, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.607, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 0.529, c: 0.5, d: 1 }, format), ], + [0, 1, 0, 1, 1, 1, 1], + ); + cat3 = new Cat({ nStartItems: 0 }); + const randomSeed = 'test'; + rng = seedrandom(randomSeed); + cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); + cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); // ask + + cat6 = new Cat(); + cat6.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], [0, 0], - 'coolMethod', ); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - - it('should throw an error if itemSelect is invalid', () => { - try { - new Cat({ itemSelect: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - - try { - cat7.findNextItem(stimuli, 'coolMethod'); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); - it('should throw an error if startSelect is invalid', () => { - try { - new Cat({ startSelect: 'coolMethod' }); - } catch (error) { - expect(error).toBeInstanceOf(Error); - } - }); + cat7 = new Cat({ method: 'eap' }); + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0], + ); - it('should return undefined if there are no input items', () => { - const cat10 = new Cat(); - const { nextStimulus } = cat10.findNextItem([]); - expect(nextStimulus).toBeUndefined(); - }); -}); + cat8 = new Cat({ nStartItems: 0, itemSelect: 'FIXED' }); + }); + + const s1: Stimulus = { difficulty: 0.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'looking' }; + const s2: Stimulus = { difficulty: 3.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'opaque' }; + const s3: Stimulus = { difficulty: 2, guessing: 0.5, discrimination: 1, slipping: 1, word: 'right' }; + const s4: Stimulus = { difficulty: -2.5, guessing: 0.5, discrimination: 1, slipping: 1, word: 'yes' }; + const s5: Stimulus = { difficulty: -1.8, guessing: 0.5, discrimination: 1, slipping: 1, word: 'mom' }; + const stimuli = [s1, s2, s3, s4, s5]; + + it('constructs an adaptive test', () => { + expect(cat1.method).toBe('mle'); + expect(cat1.itemSelect).toBe('mfi'); + }); + + it('correctly updates ability estimate', () => { + expect(cat1.theta).toBeCloseTo(-1.642307, 1); + }); + + it('correctly updates ability estimate', () => { + expect(cat2.theta).toBeCloseTo(-1.272, 1); + }); + + it('correctly updates standard error of mean of ability estimate', () => { + expect(cat2.seMeasurement).toBeCloseTo(1.71, 1); + }); + + it('correctly counts number of items', () => { + expect(cat2.nItems).toEqual(7); + }); + + it('correctly updates answers', () => { + expect(cat2.resps).toEqual([0, 1, 0, 1, 1, 1, 1]); + }); + + it('correctly updates zetas', () => { + expect(cat2.zetas).toEqual([ + convertZeta({ a: 1, b: -0.447, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 2.869, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.469, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -0.576, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.43, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: -1.607, c: 0.5, d: 1 }, format), + convertZeta({ a: 1, b: 0.529, c: 0.5, d: 1 }, format), + ]); + }); + + it('correctly suggests the next item (closest method)', () => { + const expected = { nextStimulus: s5, remainingStimuli: [s4, s1, s3, s2] }; + const received = cat1.findNextItem(stimuli, 'closest'); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (mfi method)', () => { + const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; + const received = cat3.findNextItem(stimuli, 'MFI'); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (middle method)', () => { + const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; + const received = cat5.findNextItem(stimuli); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (fixed method)', () => { + expect(cat8.itemSelect).toBe('fixed'); + const expected = { nextStimulus: s1, remainingStimuli: [s2, s3, s4, s5] }; + const received = cat8.findNextItem(stimuli); + expect(received).toEqual(expected); + }); + + it('correctly suggests the next item (random method)', () => { + let received; + const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); // ask + let index = Math.floor(rng() * stimuliSorted.length); + received = cat4.findNextItem(stimuliSorted); + expect(received.nextStimulus).toEqual(stimuliSorted[index]); + + for (let i = 0; i < 3; i++) { + const remainingStimuli = received.remainingStimuli; + index = Math.floor(rng() * remainingStimuli.length); + received = cat4.findNextItem(remainingStimuli); + expect(received.nextStimulus).toEqual(remainingStimuli[index]); + } + }); + + it('correctly updates ability estimate through MLE', () => { + expect(cat6.theta).toBeCloseTo(-6.0, 1); + }); + + it('correctly updates ability estimate through EAP', () => { + expect(cat7.theta).toBeCloseTo(0.25, 1); + }); + + it('should throw an error if zeta and answers do not have matching length', () => { + try { + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0, 0], + ); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if method is invalid', () => { + try { + new Cat({ method: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + + try { + cat7.updateAbilityEstimate( + [convertZeta({ a: 1, b: -4.0, c: 0.5, d: 1 }, format), convertZeta({ a: 1, b: -3.0, c: 0.5, d: 1 }, format)], + [0, 0], + 'coolMethod', + ); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if itemSelect is invalid', () => { + try { + new Cat({ itemSelect: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + + try { + cat7.findNextItem(stimuli, 'coolMethod'); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should throw an error if startSelect is invalid', () => { + try { + new Cat({ startSelect: 'coolMethod' }); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + }); + + it('should return undefined if there are no input items', () => { + const cat10 = new Cat(); + const { nextStimulus } = cat10.findNextItem([]); + expect(nextStimulus).toBeUndefined(); + }); + }); +} diff --git a/src/clowder.ts b/src/clowder.ts index 97ee537..60367c2 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -1,13 +1,10 @@ import { Cat, CatInput } from './index'; -import { Stimulus, Zeta } from './type'; +import { MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; import _cloneDeep from 'lodash/cloneDeep'; import _mapValues from 'lodash/mapValues'; +import _unzip from 'lodash/unzip'; import _zip from 'lodash/zip'; - -interface Corpora { - validated: Stimulus[]; - unvalidated: Stimulus[]; -} +import { validateCorpora } from './utils'; export interface ClowderInput { // An object containing Cat configurations for each Cat instance. @@ -15,13 +12,13 @@ export interface ClowderInput { [name: string]: CatInput; }; // An object containing arrays of stimuli for each corpus. - corpora: Corpora; + corpora: MultiZetaStimulus[]; } export class Clowder { private cats: { [name: string]: Cat }; - private corpora: Corpora; - public remainingItems: Corpora; + private corpora: MultiZetaStimulus[]; + public remainingItems: MultiZetaStimulus[]; public seenItems: Stimulus[]; /** @@ -29,8 +26,10 @@ export class Clowder { * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. */ constructor({ cats, corpora }: ClowderInput) { + // TODO: Need to pass in numItemsRequired so that we know when to stop providing new items. this.cats = _mapValues(cats, (catInput) => new Cat(catInput)); this.seenItems = []; + validateCorpora(corpora); this.corpora = corpora; this.remainingItems = _cloneDeep(corpora); } @@ -73,23 +72,23 @@ export class Clowder { /** * Updates the ability estimates for the specified `catsToUpdate` and selects the next stimulus for the `catToSelect`. * This function processes previous items and answers, updates internal state, and selects the next stimulus - * based on the current state of validated and unvalidated stimuli. + * based on the remaining stimuli and `catToSelect`. * - * @param {Object} params - The parameters for updating the Cat instance and selecting the next stimulus. - * @param {string} params.catToSelect - The Cat instance to use for selecting the next stimulus. - * @param {string | string[]} [params.catsToUpdate=[]] - A single Cat or array of Cats for which to update ability estimates. - * @param {Stimulus[]} [params.previousItems=[]] - An array of previously presented stimuli. - * @param {(0 | 1) | (0 | 1)[]} [params.previousAnswers=[]] - An array of answers (0 or 1) corresponding to `previousItems`. - * @param {string} [params.method] - Optional method for updating ability estimates (if applicable). + * @param {Object} input - The parameters for updating the Cat instance and selecting the next stimulus. + * @param {string} input.catToSelect - The Cat instance to use for selecting the next stimulus. + * @param {string | string[]} [input.catsToUpdate=[]] - A single Cat or array of Cats for which to update ability estimates. + * @param {Stimulus[]} [input.items=[]] - An array of previously presented stimuli. + * @param {(0 | 1) | (0 | 1)[]} [input.answers=[]] - An array of answers (0 or 1) corresponding to `items`. + * @param {string} [input.method] - Optional method for updating ability estimates (if applicable). * * @returns {Stimulus | undefined} - The next stimulus to present, or `undefined` if no further validated stimuli are available. * - * @throws {Error} If `previousItems` and `previousAnswers` lengths do not match. - * @throws {Error} If any `previousItems` are not found in the Clowder's corpora (validated or unvalidated). + * @throws {Error} If `items` and `answers` lengths do not match. + * @throws {Error} If any `items` are not found in the Clowder's corpora (validated or unvalidated). * * The function operates in several steps: * 1. Validates the `catToSelect` and `catsToUpdate`. - * 2. Ensures `previousItems` and `previousAnswers` arrays are properly formatted. + * 2. Ensures `items` and `answers` arrays are properly formatted. * 3. Updates the internal list of seen items. * 4. Updates the ability estimates for the `catsToUpdate`. * 5. Selects the next stimulus for `catToSelect`, considering validated and unvalidated stimuli. @@ -97,69 +96,60 @@ export class Clowder { public updateCatAndGetNextItem({ catToSelect, catsToUpdate = [], - previousItems = [], - previousAnswers = [], + items = [], + answers = [], method, }: { catToSelect: string; catsToUpdate?: string | string[]; - previousItems: Stimulus[]; - previousAnswers: (0 | 1) | (0 | 1)[]; + items: MultiZetaStimulus[]; + answers: (0 | 1) | (0 | 1)[]; method?: string; }): Stimulus | undefined { + // Validate all cat names this._validateCatName(catToSelect); - catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; catsToUpdate.forEach((cat) => { this._validateCatName(cat); }); - previousItems = Array.isArray(previousItems) ? previousItems : [previousItems]; - previousAnswers = Array.isArray(previousAnswers) ? previousAnswers : [previousAnswers]; + // Convert items and answers to arrays + items = Array.isArray(items) ? items : [items]; + answers = Array.isArray(answers) ? answers : [answers]; - if (previousItems.length !== previousAnswers.length) { + // Ensure that the lengths of items and answers match + if (items.length !== answers.length) { throw new Error('Previous items and answers must have the same length.'); } // Update the seenItems with the provided previous items - this.seenItems.push(...previousItems); - - const itemsAndAnswers = _zip(previousItems, previousAnswers) as [Stimulus, 0 | 1][]; - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const validatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => this.corpora.validated.includes(item)); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const unvalidatedItemsAndAnswers = itemsAndAnswers.filter(([item, _answer]) => - this.corpora.unvalidated.includes(item), - ); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const invalidItems = itemsAndAnswers.filter(([item, _answer]) => { - return !this.corpora.validated.includes(item) && !this.corpora.unvalidated.includes(item); - }); - - if (!invalidItems) { - throw new Error( - `The following previous items provided are not in this Clowder's corpora:\n${JSON.stringify( - invalidItems, - null, - 2, - )} ${invalidItems}`, - ); + this.seenItems.push(...items); + + // Remove the seenItems from the remainingItems + this.remainingItems = this.remainingItems.filter((stim) => !items.includes(stim)); + + const itemsAndAnswers = _zip(items, answers) as [Stimulus, 0 | 1][]; + + // Update the ability estimate for all cats + for (const catName of catsToUpdate) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim, _answer]) => { + const allCats = stim.zetas.reduce((acc: string[], { cats }: { cats: string }) => { + return [...acc, ...cats]; + }, []); + return allCats.includes(catName); + }); + + const zetasAndAnswersForCat = itemsAndAnswersForCat.map(([stim, _answer]) => { + const { zetas } = stim; + const zetaForCat = zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName)); + return [zetaForCat.zeta, _answer]; + }); + + // Extract the cat to update ability estimate + const [zetas, answers] = _unzip(zetasAndAnswersForCat); + this.cats[catName].updateAbilityEstimate(zetas, answers, method); } - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const validatedStimuli = validatedItemsAndAnswers.map(([stim, _]) => stim); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const unvalidatedStimuli = unvalidatedItemsAndAnswers.map(([stim, _]) => stim); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const validatedAnswers = validatedItemsAndAnswers.map(([_, answer]) => answer); - - // Remove previous items from the remainingItems - this.remainingItems.validated = this.remainingItems.validated.filter((item) => !validatedStimuli.includes(item)); - this.remainingItems.unvalidated = this.remainingItems.unvalidated.filter( - (item) => !unvalidatedStimuli.includes(item), - ); - - // Update the ability estimates for the requested Cats - this.updateAbilityEstimates(catsToUpdate, validatedStimuli, validatedAnswers, method); // Use the catForSelect to determine the next stimulus const cat = this.cats[catToSelect]; @@ -189,26 +179,4 @@ export class Clowder { } } } - - /** - * Add a new Cat instance to the Clowder. - * @param {string} catName - Name of the new Cat. - * @param {CatInput} catInput - Configuration for the new Cat instance. - * @param {Stimulus[]} stimuli - The corpus for the new Cat. - */ - public addCat(catName: string, catInput: CatInput) { - if (Object.prototype.hasOwnProperty.call(this.cats, catName)) { - throw new Error(`Cat with the name "${catName}" already exists.`); - } - this.cats[catName] = new Cat(catInput); - } - - /** - * Remove a Cat instance from the Clowder. - * @param {string} catName - The name of the Cat instance to remove. - */ - public removeCat(catName: string) { - this._validateCatName(catName); - delete this.cats[catName]; - } } diff --git a/src/index.ts b/src/index.ts index 5c79286..b0cf123 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,7 +1,15 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ import { minimize_Powell } from 'optimization-js'; import { cloneDeep } from 'lodash'; import { Stimulus, Zeta } from './type'; -import { itemResponseFunction, fisherInformation, normal, findClosest } from './utils'; +import { + itemResponseFunction, + fisherInformation, + normal, + findClosest, + validateZetaParams, + fillZetaDefaults, +} from './utils'; import seedrandom from 'seedrandom'; export const abilityPrior = normal(); @@ -26,7 +34,6 @@ export class Cat { public prior: number[][]; private readonly _zetas: Zeta[]; private readonly _resps: (0 | 1)[]; - private _nItems: number; private _theta: number; private _seMeasurement: number; public nStartItems: number; @@ -70,7 +77,6 @@ export class Cat { this._zetas = []; this._resps = []; this._theta = theta; - this._nItems = 0; this._seMeasurement = Number.MAX_VALUE; this.nStartItems = nStartItems; this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); @@ -84,6 +90,9 @@ export class Cat { return this._seMeasurement; } + /** + * Return the number of items that have been observed so far. + */ public get nItems() { return this._resps.length; } @@ -135,6 +144,8 @@ export class Cat { zeta = Array.isArray(zeta) ? zeta : [zeta]; answer = Array.isArray(answer) ? answer : [answer]; + zeta.forEach((z) => validateZetaParams(z, true)); + if (zeta.length !== answer.length) { throw new Error('Unmatched length between answers and item params'); } @@ -209,6 +220,9 @@ export class Cat { } else { arr = stimuli; } + + arr = arr.map((stim) => fillZetaDefaults(stim, 'semantic')); + if (this.nItems < this.nStartItems) { selector = this.startSelect; } @@ -216,7 +230,7 @@ export class Cat { // for mfi, we sort the arr by fisher information in the private function to select the best item, // and then sort by difficulty to return the remainingStimuli // for fixed, we want to keep the corpus order as input - arr.sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty); + arr.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); } if (selector === 'middle') { @@ -233,14 +247,10 @@ export class Cat { } } - private selectorMFI(arr: Stimulus[]) { - const stimuliAddFisher = arr.map((element: Stimulus) => ({ - fisherInformation: fisherInformation(this._theta, { - a: element.a || 1, - b: element.difficulty || 0, - c: element.c || 0, - d: element.d || 1, - }), + private selectorMFI(inputStimuli: Stimulus[]) { + const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); + const stimuliAddFisher = stimuli.map((element: Stimulus) => ({ + fisherInformation: fisherInformation(this._theta, fillZetaDefaults(element, 'symbolic')), ...element, })); @@ -250,7 +260,7 @@ export class Cat { }); return { nextStimulus: stimuliAddFisher[0], - remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty - b.difficulty), + remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!), }; } diff --git a/src/type.ts b/src/type.ts index 46e4813..522ccfd 100644 --- a/src/type.ts +++ b/src/type.ts @@ -1,27 +1,36 @@ -export const zetaKeyMap = { - a: 'discrimination', - b: 'difficulty', - c: 'guessing', - d: 'slipping', -}; - -export type ZetaImplicit = { +export type ZetaSymbolic = { + // Symbolic parameter names a: number; // Discrimination (slope of the curve) b: number; // Difficulty (location of the curve) c: number; // Guessing (lower asymptote) d: number; // Slipping (upper asymptote) }; -export type ZetaExplicit = { - discrimination: number; - difficulty: number; - guessing: number; - slipping: number; -}; +export interface Zeta { + // Symbolic parameter names + a?: number; // Discrimination (slope of the curve) + b?: number; // Difficulty (location of the curve) + c?: number; // Guessing (lower asymptote) + d?: number; // Slipping (upper asymptote) + // Semantic parameter names + discrimination?: number; + difficulty?: number; + guessing?: number; + slipping?: number; +} -export type Zeta = ZetaImplicit | ZetaExplicit; +export interface Stimulus extends Zeta { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [key: string]: any; +} + +export type ZetaCatMap = { + cats: string[]; + zeta: Zeta; +}; -export interface Stimulus extends ZetaExplicit { +export interface MultiZetaStimulus { + zetas: ZetaCatMap[]; // eslint-disable-next-line @typescript-eslint/no-explicit-any [key: string]: any; } diff --git a/src/utils.ts b/src/utils.ts index 6982fab..f5eba90 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,5 +1,93 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ import bs from 'binary-search'; -import { Stimulus, Zeta, ZetaExplicit, ZetaImplicit } from './type'; +import { MultiZetaStimulus, Stimulus, Zeta, ZetaSymbolic } from './type'; +import _intersection from 'lodash/intersection'; +import _invert from 'lodash/invert'; +import _mapKeys from 'lodash/mapKeys'; + +export const zetaKeyMap = { + a: 'discrimination', + b: 'difficulty', + c: 'guessing', + d: 'slipping', +}; + +export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + const defaultZeta: Zeta = { + a: 1, + b: 0, + c: 0, + d: 1, + }; + + return convertZeta(defaultZeta, desiredFormat); +}; + +export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { + if (zeta.a !== undefined && zeta.discrimination !== undefined) { + throw new Error('This item has both an `a` key and `discrimination` key. Please provide only one.'); + } + + if (zeta.b !== undefined && zeta.difficulty !== undefined) { + throw new Error('This item has both a `b` key and `difficulty` key. Please provide only one.'); + } + + if (zeta.c !== undefined && zeta.guessing !== undefined) { + throw new Error('This item has both a `c` key and `guessing` key. Please provide only one.'); + } + + if (zeta.d !== undefined && zeta.slipping !== undefined) { + throw new Error('This item has both a `d` key and `slipping` key. Please provide only one.'); + } + + if (requireAll) { + if (zeta.a === undefined && zeta.discrimination === undefined) { + throw new Error('This item is missing an `a` or `discrimination` key.'); + } + + if (zeta.b === undefined && zeta.difficulty === undefined) { + throw new Error('This item is missing a `b` or `difficulty` key.'); + } + + if (zeta.c === undefined && zeta.guessing === undefined) { + throw new Error('This item is missing a `c` or `guessing` key.'); + } + + if (zeta.d === undefined && zeta.slipping === undefined) { + throw new Error('This item is missing a `d` or `slipping` key.'); + } + } +}; + +export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + return { + ...defaultZeta(desiredFormat), + ...convertZeta(zeta, desiredFormat), + }; +}; + +export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): Zeta => { + if (!['symbolic', 'semantic'].includes(desiredFormat)) { + throw new Error(`Invalid desired format. Expected 'symbolic' or'semantic'. Received ${desiredFormat} instead.`); + } + + return _mapKeys(zeta, (value, key) => { + if (desiredFormat === 'symbolic') { + const inverseMap = _invert(zetaKeyMap); + if (key in inverseMap) { + return inverseMap[key]; + } else { + return key; + } + } else { + if (key in zetaKeyMap) { + return zetaKeyMap[key as keyof typeof zetaKeyMap]; + } else { + return key; + } + } + }); +}; /** * calculates the probability that someone with a given ability level theta will answer correctly an item. Uses the 4 parameters logistic model @@ -8,16 +96,8 @@ import { Stimulus, Zeta, ZetaExplicit, ZetaImplicit } from './type'; * @returns {number} the probability */ export const itemResponseFunction = (theta: number, zeta: Zeta) => { - if ((zeta as ZetaImplicit).a) { - const _zeta = zeta as ZetaImplicit; - return _zeta.c + (_zeta.d - _zeta.c) / (1 + Math.exp(-_zeta.a * (theta - _zeta.b))); - } else { - const _zeta = zeta as ZetaExplicit; - return ( - _zeta.guessing + - (_zeta.slipping - _zeta.guessing) / (1 + Math.exp(-_zeta.discrimination * (theta - _zeta.difficulty))) - ); - } + const _zeta = fillZetaDefaults(zeta, 'symbolic') as ZetaSymbolic; + return _zeta.c + (_zeta.d - _zeta.c) / (1 + Math.exp(-_zeta.a * (theta - _zeta.b))); }; /** @@ -27,17 +107,10 @@ export const itemResponseFunction = (theta: number, zeta: Zeta) => { * @returns {number} - the expected value of the observed information */ export const fisherInformation = (theta: number, zeta: Zeta) => { - const p = itemResponseFunction(theta, zeta); + const _zeta = fillZetaDefaults(zeta, 'symbolic') as ZetaSymbolic; + const p = itemResponseFunction(theta, _zeta); const q = 1 - p; - if ((zeta as ZetaImplicit).a) { - const _zeta = zeta as ZetaImplicit; - return Math.pow(_zeta.a, 2) * (q / p) * (Math.pow(p - _zeta.c, 2) / Math.pow(1 - _zeta.c, 2)); - } else { - const _zeta = zeta as ZetaExplicit; - return ( - Math.pow(_zeta.discrimination, 2) * (q / p) * (Math.pow(p - _zeta.guessing, 2) / Math.pow(1 - _zeta.guessing, 2)) - ); - } + return Math.pow(_zeta.a, 2) * (q / p) * (Math.pow(p - _zeta.c, 2) / Math.pow(1 - _zeta.c, 2)); }; /** @@ -67,22 +140,23 @@ export const normal = (mean = 0, stdDev = 1, min = -4, max = 4, stepSize = 0.1) * @remarks * The input array of stimuli must be sorted by difficulty. * - * @param arr Array - an array of stimuli sorted by difficulty + * @param stimuli Array - an array of stimuli sorted by difficulty * @param target number - ability estimate - * @returns {number} the index of arr + * @returns {number} the index of stimuli */ -export const findClosest = (arr: Array, target: number) => { +export const findClosest = (inputStimuli: Array, target: number) => { + const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); // Let's consider the edge cases first - if (target <= arr[0].difficulty) { + if (target <= stimuli[0].difficulty!) { return 0; - } else if (target >= arr[arr.length - 1].difficulty) { - return arr.length - 1; + } else if (target >= stimuli[stimuli.length - 1].difficulty!) { + return stimuli.length - 1; } const comparitor = (element: Stimulus, needle: number) => { - return element.difficulty - needle; + return element.difficulty! - needle; }; - const indexOfTarget = bs(arr, target, comparitor); + const indexOfTarget = bs(stimuli, target, comparitor); if (indexOfTarget >= 0) { // `bs` returns a positive integer index if it found an exact match. @@ -96,8 +170,8 @@ export const findClosest = (arr: Array, target: number) => { // So we simply compare the differences between the target and the high and // low values, respectively - const lowDiff = Math.abs(arr[lowIndex].difficulty - target); - const highDiff = Math.abs(arr[highIndex].difficulty - target); + const lowDiff = Math.abs(stimuli[lowIndex].difficulty! - target); + const highDiff = Math.abs(stimuli[highIndex].difficulty! - target); if (lowDiff < highDiff) { return lowIndex; @@ -106,3 +180,13 @@ export const findClosest = (arr: Array, target: number) => { } } }; + +export const validateCorpora = (corpus: MultiZetaStimulus[]): void => { + const zetaCatMapsArray = corpus.map((item) => item.zetas); + for (const zetaCatMaps of zetaCatMapsArray) { + const intersection = _intersection(zetaCatMaps); + if (intersection.length > 0) { + throw new Error(`The cat names ${intersection.join(', ')} are present in multiple corpora.`); + } + } +}; From 423954ffb8d8714facc5fdb852c0bbd4a605fe27 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Tue, 17 Sep 2024 17:42:37 -0700 Subject: [PATCH 07/47] Add TODO comments --- src/clowder.ts | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/clowder.ts b/src/clowder.ts index 60367c2..8012246 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -80,6 +80,7 @@ export class Clowder { * @param {Stimulus[]} [input.items=[]] - An array of previously presented stimuli. * @param {(0 | 1) | (0 | 1)[]} [input.answers=[]] - An array of answers (0 or 1) corresponding to `items`. * @param {string} [input.method] - Optional method for updating ability estimates (if applicable). + * @param {string} [input.itemSelect] - Optional item selection method (if applicable). * * @returns {Stimulus | undefined} - The next stimulus to present, or `undefined` if no further validated stimuli are available. * @@ -99,12 +100,14 @@ export class Clowder { items = [], answers = [], method, + itemSelect, }: { catToSelect: string; catsToUpdate?: string | string[]; items: MultiZetaStimulus[]; answers: (0 | 1) | (0 | 1)[]; method?: string; + itemSelect?: string; }): Stimulus | undefined { // Validate all cat names this._validateCatName(catToSelect); @@ -151,29 +154,38 @@ export class Clowder { this.cats[catName].updateAbilityEstimate(zetas, answers, method); } + // TODO: Before we explicityly differentiated between validated and unvalidated stimuli. + // Now, we need to dynamically calculate the unvalidated stimuli by looking at the remaining items + // that do not have a zeta associated with the catToSelect. + + // TODO: These functions do not exist. + const validatedRemainingItems = filterRemainingItemsForThisCat('validated'); + const unvalidatedRemainingItems = filterRemainingItemsForThisCat('unvalidated'); + const validatedCatInput = validatedRemainingItems.map((stim) => putStimuliInExpectedFormat); + // Use the catForSelect to determine the next stimulus const cat = this.cats[catToSelect]; - const { nextStimulus } = cat.findNextItem(this.remainingItems.validated); + const { nextStimulus } = cat.findNextItem(validatedCatInput, itemSelect); // Added some logic to mix in the unvalidated stimuli if needed. - if (this.remainingItems.unvalidated.length === 0) { + if (unvalidatedRemainingItems.length === 0) { // If there are no more unvalidated stimuli, we only have validated items left. // Use the Cat to find the next item. The Cat may return undefined if all validated items have been seen. return nextStimulus; - } else if (this.remainingItems.validated.length === 0) { + } else if (validatedRemainingItems.length === 0) { // In this case, there are no more validated items left. Choose an unvalidated item at random. - return this.remainingItems.unvalidated[Math.floor(Math.random() * this.remainingItems.unvalidated.length)]; + return unvalidatedRemainingItems[Math.floor(Math.random() * unvalidatedRemainingItems.length)]; } else { // In this case, there are both validated and unvalidated items left. // We need to randomly insert unvalidated items const numRemaining = { - validated: this.remainingItems.validated.length, - unvalidated: this.remainingItems.unvalidated.length, + validated: validatedRemainingItems.length, + unvalidated: unvalidatedRemainingItems.length, }; const random = Math.random(); if (random < numRemaining.unvalidated / (numRemaining.validated + numRemaining.unvalidated)) { - return this.remainingItems.unvalidated[Math.floor(Math.random() * this.remainingItems.unvalidated.length)]; + return unvalidatedRemainingItems[Math.floor(Math.random() * unvalidatedRemainingItems.length)]; } else { return nextStimulus; } From 22b85febcd8c1113d90bc2b39abc20d446204c62 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Wed, 18 Sep 2024 09:26:51 -0700 Subject: [PATCH 08/47] Add util tests --- src/__tests__/clowder.test.ts | 115 ++++++++----------- src/__tests__/utils.test.ts | 206 +++++++++++++++++++++++++++++++++- src/clowder.ts | 73 +++++++----- src/utils.ts | 35 ++++-- 4 files changed, 326 insertions(+), 103 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index ea35133..eca29f1 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -1,14 +1,16 @@ import { Clowder, ClowderInput } from '../clowder'; -import { Stimulus } from '../type'; +import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; +import { defaultZeta } from '../utils'; -// Mocking Stimulus -const createStimulus = (id: string): Stimulus => ({ +const createMultiZetaStimulus = (id: string, zetas: ZetaCatMap[]): MultiZetaStimulus => ({ id, - difficulty: 1, - discrimination: 1, - guessing: 0, - slipping: 0, - content: `Stimulus content ${id}`, + content: `Multi-Zeta Stimulus content ${id}`, + zetas, +}); + +const createZetaCatMap = (catNames: string[], zeta: Zeta = defaultZeta()): ZetaCatMap => ({ + cats: catNames, + zeta, }); describe('Clowder Class', () => { @@ -20,89 +22,68 @@ describe('Clowder Class', () => { cat1: { method: 'MLE', theta: 0.5 }, cat2: { method: 'EAP', theta: -1.0 }, }, - corpora: { - validated: [createStimulus('1'), createStimulus('2')], - unvalidated: [createStimulus('1')], - }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1']), createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1']), createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('4', []), + ], }; clowder = new Clowder(clowderInput); }); test('should initialize with provided cats and corpora', () => { expect(Object.keys(clowder['cats'])).toContain('cat1'); - expect(clowder.remainingItems.validated).toHaveLength(2); - expect(clowder.remainingItems.unvalidated).toHaveLength(1); + expect(clowder.remainingItems).toHaveLength(2); + expect(clowder.corpus).toHaveLength(1); }); test('should validate cat names', () => { expect(() => { clowder.updateCatAndGetNextItem({ catToSelect: 'invalidCat', - previousItems: [], - previousAnswers: [], }); }).toThrow('Invalid Cat name'); }); - test('should update ability estimates', () => { - clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); - const cat1 = clowder['cats']['cat1']; - expect(cat1.theta).toBeGreaterThanOrEqual(0); // Since we mock, assume the result is logical. - }); + // test('should update ability estimates', () => { + // clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); + // const cat1 = clowder['cats']['cat1']; + // expect(cat1.theta).toBeGreaterThanOrEqual(0); // Since we mock, assume the result is logical. + // }); - test('should select next stimulus from validated stimuli', () => { - const nextItem = clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - catsToUpdate: ['cat1'], - previousItems: [createStimulus('1')], - previousAnswers: [1], - }); - expect(nextItem).toEqual(createStimulus('1')); // Second validated stimulus - }); + // test('should select next stimulus from validated stimuli', () => { + // const nextItem = clowder.updateCatAndGetNextItem({ + // catToSelect: 'cat1', + // catsToUpdate: ['cat1'], + // previousItems: [createStimulus('1')], + // previousAnswers: [1], + // }); + // expect(nextItem).toEqual(createStimulus('1')); // Second validated stimulus + // }); - test('should return unvalidated stimulus when no validated stimuli remain', () => { - clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - previousItems: [createStimulus('1'), createStimulus('2')], - previousAnswers: [1, 0], - }); - - const nextItem = clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - previousItems: [], - previousAnswers: [], - }); - expect(nextItem).toEqual(createStimulus('1')); // Unvalidated item - }); - - test('should add a new Cat instance', () => { - clowder.addCat('cat3', { method: 'MLE', theta: 0 }); - expect(Object.keys(clowder['cats'])).toContain('cat3'); - }); + // test('should return unvalidated stimulus when no validated stimuli remain', () => { + // clowder.updateCatAndGetNextItem({ + // catToSelect: 'cat1', + // previousItems: [createStimulus('1'), createStimulus('2')], + // previousAnswers: [1, 0], + // }); - test('should throw error if adding duplicate Cat instance', () => { - expect(() => { - clowder.addCat('cat1', { method: 'MLE', theta: 0 }); - }).toThrow('Cat with the name "cat1" already exists.'); - }); - - test('should remove a Cat instance', () => { - clowder.removeCat('cat1'); - expect(Object.keys(clowder['cats'])).not.toContain('cat1'); - }); - - test('should throw error when trying to remove non-existent Cat instance', () => { - expect(() => { - clowder.removeCat('nonExistentCat'); - }).toThrow('Invalid Cat name'); - }); + // const nextItem = clowder.updateCatAndGetNextItem({ + // catToSelect: 'cat1', + // previousItems: [], + // previousAnswers: [], + // }); + // expect(nextItem).toEqual(createStimulus('1')); // Unvalidated item + // }); test('should throw error if previousItems and previousAnswers have mismatched lengths', () => { expect(() => { clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', - previousItems: [createStimulus('1')], - previousAnswers: [1, 0], // Mismatched length + items: createMultiZetaStimulus('1', [createZetaCatMap(['cat1']), createZetaCatMap(['cat2'])]), + answers: [1, 0], // Mismatched length }); }).toThrow('Previous items and answers must have the same length.'); }); diff --git a/src/__tests__/utils.test.ts b/src/__tests__/utils.test.ts index 8895e8e..31f915a 100644 --- a/src/__tests__/utils.test.ts +++ b/src/__tests__/utils.test.ts @@ -1,4 +1,15 @@ -import { itemResponseFunction, fisherInformation, findClosest } from '../utils'; +import { Stimulus, Zeta } from '../type'; +import { + itemResponseFunction, + fisherInformation, + findClosest, + validateZetaParams, + ZETA_KEY_MAP, + defaultZeta, + fillZetaDefaults, + convertZeta, +} from '../utils'; +import _omit from 'lodash/omit'; describe('itemResponseFunction', () => { it('correctly calculates the probability', () => { @@ -55,3 +66,196 @@ describe('findClosest', () => { expect(findClosest(stimuliWithDecimal, 9.1)).toBe(2); }); }); + +describe('validateZetaParams', () => { + it('throws an error when providing both a and discrimination', () => { + expect(() => validateZetaParams({ a: 1, discrimination: 1 })).toThrow( + 'This item has both an `a` key and `discrimination` key. Please provide only one.', + ); + }); + + it('throws an error when providing both b and difficulty', () => { + expect(() => validateZetaParams({ b: 1, difficulty: 1 })).toThrow( + 'This item has both a `b` key and `difficulty` key. Please provide only one.', + ); + }); + + it('throws an error when providing both c and guessing', () => { + expect(() => validateZetaParams({ c: 1, guessing: 1 })).toThrow( + 'This item has both a `c` key and `guessing` key. Please provide only one.', + ); + }); + + it('throws an error when providing both d and slipping', () => { + expect(() => validateZetaParams({ d: 1, slipping: 1 })).toThrow( + 'This item has both a `d` key and `slipping` key. Please provide only one.', + ); + }); + + it('throws an error when requiring all keys and missing one', () => { + for (const key of ['a', 'b', 'c', 'd'] as (keyof typeof ZETA_KEY_MAP)[]) { + const semanticKey = ZETA_KEY_MAP[key]; + const zeta = _omit(defaultZeta('symbolic'), [key]); + + expect(() => validateZetaParams(zeta, true)).toThrow( + `This item is missing the key \`${String(key)}\` or \`${semanticKey}\`.`, + ); + } + }); +}); + +describe('fillZetaDefaults', () => { + it('fills in default values for missing keys', () => { + const zeta: Zeta = { + difficulty: 1, + guessing: 0.5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'semantic'); + + expect(filledZeta).toEqual({ + discrimination: 1, + difficulty: 1, + guessing: 0.5, + slipping: 1, + }); + }); + + it('does not modify the input object when no missing keys', () => { + const zeta: Zeta = { + a: 5, + b: 5, + c: 5, + d: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'symbolic'); + + expect(filledZeta).toEqual(zeta); + }); + + it('converts to semantic format when desired', () => { + const zeta: Zeta = { + a: 5, + b: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'semantic'); + + expect(filledZeta).toEqual({ + difficulty: 5, + discrimination: 5, + guessing: 0, + slipping: 1, + }); + }); + + it('converts to symbolic format when desired', () => { + const zeta: Zeta = { + difficulty: 5, + discrimination: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'symbolic'); + + expect(filledZeta).toEqual({ + a: 5, + b: 5, + c: 0, + d: 1, + }); + }); +}); + +describe('convertZeta', () => { + it('converts from symbolic format to semantic format', () => { + const zeta: Zeta = { + a: 1, + b: 2, + c: 3, + d: 4, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + }); + }); + + it('converts from semantic format to symbolic format', () => { + const zeta: Zeta = { + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + }; + + const convertedZeta = convertZeta(zeta, 'symbolic'); + + expect(convertedZeta).toEqual({ + a: 1, + b: 2, + c: 3, + d: 4, + }); + }); + + it('throws an error when converting from an unsupported format', () => { + const zeta: Zeta = { + a: 1, + b: 2, + c: 3, + d: 4, + }; + + expect(() => convertZeta(zeta, 'unsupported' as 'symbolic')).toThrow( + "Invalid desired format. Expected 'symbolic' or'semantic'. Received unsupported instead.", + ); + }); + + it('does not modify other keys when converting', () => { + const zeta: Stimulus = { + a: 1, + b: 2, + c: 3, + d: 4, + key1: 5, + key2: 6, + key3: 7, + key4: 8, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + key1: 5, + key2: 6, + key3: 7, + key4: 8, + }); + }); + + it('converts only existing keys', () => { + const zeta: Zeta = { + a: 1, + b: 2, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + }); + }); +}); + +// TODO: Write tests for validateCorpus and filterItemsByCatParameterAvailability diff --git a/src/clowder.ts b/src/clowder.ts index 8012246..f08815b 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -1,10 +1,11 @@ import { Cat, CatInput } from './index'; import { MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; import _cloneDeep from 'lodash/cloneDeep'; +import _isEqual from 'lodash/isEqual'; import _mapValues from 'lodash/mapValues'; import _unzip from 'lodash/unzip'; import _zip from 'lodash/zip'; -import { validateCorpora } from './utils'; +import { filterItemsByCatParameterAvailability, validateCorpus } from './utils'; export interface ClowderInput { // An object containing Cat configurations for each Cat instance. @@ -12,12 +13,12 @@ export interface ClowderInput { [name: string]: CatInput; }; // An object containing arrays of stimuli for each corpus. - corpora: MultiZetaStimulus[]; + corpus: MultiZetaStimulus[]; } export class Clowder { private cats: { [name: string]: Cat }; - private corpora: MultiZetaStimulus[]; + private _corpus: MultiZetaStimulus[]; public remainingItems: MultiZetaStimulus[]; public seenItems: Stimulus[]; @@ -25,13 +26,13 @@ export class Clowder { * Create a Clowder object. * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. */ - constructor({ cats, corpora }: ClowderInput) { + constructor({ cats, corpus }: ClowderInput) { // TODO: Need to pass in numItemsRequired so that we know when to stop providing new items. this.cats = _mapValues(cats, (catInput) => new Cat(catInput)); this.seenItems = []; - validateCorpora(corpora); - this.corpora = corpora; - this.remainingItems = _cloneDeep(corpora); + validateCorpus(corpus); + this._corpus = corpus; + this.remainingItems = _cloneDeep(corpus); } private _validateCatName(catName: string): void { @@ -40,6 +41,10 @@ export class Clowder { } } + public get corpus() { + return this._corpus; + } + public get theta() { return _mapValues(this.cats, (cat) => cat.theta); } @@ -104,8 +109,8 @@ export class Clowder { }: { catToSelect: string; catsToUpdate?: string | string[]; - items: MultiZetaStimulus[]; - answers: (0 | 1) | (0 | 1)[]; + items?: MultiZetaStimulus | MultiZetaStimulus[]; + answers?: (0 | 1) | (0 | 1)[]; method?: string; itemSelect?: string; }): Stimulus | undefined { @@ -154,40 +159,56 @@ export class Clowder { this.cats[catName].updateAbilityEstimate(zetas, answers, method); } - // TODO: Before we explicityly differentiated between validated and unvalidated stimuli. - // Now, we need to dynamically calculate the unvalidated stimuli by looking at the remaining items - // that do not have a zeta associated with the catToSelect. + // Now, we need to dynamically calculate the stimuli available for selection by `catToSelect`. + // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` - // TODO: These functions do not exist. - const validatedRemainingItems = filterRemainingItemsForThisCat('validated'); - const unvalidatedRemainingItems = filterRemainingItemsForThisCat('unvalidated'); - const validatedCatInput = validatedRemainingItems.map((stim) => putStimuliInExpectedFormat); + const { available, missing } = filterItemsByCatParameterAvailability(this.remainingItems, catToSelect); + + // The cat expects an array of Stimulus objects, with the zeta parameters + // spread at the top-level of each Stimulus object. So we need to convert + // the MultiZetaStimulus array to an array of Stimulus objects. + const availableCatInput = available.map((item) => { + const { zetas, ...rest } = item; + const zetasForCat = zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catToSelect)); + return { + ...(zetasForCat?.zeta ?? {}), + ...rest, + }; + }); // Use the catForSelect to determine the next stimulus const cat = this.cats[catToSelect]; - const { nextStimulus } = cat.findNextItem(validatedCatInput, itemSelect); + const { nextStimulus } = cat.findNextItem(availableCatInput, itemSelect); + + // Again `nextStimulus` will be a Stimulus object, or `undefined` if no further validated stimuli are available. + // We need to convert the Stimulus object back to a MultiZetaStimulus object to return to the user. + const returnStimulus: MultiZetaStimulus | undefined = available.find((stim) => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { zetas, ...rest } = stim; + return _isEqual(rest, nextStimulus); + }); // Added some logic to mix in the unvalidated stimuli if needed. - if (unvalidatedRemainingItems.length === 0) { + if (missing.length === 0) { // If there are no more unvalidated stimuli, we only have validated items left. // Use the Cat to find the next item. The Cat may return undefined if all validated items have been seen. - return nextStimulus; - } else if (validatedRemainingItems.length === 0) { + return returnStimulus; + } else if (available.length === 0) { // In this case, there are no more validated items left. Choose an unvalidated item at random. - return unvalidatedRemainingItems[Math.floor(Math.random() * unvalidatedRemainingItems.length)]; + return missing[Math.floor(Math.random() * missing.length)]; } else { // In this case, there are both validated and unvalidated items left. // We need to randomly insert unvalidated items const numRemaining = { - validated: validatedRemainingItems.length, - unvalidated: unvalidatedRemainingItems.length, + available: available.length, + missing: missing.length, }; const random = Math.random(); - if (random < numRemaining.unvalidated / (numRemaining.validated + numRemaining.unvalidated)) { - return unvalidatedRemainingItems[Math.floor(Math.random() * unvalidatedRemainingItems.length)]; + if (random < numRemaining.missing / (numRemaining.available + numRemaining.missing)) { + return missing[Math.floor(Math.random() * missing.length)]; } else { - return nextStimulus; + return returnStimulus; } } } diff --git a/src/utils.ts b/src/utils.ts index f5eba90..a29e957 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -5,13 +5,15 @@ import _intersection from 'lodash/intersection'; import _invert from 'lodash/invert'; import _mapKeys from 'lodash/mapKeys'; -export const zetaKeyMap = { +// TODO: Document this +export const ZETA_KEY_MAP = { a: 'discrimination', b: 'difficulty', c: 'guessing', d: 'slipping', }; +// TODO: Document this export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { const defaultZeta: Zeta = { a: 1, @@ -23,6 +25,7 @@ export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic') return convertZeta(defaultZeta, desiredFormat); }; +// TODO: Document this export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { if (zeta.a !== undefined && zeta.discrimination !== undefined) { throw new Error('This item has both an `a` key and `discrimination` key. Please provide only one.'); @@ -42,23 +45,24 @@ export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { if (requireAll) { if (zeta.a === undefined && zeta.discrimination === undefined) { - throw new Error('This item is missing an `a` or `discrimination` key.'); + throw new Error('This item is missing the key `a` or `discrimination`.'); } if (zeta.b === undefined && zeta.difficulty === undefined) { - throw new Error('This item is missing a `b` or `difficulty` key.'); + throw new Error('This item is missing the key `b` or `difficulty`.'); } if (zeta.c === undefined && zeta.guessing === undefined) { - throw new Error('This item is missing a `c` or `guessing` key.'); + throw new Error('This item is missing the key `c` or `guessing`.'); } if (zeta.d === undefined && zeta.slipping === undefined) { - throw new Error('This item is missing a `d` or `slipping` key.'); + throw new Error('This item is missing the key `d` or `slipping`.'); } } }; +// TODO: Document this export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { return { ...defaultZeta(desiredFormat), @@ -66,6 +70,7 @@ export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semant }; }; +// TODO: Document this export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): Zeta => { if (!['symbolic', 'semantic'].includes(desiredFormat)) { throw new Error(`Invalid desired format. Expected 'symbolic' or'semantic'. Received ${desiredFormat} instead.`); @@ -73,15 +78,15 @@ export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): return _mapKeys(zeta, (value, key) => { if (desiredFormat === 'symbolic') { - const inverseMap = _invert(zetaKeyMap); + const inverseMap = _invert(ZETA_KEY_MAP); if (key in inverseMap) { return inverseMap[key]; } else { return key; } } else { - if (key in zetaKeyMap) { - return zetaKeyMap[key as keyof typeof zetaKeyMap]; + if (key in ZETA_KEY_MAP) { + return ZETA_KEY_MAP[key as keyof typeof ZETA_KEY_MAP]; } else { return key; } @@ -181,7 +186,8 @@ export const findClosest = (inputStimuli: Array, target: number) => { } }; -export const validateCorpora = (corpus: MultiZetaStimulus[]): void => { +// TODO: Document this +export const validateCorpus = (corpus: MultiZetaStimulus[]): void => { const zetaCatMapsArray = corpus.map((item) => item.zetas); for (const zetaCatMaps of zetaCatMapsArray) { const intersection = _intersection(zetaCatMaps); @@ -190,3 +196,14 @@ export const validateCorpora = (corpus: MultiZetaStimulus[]): void => { } } }; + +// TODO: Document this +export const filterItemsByCatParameterAvailability = (items: MultiZetaStimulus[], catName: string) => { + const paramsExist = items.filter((item) => item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); + const paramsMissing = items.filter((item) => !item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); + + return { + available: paramsExist, + missing: paramsMissing, + }; +}; From 242f02f2911f77c6efe0a0fbfde37323a8cd8799 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Wed, 18 Sep 2024 12:56:46 -0700 Subject: [PATCH 09/47] Add documentation --- src/utils.ts | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/utils.ts b/src/utils.ts index a29e957..2f62560 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -6,6 +6,10 @@ import _invert from 'lodash/invert'; import _mapKeys from 'lodash/mapKeys'; // TODO: Document this +/** + * A constant map from the symbolic item parameter names to their semantic + * counterparts. + */ export const ZETA_KEY_MAP = { a: 'discrimination', b: 'difficulty', @@ -13,7 +17,12 @@ export const ZETA_KEY_MAP = { d: 'slipping', }; -// TODO: Document this +/** + * Return default item parameters (i.e., zeta) + * + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. + * @returns {Zeta} the default zeta object in the specified format. + */ export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { const defaultZeta: Zeta = { a: 1, @@ -25,7 +34,15 @@ export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic') return convertZeta(defaultZeta, desiredFormat); }; -// TODO: Document this +/** + * Validates the item (a.k.a. zeta) parameters, prohibiting redundant keys and + * optionally requiring all parameters. + * + * @param {Zeta} zeta - The zeta parameters to validate. + * @param {boolean} requireAll - If `true`, ensures that all required keys are present. Default is `false`. + * + * @throws {Error} Will throw an error if any of the validation rules are violated. + */ export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { if (zeta.a !== undefined && zeta.discrimination !== undefined) { throw new Error('This item has both an `a` key and `discrimination` key. Please provide only one.'); @@ -62,7 +79,20 @@ export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { } }; -// TODO: Document this +/** + * Fills in default zeta parameters for any missing keys in the provided zeta object. + * + * @remarks + * This function merges the provided zeta object with the default zeta object, converting + * the keys to the desired format if specified. If no desired format is provided, the + * keys will remain in their original format. + * + * @param {Zeta} zeta - The zeta parameters to fill in defaults for. + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Default is 'symbolic'. + * + * @returns A new zeta object with default values filled in for any missing keys, + * and converted to the desired format if specified. + */ export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { return { ...defaultZeta(desiredFormat), From 9fb94361da3a38a539601b1279eccfb18b2339d6 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Thu, 19 Sep 2024 06:07:53 -0700 Subject: [PATCH 10/47] Document and test utils --- src/__tests__/index.test.ts | 53 +++++++++---- src/__tests__/utils.test.ts | 124 ++++++++++++++++++++++++++++- src/clowder.ts | 4 +- src/index.ts | 25 +++--- src/utils.ts | 151 ++++++++++++++++++++++++++++++------ 5 files changed, 302 insertions(+), 55 deletions(-) diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index bcbbd7c..61a761d 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -5,7 +5,7 @@ import seedrandom from 'seedrandom'; import { convertZeta } from '../utils'; for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) { - describe('Cat with explicit zeta', () => { + describe(`Cat with ${format} zeta`, () => { let cat1: Cat, cat2: Cat, cat3: Cat, cat4: Cat, cat5: Cat, cat6: Cat, cat7: Cat, cat8: Cat; let rng = seedrandom(); @@ -37,7 +37,7 @@ for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) const randomSeed = 'test'; rng = seedrandom(randomSeed); cat4 = new Cat({ nStartItems: 0, itemSelect: 'RANDOM', randomSeed }); - cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); // ask + cat5 = new Cat({ nStartItems: 1, startSelect: 'miDdle' }); cat6 = new Cat(); cat6.updateAbilityEstimate( @@ -61,6 +61,13 @@ for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) const s5: Stimulus = { difficulty: -1.8, guessing: 0.5, discrimination: 1, slipping: 1, word: 'mom' }; const stimuli = [s1, s2, s3, s4, s5]; + it('can update an ability estimate using only a single item and answer', () => { + const cat = new Cat(); + cat.updateAbilityEstimate(s1, 1); + expect(cat.nItems).toEqual(1); + expect(cat.theta).toBeCloseTo(4.572, 1); + }); + it('constructs an adaptive test', () => { expect(cat1.method).toBe('mle'); expect(cat1.itemSelect).toBe('mfi'); @@ -98,42 +105,62 @@ for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) ]); }); - it('correctly suggests the next item (closest method)', () => { + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (closest method) with deepCopy='$deepCopy'", ({ deepCopy }) => { const expected = { nextStimulus: s5, remainingStimuli: [s4, s1, s3, s2] }; - const received = cat1.findNextItem(stimuli, 'closest'); + const received = cat1.findNextItem(stimuli, 'closest', deepCopy); expect(received).toEqual(expected); }); - it('correctly suggests the next item (mfi method)', () => { + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (mfi method) with deepCopy='$deepCopy'", ({ deepCopy }) => { const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat3.findNextItem(stimuli, 'MFI'); + const received = cat3.findNextItem(stimuli, 'MFI', deepCopy); expect(received).toEqual(expected); }); - it('correctly suggests the next item (middle method)', () => { + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (middle method) with deepCopy='$deepCopy'", ({ deepCopy }) => { const expected = { nextStimulus: s1, remainingStimuli: [s4, s5, s3, s2] }; - const received = cat5.findNextItem(stimuli); + const received = cat5.findNextItem(stimuli, undefined, deepCopy); expect(received).toEqual(expected); }); - it('correctly suggests the next item (fixed method)', () => { + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (fixed method) with deepCopy='$deepCopy'", ({ deepCopy }) => { expect(cat8.itemSelect).toBe('fixed'); const expected = { nextStimulus: s1, remainingStimuli: [s2, s3, s4, s5] }; - const received = cat8.findNextItem(stimuli); + const received = cat8.findNextItem(stimuli, undefined, deepCopy); expect(received).toEqual(expected); }); - it('correctly suggests the next item (random method)', () => { + it.each` + deepCopy + ${true} + ${false} + `("correctly suggests the next item (random method) with deepCopy='$deepCopy'", ({ deepCopy }) => { let received; const stimuliSorted = stimuli.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); // ask let index = Math.floor(rng() * stimuliSorted.length); - received = cat4.findNextItem(stimuliSorted); + received = cat4.findNextItem(stimuliSorted, undefined, deepCopy); expect(received.nextStimulus).toEqual(stimuliSorted[index]); for (let i = 0; i < 3; i++) { const remainingStimuli = received.remainingStimuli; index = Math.floor(rng() * remainingStimuli.length); - received = cat4.findNextItem(remainingStimuli); + received = cat4.findNextItem(remainingStimuli, undefined, deepCopy); expect(received.nextStimulus).toEqual(remainingStimuli[index]); } }); diff --git a/src/__tests__/utils.test.ts b/src/__tests__/utils.test.ts index 31f915a..ba72977 100644 --- a/src/__tests__/utils.test.ts +++ b/src/__tests__/utils.test.ts @@ -1,4 +1,4 @@ -import { Stimulus, Zeta } from '../type'; +import { MultiZetaStimulus, Stimulus, Zeta } from '../type'; import { itemResponseFunction, fisherInformation, @@ -8,6 +8,8 @@ import { defaultZeta, fillZetaDefaults, convertZeta, + checkNoDuplicateCatNames, + filterItemsByCatParameterAvailability, } from '../utils'; import _omit from 'lodash/omit'; @@ -258,4 +260,122 @@ describe('convertZeta', () => { }); }); -// TODO: Write tests for validateCorpus and filterItemsByCatParameterAvailability +describe('checkNoDuplicateCatNames', () => { + it('should throw an error when a cat name is present in multiple zetas', () => { + const corpus: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + { cats: ['Model C'], zeta: { a: 1, b: 2, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + expect(() => checkNoDuplicateCatNames(corpus)).toThrowError('The cat names Model C are present in multiple corpora.'); + }); + + it('should not throw an error when a cat name is not present in multiple corpora', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + expect(() => checkNoDuplicateCatNames(items)).not.toThrowError(); + }); + + it('should handle an empty corpus without throwing an error', () => { + const emptyCorpus: MultiZetaStimulus[] = []; + + expect(() => checkNoDuplicateCatNames(emptyCorpus)).not.toThrowError(); + }); +}); + +describe('filterItemsByCatParameterAvailability', () => { + it('returns an empty "available" array when no items match the catname', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model D'); + + expect(result.available).toEqual([]); + expect(result.missing).toEqual(items); + }); + + it('returns empty missing array when all items match the catname', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model A'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [ + { cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + { cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }, + ], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model A'); + + expect(result.missing).toEqual([]); + expect(result.available).toEqual(items); + }); + + it('separates items based on matching catnames', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + { + stimulus: 'Item 3', + zetas: [{ cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model A'); + + // Assert + expect(result.available.length).toBe(2); + expect(result.available[0].stimulus).toBe('Item 1'); + expect(result.available[1].stimulus).toBe('Item 3'); + expect(result.missing.length).toBe(1); + expect(result.missing[0].stimulus).toBe('Item 2'); + }); +}); diff --git a/src/clowder.ts b/src/clowder.ts index f08815b..6a7844e 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -5,7 +5,7 @@ import _isEqual from 'lodash/isEqual'; import _mapValues from 'lodash/mapValues'; import _unzip from 'lodash/unzip'; import _zip from 'lodash/zip'; -import { filterItemsByCatParameterAvailability, validateCorpus } from './utils'; +import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './utils'; export interface ClowderInput { // An object containing Cat configurations for each Cat instance. @@ -30,7 +30,7 @@ export class Clowder { // TODO: Need to pass in numItemsRequired so that we know when to stop providing new items. this.cats = _mapValues(cats, (catInput) => new Cat(catInput)); this.seenItems = []; - validateCorpus(corpus); + checkNoDuplicateCatNames(corpus); this._corpus = corpus; this.remainingItems = _cloneDeep(corpus); } diff --git a/src/index.ts b/src/index.ts index b0cf123..91a69e3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,5 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { minimize_Powell } from 'optimization-js'; -import { cloneDeep } from 'lodash'; import { Stimulus, Zeta } from './type'; import { itemResponseFunction, @@ -11,6 +10,8 @@ import { fillZetaDefaults, } from './utils'; import seedrandom from 'seedrandom'; +import _clamp from 'lodash/clamp'; +import _cloneDeep from 'lodash/cloneDeep'; export const abilityPrior = normal(); @@ -175,13 +176,8 @@ export class Cat { private estimateAbilityMLE() { const theta0 = [0]; const solution = minimize_Powell(this.negLikelihood.bind(this), theta0); - let theta = solution.argument[0]; - if (theta > this.maxTheta) { - theta = this.maxTheta; - } else if (theta < this.minTheta) { - theta = this.minTheta; - } - return theta; + const theta = solution.argument[0]; + return _clamp(theta, this.minTheta, this.maxTheta); } private negLikelihood(thetaArray: Array) { @@ -216,7 +212,7 @@ export class Cat { let arr: Array; let selector = Cat.validateItemSelect(itemSelect); if (deepCopy) { - arr = cloneDeep(stimuli); + arr = _cloneDeep(stimuli); } else { arr = stimuli; } @@ -266,13 +262,12 @@ export class Cat { private selectorMiddle(arr: Stimulus[]) { let index: number; - if (arr.length < this.nStartItems) { - index = Math.floor(arr.length / 2); - } else { - index = - Math.floor(arr.length / 2) + - this.randomInteger(-Math.floor(this.nStartItems / 2), Math.floor(this.nStartItems / 2)); + index = Math.floor(arr.length / 2); + + if (arr.length >= this.nStartItems) { + index += this.randomInteger(-Math.floor(this.nStartItems / 2), Math.floor(this.nStartItems / 2)); } + const nextItem = arr[index]; arr.splice(index, 1); return { diff --git a/src/utils.ts b/src/utils.ts index 2f62560..bf70927 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,9 +1,11 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import bs from 'binary-search'; import { MultiZetaStimulus, Stimulus, Zeta, ZetaSymbolic } from './type'; -import _intersection from 'lodash/intersection'; +import _flatten from 'lodash/flatten'; import _invert from 'lodash/invert'; import _mapKeys from 'lodash/mapKeys'; +import _union from 'lodash/union'; +import _uniq from 'lodash/uniq'; // TODO: Document this /** @@ -100,7 +102,24 @@ export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semant }; }; -// TODO: Document this +/** + * Converts zeta parameters between symbolic and semantic formats. + * + * @remarks + * This function takes a zeta object and a desired format as input. It converts + * the keys of the zeta object from their current format to the desired format. + * If the desired format is 'symbolic', the function maps the keys to their + * symbolic counterparts using the `ZETA_KEY_MAP`. If the desired format is + * 'semantic', the function maps the keys to their semantic counterparts using + * the inverse of `ZETA_KEY_MAP`. + * + * @param {Zeta} zeta - The zeta parameters to convert. + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Must be either 'symbolic' or 'semantic'. + * + * @throws {Error} - Will throw an error if the desired format is not 'symbolic' or 'semantic'. + * + * @returns {Zeta} A new zeta object with keys converted to the desired format. + */ export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): Zeta => { if (!['symbolic', 'semantic'].includes(desiredFormat)) { throw new Error(`Invalid desired format. Expected 'symbolic' or'semantic'. Received ${desiredFormat} instead.`); @@ -125,9 +144,11 @@ export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): }; /** - * calculates the probability that someone with a given ability level theta will answer correctly an item. Uses the 4 parameters logistic model - * @param theta - ability estimate - * @param zeta - item params + * Calculates the probability that someone with a given ability level theta will + * answer correctly an item. Uses the 4 parameters logistic model + * + * @param {number} theta - ability estimate + * @param {Zeta} zeta - item params * @returns {number} the probability */ export const itemResponseFunction = (theta: number, zeta: Zeta) => { @@ -136,9 +157,10 @@ export const itemResponseFunction = (theta: number, zeta: Zeta) => { }; /** - * a 3PL Fisher information function - * @param theta - ability estimate - * @param zeta - item params + * A 3PL Fisher information function + * + * @param {number} theta - ability estimate + * @param {Zeta} zeta - item params * @returns {number} - the expected value of the observed information */ export const fisherInformation = (theta: number, zeta: Zeta) => { @@ -149,12 +171,13 @@ export const fisherInformation = (theta: number, zeta: Zeta) => { }; /** - * return a Gaussian distribution within a given range - * @param mean - * @param stdDev - * @param min - * @param max - * @param stepSize - the quantization (step size) of the internal table, default = 0.1 + * Return a Gaussian distribution within a given range + * + * @param {number} mean + * @param {number} stdDev + * @param {number} min + * @param {number} max + * @param {number} stepSize - the quantization (step size) of the internal table, default = 0.1 * @returns {Array<[number, number]>} - a normal distribution */ export const normal = (mean = 0, stdDev = 1, min = -4, max = 4, stepSize = 0.1) => { @@ -170,13 +193,13 @@ export const normal = (mean = 0, stdDev = 1, min = -4, max = 4, stepSize = 0.1) }; /** - * find the item in a given array that has the difficulty closest to the target value + * Find the item in a given array that has the difficulty closest to the target value * * @remarks * The input array of stimuli must be sorted by difficulty. * - * @param stimuli Array - an array of stimuli sorted by difficulty - * @param target number - ability estimate + * @param {Stimulus[]} inputStimuli - an array of stimuli sorted by difficulty + * @param {number} target - ability estimate * @returns {number} the index of stimuli */ export const findClosest = (inputStimuli: Array, target: number) => { @@ -216,18 +239,100 @@ export const findClosest = (inputStimuli: Array, target: number) => { } }; -// TODO: Document this -export const validateCorpus = (corpus: MultiZetaStimulus[]): void => { +/** + * Validates a corpus of multi-zeta stimuli to ensure that no cat names are + * duplicated. + * + * @remarks + * This function takes an array of `MultiZetaStimulus` objects, where each + * object represents an item containing item parameters (zetas) associated with + * different CAT models. The function checks for any duplicate cat names across + * each item's array of zeta values. It throws an error if any are found. + * + * @param {MultiZetaStimulus[]} corpus - An array of `MultiZetaStimulus` objects representing the corpora to validate. + * + * @throws {Error} - Throws an error if any duplicate cat names are found across the corpora. + */ +export const checkNoDuplicateCatNames = (corpus: MultiZetaStimulus[]): void => { const zetaCatMapsArray = corpus.map((item) => item.zetas); for (const zetaCatMaps of zetaCatMapsArray) { - const intersection = _intersection(zetaCatMaps); - if (intersection.length > 0) { - throw new Error(`The cat names ${intersection.join(', ')} are present in multiple corpora.`); + const cats = zetaCatMaps.map(({ cats }) => cats); + + // Check to see if there are any duplicate names by comparing the union + // (which removed duplicates) to the flattened array. + const union = _union(...cats); + const flattened = _flatten(cats); + + if (union.length !== flattened.length) { + // If there are duplicates, remove the first occurence of each cat name in + // the union array from the flattened array. The remaining items in the + // flattened array should contain the duplicated cat names. + for (const cat of union) { + const idx = flattened.findIndex((c) => c === cat); + if (idx >= 0) { + flattened.splice(idx, 1); + } + } + + throw new Error(`The cat names ${_uniq(flattened).join(', ')} are present in multiple corpora.`); } } }; -// TODO: Document this +/** + * Filters a list of multi-zeta stimuli based on the availability of model parameters for a specific CAT. + * + * This function takes an array of `MultiZetaStimulus` objects and a `catName` as input. It then filters + * the items based on whether the specified CAT model parameter is present in the item's zeta values. + * The function returns an object containing two arrays: `available` and `missing`. The `available` array + * contains items where the specified CAT model parameter is present, while the `missing` array contains + * items where the parameter is not present. + * + * @param {MultiZetaStimulus[]} items - An array of `MultiZetaStimulus` objects representing the stimuli to filter. + * @param {string} catName - The name of the CAT model parameter to check for. + * + * @returns An object with two arrays: `available` and `missing`. + * + * @example + * ```typescript + * const items: MultiZetaStimulus[] = [ + * { + * stimulus: 'Item 1', + * zetas: [ + * { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + * { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + * ], + * }, + * { + * stimulus: 'Item 2', + * zetas: [ + * { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + * ], + * }, + * ]; + * + * const result = filterItemsByCatParameterAvailability(items, 'Model A'); + * console.log(result.available); + * // Output: [ + * // { + * // stimulus: 'Item 1', + * // zetas: [ + * // { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + * // { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + * // ], + * // }, + * // ] + * console.log(result.missing); + * // Output: [ + * // { + * // stimulus: 'Item 2', + * // zetas: [ + * // { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + * // ], + * // }, + * // ] + * ``` + */ export const filterItemsByCatParameterAvailability = (items: MultiZetaStimulus[], catName: string) => { const paramsExist = items.filter((item) => item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); const paramsMissing = items.filter((item) => !item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); From fc72b04fff5e3a49c5fd7d47285b60a5d4670a05 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Thu, 19 Sep 2024 06:36:52 -0700 Subject: [PATCH 11/47] Start adding clowder tests --- src/__tests__/clowder.test.ts | 41 +++++++++++++++++++++++++---------- src/clowder.ts | 22 +++++++++++++------ 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index eca29f1..e3962e4 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -2,6 +2,12 @@ import { Clowder, ClowderInput } from '../clowder'; import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; import { defaultZeta } from '../utils'; +const createStimulus = (id: string) => ({ + ...defaultZeta(), + id, + content: `Stimulus content ${id}`, +}); + const createMultiZetaStimulus = (id: string, zetas: ZetaCatMap[]): MultiZetaStimulus => ({ id, content: `Multi-Zeta Stimulus content ${id}`, @@ -33,25 +39,36 @@ describe('Clowder Class', () => { clowder = new Clowder(clowderInput); }); - test('should initialize with provided cats and corpora', () => { - expect(Object.keys(clowder['cats'])).toContain('cat1'); - expect(clowder.remainingItems).toHaveLength(2); - expect(clowder.corpus).toHaveLength(1); + it('initializes with provided cats and corpora', () => { + expect(Object.keys(clowder.cats)).toContain('cat1'); + expect(clowder.remainingItems).toHaveLength(5); + expect(clowder.corpus).toHaveLength(5); + expect(clowder.seenItems).toHaveLength(0); }); - test('should validate cat names', () => { + it('validates cat names', () => { expect(() => { clowder.updateCatAndGetNextItem({ catToSelect: 'invalidCat', }); - }).toThrow('Invalid Cat name'); + }).toThrowError('Invalid Cat name'); }); - // test('should update ability estimates', () => { - // clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); - // const cat1 = clowder['cats']['cat1']; - // expect(cat1.theta).toBeGreaterThanOrEqual(0); // Since we mock, assume the result is logical. - // }); + it('updates ability estimates only for the named cats', () => { + const origTheta1 = clowder.cats.cat1.theta; + const origTheta2 = clowder.cats.cat2.theta; + + clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); + + expect(clowder.cats.cat1.theta).not.toBe(origTheta1); + expect(clowder.cats.cat2.theta).toBe(origTheta2); + }); + + it('throws an error when updating ability estimates for an invalid cat', () => { + expect(() => clowder.updateAbilityEstimates(['invalidCatName'], createStimulus('1'), [0])).toThrowError( + 'Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.', + ); + }); // test('should select next stimulus from validated stimuli', () => { // const nextItem = clowder.updateCatAndGetNextItem({ @@ -78,7 +95,7 @@ describe('Clowder Class', () => { // expect(nextItem).toEqual(createStimulus('1')); // Unvalidated item // }); - test('should throw error if previousItems and previousAnswers have mismatched lengths', () => { + test('should throw error if items and answers have mismatched lengths', () => { expect(() => { clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', diff --git a/src/clowder.ts b/src/clowder.ts index 6a7844e..21d2610 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -17,10 +17,10 @@ export interface ClowderInput { } export class Clowder { - private cats: { [name: string]: Cat }; + private _cats: { [name: string]: Cat }; private _corpus: MultiZetaStimulus[]; public remainingItems: MultiZetaStimulus[]; - public seenItems: Stimulus[]; + private _seenItems: Stimulus[]; /** * Create a Clowder object. @@ -28,16 +28,16 @@ export class Clowder { */ constructor({ cats, corpus }: ClowderInput) { // TODO: Need to pass in numItemsRequired so that we know when to stop providing new items. - this.cats = _mapValues(cats, (catInput) => new Cat(catInput)); - this.seenItems = []; + this._cats = _mapValues(cats, (catInput) => new Cat(catInput)); + this._seenItems = []; checkNoDuplicateCatNames(corpus); this._corpus = corpus; this.remainingItems = _cloneDeep(corpus); } private _validateCatName(catName: string): void { - if (!Object.prototype.hasOwnProperty.call(this.cats, catName)) { - throw new Error(`Invalid Cat name. Expected one of ${Object.keys(this.cats).join(', ')}. Received ${catName}.`); + if (!Object.prototype.hasOwnProperty.call(this._cats, catName)) { + throw new Error(`Invalid Cat name. Expected one of ${Object.keys(this._cats).join(', ')}. Received ${catName}.`); } } @@ -45,6 +45,14 @@ export class Clowder { return this._corpus; } + public get cats() { + return this._cats; + } + + public get seenItems() { + return this._seenItems; + } + public get theta() { return _mapValues(this.cats, (cat) => cat.theta); } @@ -131,7 +139,7 @@ export class Clowder { } // Update the seenItems with the provided previous items - this.seenItems.push(...items); + this._seenItems.push(...items); // Remove the seenItems from the remainingItems this.remainingItems = this.remainingItems.filter((stim) => !items.includes(stim)); From 92a4d742828c4705c84136767b505f7908b3ba1f Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Thu, 19 Sep 2024 06:51:17 -0700 Subject: [PATCH 12/47] Add more clowder tests --- src/__tests__/clowder.test.ts | 37 ++++++++++++++++++++++++++++++++++- src/clowder.ts | 1 + 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index e3962e4..b902e95 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -1,3 +1,4 @@ +import { Cat } from '..'; import { Clowder, ClowderInput } from '../clowder'; import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; import { defaultZeta } from '../utils'; @@ -70,6 +71,23 @@ describe('Clowder Class', () => { ); }); + it.each` + property + ${'theta'} + ${'seMeasurement'} + ${'nItems'} + ${'resps'} + ${'zetas'} + `("accesses the '$property' property of each cat", ({ property }) => { + clowder.updateAbilityEstimates(['cat1'], createStimulus('1'), [0]); + clowder.updateAbilityEstimates(['cat2'], createStimulus('1'), [1]); + const expected = { + cat1: clowder.cats['cat1'][property as keyof Cat], + cat2: clowder.cats['cat2'][property as keyof Cat], + }; + expect(clowder[property as keyof Clowder]).toEqual(expected); + }); + // test('should select next stimulus from validated stimuli', () => { // const nextItem = clowder.updateCatAndGetNextItem({ // catToSelect: 'cat1', @@ -95,7 +113,7 @@ describe('Clowder Class', () => { // expect(nextItem).toEqual(createStimulus('1')); // Unvalidated item // }); - test('should throw error if items and answers have mismatched lengths', () => { + it('throws an error if items and answers have mismatched lengths', () => { expect(() => { clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', @@ -104,4 +122,21 @@ describe('Clowder Class', () => { }); }).toThrow('Previous items and answers must have the same length.'); }); + + it('throws an error if catToSelect is invalid', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'invalidCatName', + }); + }).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.'); + }); + + it('throws an error if any of catsToUpdate is invalid', () => { + expect(() => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['invalidCatName', 'cat2'], + }); + }).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.'); + }); }); diff --git a/src/clowder.ts b/src/clowder.ts index 21d2610..fed43cf 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -125,6 +125,7 @@ export class Clowder { // Validate all cat names this._validateCatName(catToSelect); catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; + catsToUpdate.forEach((cat) => { this._validateCatName(cat); }); From 5082d5bf2f188f58e81ffa9f11623c5a3a9298f3 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Fri, 20 Sep 2024 04:58:47 -0700 Subject: [PATCH 13/47] Add documentation and randomlySelectUnvalidated parameter --- src/__tests__/utils.test.ts | 4 +- src/clowder.ts | 128 +++++++++++++++++++++++++++++++----- 2 files changed, 115 insertions(+), 17 deletions(-) diff --git a/src/__tests__/utils.test.ts b/src/__tests__/utils.test.ts index ba72977..91d237e 100644 --- a/src/__tests__/utils.test.ts +++ b/src/__tests__/utils.test.ts @@ -276,7 +276,9 @@ describe('checkNoDuplicateCatNames', () => { zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], }, ]; - expect(() => checkNoDuplicateCatNames(corpus)).toThrowError('The cat names Model C are present in multiple corpora.'); + expect(() => checkNoDuplicateCatNames(corpus)).toThrowError( + 'The cat names Model C are present in multiple corpora.', + ); }); it('should not throw an error when a cat name is not present in multiple corpora', () => { diff --git a/src/clowder.ts b/src/clowder.ts index fed43cf..3cbf315 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -8,14 +8,27 @@ import _zip from 'lodash/zip'; import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './utils'; export interface ClowderInput { - // An object containing Cat configurations for each Cat instance. + /** + * An object containing Cat configurations for each Cat instance. + * Keys correspond to Cat names, while values correspond to Cat configurations. + */ cats: { [name: string]: CatInput; }; - // An object containing arrays of stimuli for each corpus. + /** + * An object containing arrays of stimuli for each corpus. + */ corpus: MultiZetaStimulus[]; } +/** + * The Clowder class is responsible for managing a collection of Cat instances + * along with a corpus of stimuli. It maintains a list of named Cat instances + * and a corpus where each item in the coprpus may have IRT parameters + * corresponding to each named Cat. Clowder provides methods for updating the + * ability estimates of each of its Cats, and selecting the next item to present + * to the participant. + */ export class Clowder { private _cats: { [name: string]: Cat }; private _corpus: MultiZetaStimulus[]; @@ -24,10 +37,18 @@ export class Clowder { /** * Create a Clowder object. + * * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. + * @param {CatInput[]} input.cats - An object containing Cat configurations for each Cat instance. + * @param {MultiZetaStimulus[]} input.corpus - An array of stimuli representing each corpus. + * + * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. */ constructor({ cats, corpus }: ClowderInput) { - // TODO: Need to pass in numItemsRequired so that we know when to stop providing new items. + // TODO: Need to pass in numItemsRequired so that we know when to stop + // providing new items. This may depend on the cat name. For instance, + // perhaps numItemsRequired should be an object with cat names as keys and + // numItemsRequired as values. this._cats = _mapValues(cats, (catInput) => new Cat(catInput)); this._seenItems = []; checkNoDuplicateCatNames(corpus); @@ -35,44 +56,86 @@ export class Clowder { this.remainingItems = _cloneDeep(corpus); } + /** + * Validate the provided Cat name against the existing Cat instances. + * Throw an error if the Cat name is not found. + * + * @param {string} catName - The name of the Cat instance to validate. + * + * @throws {Error} - Throws an error if the provided Cat name is not found among the existing Cat instances. + */ private _validateCatName(catName: string): void { if (!Object.prototype.hasOwnProperty.call(this._cats, catName)) { throw new Error(`Invalid Cat name. Expected one of ${Object.keys(this._cats).join(', ')}. Received ${catName}.`); } } + /** + * The corpus that was provided to this Clowder when it was created. + */ public get corpus() { return this._corpus; } + /** + * The named Cat instances that this Clowder manages. + */ public get cats() { return this._cats; } + /** + * The subset of the input corpus that this Clowder has "seen" so far. + */ public get seenItems() { return this._seenItems; } + /** + * The theta estimates for each Cat instance. + */ public get theta() { return _mapValues(this.cats, (cat) => cat.theta); } + /** + * The standard error of measurement estimates for each Cat instance. + */ public get seMeasurement() { return _mapValues(this.cats, (cat) => cat.seMeasurement); } + /** + * The number of items presented to each Cat instance. + */ public get nItems() { return _mapValues(this.cats, (cat) => cat.nItems); } + /** + * The responses received by each Cat instance. + */ public get resps() { return _mapValues(this.cats, (cat) => cat.resps); } + /** + * The zeta (item parameters) received by each Cat instance. + */ public get zetas() { return _mapValues(this.cats, (cat) => cat.zetas); } + /** + * Updates the ability estimates for the specified Cat instances. + * + * @param {string[]} catNames - The names of the Cat instances to update. + * @param {Zeta | Zeta[]} zeta - The item parameter(s) (zeta) for the given stimuli. + * @param {(0 | 1) | (0 | 1)[]} answer - The corresponding answer(s) (0 or 1) for the given stimuli. + * @param {string} [method] - Optional method for updating ability estimates. If none is provided, it will use the default method for each Cat instance. + * + * @throws {Error} If any `catName` is not found among the existing Cat instances. + */ public updateAbilityEstimates(catNames: string[], zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method?: string) { catNames.forEach((catName) => { this._validateCatName(catName); @@ -83,7 +146,7 @@ export class Clowder { } /** - * Updates the ability estimates for the specified `catsToUpdate` and selects the next stimulus for the `catToSelect`. + * Update the ability estimates for the specified `catsToUpdate` and select the next stimulus for the `catToSelect`. * This function processes previous items and answers, updates internal state, and selects the next stimulus * based on the remaining stimuli and `catToSelect`. * @@ -94,6 +157,7 @@ export class Clowder { * @param {(0 | 1) | (0 | 1)[]} [input.answers=[]] - An array of answers (0 or 1) corresponding to `items`. * @param {string} [input.method] - Optional method for updating ability estimates (if applicable). * @param {string} [input.itemSelect] - Optional item selection method (if applicable). + * @param {boolean} [input.randomlySelectUnvalidated=false] - Optional flag indicating whether to randomly select an unvalidated item for `catToSelect`. * * @returns {Stimulus | undefined} - The next stimulus to present, or `undefined` if no further validated stimuli are available. * @@ -101,11 +165,15 @@ export class Clowder { * @throws {Error} If any `items` are not found in the Clowder's corpora (validated or unvalidated). * * The function operates in several steps: - * 1. Validates the `catToSelect` and `catsToUpdate`. - * 2. Ensures `items` and `answers` arrays are properly formatted. - * 3. Updates the internal list of seen items. - * 4. Updates the ability estimates for the `catsToUpdate`. - * 5. Selects the next stimulus for `catToSelect`, considering validated and unvalidated stimuli. + * 1. Validate: + * a. Validates the `catToSelect` and `catsToUpdate`. + * b. Ensures `items` and `answers` arrays are properly formatted. + * 2. Update: + * a. Updates the internal list of seen items. + * b. Updates the ability estimates for the `catsToUpdate`. + * 3. Select: + * a. Selects the next item using `catToSelect`, considering only remaining items that are valid for that cat. + * b. If desired, randomly selects an unvalidated item for catToSelect. */ public updateCatAndGetNextItem({ catToSelect, @@ -114,6 +182,7 @@ export class Clowder { answers = [], method, itemSelect, + randomlySelectUnvalidated = false, }: { catToSelect: string; catsToUpdate?: string | string[]; @@ -121,11 +190,17 @@ export class Clowder { answers?: (0 | 1) | (0 | 1)[]; method?: string; itemSelect?: string; + randomlySelectUnvalidated?: boolean; }): Stimulus | undefined { - // Validate all cat names + // +----------+ + // ----------| Validate |----------| + // +----------+ + + // Validate catToSelect this._validateCatName(catToSelect); - catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; + // Convert catsToUpdate to array and validate each name + catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; catsToUpdate.forEach((cat) => { this._validateCatName(cat); }); @@ -139,35 +214,53 @@ export class Clowder { throw new Error('Previous items and answers must have the same length.'); } + // +----------+ + // ----------| Update |----------| + // +----------+ + // Update the seenItems with the provided previous items this._seenItems.push(...items); // Remove the seenItems from the remainingItems this.remainingItems = this.remainingItems.filter((stim) => !items.includes(stim)); + // Create a new zip array of items and answers. This will be useful in + // filtering operations below. It ensures that items and their corresponding + // answers "stay together." const itemsAndAnswers = _zip(items, answers) as [Stimulus, 0 | 1][]; // Update the ability estimate for all cats for (const catName of catsToUpdate) { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim, _answer]) => { + const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => { + // We are dealing with a single item in this function. This single item + // has an array of zeta parameters for a bunch of different Cats. We + // need to determine if `catName` is present in that list. So we first + // reduce the zetas to get all of the applicabe cat names. const allCats = stim.zetas.reduce((acc: string[], { cats }: { cats: string }) => { return [...acc, ...cats]; }, []); + + // Then we simply check if `catName` is present in this reduction. return allCats.includes(catName); }); + // Now that we have the subset of items that can apply to this cat, + // retrieve only the item parameters that apply to this cat. const zetasAndAnswersForCat = itemsAndAnswersForCat.map(([stim, _answer]) => { const { zetas } = stim; const zetaForCat = zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName)); return [zetaForCat.zeta, _answer]; }); - // Extract the cat to update ability estimate + // Finally, unzip the zetas and answers and feed them into the cat's updateAbilityEstimate method. const [zetas, answers] = _unzip(zetasAndAnswersForCat); this.cats[catName].updateAbilityEstimate(zetas, answers, method); } + // +----------+ + // ----------| Select |----------| + // +----------+ + // Now, we need to dynamically calculate the stimuli available for selection by `catToSelect`. // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` @@ -197,7 +290,6 @@ export class Clowder { return _isEqual(rest, nextStimulus); }); - // Added some logic to mix in the unvalidated stimuli if needed. if (missing.length === 0) { // If there are no more unvalidated stimuli, we only have validated items left. // Use the Cat to find the next item. The Cat may return undefined if all validated items have been seen. @@ -207,7 +299,11 @@ export class Clowder { return missing[Math.floor(Math.random() * missing.length)]; } else { // In this case, there are both validated and unvalidated items left. - // We need to randomly insert unvalidated items + // We randomly insert unvalidated items + if (!randomlySelectUnvalidated) { + return returnStimulus; + } + const numRemaining = { available: available.length, missing: missing.length, From f7cf2f779decb4e0d90db76bb1b0dd4a11f4aea6 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Fri, 20 Sep 2024 06:15:13 -0700 Subject: [PATCH 14/47] Add more tests and a random seed --- src/__tests__/clowder.test.ts | 95 ++++++++++++++++++++++++++--------- src/clowder.ts | 52 ++++++++++++++----- src/index.ts | 2 +- 3 files changed, 110 insertions(+), 39 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index b902e95..1e011e4 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -47,6 +47,27 @@ describe('Clowder Class', () => { expect(clowder.seenItems).toHaveLength(0); }); + it('throws an error when given an invalid corpus', () => { + expect(() => { + const corpus: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + { cats: ['Model C'], zeta: { a: 1, b: 2, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + new Clowder({ cats: { cat1: {} }, corpus }); + }).toThrowError('The cat names Model C are present in multiple corpora.'); + }); + it('validates cat names', () => { expect(() => { clowder.updateCatAndGetNextItem({ @@ -88,31 +109,6 @@ describe('Clowder Class', () => { expect(clowder[property as keyof Clowder]).toEqual(expected); }); - // test('should select next stimulus from validated stimuli', () => { - // const nextItem = clowder.updateCatAndGetNextItem({ - // catToSelect: 'cat1', - // catsToUpdate: ['cat1'], - // previousItems: [createStimulus('1')], - // previousAnswers: [1], - // }); - // expect(nextItem).toEqual(createStimulus('1')); // Second validated stimulus - // }); - - // test('should return unvalidated stimulus when no validated stimuli remain', () => { - // clowder.updateCatAndGetNextItem({ - // catToSelect: 'cat1', - // previousItems: [createStimulus('1'), createStimulus('2')], - // previousAnswers: [1, 0], - // }); - - // const nextItem = clowder.updateCatAndGetNextItem({ - // catToSelect: 'cat1', - // previousItems: [], - // previousAnswers: [], - // }); - // expect(nextItem).toEqual(createStimulus('1')); // Unvalidated item - // }); - it('throws an error if items and answers have mismatched lengths', () => { expect(() => { clowder.updateCatAndGetNextItem({ @@ -139,4 +135,53 @@ describe('Clowder Class', () => { }); }).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.'); }); + + it('updates seen and remaining items', () => { + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat2', + catsToUpdate: ['cat1', 'cat2'], + items: [clowder.corpus[0], clowder.corpus[1], clowder.corpus[2]], + answers: [1, 1, 1], + }); + + expect(clowder.seenItems).toHaveLength(3); + expect(clowder.remainingItems).toHaveLength(2); + }); + + it('should select an item that has not yet been seen', () => { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat2', + catsToUpdate: ['cat1', 'cat2'], + items: [clowder.corpus[0], clowder.corpus[1], clowder.corpus[2]], + answers: [1, 1, 1], + }); + + expect([clowder.corpus[3], clowder.corpus[4]]).toContainEqual(nextItem); // Third validated stimulus + }); + + it('should select a validated item if validated items are present and randomlySelectUnvalidated is false', () => { + // TODO: Implement this test + expect(1).toBe(0); + }); + + it('should select an unvalidated item if no validated items remain', () => { + // TODO: Implement this test + expect(1).toBe(0); + }); + + it('should correctly update ability estimates during the updateCatAndGetNextItem method', () => { + // TODO: Implement this test + expect(1).toBe(0); + }); + + it('should randomly choose between validated and unvalidated items if randomlySelectUnvalidated is true', () => { + // TODO: Implement this test + // Pass in a random seed for reproducibility + expect(1).toBe(0); + }); + + it('should return undefined if no more items remain', () => { + // TODO: Implement this test + expect(1).toBe(0); + }); }); diff --git a/src/clowder.ts b/src/clowder.ts index 3cbf315..64c79f7 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -1,11 +1,14 @@ import { Cat, CatInput } from './index'; import { MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; +import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './utils'; import _cloneDeep from 'lodash/cloneDeep'; +import _differenceWith from 'lodash/differenceWith'; import _isEqual from 'lodash/isEqual'; import _mapValues from 'lodash/mapValues'; +import _omit from 'lodash/omit'; import _unzip from 'lodash/unzip'; import _zip from 'lodash/zip'; -import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './utils'; +import seedrandom from 'seedrandom'; export interface ClowderInput { /** @@ -19,6 +22,10 @@ export interface ClowderInput { * An object containing arrays of stimuli for each corpus. */ corpus: MultiZetaStimulus[]; + /** + * A random seed for reproducibility. If not provided, a random seed will be generated. + */ + randomSeed?: string | null; } /** @@ -32,8 +39,9 @@ export interface ClowderInput { export class Clowder { private _cats: { [name: string]: Cat }; private _corpus: MultiZetaStimulus[]; - public remainingItems: MultiZetaStimulus[]; + private _remainingItems: MultiZetaStimulus[]; private _seenItems: Stimulus[]; + private readonly _rng: ReturnType; /** * Create a Clowder object. @@ -44,7 +52,7 @@ export class Clowder { * * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. */ - constructor({ cats, corpus }: ClowderInput) { + constructor({ cats, corpus, randomSeed = null }: ClowderInput) { // TODO: Need to pass in numItemsRequired so that we know when to stop // providing new items. This may depend on the cat name. For instance, // perhaps numItemsRequired should be an object with cat names as keys and @@ -53,7 +61,8 @@ export class Clowder { this._seenItems = []; checkNoDuplicateCatNames(corpus); this._corpus = corpus; - this.remainingItems = _cloneDeep(corpus); + this._remainingItems = _cloneDeep(corpus); + this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); } /** @@ -70,6 +79,13 @@ export class Clowder { } } + /** + * The named Cat instances that this Clowder manages. + */ + public get cats() { + return this._cats; + } + /** * The corpus that was provided to this Clowder when it was created. */ @@ -78,10 +94,10 @@ export class Clowder { } /** - * The named Cat instances that this Clowder manages. + * The subset of the input corpus that this Clowder has not yet "seen". */ - public get cats() { - return this._cats; + public get remainingItems() { + return this._remainingItems; } /** @@ -221,8 +237,8 @@ export class Clowder { // Update the seenItems with the provided previous items this._seenItems.push(...items); - // Remove the seenItems from the remainingItems - this.remainingItems = this.remainingItems.filter((stim) => !items.includes(stim)); + // Remove the provided previous items from the remainingItems + this._remainingItems = _differenceWith(this._remainingItems, items, _isEqual); // Create a new zip array of items and answers. This will be useful in // filtering operations below. It ensures that items and their corresponding @@ -264,7 +280,7 @@ export class Clowder { // Now, we need to dynamically calculate the stimuli available for selection by `catToSelect`. // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` - const { available, missing } = filterItemsByCatParameterAvailability(this.remainingItems, catToSelect); + const { available, missing } = filterItemsByCatParameterAvailability(this._remainingItems, catToSelect); // The cat expects an array of Stimulus objects, with the zeta parameters // spread at the top-level of each Stimulus object. So we need to convert @@ -281,13 +297,23 @@ export class Clowder { // Use the catForSelect to determine the next stimulus const cat = this.cats[catToSelect]; const { nextStimulus } = cat.findNextItem(availableCatInput, itemSelect); + const nextStimulusWithoutZeta = _omit(nextStimulus, [ + 'a', + 'b', + 'c', + 'd', + 'discrimination', + 'difficulty', + 'guessing', + 'slipping', + ]); // Again `nextStimulus` will be a Stimulus object, or `undefined` if no further validated stimuli are available. // We need to convert the Stimulus object back to a MultiZetaStimulus object to return to the user. const returnStimulus: MultiZetaStimulus | undefined = available.find((stim) => { // eslint-disable-next-line @typescript-eslint/no-unused-vars const { zetas, ...rest } = stim; - return _isEqual(rest, nextStimulus); + return _isEqual(rest, nextStimulusWithoutZeta); }); if (missing.length === 0) { @@ -296,7 +322,7 @@ export class Clowder { return returnStimulus; } else if (available.length === 0) { // In this case, there are no more validated items left. Choose an unvalidated item at random. - return missing[Math.floor(Math.random() * missing.length)]; + return missing[Math.floor(this._rng() * missing.length)]; } else { // In this case, there are both validated and unvalidated items left. // We randomly insert unvalidated items @@ -311,7 +337,7 @@ export class Clowder { const random = Math.random(); if (random < numRemaining.missing / (numRemaining.available + numRemaining.missing)) { - return missing[Math.floor(Math.random() * missing.length)]; + return missing[Math.floor(this._rng() * missing.length)]; } else { return returnStimulus; } diff --git a/src/index.ts b/src/index.ts index 91a69e3..692a33a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -288,7 +288,7 @@ export class Cat { } private selectorRandom(arr: Stimulus[]) { - const index = Math.floor(this._rng() * arr.length); + const index = this.randomInteger(0, arr.length - 1); const nextItem = arr.splice(index, 1)[0]; return { nextStimulus: nextItem, From 3d1a614b20a644dc424d35a0f3249d2920c55f0e Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Fri, 20 Sep 2024 06:24:10 -0700 Subject: [PATCH 15/47] Reorganize files --- src/__tests__/{index.test.ts => cat.test.ts} | 2 +- src/__tests__/clowder.test.ts | 4 +- src/__tests__/corpus.test.ts | 324 ++++++++++++++++++ src/__tests__/utils.test.ts | 328 +------------------ src/cat.ts | 319 ++++++++++++++++++ src/clowder.ts | 4 +- src/corpus.ts | 246 ++++++++++++++ src/index.ts | 327 +----------------- src/utils.ts | 248 +------------- 9 files changed, 899 insertions(+), 903 deletions(-) rename src/__tests__/{index.test.ts => cat.test.ts} (99%) create mode 100644 src/__tests__/corpus.test.ts create mode 100644 src/cat.ts create mode 100644 src/corpus.ts diff --git a/src/__tests__/index.test.ts b/src/__tests__/cat.test.ts similarity index 99% rename from src/__tests__/index.test.ts rename to src/__tests__/cat.test.ts index 61a761d..ddeb0a3 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/cat.test.ts @@ -2,7 +2,7 @@ import { Cat } from '../index'; import { Stimulus } from '../type'; import seedrandom from 'seedrandom'; -import { convertZeta } from '../utils'; +import { convertZeta } from '../corpus'; for (const format of ['symbolic', 'semantic'] as Array<'symbolic' | 'semantic'>) { describe(`Cat with ${format} zeta`, () => { diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 1e011e4..eacaa50 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -1,7 +1,7 @@ -import { Cat } from '..'; +import { Cat } from '../cat'; import { Clowder, ClowderInput } from '../clowder'; import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; -import { defaultZeta } from '../utils'; +import { defaultZeta } from '../corpus'; const createStimulus = (id: string) => ({ ...defaultZeta(), diff --git a/src/__tests__/corpus.test.ts b/src/__tests__/corpus.test.ts new file mode 100644 index 0000000..e057937 --- /dev/null +++ b/src/__tests__/corpus.test.ts @@ -0,0 +1,324 @@ +import { MultiZetaStimulus, Stimulus, Zeta } from '../type'; +import { + validateZetaParams, + ZETA_KEY_MAP, + defaultZeta, + fillZetaDefaults, + convertZeta, + checkNoDuplicateCatNames, + filterItemsByCatParameterAvailability, +} from '../corpus'; +import _omit from 'lodash/omit'; + +describe('validateZetaParams', () => { + it('throws an error when providing both a and discrimination', () => { + expect(() => validateZetaParams({ a: 1, discrimination: 1 })).toThrow( + 'This item has both an `a` key and `discrimination` key. Please provide only one.', + ); + }); + + it('throws an error when providing both b and difficulty', () => { + expect(() => validateZetaParams({ b: 1, difficulty: 1 })).toThrow( + 'This item has both a `b` key and `difficulty` key. Please provide only one.', + ); + }); + + it('throws an error when providing both c and guessing', () => { + expect(() => validateZetaParams({ c: 1, guessing: 1 })).toThrow( + 'This item has both a `c` key and `guessing` key. Please provide only one.', + ); + }); + + it('throws an error when providing both d and slipping', () => { + expect(() => validateZetaParams({ d: 1, slipping: 1 })).toThrow( + 'This item has both a `d` key and `slipping` key. Please provide only one.', + ); + }); + + it('throws an error when requiring all keys and missing one', () => { + for (const key of ['a', 'b', 'c', 'd'] as (keyof typeof ZETA_KEY_MAP)[]) { + const semanticKey = ZETA_KEY_MAP[key]; + const zeta = _omit(defaultZeta('symbolic'), [key]); + + expect(() => validateZetaParams(zeta, true)).toThrow( + `This item is missing the key \`${String(key)}\` or \`${semanticKey}\`.`, + ); + } + }); +}); + +describe('fillZetaDefaults', () => { + it('fills in default values for missing keys', () => { + const zeta: Zeta = { + difficulty: 1, + guessing: 0.5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'semantic'); + + expect(filledZeta).toEqual({ + discrimination: 1, + difficulty: 1, + guessing: 0.5, + slipping: 1, + }); + }); + + it('does not modify the input object when no missing keys', () => { + const zeta: Zeta = { + a: 5, + b: 5, + c: 5, + d: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'symbolic'); + + expect(filledZeta).toEqual(zeta); + }); + + it('converts to semantic format when desired', () => { + const zeta: Zeta = { + a: 5, + b: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'semantic'); + + expect(filledZeta).toEqual({ + difficulty: 5, + discrimination: 5, + guessing: 0, + slipping: 1, + }); + }); + + it('converts to symbolic format when desired', () => { + const zeta: Zeta = { + difficulty: 5, + discrimination: 5, + }; + + const filledZeta = fillZetaDefaults(zeta, 'symbolic'); + + expect(filledZeta).toEqual({ + a: 5, + b: 5, + c: 0, + d: 1, + }); + }); +}); + +describe('convertZeta', () => { + it('converts from symbolic format to semantic format', () => { + const zeta: Zeta = { + a: 1, + b: 2, + c: 3, + d: 4, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + }); + }); + + it('converts from semantic format to symbolic format', () => { + const zeta: Zeta = { + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + }; + + const convertedZeta = convertZeta(zeta, 'symbolic'); + + expect(convertedZeta).toEqual({ + a: 1, + b: 2, + c: 3, + d: 4, + }); + }); + + it('throws an error when converting from an unsupported format', () => { + const zeta: Zeta = { + a: 1, + b: 2, + c: 3, + d: 4, + }; + + expect(() => convertZeta(zeta, 'unsupported' as 'symbolic')).toThrow( + "Invalid desired format. Expected 'symbolic' or'semantic'. Received unsupported instead.", + ); + }); + + it('does not modify other keys when converting', () => { + const zeta: Stimulus = { + a: 1, + b: 2, + c: 3, + d: 4, + key1: 5, + key2: 6, + key3: 7, + key4: 8, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + guessing: 3, + slipping: 4, + key1: 5, + key2: 6, + key3: 7, + key4: 8, + }); + }); + + it('converts only existing keys', () => { + const zeta: Zeta = { + a: 1, + b: 2, + }; + + const convertedZeta = convertZeta(zeta, 'semantic'); + + expect(convertedZeta).toEqual({ + discrimination: 1, + difficulty: 2, + }); + }); +}); + +describe('checkNoDuplicateCatNames', () => { + it('should throw an error when a cat name is present in multiple zetas', () => { + const corpus: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + { cats: ['Model C'], zeta: { a: 1, b: 2, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + expect(() => checkNoDuplicateCatNames(corpus)).toThrowError( + 'The cat names Model C are present in multiple corpora.', + ); + }); + + it('should not throw an error when a cat name is not present in multiple corpora', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + expect(() => checkNoDuplicateCatNames(items)).not.toThrowError(); + }); + + it('should handle an empty corpus without throwing an error', () => { + const emptyCorpus: MultiZetaStimulus[] = []; + + expect(() => checkNoDuplicateCatNames(emptyCorpus)).not.toThrowError(); + }); +}); + +describe('filterItemsByCatParameterAvailability', () => { + it('returns an empty "available" array when no items match the catname', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model D'); + + expect(result.available).toEqual([]); + expect(result.missing).toEqual(items); + }); + + it('returns empty missing array when all items match the catname', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model A'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [ + { cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + { cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }, + ], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model A'); + + expect(result.missing).toEqual([]); + expect(result.available).toEqual(items); + }); + + it('separates items based on matching catnames', () => { + const items: MultiZetaStimulus[] = [ + { + stimulus: 'Item 1', + zetas: [ + { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + ], + }, + { + stimulus: 'Item 2', + zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], + }, + { + stimulus: 'Item 3', + zetas: [{ cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }], + }, + ]; + + const result = filterItemsByCatParameterAvailability(items, 'Model A'); + + // Assert + expect(result.available.length).toBe(2); + expect(result.available[0].stimulus).toBe('Item 1'); + expect(result.available[1].stimulus).toBe('Item 3'); + expect(result.missing.length).toBe(1); + expect(result.missing[0].stimulus).toBe('Item 2'); + }); +}); diff --git a/src/__tests__/utils.test.ts b/src/__tests__/utils.test.ts index 91d237e..8895e8e 100644 --- a/src/__tests__/utils.test.ts +++ b/src/__tests__/utils.test.ts @@ -1,17 +1,4 @@ -import { MultiZetaStimulus, Stimulus, Zeta } from '../type'; -import { - itemResponseFunction, - fisherInformation, - findClosest, - validateZetaParams, - ZETA_KEY_MAP, - defaultZeta, - fillZetaDefaults, - convertZeta, - checkNoDuplicateCatNames, - filterItemsByCatParameterAvailability, -} from '../utils'; -import _omit from 'lodash/omit'; +import { itemResponseFunction, fisherInformation, findClosest } from '../utils'; describe('itemResponseFunction', () => { it('correctly calculates the probability', () => { @@ -68,316 +55,3 @@ describe('findClosest', () => { expect(findClosest(stimuliWithDecimal, 9.1)).toBe(2); }); }); - -describe('validateZetaParams', () => { - it('throws an error when providing both a and discrimination', () => { - expect(() => validateZetaParams({ a: 1, discrimination: 1 })).toThrow( - 'This item has both an `a` key and `discrimination` key. Please provide only one.', - ); - }); - - it('throws an error when providing both b and difficulty', () => { - expect(() => validateZetaParams({ b: 1, difficulty: 1 })).toThrow( - 'This item has both a `b` key and `difficulty` key. Please provide only one.', - ); - }); - - it('throws an error when providing both c and guessing', () => { - expect(() => validateZetaParams({ c: 1, guessing: 1 })).toThrow( - 'This item has both a `c` key and `guessing` key. Please provide only one.', - ); - }); - - it('throws an error when providing both d and slipping', () => { - expect(() => validateZetaParams({ d: 1, slipping: 1 })).toThrow( - 'This item has both a `d` key and `slipping` key. Please provide only one.', - ); - }); - - it('throws an error when requiring all keys and missing one', () => { - for (const key of ['a', 'b', 'c', 'd'] as (keyof typeof ZETA_KEY_MAP)[]) { - const semanticKey = ZETA_KEY_MAP[key]; - const zeta = _omit(defaultZeta('symbolic'), [key]); - - expect(() => validateZetaParams(zeta, true)).toThrow( - `This item is missing the key \`${String(key)}\` or \`${semanticKey}\`.`, - ); - } - }); -}); - -describe('fillZetaDefaults', () => { - it('fills in default values for missing keys', () => { - const zeta: Zeta = { - difficulty: 1, - guessing: 0.5, - }; - - const filledZeta = fillZetaDefaults(zeta, 'semantic'); - - expect(filledZeta).toEqual({ - discrimination: 1, - difficulty: 1, - guessing: 0.5, - slipping: 1, - }); - }); - - it('does not modify the input object when no missing keys', () => { - const zeta: Zeta = { - a: 5, - b: 5, - c: 5, - d: 5, - }; - - const filledZeta = fillZetaDefaults(zeta, 'symbolic'); - - expect(filledZeta).toEqual(zeta); - }); - - it('converts to semantic format when desired', () => { - const zeta: Zeta = { - a: 5, - b: 5, - }; - - const filledZeta = fillZetaDefaults(zeta, 'semantic'); - - expect(filledZeta).toEqual({ - difficulty: 5, - discrimination: 5, - guessing: 0, - slipping: 1, - }); - }); - - it('converts to symbolic format when desired', () => { - const zeta: Zeta = { - difficulty: 5, - discrimination: 5, - }; - - const filledZeta = fillZetaDefaults(zeta, 'symbolic'); - - expect(filledZeta).toEqual({ - a: 5, - b: 5, - c: 0, - d: 1, - }); - }); -}); - -describe('convertZeta', () => { - it('converts from symbolic format to semantic format', () => { - const zeta: Zeta = { - a: 1, - b: 2, - c: 3, - d: 4, - }; - - const convertedZeta = convertZeta(zeta, 'semantic'); - - expect(convertedZeta).toEqual({ - discrimination: 1, - difficulty: 2, - guessing: 3, - slipping: 4, - }); - }); - - it('converts from semantic format to symbolic format', () => { - const zeta: Zeta = { - discrimination: 1, - difficulty: 2, - guessing: 3, - slipping: 4, - }; - - const convertedZeta = convertZeta(zeta, 'symbolic'); - - expect(convertedZeta).toEqual({ - a: 1, - b: 2, - c: 3, - d: 4, - }); - }); - - it('throws an error when converting from an unsupported format', () => { - const zeta: Zeta = { - a: 1, - b: 2, - c: 3, - d: 4, - }; - - expect(() => convertZeta(zeta, 'unsupported' as 'symbolic')).toThrow( - "Invalid desired format. Expected 'symbolic' or'semantic'. Received unsupported instead.", - ); - }); - - it('does not modify other keys when converting', () => { - const zeta: Stimulus = { - a: 1, - b: 2, - c: 3, - d: 4, - key1: 5, - key2: 6, - key3: 7, - key4: 8, - }; - - const convertedZeta = convertZeta(zeta, 'semantic'); - - expect(convertedZeta).toEqual({ - discrimination: 1, - difficulty: 2, - guessing: 3, - slipping: 4, - key1: 5, - key2: 6, - key3: 7, - key4: 8, - }); - }); - - it('converts only existing keys', () => { - const zeta: Zeta = { - a: 1, - b: 2, - }; - - const convertedZeta = convertZeta(zeta, 'semantic'); - - expect(convertedZeta).toEqual({ - discrimination: 1, - difficulty: 2, - }); - }); -}); - -describe('checkNoDuplicateCatNames', () => { - it('should throw an error when a cat name is present in multiple zetas', () => { - const corpus: MultiZetaStimulus[] = [ - { - stimulus: 'Item 1', - zetas: [ - { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, - { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, - { cats: ['Model C'], zeta: { a: 1, b: 2, c: 0.3, d: 0.9 } }, - ], - }, - { - stimulus: 'Item 2', - zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], - }, - ]; - expect(() => checkNoDuplicateCatNames(corpus)).toThrowError( - 'The cat names Model C are present in multiple corpora.', - ); - }); - - it('should not throw an error when a cat name is not present in multiple corpora', () => { - const items: MultiZetaStimulus[] = [ - { - stimulus: 'Item 1', - zetas: [ - { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, - { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, - ], - }, - { - stimulus: 'Item 2', - zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], - }, - ]; - - expect(() => checkNoDuplicateCatNames(items)).not.toThrowError(); - }); - - it('should handle an empty corpus without throwing an error', () => { - const emptyCorpus: MultiZetaStimulus[] = []; - - expect(() => checkNoDuplicateCatNames(emptyCorpus)).not.toThrowError(); - }); -}); - -describe('filterItemsByCatParameterAvailability', () => { - it('returns an empty "available" array when no items match the catname', () => { - const items: MultiZetaStimulus[] = [ - { - stimulus: 'Item 1', - zetas: [ - { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, - { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, - ], - }, - { - stimulus: 'Item 2', - zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], - }, - ]; - - const result = filterItemsByCatParameterAvailability(items, 'Model D'); - - expect(result.available).toEqual([]); - expect(result.missing).toEqual(items); - }); - - it('returns empty missing array when all items match the catname', () => { - const items: MultiZetaStimulus[] = [ - { - stimulus: 'Item 1', - zetas: [ - { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, - { cats: ['Model A'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, - ], - }, - { - stimulus: 'Item 2', - zetas: [ - { cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, - { cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }, - ], - }, - ]; - - const result = filterItemsByCatParameterAvailability(items, 'Model A'); - - expect(result.missing).toEqual([]); - expect(result.available).toEqual(items); - }); - - it('separates items based on matching catnames', () => { - const items: MultiZetaStimulus[] = [ - { - stimulus: 'Item 1', - zetas: [ - { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, - { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, - ], - }, - { - stimulus: 'Item 2', - zetas: [{ cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], - }, - { - stimulus: 'Item 3', - zetas: [{ cats: ['Model A'], zeta: { a: 3, b: 0.9, c: 0.4, d: 0.99 } }], - }, - ]; - - const result = filterItemsByCatParameterAvailability(items, 'Model A'); - - // Assert - expect(result.available.length).toBe(2); - expect(result.available[0].stimulus).toBe('Item 1'); - expect(result.available[1].stimulus).toBe('Item 3'); - expect(result.missing.length).toBe(1); - expect(result.missing[0].stimulus).toBe('Item 2'); - }); -}); diff --git a/src/cat.ts b/src/cat.ts new file mode 100644 index 0000000..47a86f9 --- /dev/null +++ b/src/cat.ts @@ -0,0 +1,319 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { minimize_Powell } from 'optimization-js'; +import { Stimulus, Zeta } from './type'; +import { itemResponseFunction, fisherInformation, normal, findClosest } from './utils'; +import { validateZetaParams, fillZetaDefaults } from './corpus'; +import seedrandom from 'seedrandom'; +import _clamp from 'lodash/clamp'; +import _cloneDeep from 'lodash/cloneDeep'; + +export const abilityPrior = normal(); + +export interface CatInput { + method?: string; + itemSelect?: string; + nStartItems?: number; + startSelect?: string; + theta?: number; + minTheta?: number; + maxTheta?: number; + prior?: number[][]; + randomSeed?: string | null; +} + +export class Cat { + public method: string; + public itemSelect: string; + public minTheta: number; + public maxTheta: number; + public prior: number[][]; + private readonly _zetas: Zeta[]; + private readonly _resps: (0 | 1)[]; + private _theta: number; + private _seMeasurement: number; + public nStartItems: number; + public startSelect: string; + private readonly _rng: ReturnType; + + /** + * Create a Cat object. This expects an single object parameter with the following keys + * @param {{method: string, itemSelect: string, nStartItems: number, startSelect:string, theta: number, minTheta: number, maxTheta: number, prior: number[][]}=} destructuredParam + * method: ability estimator, e.g. MLE or EAP, default = 'MLE' + * itemSelect: the method of item selection, e.g. "MFI", "random", "closest", default method = 'MFI' + * nStartItems: first n trials to keep non-adaptive selection + * startSelect: rule to select first n trials + * theta: initial theta estimate + * minTheta: lower bound of theta + * maxTheta: higher bound of theta + * prior: the prior distribution + * randomSeed: set a random seed to trace the simulation + */ + + constructor({ + method = 'MLE', + itemSelect = 'MFI', + nStartItems = 0, + startSelect = 'middle', + theta = 0, + minTheta = -6, + maxTheta = 6, + prior = abilityPrior, + randomSeed = null, + }: CatInput = {}) { + this.method = Cat.validateMethod(method); + + this.itemSelect = Cat.validateItemSelect(itemSelect); + + this.startSelect = Cat.validateStartSelect(startSelect); + + this.minTheta = minTheta; + this.maxTheta = maxTheta; + this.prior = prior; + this._zetas = []; + this._resps = []; + this._theta = theta; + this._seMeasurement = Number.MAX_VALUE; + this.nStartItems = nStartItems; + this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); + } + + public get theta() { + return this._theta; + } + + public get seMeasurement() { + return this._seMeasurement; + } + + /** + * Return the number of items that have been observed so far. + */ + public get nItems() { + return this._resps.length; + } + + public get resps() { + return this._resps; + } + + public get zetas() { + return this._zetas; + } + + private static validateMethod(method: string) { + const lowerMethod = method.toLowerCase(); + const validMethods: Array = ['mle', 'eap']; // TO DO: add staircase + if (!validMethods.includes(lowerMethod)) { + throw new Error('The abilityEstimator you provided is not in the list of valid methods'); + } + return lowerMethod; + } + + private static validateItemSelect(itemSelect: string) { + const lowerItemSelect = itemSelect.toLowerCase(); + const validItemSelect: Array = ['mfi', 'random', 'closest', 'fixed']; + if (!validItemSelect.includes(lowerItemSelect)) { + throw new Error('The itemSelector you provided is not in the list of valid methods'); + } + return lowerItemSelect; + } + + private static validateStartSelect(startSelect: string) { + const lowerStartSelect = startSelect.toLowerCase(); + const validStartSelect: Array = ['random', 'middle', 'fixed']; // TO DO: add staircase + if (!validStartSelect.includes(lowerStartSelect)) { + throw new Error('The startSelect you provided is not in the list of valid methods'); + } + return lowerStartSelect; + } + + /** + * use previous response patterns and item params to calculate the estimate ability based on a defined method + * @param zeta - last item param + * @param answer - last response pattern + * @param method + */ + public updateAbilityEstimate(zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method: string = this.method) { + method = Cat.validateMethod(method); + + zeta = Array.isArray(zeta) ? zeta : [zeta]; + answer = Array.isArray(answer) ? answer : [answer]; + + zeta.forEach((z) => validateZetaParams(z, true)); + + if (zeta.length !== answer.length) { + throw new Error('Unmatched length between answers and item params'); + } + this._zetas.push(...zeta); + this._resps.push(...answer); + + if (method === 'eap') { + this._theta = this.estimateAbilityEAP(); + } else if (method === 'mle') { + this._theta = this.estimateAbilityMLE(); + } + this.calculateSE(); + } + + private estimateAbilityEAP() { + let num = 0; + let nf = 0; + this.prior.forEach(([theta, probability]) => { + const like = this.likelihood(theta); + num += theta * like * probability; + nf += like * probability; + }); + + return num / nf; + } + + private estimateAbilityMLE() { + const theta0 = [0]; + const solution = minimize_Powell(this.negLikelihood.bind(this), theta0); + const theta = solution.argument[0]; + return _clamp(theta, this.minTheta, this.maxTheta); + } + + private negLikelihood(thetaArray: Array) { + return -this.likelihood(thetaArray[0]); + } + + private likelihood(theta: number) { + return this._zetas.reduce((acc, zeta, i) => { + const irf = itemResponseFunction(theta, zeta); + return this._resps[i] === 1 ? acc + Math.log(irf) : acc + Math.log(1 - irf); + }, 1); + } + + /** + * calculate the standard error of ability estimation + */ + private calculateSE() { + const sum = this._zetas.reduce((previousValue, zeta) => previousValue + fisherInformation(this._theta, zeta), 0); + this._seMeasurement = 1 / Math.sqrt(sum); + } + + /** + * find the next available item from an input array of stimuli based on a selection method + * + * remainingStimuli is sorted by fisher information to reduce the computation complexity for future item selection + * @param stimuli - an array of stimulus + * @param itemSelect - the item selection method + * @param deepCopy - default deepCopy = true + * @returns {nextStimulus: Stimulus, remainingStimuli: Array} + */ + public findNextItem(stimuli: Stimulus[], itemSelect: string = this.itemSelect, deepCopy = true) { + let arr: Array; + let selector = Cat.validateItemSelect(itemSelect); + if (deepCopy) { + arr = _cloneDeep(stimuli); + } else { + arr = stimuli; + } + + arr = arr.map((stim) => fillZetaDefaults(stim, 'semantic')); + + if (this.nItems < this.nStartItems) { + selector = this.startSelect; + } + if (selector !== 'mfi' && selector !== 'fixed') { + // for mfi, we sort the arr by fisher information in the private function to select the best item, + // and then sort by difficulty to return the remainingStimuli + // for fixed, we want to keep the corpus order as input + arr.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); + } + + if (selector === 'middle') { + // middle will only be used in startSelect + return this.selectorMiddle(arr); + } else if (selector === 'closest') { + return this.selectorClosest(arr); + } else if (selector === 'random') { + return this.selectorRandom(arr); + } else if (selector === 'fixed') { + return this.selectorFixed(arr); + } else { + return this.selectorMFI(arr); + } + } + + private selectorMFI(inputStimuli: Stimulus[]) { + const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); + const stimuliAddFisher = stimuli.map((element: Stimulus) => ({ + fisherInformation: fisherInformation(this._theta, fillZetaDefaults(element, 'symbolic')), + ...element, + })); + + stimuliAddFisher.sort((a, b) => b.fisherInformation - a.fisherInformation); + stimuliAddFisher.forEach((stimulus: Stimulus) => { + delete stimulus['fisherInformation']; + }); + return { + nextStimulus: stimuliAddFisher[0], + remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!), + }; + } + + private selectorMiddle(arr: Stimulus[]) { + let index: number; + index = Math.floor(arr.length / 2); + + if (arr.length >= this.nStartItems) { + index += this.randomInteger(-Math.floor(this.nStartItems / 2), Math.floor(this.nStartItems / 2)); + } + + const nextItem = arr[index]; + arr.splice(index, 1); + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + private selectorClosest(arr: Stimulus[]) { + //findClosest requires arr is sorted by difficulty + const index = findClosest(arr, this._theta + 0.481); + const nextItem = arr[index]; + arr.splice(index, 1); + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + private selectorRandom(arr: Stimulus[]) { + const index = this.randomInteger(0, arr.length - 1); + const nextItem = arr.splice(index, 1)[0]; + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + /** + * Picks the next item in line from the given list of stimuli. + * It grabs the first item from the list, removes it, and then returns it along with the rest of the list. + * + * @param arr - The list of stimuli to choose from. + * @returns {Object} - An object with the next item and the updated list. + * @returns {Stimulus} return.nextStimulus - The item that was picked from the list. + * @returns {Stimulus[]} return.remainingStimuli - The list of what's left after picking the item. + */ + private selectorFixed(arr: Stimulus[]) { + const nextItem = arr.shift(); + return { + nextStimulus: nextItem, + remainingStimuli: arr, + }; + } + + /** + * return a random integer between min and max + * @param min - The minimum of the random number range (include) + * @param max - The maximum of the random number range (include) + * @returns {number} - random integer within the range + */ + private randomInteger(min: number, max: number) { + return Math.floor(this._rng() * (max - min + 1)) + min; + } +} diff --git a/src/clowder.ts b/src/clowder.ts index 64c79f7..fc1b4a5 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -1,6 +1,6 @@ -import { Cat, CatInput } from './index'; +import { Cat, CatInput } from './cat'; import { MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; -import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './utils'; +import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './corpus'; import _cloneDeep from 'lodash/cloneDeep'; import _differenceWith from 'lodash/differenceWith'; import _isEqual from 'lodash/isEqual'; diff --git a/src/corpus.ts b/src/corpus.ts new file mode 100644 index 0000000..2df7aa7 --- /dev/null +++ b/src/corpus.ts @@ -0,0 +1,246 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { MultiZetaStimulus, Zeta } from './type'; +import _flatten from 'lodash/flatten'; +import _invert from 'lodash/invert'; +import _mapKeys from 'lodash/mapKeys'; +import _union from 'lodash/union'; +import _uniq from 'lodash/uniq'; + +/** + * A constant map from the symbolic item parameter names to their semantic + * counterparts. + */ +export const ZETA_KEY_MAP = { + a: 'discrimination', + b: 'difficulty', + c: 'guessing', + d: 'slipping', +}; + +/** + * Return default item parameters (i.e., zeta) + * + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. + * @returns {Zeta} the default zeta object in the specified format. + */ +export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + const defaultZeta: Zeta = { + a: 1, + b: 0, + c: 0, + d: 1, + }; + + return convertZeta(defaultZeta, desiredFormat); +}; + +/** + * Validates the item (a.k.a. zeta) parameters, prohibiting redundant keys and + * optionally requiring all parameters. + * + * @param {Zeta} zeta - The zeta parameters to validate. + * @param {boolean} requireAll - If `true`, ensures that all required keys are present. Default is `false`. + * + * @throws {Error} Will throw an error if any of the validation rules are violated. + */ +export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { + if (zeta.a !== undefined && zeta.discrimination !== undefined) { + throw new Error('This item has both an `a` key and `discrimination` key. Please provide only one.'); + } + + if (zeta.b !== undefined && zeta.difficulty !== undefined) { + throw new Error('This item has both a `b` key and `difficulty` key. Please provide only one.'); + } + + if (zeta.c !== undefined && zeta.guessing !== undefined) { + throw new Error('This item has both a `c` key and `guessing` key. Please provide only one.'); + } + + if (zeta.d !== undefined && zeta.slipping !== undefined) { + throw new Error('This item has both a `d` key and `slipping` key. Please provide only one.'); + } + + if (requireAll) { + if (zeta.a === undefined && zeta.discrimination === undefined) { + throw new Error('This item is missing the key `a` or `discrimination`.'); + } + + if (zeta.b === undefined && zeta.difficulty === undefined) { + throw new Error('This item is missing the key `b` or `difficulty`.'); + } + + if (zeta.c === undefined && zeta.guessing === undefined) { + throw new Error('This item is missing the key `c` or `guessing`.'); + } + + if (zeta.d === undefined && zeta.slipping === undefined) { + throw new Error('This item is missing the key `d` or `slipping`.'); + } + } +}; + +/** + * Fills in default zeta parameters for any missing keys in the provided zeta object. + * + * @remarks + * This function merges the provided zeta object with the default zeta object, converting + * the keys to the desired format if specified. If no desired format is provided, the + * keys will remain in their original format. + * + * @param {Zeta} zeta - The zeta parameters to fill in defaults for. + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Default is 'symbolic'. + * + * @returns A new zeta object with default values filled in for any missing keys, + * and converted to the desired format if specified. + */ +export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { + return { + ...defaultZeta(desiredFormat), + ...convertZeta(zeta, desiredFormat), + }; +}; + +/** + * Converts zeta parameters between symbolic and semantic formats. + * + * @remarks + * This function takes a zeta object and a desired format as input. It converts + * the keys of the zeta object from their current format to the desired format. + * If the desired format is 'symbolic', the function maps the keys to their + * symbolic counterparts using the `ZETA_KEY_MAP`. If the desired format is + * 'semantic', the function maps the keys to their semantic counterparts using + * the inverse of `ZETA_KEY_MAP`. + * + * @param {Zeta} zeta - The zeta parameters to convert. + * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Must be either 'symbolic' or 'semantic'. + * + * @throws {Error} - Will throw an error if the desired format is not 'symbolic' or 'semantic'. + * + * @returns {Zeta} A new zeta object with keys converted to the desired format. + */ +export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): Zeta => { + if (!['symbolic', 'semantic'].includes(desiredFormat)) { + throw new Error(`Invalid desired format. Expected 'symbolic' or'semantic'. Received ${desiredFormat} instead.`); + } + + return _mapKeys(zeta, (value, key) => { + if (desiredFormat === 'symbolic') { + const inverseMap = _invert(ZETA_KEY_MAP); + if (key in inverseMap) { + return inverseMap[key]; + } else { + return key; + } + } else { + if (key in ZETA_KEY_MAP) { + return ZETA_KEY_MAP[key as keyof typeof ZETA_KEY_MAP]; + } else { + return key; + } + } + }); +}; + +/** + * Validates a corpus of multi-zeta stimuli to ensure that no cat names are + * duplicated. + * + * @remarks + * This function takes an array of `MultiZetaStimulus` objects, where each + * object represents an item containing item parameters (zetas) associated with + * different CAT models. The function checks for any duplicate cat names across + * each item's array of zeta values. It throws an error if any are found. + * + * @param {MultiZetaStimulus[]} corpus - An array of `MultiZetaStimulus` objects representing the corpora to validate. + * + * @throws {Error} - Throws an error if any duplicate cat names are found across the corpora. + */ +export const checkNoDuplicateCatNames = (corpus: MultiZetaStimulus[]): void => { + const zetaCatMapsArray = corpus.map((item) => item.zetas); + for (const zetaCatMaps of zetaCatMapsArray) { + const cats = zetaCatMaps.map(({ cats }) => cats); + + // Check to see if there are any duplicate names by comparing the union + // (which removed duplicates) to the flattened array. + const union = _union(...cats); + const flattened = _flatten(cats); + + if (union.length !== flattened.length) { + // If there are duplicates, remove the first occurence of each cat name in + // the union array from the flattened array. The remaining items in the + // flattened array should contain the duplicated cat names. + for (const cat of union) { + const idx = flattened.findIndex((c) => c === cat); + if (idx >= 0) { + flattened.splice(idx, 1); + } + } + + throw new Error(`The cat names ${_uniq(flattened).join(', ')} are present in multiple corpora.`); + } + } +}; + +/** + * Filters a list of multi-zeta stimuli based on the availability of model parameters for a specific CAT. + * + * This function takes an array of `MultiZetaStimulus` objects and a `catName` as input. It then filters + * the items based on whether the specified CAT model parameter is present in the item's zeta values. + * The function returns an object containing two arrays: `available` and `missing`. The `available` array + * contains items where the specified CAT model parameter is present, while the `missing` array contains + * items where the parameter is not present. + * + * @param {MultiZetaStimulus[]} items - An array of `MultiZetaStimulus` objects representing the stimuli to filter. + * @param {string} catName - The name of the CAT model parameter to check for. + * + * @returns An object with two arrays: `available` and `missing`. + * + * @example + * ```typescript + * const items: MultiZetaStimulus[] = [ + * { + * stimulus: 'Item 1', + * zetas: [ + * { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + * { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + * ], + * }, + * { + * stimulus: 'Item 2', + * zetas: [ + * { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + * ], + * }, + * ]; + * + * const result = filterItemsByCatParameterAvailability(items, 'Model A'); + * console.log(result.available); + * // Output: [ + * // { + * // stimulus: 'Item 1', + * // zetas: [ + * // { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, + * // { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, + * // ], + * // }, + * // ] + * console.log(result.missing); + * // Output: [ + * // { + * // stimulus: 'Item 2', + * // zetas: [ + * // { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, + * // ], + * // }, + * // ] + * ``` + */ +export const filterItemsByCatParameterAvailability = (items: MultiZetaStimulus[], catName: string) => { + const paramsExist = items.filter((item) => item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); + const paramsMissing = items.filter((item) => !item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); + + return { + available: paramsExist, + missing: paramsMissing, + }; +}; diff --git a/src/index.ts b/src/index.ts index 692a33a..00c4cbe 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,325 +1,2 @@ -/* eslint-disable @typescript-eslint/no-non-null-assertion */ -import { minimize_Powell } from 'optimization-js'; -import { Stimulus, Zeta } from './type'; -import { - itemResponseFunction, - fisherInformation, - normal, - findClosest, - validateZetaParams, - fillZetaDefaults, -} from './utils'; -import seedrandom from 'seedrandom'; -import _clamp from 'lodash/clamp'; -import _cloneDeep from 'lodash/cloneDeep'; - -export const abilityPrior = normal(); - -export interface CatInput { - method?: string; - itemSelect?: string; - nStartItems?: number; - startSelect?: string; - theta?: number; - minTheta?: number; - maxTheta?: number; - prior?: number[][]; - randomSeed?: string | null; -} - -export class Cat { - public method: string; - public itemSelect: string; - public minTheta: number; - public maxTheta: number; - public prior: number[][]; - private readonly _zetas: Zeta[]; - private readonly _resps: (0 | 1)[]; - private _theta: number; - private _seMeasurement: number; - public nStartItems: number; - public startSelect: string; - private readonly _rng: ReturnType; - - /** - * Create a Cat object. This expects an single object parameter with the following keys - * @param {{method: string, itemSelect: string, nStartItems: number, startSelect:string, theta: number, minTheta: number, maxTheta: number, prior: number[][]}=} destructuredParam - * method: ability estimator, e.g. MLE or EAP, default = 'MLE' - * itemSelect: the method of item selection, e.g. "MFI", "random", "closest", default method = 'MFI' - * nStartItems: first n trials to keep non-adaptive selection - * startSelect: rule to select first n trials - * theta: initial theta estimate - * minTheta: lower bound of theta - * maxTheta: higher bound of theta - * prior: the prior distribution - * randomSeed: set a random seed to trace the simulation - */ - - constructor({ - method = 'MLE', - itemSelect = 'MFI', - nStartItems = 0, - startSelect = 'middle', - theta = 0, - minTheta = -6, - maxTheta = 6, - prior = abilityPrior, - randomSeed = null, - }: CatInput = {}) { - this.method = Cat.validateMethod(method); - - this.itemSelect = Cat.validateItemSelect(itemSelect); - - this.startSelect = Cat.validateStartSelect(startSelect); - - this.minTheta = minTheta; - this.maxTheta = maxTheta; - this.prior = prior; - this._zetas = []; - this._resps = []; - this._theta = theta; - this._seMeasurement = Number.MAX_VALUE; - this.nStartItems = nStartItems; - this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); - } - - public get theta() { - return this._theta; - } - - public get seMeasurement() { - return this._seMeasurement; - } - - /** - * Return the number of items that have been observed so far. - */ - public get nItems() { - return this._resps.length; - } - - public get resps() { - return this._resps; - } - - public get zetas() { - return this._zetas; - } - - private static validateMethod(method: string) { - const lowerMethod = method.toLowerCase(); - const validMethods: Array = ['mle', 'eap']; // TO DO: add staircase - if (!validMethods.includes(lowerMethod)) { - throw new Error('The abilityEstimator you provided is not in the list of valid methods'); - } - return lowerMethod; - } - - private static validateItemSelect(itemSelect: string) { - const lowerItemSelect = itemSelect.toLowerCase(); - const validItemSelect: Array = ['mfi', 'random', 'closest', 'fixed']; - if (!validItemSelect.includes(lowerItemSelect)) { - throw new Error('The itemSelector you provided is not in the list of valid methods'); - } - return lowerItemSelect; - } - - private static validateStartSelect(startSelect: string) { - const lowerStartSelect = startSelect.toLowerCase(); - const validStartSelect: Array = ['random', 'middle', 'fixed']; // TO DO: add staircase - if (!validStartSelect.includes(lowerStartSelect)) { - throw new Error('The startSelect you provided is not in the list of valid methods'); - } - return lowerStartSelect; - } - - /** - * use previous response patterns and item params to calculate the estimate ability based on a defined method - * @param zeta - last item param - * @param answer - last response pattern - * @param method - */ - public updateAbilityEstimate(zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method: string = this.method) { - method = Cat.validateMethod(method); - - zeta = Array.isArray(zeta) ? zeta : [zeta]; - answer = Array.isArray(answer) ? answer : [answer]; - - zeta.forEach((z) => validateZetaParams(z, true)); - - if (zeta.length !== answer.length) { - throw new Error('Unmatched length between answers and item params'); - } - this._zetas.push(...zeta); - this._resps.push(...answer); - - if (method === 'eap') { - this._theta = this.estimateAbilityEAP(); - } else if (method === 'mle') { - this._theta = this.estimateAbilityMLE(); - } - this.calculateSE(); - } - - private estimateAbilityEAP() { - let num = 0; - let nf = 0; - this.prior.forEach(([theta, probability]) => { - const like = this.likelihood(theta); - num += theta * like * probability; - nf += like * probability; - }); - - return num / nf; - } - - private estimateAbilityMLE() { - const theta0 = [0]; - const solution = minimize_Powell(this.negLikelihood.bind(this), theta0); - const theta = solution.argument[0]; - return _clamp(theta, this.minTheta, this.maxTheta); - } - - private negLikelihood(thetaArray: Array) { - return -this.likelihood(thetaArray[0]); - } - - private likelihood(theta: number) { - return this._zetas.reduce((acc, zeta, i) => { - const irf = itemResponseFunction(theta, zeta); - return this._resps[i] === 1 ? acc + Math.log(irf) : acc + Math.log(1 - irf); - }, 1); - } - - /** - * calculate the standard error of ability estimation - */ - private calculateSE() { - const sum = this._zetas.reduce((previousValue, zeta) => previousValue + fisherInformation(this._theta, zeta), 0); - this._seMeasurement = 1 / Math.sqrt(sum); - } - - /** - * find the next available item from an input array of stimuli based on a selection method - * - * remainingStimuli is sorted by fisher information to reduce the computation complexity for future item selection - * @param stimuli - an array of stimulus - * @param itemSelect - the item selection method - * @param deepCopy - default deepCopy = true - * @returns {nextStimulus: Stimulus, remainingStimuli: Array} - */ - public findNextItem(stimuli: Stimulus[], itemSelect: string = this.itemSelect, deepCopy = true) { - let arr: Array; - let selector = Cat.validateItemSelect(itemSelect); - if (deepCopy) { - arr = _cloneDeep(stimuli); - } else { - arr = stimuli; - } - - arr = arr.map((stim) => fillZetaDefaults(stim, 'semantic')); - - if (this.nItems < this.nStartItems) { - selector = this.startSelect; - } - if (selector !== 'mfi' && selector !== 'fixed') { - // for mfi, we sort the arr by fisher information in the private function to select the best item, - // and then sort by difficulty to return the remainingStimuli - // for fixed, we want to keep the corpus order as input - arr.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!); - } - - if (selector === 'middle') { - // middle will only be used in startSelect - return this.selectorMiddle(arr); - } else if (selector === 'closest') { - return this.selectorClosest(arr); - } else if (selector === 'random') { - return this.selectorRandom(arr); - } else if (selector === 'fixed') { - return this.selectorFixed(arr); - } else { - return this.selectorMFI(arr); - } - } - - private selectorMFI(inputStimuli: Stimulus[]) { - const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic')); - const stimuliAddFisher = stimuli.map((element: Stimulus) => ({ - fisherInformation: fisherInformation(this._theta, fillZetaDefaults(element, 'symbolic')), - ...element, - })); - - stimuliAddFisher.sort((a, b) => b.fisherInformation - a.fisherInformation); - stimuliAddFisher.forEach((stimulus: Stimulus) => { - delete stimulus['fisherInformation']; - }); - return { - nextStimulus: stimuliAddFisher[0], - remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!), - }; - } - - private selectorMiddle(arr: Stimulus[]) { - let index: number; - index = Math.floor(arr.length / 2); - - if (arr.length >= this.nStartItems) { - index += this.randomInteger(-Math.floor(this.nStartItems / 2), Math.floor(this.nStartItems / 2)); - } - - const nextItem = arr[index]; - arr.splice(index, 1); - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - private selectorClosest(arr: Stimulus[]) { - //findClosest requires arr is sorted by difficulty - const index = findClosest(arr, this._theta + 0.481); - const nextItem = arr[index]; - arr.splice(index, 1); - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - private selectorRandom(arr: Stimulus[]) { - const index = this.randomInteger(0, arr.length - 1); - const nextItem = arr.splice(index, 1)[0]; - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - /** - * Picks the next item in line from the given list of stimuli. - * It grabs the first item from the list, removes it, and then returns it along with the rest of the list. - * - * @param arr - The list of stimuli to choose from. - * @returns {Object} - An object with the next item and the updated list. - * @returns {Stimulus} return.nextStimulus - The item that was picked from the list. - * @returns {Stimulus[]} return.remainingStimuli - The list of what's left after picking the item. - */ - private selectorFixed(arr: Stimulus[]) { - const nextItem = arr.shift(); - return { - nextStimulus: nextItem, - remainingStimuli: arr, - }; - } - - /** - * return a random integer between min and max - * @param min - The minimum of the random number range (include) - * @param max - The maximum of the random number range (include) - * @returns {number} - random integer within the range - */ - private randomInteger(min: number, max: number) { - return Math.floor(this._rng() * (max - min + 1)) + min; - } -} +export { Cat, CatInput } from './cat'; +export { Clowder, ClowderInput } from './clowder'; diff --git a/src/utils.ts b/src/utils.ts index bf70927..b2c276f 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,147 +1,7 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import bs from 'binary-search'; -import { MultiZetaStimulus, Stimulus, Zeta, ZetaSymbolic } from './type'; -import _flatten from 'lodash/flatten'; -import _invert from 'lodash/invert'; -import _mapKeys from 'lodash/mapKeys'; -import _union from 'lodash/union'; -import _uniq from 'lodash/uniq'; - -// TODO: Document this -/** - * A constant map from the symbolic item parameter names to their semantic - * counterparts. - */ -export const ZETA_KEY_MAP = { - a: 'discrimination', - b: 'difficulty', - c: 'guessing', - d: 'slipping', -}; - -/** - * Return default item parameters (i.e., zeta) - * - * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. - * @returns {Zeta} the default zeta object in the specified format. - */ -export const defaultZeta = (desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { - const defaultZeta: Zeta = { - a: 1, - b: 0, - c: 0, - d: 1, - }; - - return convertZeta(defaultZeta, desiredFormat); -}; - -/** - * Validates the item (a.k.a. zeta) parameters, prohibiting redundant keys and - * optionally requiring all parameters. - * - * @param {Zeta} zeta - The zeta parameters to validate. - * @param {boolean} requireAll - If `true`, ensures that all required keys are present. Default is `false`. - * - * @throws {Error} Will throw an error if any of the validation rules are violated. - */ -export const validateZetaParams = (zeta: Zeta, requireAll = false): void => { - if (zeta.a !== undefined && zeta.discrimination !== undefined) { - throw new Error('This item has both an `a` key and `discrimination` key. Please provide only one.'); - } - - if (zeta.b !== undefined && zeta.difficulty !== undefined) { - throw new Error('This item has both a `b` key and `difficulty` key. Please provide only one.'); - } - - if (zeta.c !== undefined && zeta.guessing !== undefined) { - throw new Error('This item has both a `c` key and `guessing` key. Please provide only one.'); - } - - if (zeta.d !== undefined && zeta.slipping !== undefined) { - throw new Error('This item has both a `d` key and `slipping` key. Please provide only one.'); - } - - if (requireAll) { - if (zeta.a === undefined && zeta.discrimination === undefined) { - throw new Error('This item is missing the key `a` or `discrimination`.'); - } - - if (zeta.b === undefined && zeta.difficulty === undefined) { - throw new Error('This item is missing the key `b` or `difficulty`.'); - } - - if (zeta.c === undefined && zeta.guessing === undefined) { - throw new Error('This item is missing the key `c` or `guessing`.'); - } - - if (zeta.d === undefined && zeta.slipping === undefined) { - throw new Error('This item is missing the key `d` or `slipping`.'); - } - } -}; - -/** - * Fills in default zeta parameters for any missing keys in the provided zeta object. - * - * @remarks - * This function merges the provided zeta object with the default zeta object, converting - * the keys to the desired format if specified. If no desired format is provided, the - * keys will remain in their original format. - * - * @param {Zeta} zeta - The zeta parameters to fill in defaults for. - * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Default is 'symbolic'. - * - * @returns A new zeta object with default values filled in for any missing keys, - * and converted to the desired format if specified. - */ -export const fillZetaDefaults = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic' = 'symbolic'): Zeta => { - return { - ...defaultZeta(desiredFormat), - ...convertZeta(zeta, desiredFormat), - }; -}; - -/** - * Converts zeta parameters between symbolic and semantic formats. - * - * @remarks - * This function takes a zeta object and a desired format as input. It converts - * the keys of the zeta object from their current format to the desired format. - * If the desired format is 'symbolic', the function maps the keys to their - * symbolic counterparts using the `ZETA_KEY_MAP`. If the desired format is - * 'semantic', the function maps the keys to their semantic counterparts using - * the inverse of `ZETA_KEY_MAP`. - * - * @param {Zeta} zeta - The zeta parameters to convert. - * @param {'symbolic' | 'semantic'} desiredFormat - The desired format for the output zeta object. Must be either 'symbolic' or 'semantic'. - * - * @throws {Error} - Will throw an error if the desired format is not 'symbolic' or 'semantic'. - * - * @returns {Zeta} A new zeta object with keys converted to the desired format. - */ -export const convertZeta = (zeta: Zeta, desiredFormat: 'symbolic' | 'semantic'): Zeta => { - if (!['symbolic', 'semantic'].includes(desiredFormat)) { - throw new Error(`Invalid desired format. Expected 'symbolic' or'semantic'. Received ${desiredFormat} instead.`); - } - - return _mapKeys(zeta, (value, key) => { - if (desiredFormat === 'symbolic') { - const inverseMap = _invert(ZETA_KEY_MAP); - if (key in inverseMap) { - return inverseMap[key]; - } else { - return key; - } - } else { - if (key in ZETA_KEY_MAP) { - return ZETA_KEY_MAP[key as keyof typeof ZETA_KEY_MAP]; - } else { - return key; - } - } - }); -}; +import { Stimulus, Zeta, ZetaSymbolic } from './type'; +import { fillZetaDefaults } from './corpus'; /** * Calculates the probability that someone with a given ability level theta will @@ -238,107 +98,3 @@ export const findClosest = (inputStimuli: Array, target: number) => { } } }; - -/** - * Validates a corpus of multi-zeta stimuli to ensure that no cat names are - * duplicated. - * - * @remarks - * This function takes an array of `MultiZetaStimulus` objects, where each - * object represents an item containing item parameters (zetas) associated with - * different CAT models. The function checks for any duplicate cat names across - * each item's array of zeta values. It throws an error if any are found. - * - * @param {MultiZetaStimulus[]} corpus - An array of `MultiZetaStimulus` objects representing the corpora to validate. - * - * @throws {Error} - Throws an error if any duplicate cat names are found across the corpora. - */ -export const checkNoDuplicateCatNames = (corpus: MultiZetaStimulus[]): void => { - const zetaCatMapsArray = corpus.map((item) => item.zetas); - for (const zetaCatMaps of zetaCatMapsArray) { - const cats = zetaCatMaps.map(({ cats }) => cats); - - // Check to see if there are any duplicate names by comparing the union - // (which removed duplicates) to the flattened array. - const union = _union(...cats); - const flattened = _flatten(cats); - - if (union.length !== flattened.length) { - // If there are duplicates, remove the first occurence of each cat name in - // the union array from the flattened array. The remaining items in the - // flattened array should contain the duplicated cat names. - for (const cat of union) { - const idx = flattened.findIndex((c) => c === cat); - if (idx >= 0) { - flattened.splice(idx, 1); - } - } - - throw new Error(`The cat names ${_uniq(flattened).join(', ')} are present in multiple corpora.`); - } - } -}; - -/** - * Filters a list of multi-zeta stimuli based on the availability of model parameters for a specific CAT. - * - * This function takes an array of `MultiZetaStimulus` objects and a `catName` as input. It then filters - * the items based on whether the specified CAT model parameter is present in the item's zeta values. - * The function returns an object containing two arrays: `available` and `missing`. The `available` array - * contains items where the specified CAT model parameter is present, while the `missing` array contains - * items where the parameter is not present. - * - * @param {MultiZetaStimulus[]} items - An array of `MultiZetaStimulus` objects representing the stimuli to filter. - * @param {string} catName - The name of the CAT model parameter to check for. - * - * @returns An object with two arrays: `available` and `missing`. - * - * @example - * ```typescript - * const items: MultiZetaStimulus[] = [ - * { - * stimulus: 'Item 1', - * zetas: [ - * { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, - * { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, - * ], - * }, - * { - * stimulus: 'Item 2', - * zetas: [ - * { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, - * ], - * }, - * ]; - * - * const result = filterItemsByCatParameterAvailability(items, 'Model A'); - * console.log(result.available); - * // Output: [ - * // { - * // stimulus: 'Item 1', - * // zetas: [ - * // { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, - * // { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, - * // ], - * // }, - * // ] - * console.log(result.missing); - * // Output: [ - * // { - * // stimulus: 'Item 2', - * // zetas: [ - * // { cats: ['Model B', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }, - * // ], - * // }, - * // ] - * ``` - */ -export const filterItemsByCatParameterAvailability = (items: MultiZetaStimulus[], catName: string) => { - const paramsExist = items.filter((item) => item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); - const paramsMissing = items.filter((item) => !item.zetas.some((zetaCatMap) => zetaCatMap.cats.includes(catName))); - - return { - available: paramsExist, - missing: paramsMissing, - }; -}; From 8ac9f444660f2c1f9299f45620c8de55df4bb2e4 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Fri, 20 Sep 2024 06:27:08 -0700 Subject: [PATCH 16/47] Don't export abilityPrior --- src/cat.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cat.ts b/src/cat.ts index 47a86f9..5344f26 100644 --- a/src/cat.ts +++ b/src/cat.ts @@ -7,7 +7,7 @@ import seedrandom from 'seedrandom'; import _clamp from 'lodash/clamp'; import _cloneDeep from 'lodash/cloneDeep'; -export const abilityPrior = normal(); +const abilityPrior = normal(); export interface CatInput { method?: string; From 668d68b2453034b9c76820bd7fbc597a3894d653 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Fri, 20 Sep 2024 06:31:06 -0700 Subject: [PATCH 17/47] Update readme --- README.md | 14 +++++++++----- package-lock.json | 22 +++++----------------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index a3d13af..41c17f7 100644 --- a/README.md +++ b/README.md @@ -42,22 +42,26 @@ const stimuli = [{difficulty: -3, item: 'item1'}, {difficulty: -2, item: 'item2 const nextItem = cat.findNextItem(stimuli, 'MFI'); ``` -## Validations +## Validation + ### Validation of theta estimate and theta standard error + Reference software: mirt (Chalmers, 2012) ![img.png](validation/plots/jsCAT_validation_1.png) ### Validation of MFI algorithm + Reference software: catR (Magis et al., 2017) ![img_1.png](validation/plots/jsCAT_validation_2.png) - ## References -Chalmers, R. P. (2012). mirt: A multidimensional item response theory package for the R environment. Journal of Statistical Software. -Magis, D., & Barrada, J. R. (2017). Computerized adaptive testing with R: Recent updates of the package catR. Journal of Statistical Software, 76, 1-19. +- Chalmers, R. P. (2012). mirt: A multidimensional item response theory package for the R environment. Journal of Statistical Software. + +- Magis, D., & Barrada, J. R. (2017). Computerized adaptive testing with R: Recent updates of the package catR. Journal of Statistical Software, 76, 1-19. -Lucas Duailibe, irt-js, (2019), GitHub repository, https://github.com/geekie/irt-js +- Lucas Duailibe, irt-js, (2019), GitHub repository, https://github.com/geekie/irt-js ## License + jsCAT is distributed under the [ISC license](LICENSE). diff --git a/package-lock.json b/package-lock.json index b933c10..0c6e760 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10298,7 +10298,6 @@ }, "node_modules/npm/node_modules/lodash._baseindexof": { "version": "3.1.0", - "extraneous": true, "inBundle": true, "license": "MIT" }, @@ -10323,19 +10322,16 @@ }, "node_modules/npm/node_modules/lodash._bindcallback": { "version": "3.0.1", - "extraneous": true, "inBundle": true, "license": "MIT" }, "node_modules/npm/node_modules/lodash._cacheindexof": { "version": "3.0.2", - "extraneous": true, "inBundle": true, "license": "MIT" }, "node_modules/npm/node_modules/lodash._createcache": { "version": "3.1.2", - "extraneous": true, "inBundle": true, "license": "MIT", "dependencies": { @@ -10344,7 +10340,6 @@ }, "node_modules/npm/node_modules/lodash._getnative": { "version": "3.9.1", - "extraneous": true, "inBundle": true, "license": "MIT" }, @@ -10355,7 +10350,6 @@ }, "node_modules/npm/node_modules/lodash.restparam": { "version": "3.6.1", - "extraneous": true, "inBundle": true, "license": "MIT" }, @@ -24672,8 +24666,7 @@ }, "lodash._baseindexof": { "version": "3.1.0", - "bundled": true, - "extraneous": true + "bundled": true }, "lodash._baseuniq": { "version": "4.6.0", @@ -24695,26 +24688,22 @@ }, "lodash._bindcallback": { "version": "3.0.1", - "bundled": true, - "extraneous": true + "bundled": true }, "lodash._cacheindexof": { "version": "3.0.2", - "bundled": true, - "extraneous": true + "bundled": true }, "lodash._createcache": { "version": "3.1.2", "bundled": true, - "extraneous": true, "requires": { "lodash._getnative": "^3.0.0" } }, "lodash._getnative": { "version": "3.9.1", - "bundled": true, - "extraneous": true + "bundled": true }, "lodash.clonedeep": { "version": "4.5.0", @@ -24722,8 +24711,7 @@ }, "lodash.restparam": { "version": "3.6.1", - "bundled": true, - "extraneous": true + "bundled": true }, "lodash.union": { "version": "4.6.0", From 3550d96e119d8d628f86eca44eea6dea49728795 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Fri, 20 Sep 2024 11:28:01 -0700 Subject: [PATCH 18/47] adding missing tests to clowder --- src/__tests__/clowder.test.ts | 115 ++++++++++++++++++++++++++++++---- src/clowder.ts | 3 +- 2 files changed, 106 insertions(+), 12 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index eacaa50..67c6611 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -2,6 +2,7 @@ import { Cat } from '../cat'; import { Clowder, ClowderInput } from '../clowder'; import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; import { defaultZeta } from '../corpus'; +import _uniq from 'lodash/uniq'; const createStimulus = (id: string) => ({ ...defaultZeta(), @@ -160,28 +161,120 @@ describe('Clowder Class', () => { }); it('should select a validated item if validated items are present and randomlySelectUnvalidated is false', () => { - // TODO: Implement this test - expect(1).toBe(0); + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'EAP', theta: -1.0 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat2'])]), + ], + }; + const clowder = new Clowder(clowderInput); + + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + randomlySelectUnvalidated: false, + }); + expect(nextItem?.id).toBe('0'); }); it('should select an unvalidated item if no validated items remain', () => { - // TODO: Implement this test - expect(1).toBe(0); + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap([])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap([])]), + ], + }; + const clowder = new Clowder(clowderInput); + + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[1]], + answers: [1], + }); + expect(nextItem).toBeDefined(); + expect(['0', '2']).toContain(nextItem?.id); }); it('should correctly update ability estimates during the updateCatAndGetNextItem method', () => { - // TODO: Implement this test - expect(1).toBe(0); + const originalTheta = clowder.cats.cat1.theta; + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0]], + answers: [1], + }); + expect(clowder.cats.cat1.theta).not.toBe(originalTheta); }); it('should randomly choose between validated and unvalidated items if randomlySelectUnvalidated is true', () => { - // TODO: Implement this test - // Pass in a random seed for reproducibility - expect(1).toBe(0); + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), // Validated item + createMultiZetaStimulus('1', [createZetaCatMap([])]), // Unvalidated item + createMultiZetaStimulus('2', [createZetaCatMap([])]), // Unvalidated item + createMultiZetaStimulus('3', [createZetaCatMap([])]), // Validated item + ], + randomSeed: 'randomSeed', + }; + const clowder = new Clowder(clowderInput); + + const nextItems = Array(20) + .fill('-1') + .map(() => { + return clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + randomlySelectUnvalidated: true, + }); + }); + + const itemsId = nextItems.map((item) => item?.id); + + expect(nextItems).toBeDefined(); + expect(_uniq(itemsId)).toEqual(expect.arrayContaining(['0', '1', '2', '3'])); // Could be validated or unvalidated }); it('should return undefined if no more items remain', () => { - // TODO: Implement this test - expect(1).toBe(0); + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + items: clowder.remainingItems, + answers: [1, 0, 1, 1, 0], // Exhaust all items + }); + + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + }); + expect(nextItem).toBeUndefined(); + }); + + it('can receive one item and answer as an input', () => { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + items: clowder.corpus[0], + answers: 1, + }); + expect(nextItem).toBeDefined(); + }); + + it('can receive only one catToUpdate', () => { + const originalTheta = clowder.cats.cat1.theta; + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: 'cat1', + items: clowder.corpus[0], + answers: 1, + }); + expect(nextItem).toBeDefined(); + expect(clowder.cats.cat1.theta).not.toBe(originalTheta); }); }); diff --git a/src/clowder.ts b/src/clowder.ts index fc1b4a5..742961b 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -289,7 +289,8 @@ export class Clowder { const { zetas, ...rest } = item; const zetasForCat = zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catToSelect)); return { - ...(zetasForCat?.zeta ?? {}), + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + ...zetasForCat!.zeta, ...rest, }; }); From a1d2bf9aa86a7b8c28a322ea3a51a52731e11906 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Sun, 22 Sep 2024 17:30:16 -0700 Subject: [PATCH 19/47] Import Cat, Clowder, and ClowderInput from index --- src/__tests__/cat.test.ts | 2 +- src/__tests__/clowder.test.ts | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/__tests__/cat.test.ts b/src/__tests__/cat.test.ts index ddeb0a3..fb9cdf4 100644 --- a/src/__tests__/cat.test.ts +++ b/src/__tests__/cat.test.ts @@ -1,5 +1,5 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ -import { Cat } from '../index'; +import { Cat } from '..'; import { Stimulus } from '../type'; import seedrandom from 'seedrandom'; import { convertZeta } from '../corpus'; diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 67c6611..5e24e5e 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -1,5 +1,4 @@ -import { Cat } from '../cat'; -import { Clowder, ClowderInput } from '../clowder'; +import { Cat, Clowder, ClowderInput } from '..'; import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; import { defaultZeta } from '../corpus'; import _uniq from 'lodash/uniq'; From 8ebae678ae4c02e5dd03b1ecfde68922f3f26fc8 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Sun, 22 Sep 2024 19:19:55 -0700 Subject: [PATCH 20/47] Add src/stopping.ts --- src/clowder.ts | 10 ++-- src/stopping.ts | 143 ++++++++++++++++++++++++++++++++++++++++++++++++ src/type.ts | 4 ++ 3 files changed, 151 insertions(+), 6 deletions(-) create mode 100644 src/stopping.ts diff --git a/src/clowder.ts b/src/clowder.ts index 742961b..464092c 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -1,5 +1,5 @@ import { Cat, CatInput } from './cat'; -import { MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; +import { CatMap, MultiZetaStimulus, Stimulus, Zeta, ZetaCatMap } from './type'; import { filterItemsByCatParameterAvailability, checkNoDuplicateCatNames } from './corpus'; import _cloneDeep from 'lodash/cloneDeep'; import _differenceWith from 'lodash/differenceWith'; @@ -15,9 +15,7 @@ export interface ClowderInput { * An object containing Cat configurations for each Cat instance. * Keys correspond to Cat names, while values correspond to Cat configurations. */ - cats: { - [name: string]: CatInput; - }; + cats: CatMap; /** * An object containing arrays of stimuli for each corpus. */ @@ -37,7 +35,7 @@ export interface ClowderInput { * to the participant. */ export class Clowder { - private _cats: { [name: string]: Cat }; + private _cats: CatMap; private _corpus: MultiZetaStimulus[]; private _remainingItems: MultiZetaStimulus[]; private _seenItems: Stimulus[]; @@ -47,7 +45,7 @@ export class Clowder { * Create a Clowder object. * * @param {ClowderInput} input - An object containing arrays of Cat configurations and corpora. - * @param {CatInput[]} input.cats - An object containing Cat configurations for each Cat instance. + * @param {CatMap} input.cats - An object containing Cat configurations for each Cat instance. * @param {MultiZetaStimulus[]} input.corpus - An array of stimuli representing each corpus. * * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. diff --git a/src/stopping.ts b/src/stopping.ts new file mode 100644 index 0000000..254748d --- /dev/null +++ b/src/stopping.ts @@ -0,0 +1,143 @@ +import { Cat } from './cat'; +import { CatMap } from './type'; + +/** + * Interface for input parameters to EarlyStopping classes. + */ +export interface EarlyStoppingInput { + /** Number of items to wait for before triggering early stopping */ + patience: CatMap; + /** Tolerance for standard error of measurement drop */ + tolerance: CatMap; + /** Number of items to require before stopping */ + requiredItems: CatMap; + /** Stop if the standard error of measurement drops below this level */ + seMeasurementThreshold: CatMap; +} + +/** + * Abstract class for early stopping strategies. + */ +export abstract class EarlyStopping { + protected _earlyStop: boolean; + protected _patience: CatMap; + protected _tolerance: CatMap; + protected _requiredItems: CatMap; + protected _seMeasurementThreshold: CatMap; + protected _nItems: CatMap; + protected _seMeasurements: CatMap; + + constructor({ patience, tolerance, requiredItems, seMeasurementThreshold }: EarlyStoppingInput) { + this._patience = patience; + this._tolerance = tolerance; + this._requiredItems = requiredItems; + this._seMeasurementThreshold = seMeasurementThreshold; + this._seMeasurements = {}; + this._nItems = {}; + this._earlyStop = false; + } + + public get patience() { + return this._patience; + } + + public get tolerance() { + return this._tolerance; + } + + public get requiredItems() { + return this._requiredItems; + } + + public get seMeasurementThreshold() { + return this._seMeasurementThreshold; + } + + public get earlyStop() { + return this._earlyStop; + } + + /** + * Update the internal state of the early stopping strategy based on the provided cats. + * @param {CatMap}cats - A map of cats to update. + */ + protected _updateCats(cats: CatMap) { + for (const catName in cats) { + const cat = cats[catName]; + const nItems = cat.nItems; + const seMeasurement = cat.seMeasurement; + + if (nItems > (this._nItems[catName] ?? 0)) { + this._nItems[catName] = nItems; + this._seMeasurements[catName] = [...(this._seMeasurements[catName] ?? []), seMeasurement]; + } + } + } + + /** + * Abstract method to be implemented by subclasses to update the early stopping strategy. + * @param {CatMap} cats - A map of cats to update. + * @param {string} catToEvaluate - The name of the cat to evaluate for early stopping. + */ + public abstract update(cats: CatMap, catToEvaluate: string): void; +} + +/** + * Class implementing early stopping based on a plateau in standard error of measurement. + */ +export class StopOnSEMeasurementPlateau extends EarlyStopping { + public update(cats: CatMap, catToEvaluate: string) { + super._updateCats(cats); + + const seMeasurements = this._seMeasurements[catToEvaluate]; + const patience = this._patience[catToEvaluate]; + const tolerance = this._tolerance[catToEvaluate]; + + if (seMeasurements.length >= patience) { + const mean = seMeasurements.slice(-patience).reduce((sum, se) => sum + se, 0) / patience; + const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - mean) <= tolerance); + + if (withinTolerance) { + this._earlyStop = true; + } + } + } +} + +/** + * Class implementing early stopping after a certain number of items. + */ +export class StopAfterNItems extends EarlyStopping { + public update(cats: CatMap, catToEvaluate: string) { + super._updateCats(cats); + + const requiredItems = this._requiredItems[catToEvaluate]; + const nItems = this._nItems[catToEvaluate]; + + if (nItems >= requiredItems) { + this._earlyStop = true; + } + } +} + +/** + * Class implementing early stopping if the standard error of measurement drops below a certain threshold. + */ +export class StopIfSEMeasurementBelowThreshold extends EarlyStopping { + public update(cats: CatMap, catToEvaluate: string) { + super._updateCats(cats); + + const seMeasurements = this._seMeasurements[catToEvaluate]; + const seThreshold = this._seMeasurementThreshold[catToEvaluate]; + const patience = this._patience[catToEvaluate]; + const tolerance = this._tolerance[catToEvaluate]; + + if (seMeasurements.length >= patience) { + const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - seThreshold) <= tolerance); + + if (withinTolerance) { + this._earlyStop = true; + } + } + } +} diff --git a/src/type.ts b/src/type.ts index 522ccfd..739c65d 100644 --- a/src/type.ts +++ b/src/type.ts @@ -34,3 +34,7 @@ export interface MultiZetaStimulus { // eslint-disable-next-line @typescript-eslint/no-explicit-any [key: string]: any; } + +export type CatMap = { + [name: string]: T; +}; From 7d4a1a9c299b9aa55f620da759a2b3c41a789a63 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 23 Sep 2024 06:28:40 -0700 Subject: [PATCH 21/47] Add tests for stopping.ts --- package-lock.json | 302 +++++++++++++++++++ package.json | 1 + src/__tests__/stopping.test.ts | 526 +++++++++++++++++++++++++++++++++ src/stopping.ts | 38 ++- 4 files changed, 853 insertions(+), 14 deletions(-) create mode 100644 src/__tests__/stopping.test.ts diff --git a/package-lock.json b/package-lock.json index 0c6e760..144274e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -24,6 +24,7 @@ "eslint": "^8.20.0", "eslint-config-prettier": "^8.5.0", "jest": "^28.1.3", + "jest-extended": "^4.0.2", "prettier": "^2.7.1", "ts-jest": "^28.0.8", "tsdoc": "^0.0.4", @@ -6944,6 +6945,188 @@ "node": "^12.13.0 || ^14.15.0 || ^16.10.0 || >=17.0.0" } }, + "node_modules/jest-extended": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/jest-extended/-/jest-extended-4.0.2.tgz", + "integrity": "sha512-FH7aaPgtGYHc9mRjriS0ZEHYM5/W69tLrFTIdzm+yJgeoCmmrSB/luSfMSqWP9O29QWHPEmJ4qmU6EwsZideog==", + "dev": true, + "license": "MIT", + "dependencies": { + "jest-diff": "^29.0.0", + "jest-get-type": "^29.0.0" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + }, + "peerDependencies": { + "jest": ">=27.2.5" + }, + "peerDependenciesMeta": { + "jest": { + "optional": true + } + } + }, + "node_modules/jest-extended/node_modules/@jest/schemas": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/@jest/schemas/-/schemas-29.6.3.tgz", + "integrity": "sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@sinclair/typebox": "^0.27.8" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/@sinclair/typebox": { + "version": "0.27.8", + "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", + "integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==", + "dev": true, + "license": "MIT" + }, + "node_modules/jest-extended/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/jest-extended/node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/jest-extended/node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/jest-extended/node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true, + "license": "MIT" + }, + "node_modules/jest-extended/node_modules/diff-sequences": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/diff-sequences/-/diff-sequences-29.6.3.tgz", + "integrity": "sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/jest-extended/node_modules/jest-diff": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/jest-diff/-/jest-diff-29.7.0.tgz", + "integrity": "sha512-LMIgiIrhigmPrs03JHpxUh2yISK3vLFPkAodPeo0+BuF7wA2FoQbkEg1u8gBYBThncu7e1oEDUfIXVuTqLRUjw==", + "dev": true, + "license": "MIT", + "dependencies": { + "chalk": "^4.0.0", + "diff-sequences": "^29.6.3", + "jest-get-type": "^29.6.3", + "pretty-format": "^29.7.0" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/jest-get-type": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-29.6.3.tgz", + "integrity": "sha512-zrteXnqYxfQh7l5FHyL38jL39di8H8rHoecLH3JNxH3BwOrBsNeabdap5e0I23lD4HHI8W5VFBZqG4Eaq5LNcw==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/pretty-format": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", + "integrity": "sha512-Pdlw/oPxN+aXdmM9R00JVC9WVFoCLTKJvDVLgmJ+qAffBMxsV85l/Lu7sNx4zSzPyoL2euImuEwHhOXdEgNFZQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jest/schemas": "^29.6.3", + "ansi-styles": "^5.0.0", + "react-is": "^18.0.0" + }, + "engines": { + "node": "^14.15.0 || ^16.10.0 || >=18.0.0" + } + }, + "node_modules/jest-extended/node_modules/pretty-format/node_modules/ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/jest-extended/node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/jest-get-type": { "version": "28.0.2", "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-28.0.2.tgz", @@ -22119,6 +22302,125 @@ "jest-util": "^28.1.3" } }, + "jest-extended": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/jest-extended/-/jest-extended-4.0.2.tgz", + "integrity": "sha512-FH7aaPgtGYHc9mRjriS0ZEHYM5/W69tLrFTIdzm+yJgeoCmmrSB/luSfMSqWP9O29QWHPEmJ4qmU6EwsZideog==", + "dev": true, + "requires": { + "jest-diff": "^29.0.0", + "jest-get-type": "^29.0.0" + }, + "dependencies": { + "@jest/schemas": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/@jest/schemas/-/schemas-29.6.3.tgz", + "integrity": "sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==", + "dev": true, + "requires": { + "@sinclair/typebox": "^0.27.8" + } + }, + "@sinclair/typebox": { + "version": "0.27.8", + "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", + "integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==", + "dev": true + }, + "ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "requires": { + "color-convert": "^2.0.1" + } + }, + "chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "requires": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + } + }, + "color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "requires": { + "color-name": "~1.1.4" + } + }, + "color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + }, + "diff-sequences": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/diff-sequences/-/diff-sequences-29.6.3.tgz", + "integrity": "sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==", + "dev": true + }, + "has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true + }, + "jest-diff": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/jest-diff/-/jest-diff-29.7.0.tgz", + "integrity": "sha512-LMIgiIrhigmPrs03JHpxUh2yISK3vLFPkAodPeo0+BuF7wA2FoQbkEg1u8gBYBThncu7e1oEDUfIXVuTqLRUjw==", + "dev": true, + "requires": { + "chalk": "^4.0.0", + "diff-sequences": "^29.6.3", + "jest-get-type": "^29.6.3", + "pretty-format": "^29.7.0" + } + }, + "jest-get-type": { + "version": "29.6.3", + "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-29.6.3.tgz", + "integrity": "sha512-zrteXnqYxfQh7l5FHyL38jL39di8H8rHoecLH3JNxH3BwOrBsNeabdap5e0I23lD4HHI8W5VFBZqG4Eaq5LNcw==", + "dev": true + }, + "pretty-format": { + "version": "29.7.0", + "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", + "integrity": "sha512-Pdlw/oPxN+aXdmM9R00JVC9WVFoCLTKJvDVLgmJ+qAffBMxsV85l/Lu7sNx4zSzPyoL2euImuEwHhOXdEgNFZQ==", + "dev": true, + "requires": { + "@jest/schemas": "^29.6.3", + "ansi-styles": "^5.0.0", + "react-is": "^18.0.0" + }, + "dependencies": { + "ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "dev": true + } + } + }, + "supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "requires": { + "has-flag": "^4.0.0" + } + } + } + }, "jest-get-type": { "version": "28.0.2", "resolved": "https://registry.npmjs.org/jest-get-type/-/jest-get-type-28.0.2.tgz", diff --git a/package.json b/package.json index b5f68ce..9c10826 100644 --- a/package.json +++ b/package.json @@ -41,6 +41,7 @@ "eslint": "^8.20.0", "eslint-config-prettier": "^8.5.0", "jest": "^28.1.3", + "jest-extended": "^4.0.2", "prettier": "^2.7.1", "ts-jest": "^28.0.8", "tsdoc": "^0.0.4", diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts new file mode 100644 index 0000000..abfdd69 --- /dev/null +++ b/src/__tests__/stopping.test.ts @@ -0,0 +1,526 @@ +import { Cat } from '..'; +import { CatMap } from '../type'; +import { + EarlyStopping, + EarlyStoppingInput, + StopAfterNItems, + StopIfSEMeasurementBelowThreshold, + StopOnSEMeasurementPlateau, +} from '../stopping'; +import { toBeBoolean } from 'jest-extended'; +expect.extend({ toBeBoolean }); + +const testInstantiation = (earlyStopping: EarlyStopping, input: EarlyStoppingInput) => { + expect(earlyStopping.patience).toEqual(input.patience ?? {}); + expect(earlyStopping.tolerance).toEqual(input.tolerance ?? {}); + expect(earlyStopping.requiredItems).toEqual(input.requiredItems ?? {}); + expect(earlyStopping.seMeasurementThreshold).toEqual(input.seMeasurementThreshold ?? {}); + expect(earlyStopping.earlyStop).toBeBoolean(); +}; + +const testInternalState = (earlyStopping: EarlyStopping) => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat1'); + expect(earlyStopping.nItems.cat1).toBe(1); + expect(earlyStopping.seMeasurements.cat1).toEqual([0.5]); + expect(earlyStopping.nItems.cat2).toBe(1); + expect(earlyStopping.seMeasurements.cat2).toEqual([0.3]); + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats2, 'cat1'); + expect(earlyStopping.nItems.cat1).toBe(2); + expect(earlyStopping.seMeasurements.cat1).toEqual([0.5, 0.5]); + expect(earlyStopping.nItems.cat2).toBe(2); + expect(earlyStopping.seMeasurements.cat2).toEqual([0.3, 0.3]); +}; + +const testNoStoppingOnInvalidCatName = (earlyStopping: EarlyStopping) => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'invalidCatName'); + expect(earlyStopping.earlyStop).toBe(false); +}; + +describe('StopOnSEMeasurementPlateau', () => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: EarlyStoppingInput; + + beforeEach(() => { + input = { + patience: { cat1: 2, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); + + it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); + + it('does not stop on invalid cat name', () => testNoStoppingOnInvalidCatName(earlyStopping)); + + it('stops when the seMeasurement has plateaued', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat1'); + expect(earlyStopping.earlyStop).toBe(false); + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats2, 'cat1'); + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('does not stop when the seMeasurement has not plateaued', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat1'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats2, 'cat1'); + expect(earlyStopping.earlyStop).toBe(false); + }); + + it('waits for `patience` items to monitor the seMeasurement plateau', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + const cats3: CatMap = { + cat1: { + nItems: 3, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats2, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats3, 'cat2'); + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('triggers early stopping when within tolerance', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.4, + } as Cat, + }; + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.395, + } as Cat, + }; + + const cats3: CatMap = { + cat1: { + nItems: 3, + seMeasurement: 0.0001, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.39, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats2, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats3, 'cat2'); + expect(earlyStopping.earlyStop).toBe(true); + }); +}); + +describe('StopAfterNItems', () => { + let earlyStopping: StopAfterNItems; + let input: EarlyStoppingInput; + + beforeEach(() => { + input = { + requiredItems: { cat1: 2, cat2: 3 }, + }; + earlyStopping = new StopAfterNItems(input); + }); + + it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); + + it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); + + it('does not stop on invalid cat name', () => testNoStoppingOnInvalidCatName(earlyStopping)); + + it('does not step when it has not seen required items', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats2, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + }); + + it('stops when it has seen required items', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + const cats3: CatMap = { + cat1: { + nItems: 3, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats2, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats3, 'cat2'); + expect(earlyStopping.earlyStop).toBe(true); + }); +}); + +describe('StopIfSEMeasurementBelowThreshold', () => { + let earlyStopping: StopIfSEMeasurementBelowThreshold; + let input: EarlyStoppingInput; + + beforeEach(() => { + input = { + patience: { cat1: 1, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + seMeasurementThreshold: { cat1: 0.03, cat2: 0.02 }, + }; + earlyStopping = new StopIfSEMeasurementBelowThreshold(input); + }); + + it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); + + it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); + + it('does not stop on invalid cat name', () => testNoStoppingOnInvalidCatName(earlyStopping)); + + it('stops when the seMeasurement has fallen below a threshold', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat1'); + expect(earlyStopping.earlyStop).toBe(false); + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats2, 'cat1'); + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('does not stop when the seMeasurement is above threshold', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat1'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats2, 'cat1'); + expect(earlyStopping.earlyStop).toBe(false); + }); + + it('waits for `patience` items to monitor the seMeasurement plateau', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }; + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.01, + } as Cat, + }; + + const cats3: CatMap = { + cat1: { + nItems: 3, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.01, + } as Cat, + }; + + const cats4: CatMap = { + cat1: { + nItems: 4, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats2, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats3, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats4, 'cat2'); + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('triggers early stopping when within tolerance', () => { + const cats1: CatMap = { + cat1: { + nItems: 1, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.4, + } as Cat, + }; + + const cats2: CatMap = { + cat1: { + nItems: 2, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + }; + + const cats3: CatMap = { + cat1: { + nItems: 3, + seMeasurement: 0.0001, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.04, + } as Cat, + }; + + const cats4: CatMap = { + cat1: { + nItems: 4, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, + }; + + earlyStopping.update(cats1, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats2, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats3, 'cat2'); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(cats4, 'cat2'); + expect(earlyStopping.earlyStop).toBe(true); + }); +}); diff --git a/src/stopping.ts b/src/stopping.ts index 254748d..14aaeb0 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -6,13 +6,13 @@ import { CatMap } from './type'; */ export interface EarlyStoppingInput { /** Number of items to wait for before triggering early stopping */ - patience: CatMap; + patience?: CatMap; /** Tolerance for standard error of measurement drop */ - tolerance: CatMap; + tolerance?: CatMap; /** Number of items to require before stopping */ - requiredItems: CatMap; + requiredItems?: CatMap; /** Stop if the standard error of measurement drops below this level */ - seMeasurementThreshold: CatMap; + seMeasurementThreshold?: CatMap; } /** @@ -27,7 +27,7 @@ export abstract class EarlyStopping { protected _nItems: CatMap; protected _seMeasurements: CatMap; - constructor({ patience, tolerance, requiredItems, seMeasurementThreshold }: EarlyStoppingInput) { + constructor({ patience = {}, tolerance = {}, requiredItems = {}, seMeasurementThreshold = {} }: EarlyStoppingInput) { this._patience = patience; this._tolerance = tolerance; this._requiredItems = requiredItems; @@ -57,6 +57,14 @@ export abstract class EarlyStopping { return this._earlyStop; } + public get nItems() { + return this._nItems; + } + + public get seMeasurements() { + return this._seMeasurements; + } + /** * Update the internal state of the early stopping strategy based on the provided cats. * @param {CatMap}cats - A map of cats to update. @@ -89,9 +97,11 @@ export class StopOnSEMeasurementPlateau extends EarlyStopping { public update(cats: CatMap, catToEvaluate: string) { super._updateCats(cats); - const seMeasurements = this._seMeasurements[catToEvaluate]; - const patience = this._patience[catToEvaluate]; - const tolerance = this._tolerance[catToEvaluate]; + const seMeasurements = this._seMeasurements[catToEvaluate] ?? []; + + // Use MAX_SAFE_INTEGER and MAX_VALUE to prevent early stopping if the `catToEvaluate` is missing from the cats map. + const patience = this._patience[catToEvaluate] ?? Number.MAX_SAFE_INTEGER; + const tolerance = this._tolerance[catToEvaluate] ?? Number.MAX_VALUE; if (seMeasurements.length >= patience) { const mean = seMeasurements.slice(-patience).reduce((sum, se) => sum + se, 0) / patience; @@ -111,8 +121,8 @@ export class StopAfterNItems extends EarlyStopping { public update(cats: CatMap, catToEvaluate: string) { super._updateCats(cats); - const requiredItems = this._requiredItems[catToEvaluate]; - const nItems = this._nItems[catToEvaluate]; + const requiredItems = this._requiredItems[catToEvaluate] ?? Number.MAX_SAFE_INTEGER; + const nItems = this._nItems[catToEvaluate] ?? 0; if (nItems >= requiredItems) { this._earlyStop = true; @@ -127,10 +137,10 @@ export class StopIfSEMeasurementBelowThreshold extends EarlyStopping { public update(cats: CatMap, catToEvaluate: string) { super._updateCats(cats); - const seMeasurements = this._seMeasurements[catToEvaluate]; - const seThreshold = this._seMeasurementThreshold[catToEvaluate]; - const patience = this._patience[catToEvaluate]; - const tolerance = this._tolerance[catToEvaluate]; + const seMeasurements = this._seMeasurements[catToEvaluate] ?? []; + const seThreshold = this._seMeasurementThreshold[catToEvaluate] ?? 0; + const patience = this._patience[catToEvaluate] ?? 1; + const tolerance = this._tolerance[catToEvaluate] ?? 0; if (seMeasurements.length >= patience) { const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - seThreshold) <= tolerance); From fb30ac5fc1c3e008ebe8abd1540d8f40972b3caf Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 23 Sep 2024 06:31:20 -0700 Subject: [PATCH 22/47] Add earlyStopping input to Clowder --- src/clowder.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/clowder.ts b/src/clowder.ts index 464092c..7c38af9 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -9,6 +9,7 @@ import _omit from 'lodash/omit'; import _unzip from 'lodash/unzip'; import _zip from 'lodash/zip'; import seedrandom from 'seedrandom'; +import { EarlyStopping } from './stopping'; export interface ClowderInput { /** @@ -24,6 +25,10 @@ export interface ClowderInput { * A random seed for reproducibility. If not provided, a random seed will be generated. */ randomSeed?: string | null; + /** + * An optional EarlyStopping instance to use for early stopping. + */ + earlyStopping?: EarlyStopping; } /** @@ -39,6 +44,7 @@ export class Clowder { private _corpus: MultiZetaStimulus[]; private _remainingItems: MultiZetaStimulus[]; private _seenItems: Stimulus[]; + private _earlyStopping?: EarlyStopping; private readonly _rng: ReturnType; /** @@ -50,7 +56,7 @@ export class Clowder { * * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. */ - constructor({ cats, corpus, randomSeed = null }: ClowderInput) { + constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) { // TODO: Need to pass in numItemsRequired so that we know when to stop // providing new items. This may depend on the cat name. For instance, // perhaps numItemsRequired should be an object with cat names as keys and @@ -61,6 +67,7 @@ export class Clowder { this._corpus = corpus; this._remainingItems = _cloneDeep(corpus); this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); + this._earlyStopping = earlyStopping; } /** From 03d08241dc38da4f3d750e50c2ea8898e03379f4 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 23 Sep 2024 10:59:49 -0700 Subject: [PATCH 23/47] prepareClowderCorpus --- src/__tests__/corpus.test.ts | 107 +++++++++++++++++++++++++++++++++++ src/corpus.ts | 58 ++++++++++++++++++- 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/src/__tests__/corpus.test.ts b/src/__tests__/corpus.test.ts index e057937..e01f6a6 100644 --- a/src/__tests__/corpus.test.ts +++ b/src/__tests__/corpus.test.ts @@ -7,6 +7,7 @@ import { convertZeta, checkNoDuplicateCatNames, filterItemsByCatParameterAvailability, + prepareClowderCorpus, } from '../corpus'; import _omit from 'lodash/omit'; @@ -321,4 +322,110 @@ describe('filterItemsByCatParameterAvailability', () => { expect(result.missing.length).toBe(1); expect(result.missing[0].stimulus).toBe('Item 2'); }); + + describe('prepareClowderCorpus', () => { + it('converts a Stimulus array to a MultiZetaStimulus array with symbolic format', () => { + const items: Stimulus[] = [ + { + 'cat1.a': 1, + 'cat1.b': 2, + 'cat1.c': 3, + 'cat1.d': 4, + 'foo.a': 5, + 'foo.b': 6, + 'foo.c': 7, + 'foo.d': 8, + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.'); + + expect(result).toEqual([ + { + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { a: 1, b: 2, c: 3, d: 4 }, + }, + { + cats: ['foo'], + zeta: { a: 5, b: 6, c: 7, d: 8 }, + }, + ], + }, + ]); + }); + + it('converts a Stimulus array to a MultiZetaStimulus array with semantic format', () => { + const items: Stimulus[] = [ + { + 'cat1.a': 1, + 'cat1.b': 2, + 'cat1.c': 3, + 'cat1.d': 4, + 'foo.a': 5, + 'foo.b': 6, + 'foo.c': 7, + 'foo.d': 8, + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.', 'semantic'); + + expect(result).toEqual([ + { + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { discrimination: 1, difficulty: 2, guessing: 3, slipping: 4 }, + }, + { + cats: ['foo'], + zeta: { discrimination: 5, difficulty: 6, guessing: 7, slipping: 8 }, + }, + ], + }, + ]); + }); + + it('handles cases with different delimiters', () => { + const items: Stimulus[] = [ + { + cat1_a: 1, + cat1_b: 2, + foo_a: 5, + foo_b: 6, + stimulus: 'stim1', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '_', 'symbolic'); + + expect(result).toEqual([ + { + stimulus: 'stim1', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { a: 1, b: 2 }, + }, + { + cats: ['foo'], + zeta: { a: 5, b: 6 }, + }, + ], + }, + ]); + }); + }); }); diff --git a/src/corpus.ts b/src/corpus.ts index 2df7aa7..b1ac0d3 100644 --- a/src/corpus.ts +++ b/src/corpus.ts @@ -1,10 +1,11 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ -import { MultiZetaStimulus, Zeta } from './type'; +import { MultiZetaStimulus, Stimulus, Zeta } from './type'; import _flatten from 'lodash/flatten'; import _invert from 'lodash/invert'; import _mapKeys from 'lodash/mapKeys'; import _union from 'lodash/union'; import _uniq from 'lodash/uniq'; +import _omit from 'lodash/omit'; /** * A constant map from the symbolic item parameter names to their semantic @@ -244,3 +245,58 @@ export const filterItemsByCatParameterAvailability = (items: MultiZetaStimulus[] missing: paramsMissing, }; }; + +/** + * Converts an array of Stimulus objects into an array of MultiZetaStimulus objects. + * The user specifies cat names and a delimiter to identify and group parameters. + * + * @param {Stimulus[]} items - An array of stimuli, where each stimulus contains parameters + * for different CAT instances. + * @param {string[]} catNames - A list of CAT names to be mapped to their corresponding zeta values. + * @param {string} delimiter - A delimiter used to separate CAT instance names from the parameter keys in the stimulus object. + * @param {'symbolic' | 'semantic'} itemParameterFormat - Defines the format to convert zeta values ('symbolic' or 'semantic'). + * @returns {MultiZetaStimulus[]} - An array of MultiZetaStimulus objects, each containing + * the cleaned stimulus and associated zeta values for each CAT instance. + * + * This function iterates through each stimulus, extracts parameters relevant to the specified + * CAT instances, converts them to the desired format, and returns a cleaned structure of stimuli + * with the associated zeta values. + */ +export const prepareClowderCorpus = ( + items: Stimulus[], + catNames: string[], + delimiter: '.' | string, + itemParameterFormat: 'symbolic' | 'semantic' = 'symbolic', +): MultiZetaStimulus[] => { + return items.map((item) => { + const zetas = catNames + .map((cat) => { + const zeta: Zeta = {}; + + // Extract parameters that match the category + Object.keys(item).forEach((key) => { + if (key.startsWith(cat + delimiter)) { + const paramKey = key.split(delimiter)[1]; + zeta[paramKey as keyof Zeta] = item[key]; + } + }); + + return { + cats: [cat], + zeta: convertZeta(zeta, itemParameterFormat), + }; + }) + .filter((zeta) => zeta !== null); // ask if --- Filter null values + + // Create the MultiZetaStimulus structure without the category keys + const cleanItem = _omit( + item, + Object.keys(item).filter((key) => catNames.some((cat) => key.startsWith(cat + delimiter))), + ); + + return { + ...cleanItem, + zetas, + }; + }); +}; From 7a527cae8d22e1c7be5df0d7a4a90a24e0ba52d7 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 23 Sep 2024 11:10:56 -0700 Subject: [PATCH 24/47] WIP: Add logicalOperation and tests --- src/__tests__/stopping.test.ts | 912 +++++++++++++++++++-------------- src/stopping.ts | 81 ++- 2 files changed, 605 insertions(+), 388 deletions(-) diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index abfdd69..b108bc3 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -15,62 +15,52 @@ const testInstantiation = (earlyStopping: EarlyStopping, input: EarlyStoppingInp expect(earlyStopping.tolerance).toEqual(input.tolerance ?? {}); expect(earlyStopping.requiredItems).toEqual(input.requiredItems ?? {}); expect(earlyStopping.seMeasurementThreshold).toEqual(input.seMeasurementThreshold ?? {}); + expect(earlyStopping.logicalOperation).toBe(input.logicalOperation?.toLowerCase() ?? 'or'); expect(earlyStopping.earlyStop).toBeBoolean(); }; const testInternalState = (earlyStopping: EarlyStopping) => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat1'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.nItems.cat1).toBe(1); expect(earlyStopping.seMeasurements.cat1).toEqual([0.5]); expect(earlyStopping.nItems.cat2).toBe(1); expect(earlyStopping.seMeasurements.cat2).toEqual([0.3]); - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats2, 'cat1'); + earlyStopping.update(updates[1]); expect(earlyStopping.nItems.cat1).toBe(2); expect(earlyStopping.seMeasurements.cat1).toEqual([0.5, 0.5]); expect(earlyStopping.nItems.cat2).toBe(2); expect(earlyStopping.seMeasurements.cat2).toEqual([0.3, 0.3]); }; -const testNoStoppingOnInvalidCatName = (earlyStopping: EarlyStopping) => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats1, 'invalidCatName'); - expect(earlyStopping.earlyStop).toBe(false); -}; - -describe('StopOnSEMeasurementPlateau', () => { +describe.each` + logicalOperation + ${'and'} + ${'or'} +`("StopOnSEMeasurementPlateau (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { let earlyStopping: StopOnSEMeasurementPlateau; let input: EarlyStoppingInput; @@ -78,6 +68,7 @@ describe('StopOnSEMeasurementPlateau', () => { input = { patience: { cat1: 2, cat2: 3 }, tolerance: { cat1: 0.01, cat2: 0.02 }, + logicalOperation, }; earlyStopping = new StopOnSEMeasurementPlateau(input); }); @@ -86,164 +77,205 @@ describe('StopOnSEMeasurementPlateau', () => { it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); - it('does not stop on invalid cat name', () => testNoStoppingOnInvalidCatName(earlyStopping)); - it('stops when the seMeasurement has plateaued', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat1'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + // cat1 should trigger stopping if logicalOperator === 'or', because + // seMeasurement plateaued over the patience period of 2 items + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.5, + } as Cat, + // cat2 should trigger stopping if logicalOperator === 'and', because + // seMeasurement plateaued over the patience period of 3 items, and the + // cat1 criterion passed last update + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat1'); - expect(earlyStopping.earlyStop).toBe(true); + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + } }); it('does not stop when the seMeasurement has not plateaued', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.1, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 100, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 100, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 10, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 1, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats1, 'cat1'); + earlyStopping.update(updates[1]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat1'); + earlyStopping.update(updates[2]); expect(earlyStopping.earlyStop).toBe(false); }); it('waits for `patience` items to monitor the seMeasurement plateau', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; - - const cats3: CatMap = { - cat1: { - nItems: 3, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 3, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat2'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 100, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat2'); + earlyStopping.update(updates[1]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats3, 'cat2'); + earlyStopping.update(updates[2]); expect(earlyStopping.earlyStop).toBe(true); }); it('triggers early stopping when within tolerance', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 10, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.4, - } as Cat, - }; - - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 1, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.395, - } as Cat, - }; - - const cats3: CatMap = { - cat1: { - nItems: 3, - seMeasurement: 0.0001, - } as Cat, - cat2: { - nItems: 3, - seMeasurement: 0.39, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat2'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.4, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.395, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.99, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.39, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat2'); + earlyStopping.update(updates[1]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats3, 'cat2'); + earlyStopping.update(updates[2]); expect(earlyStopping.earlyStop).toBe(true); }); }); -describe('StopAfterNItems', () => { +describe.each` + logicalOperation + ${'and'} + ${'or'} +`("StopAfterNItems (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { let earlyStopping: StopAfterNItems; let input: EarlyStoppingInput; beforeEach(() => { input = { requiredItems: { cat1: 2, cat2: 3 }, + logicalOperation, }; earlyStopping = new StopAfterNItems(input); }); @@ -252,84 +284,145 @@ describe('StopAfterNItems', () => { it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); - it('does not stop on invalid cat name', () => testNoStoppingOnInvalidCatName(earlyStopping)); - it('does not step when it has not seen required items', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat2'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Do not increment nItems for cat2 + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Increment nItems for cat1, but only use this update if + // logicalOperation is 'and'. Early stopping should still not be + // triggered. + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Do not increment nItems for cat2 + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat2'); + earlyStopping.update(updates[2]); expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'and') { + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(false); + } }); it('stops when it has seen required items', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; - - const cats3: CatMap = { - cat1: { - nItems: 3, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 3, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat2'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Do not increment nItems for cat1 + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Cat2 reaches required items + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + // Cat1 reaches required items + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + // Cat2 reaches required items + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat2'); + earlyStopping.update(updates[1]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats3, 'cat2'); - expect(earlyStopping.earlyStop).toBe(true); + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } }); }); -describe('StopIfSEMeasurementBelowThreshold', () => { +describe.each` + logicalOperation + ${'and'} + ${'or'} +`("StopIfSEMeasurementBelowThreshold (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { let earlyStopping: StopIfSEMeasurementBelowThreshold; let input: EarlyStoppingInput; @@ -338,6 +431,7 @@ describe('StopIfSEMeasurementBelowThreshold', () => { patience: { cat1: 1, cat2: 3 }, tolerance: { cat1: 0.01, cat2: 0.02 }, seMeasurementThreshold: { cat1: 0.03, cat2: 0.02 }, + logicalOperation, }; earlyStopping = new StopIfSEMeasurementBelowThreshold(input); }); @@ -346,181 +440,257 @@ describe('StopIfSEMeasurementBelowThreshold', () => { it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); - it('does not stop on invalid cat name', () => testNoStoppingOnInvalidCatName(earlyStopping)); - it('stops when the seMeasurement has fallen below a threshold', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat1'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.02, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.02, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.02, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.02, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.02, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat1'); - expect(earlyStopping.earlyStop).toBe(true); + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } }); it('does not stop when the seMeasurement is above threshold', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.1, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.3, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.1, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.3, - } as Cat, - }; + earlyStopping.update(updates[1]); + expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats1, 'cat1'); + earlyStopping.update(updates[2]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat1'); + earlyStopping.update(updates[3]); expect(earlyStopping.earlyStop).toBe(false); }); it('waits for `patience` items to monitor the seMeasurement plateau', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.3, - } as Cat, - }; - - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.01, - } as Cat, - }; - - const cats3: CatMap = { - cat1: { - nItems: 3, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 3, - seMeasurement: 0.01, - } as Cat, - }; - - const cats4: CatMap = { - cat1: { - nItems: 4, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 4, - seMeasurement: 0.01, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat2'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.3, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.01, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.01, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.5, + } as Cat, + // Cat2 should trigger when logicalOperation is 'or' + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, + }, + { + // Cat1 should trigger when logicalOperation is 'and' + // Cat2 criterion was satisfied after last update + cat1: { + nItems: 5, + seMeasurement: 0.01, + } as Cat, + cat2: { + nItems: 5, + seMeasurement: 0.01, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat2'); + earlyStopping.update(updates[1]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats3, 'cat2'); + earlyStopping.update(updates[2]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats4, 'cat2'); - expect(earlyStopping.earlyStop).toBe(true); + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[4]); + expect(earlyStopping.earlyStop).toBe(true); + } }); it('triggers early stopping when within tolerance', () => { - const cats1: CatMap = { - cat1: { - nItems: 1, - seMeasurement: 10, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.4, - } as Cat, - }; - - const cats2: CatMap = { - cat1: { - nItems: 2, - seMeasurement: 1, - } as Cat, - cat2: { - nItems: 2, - seMeasurement: 0.02, - } as Cat, - }; - - const cats3: CatMap = { - cat1: { - nItems: 3, - seMeasurement: 0.0001, - } as Cat, - cat2: { - nItems: 3, - seMeasurement: 0.04, - } as Cat, - }; - - const cats4: CatMap = { - cat1: { - nItems: 4, - seMeasurement: 0.5, - } as Cat, - cat2: { - nItems: 4, - seMeasurement: 0.01, - } as Cat, - }; - - earlyStopping.update(cats1, 'cat2'); + const updates: CatMap[] = [ + { + cat1: { + nItems: 1, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.4, + } as Cat, + }, + { + cat1: { + nItems: 2, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.02, + } as Cat, + }, + { + cat1: { + nItems: 3, + seMeasurement: 0.0001, + } as Cat, + cat2: { + nItems: 3, + seMeasurement: 0.04, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.5, + } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, + }, + ]; + + earlyStopping.update(updates[0]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats2, 'cat2'); + earlyStopping.update(updates[1]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats3, 'cat2'); + earlyStopping.update(updates[2]); expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(cats4, 'cat2'); + earlyStopping.update(updates[3]); expect(earlyStopping.earlyStop).toBe(true); }); }); diff --git a/src/stopping.ts b/src/stopping.ts index 14aaeb0..27cf804 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -1,5 +1,6 @@ import { Cat } from './cat'; import { CatMap } from './type'; +import _uniq from 'lodash/uniq'; /** * Interface for input parameters to EarlyStopping classes. @@ -13,6 +14,8 @@ export interface EarlyStoppingInput { requiredItems?: CatMap; /** Stop if the standard error of measurement drops below this level */ seMeasurementThreshold?: CatMap; + /** The logical operation to use to evaluate multiple stopping conditions */ + logicalOperation?: 'and' | 'or' | 'AND' | 'OR'; } /** @@ -26,8 +29,17 @@ export abstract class EarlyStopping { protected _seMeasurementThreshold: CatMap; protected _nItems: CatMap; protected _seMeasurements: CatMap; - - constructor({ patience = {}, tolerance = {}, requiredItems = {}, seMeasurementThreshold = {} }: EarlyStoppingInput) { + protected _logicalOperation: 'and' | 'or'; + + constructor({ + patience = {}, + tolerance = {}, + requiredItems = {}, + seMeasurementThreshold = {}, + logicalOperation = 'or', + }: EarlyStoppingInput) { + // TODO: Add some input validation here + // logicalOperation.toLowerCase() should be 'and' or 'or' this._patience = patience; this._tolerance = tolerance; this._requiredItems = requiredItems; @@ -35,6 +47,16 @@ export abstract class EarlyStopping { this._seMeasurements = {}; this._nItems = {}; this._earlyStop = false; + this._logicalOperation = logicalOperation.toLowerCase() as 'and' | 'or'; + } + + public get evaluationCats() { + return _uniq([ + ...Object.keys(this._patience), + ...Object.keys(this._tolerance), + ...Object.keys(this._requiredItems), + ...Object.keys(this._seMeasurementThreshold), + ]); } public get patience() { @@ -65,6 +87,10 @@ export abstract class EarlyStopping { return this._seMeasurements; } + public get logicalOperation() { + return this._logicalOperation; + } + /** * Update the internal state of the early stopping strategy based on the provided cats. * @param {CatMap}cats - A map of cats to update. @@ -82,35 +108,52 @@ export abstract class EarlyStopping { } } + /** + * Abstract method to be implemented by subclasses to evaluate a single stopping condition. + * @param {string} catToEvaluate - The name of the cat to evaluate for early stopping. + */ + protected abstract _evaluateStoppingCondition(catToEvaluate: string): boolean; + /** * Abstract method to be implemented by subclasses to update the early stopping strategy. * @param {CatMap} cats - A map of cats to update. - * @param {string} catToEvaluate - The name of the cat to evaluate for early stopping. */ - public abstract update(cats: CatMap, catToEvaluate: string): void; + public update(cats: CatMap): void { + this._updateCats(cats); + + const conditions: boolean[] = this.evaluationCats.map((catName) => this._evaluateStoppingCondition(catName)); + + if (this._logicalOperation === 'and') { + this._earlyStop = conditions.every(Boolean); + } else { + this._earlyStop = conditions.some(Boolean); + } + } } /** * Class implementing early stopping based on a plateau in standard error of measurement. */ export class StopOnSEMeasurementPlateau extends EarlyStopping { - public update(cats: CatMap, catToEvaluate: string) { - super._updateCats(cats); - + protected _evaluateStoppingCondition(catToEvaluate: string) { const seMeasurements = this._seMeasurements[catToEvaluate] ?? []; // Use MAX_SAFE_INTEGER and MAX_VALUE to prevent early stopping if the `catToEvaluate` is missing from the cats map. const patience = this._patience[catToEvaluate] ?? Number.MAX_SAFE_INTEGER; - const tolerance = this._tolerance[catToEvaluate] ?? Number.MAX_VALUE; + const tolerance = this._tolerance[catToEvaluate] ?? 0; + + let earlyStop = false; if (seMeasurements.length >= patience) { const mean = seMeasurements.slice(-patience).reduce((sum, se) => sum + se, 0) / patience; const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - mean) <= tolerance); if (withinTolerance) { - this._earlyStop = true; + earlyStop = true; } } + + return earlyStop; } } @@ -118,15 +161,17 @@ export class StopOnSEMeasurementPlateau extends EarlyStopping { * Class implementing early stopping after a certain number of items. */ export class StopAfterNItems extends EarlyStopping { - public update(cats: CatMap, catToEvaluate: string) { - super._updateCats(cats); - + protected _evaluateStoppingCondition(catToEvaluate: string) { const requiredItems = this._requiredItems[catToEvaluate] ?? Number.MAX_SAFE_INTEGER; const nItems = this._nItems[catToEvaluate] ?? 0; + let earlyStop = false; + if (nItems >= requiredItems) { - this._earlyStop = true; + earlyStop = true; } + + return earlyStop; } } @@ -134,20 +179,22 @@ export class StopAfterNItems extends EarlyStopping { * Class implementing early stopping if the standard error of measurement drops below a certain threshold. */ export class StopIfSEMeasurementBelowThreshold extends EarlyStopping { - public update(cats: CatMap, catToEvaluate: string) { - super._updateCats(cats); - + protected _evaluateStoppingCondition(catToEvaluate: string) { const seMeasurements = this._seMeasurements[catToEvaluate] ?? []; const seThreshold = this._seMeasurementThreshold[catToEvaluate] ?? 0; const patience = this._patience[catToEvaluate] ?? 1; const tolerance = this._tolerance[catToEvaluate] ?? 0; + let earlyStop = false; + if (seMeasurements.length >= patience) { const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - seThreshold) <= tolerance); if (withinTolerance) { - this._earlyStop = true; + earlyStop = true; } } + + return earlyStop; } } From 0334ba0ba6f6439074c670706fa7e8d19de62206 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 23 Sep 2024 11:19:49 -0700 Subject: [PATCH 25/47] Use _isEmpty instead of null test in prepareClowderCorpus --- src/__tests__/corpus.test.ts | 204 +++++++++++++++++------------------ src/corpus.ts | 3 +- 2 files changed, 104 insertions(+), 103 deletions(-) diff --git a/src/__tests__/corpus.test.ts b/src/__tests__/corpus.test.ts index e01f6a6..8d5bd16 100644 --- a/src/__tests__/corpus.test.ts +++ b/src/__tests__/corpus.test.ts @@ -322,110 +322,110 @@ describe('filterItemsByCatParameterAvailability', () => { expect(result.missing.length).toBe(1); expect(result.missing[0].stimulus).toBe('Item 2'); }); +}); - describe('prepareClowderCorpus', () => { - it('converts a Stimulus array to a MultiZetaStimulus array with symbolic format', () => { - const items: Stimulus[] = [ - { - 'cat1.a': 1, - 'cat1.b': 2, - 'cat1.c': 3, - 'cat1.d': 4, - 'foo.a': 5, - 'foo.b': 6, - 'foo.c': 7, - 'foo.d': 8, - stimulus: 'stim0', - type: 'jspsychHtmlMultiResponse', - }, - ]; - - const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.'); - - expect(result).toEqual([ - { - stimulus: 'stim0', - type: 'jspsychHtmlMultiResponse', - zetas: [ - { - cats: ['cat1'], - zeta: { a: 1, b: 2, c: 3, d: 4 }, - }, - { - cats: ['foo'], - zeta: { a: 5, b: 6, c: 7, d: 8 }, - }, - ], - }, - ]); - }); +describe('prepareClowderCorpus', () => { + it('converts a Stimulus array to a MultiZetaStimulus array with symbolic format', () => { + const items: Stimulus[] = [ + { + 'cat1.a': 1, + 'cat1.b': 2, + 'cat1.c': 3, + 'cat1.d': 4, + 'foo.a': 5, + 'foo.b': 6, + 'foo.c': 7, + 'foo.d': 8, + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + }, + ]; - it('converts a Stimulus array to a MultiZetaStimulus array with semantic format', () => { - const items: Stimulus[] = [ - { - 'cat1.a': 1, - 'cat1.b': 2, - 'cat1.c': 3, - 'cat1.d': 4, - 'foo.a': 5, - 'foo.b': 6, - 'foo.c': 7, - 'foo.d': 8, - stimulus: 'stim0', - type: 'jspsychHtmlMultiResponse', - }, - ]; - - const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.', 'semantic'); - - expect(result).toEqual([ - { - stimulus: 'stim0', - type: 'jspsychHtmlMultiResponse', - zetas: [ - { - cats: ['cat1'], - zeta: { discrimination: 1, difficulty: 2, guessing: 3, slipping: 4 }, - }, - { - cats: ['foo'], - zeta: { discrimination: 5, difficulty: 6, guessing: 7, slipping: 8 }, - }, - ], - }, - ]); - }); + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.'); - it('handles cases with different delimiters', () => { - const items: Stimulus[] = [ - { - cat1_a: 1, - cat1_b: 2, - foo_a: 5, - foo_b: 6, - stimulus: 'stim1', - type: 'jspsychHtmlMultiResponse', - }, - ]; - - const result = prepareClowderCorpus(items, ['cat1', 'foo'], '_', 'symbolic'); - - expect(result).toEqual([ - { - stimulus: 'stim1', - type: 'jspsychHtmlMultiResponse', - zetas: [ - { - cats: ['cat1'], - zeta: { a: 1, b: 2 }, - }, - { - cats: ['foo'], - zeta: { a: 5, b: 6 }, - }, - ], - }, - ]); - }); + expect(result).toEqual([ + { + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { a: 1, b: 2, c: 3, d: 4 }, + }, + { + cats: ['foo'], + zeta: { a: 5, b: 6, c: 7, d: 8 }, + }, + ], + }, + ]); + }); + + it('converts a Stimulus array to a MultiZetaStimulus array with semantic format', () => { + const items: Stimulus[] = [ + { + 'cat1.a': 1, + 'cat1.b': 2, + 'cat1.c': 3, + 'cat1.d': 4, + 'foo.a': 5, + 'foo.b': 6, + 'foo.c': 7, + 'foo.d': 8, + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '.', 'semantic'); + + expect(result).toEqual([ + { + stimulus: 'stim0', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { discrimination: 1, difficulty: 2, guessing: 3, slipping: 4 }, + }, + { + cats: ['foo'], + zeta: { discrimination: 5, difficulty: 6, guessing: 7, slipping: 8 }, + }, + ], + }, + ]); + }); + + it('handles cases with different delimiters', () => { + const items: Stimulus[] = [ + { + cat1_a: 1, + cat1_b: 2, + foo_a: 5, + foo_b: 6, + stimulus: 'stim1', + type: 'jspsychHtmlMultiResponse', + }, + ]; + + const result = prepareClowderCorpus(items, ['cat1', 'foo'], '_', 'symbolic'); + + expect(result).toEqual([ + { + stimulus: 'stim1', + type: 'jspsychHtmlMultiResponse', + zetas: [ + { + cats: ['cat1'], + zeta: { a: 1, b: 2 }, + }, + { + cats: ['foo'], + zeta: { a: 5, b: 6 }, + }, + ], + }, + ]); }); }); diff --git a/src/corpus.ts b/src/corpus.ts index b1ac0d3..c5231da 100644 --- a/src/corpus.ts +++ b/src/corpus.ts @@ -2,6 +2,7 @@ import { MultiZetaStimulus, Stimulus, Zeta } from './type'; import _flatten from 'lodash/flatten'; import _invert from 'lodash/invert'; +import _isEmpty from 'lodash/isEmpty'; import _mapKeys from 'lodash/mapKeys'; import _union from 'lodash/union'; import _uniq from 'lodash/uniq'; @@ -286,7 +287,7 @@ export const prepareClowderCorpus = ( zeta: convertZeta(zeta, itemParameterFormat), }; }) - .filter((zeta) => zeta !== null); // ask if --- Filter null values + .filter((zeta) => !_isEmpty(zeta)); // ask if --- Filter null values // Create the MultiZetaStimulus structure without the category keys const cleanItem = _omit( From fd254a45d602dab2b6be5ce389f64a42740b6a90 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 23 Sep 2024 11:32:09 -0700 Subject: [PATCH 26/47] Increase test coverage --- src/__tests__/stopping.test.ts | 43 +++++++++++++++++++++++++++++----- src/stopping.ts | 2 +- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index b108bc3..2bdc2ae 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -638,8 +638,12 @@ describe.each` }); it('triggers early stopping when within tolerance', () => { + // patience: { cat1: 1, cat2: 3 }, + // tolerance: { cat1: 0.01, cat2: 0.02 }, + // seMeasurementThreshold: { cat1: 0.03, cat2: 0.02 }, const updates: CatMap[] = [ { + // Update 1 should not trigger cat1: { nItems: 1, seMeasurement: 10, @@ -650,31 +654,38 @@ describe.each` } as Cat, }, { + // Update 2 should not trigger cat1: { nItems: 2, seMeasurement: 1, } as Cat, cat2: { nItems: 2, + // Cat 2 is low enough but not enough items to satisfy patience seMeasurement: 0.02, } as Cat, }, { + // Update 3 should trigger for logicalOperation === 'or', but not for 'and' cat1: { nItems: 3, - seMeasurement: 0.0001, + // Cat 1 is low enough and the patience is only 1 + seMeasurement: 0.0399, } as Cat, cat2: { nItems: 3, + // Cat 2 patience is still not satisfied seMeasurement: 0.04, } as Cat, }, { + // Update 4 should trigger for logicalOperation === 'and' cat1: { nItems: 4, - seMeasurement: 0.5, + seMeasurement: 0.001, } as Cat, cat2: { + // SE is low enough and patience is satisfied nItems: 4, seMeasurement: 0.01, } as Cat, @@ -688,9 +699,29 @@ describe.each` expect(earlyStopping.earlyStop).toBe(false); earlyStopping.update(updates[2]); - expect(earlyStopping.earlyStop).toBe(false); - - earlyStopping.update(updates[3]); - expect(earlyStopping.earlyStop).toBe(true); + if (earlyStopping.logicalOperation === 'or') { + expect(earlyStopping.earlyStop).toBe(true); + } else { + expect(earlyStopping.earlyStop).toBe(false); + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } }); }); + +// TODO: We need to write some tests where not all cats are in the input for the early stopping instance. +// Right now, we have input like +// input = { +// patience: { cat1: 2, cat2: 3 }, +// tolerance: { cat1: 0.01, cat2: 0.02 }, +// logicalOperation, +// }; +// +// But we want input like +// input = { +// patience: { cat1: 2, cat2: 3 }, +// tolerance: { cat2: 0.02, cat3: 0.01 }, +// logicalOperation, +// }; +// +// In these situations, we need good default values to make sure that the tests pass. diff --git a/src/stopping.ts b/src/stopping.ts index 27cf804..88fe1dc 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -188,7 +188,7 @@ export class StopIfSEMeasurementBelowThreshold extends EarlyStopping { let earlyStop = false; if (seMeasurements.length >= patience) { - const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - seThreshold) <= tolerance); + const withinTolerance = seMeasurements.slice(-patience).every((se) => se - seThreshold <= tolerance); if (withinTolerance) { earlyStop = true; From 95da412613ce3149d18f0a4ef215576ccdd7ad7d Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 23 Sep 2024 11:34:48 -0700 Subject: [PATCH 27/47] Add comments on using early stopping in the Clowder class --- src/clowder.ts | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/clowder.ts b/src/clowder.ts index 7c38af9..114fde5 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -278,6 +278,15 @@ export class Clowder { this.cats[catName].updateAbilityEstimate(zetas, answers, method); } + // TODO: These next two if clauses were not very well thought through by Adam. We should scrutinize and add tests. + if (this._earlyStopping) { + this._earlyStopping.update(this.cats); + } + + if (this._earlyStopping?.earlyStop) { + return undefined; + } + // +----------+ // ----------| Select |----------| // +----------+ From 58ccf4a9696c35c129a9a30892edf2a1500bcd69 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 23 Sep 2024 14:32:59 -0700 Subject: [PATCH 28/47] solving TODOS and adding stopping tests --- src/__tests__/clowder.test.ts | 90 +++++++++++++++++- src/__tests__/stopping.test.ts | 167 +++++++++++---------------------- src/clowder.ts | 91 ++++++++---------- src/corpus.ts | 2 +- 4 files changed, 186 insertions(+), 164 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 5e24e5e..addc218 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -2,6 +2,7 @@ import { Cat, Clowder, ClowderInput } from '..'; import { MultiZetaStimulus, Zeta, ZetaCatMap } from '../type'; import { defaultZeta } from '../corpus'; import _uniq from 'lodash/uniq'; +import { StopAfterNItems, StopIfSEMeasurementBelowThreshold, StopOnSEMeasurementPlateau } from '../stopping'; const createStimulus = (id: string) => ({ ...defaultZeta(), @@ -176,7 +177,7 @@ describe('Clowder Class', () => { catToSelect: 'cat1', randomlySelectUnvalidated: false, }); - expect(nextItem?.id).toBe('0'); + expect(nextItem?.id).toMatch(/^(0|1)$/); }); it('should select an unvalidated item if no validated items remain', () => { @@ -276,4 +277,91 @@ describe('Clowder Class', () => { expect(nextItem).toBeDefined(); expect(clowder.cats.cat1.theta).not.toBe(originalTheta); }); + + it('should update early stopping conditions based on number of items presented', () => { + const earlyStopping = new StopOnSEMeasurementPlateau({ + patience: { cat1: 2 }, // Requires 2 items to check for plateau + tolerance: { cat1: 0.05 }, // SE change tolerance + }); + + const clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + ], + earlyStopping, + }); + + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0]], + answers: [1], + }); + + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[1]], + answers: [1], + }); + + expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after 2 items + }); + + it('should trigger early stopping after required number of items', () => { + const earlyStopping = new StopAfterNItems({ + requiredItems: { cat2: 3 }, // Stop after 3 items for cat2 + }); + + const clowder = new Clowder({ + cats: { cat2: { method: 'EAP', theta: -1.0 } }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('2', [createZetaCatMap(['cat2'])]), + ], + earlyStopping, + }); + + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat2', + items: [clowder.corpus[0], clowder.corpus[1], clowder.corpus[2]], + answers: [1, 1, 1], + }); + + expect(clowder.earlyStopping?.earlyStop).toBe(false); + }); + + it('should handle StopIfSEMeasurementBelowThreshold condition', () => { + const earlyStopping = new StopIfSEMeasurementBelowThreshold({ + seMeasurementThreshold: { cat1: 0.05 }, // Threshold for SE + patience: { cat1: 2 }, + tolerance: { cat1: 0.01 }, + }); + + const clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + ], + earlyStopping, + }); + + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + items: [clowder.corpus[0]], + answers: [1], + }); + + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + items: [clowder.corpus[1]], + answers: [1], + }); + + expect(clowder.earlyStopping?.earlyStop).toBe(false); + }); }); diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index 2bdc2ae..d730737 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -90,8 +90,6 @@ describe.each` } as Cat, }, { - // cat1 should trigger stopping if logicalOperator === 'or', because - // seMeasurement plateaued over the patience period of 2 items cat1: { nItems: 2, seMeasurement: 0.5, @@ -106,9 +104,6 @@ describe.each` nItems: 3, seMeasurement: 0.5, } as Cat, - // cat2 should trigger stopping if logicalOperator === 'and', because - // seMeasurement plateaued over the patience period of 3 items, and the - // cat1 criterion passed last update cat2: { nItems: 3, seMeasurement: 0.3, @@ -298,7 +293,6 @@ describe.each` }, { cat1: { - // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, @@ -309,26 +303,20 @@ describe.each` }, { cat1: { - // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, cat2: { - // Do not increment nItems for cat2 nItems: 2, seMeasurement: 0.3, } as Cat, }, { cat1: { - // Increment nItems for cat1, but only use this update if - // logicalOperation is 'and'. Early stopping should still not be - // triggered. nItems: 2, seMeasurement: 0.5, } as Cat, cat2: { - // Do not increment nItems for cat2 nItems: 2, seMeasurement: 0.3, } as Cat, @@ -364,7 +352,6 @@ describe.each` }, { cat1: { - // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, @@ -375,24 +362,20 @@ describe.each` }, { cat1: { - // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, cat2: { - // Cat2 reaches required items nItems: 3, seMeasurement: 0.3, } as Cat, }, { cat1: { - // Cat1 reaches required items nItems: 2, seMeasurement: 0.5, } as Cat, cat2: { - // Cat2 reaches required items nItems: 3, seMeasurement: 0.3, } as Cat, @@ -559,7 +542,15 @@ describe.each` expect(earlyStopping.earlyStop).toBe(false); }); - it('waits for `patience` items to monitor the seMeasurement plateau', () => { + it('handles missing input for some cats', () => { + const input = { + patience: { cat1: 2 }, + tolerance: { cat2: 0.02 }, + seMeasurementThreshold: { cat3: 0.01 }, + logicalOperation, + }; + const earlyStopping = new StopIfSEMeasurementBelowThreshold(input); + const updates: CatMap[] = [ { cat1: { @@ -574,7 +565,7 @@ describe.each` { cat1: { nItems: 2, - seMeasurement: 0.5, + seMeasurement: 0.02, } as Cat, cat2: { nItems: 2, @@ -584,34 +575,11 @@ describe.each` { cat1: { nItems: 3, - seMeasurement: 0.5, + seMeasurement: 0.02, } as Cat, cat2: { nItems: 3, - seMeasurement: 0.01, - } as Cat, - }, - { - cat1: { - nItems: 4, - seMeasurement: 0.5, - } as Cat, - // Cat2 should trigger when logicalOperation is 'or' - cat2: { - nItems: 4, - seMeasurement: 0.01, - } as Cat, - }, - { - // Cat1 should trigger when logicalOperation is 'and' - // Cat2 criterion was satisfied after last update - cat1: { - nItems: 5, - seMeasurement: 0.01, - } as Cat, - cat2: { - nItems: 5, - seMeasurement: 0.01, + seMeasurement: 0.02, } as Cat, }, ]; @@ -620,74 +588,79 @@ describe.each` expect(earlyStopping.earlyStop).toBe(false); earlyStopping.update(updates[1]); - expect(earlyStopping.earlyStop).toBe(false); - - earlyStopping.update(updates[2]); - expect(earlyStopping.earlyStop).toBe(false); - if (earlyStopping.logicalOperation === 'or') { - earlyStopping.update(updates[3]); expect(earlyStopping.earlyStop).toBe(true); } else { - earlyStopping.update(updates[3]); expect(earlyStopping.earlyStop).toBe(false); + } - earlyStopping.update(updates[4]); + earlyStopping.update(updates[2]); + if (earlyStopping.logicalOperation === 'or') { expect(earlyStopping.earlyStop).toBe(true); + } else { + expect(earlyStopping.earlyStop).toBe(false); } }); - it('triggers early stopping when within tolerance', () => { - // patience: { cat1: 1, cat2: 3 }, - // tolerance: { cat1: 0.01, cat2: 0.02 }, - // seMeasurementThreshold: { cat1: 0.03, cat2: 0.02 }, + it('does not stop when the seMeasurement has not plateaued enough over patience', () => { + const input = { + patience: { cat1: 2 }, + tolerance: { cat1: 0.05 }, + logicalOperation, + }; + const earlyStopping = new StopOnSEMeasurementPlateau(input); + const updates: CatMap[] = [ { - // Update 1 should not trigger cat1: { nItems: 1, - seMeasurement: 10, - } as Cat, - cat2: { - nItems: 1, - seMeasurement: 0.4, + seMeasurement: 0.5, } as Cat, }, { - // Update 2 should not trigger cat1: { nItems: 2, - seMeasurement: 1, - } as Cat, - cat2: { - nItems: 2, - // Cat 2 is low enough but not enough items to satisfy patience - seMeasurement: 0.02, + seMeasurement: 0.49, } as Cat, }, { - // Update 3 should trigger for logicalOperation === 'or', but not for 'and' cat1: { nItems: 3, - // Cat 1 is low enough and the patience is only 1 - seMeasurement: 0.0399, - } as Cat, - cat2: { - nItems: 3, - // Cat 2 patience is still not satisfied - seMeasurement: 0.04, + seMeasurement: 0.48, } as Cat, }, + ]; + + earlyStopping.update(updates[0]); + expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[1]); + if (earlyStopping.logicalOperation === 'and') { + expect(earlyStopping.earlyStop).toBe(true); + } else { + expect(earlyStopping.earlyStop).toBe(true); + } + + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('does not stop if required items have not been reached', () => { + const earlyStopping = new StopAfterNItems({ + requiredItems: { cat1: 5 }, + logicalOperation: 'or', + }); + const updates: CatMap[] = [ { - // Update 4 should trigger for logicalOperation === 'and' cat1: { - nItems: 4, - seMeasurement: 0.001, + nItems: 2, + seMeasurement: 0.5, } as Cat, - cat2: { - // SE is low enough and patience is satisfied + }, + { + cat1: { nItems: 4, - seMeasurement: 0.01, + seMeasurement: 0.5, } as Cat, }, ]; @@ -696,32 +669,6 @@ describe.each` expect(earlyStopping.earlyStop).toBe(false); earlyStopping.update(updates[1]); - expect(earlyStopping.earlyStop).toBe(false); - - earlyStopping.update(updates[2]); - if (earlyStopping.logicalOperation === 'or') { - expect(earlyStopping.earlyStop).toBe(true); - } else { - expect(earlyStopping.earlyStop).toBe(false); - earlyStopping.update(updates[3]); - expect(earlyStopping.earlyStop).toBe(true); - } + expect(earlyStopping.earlyStop).toBe(false); // still not reached requiredItems }); }); - -// TODO: We need to write some tests where not all cats are in the input for the early stopping instance. -// Right now, we have input like -// input = { -// patience: { cat1: 2, cat2: 3 }, -// tolerance: { cat1: 0.01, cat2: 0.02 }, -// logicalOperation, -// }; -// -// But we want input like -// input = { -// patience: { cat1: 2, cat2: 3 }, -// tolerance: { cat2: 0.02, cat3: 0.01 }, -// logicalOperation, -// }; -// -// In these situations, we need good default values to make sure that the tests pass. diff --git a/src/clowder.ts b/src/clowder.ts index 114fde5..dab07de 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -29,6 +29,10 @@ export interface ClowderInput { * An optional EarlyStopping instance to use for early stopping. */ earlyStopping?: EarlyStopping; + /** + * An optional number of items required for each Cat to be considered for early stopping. + */ + numItemsRequired?: CatMap; } /** @@ -46,6 +50,7 @@ export class Clowder { private _seenItems: Stimulus[]; private _earlyStopping?: EarlyStopping; private readonly _rng: ReturnType; + _numItemsRequired?: CatMap; /** * Create a Clowder object. @@ -56,11 +61,7 @@ export class Clowder { * * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. */ - constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) { - // TODO: Need to pass in numItemsRequired so that we know when to stop - // providing new items. This may depend on the cat name. For instance, - // perhaps numItemsRequired should be an object with cat names as keys and - // numItemsRequired as values. + constructor({ cats, corpus, randomSeed = null, earlyStopping, numItemsRequired = {} }: ClowderInput) { this._cats = _mapValues(cats, (catInput) => new Cat(catInput)); this._seenItems = []; checkNoDuplicateCatNames(corpus); @@ -68,6 +69,7 @@ export class Clowder { this._remainingItems = _cloneDeep(corpus); this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); this._earlyStopping = earlyStopping; + this._numItemsRequired = numItemsRequired; } /** @@ -147,6 +149,10 @@ export class Clowder { return _mapValues(this.cats, (cat) => cat.zetas); } + public get earlyStopping() { + return this._earlyStopping; + } + /** * Updates the ability estimates for the specified Cat instances. * @@ -213,14 +219,7 @@ export class Clowder { itemSelect?: string; randomlySelectUnvalidated?: boolean; }): Stimulus | undefined { - // +----------+ - // ----------| Validate |----------| - // +----------+ - - // Validate catToSelect this._validateCatName(catToSelect); - - // Convert catsToUpdate to array and validate each name catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; catsToUpdate.forEach((cat) => { this._validateCatName(cat); @@ -252,29 +251,27 @@ export class Clowder { // Update the ability estimate for all cats for (const catName of catsToUpdate) { - const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => { - // We are dealing with a single item in this function. This single item - // has an array of zeta parameters for a bunch of different Cats. We - // need to determine if `catName` is present in that list. So we first - // reduce the zetas to get all of the applicabe cat names. - const allCats = stim.zetas.reduce((acc: string[], { cats }: { cats: string }) => { - return [...acc, ...cats]; - }, []); - - // Then we simply check if `catName` is present in this reduction. - return allCats.includes(catName); - }); - + const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => + stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), + ); + // We are dealing with a single item in this function. This single item + // has an array of zeta parameters for a bunch of different Cats. We + // need to determine if `catName` is present in that list. So we first + // reduce the zetas to get all of the applicabe cat names. // Now that we have the subset of items that can apply to this cat, // retrieve only the item parameters that apply to this cat. - const zetasAndAnswersForCat = itemsAndAnswersForCat.map(([stim, _answer]) => { - const { zetas } = stim; - const zetaForCat = zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName)); - return [zetaForCat.zeta, _answer]; - }); - - // Finally, unzip the zetas and answers and feed them into the cat's updateAbilityEstimate method. - const [zetas, answers] = _unzip(zetasAndAnswersForCat); + const zetasAndAnswersForCat = itemsAndAnswersForCat + .map(([stim, _answer]) => { + const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName)); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined + }) + .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values + + // Unzip the zetas and answers, making sure the zetas array contains only Zeta types + const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; + + // Now, pass the filtered zetas and answers to the cat's updateAbilityEstimate method this.cats[catName].updateAbilityEstimate(zetas, answers, method); } @@ -300,12 +297,11 @@ export class Clowder { // spread at the top-level of each Stimulus object. So we need to convert // the MultiZetaStimulus array to an array of Stimulus objects. const availableCatInput = available.map((item) => { - const { zetas, ...rest } = item; - const zetasForCat = zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catToSelect)); + const zetasForCat = item.zetas.find((zeta) => zeta.cats.includes(catToSelect)); return { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion ...zetasForCat!.zeta, - ...rest, + ...item, }; }); @@ -325,11 +321,9 @@ export class Clowder { // Again `nextStimulus` will be a Stimulus object, or `undefined` if no further validated stimuli are available. // We need to convert the Stimulus object back to a MultiZetaStimulus object to return to the user. - const returnStimulus: MultiZetaStimulus | undefined = available.find((stim) => { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { zetas, ...rest } = stim; - return _isEqual(rest, nextStimulusWithoutZeta); - }); + const returnStimulus: MultiZetaStimulus | undefined = available.find((stim) => + _isEqual(stim, nextStimulusWithoutZeta), + ); if (missing.length === 0) { // If there are no more unvalidated stimuli, we only have validated items left. @@ -344,18 +338,11 @@ export class Clowder { if (!randomlySelectUnvalidated) { return returnStimulus; } - - const numRemaining = { - available: available.length, - missing: missing.length, - }; const random = Math.random(); - - if (random < numRemaining.missing / (numRemaining.available + numRemaining.missing)) { - return missing[Math.floor(this._rng() * missing.length)]; - } else { - return returnStimulus; - } + const numRemaining = { available: available.length, missing: missing.length }; + return random < numRemaining.missing / (numRemaining.available + numRemaining.missing) + ? missing[Math.floor(this._rng() * missing.length)] + : returnStimulus; } } } diff --git a/src/corpus.ts b/src/corpus.ts index c5231da..687fc32 100644 --- a/src/corpus.ts +++ b/src/corpus.ts @@ -287,7 +287,7 @@ export const prepareClowderCorpus = ( zeta: convertZeta(zeta, itemParameterFormat), }; }) - .filter((zeta) => !_isEmpty(zeta)); // ask if --- Filter null values + .filter((zeta) => !_isEmpty(zeta)); // filter empty values // Create the MultiZetaStimulus structure without the category keys const cleanItem = _omit( From 3b1c56b88e77f30bf9e849cf118b2377a0bfad1d Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 23 Sep 2024 14:56:12 -0700 Subject: [PATCH 29/47] we don't need default --- src/stopping.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stopping.ts b/src/stopping.ts index 88fe1dc..ccf41ba 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -136,11 +136,11 @@ export abstract class EarlyStopping { */ export class StopOnSEMeasurementPlateau extends EarlyStopping { protected _evaluateStoppingCondition(catToEvaluate: string) { - const seMeasurements = this._seMeasurements[catToEvaluate] ?? []; + const seMeasurements = this._seMeasurements[catToEvaluate]; // Use MAX_SAFE_INTEGER and MAX_VALUE to prevent early stopping if the `catToEvaluate` is missing from the cats map. - const patience = this._patience[catToEvaluate] ?? Number.MAX_SAFE_INTEGER; - const tolerance = this._tolerance[catToEvaluate] ?? 0; + const patience = this._patience[catToEvaluate]; + const tolerance = this._tolerance[catToEvaluate]; let earlyStop = false; @@ -162,7 +162,7 @@ export class StopOnSEMeasurementPlateau extends EarlyStopping { */ export class StopAfterNItems extends EarlyStopping { protected _evaluateStoppingCondition(catToEvaluate: string) { - const requiredItems = this._requiredItems[catToEvaluate] ?? Number.MAX_SAFE_INTEGER; + const requiredItems = this._requiredItems[catToEvaluate]; const nItems = this._nItems[catToEvaluate] ?? 0; let earlyStop = false; From 2f232deb6409cdf552d928becf0dd82414bd2429 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Tue, 24 Sep 2024 11:51:56 -0700 Subject: [PATCH 30/47] resolve some comments --- src/clowder.ts | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/clowder.ts b/src/clowder.ts index dab07de..05303d8 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -29,10 +29,6 @@ export interface ClowderInput { * An optional EarlyStopping instance to use for early stopping. */ earlyStopping?: EarlyStopping; - /** - * An optional number of items required for each Cat to be considered for early stopping. - */ - numItemsRequired?: CatMap; } /** @@ -50,7 +46,6 @@ export class Clowder { private _seenItems: Stimulus[]; private _earlyStopping?: EarlyStopping; private readonly _rng: ReturnType; - _numItemsRequired?: CatMap; /** * Create a Clowder object. @@ -61,15 +56,15 @@ export class Clowder { * * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. */ - constructor({ cats, corpus, randomSeed = null, earlyStopping, numItemsRequired = {} }: ClowderInput) { + constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) { this._cats = _mapValues(cats, (catInput) => new Cat(catInput)); + console.log('Initialized cats:', this._cats); this._seenItems = []; checkNoDuplicateCatNames(corpus); this._corpus = corpus; this._remainingItems = _cloneDeep(corpus); this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); this._earlyStopping = earlyStopping; - this._numItemsRequired = numItemsRequired; } /** @@ -252,14 +247,14 @@ export class Clowder { // Update the ability estimate for all cats for (const catName of catsToUpdate) { const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => + // We are dealing with a single item in this function. This single item + // has an array of zeta parameters for a bunch of different Cats. We + // need to determine if `catName` is present in that list. So we first + // reduce the zetas to get all of the applicabe cat names. + // Now that we have the subset of items that can apply to this cat, + // retrieve only the item parameters that apply to this cat. stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), ); - // We are dealing with a single item in this function. This single item - // has an array of zeta parameters for a bunch of different Cats. We - // need to determine if `catName` is present in that list. So we first - // reduce the zetas to get all of the applicabe cat names. - // Now that we have the subset of items that can apply to this cat, - // retrieve only the item parameters that apply to this cat. const zetasAndAnswersForCat = itemsAndAnswersForCat .map(([stim, _answer]) => { const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName)); @@ -275,7 +270,6 @@ export class Clowder { this.cats[catName].updateAbilityEstimate(zetas, answers, method); } - // TODO: These next two if clauses were not very well thought through by Adam. We should scrutinize and add tests. if (this._earlyStopping) { this._earlyStopping.update(this.cats); } From 5300ce6f5ed81d96cf47fd18e7cdaa55b6a44410 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Tue, 24 Sep 2024 12:43:18 -0700 Subject: [PATCH 31/47] correcting some tests --- src/__tests__/clowder.test.ts | 71 +++++++++++++---- src/__tests__/stopping.test.ts | 140 ++++++++++++++++++++------------- src/clowder.ts | 1 - 3 files changed, 140 insertions(+), 72 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index addc218..99e9f97 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -300,7 +300,7 @@ describe('Clowder Class', () => { answers: [1], }); - clowder.updateCatAndGetNextItem({ + const nextItem = clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', catsToUpdate: ['cat1'], items: [clowder.corpus[1]], @@ -308,40 +308,73 @@ describe('Clowder Class', () => { }); expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after 2 items + expect(nextItem).toBe(undefined); // Expect undefined after early stopping + }); +}); + +describe('Clowder Early Stopping', () => { + let clowder: Clowder; + + beforeEach(() => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + ], + }; + clowder = new Clowder(clowderInput); }); it('should trigger early stopping after required number of items', () => { const earlyStopping = new StopAfterNItems({ - requiredItems: { cat2: 3 }, // Stop after 3 items for cat2 + requiredItems: { cat1: 2 }, // Stop after 2 items }); - const clowder = new Clowder({ - cats: { cat2: { method: 'EAP', theta: -1.0 } }, + clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, corpus: [ - createMultiZetaStimulus('0', [createZetaCatMap(['cat2'])]), - createMultiZetaStimulus('1', [createZetaCatMap(['cat2'])]), - createMultiZetaStimulus('2', [createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), // This item should trigger early stopping ], earlyStopping, }); clowder.updateCatAndGetNextItem({ - catToSelect: 'cat2', - items: [clowder.corpus[0], clowder.corpus[1], clowder.corpus[2]], - answers: [1, 1, 1], + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0]], + answers: [1], + }); + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[1]], + answers: [1], }); - expect(clowder.earlyStopping?.earlyStop).toBe(false); + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[2]], + answers: [1], + }); + + expect(clowder.earlyStopping?.earlyStop).toBe(true); // Early stop should be triggered after 2 items + expect(nextItem).toBe(undefined); // No further items should be selected }); it('should handle StopIfSEMeasurementBelowThreshold condition', () => { const earlyStopping = new StopIfSEMeasurementBelowThreshold({ - seMeasurementThreshold: { cat1: 0.05 }, // Threshold for SE + seMeasurementThreshold: { cat1: 0.2 }, // Stop if SE drops below 0.2 patience: { cat1: 2 }, tolerance: { cat1: 0.01 }, }); - const clowder = new Clowder({ + clowder = new Clowder({ cats: { cat1: { method: 'MLE', theta: 0.5 } }, corpus: [ createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), @@ -350,18 +383,26 @@ describe('Clowder Class', () => { earlyStopping, }); + // First update clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', + catsToUpdate: ['cat1'], items: [clowder.corpus[0]], answers: [1], }); + // pringing results + console.log('SE Measurements:', clowder.earlyStopping?.seMeasurementThreshold, clowder.cats.cat1); - clowder.updateCatAndGetNextItem({ + const nextItem = clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', + catsToUpdate: ['cat1'], items: [clowder.corpus[1]], answers: [1], }); - expect(clowder.earlyStopping?.earlyStop).toBe(false); + console.log('Early Stop Triggered:', clowder.earlyStopping?.earlyStop); + + expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after SE drops below threshold + expect(nextItem).toBe(undefined); // No further items should be selected }); }); diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index d730737..f34f469 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -90,6 +90,8 @@ describe.each` } as Cat, }, { + // cat1 should trigger stopping if logicalOperator === 'or', because + // seMeasurement plateaued over the patience period of 2 items cat1: { nItems: 2, seMeasurement: 0.5, @@ -104,6 +106,9 @@ describe.each` nItems: 3, seMeasurement: 0.5, } as Cat, + // cat2 should trigger stopping if logicalOperator === 'and', because + // seMeasurement plateaued over the patience period of 3 items, and the + // cat1 criterion passed last update cat2: { nItems: 3, seMeasurement: 0.3, @@ -293,6 +298,7 @@ describe.each` }, { cat1: { + // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, @@ -303,20 +309,26 @@ describe.each` }, { cat1: { + // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, cat2: { + // Do not increment nItems for cat2 nItems: 2, seMeasurement: 0.3, } as Cat, }, { cat1: { + // Increment nItems for cat1, but only use this update if + // logicalOperation is 'and'. Early stopping should still not be + // triggered. nItems: 2, seMeasurement: 0.5, } as Cat, cat2: { + // Do not increment nItems for cat2 nItems: 2, seMeasurement: 0.3, } as Cat, @@ -352,6 +364,7 @@ describe.each` }, { cat1: { + // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, @@ -362,20 +375,24 @@ describe.each` }, { cat1: { + // Do not increment nItems for cat1 nItems: 1, seMeasurement: 0.5, } as Cat, cat2: { + // Cat2 reaches required items nItems: 3, seMeasurement: 0.3, } as Cat, }, { cat1: { + // Cat1 reaches required items nItems: 2, seMeasurement: 0.5, } as Cat, cat2: { + // Cat2 reaches required items nItems: 3, seMeasurement: 0.3, } as Cat, @@ -542,15 +559,7 @@ describe.each` expect(earlyStopping.earlyStop).toBe(false); }); - it('handles missing input for some cats', () => { - const input = { - patience: { cat1: 2 }, - tolerance: { cat2: 0.02 }, - seMeasurementThreshold: { cat3: 0.01 }, - logicalOperation, - }; - const earlyStopping = new StopIfSEMeasurementBelowThreshold(input); - + it('waits for `patience` items to monitor the seMeasurement plateau', () => { const updates: CatMap[] = [ { cat1: { @@ -565,7 +574,7 @@ describe.each` { cat1: { nItems: 2, - seMeasurement: 0.02, + seMeasurement: 0.5, } as Cat, cat2: { nItems: 2, @@ -575,11 +584,34 @@ describe.each` { cat1: { nItems: 3, - seMeasurement: 0.02, + seMeasurement: 0.5, } as Cat, cat2: { nItems: 3, - seMeasurement: 0.02, + seMeasurement: 0.01, + } as Cat, + }, + { + cat1: { + nItems: 4, + seMeasurement: 0.5, + } as Cat, + // Cat2 should trigger when logicalOperation is 'or' + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, + }, + { + // Cat1 should trigger when logicalOperation is 'and' + // Cat2 criterion was satisfied after last update + cat1: { + nItems: 5, + seMeasurement: 0.01, + } as Cat, + cat2: { + nItems: 5, + seMeasurement: 0.01, } as Cat, }, ]; @@ -588,73 +620,53 @@ describe.each` expect(earlyStopping.earlyStop).toBe(false); earlyStopping.update(updates[1]); - if (earlyStopping.logicalOperation === 'or') { - expect(earlyStopping.earlyStop).toBe(true); - } else { - expect(earlyStopping.earlyStop).toBe(false); - } + expect(earlyStopping.earlyStop).toBe(false); earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[3]); expect(earlyStopping.earlyStop).toBe(true); } else { + earlyStopping.update(updates[3]); expect(earlyStopping.earlyStop).toBe(false); + + earlyStopping.update(updates[4]); + expect(earlyStopping.earlyStop).toBe(true); } }); - it('does not stop when the seMeasurement has not plateaued enough over patience', () => { - const input = { - patience: { cat1: 2 }, - tolerance: { cat1: 0.05 }, - logicalOperation, - }; - const earlyStopping = new StopOnSEMeasurementPlateau(input); - + it('triggers early stopping when within tolerance', () => { const updates: CatMap[] = [ { cat1: { nItems: 1, - seMeasurement: 0.5, + seMeasurement: 10, + } as Cat, + cat2: { + nItems: 1, + seMeasurement: 0.4, } as Cat, }, { cat1: { nItems: 2, - seMeasurement: 0.49, + seMeasurement: 1, + } as Cat, + cat2: { + nItems: 2, + seMeasurement: 0.02, } as Cat, }, { cat1: { nItems: 3, - seMeasurement: 0.48, + seMeasurement: 0.0001, } as Cat, - }, - ]; - - earlyStopping.update(updates[0]); - expect(earlyStopping.earlyStop).toBe(false); - - earlyStopping.update(updates[1]); - if (earlyStopping.logicalOperation === 'and') { - expect(earlyStopping.earlyStop).toBe(true); - } else { - expect(earlyStopping.earlyStop).toBe(true); - } - - earlyStopping.update(updates[2]); - expect(earlyStopping.earlyStop).toBe(true); - }); - - it('does not stop if required items have not been reached', () => { - const earlyStopping = new StopAfterNItems({ - requiredItems: { cat1: 5 }, - logicalOperation: 'or', - }); - const updates: CatMap[] = [ - { - cat1: { - nItems: 2, - seMeasurement: 0.5, + cat2: { + nItems: 3, + seMeasurement: 0.04, } as Cat, }, { @@ -662,6 +674,10 @@ describe.each` nItems: 4, seMeasurement: 0.5, } as Cat, + cat2: { + nItems: 4, + seMeasurement: 0.01, + } as Cat, }, ]; @@ -669,6 +685,18 @@ describe.each` expect(earlyStopping.earlyStop).toBe(false); earlyStopping.update(updates[1]); - expect(earlyStopping.earlyStop).toBe(false); // still not reached requiredItems + expect(earlyStopping.earlyStop).toBe(false); + + if (earlyStopping.logicalOperation === 'or') { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(true); + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(true); + } else { + earlyStopping.update(updates[2]); + expect(earlyStopping.earlyStop).toBe(false); + earlyStopping.update(updates[3]); + expect(earlyStopping.earlyStop).toBe(false); + } }); }); diff --git a/src/clowder.ts b/src/clowder.ts index 05303d8..c9db1f7 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -58,7 +58,6 @@ export class Clowder { */ constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) { this._cats = _mapValues(cats, (catInput) => new Cat(catInput)); - console.log('Initialized cats:', this._cats); this._seenItems = []; checkNoDuplicateCatNames(corpus); this._corpus = corpus; From 5c684a82e4e2cf9c496824f0935f1112dad2b879 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Wed, 25 Sep 2024 17:18:18 -0700 Subject: [PATCH 32/47] adding the unvalidated cat -- and test --- src/__tests__/clowder.test.ts | 16 ++++++++++------ src/clowder.ts | 23 +++++++++++++++++------ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 99e9f97..8b42f6c 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -35,7 +35,7 @@ describe('Clowder Class', () => { createMultiZetaStimulus('1', [createZetaCatMap(['cat1']), createZetaCatMap(['cat2'])]), createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), createMultiZetaStimulus('3', [createZetaCatMap(['cat2'])]), - createMultiZetaStimulus('4', []), + createMultiZetaStimulus('4', []), // Unvalidated item ], }; clowder = new Clowder(clowderInput); @@ -43,6 +43,7 @@ describe('Clowder Class', () => { it('initializes with provided cats and corpora', () => { expect(Object.keys(clowder.cats)).toContain('cat1'); + expect(Object.keys(clowder.cats)).toContain('unvalidated'); // Ensure 'unvalidated' cat is present expect(clowder.remainingItems).toHaveLength(5); expect(clowder.corpus).toHaveLength(5); expect(clowder.seenItems).toHaveLength(0); @@ -52,7 +53,8 @@ describe('Clowder Class', () => { expect(() => { const corpus: MultiZetaStimulus[] = [ { - stimulus: 'Item 1', + id: 'item1', + content: 'Item 1', zetas: [ { cats: ['Model A', 'Model B'], zeta: { a: 1, b: 0.5, c: 0.2, d: 0.8 } }, { cats: ['Model C'], zeta: { a: 2, b: 0.7, c: 0.3, d: 0.9 } }, @@ -60,7 +62,8 @@ describe('Clowder Class', () => { ], }, { - stimulus: 'Item 2', + id: 'item2', + content: 'Item 2', zetas: [{ cats: ['Model A', 'Model C'], zeta: { a: 2.5, b: 0.8, c: 0.35, d: 0.95 } }], }, ]; @@ -89,7 +92,7 @@ describe('Clowder Class', () => { it('throws an error when updating ability estimates for an invalid cat', () => { expect(() => clowder.updateAbilityEstimates(['invalidCatName'], createStimulus('1'), [0])).toThrowError( - 'Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.', + 'Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.', ); }); @@ -106,6 +109,7 @@ describe('Clowder Class', () => { const expected = { cat1: clowder.cats['cat1'][property as keyof Cat], cat2: clowder.cats['cat2'][property as keyof Cat], + unvalidated: clowder.cats['unvalidated'][property as keyof Cat], }; expect(clowder[property as keyof Clowder]).toEqual(expected); }); @@ -125,7 +129,7 @@ describe('Clowder Class', () => { clowder.updateCatAndGetNextItem({ catToSelect: 'invalidCatName', }); - }).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.'); + }).toThrow('Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.'); }); it('throws an error if any of catsToUpdate is invalid', () => { @@ -134,7 +138,7 @@ describe('Clowder Class', () => { catToSelect: 'cat1', catsToUpdate: ['invalidCatName', 'cat2'], }); - }).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.'); + }).toThrow('Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.'); }); it('updates seen and remaining items', () => { diff --git a/src/clowder.ts b/src/clowder.ts index c9db1f7..49c4ee7 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -57,7 +57,10 @@ export class Clowder { * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. */ constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) { - this._cats = _mapValues(cats, (catInput) => new Cat(catInput)); + this._cats = { + ..._mapValues(cats, (catInput) => new Cat(catInput)), + unvalidated: new Cat(), // Add 'unvalidated' cat + }; this._seenItems = []; checkNoDuplicateCatNames(corpus); this._corpus = corpus; @@ -243,7 +246,7 @@ export class Clowder { // answers "stay together." const itemsAndAnswers = _zip(items, answers) as [Stimulus, 0 | 1][]; - // Update the ability estimate for all cats + // Update the ability estimate for all validated cats for (const catName of catsToUpdate) { const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => // We are dealing with a single item in this function. This single item @@ -269,12 +272,20 @@ export class Clowder { this.cats[catName].updateAbilityEstimate(zetas, answers, method); } - if (this._earlyStopping) { - this._earlyStopping.update(this.cats); + // Assign items with no valid parameters to the 'unvalidated' cat + const unvalidatedItemsAndAnswers = itemsAndAnswers.filter( + ([stim]) => !stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.length > 0), + ); + if (unvalidatedItemsAndAnswers.length > 0) { + const [zetas, answers] = _unzip(unvalidatedItemsAndAnswers) as [Zeta[], (0 | 1)[]]; + this.cats['unvalidated'].updateAbilityEstimate(zetas, answers, method); } - if (this._earlyStopping?.earlyStop) { - return undefined; + if (this._earlyStopping) { + this._earlyStopping.update(this.cats); + if (this._earlyStopping.earlyStop) { + return undefined; + } } // +----------+ From 5d0ff5068bea87504c21c47e09bb58170ac37b46 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 30 Sep 2024 06:46:36 -0700 Subject: [PATCH 33/47] Fix tests. Don't update ability estimate for the unvalidated Cat. Handle unvalidated remaining items separately --- src/__tests__/clowder.test.ts | 183 ++++++++++++++++++++++++++-------- src/clowder.ts | 78 +++++++++------ 2 files changed, 190 insertions(+), 71 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 8b42f6c..380a739 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -43,7 +43,6 @@ describe('Clowder Class', () => { it('initializes with provided cats and corpora', () => { expect(Object.keys(clowder.cats)).toContain('cat1'); - expect(Object.keys(clowder.cats)).toContain('unvalidated'); // Ensure 'unvalidated' cat is present expect(clowder.remainingItems).toHaveLength(5); expect(clowder.corpus).toHaveLength(5); expect(clowder.seenItems).toHaveLength(0); @@ -92,7 +91,7 @@ describe('Clowder Class', () => { it('throws an error when updating ability estimates for an invalid cat', () => { expect(() => clowder.updateAbilityEstimates(['invalidCatName'], createStimulus('1'), [0])).toThrowError( - 'Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.', + 'Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.', ); }); @@ -109,7 +108,6 @@ describe('Clowder Class', () => { const expected = { cat1: clowder.cats['cat1'][property as keyof Cat], cat2: clowder.cats['cat2'][property as keyof Cat], - unvalidated: clowder.cats['unvalidated'][property as keyof Cat], }; expect(clowder[property as keyof Clowder]).toEqual(expected); }); @@ -138,7 +136,7 @@ describe('Clowder Class', () => { catToSelect: 'cat1', catsToUpdate: ['invalidCatName', 'cat2'], }); - }).toThrow('Invalid Cat name. Expected one of cat1, cat2, unvalidated. Received invalidCatName.'); + }).toThrow('Invalid Cat name. Expected one of cat1, cat2. Received invalidCatName.'); }); it('updates seen and remaining items', () => { @@ -207,6 +205,110 @@ describe('Clowder Class', () => { expect(['0', '2']).toContain(nextItem?.id); }); + it('should select an unvalidated item if catToSelect is "unvalidated"', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap([])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap([])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + const nDraws = 50; + // Simulate sDraws unvalidated items being selected + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for (const _ of Array(nDraws).fill(0)) { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'unvalidated', + }); + + expect(['0', '2']).toContain(nextItem?.id); + } + }); + + it('should not update cats with items that do not have parameters for that cat', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap(['cat2'])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat2'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + clowder.updateCatAndGetNextItem({ + catsToUpdate: ['cat1', 'cat2'], + items: clowder.corpus, + answers: [1, 1, 1, 1], + catToSelect: 'unvalidated', + }); + + expect(clowder.nItems.cat1).toBe(2); + expect(clowder.nItems.cat2).toBe(2); + }); + + it('should not update any cats if only unvalidated items have been seen', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap([])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap([])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + clowder.updateCatAndGetNextItem({ + catsToUpdate: ['cat1'], + items: [clowder.corpus[0], clowder.corpus[2]], + answers: [1, 1], + catToSelect: 'unvalidated', + }); + + expect(clowder.nItems.cat1).toBe(0); + }); + + it('should return undefined for next item if catToSelect = "unvalidated" and no unvalidated items remain', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap([])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('2', [createZetaCatMap([])]), + createMultiZetaStimulus('3', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + const nextItem = clowder.updateCatAndGetNextItem({ + catsToUpdate: ['cat1'], + items: [clowder.corpus[0], clowder.corpus[2]], + answers: [1, 1], + catToSelect: 'unvalidated', + }); + + expect(nextItem).toBeUndefined(); + }); + it('should correctly update ability estimates during the updateCatAndGetNextItem method', () => { const originalTheta = clowder.cats.cat1.theta; clowder.updateCatAndGetNextItem({ @@ -249,15 +351,12 @@ describe('Clowder Class', () => { }); it('should return undefined if no more items remain', () => { - clowder.updateCatAndGetNextItem({ + const nextItem = clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', items: clowder.remainingItems, answers: [1, 0, 1, 1, 0], // Exhaust all items }); - const nextItem = clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - }); expect(nextItem).toBeUndefined(); }); @@ -341,8 +440,8 @@ describe('Clowder Early Stopping', () => { cats: { cat1: { method: 'MLE', theta: 0.5 } }, corpus: [ createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), - createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), - createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), // This item should trigger early stopping + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), // This item should trigger early stopping + createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), ], earlyStopping, }); @@ -353,17 +452,13 @@ describe('Clowder Early Stopping', () => { items: [clowder.corpus[0]], answers: [1], }); - clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - catsToUpdate: ['cat1'], - items: [clowder.corpus[1]], - answers: [1], - }); + + expect(clowder.earlyStopping?.earlyStop).toBe(false); const nextItem = clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', catsToUpdate: ['cat1'], - items: [clowder.corpus[2]], + items: [clowder.corpus[1]], answers: [1], }); @@ -378,35 +473,41 @@ describe('Clowder Early Stopping', () => { tolerance: { cat1: 0.01 }, }); - clowder = new Clowder({ - cats: { cat1: { method: 'MLE', theta: 0.5 } }, - corpus: [ - createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), - createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), - ], - earlyStopping, + const zetaMap = createZetaCatMap(['cat1'], { + a: 6, + b: 6, + c: 0, + d: 1, }); - // First update - clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - catsToUpdate: ['cat1'], - items: [clowder.corpus[0]], - answers: [1], - }); - // pringing results - console.log('SE Measurements:', clowder.earlyStopping?.seMeasurementThreshold, clowder.cats.cat1); + const corpus = [ + createMultiZetaStimulus('0', [zetaMap]), + createMultiZetaStimulus('1', [zetaMap]), + createMultiZetaStimulus('2', [zetaMap]), // Here the SE measurement drops below threshold + createMultiZetaStimulus('3', [zetaMap]), // And here, early stopping should be triggered because it has been below threshold for 2 items + ]; - const nextItem = clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - catsToUpdate: ['cat1'], - items: [clowder.corpus[1]], - answers: [1], + clowder = new Clowder({ + cats: { cat1: { method: 'MLE', theta: 0.5 } }, + corpus, + earlyStopping, }); - console.log('Early Stop Triggered:', clowder.earlyStopping?.earlyStop); + for (const item of corpus) { + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [item], + answers: [1], + }); - expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after SE drops below threshold - expect(nextItem).toBe(undefined); // No further items should be selected + if (item.id === '3') { + expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after SE drops below threshold + expect(nextItem).toBe(undefined); // No further items should be selected + } else { + expect(clowder.earlyStopping?.earlyStop).toBe(false); + expect(nextItem).toBeDefined(); + } + } }); }); diff --git a/src/clowder.ts b/src/clowder.ts index 49c4ee7..c39cb46 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -57,9 +57,12 @@ export class Clowder { * @throws {Error} - Throws an error if any item in the corpus has duplicated IRT parameters for any Cat name. */ constructor({ cats, corpus, randomSeed = null, earlyStopping }: ClowderInput) { + // TODO: Add some imput validation to both the cats and the corpus to make sure that "unvalidated" is not used as a cat name. + // If so, throw an error saying that "unvalidated" is a reserved name and may not be used. + // TODO: Also add a test of this behavior. this._cats = { ..._mapValues(cats, (catInput) => new Cat(catInput)), - unvalidated: new Cat(), // Add 'unvalidated' cat + unvalidated: new Cat({ itemSelect: 'random', randomSeed }), // Add 'unvalidated' cat }; this._seenItems = []; checkNoDuplicateCatNames(corpus); @@ -74,12 +77,14 @@ export class Clowder { * Throw an error if the Cat name is not found. * * @param {string} catName - The name of the Cat instance to validate. + * @param {boolean} allowUnvalidated - Whether to allow the reserved 'unvalidated' name. * * @throws {Error} - Throws an error if the provided Cat name is not found among the existing Cat instances. */ - private _validateCatName(catName: string): void { - if (!Object.prototype.hasOwnProperty.call(this._cats, catName)) { - throw new Error(`Invalid Cat name. Expected one of ${Object.keys(this._cats).join(', ')}. Received ${catName}.`); + private _validateCatName(catName: string, allowUnvalidated = false): void { + const allowedCats = allowUnvalidated ? this._cats : this.cats; + if (!Object.prototype.hasOwnProperty.call(allowedCats, catName)) { + throw new Error(`Invalid Cat name. Expected one of ${Object.keys(allowedCats).join(', ')}. Received ${catName}.`); } } @@ -87,7 +92,7 @@ export class Clowder { * The named Cat instances that this Clowder manages. */ public get cats() { - return this._cats; + return _omit(this._cats, ['unvalidated']); } /** @@ -162,7 +167,7 @@ export class Clowder { */ public updateAbilityEstimates(catNames: string[], zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method?: string) { catNames.forEach((catName) => { - this._validateCatName(catName); + this._validateCatName(catName, false); }); for (const catName of catNames) { this.cats[catName].updateAbilityEstimate(zeta, answer, method); @@ -216,10 +221,14 @@ export class Clowder { itemSelect?: string; randomlySelectUnvalidated?: boolean; }): Stimulus | undefined { - this._validateCatName(catToSelect); + // +----------+ + // ----------| Update |----------| + // +----------+ + + this._validateCatName(catToSelect, true); catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; catsToUpdate.forEach((cat) => { - this._validateCatName(cat); + this._validateCatName(cat, false); }); // Convert items and answers to arrays @@ -257,28 +266,24 @@ export class Clowder { // retrieve only the item parameters that apply to this cat. stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), ); - const zetasAndAnswersForCat = itemsAndAnswersForCat - .map(([stim, _answer]) => { - const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => zeta.cats.includes(catName)); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined - }) - .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values - - // Unzip the zetas and answers, making sure the zetas array contains only Zeta types - const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; - - // Now, pass the filtered zetas and answers to the cat's updateAbilityEstimate method - this.cats[catName].updateAbilityEstimate(zetas, answers, method); - } - // Assign items with no valid parameters to the 'unvalidated' cat - const unvalidatedItemsAndAnswers = itemsAndAnswers.filter( - ([stim]) => !stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.length > 0), - ); - if (unvalidatedItemsAndAnswers.length > 0) { - const [zetas, answers] = _unzip(unvalidatedItemsAndAnswers) as [Zeta[], (0 | 1)[]]; - this.cats['unvalidated'].updateAbilityEstimate(zetas, answers, method); + if (itemsAndAnswersForCat.length > 0) { + const zetasAndAnswersForCat = itemsAndAnswersForCat + .map(([stim, _answer]) => { + const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => + zeta.cats.includes(catName), + ); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined + }) + .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values + + // Unzip the zetas and answers, making sure the zetas array contains only Zeta types + const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; + + // Now, pass the filtered zetas and answers to the cat's updateAbilityEstimate method + this.cats[catName].updateAbilityEstimate(zetas, answers, method); + } } if (this._earlyStopping) { @@ -292,9 +297,22 @@ export class Clowder { // ----------| Select |----------| // +----------+ + if (catToSelect === 'unvalidated') { + // Assign items with no valid parameters to the 'unvalidated' cat + const unvalidatedRemainingItems = this._remainingItems.filter( + (stim) => !stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.length > 0), + ); + + if (unvalidatedRemainingItems.length === 0) { + return undefined; + } else { + const randInt = Math.floor(this._rng() * unvalidatedRemainingItems.length); + return unvalidatedRemainingItems[randInt]; + } + } + // Now, we need to dynamically calculate the stimuli available for selection by `catToSelect`. // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` - const { available, missing } = filterItemsByCatParameterAvailability(this._remainingItems, catToSelect); // The cat expects an array of Stimulus objects, with the zeta parameters From 77a3bfca26517f693348bf8a5ba969bf83f40d52 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Tue, 1 Oct 2024 09:19:48 -0700 Subject: [PATCH 34/47] adding returnUndefinedOnExhaustion parameter and test --- src/__tests__/clowder.test.ts | 30 ++++++++++++++++++++++++++++++ src/clowder.ts | 30 +++++++++++++++--------------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 380a739..5faee14 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -95,6 +95,36 @@ describe('Clowder Class', () => { ); }); + it('should return undefined if no validated items remain and returnUndefinedOnExhaustion is true', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), + createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), + ], + }; + + const clowder = new Clowder(clowderInput); + + // Use all the validated items for cat1 + clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + catsToUpdate: ['cat1'], + items: [clowder.corpus[0], clowder.corpus[1]], + answers: [1, 1], + }); + + // Try to get another validated item for cat1 with returnUndefinedOnExhaustion set to true + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + returnUndefinedOnExhaustion: true, + }); + + expect(nextItem).toBeUndefined(); + }); + it.each` property ${'theta'} diff --git a/src/clowder.ts b/src/clowder.ts index c39cb46..04f44e1 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -212,6 +212,7 @@ export class Clowder { method, itemSelect, randomlySelectUnvalidated = false, + returnUndefinedOnExhaustion = false, // New parameter }: { catToSelect: string; catsToUpdate?: string | string[]; @@ -220,11 +221,11 @@ export class Clowder { method?: string; itemSelect?: string; randomlySelectUnvalidated?: boolean; + returnUndefinedOnExhaustion?: boolean; // New parameter type }): Stimulus | undefined { // +----------+ // ----------| Update |----------| // +----------+ - this._validateCatName(catToSelect, true); catsToUpdate = Array.isArray(catsToUpdate) ? catsToUpdate : [catsToUpdate]; catsToUpdate.forEach((cat) => { @@ -293,6 +294,7 @@ export class Clowder { } } + // Handle the 'unvalidated' cat selection // +----------+ // ----------| Select |----------| // +----------+ @@ -310,7 +312,6 @@ export class Clowder { return unvalidatedRemainingItems[randInt]; } } - // Now, we need to dynamically calculate the stimuli available for selection by `catToSelect`. // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` const { available, missing } = filterItemsByCatParameterAvailability(this._remainingItems, catToSelect); @@ -340,26 +341,25 @@ export class Clowder { 'guessing', 'slipping', ]); - // Again `nextStimulus` will be a Stimulus object, or `undefined` if no further validated stimuli are available. // We need to convert the Stimulus object back to a MultiZetaStimulus object to return to the user. const returnStimulus: MultiZetaStimulus | undefined = available.find((stim) => _isEqual(stim, nextStimulusWithoutZeta), ); - if (missing.length === 0) { - // If there are no more unvalidated stimuli, we only have validated items left. - // Use the Cat to find the next item. The Cat may return undefined if all validated items have been seen. - return returnStimulus; - } else if (available.length === 0) { - // In this case, there are no more validated items left. Choose an unvalidated item at random. - return missing[Math.floor(this._rng() * missing.length)]; - } else { - // In this case, there are both validated and unvalidated items left. - // We randomly insert unvalidated items - if (!randomlySelectUnvalidated) { - return returnStimulus; + // Determine behavior based on returnUndefinedOnExhaustion + if (available.length === 0) { + // If returnUndefinedOnExhaustion is true and no validated items remain for the specified catToSelect, return undefined. + if (returnUndefinedOnExhaustion) { + return undefined; // Return undefined if no validated items remain + } else { + // If returnUndefinedOnExhaustion is false, proceed with the fallback mechanism to select an item from other available categories. + return missing[Math.floor(this._rng() * missing.length)]; } + } else if (missing.length === 0 || !randomlySelectUnvalidated) { + return returnStimulus; // Return validated item if available + } else { + // Randomly decide whether to return a validated or unvalidated item const random = Math.random(); const numRemaining = { available: available.length, missing: missing.length }; return random < numRemaining.missing / (numRemaining.available + numRemaining.missing) From 06d170971aac58d941a1370b285f5d79244f0ee5 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Wed, 2 Oct 2024 11:42:53 -0700 Subject: [PATCH 35/47] implementing suggestions --- src/__tests__/clowder.test.ts | 49 +++++++++++++++++++++++++++++------ src/clowder.ts | 16 ++++++++---- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 5faee14..3c6f915 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -125,6 +125,36 @@ describe('Clowder Class', () => { expect(nextItem).toBeUndefined(); }); + it('should return an item from missing if catToSelect is "unvalidated", no unvalidated items remain, and returnUndefinedOnExhaustion is false', () => { + const clowderInput: ClowderInput = { + cats: { + cat1: { method: 'MLE', theta: 0.5 }, + }, + corpus: [ + createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), // Validated item + createMultiZetaStimulus('1', [createZetaCatMap([])]), // Unvalidated item + ], + }; + + const clowder = new Clowder(clowderInput); + + // Exhaust the unvalidated items + clowder.updateCatAndGetNextItem({ + catToSelect: 'unvalidated', + items: [clowder.corpus[1]], + answers: [1], + }); + + // Attempt to get another unvalidated item with returnUndefinedOnExhaustion set to false + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'unvalidated', + returnUndefinedOnExhaustion: false, + }); + + // Should return the validated item since no unvalidated items remain + expect(nextItem?.id).toBe('0'); + }); + it.each` property ${'theta'} @@ -212,27 +242,30 @@ describe('Clowder Class', () => { expect(nextItem?.id).toMatch(/^(0|1)$/); }); - it('should select an unvalidated item if no validated items remain', () => { + it('should return an item from missing if no validated items remain and returnUndefinedOnExhaustion is false', () => { const clowderInput: ClowderInput = { cats: { cat1: { method: 'MLE', theta: 0.5 }, + cat2: { method: 'EAP', theta: -1.0 }, }, corpus: [ - createMultiZetaStimulus('0', [createZetaCatMap([])]), - createMultiZetaStimulus('1', [createZetaCatMap(['cat1'])]), - createMultiZetaStimulus('2', [createZetaCatMap([])]), + createMultiZetaStimulus('0', [createZetaCatMap(['cat2'])]), // Validated for cat2 + createMultiZetaStimulus('1', [createZetaCatMap(['cat2'])]), // Validated for cat2 + createMultiZetaStimulus('2', [createZetaCatMap(['cat2'])]), // Validated for cat2 ], }; + const clowder = new Clowder(clowderInput); + // Attempt to select an item for cat1, which has no validated items in the corpus const nextItem = clowder.updateCatAndGetNextItem({ catToSelect: 'cat1', - catsToUpdate: ['cat1'], - items: [clowder.corpus[1]], - answers: [1], + returnUndefinedOnExhaustion: false, // Ensure fallback is enabled }); + + // Should return an item from `missing`, which are items validated for cat2 expect(nextItem).toBeDefined(); - expect(['0', '2']).toContain(nextItem?.id); + expect(['0', '1', '2']).toContain(nextItem?.id); // Item ID should match any of the items for cat2 }); it('should select an unvalidated item if catToSelect is "unvalidated"', () => { diff --git a/src/clowder.ts b/src/clowder.ts index 04f44e1..6c9ef4a 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -212,7 +212,7 @@ export class Clowder { method, itemSelect, randomlySelectUnvalidated = false, - returnUndefinedOnExhaustion = false, // New parameter + returnUndefinedOnExhaustion = true, // New parameter }: { catToSelect: string; catsToUpdate?: string | string[]; @@ -299,22 +299,28 @@ export class Clowder { // ----------| Select |----------| // +----------+ + // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` + const { available, missing } = filterItemsByCatParameterAvailability(this._remainingItems, catToSelect); + + // Handle the 'unvalidated' cat selection if (catToSelect === 'unvalidated') { - // Assign items with no valid parameters to the 'unvalidated' cat const unvalidatedRemainingItems = this._remainingItems.filter( (stim) => !stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.length > 0), ); if (unvalidatedRemainingItems.length === 0) { + // If returnUndefinedOnExhaustion is false, return an item from 'missing' + if (!returnUndefinedOnExhaustion && missing.length > 0) { + const randInt = Math.floor(this._rng() * missing.length); + return missing[randInt]; + } + return undefined; } else { const randInt = Math.floor(this._rng() * unvalidatedRemainingItems.length); return unvalidatedRemainingItems[randInt]; } } - // Now, we need to dynamically calculate the stimuli available for selection by `catToSelect`. - // We inspect the remaining items and find ones that have zeta parameters for `catToSelect` - const { available, missing } = filterItemsByCatParameterAvailability(this._remainingItems, catToSelect); // The cat expects an array of Stimulus objects, with the zeta parameters // spread at the top-level of each Stimulus object. So we need to convert From d641960907b03867c008567feed65c3e812726f4 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Sat, 5 Oct 2024 03:25:57 -0700 Subject: [PATCH 36/47] Separate the stopping classes so that they don't share the same input --- src/__tests__/clowder.test.ts | 43 ++++++---- src/__tests__/stopping.test.ts | 60 +++++++++++--- src/clowder.ts | 2 +- src/stopping.ts | 141 ++++++++++++++++++++++----------- 4 files changed, 168 insertions(+), 78 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 3c6f915..1dcb723 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -133,6 +133,7 @@ describe('Clowder Class', () => { corpus: [ createMultiZetaStimulus('0', [createZetaCatMap(['cat1'])]), // Validated item createMultiZetaStimulus('1', [createZetaCatMap([])]), // Unvalidated item + createMultiZetaStimulus('2', [createZetaCatMap(['cat1'])]), // Unvalidated item ], }; @@ -145,14 +146,19 @@ describe('Clowder Class', () => { answers: [1], }); - // Attempt to get another unvalidated item with returnUndefinedOnExhaustion set to false - const nextItem = clowder.updateCatAndGetNextItem({ - catToSelect: 'unvalidated', - returnUndefinedOnExhaustion: false, - }); + const nDraws = 50; + // Simulate sDraws unvalidated items being selected + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for (const _ of Array(nDraws).fill(0)) { + // Attempt to get another unvalidated item with returnUndefinedOnExhaustion set to false + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'unvalidated', + returnUndefinedOnExhaustion: false, + }); - // Should return the validated item since no unvalidated items remain - expect(nextItem?.id).toBe('0'); + // Should return a validated item since no unvalidated items remain + expect(['0', '2']).toContain(nextItem?.id); // Item ID should match any of the items for cat2 + } }); it.each` @@ -251,21 +257,24 @@ describe('Clowder Class', () => { corpus: [ createMultiZetaStimulus('0', [createZetaCatMap(['cat2'])]), // Validated for cat2 createMultiZetaStimulus('1', [createZetaCatMap(['cat2'])]), // Validated for cat2 - createMultiZetaStimulus('2', [createZetaCatMap(['cat2'])]), // Validated for cat2 + createMultiZetaStimulus('2', [createZetaCatMap([])]), // Unvalidated ], }; const clowder = new Clowder(clowderInput); - // Attempt to select an item for cat1, which has no validated items in the corpus - const nextItem = clowder.updateCatAndGetNextItem({ - catToSelect: 'cat1', - returnUndefinedOnExhaustion: false, // Ensure fallback is enabled - }); - - // Should return an item from `missing`, which are items validated for cat2 - expect(nextItem).toBeDefined(); - expect(['0', '1', '2']).toContain(nextItem?.id); // Item ID should match any of the items for cat2 + // Should return an item from `missing`, which are items validated for cat2 or unvalidated + const nDraws = 50; + // Simulate sDraws unvalidated items being selected + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for (const _ of Array(nDraws).fill(0)) { + // Attempt to select an item for cat1, which has no validated items in the corpus + const nextItem = clowder.updateCatAndGetNextItem({ + catToSelect: 'cat1', + returnUndefinedOnExhaustion: false, // Ensure fallback is enabled + }); + expect(['0', '1', '2']).toContain(nextItem?.id); // Item ID should match any of the items for cat2 + } }); it('should select an unvalidated item if catToSelect is "unvalidated"', () => { diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index f34f469..d8aeb5c 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -2,19 +2,49 @@ import { Cat } from '..'; import { CatMap } from '../type'; import { EarlyStopping, - EarlyStoppingInput, StopAfterNItems, + StopAfterNItemsInput, StopIfSEMeasurementBelowThreshold, + StopIfSEMeasurementBelowThresholdInput, StopOnSEMeasurementPlateau, + StopOnSEMeasurementPlateauInput, } from '../stopping'; import { toBeBoolean } from 'jest-extended'; expect.extend({ toBeBoolean }); -const testInstantiation = (earlyStopping: EarlyStopping, input: EarlyStoppingInput) => { - expect(earlyStopping.patience).toEqual(input.patience ?? {}); - expect(earlyStopping.tolerance).toEqual(input.tolerance ?? {}); - expect(earlyStopping.requiredItems).toEqual(input.requiredItems ?? {}); - expect(earlyStopping.seMeasurementThreshold).toEqual(input.seMeasurementThreshold ?? {}); +type Class = new (...args: any[]) => T; + +const testLogicalOperationValidation = ( + stoppingClass: Class, + input: StopAfterNItemsInput | StopIfSEMeasurementBelowThresholdInput | StopOnSEMeasurementPlateauInput, +) => { + expect(() => new stoppingClass(input)).toThrowError( + `Invalid logical operation. Expected "and" or "or". Received "${input.logicalOperation}"`, + ); +}; + +const testInstantiation = ( + earlyStopping: EarlyStopping, + input: StopAfterNItemsInput | StopIfSEMeasurementBelowThresholdInput | StopOnSEMeasurementPlateauInput, +) => { + if (earlyStopping instanceof StopAfterNItems) { + expect(earlyStopping.requiredItems).toEqual((input as StopAfterNItems).requiredItems ?? {}); + } + + if ( + earlyStopping instanceof StopOnSEMeasurementPlateau || + earlyStopping instanceof StopIfSEMeasurementBelowThreshold + ) { + expect(earlyStopping.patience).toEqual((input as StopOnSEMeasurementPlateauInput).patience ?? {}); + expect(earlyStopping.tolerance).toEqual((input as StopOnSEMeasurementPlateauInput).tolerance ?? {}); + } + + if (earlyStopping instanceof StopIfSEMeasurementBelowThreshold) { + expect(earlyStopping.seMeasurementThreshold).toEqual( + (input as StopIfSEMeasurementBelowThresholdInput).seMeasurementThreshold ?? {}, + ); + } + expect(earlyStopping.logicalOperation).toBe(input.logicalOperation?.toLowerCase() ?? 'or'); expect(earlyStopping.earlyStop).toBeBoolean(); }; @@ -62,7 +92,7 @@ describe.each` ${'or'} `("StopOnSEMeasurementPlateau (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { let earlyStopping: StopOnSEMeasurementPlateau; - let input: EarlyStoppingInput; + let input: StopOnSEMeasurementPlateauInput; beforeEach(() => { input = { @@ -74,7 +104,8 @@ describe.each` }); it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); - + it('validates input', () => + testLogicalOperationValidation(StopOnSEMeasurementPlateau, { ...input, logicalOperation: 'invalid' as 'and' })); it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); it('stops when the seMeasurement has plateaued', () => { @@ -270,7 +301,7 @@ describe.each` ${'or'} `("StopAfterNItems (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { let earlyStopping: StopAfterNItems; - let input: EarlyStoppingInput; + let input: StopAfterNItemsInput; beforeEach(() => { input = { @@ -281,7 +312,8 @@ describe.each` }); it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); - + it('validates input', () => + testLogicalOperationValidation(StopAfterNItems, { ...input, logicalOperation: 'invalid' as 'and' })); it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); it('does not step when it has not seen required items', () => { @@ -424,7 +456,7 @@ describe.each` ${'or'} `("StopIfSEMeasurementBelowThreshold (with logicalOperation='$logicalOperation'", ({ logicalOperation }) => { let earlyStopping: StopIfSEMeasurementBelowThreshold; - let input: EarlyStoppingInput; + let input: StopIfSEMeasurementBelowThresholdInput; beforeEach(() => { input = { @@ -437,7 +469,11 @@ describe.each` }); it('instantiates with input parameters', () => testInstantiation(earlyStopping, input)); - + it('validates input', () => + testLogicalOperationValidation(StopIfSEMeasurementBelowThreshold, { + ...input, + logicalOperation: 'invalid' as 'and', + })); it('updates internal state when new measurements are added', () => testInternalState(earlyStopping)); it('stops when the seMeasurement has fallen below a threshold', () => { diff --git a/src/clowder.ts b/src/clowder.ts index 6c9ef4a..dd7c43d 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -212,7 +212,7 @@ export class Clowder { method, itemSelect, randomlySelectUnvalidated = false, - returnUndefinedOnExhaustion = true, // New parameter + returnUndefinedOnExhaustion = true, }: { catToSelect: string; catsToUpdate?: string | string[]; diff --git a/src/stopping.ts b/src/stopping.ts index ccf41ba..9311b6d 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -6,16 +6,29 @@ import _uniq from 'lodash/uniq'; * Interface for input parameters to EarlyStopping classes. */ export interface EarlyStoppingInput { + /** The logical operation to use to evaluate multiple stopping conditions */ + logicalOperation?: 'and' | 'or' | 'AND' | 'OR'; +} + +export interface StopAfterNItemsInput extends EarlyStoppingInput { + /** Number of items to require before stopping */ + requiredItems: CatMap; +} + +export interface StopOnSEMeasurementPlateauInput extends EarlyStoppingInput { /** Number of items to wait for before triggering early stopping */ - patience?: CatMap; + patience: CatMap; /** Tolerance for standard error of measurement drop */ tolerance?: CatMap; - /** Number of items to require before stopping */ - requiredItems?: CatMap; +} + +export interface StopIfSEMeasurementBelowThresholdInput extends EarlyStoppingInput { /** Stop if the standard error of measurement drops below this level */ - seMeasurementThreshold?: CatMap; - /** The logical operation to use to evaluate multiple stopping conditions */ - logicalOperation?: 'and' | 'or' | 'AND' | 'OR'; + seMeasurementThreshold: CatMap; + /** Number of items to wait for before triggering early stopping */ + patience?: CatMap; + /** Tolerance for standard error of measurement drop */ + tolerance?: CatMap; } /** @@ -23,57 +36,22 @@ export interface EarlyStoppingInput { */ export abstract class EarlyStopping { protected _earlyStop: boolean; - protected _patience: CatMap; - protected _tolerance: CatMap; - protected _requiredItems: CatMap; - protected _seMeasurementThreshold: CatMap; protected _nItems: CatMap; protected _seMeasurements: CatMap; protected _logicalOperation: 'and' | 'or'; - constructor({ - patience = {}, - tolerance = {}, - requiredItems = {}, - seMeasurementThreshold = {}, - logicalOperation = 'or', - }: EarlyStoppingInput) { - // TODO: Add some input validation here - // logicalOperation.toLowerCase() should be 'and' or 'or' - this._patience = patience; - this._tolerance = tolerance; - this._requiredItems = requiredItems; - this._seMeasurementThreshold = seMeasurementThreshold; + constructor({ logicalOperation = 'or' }: EarlyStoppingInput) { this._seMeasurements = {}; this._nItems = {}; this._earlyStop = false; - this._logicalOperation = logicalOperation.toLowerCase() as 'and' | 'or'; - } - - public get evaluationCats() { - return _uniq([ - ...Object.keys(this._patience), - ...Object.keys(this._tolerance), - ...Object.keys(this._requiredItems), - ...Object.keys(this._seMeasurementThreshold), - ]); - } - public get patience() { - return this._patience; - } - - public get tolerance() { - return this._tolerance; - } - - public get requiredItems() { - return this._requiredItems; + if (!['and', 'or'].includes(logicalOperation.toLowerCase())) { + throw new Error(`Invalid logical operation. Expected "and" or "or". Received "${logicalOperation}"`); + } + this._logicalOperation = logicalOperation.toLowerCase() as 'and' | 'or'; } - public get seMeasurementThreshold() { - return this._seMeasurementThreshold; - } + public abstract get evaluationCats(): string[]; public get earlyStop() { return this._earlyStop; @@ -135,6 +113,27 @@ export abstract class EarlyStopping { * Class implementing early stopping based on a plateau in standard error of measurement. */ export class StopOnSEMeasurementPlateau extends EarlyStopping { + protected _patience: CatMap; + protected _tolerance: CatMap; + + constructor(input: StopOnSEMeasurementPlateauInput) { + super(input); + this._patience = input.patience; + this._tolerance = input.tolerance ?? {}; + } + + public get evaluationCats() { + return _uniq([...Object.keys(this._patience), ...Object.keys(this._tolerance)]); + } + + public get patience() { + return this._patience; + } + + public get tolerance() { + return this._tolerance; + } + protected _evaluateStoppingCondition(catToEvaluate: string) { const seMeasurements = this._seMeasurements[catToEvaluate]; @@ -161,9 +160,24 @@ export class StopOnSEMeasurementPlateau extends EarlyStopping { * Class implementing early stopping after a certain number of items. */ export class StopAfterNItems extends EarlyStopping { + protected _requiredItems: CatMap; + + constructor(input: StopAfterNItemsInput) { + super(input); + this._requiredItems = input.requiredItems; + } + + public get requiredItems() { + return this._requiredItems; + } + + public get evaluationCats() { + return Object.keys(this._requiredItems); + } + protected _evaluateStoppingCondition(catToEvaluate: string) { const requiredItems = this._requiredItems[catToEvaluate]; - const nItems = this._nItems[catToEvaluate] ?? 0; + const nItems = this._nItems[catToEvaluate]; let earlyStop = false; @@ -179,6 +193,37 @@ export class StopAfterNItems extends EarlyStopping { * Class implementing early stopping if the standard error of measurement drops below a certain threshold. */ export class StopIfSEMeasurementBelowThreshold extends EarlyStopping { + protected _patience: CatMap; + protected _tolerance: CatMap; + protected _seMeasurementThreshold: CatMap; + + constructor(input: StopIfSEMeasurementBelowThresholdInput) { + super(input); + this._seMeasurementThreshold = input.seMeasurementThreshold; + this._patience = input.patience ?? {}; + this._tolerance = input.tolerance ?? {}; + } + + public get patience() { + return this._patience; + } + + public get tolerance() { + return this._tolerance; + } + + public get seMeasurementThreshold() { + return this._seMeasurementThreshold; + } + + public get evaluationCats() { + return _uniq([ + ...Object.keys(this._patience), + ...Object.keys(this._tolerance), + ...Object.keys(this._seMeasurementThreshold), + ]); + } + protected _evaluateStoppingCondition(catToEvaluate: string) { const seMeasurements = this._seMeasurements[catToEvaluate] ?? []; const seThreshold = this._seMeasurementThreshold[catToEvaluate] ?? 0; From 07e1f58b070bf960d623160ab6bf71a049fbd898 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Mon, 7 Oct 2024 12:53:53 -0700 Subject: [PATCH 37/47] updating cats for clowder --- src/__tests__/stopping.test.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index d8aeb5c..7da8222 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -12,6 +12,7 @@ import { import { toBeBoolean } from 'jest-extended'; expect.extend({ toBeBoolean }); +/* eslint-disable @typescript-eslint/no-explicit-any */ type Class = new (...args: any[]) => T; const testLogicalOperationValidation = ( From a0ac2660fab96949a0aaff1dd97962b554460bb6 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Thu, 17 Oct 2024 16:18:35 -0700 Subject: [PATCH 38/47] clowder changes based on letter implementation --- src/clowder.ts | 56 +++++++++++++++++++++++++------------------------ src/index.ts | 6 ++++++ src/stopping.ts | 42 +++++++++++++++++++++++++++++-------- 3 files changed, 68 insertions(+), 36 deletions(-) diff --git a/src/clowder.ts b/src/clowder.ts index dd7c43d..f037067 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -257,33 +257,35 @@ export class Clowder { const itemsAndAnswers = _zip(items, answers) as [Stimulus, 0 | 1][]; // Update the ability estimate for all validated cats - for (const catName of catsToUpdate) { - const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => - // We are dealing with a single item in this function. This single item - // has an array of zeta parameters for a bunch of different Cats. We - // need to determine if `catName` is present in that list. So we first - // reduce the zetas to get all of the applicabe cat names. - // Now that we have the subset of items that can apply to this cat, - // retrieve only the item parameters that apply to this cat. - stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), - ); - - if (itemsAndAnswersForCat.length > 0) { - const zetasAndAnswersForCat = itemsAndAnswersForCat - .map(([stim, _answer]) => { - const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => - zeta.cats.includes(catName), - ); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined - }) - .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values - - // Unzip the zetas and answers, making sure the zetas array contains only Zeta types - const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; - - // Now, pass the filtered zetas and answers to the cat's updateAbilityEstimate method - this.cats[catName].updateAbilityEstimate(zetas, answers, method); + if (catsToUpdate.includes(catToSelect)) { + for (const catName of catsToUpdate) { + const itemsAndAnswersForCat = itemsAndAnswers?.filter(([stim]) => + // We are dealing with a single item in this function. This single item + // has an array of zeta parameters for a bunch of different Cats. We + // need to determine if `catName` is present in that list. So we first + // reduce the zetas to get all of the applicabe cat names. + // Now that we have the subset of items that can apply to this cat, + // retrieve only the item parameters that apply to this cat. + stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), + ); + + if (itemsAndAnswersForCat.length > 0) { + const zetasAndAnswersForCat = itemsAndAnswersForCat + .map(([stim, _answer]) => { + const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => + zeta.cats.includes(catName), + ); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined + }) + .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values + + // Unzip the zetas and answers, making sure the zetas array contains only Zeta types + const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; + + // Now call updateAbilityEstimates for this cat + this.updateAbilityEstimates([catName], zetas, answers, method); + } } } diff --git a/src/index.ts b/src/index.ts index 00c4cbe..4529537 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,2 +1,8 @@ export { Cat, CatInput } from './cat'; export { Clowder, ClowderInput } from './clowder'; +export { + EarlyStopping, + StopAfterNItems, + StopOnSEMeasurementPlateau, + StopIfSEMeasurementBelowThreshold, +} from './stopping'; diff --git a/src/stopping.ts b/src/stopping.ts index 9311b6d..a66707c 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -38,17 +38,17 @@ export abstract class EarlyStopping { protected _earlyStop: boolean; protected _nItems: CatMap; protected _seMeasurements: CatMap; - protected _logicalOperation: 'and' | 'or'; + protected _logicalOperation: 'and' | 'or' | 'only'; constructor({ logicalOperation = 'or' }: EarlyStoppingInput) { this._seMeasurements = {}; this._nItems = {}; this._earlyStop = false; - if (!['and', 'or'].includes(logicalOperation.toLowerCase())) { - throw new Error(`Invalid logical operation. Expected "and" or "or". Received "${logicalOperation}"`); + if (!['and', 'or', 'only'].includes(logicalOperation.toLowerCase())) { + throw new Error(`Invalid logical operation. Expected "and", "or", or "only". Received "${logicalOperation}"`); } - this._logicalOperation = logicalOperation.toLowerCase() as 'and' | 'or'; + this._logicalOperation = logicalOperation.toLowerCase() as 'and' | 'or' | 'only'; } public abstract get evaluationCats(): string[]; @@ -96,15 +96,39 @@ export abstract class EarlyStopping { * Abstract method to be implemented by subclasses to update the early stopping strategy. * @param {CatMap} cats - A map of cats to update. */ - public update(cats: CatMap): void { - this._updateCats(cats); + public update(cats: CatMap, catToSelect?: string): void { + this._updateCats(cats); // This updates internal state with current cat data + // Iterate over each cat and update the _nItems map + for (const catName in cats) { + const cat = cats[catName]; + const nItems = cat.nItems; // Get the current number of items for this cat + + // Update the _nItems map with the current nItems value + if (nItems !== undefined) { + this._nItems[catName] = nItems; // Make sure nItems is set for this cat + } + } + + // Collect the stopping conditions for all cats const conditions: boolean[] = this.evaluationCats.map((catName) => this._evaluateStoppingCondition(catName)); + // Evaluate the stopping condition based on the logical operation if (this._logicalOperation === 'and') { - this._earlyStop = conditions.every(Boolean); - } else { - this._earlyStop = conditions.some(Boolean); + this._earlyStop = conditions.every(Boolean); // All conditions must be true for 'and' + } else if (this._logicalOperation === 'or') { + this._earlyStop = conditions.some(Boolean); // Any condition can be true for 'or' + } else if (this._logicalOperation === 'only') { + if (catToSelect === undefined) { + throw new Error('Must provide a cat to select for "only" stopping condition'); + } + + // Evaluate the stopping condition for the selected cat + if (this.evaluationCats.includes(catToSelect)) { + this._earlyStop = this._evaluateStoppingCondition(catToSelect); + } else { + this._earlyStop = false; // Default to false if the selected cat is not in evaluationCats + } } } } From 3674a1f3d77acfce03e405113e0266d5061b768f Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Fri, 18 Oct 2024 14:08:40 -0700 Subject: [PATCH 39/47] addressing all lines of code for testing --- src/__tests__/clowder.test.ts | 4 +- src/__tests__/stopping.test.ts | 142 ++++++++++++++++++++++++++++++++- src/clowder.ts | 2 +- src/stopping.ts | 4 +- 4 files changed, 146 insertions(+), 6 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 1dcb723..6f1a576 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -327,8 +327,8 @@ describe('Clowder Class', () => { catToSelect: 'unvalidated', }); - expect(clowder.nItems.cat1).toBe(2); - expect(clowder.nItems.cat2).toBe(2); + expect(clowder.nItems.cat1).toBe(0); + expect(clowder.nItems.cat2).toBe(0); }); it('should not update any cats if only unvalidated items have been seen', () => { diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index 7da8222..4a9d9c7 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -20,7 +20,7 @@ const testLogicalOperationValidation = ( input: StopAfterNItemsInput | StopIfSEMeasurementBelowThresholdInput | StopOnSEMeasurementPlateauInput, ) => { expect(() => new stoppingClass(input)).toThrowError( - `Invalid logical operation. Expected "and" or "or". Received "${input.logicalOperation}"`, + `Invalid logical operation. Expected "and", "or", or "only". Received "${input.logicalOperation}"`, ); }; @@ -451,6 +451,146 @@ describe.each` }); }); +describe('EarlyStopping with logicalOperation "only"', () => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: StopOnSEMeasurementPlateauInput; + + beforeEach(() => { + input = { + patience: { cat1: 2, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + logicalOperation: 'only', + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('throws an error if catToSelect is not provided when logicalOperation is "only"', () => { + expect(() => { + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.5 } as any }, undefined); + }).toThrowError('Must provide a cat to select for "only" stopping condition'); + }); +}); + +describe('EarlyStopping with logicalOperation "only"', () => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: StopOnSEMeasurementPlateauInput; + + beforeEach(() => { + input = { + patience: { cat1: 2, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + logicalOperation: 'only', + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('evaluates the stopping condition when catToSelect is in evaluationCats', () => { + // Add updates to make sure cat1 is included in evaluationCats and has some measurements + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.5 } as any }, 'cat1'); + earlyStopping.update({ cat1: { nItems: 2, seMeasurement: 0.5 } as any }, 'cat1'); + + // Since 'cat1' is in evaluationCats, _earlyStop should be evaluated based on the stopping condition + expect(earlyStopping.earlyStop).toBe(true); // Should be true because seMeasurement has plateaued + }); +}); +describe('EarlyStopping with logicalOperation "only"', () => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: StopOnSEMeasurementPlateauInput; + + beforeEach(() => { + input = { + patience: { cat1: 2, cat2: 3 }, + tolerance: { cat1: 0.01, cat2: 0.02 }, + logicalOperation: 'only', + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('sets _earlyStop to false when catToSelect is not in evaluationCats', () => { + // Use 'cat3', which is not in the patience or tolerance maps (and thus not in evaluationCats) + earlyStopping.update({ cat3: { nItems: 1, seMeasurement: 0.5 } as any }, 'cat3'); + + // Since 'cat3' is not in evaluationCats, _earlyStop should be false + expect(earlyStopping.earlyStop).toBe(false); + }); +}); + +describe('StopIfSEMeasurementBelowThreshold with empty patience and tolerance', () => { + let earlyStopping: StopIfSEMeasurementBelowThreshold; + let input: StopIfSEMeasurementBelowThresholdInput; + + beforeEach(() => { + input = { + seMeasurementThreshold: { cat1: 0.03, cat2: 0.02 }, + logicalOperation: 'only', + }; + earlyStopping = new StopIfSEMeasurementBelowThreshold(input); + }); + + it('should handle updates correctly even with empty patience and tolerance', () => { + // Update the state with some measurements for cat2, where seMeasurement is below the threshold + earlyStopping.update({ cat2: { nItems: 1, seMeasurement: 0.01 } as any }, 'cat2'); + + // Since patience defaults to 1 and tolerance defaults to 0, early stopping should be triggered + expect(earlyStopping.earlyStop).toBe(true); + }); + + it('should not trigger early stopping when seMeasurement does not fall below the threshold', () => { + // Update the state with some measurements for cat1, where seMeasurement is above the threshold + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.05 } as any }, 'cat1'); + + // Early stopping should not be triggered because the seMeasurement is above the threshold + expect(earlyStopping.earlyStop).toBe(false); + }); +}); + +describe('StopIfSEMeasurementBelowThreshold with undefined seMeasurementThreshold for a category', () => { + let earlyStopping: StopIfSEMeasurementBelowThreshold; + let input: StopIfSEMeasurementBelowThresholdInput; + + beforeEach(() => { + input = { + seMeasurementThreshold: {}, // Empty object, meaning no thresholds are defined + patience: { cat1: 2 }, // Setting patience to 2 for cat1 + tolerance: { cat1: 0.01 }, // Small tolerance for cat1 + logicalOperation: 'only', + }; + earlyStopping = new StopIfSEMeasurementBelowThreshold(input); + }); + + it('should use a default seThreshold of 0 when seMeasurementThreshold is not defined for the category', () => { + // Update the state with measurements for cat1, ensuring to meet the patience requirement + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: -0.005 } as any }, 'cat1'); + earlyStopping.update({ cat1: { nItems: 2, seMeasurement: -0.01 } as any }, 'cat1'); + + // Early stopping should now be triggered because the seMeasurement has been below the default threshold of 0 for the patience period + expect(earlyStopping.earlyStop).toBe(true); + }); +}); + +describe('StopOnSEMeasurementPlateau without tolerance provided', () => { + let earlyStopping: StopOnSEMeasurementPlateau; + let input: StopOnSEMeasurementPlateauInput; + + beforeEach(() => { + input = { + patience: { cat1: 2 }, + // No tolerance is provided, it should default to an empty object + logicalOperation: 'only', + }; + earlyStopping = new StopOnSEMeasurementPlateau(input); + }); + + it('should handle updates without triggering early stopping when no tolerance is provided', () => { + // Update with measurements for cat1 that are not exactly equal, simulating tolerance as undefined + earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.5 } as any }, 'cat1'); + earlyStopping.update({ cat1: { nItems: 2, seMeasurement: 0.55 } as any }, 'cat1'); + + // Since tolerance is undefined, early stopping should not be triggered even if seMeasurements are slightly different + expect(earlyStopping.earlyStop).toBe(false); + }); +}); + describe.each` logicalOperation ${'and'} diff --git a/src/clowder.ts b/src/clowder.ts index f037067..19b20a9 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -259,7 +259,7 @@ export class Clowder { // Update the ability estimate for all validated cats if (catsToUpdate.includes(catToSelect)) { for (const catName of catsToUpdate) { - const itemsAndAnswersForCat = itemsAndAnswers?.filter(([stim]) => + const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => // We are dealing with a single item in this function. This single item // has an array of zeta parameters for a bunch of different Cats. We // need to determine if `catName` is present in that list. So we first diff --git a/src/stopping.ts b/src/stopping.ts index a66707c..36442e4 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -7,7 +7,7 @@ import _uniq from 'lodash/uniq'; */ export interface EarlyStoppingInput { /** The logical operation to use to evaluate multiple stopping conditions */ - logicalOperation?: 'and' | 'or' | 'AND' | 'OR'; + logicalOperation?: 'and' | 'or' | 'only' | 'AND' | 'OR' | 'ONLY'; } export interface StopAfterNItemsInput extends EarlyStoppingInput { @@ -167,7 +167,7 @@ export class StopOnSEMeasurementPlateau extends EarlyStopping { let earlyStop = false; - if (seMeasurements.length >= patience) { + if (seMeasurements?.length >= patience) { const mean = seMeasurements.slice(-patience).reduce((sum, se) => sum + se, 0) / patience; const withinTolerance = seMeasurements.slice(-patience).every((se) => Math.abs(se - mean) <= tolerance); From 74bf57702478dbb1430ab3a52c21e14901c59b63 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Fri, 18 Oct 2024 14:24:29 -0700 Subject: [PATCH 40/47] adding documentation about early stopping --- README.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/README.md b/README.md index 41c17f7..7d7c56e 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,36 @@ const stimuli = [{difficulty: -3, item: 'item1'}, {difficulty: -2, item: 'item2 const nextItem = cat.findNextItem(stimuli, 'MFI'); ``` +## Early Stopping Criteria Combinations + +To clarify the available combinations for early stopping, here’s a breakdown of the options you can use: + +### 1. Logical Operations + +You can combine multiple stopping criteria using one of the following logical operations: + +- **`and`**: All conditions need to be met to trigger early stopping. +- **`or`**: Any one condition being met will trigger early stopping. +- **`only`**: Only a specific condition is considered (you need to specify the cat to evaluate). + +### 2. Stopping Criteria Classes + +There are different types of stopping criteria you can configure: + +- **`StopAfterNItems`**: Stops the process after a specified number of items. +- **`StopOnSEMeasurementPlateau`**: Stops if the standard error (SE) of measurement remains stable (within a defined tolerance) for a specified number of items. +- **`StopIfSEMeasurementBelowThreshold`**: Stops if the SE measurement drops below a set threshold. + +### How Combinations Work + +You can mix and match these criteria with different logical operations, giving you a range of configurations for early stopping. For example: + +- Using **`and`** with both `StopAfterNItems` and `StopIfSEMeasurementBelowThreshold` means stopping will only occur if both conditions are satisfied. +- Using **`or`** with `StopOnSEMeasurementPlateau` and `StopAfterNItems` allows early stopping if either condition is met. + +If you need more details or a specific example documented, feel free to ask! + + ## Validation ### Validation of theta estimate and theta standard error From 508bc51c59c9387b7347572325dba0f480ab40ad Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Wed, 23 Oct 2024 13:51:13 -0700 Subject: [PATCH 41/47] since we added only, we need to add catToSelect --- src/clowder.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/clowder.ts b/src/clowder.ts index 19b20a9..b0d8ab2 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -287,12 +287,12 @@ export class Clowder { this.updateAbilityEstimates([catName], zetas, answers, method); } } - } - if (this._earlyStopping) { - this._earlyStopping.update(this.cats); - if (this._earlyStopping.earlyStop) { - return undefined; + if (this._earlyStopping) { + this._earlyStopping.update(this.cats, catToSelect); + if (this._earlyStopping.earlyStop) { + return undefined; + } } } From efbd55bbb7cb0717e7e7e829f763f01cff81f409 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Wed, 23 Oct 2024 14:43:23 -0700 Subject: [PATCH 42/47] solving adams comments --- src/__tests__/clowder.test.ts | 4 +-- src/__tests__/stopping.test.ts | 33 +---------------- src/clowder.ts | 66 +++++++++++++++++----------------- 3 files changed, 35 insertions(+), 68 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 6f1a576..1dcb723 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -327,8 +327,8 @@ describe('Clowder Class', () => { catToSelect: 'unvalidated', }); - expect(clowder.nItems.cat1).toBe(0); - expect(clowder.nItems.cat2).toBe(0); + expect(clowder.nItems.cat1).toBe(2); + expect(clowder.nItems.cat2).toBe(2); }); it('should not update any cats if only unvalidated items have been seen', () => { diff --git a/src/__tests__/stopping.test.ts b/src/__tests__/stopping.test.ts index 4a9d9c7..0676d33 100644 --- a/src/__tests__/stopping.test.ts +++ b/src/__tests__/stopping.test.ts @@ -1,12 +1,9 @@ import { Cat } from '..'; import { CatMap } from '../type'; +import { EarlyStopping, StopAfterNItems, StopIfSEMeasurementBelowThreshold, StopOnSEMeasurementPlateau } from '../'; import { - EarlyStopping, - StopAfterNItems, StopAfterNItemsInput, - StopIfSEMeasurementBelowThreshold, StopIfSEMeasurementBelowThresholdInput, - StopOnSEMeasurementPlateau, StopOnSEMeasurementPlateauInput, } from '../stopping'; import { toBeBoolean } from 'jest-extended'; @@ -469,20 +466,6 @@ describe('EarlyStopping with logicalOperation "only"', () => { earlyStopping.update({ cat1: { nItems: 1, seMeasurement: 0.5 } as any }, undefined); }).toThrowError('Must provide a cat to select for "only" stopping condition'); }); -}); - -describe('EarlyStopping with logicalOperation "only"', () => { - let earlyStopping: StopOnSEMeasurementPlateau; - let input: StopOnSEMeasurementPlateauInput; - - beforeEach(() => { - input = { - patience: { cat1: 2, cat2: 3 }, - tolerance: { cat1: 0.01, cat2: 0.02 }, - logicalOperation: 'only', - }; - earlyStopping = new StopOnSEMeasurementPlateau(input); - }); it('evaluates the stopping condition when catToSelect is in evaluationCats', () => { // Add updates to make sure cat1 is included in evaluationCats and has some measurements @@ -492,20 +475,6 @@ describe('EarlyStopping with logicalOperation "only"', () => { // Since 'cat1' is in evaluationCats, _earlyStop should be evaluated based on the stopping condition expect(earlyStopping.earlyStop).toBe(true); // Should be true because seMeasurement has plateaued }); -}); -describe('EarlyStopping with logicalOperation "only"', () => { - let earlyStopping: StopOnSEMeasurementPlateau; - let input: StopOnSEMeasurementPlateauInput; - - beforeEach(() => { - input = { - patience: { cat1: 2, cat2: 3 }, - tolerance: { cat1: 0.01, cat2: 0.02 }, - logicalOperation: 'only', - }; - earlyStopping = new StopOnSEMeasurementPlateau(input); - }); - it('sets _earlyStop to false when catToSelect is not in evaluationCats', () => { // Use 'cat3', which is not in the patience or tolerance maps (and thus not in evaluationCats) earlyStopping.update({ cat3: { nItems: 1, seMeasurement: 0.5 } as any }, 'cat3'); diff --git a/src/clowder.ts b/src/clowder.ts index b0d8ab2..793c33c 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -257,42 +257,40 @@ export class Clowder { const itemsAndAnswers = _zip(items, answers) as [Stimulus, 0 | 1][]; // Update the ability estimate for all validated cats - if (catsToUpdate.includes(catToSelect)) { - for (const catName of catsToUpdate) { - const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => - // We are dealing with a single item in this function. This single item - // has an array of zeta parameters for a bunch of different Cats. We - // need to determine if `catName` is present in that list. So we first - // reduce the zetas to get all of the applicabe cat names. - // Now that we have the subset of items that can apply to this cat, - // retrieve only the item parameters that apply to this cat. - stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), - ); - - if (itemsAndAnswersForCat.length > 0) { - const zetasAndAnswersForCat = itemsAndAnswersForCat - .map(([stim, _answer]) => { - const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => - zeta.cats.includes(catName), - ); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined - }) - .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values - - // Unzip the zetas and answers, making sure the zetas array contains only Zeta types - const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; - - // Now call updateAbilityEstimates for this cat - this.updateAbilityEstimates([catName], zetas, answers, method); - } + for (const catName of catsToUpdate) { + const itemsAndAnswersForCat = itemsAndAnswers.filter(([stim]) => + // We are dealing with a single item in this function. This single item + // has an array of zeta parameters for a bunch of different Cats. We + // need to determine if `catName` is present in that list. So we first + // reduce the zetas to get all of the applicabe cat names. + // Now that we have the subset of items that can apply to this cat, + // retrieve only the item parameters that apply to this cat. + stim.zetas.some((zeta: ZetaCatMap) => zeta.cats.includes(catName)), + ); + + if (itemsAndAnswersForCat.length > 0) { + const zetasAndAnswersForCat = itemsAndAnswersForCat + .map(([stim, _answer]) => { + const zetaForCat: ZetaCatMap | undefined = stim.zetas.find((zeta: ZetaCatMap) => + zeta.cats.includes(catName), + ); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + return [zetaForCat!.zeta, _answer]; // Optional chaining in case zetaForCat is undefined + }) + .filter(([zeta]) => zeta !== undefined); // Filter out undefined zeta values + + // Unzip the zetas and answers, making sure the zetas array contains only Zeta types + const [zetas, answers] = _unzip(zetasAndAnswersForCat) as [Zeta[], (0 | 1)[]]; + + // Now call updateAbilityEstimates for this cat + this.updateAbilityEstimates([catName], zetas, answers, method); } + } - if (this._earlyStopping) { - this._earlyStopping.update(this.cats, catToSelect); - if (this._earlyStopping.earlyStop) { - return undefined; - } + if (this._earlyStopping) { + this._earlyStopping.update(this.cats, catToSelect); + if (this._earlyStopping.earlyStop) { + return undefined; } } From c26f57c8590e7e3e9e2873b1e14d1dce18f1069d Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Wed, 23 Oct 2024 14:51:02 -0700 Subject: [PATCH 43/47] deleting for loop --- src/stopping.ts | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/stopping.ts b/src/stopping.ts index 36442e4..937a01f 100644 --- a/src/stopping.ts +++ b/src/stopping.ts @@ -99,17 +99,6 @@ export abstract class EarlyStopping { public update(cats: CatMap, catToSelect?: string): void { this._updateCats(cats); // This updates internal state with current cat data - // Iterate over each cat and update the _nItems map - for (const catName in cats) { - const cat = cats[catName]; - const nItems = cat.nItems; // Get the current number of items for this cat - - // Update the _nItems map with the current nItems value - if (nItems !== undefined) { - this._nItems[catName] = nItems; // Make sure nItems is set for this cat - } - } - // Collect the stopping conditions for all cats const conditions: boolean[] = this.evaluationCats.map((catName) => this._evaluateStoppingCondition(catName)); From 7abca472eece52f9ae1362e68ccc68f19bb38e41 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Tue, 29 Oct 2024 14:02:30 -0700 Subject: [PATCH 44/47] adding stopping reason --- src/__tests__/clowder.test.ts | 1 + src/clowder.ts | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index 1dcb723..b01c1bf 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -536,6 +536,7 @@ describe('Clowder Early Stopping', () => { expect(clowder.earlyStopping?.earlyStop).toBe(true); // Early stop should be triggered after 2 items expect(nextItem).toBe(undefined); // No further items should be selected + expect(clowder.stoppingReason).toBe('Early stopping'); }); it('should handle StopIfSEMeasurementBelowThreshold condition', () => { diff --git a/src/clowder.ts b/src/clowder.ts index 793c33c..57a6d22 100644 --- a/src/clowder.ts +++ b/src/clowder.ts @@ -46,6 +46,7 @@ export class Clowder { private _seenItems: Stimulus[]; private _earlyStopping?: EarlyStopping; private readonly _rng: ReturnType; + private _stoppingReason: string | null; /** * Create a Clowder object. @@ -70,6 +71,7 @@ export class Clowder { this._remainingItems = _cloneDeep(corpus); this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed); this._earlyStopping = earlyStopping; + this._stoppingReason = null; } /** @@ -151,10 +153,20 @@ export class Clowder { return _mapValues(this.cats, (cat) => cat.zetas); } + /** + * The early stopping condition in the Clowder configuration. + */ public get earlyStopping() { return this._earlyStopping; } + /** + * The stopping reason in the Clowder configuration. + */ + public get stoppingReason() { + return this._stoppingReason; + } + /** * Updates the ability estimates for the specified Cat instances. * @@ -290,6 +302,7 @@ export class Clowder { if (this._earlyStopping) { this._earlyStopping.update(this.cats, catToSelect); if (this._earlyStopping.earlyStop) { + this._stoppingReason = 'Early stopping'; return undefined; } } @@ -314,7 +327,7 @@ export class Clowder { const randInt = Math.floor(this._rng() * missing.length); return missing[randInt]; } - + this._stoppingReason = 'No unvalidated items remaining'; return undefined; } else { const randInt = Math.floor(this._rng() * unvalidatedRemainingItems.length); @@ -357,6 +370,7 @@ export class Clowder { if (available.length === 0) { // If returnUndefinedOnExhaustion is true and no validated items remain for the specified catToSelect, return undefined. if (returnUndefinedOnExhaustion) { + this._stoppingReason = 'No validated items remaining for specified catToSelect'; return undefined; // Return undefined if no validated items remain } else { // If returnUndefinedOnExhaustion is false, proceed with the fallback mechanism to select an item from other available categories. From fb5886c49e617661cab071df7ac93a903c778d41 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Tue, 29 Oct 2024 14:09:21 -0700 Subject: [PATCH 45/47] adding more stoppingReasons to the tests --- src/__tests__/clowder.test.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/__tests__/clowder.test.ts b/src/__tests__/clowder.test.ts index b01c1bf..998499d 100644 --- a/src/__tests__/clowder.test.ts +++ b/src/__tests__/clowder.test.ts @@ -121,7 +121,7 @@ describe('Clowder Class', () => { catToSelect: 'cat1', returnUndefinedOnExhaustion: true, }); - + expect(clowder.stoppingReason).toBe('No validated items remaining for specified catToSelect'); expect(nextItem).toBeUndefined(); }); @@ -377,7 +377,7 @@ describe('Clowder Class', () => { answers: [1, 1], catToSelect: 'unvalidated', }); - + expect(clowder.stoppingReason).toBe('No unvalidated items remaining'); expect(nextItem).toBeUndefined(); }); @@ -483,6 +483,7 @@ describe('Clowder Class', () => { }); expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after 2 items + expect(clowder.stoppingReason).toBe('Early stopping'); expect(nextItem).toBe(undefined); // Expect undefined after early stopping }); }); @@ -576,6 +577,7 @@ describe('Clowder Early Stopping', () => { if (item.id === '3') { expect(clowder.earlyStopping?.earlyStop).toBe(true); // Should stop after SE drops below threshold + expect(clowder.stoppingReason).toBe('Early stopping'); expect(nextItem).toBe(undefined); // No further items should be selected } else { expect(clowder.earlyStopping?.earlyStop).toBe(false); From 2f4285be410f54368930b5db985ffdeb961b87ca Mon Sep 17 00:00:00 2001 From: AnyaWMa <76414800+AnyaWMa@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:46:25 -0700 Subject: [PATCH 46/47] Update README.md --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7d7c56e..7c59254 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,14 @@ const currentPrior = normal(); // create a Cat object const cat = new CAT({method: 'MLE', itemSelect: 'MFI', nStartItems: 0, theta: 0, minTheta: -6, maxTheta: 6, prior: currentPrior}) +// option 1 to input stimuli: +const zeta = {[{discrimination: 1, difficulty: 0, guessing: 0, slipping: 1}, {discrimination: 1, difficulty: 0.5, guessing: 0, slipping: 1}]} + +// option 2 to input stimuli: +const zeta = {[{a: 1, b: 0, c: 0, d: 1}, {a: 1, b: 0.5, c: 0, d: 1}]} + +const answer = {[1, 0]} + // update the abilitiy estimate by adding test items cat.updateAbilityEstimate(zeta, answer); @@ -37,7 +45,7 @@ const numItems = cat.nItems; // find the next available item from an input array of stimuli based on a selection method -const stimuli = [{difficulty: -3, item: 'item1'}, {difficulty: -2, item: 'item2'}]; +const stimuli = [{ discrimination: 1, difficulty: -2, guessing: 0, slipping: 1, item = "item1" },{ discrimination: 1, difficulty: 3, guessing: 0, slipping: 1, item = "item2" }]; const nextItem = cat.findNextItem(stimuli, 'MFI'); ``` From e5a352e8e4afba356c2c3e5405fd011c89e488f5 Mon Sep 17 00:00:00 2001 From: emily-ejag Date: Tue, 12 Nov 2024 16:10:04 -0800 Subject: [PATCH 47/47] filterin NA from overall corpus --- src/__tests__/corpus.test.ts | 2 +- src/corpus.ts | 5 ++++- src/index.ts | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/__tests__/corpus.test.ts b/src/__tests__/corpus.test.ts index 8d5bd16..d9d79d8 100644 --- a/src/__tests__/corpus.test.ts +++ b/src/__tests__/corpus.test.ts @@ -7,8 +7,8 @@ import { convertZeta, checkNoDuplicateCatNames, filterItemsByCatParameterAvailability, - prepareClowderCorpus, } from '../corpus'; +import { prepareClowderCorpus } from '..'; import _omit from 'lodash/omit'; describe('validateZetaParams', () => { diff --git a/src/corpus.ts b/src/corpus.ts index 687fc32..94102f5 100644 --- a/src/corpus.ts +++ b/src/corpus.ts @@ -287,7 +287,10 @@ export const prepareClowderCorpus = ( zeta: convertZeta(zeta, itemParameterFormat), }; }) - .filter((zeta) => !_isEmpty(zeta)); // filter empty values + .filter((zeta) => { + // Check if zeta has no `NA` values and is not empty + return !_isEmpty(zeta.zeta) && Object.values(zeta.zeta).every((value) => value !== 'NA'); + }); // Create the MultiZetaStimulus structure without the category keys const cleanItem = _omit( diff --git a/src/index.ts b/src/index.ts index 4529537..2571bd7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,6 @@ export { Cat, CatInput } from './cat'; export { Clowder, ClowderInput } from './clowder'; +export { prepareClowderCorpus } from './corpus'; export { EarlyStopping, StopAfterNItems,