Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dot product #1773

Merged
merged 10 commits into from
Mar 29, 2020
38 changes: 4 additions & 34 deletions src/function/arithmetic/multiply.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ const dependencies = [
'matrix',
'addScalar',
'multiplyScalar',
'equalScalar'
'equalScalar',
'dot'
]

export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typed, matrix, addScalar, multiplyScalar, equalScalar }) => {
export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typed, matrix, addScalar, multiplyScalar, equalScalar, dot }) => {
const algorithm11 = createAlgorithm11({ typed, equalScalar })
const algorithm14 = createAlgorithm14({ typed })

Expand Down Expand Up @@ -201,38 +202,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
function _multiplyVectorVector (a, b, n) {
// check empty vector
if (n === 0) { throw new Error('Cannot multiply two empty vectors') }

// a dense
const adata = a._data
const adt = a._datatype
// b dense
const bdata = b._data
const bdt = b._datatype

// datatype
let dt
// addScalar signature to use
let af = addScalar
// multiplyScalar signature to use
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
af = typed.find(addScalar, [dt, dt])
mf = typed.find(multiplyScalar, [dt, dt])
}

// result (do not initialize it with zero)
let c = mf(adata[0], bdata[0])
// loop data
for (let i = 1; i < n; i++) {
// multiply and accumulate
c = af(c, mf(adata[i], bdata[i]))
}
return c
return dot(a, b)
}

/**
Expand Down
169 changes: 132 additions & 37 deletions src/function/matrix/dot.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { arraySize as size } from '../../utils/array'
import { factory } from '../../utils/factory'
import { isMatrix } from '../../utils/is'

const name = 'dot'
const dependencies = ['typed', 'add', 'multiply']
const dependencies = ['typed', 'addScalar', 'multiplyScalar', 'conj', 'size']

export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, add, multiply }) => {
export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, addScalar, multiplyScalar, conj, size }) => {
/**
* Calculate the dot product of two vectors. The dot product of
* `A = [a1, a2, a3, ..., an]` and `B = [b1, b2, b3, ..., bn]` is defined as:
* `A = [a1, a2, ..., an]` and `B = [b1, b2, ..., bn]` is defined as:
*
* dot(A, B) = a1 * b1 + a2 * b2 + a3 * b3 + ... + an * bn
* dot(A, B) = conj(a1) * b1 + conj(a2) * b2 + ... + conj(an) * bn
*
* Syntax:
*
Expand All @@ -29,43 +29,138 @@ export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, a
* @return {number} Returns the dot product of `x` and `y`
*/
return typed(name, {
'Matrix, Matrix': function (x, y) {
return _dot(x.toArray(), y.toArray())
},
'Array | DenseMatrix, Array | DenseMatrix': _denseDot,
'SparseMatrix, SparseMatrix': _sparseDot
})

'Matrix, Array': function (x, y) {
return _dot(x.toArray(), y)
},
function _validateDim (x, y) {
const xSize = _size(x)
const ySize = _size(y)
let xLen, yLen

'Array, Matrix': function (x, y) {
return _dot(x, y.toArray())
},
if (xSize.length === 1) {
xLen = xSize[0]
} else if (xSize.length === 2 && xSize[1] === 1) {
xLen = xSize[0]
} else {
throw new RangeError('Expected a column vector, instead got a matrix of size (' + xSize.join(', ') + ')')
}

'Array, Array': _dot
})
if (ySize.length === 1) {
yLen = ySize[0]
} else if (ySize.length === 2 && ySize[1] === 1) {
yLen = ySize[0]
} else {
throw new RangeError('Expected a column vector, instead got a matrix of size (' + ySize.join(', ') + ')')
}

/**
* Calculate the dot product for two arrays
* @param {Array} x First vector
* @param {Array} y Second vector
* @returns {number} Returns the dot product of x and y
* @private
*/
// TODO: double code with math.multiply
function _dot (x, y) {
const xSize = size(x)
const ySize = size(y)
const len = xSize[0]

if (xSize.length !== 1 || ySize.length !== 1) throw new RangeError('Vector expected') // TODO: better error message
if (xSize[0] !== ySize[0]) throw new RangeError('Vectors must have equal length (' + xSize[0] + ' != ' + ySize[0] + ')')
if (len === 0) throw new RangeError('Cannot calculate the dot product of empty vectors')

let prod = 0
for (let i = 0; i < len; i++) {
prod = add(prod, multiply(x[i], y[i]))
if (xLen !== yLen) throw new RangeError('Vectors must have equal length (' + xLen + ' != ' + yLen + ')')
if (xLen === 0) throw new RangeError('Cannot calculate the dot product of empty vectors')

return xLen
}

function _denseDot (a, b) {
const N = _validateDim(a, b)

const adata = isMatrix(a) ? a._data : a
const adt = isMatrix(a) ? a._datatype : undefined

const bdata = isMatrix(b) ? b._data : b
const bdt = isMatrix(b) ? b._datatype : undefined

// are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors)
const aIsColumn = _size(a).length === 2
const bIsColumn = _size(b).length === 2

let add = addScalar
let mul = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
const dt = adt
// find signatures that matches (dt, dt)
add = typed.find(addScalar, [dt, dt])
mul = typed.find(multiplyScalar, [dt, dt])
}

// both vectors 1-dimensional
if (!aIsColumn && !bIsColumn) {
let c = mul(conj(adata[0]), bdata[0])
for (let i = 1; i < N; i++) {
c = add(c, mul(conj(adata[i]), bdata[i]))
}
return c
}

return prod
// a is 1-dim, b is column
if (!aIsColumn && bIsColumn) {
let c = mul(conj(adata[0]), bdata[0][0])
for (let i = 1; i < N; i++) {
c = add(c, mul(conj(adata[i]), bdata[i][0]))
}
return c
}

// a is column, b is 1-dim
if (aIsColumn && !bIsColumn) {
let c = mul(conj(adata[0][0]), bdata[0])
for (let i = 1; i < N; i++) {
c = add(c, mul(conj(adata[i][0]), bdata[i]))
}
return c
}

// both vectors are column
if (aIsColumn && bIsColumn) {
let c = mul(conj(adata[0][0]), bdata[0][0])
for (let i = 1; i < N; i++) {
c = add(c, mul(conj(adata[i][0]), bdata[i][0]))
}
return c
}
}

function _sparseDot (x, y) {
_validateDim(x, y)

const xindex = x._index
const xvalues = x._values

const yindex = y._index
const yvalues = y._values

// TODO optimize add & mul using datatype
let c = 0
const add = addScalar
const mul = multiplyScalar

let i = 0
let j = 0
while (i < xindex.length && j < yindex.length) {
const I = xindex[i]
const J = yindex[j]

if (I < J) {
i++
continue
}
if (I > J) {
j++
continue
}
if (I === J) {
c = add(c, mul(xvalues[i], yvalues[j]))
i++
j++
}
}

return c
}

// TODO remove this once #1771 is fixed
function _size (x) {
return isMatrix(x) ? x.size() : size(x)
}
})
8 changes: 8 additions & 0 deletions test/unit-tests/function/arithmetic/multiply.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ describe('multiply', function () {
approx.deepEqual(multiply(matrix(a), matrix(b)), 32)
})

it('should conjugate the first argument in dot product', function () {
const a = [complex(1, 2), complex(3, 4)]
const b = [complex(5, 6), complex(7, 8)]

approx.deepEqual(multiply(a, b), complex(70, -8))
approx.deepEqual(multiply(matrix(a), matrix(b)), complex(70, -8))
})

it('should multiply row vector x column vector', function () {
const v = [[1, 2, 3, 0, 0, 5, 6]]

Expand Down
79 changes: 66 additions & 13 deletions test/unit-tests/function/matrix/dot.test.js
Original file line number Diff line number Diff line change
@@ -1,40 +1,93 @@
import assert from 'assert'
import math from '../../../../src/bundleAny'

const dot = math.dot
const matrix = math.matrix
const sparse = math.sparse
const complex = math.complex

describe('dot', function () {
it('should calculate dot product for two arrays', function () {
assert.strictEqual(math.dot([2, 4, 1], [2, 2, 3]), 15)
assert.strictEqual(math.dot([7, 3], [2, 4]), 26)
it('should calculate dot product for two 1-dim arrays', function () {
assert.strictEqual(dot([2, 4, 1], [2, 2, 3]), 15)
assert.strictEqual(dot([7, 3], [2, 4]), 26)
})

it('should calculate dot product for two column arrays', function () {
assert.strictEqual(dot([[2], [4], [1]], [[2], [2], [3]]), 15)
assert.strictEqual(dot([[7], [3]], [[2], [4]]), 26)
})

it('should calculate dot product for two 1-dim vectors', function () {
assert.strictEqual(dot(matrix([2, 4, 1]), matrix([2, 2, 3])), 15)
assert.strictEqual(dot(matrix([7, 3]), matrix([2, 4])), 26)
})

it('should calculate dot product for two column vectors', function () {
assert.strictEqual(dot(matrix([[2], [4], [1]]), matrix([[2], [2], [3]])), 15)
assert.strictEqual(dot(matrix([[7], [3]]), matrix([[2], [4]])), 26)
})

it('should calculate dot product for mixed 1-dim arrays and column arrays', function () {
assert.strictEqual(dot([2, 4, 1], [[2], [2], [3]]), 15)
assert.strictEqual(dot([[7], [3]], [2, 4]), 26)
})

it('should calculate dot product for mixed 1-dim arrays and 1-dim vectors', function () {
assert.strictEqual(dot([2, 4, 1], matrix([2, 2, 3])), 15)
assert.strictEqual(dot(matrix([7, 3]), [2, 4]), 26)
})

it('should calculate dot product for mixed 1-dim arrays and column vectors', function () {
assert.strictEqual(dot([2, 4, 1], matrix([[2], [2], [3]])), 15)
assert.strictEqual(dot(matrix([[7], [3]]), [2, 4]), 26)
})

it('should calculate dot product for mixed column arrays and 1-dim vectors', function () {
assert.strictEqual(dot([[2], [4], [1]], matrix([2, 2, 3])), 15)
assert.strictEqual(dot(matrix([7, 3]), [[2], [4]]), 26)
})

it('should calculate dot product for two matrices', function () {
assert.strictEqual(math.dot(math.matrix([2, 4, 1]), math.matrix([2, 2, 3])), 15)
assert.strictEqual(math.dot(math.matrix([7, 3]), math.matrix([2, 4])), 26)
it('should calculate dot product for mixed column arrays and column vectors', function () {
assert.strictEqual(dot([[2], [4], [1]], matrix([[2], [2], [3]])), 15)
assert.strictEqual(dot(matrix([[7], [3]]), [[2], [4]]), 26)
})

it('should calculate dot product for mixed arrays and matrices', function () {
assert.strictEqual(math.dot([2, 4, 1], math.matrix([2, 2, 3])), 15)
assert.strictEqual(math.dot(math.matrix([7, 3]), [2, 4]), 26)
it('should calculate dot product for mixed 1-dim vectors and column vectors', function () {
assert.strictEqual(dot(matrix([2, 4, 1]), matrix([[2], [2], [3]])), 15)
assert.strictEqual(dot(matrix([[7], [3]]), matrix([2, 4])), 26)
})

it('should calculate dot product for sparse vectors', function () {
assert.strictEqual(dot(sparse([0, 0, 2, 4, 4, 1]), sparse([1, 0, 2, 2, 0, 3])), 15)
assert.strictEqual(dot(sparse([7, 1, 2, 3]), sparse([2, 0, 0, 4])), 26)
})

it('should throw an error for unsupported types of arguments', function () {
assert.throws(function () { math.dot([2, 4, 1], 2) }, TypeError)
assert.throws(function () { dot([2, 4, 1], 2) }, TypeError)
})

it('should throw an error for multi dimensional matrix input', function () {
assert.throws(function () { math.dot([[1, 2], [3, 4]], [[1, 2], [3, 4]]) }, /Vector expected/)
assert.throws(function () { dot([[1, 2], [3, 4]], [[1, 2], [3, 4]]) }, /Expected a column vector, instead got a matrix of size \(2, 2\)/)
})

it('should throw an error in case of vectors with unequal length', function () {
assert.throws(function () { math.dot([2, 3], [1, 2, 3]) }, /Vectors must have equal length \(2 != 3\)/)
assert.throws(function () { dot([2, 3], [1, 2, 3]) }, /Vectors must have equal length \(2 != 3\)/)
})

it('should throw an error in case of empty vectors', function () {
assert.throws(function () { math.dot([], []) }, /Cannot calculate the dot product of empty vectors/)
assert.throws(function () { dot([], []) }, /Cannot calculate the dot product of empty vectors/)
})

it('should LaTeX dot', function () {
const expression = math.parse('dot([1,2],[3,4])')
assert.strictEqual(expression.toTex(), '\\left(\\begin{bmatrix}1\\\\2\\\\\\end{bmatrix}\\cdot\\begin{bmatrix}3\\\\4\\\\\\end{bmatrix}\\right)')
})

it('should be antilinear in the first argument', function () {
const I = complex(0, 1)
assert.deepStrictEqual(dot([I, 2], [1, I]), I)

const v = matrix([2, I, 1])
assert.deepStrictEqual(dot(v, v).sqrt(), complex(math.norm(v)))
})
})