-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #136 from mipt-npm/ejml
Drop koma module, implement kmath-ejml module copying it, but for EJML SimpleMatrix
- Loading branch information
Showing
13 changed files
with
463 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
examples/src/main/kotlin/kscience/kmath/linear/LinearAlgebraBenchmark.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package kscience.kmath.linear | ||
|
||
import kscience.kmath.commons.linear.CMMatrixContext | ||
import kscience.kmath.commons.linear.inverse | ||
import kscience.kmath.commons.linear.toCM | ||
import kscience.kmath.ejml.EjmlMatrixContext | ||
import kscience.kmath.ejml.inverse | ||
import kscience.kmath.operations.RealField | ||
import kscience.kmath.operations.invoke | ||
import kscience.kmath.structures.Matrix | ||
import kotlin.random.Random | ||
import kotlin.system.measureTimeMillis | ||
|
||
fun main() { | ||
val random = Random(1224) | ||
val dim = 100 | ||
//creating invertible matrix | ||
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } | ||
val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 } | ||
val matrix = l dot u | ||
|
||
val n = 5000 // iterations | ||
|
||
MatrixContext.real { | ||
repeat(50) { inverse(matrix) } | ||
val inverseTime = measureTimeMillis { repeat(n) { inverse(matrix) } } | ||
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis") | ||
} | ||
|
||
//commons-math | ||
|
||
val commonsTime = measureTimeMillis { | ||
CMMatrixContext { | ||
val cm = matrix.toCM() //avoid overhead on conversion | ||
repeat(n) { inverse(cm) } | ||
} | ||
} | ||
|
||
|
||
println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis") | ||
|
||
val ejmlTime = measureTimeMillis { | ||
(EjmlMatrixContext(RealField)) { | ||
val km = matrix.toEjml() //avoid overhead on conversion | ||
repeat(n) { inverse(km) } | ||
} | ||
} | ||
|
||
println("[ejml] Inversion of $n matrices $dim x $dim finished in $ejmlTime millis") | ||
} |
38 changes: 38 additions & 0 deletions
38
examples/src/main/kotlin/kscience/kmath/linear/MultiplicationBenchmark.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package kscience.kmath.linear | ||
|
||
import kscience.kmath.commons.linear.CMMatrixContext | ||
import kscience.kmath.commons.linear.toCM | ||
import kscience.kmath.ejml.EjmlMatrixContext | ||
import kscience.kmath.operations.RealField | ||
import kscience.kmath.operations.invoke | ||
import kscience.kmath.structures.Matrix | ||
import kotlin.random.Random | ||
import kotlin.system.measureTimeMillis | ||
|
||
fun main() { | ||
val random = Random(12224) | ||
val dim = 1000 | ||
//creating invertible matrix | ||
val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } | ||
val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } | ||
|
||
// //warmup | ||
// matrix1 dot matrix2 | ||
|
||
CMMatrixContext { | ||
val cmMatrix1 = matrix1.toCM() | ||
val cmMatrix2 = matrix2.toCM() | ||
val cmTime = measureTimeMillis { cmMatrix1 dot cmMatrix2 } | ||
println("CM implementation time: $cmTime") | ||
} | ||
|
||
(EjmlMatrixContext(RealField)) { | ||
val ejmlMatrix1 = matrix1.toEjml() | ||
val ejmlMatrix2 = matrix2.toEjml() | ||
val ejmlTime = measureTimeMillis { ejmlMatrix1 dot ejmlMatrix2 } | ||
println("EJML implementation time: $ejmlTime") | ||
} | ||
|
||
val genericTime = measureTimeMillis { val res = matrix1 dot matrix2 } | ||
println("Generic implementation time: $genericTime") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
plugins { | ||
id("ru.mipt.npm.jvm") | ||
} | ||
|
||
dependencies { | ||
implementation("org.ejml:ejml-simple:0.39") | ||
implementation(project(":kmath-core")) | ||
} |
71 changes: 71 additions & 0 deletions
71
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package kscience.kmath.ejml | ||
|
||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM | ||
import org.ejml.simple.SimpleMatrix | ||
import kscience.kmath.linear.DeterminantFeature | ||
import kscience.kmath.linear.FeaturedMatrix | ||
import kscience.kmath.linear.LUPDecompositionFeature | ||
import kscience.kmath.linear.MatrixFeature | ||
import kscience.kmath.structures.NDStructure | ||
|
||
/** | ||
* Represents featured matrix over EJML [SimpleMatrix]. | ||
* | ||
* @property origin the underlying [SimpleMatrix]. | ||
* @author Iaroslav Postovalov | ||
*/ | ||
public class EjmlMatrix(public val origin: SimpleMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> { | ||
public override val rowNum: Int | ||
get() = origin.numRows() | ||
|
||
public override val colNum: Int | ||
get() = origin.numCols() | ||
|
||
public override val shape: IntArray | ||
get() = intArrayOf(origin.numRows(), origin.numCols()) | ||
|
||
public override val features: Set<MatrixFeature> = setOf( | ||
object : LUPDecompositionFeature<Double>, DeterminantFeature<Double> { | ||
override val determinant: Double | ||
get() = origin.determinant() | ||
|
||
private val lup by lazy { | ||
val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()) | ||
.also { it.decompose(origin.ddrm.copy()) } | ||
|
||
Triple( | ||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), | ||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), | ||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), | ||
) | ||
} | ||
|
||
override val l: FeaturedMatrix<Double> | ||
get() = lup.second | ||
|
||
override val u: FeaturedMatrix<Double> | ||
get() = lup.third | ||
|
||
override val p: FeaturedMatrix<Double> | ||
get() = lup.first | ||
} | ||
) union features.orEmpty() | ||
|
||
public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix = | ||
EjmlMatrix(origin, this.features + features) | ||
|
||
public override operator fun get(i: Int, j: Int): Double = origin[i, j] | ||
|
||
public override fun equals(other: Any?): Boolean { | ||
if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0) | ||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) | ||
} | ||
|
||
public override fun hashCode(): Int { | ||
var result = origin.hashCode() | ||
result = 31 * result + features.hashCode() | ||
return result | ||
} | ||
|
||
public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)" | ||
} |
86 changes: 86 additions & 0 deletions
86
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package kscience.kmath.ejml | ||
|
||
import org.ejml.simple.SimpleMatrix | ||
import kscience.kmath.linear.MatrixContext | ||
import kscience.kmath.linear.Point | ||
import kscience.kmath.operations.Space | ||
import kscience.kmath.operations.invoke | ||
import kscience.kmath.structures.Matrix | ||
|
||
/** | ||
* Represents context of basic operations operating with [EjmlMatrix]. | ||
* | ||
* @author Iaroslav Postovalov | ||
*/ | ||
public class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double> { | ||
/** | ||
* Converts this matrix to EJML one. | ||
*/ | ||
public fun Matrix<Double>.toEjml(): EjmlMatrix = | ||
if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) } | ||
|
||
/** | ||
* Converts this vector to EJML one. | ||
*/ | ||
public fun Point<Double>.toEjml(): EjmlVector = | ||
if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also { | ||
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) } | ||
}) | ||
|
||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix = | ||
EjmlMatrix(SimpleMatrix(rows, columns).also { | ||
(0 until it.numRows()).forEach { row -> | ||
(0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) } | ||
} | ||
}) | ||
|
||
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix = | ||
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin)) | ||
|
||
public override fun Matrix<Double>.dot(vector: Point<Double>): EjmlVector = | ||
EjmlVector(toEjml().origin.mult(vector.toEjml().origin)) | ||
|
||
public override fun add(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix = | ||
EjmlMatrix(a.toEjml().origin + b.toEjml().origin) | ||
|
||
public override operator fun Matrix<Double>.minus(b: Matrix<Double>): EjmlMatrix = | ||
EjmlMatrix(toEjml().origin - b.toEjml().origin) | ||
|
||
public override fun multiply(a: Matrix<Double>, k: Number): EjmlMatrix = | ||
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } | ||
|
||
public override operator fun Matrix<Double>.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value)) | ||
|
||
public companion object | ||
} | ||
|
||
/** | ||
* Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix. | ||
* | ||
* @param a the base matrix. | ||
* @param b n by p matrix. | ||
* @return the solution for 'x' that is n by p. | ||
* @author Iaroslav Postovalov | ||
*/ | ||
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix = | ||
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin)) | ||
|
||
/** | ||
* Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix. | ||
* | ||
* @param a the base matrix. | ||
* @param b n by p vector. | ||
* @return the solution for 'x' that is n by p. | ||
* @author Iaroslav Postovalov | ||
*/ | ||
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector = | ||
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) | ||
|
||
/** | ||
* Returns the inverse of given matrix: b = a^(-1). | ||
* | ||
* @param a the matrix. | ||
* @return the inverse of this matrix. | ||
* @author Iaroslav Postovalov | ||
*/ | ||
public fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert()) |
Oops, something went wrong.