Skip to content

Commit

Permalink
add more methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AnyaWMa committed Jul 20, 2022
1 parent a4b3481 commit 8c76113
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 23 deletions.
20 changes: 18 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
},
"homepage": "https://github.com/yeatmanlab/jsCAT#readme",
"devDependencies": {
"@types/lodash": "^4.14.182",
"typescript": "^4.7.4"
},
"dependencies": {
"lodash": "^4.17.21",
"optimization-js": "^1.5.0"
}
}
99 changes: 79 additions & 20 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import {minimize_Powell} from "optimization-js";
import { cloneDeep } from "lodash";

export type Zeta = { a: number, b: number, c: number, d: number };

export interface Stimulus {
difficulty: number;
[key: string]: any;
}

export const fisherInformation = (theta: number, zeta: Zeta) => {
const p = itemResponseFunction(zeta, theta)
const q = 1-p
Expand All @@ -12,7 +18,7 @@ export const itemResponseFunction = (zeta: Zeta, theta: number) => {
return zeta.c + (zeta.d - zeta.c) / (1 + Math.exp(-zeta.a * (theta - zeta.b)));
}

export const findClosest = (arr: Array<object>, target: number) => {
export const findClosest = (arr: Array<Stimulus>, target: number) => {
let n = arr.length;
// Corner cases
if (target <= arr[0].difficulty)
Expand Down Expand Up @@ -46,25 +52,16 @@ export const findClosest = (arr: Array<object>, target: number) => {
return mid;
}

export const getClosest = (arr: Array<object>, val1:number, val2: number, target: number) => {
export const getClosest = (arr: Array<Stimulus>, val1:number, val2: number, target: number) => {
if (target - arr[val1].difficulty >= arr[val2].difficulty - target)
return val2;
else
return val1;
}

export const estimateAbilityMLE = (answers: Array<number>, zetas: Array<Zeta>, min_theta: number, max_theta: number) => {
let max_like = -Infinity;
let theta0 = [0];

const solution = minimize_Powell(negLikelihood, theta0)
// for (let theta = min_theta; theta <= max_theta; theta += learning_rate) {
// let like = likelihood(theta);
// if (like > max_like){
// max_like = like;
// res_theta = theta;
// }
// }
const theta0 = [0];
const solution = minimize_Powell(negLikelihood, theta0);

let theta = solution.argument[0];
if (theta > max_theta) {
Expand All @@ -76,24 +73,24 @@ export const estimateAbilityMLE = (answers: Array<number>, zetas: Array<Zeta>, m

function likelihood(theta: number) {
return zetas.reduce((acc: number, zeta: Zeta, i: number) => {
let irf = itemResponseFunction(theta, zeta);
let irf = itemResponseFunction(zeta, theta);
return answers[i] === 1 ? acc + Math.log(irf) : acc + Math.log(1 - irf);
}, 0);
}

function negLikelihood(thetaArray) {
function negLikelihood(thetaArray: Array<number>) {
return -likelihood(thetaArray[0])
}
}

export const normal = (mean: number, stdDev: number) => {
let distr = [];
let distribution = [];
for (let i = -4; i <= 4; i += 0.1) {
distr.push([i, y(i)]);
distribution.push([i, y(i)]);
}
return distr;
return distribution;

function y(x) {
function y(x: number) {
return (
(1 / (Math.sqrt(2 * Math.PI) * stdDev)) *
Math.exp(-Math.pow(x - mean, 2) / (2 * Math.pow(stdDev, 2)))
Expand All @@ -115,8 +112,70 @@ export const estimateAbilityEAP = (answers: Array<number>, zetas: Array<Zeta>)
return num / nf;
function likelihood(theta: number) {
return zetas.reduce((acc, zeta, i) => {
let irf = itemResponseFunction(theta, zeta);
let irf = itemResponseFunction(zeta, theta);
return answers[i] === 1 ? acc * irf : acc * (1 - irf);
}, 1);
}
}
/**
* return a random integer between min and max
* @param min - The minimum of the random number range (include)
* @param max - The maximum of the random number range (include)
* @returns {number} - random integer within the range
*/
export const randomInteger = (min: number, max: number) => {
return Math.floor(Math.random() * (max - min + 1)) + min;
}

/**
* find the next available item from an input array of stimuli based on a selection method
* @param stimuli {Array<Stimulus>} - an array of stimulus
* @param theta {number} - the theta estimate, default theta = 0
* @param method {string} - the method of item selection, e.g. "MI", "random", "closest", default method = 'MI'
* @param deepCopy {boolean} - default deepCopy = true
* @returns {nextStimulus: Stimulus,
remainingStimuli: Array<Stimulus> }
*/
export const findNextItem = (stimuli: Array<Stimulus>, theta = 0, method = 'MI', deepCopy = true) => {
method = method.toLowerCase();
const validMethod: Array<string> = ['mi', 'random', 'closest'];
if (!validMethod.includes(method)){
throw new Error('The method you provided is not in the list of valid methods');
}
let arr: Array<Stimulus>;
if (deepCopy) {
arr = cloneDeep(stimuli);
} else {
arr = stimuli;
}
arr.sort((a, b) => a.difficulty - b.difficulty);

method = method.toLowerCase();
if (method === 'mi'){
const stimuliAddFisher = arr.map((element) => ({fisherInformation: fisherInformation(theta,
{a: 1, b: element.difficulty, c: 0.5, d: 1}), ...element}));
stimuliAddFisher.sort((a,b) => b.fisherInformation - a.fisherInformation);
return {
nextStimulus: stimuliAddFisher[0],
remainingStimuli: stimuliAddFisher.slice(1)
};
} else if (method == 'random') {
let index: number;
if (arr.length < 5){
index = Math.floor(arr.length / 2);
} else {
index = Math.floor(arr.length / 2) + randomInteger(-2, 2);
}
return {
nextStimulus: arr[index],
remainingStimuli: arr.splice(index, 1)
};
} else if (method == 'closest') {
//findClosest requires arr is sorted by difficulty
const index = findClosest(arr, theta + 0.481);
return {
nextStimulus: arr[index],
remainingStimuli: arr.splice(index, 1)
};
}
}
2 changes: 1 addition & 1 deletion tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"compilerOptions": {
"target": "es5",
"target": "es2016",
"module": "commonjs",
"declaration": true,
"outDir": "./lib",
Expand Down

0 comments on commit 8c76113

Please sign in to comment.