Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type inferencing for Vector & Matrix operations in multiply.js (+performance boost!) #3149

Merged
merged 8 commits into from
Feb 22, 2024
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ Brooks Smith <brooks.smith@clearcalcs.com>
Alex Edgcomb <aedgcomb@gmail.com>
S.Y. Lee <sylee957@gmail.com>
Hudsxn <143907857+Hudsxn@users.noreply.github.com>
RandomGamingDev <randomgamingdev@gmail.com>
Rich Martinez <richard.i.martinez.jr@gmail.com>

# Generated by tools/update-authors.js
60 changes: 30 additions & 30 deletions src/function/arithmetic/multiply.js
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b dense
const bdata = b._data
const bsize = b._size
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const alength = asize[0]
const bcolumns = bsize[1]
Expand All @@ -127,7 +127,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
josdejong marked this conversation as resolved.
Show resolved Hide resolved
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -154,7 +154,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
return a.createDenseMatrix({
data: c,
size: [bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand Down Expand Up @@ -198,10 +198,10 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b dense
const bdata = b._data
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = asize[0]
const acolumns = asize[1]
Expand All @@ -214,7 +214,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand Down Expand Up @@ -243,7 +243,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
return a.createDenseMatrix({
data: c,
size: [arows],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand All @@ -255,15 +255,15 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
*
* @return {Matrix} DenseMatrix (MxC)
*/
function _multiplyDenseMatrixDenseMatrix (a, b) {
function _multiplyDenseMatrixDenseMatrix (a, b) { // getDataType()
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b dense
const bdata = b._data
const bsize = b._size
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = asize[0]
const acolumns = asize[1]
Expand All @@ -277,7 +277,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand Down Expand Up @@ -311,7 +311,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
return a.createDenseMatrix({
data: c,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand All @@ -327,13 +327,13 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b sparse
const bvalues = b._values
const bindex = b._index
const bptr = b._ptr
const bsize = b._size
const bdt = b._datatype
const bdt = b._datatype || b._data === undefined ? b._datatype : b.getDataType()
// validate b matrix
if (!bvalues) { throw new Error('Cannot multiply Dense Matrix times Pattern only Matrix') }
// rows & columns
Expand All @@ -352,7 +352,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let zero = 0

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -373,7 +373,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
index: cindex,
ptr: cptr,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})

// loop b columns
Expand Down Expand Up @@ -437,12 +437,12 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
const avalues = a._values
const aindex = a._index
const aptr = a._ptr
const adt = a._datatype
const adt = a._datatype || a._data === undefined ? a._datatype : a.getDataType()
// validate a matrix
if (!avalues) { throw new Error('Cannot multiply Pattern only Matrix times Dense Matrix') }
// b dense
const bdata = b._data
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = a._size[0]
const brows = b._size[0]
Expand All @@ -463,7 +463,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let zero = 0

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand Down Expand Up @@ -516,13 +516,13 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// update ptr
cptr[1] = cindex.length

// return sparse matrix
// matrix to return
return a.createSparseMatrix({
values: cvalues,
index: cindex,
ptr: cptr,
size: [arows, 1],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand All @@ -539,12 +539,12 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
const avalues = a._values
const aindex = a._index
const aptr = a._ptr
const adt = a._datatype
const adt = a._datatype || a._data === undefined ? a._datatype : a.getDataType()
// validate a matrix
if (!avalues) { throw new Error('Cannot multiply Pattern only Matrix times Dense Matrix') }
// b dense
const bdata = b._data
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = a._size[0]
const brows = b._size[0]
Expand All @@ -562,7 +562,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let zero = 0

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -583,7 +583,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
index: cindex,
ptr: cptr,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})

// workspace
Expand Down Expand Up @@ -650,12 +650,12 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
const avalues = a._values
const aindex = a._index
const aptr = a._ptr
const adt = a._datatype
const adt = a._datatype || a._data === undefined ? a._datatype : a.getDataType()
// b sparse
const bvalues = b._values
const bindex = b._index
const bptr = b._ptr
const bdt = b._datatype
const bdt = b._datatype || b._data === undefined ? b._datatype : b.getDataType()

// rows & columns
const arows = a._size[0]
Expand All @@ -671,7 +671,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -689,7 +689,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
index: cindex,
ptr: cptr,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})

// workspace
Expand Down
6 changes: 3 additions & 3 deletions src/function/matrix/dot.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, a
const N = _validateDim(a, b)

const adata = isMatrix(a) ? a._data : a
const adt = isMatrix(a) ? a._datatype : undefined
const adt = isMatrix(a) ? a._datatype || a.getDataType() : undefined

const bdata = isMatrix(b) ? b._data : b
const bdt = isMatrix(b) ? b._datatype : undefined
const bdt = isMatrix(b) ? b._datatype || b.getDataType() : undefined

// are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors)
const aIsColumn = _size(a).length === 2
Expand All @@ -77,7 +77,7 @@ export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, a
let mul = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
const dt = adt
// find signatures that matches (dt, dt)
add = typed.find(addScalar, [dt, dt])
Expand Down
8 changes: 4 additions & 4 deletions src/type/matrix/utils/matAlgo01xDSid.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ export const createMatAlgo01xDSid = /* #__PURE__ */ factory(name, dependencies,
// dense matrix arrays
const adata = denseMatrix._data
const asize = denseMatrix._size
const adt = denseMatrix._datatype
const adt = denseMatrix._datatype || denseMatrix.getDataType()
// sparse matrix arrays
const bvalues = sparseMatrix._values
const bindex = sparseMatrix._index
const bptr = sparseMatrix._ptr
const bsize = sparseMatrix._size
const bdt = sparseMatrix._datatype
const bdt = sparseMatrix._datatype || sparseMatrix._data === undefined ? sparseMatrix._datatype : sparseMatrix.getDataType()

// validate dimensions
if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
Expand All @@ -50,7 +50,7 @@ export const createMatAlgo01xDSid = /* #__PURE__ */ factory(name, dependencies,
const columns = asize[1]

// process data types
const dt = typeof adt === 'string' && adt === bdt ? adt : undefined
const dt = typeof adt === 'string' && adt !== 'mixed' && adt === bdt ? adt : undefined
// callback function
const cf = dt ? typed.find(callback, [dt, dt]) : callback

Expand Down Expand Up @@ -97,7 +97,7 @@ export const createMatAlgo01xDSid = /* #__PURE__ */ factory(name, dependencies,
return denseMatrix.createDenseMatrix({
data: cdata,
size: [rows, columns],
datatype: dt
datatype: adt === denseMatrix._datatype && bdt === sparseMatrix._datatype ? dt : undefined
})
}
})
8 changes: 4 additions & 4 deletions src/type/matrix/utils/matAlgo02xDS0.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ export const createMatAlgo02xDS0 = /* #__PURE__ */ factory(name, dependencies, (
// dense matrix arrays
const adata = denseMatrix._data
const asize = denseMatrix._size
const adt = denseMatrix._datatype
const adt = denseMatrix._datatype || denseMatrix.getDataType()
// sparse matrix arrays
const bvalues = sparseMatrix._values
const bindex = sparseMatrix._index
const bptr = sparseMatrix._ptr
const bsize = sparseMatrix._size
const bdt = sparseMatrix._datatype
const bdt = sparseMatrix._datatype || sparseMatrix._data === undefined ? sparseMatrix._datatype : sparseMatrix.getDataType()

// validate dimensions
if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
Expand All @@ -59,7 +59,7 @@ export const createMatAlgo02xDS0 = /* #__PURE__ */ factory(name, dependencies, (
let cf = callback

// process data types
if (typeof adt === 'string' && adt === bdt) {
if (typeof adt === 'string' && adt === bdt && adt !== 'mixed') {
// datatype
dt = adt
// find signature that matches (dt, dt)
Expand Down Expand Up @@ -102,7 +102,7 @@ export const createMatAlgo02xDS0 = /* #__PURE__ */ factory(name, dependencies, (
index: cindex,
ptr: cptr,
size: [rows, columns],
datatype: dt
datatype: adt === denseMatrix._datatype && bdt === sparseMatrix._datatype ? dt : undefined
})
}
})
8 changes: 4 additions & 4 deletions src/type/matrix/utils/matAlgo03xDSf.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ export const createMatAlgo03xDSf = /* #__PURE__ */ factory(name, dependencies, (
// dense matrix arrays
const adata = denseMatrix._data
const asize = denseMatrix._size
const adt = denseMatrix._datatype
const adt = denseMatrix._datatype || denseMatrix.getDataType()
// sparse matrix arrays
const bvalues = sparseMatrix._values
const bindex = sparseMatrix._index
const bptr = sparseMatrix._ptr
const bsize = sparseMatrix._size
const bdt = sparseMatrix._datatype
const bdt = sparseMatrix._datatype || sparseMatrix._data === undefined ? sparseMatrix._datatype : sparseMatrix.getDataType()

// validate dimensions
if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
Expand All @@ -57,7 +57,7 @@ export const createMatAlgo03xDSf = /* #__PURE__ */ factory(name, dependencies, (
let cf = callback

// process data types
if (typeof adt === 'string' && adt === bdt) {
if (typeof adt === 'string' && adt === bdt && adt !== 'mixed') {
// datatype
dt = adt
// convert 0 to the same datatype
Expand Down Expand Up @@ -109,7 +109,7 @@ export const createMatAlgo03xDSf = /* #__PURE__ */ factory(name, dependencies, (
return denseMatrix.createDenseMatrix({
data: cdata,
size: [rows, columns],
datatype: dt
datatype: adt === denseMatrix._datatype && bdt === sparseMatrix._datatype ? dt : undefined
})
}
})
Loading