diff --git a/src/unary.js b/src/unary.js index c6976ea..54e4172 100644 --- a/src/unary.js +++ b/src/unary.js @@ -30,3 +30,20 @@ export const tan = (input) => unary(input, Math.tan); export const copy = (input) => unary(input, (x) => x); export const reciprocal = (input) => unary(input, (x) => 1 / x); export const sqrt = (input) => unary(input, Math.sqrt); +export const erf = (input) => unary(input, (x) => { + // reference 1: https://en.wikipedia.org/wiki/Error_function + // reference 2: https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-cpu/src/kernels/Erf.ts + const a1 = 0.254829592; + const a2 = -0.284496736; + const a3 = 1.421413741; + const a4 =-1.453152027; + const a5 = 1.061405429; + const p = 0.3275911; + const sign = Math.sign(x); + const v = Math.abs(x); + const t = 1.0 / (1.0 + p * v); + return sign * + (1.0 - + (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + Math.exp(-v * v)); +}); diff --git a/test/unary_test.js b/test/unary_test.js index bef9d4e..596905f 100644 --- a/test/unary_test.js +++ b/test/unary_test.js @@ -141,6 +141,33 @@ describe('test unary', function() { [3, 2, 2, 1]); }); + it('erf', function() { + // erf 1D + testUnary( + 'erf', + [-0.25, 0.25, 0.5, 0.75, -0.4], + [ + -0.2763262613535272, + 0.2763262613535272, + 0.5205000163047472, + 0.7111555696366565, + -0.42839242346728446, + ], + [5]); + // erf 2D + testUnary( + 'erf', + [ + 0.2, 0.3, + 0.4, 0.5, + ], + [ + 0.22270245785831588, 0.32862668272901174, + 0.42839242346728446, 0.5205000163047472, + ], + [2, 2]); + }); + it('exp', function() { testUnary('exp', [-1, 0, 1], [0.36787944117144233, 1, 2.718281828459045], [3]); testUnary(