diff --git a/src/function/arithmetic/multiply.js b/src/function/arithmetic/multiply.js index 4ea04a1da7..801b9e23f3 100644 --- a/src/function/arithmetic/multiply.js +++ b/src/function/arithmetic/multiply.js @@ -11,10 +11,11 @@ const dependencies = [ 'matrix', 'addScalar', 'multiplyScalar', - 'equalScalar' + 'equalScalar', + 'dot' ] -export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typed, matrix, addScalar, multiplyScalar, equalScalar }) => { +export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typed, matrix, addScalar, multiplyScalar, equalScalar, dot }) => { const algorithm11 = createAlgorithm11({ typed, equalScalar }) const algorithm14 = createAlgorithm14({ typed }) @@ -201,38 +202,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ function _multiplyVectorVector (a, b, n) { // check empty vector if (n === 0) { throw new Error('Cannot multiply two empty vectors') } - - // a dense - const adata = a._data - const adt = a._datatype - // b dense - const bdata = b._data - const bdt = b._datatype - - // datatype - let dt - // addScalar signature to use - let af = addScalar - // multiplyScalar signature to use - let mf = multiplyScalar - - // process data types - if (adt && bdt && adt === bdt && typeof adt === 'string') { - // datatype - dt = adt - // find signatures that matches (dt, dt) - af = typed.find(addScalar, [dt, dt]) - mf = typed.find(multiplyScalar, [dt, dt]) - } - - // result (do not initialize it with zero) - let c = mf(adata[0], bdata[0]) - // loop data - for (let i = 1; i < n; i++) { - // multiply and accumulate - c = af(c, mf(adata[i], bdata[i])) - } - return c + return dot(a, b) } /** diff --git a/src/function/matrix/dot.js b/src/function/matrix/dot.js index 7e7a7daf98..a03f39c7c8 100644 --- a/src/function/matrix/dot.js +++ b/src/function/matrix/dot.js @@ -1,15 +1,15 @@ -import { arraySize as size } from '../../utils/array' import { factory } from '../../utils/factory' +import { isMatrix } from '../../utils/is' const name = 'dot' -const dependencies = ['typed', 'add', 'multiply'] +const dependencies = ['typed', 'addScalar', 'multiplyScalar', 'conj', 'size'] -export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, add, multiply }) => { +export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, addScalar, multiplyScalar, conj, size }) => { /** * Calculate the dot product of two vectors. The dot product of - * `A = [a1, a2, a3, ..., an]` and `B = [b1, b2, b3, ..., bn]` is defined as: + * `A = [a1, a2, ..., an]` and `B = [b1, b2, ..., bn]` is defined as: * - * dot(A, B) = a1 * b1 + a2 * b2 + a3 * b3 + ... + an * bn + * dot(A, B) = conj(a1) * b1 + conj(a2) * b2 + ... + conj(an) * bn * * Syntax: * @@ -29,43 +29,138 @@ export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, a * @return {number} Returns the dot product of `x` and `y` */ return typed(name, { - 'Matrix, Matrix': function (x, y) { - return _dot(x.toArray(), y.toArray()) - }, + 'Array | DenseMatrix, Array | DenseMatrix': _denseDot, + 'SparseMatrix, SparseMatrix': _sparseDot + }) - 'Matrix, Array': function (x, y) { - return _dot(x.toArray(), y) - }, + function _validateDim (x, y) { + const xSize = _size(x) + const ySize = _size(y) + let xLen, yLen - 'Array, Matrix': function (x, y) { - return _dot(x, y.toArray()) - }, + if (xSize.length === 1) { + xLen = xSize[0] + } else if (xSize.length === 2 && xSize[1] === 1) { + xLen = xSize[0] + } else { + throw new RangeError('Expected a column vector, instead got a matrix of size (' + xSize.join(', ') + ')') + } - 'Array, Array': _dot - }) + if (ySize.length === 1) { + yLen = ySize[0] + } else if (ySize.length === 2 && ySize[1] === 1) { + yLen = ySize[0] + } else { + throw new RangeError('Expected a column vector, instead got a matrix of size (' + ySize.join(', ') + ')') + } - /** - * Calculate the dot product for two arrays - * @param {Array} x First vector - * @param {Array} y Second vector - * @returns {number} Returns the dot product of x and y - * @private - */ - // TODO: double code with math.multiply - function _dot (x, y) { - const xSize = size(x) - const ySize = size(y) - const len = xSize[0] - - if (xSize.length !== 1 || ySize.length !== 1) throw new RangeError('Vector expected') // TODO: better error message - if (xSize[0] !== ySize[0]) throw new RangeError('Vectors must have equal length (' + xSize[0] + ' != ' + ySize[0] + ')') - if (len === 0) throw new RangeError('Cannot calculate the dot product of empty vectors') - - let prod = 0 - for (let i = 0; i < len; i++) { - prod = add(prod, multiply(x[i], y[i])) + if (xLen !== yLen) throw new RangeError('Vectors must have equal length (' + xLen + ' != ' + yLen + ')') + if (xLen === 0) throw new RangeError('Cannot calculate the dot product of empty vectors') + + return xLen + } + + function _denseDot (a, b) { + const N = _validateDim(a, b) + + const adata = isMatrix(a) ? a._data : a + const adt = isMatrix(a) ? a._datatype : undefined + + const bdata = isMatrix(b) ? b._data : b + const bdt = isMatrix(b) ? b._datatype : undefined + + // are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors) + const aIsColumn = _size(a).length === 2 + const bIsColumn = _size(b).length === 2 + + let add = addScalar + let mul = multiplyScalar + + // process data types + if (adt && bdt && adt === bdt && typeof adt === 'string') { + const dt = adt + // find signatures that matches (dt, dt) + add = typed.find(addScalar, [dt, dt]) + mul = typed.find(multiplyScalar, [dt, dt]) + } + + // both vectors 1-dimensional + if (!aIsColumn && !bIsColumn) { + let c = mul(conj(adata[0]), bdata[0]) + for (let i = 1; i < N; i++) { + c = add(c, mul(conj(adata[i]), bdata[i])) + } + return c } - return prod + // a is 1-dim, b is column + if (!aIsColumn && bIsColumn) { + let c = mul(conj(adata[0]), bdata[0][0]) + for (let i = 1; i < N; i++) { + c = add(c, mul(conj(adata[i]), bdata[i][0])) + } + return c + } + + // a is column, b is 1-dim + if (aIsColumn && !bIsColumn) { + let c = mul(conj(adata[0][0]), bdata[0]) + for (let i = 1; i < N; i++) { + c = add(c, mul(conj(adata[i][0]), bdata[i])) + } + return c + } + + // both vectors are column + if (aIsColumn && bIsColumn) { + let c = mul(conj(adata[0][0]), bdata[0][0]) + for (let i = 1; i < N; i++) { + c = add(c, mul(conj(adata[i][0]), bdata[i][0])) + } + return c + } + } + + function _sparseDot (x, y) { + _validateDim(x, y) + + const xindex = x._index + const xvalues = x._values + + const yindex = y._index + const yvalues = y._values + + // TODO optimize add & mul using datatype + let c = 0 + const add = addScalar + const mul = multiplyScalar + + let i = 0 + let j = 0 + while (i < xindex.length && j < yindex.length) { + const I = xindex[i] + const J = yindex[j] + + if (I < J) { + i++ + continue + } + if (I > J) { + j++ + continue + } + if (I === J) { + c = add(c, mul(xvalues[i], yvalues[j])) + i++ + j++ + } + } + + return c + } + + // TODO remove this once #1771 is fixed + function _size (x) { + return isMatrix(x) ? x.size() : size(x) } }) diff --git a/test/unit-tests/function/arithmetic/multiply.test.js b/test/unit-tests/function/arithmetic/multiply.test.js index 7a12087c44..3c3d3b699e 100644 --- a/test/unit-tests/function/arithmetic/multiply.test.js +++ b/test/unit-tests/function/arithmetic/multiply.test.js @@ -285,6 +285,14 @@ describe('multiply', function () { approx.deepEqual(multiply(matrix(a), matrix(b)), 32) }) + it('should conjugate the first argument in dot product', function () { + const a = [complex(1, 2), complex(3, 4)] + const b = [complex(5, 6), complex(7, 8)] + + approx.deepEqual(multiply(a, b), complex(70, -8)) + approx.deepEqual(multiply(matrix(a), matrix(b)), complex(70, -8)) + }) + it('should multiply row vector x column vector', function () { const v = [[1, 2, 3, 0, 0, 5, 6]] diff --git a/test/unit-tests/function/matrix/dot.test.js b/test/unit-tests/function/matrix/dot.test.js index e2c7d6704b..a3de75795c 100644 --- a/test/unit-tests/function/matrix/dot.test.js +++ b/test/unit-tests/function/matrix/dot.test.js @@ -1,40 +1,93 @@ import assert from 'assert' import math from '../../../../src/bundleAny' +const dot = math.dot +const matrix = math.matrix +const sparse = math.sparse +const complex = math.complex + describe('dot', function () { - it('should calculate dot product for two arrays', function () { - assert.strictEqual(math.dot([2, 4, 1], [2, 2, 3]), 15) - assert.strictEqual(math.dot([7, 3], [2, 4]), 26) + it('should calculate dot product for two 1-dim arrays', function () { + assert.strictEqual(dot([2, 4, 1], [2, 2, 3]), 15) + assert.strictEqual(dot([7, 3], [2, 4]), 26) + }) + + it('should calculate dot product for two column arrays', function () { + assert.strictEqual(dot([[2], [4], [1]], [[2], [2], [3]]), 15) + assert.strictEqual(dot([[7], [3]], [[2], [4]]), 26) + }) + + it('should calculate dot product for two 1-dim vectors', function () { + assert.strictEqual(dot(matrix([2, 4, 1]), matrix([2, 2, 3])), 15) + assert.strictEqual(dot(matrix([7, 3]), matrix([2, 4])), 26) + }) + + it('should calculate dot product for two column vectors', function () { + assert.strictEqual(dot(matrix([[2], [4], [1]]), matrix([[2], [2], [3]])), 15) + assert.strictEqual(dot(matrix([[7], [3]]), matrix([[2], [4]])), 26) + }) + + it('should calculate dot product for mixed 1-dim arrays and column arrays', function () { + assert.strictEqual(dot([2, 4, 1], [[2], [2], [3]]), 15) + assert.strictEqual(dot([[7], [3]], [2, 4]), 26) + }) + + it('should calculate dot product for mixed 1-dim arrays and 1-dim vectors', function () { + assert.strictEqual(dot([2, 4, 1], matrix([2, 2, 3])), 15) + assert.strictEqual(dot(matrix([7, 3]), [2, 4]), 26) + }) + + it('should calculate dot product for mixed 1-dim arrays and column vectors', function () { + assert.strictEqual(dot([2, 4, 1], matrix([[2], [2], [3]])), 15) + assert.strictEqual(dot(matrix([[7], [3]]), [2, 4]), 26) + }) + + it('should calculate dot product for mixed column arrays and 1-dim vectors', function () { + assert.strictEqual(dot([[2], [4], [1]], matrix([2, 2, 3])), 15) + assert.strictEqual(dot(matrix([7, 3]), [[2], [4]]), 26) }) - it('should calculate dot product for two matrices', function () { - assert.strictEqual(math.dot(math.matrix([2, 4, 1]), math.matrix([2, 2, 3])), 15) - assert.strictEqual(math.dot(math.matrix([7, 3]), math.matrix([2, 4])), 26) + it('should calculate dot product for mixed column arrays and column vectors', function () { + assert.strictEqual(dot([[2], [4], [1]], matrix([[2], [2], [3]])), 15) + assert.strictEqual(dot(matrix([[7], [3]]), [[2], [4]]), 26) }) - it('should calculate dot product for mixed arrays and matrices', function () { - assert.strictEqual(math.dot([2, 4, 1], math.matrix([2, 2, 3])), 15) - assert.strictEqual(math.dot(math.matrix([7, 3]), [2, 4]), 26) + it('should calculate dot product for mixed 1-dim vectors and column vectors', function () { + assert.strictEqual(dot(matrix([2, 4, 1]), matrix([[2], [2], [3]])), 15) + assert.strictEqual(dot(matrix([[7], [3]]), matrix([2, 4])), 26) + }) + + it('should calculate dot product for sparse vectors', function () { + assert.strictEqual(dot(sparse([0, 0, 2, 4, 4, 1]), sparse([1, 0, 2, 2, 0, 3])), 15) + assert.strictEqual(dot(sparse([7, 1, 2, 3]), sparse([2, 0, 0, 4])), 26) }) it('should throw an error for unsupported types of arguments', function () { - assert.throws(function () { math.dot([2, 4, 1], 2) }, TypeError) + assert.throws(function () { dot([2, 4, 1], 2) }, TypeError) }) it('should throw an error for multi dimensional matrix input', function () { - assert.throws(function () { math.dot([[1, 2], [3, 4]], [[1, 2], [3, 4]]) }, /Vector expected/) + assert.throws(function () { dot([[1, 2], [3, 4]], [[1, 2], [3, 4]]) }, /Expected a column vector, instead got a matrix of size \(2, 2\)/) }) it('should throw an error in case of vectors with unequal length', function () { - assert.throws(function () { math.dot([2, 3], [1, 2, 3]) }, /Vectors must have equal length \(2 != 3\)/) + assert.throws(function () { dot([2, 3], [1, 2, 3]) }, /Vectors must have equal length \(2 != 3\)/) }) it('should throw an error in case of empty vectors', function () { - assert.throws(function () { math.dot([], []) }, /Cannot calculate the dot product of empty vectors/) + assert.throws(function () { dot([], []) }, /Cannot calculate the dot product of empty vectors/) }) it('should LaTeX dot', function () { const expression = math.parse('dot([1,2],[3,4])') assert.strictEqual(expression.toTex(), '\\left(\\begin{bmatrix}1\\\\2\\\\\\end{bmatrix}\\cdot\\begin{bmatrix}3\\\\4\\\\\\end{bmatrix}\\right)') }) + + it('should be antilinear in the first argument', function () { + const I = complex(0, 1) + assert.deepStrictEqual(dot([I, 2], [1, I]), I) + + const v = matrix([2, I, 1]) + assert.deepStrictEqual(dot(v, v).sqrt(), complex(math.norm(v))) + }) })