From 3330c8c6c727e5111e76b15896e36d69bd9ccb5d Mon Sep 17 00:00:00 2001 From: T45K Date: Sun, 18 Aug 2024 16:53:26 +0900 Subject: [PATCH 1/3] add norm API to Vector --- .../kotlinx/multik/api/linalg/norm.kt | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/linalg/norm.kt b/multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/linalg/norm.kt index 835a99ed..72334415 100644 --- a/multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/linalg/norm.kt +++ b/multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/linalg/norm.kt @@ -4,16 +4,34 @@ package org.jetbrains.kotlinx.multik.api.linalg +import org.jetbrains.kotlinx.multik.api.mk +import org.jetbrains.kotlinx.multik.api.zeros +import org.jetbrains.kotlinx.multik.ndarray.data.D1 import org.jetbrains.kotlinx.multik.ndarray.data.D2 import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray +import org.jetbrains.kotlinx.multik.ndarray.operations.stack import kotlin.jvm.JvmName +/** + * Returns norm of float vector + */ +@JvmName("normFV") +public fun LinAlg.norm(mat: MultiArray, norm: Norm = Norm.Fro): Float = + this.linAlgEx.normF(mk.stack(mat, mk.zeros(mat.size)), norm) + /** * Returns norm of float matrix */ @JvmName("normF") public fun LinAlg.norm(mat: MultiArray, norm: Norm = Norm.Fro): Float = this.linAlgEx.normF(mat, norm) +/** + * Returns norm of double vector + */ +@JvmName("normDV") +public fun LinAlg.norm(mat: MultiArray, norm: Norm = Norm.Fro): Double = + this.linAlgEx.norm(mk.stack(mat, mk.zeros(mat.size)), norm) + /** * Returns norm of double matrix */ From fcb3ad977ddd8e422c2929e1eea2f276fcbf6ba2 Mon Sep 17 00:00:00 2001 From: T45K Date: Mon, 19 Aug 2024 20:13:25 +0900 Subject: [PATCH 2/3] add tests --- .../multik_kotlin/linAlg/KELinAlgTest.kt | 49 ++++++++++++++++--- .../openblas/linalg/NativeLinAlgTest.kt | 10 ++++ 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt b/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt index 6bb913b9..a26ec311 100644 --- a/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt +++ b/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt @@ -5,26 +5,49 @@ package org.jetbrains.kotlinx.multik_kotlin.linAlg -import org.jetbrains.kotlinx.multik.api.* +import kotlin.math.abs +import kotlin.math.max +import kotlin.math.min +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import org.jetbrains.kotlinx.multik.api.d1array +import org.jetbrains.kotlinx.multik.api.d2array import org.jetbrains.kotlinx.multik.api.linalg.Norm import org.jetbrains.kotlinx.multik.api.linalg.dot import org.jetbrains.kotlinx.multik.api.linalg.norm -import org.jetbrains.kotlinx.multik.kotlin.linalg.* +import org.jetbrains.kotlinx.multik.api.mk +import org.jetbrains.kotlinx.multik.api.ndarray +import org.jetbrains.kotlinx.multik.api.zeros +import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg +import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx.solve import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx.solveC +import org.jetbrains.kotlinx.multik.kotlin.linalg.conjTranspose +import org.jetbrains.kotlinx.multik.kotlin.linalg.dotMatrixComplex +import org.jetbrains.kotlinx.multik.kotlin.linalg.gramShmidtComplexDouble +import org.jetbrains.kotlinx.multik.kotlin.linalg.qrComplexDouble +import org.jetbrains.kotlinx.multik.kotlin.linalg.schurDecomposition +import org.jetbrains.kotlinx.multik.kotlin.linalg.upperHessenbergDouble import org.jetbrains.kotlinx.multik.ndarray.complex.Complex import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexDouble import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexFloat import org.jetbrains.kotlinx.multik.ndarray.complex.toComplexDouble -import org.jetbrains.kotlinx.multik.ndarray.data.* +import org.jetbrains.kotlinx.multik.ndarray.data.D1Array +import org.jetbrains.kotlinx.multik.ndarray.data.D2 +import org.jetbrains.kotlinx.multik.ndarray.data.D2Array +import org.jetbrains.kotlinx.multik.ndarray.data.DataType +import org.jetbrains.kotlinx.multik.ndarray.data.Dim2 +import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray +import org.jetbrains.kotlinx.multik.ndarray.data.NDArray +import org.jetbrains.kotlinx.multik.ndarray.data.get +import org.jetbrains.kotlinx.multik.ndarray.data.set import org.jetbrains.kotlinx.multik.ndarray.operations.map import org.jetbrains.kotlinx.multik.ndarray.operations.minus import org.jetbrains.kotlinx.multik.ndarray.operations.plus -import kotlin.math.abs -import kotlin.math.max -import kotlin.math.min -import kotlin.random.Random -import kotlin.test.* class KELinAlgTest { @@ -492,6 +515,16 @@ class KELinAlgTest { } } + + @Test + fun compute_norm_for_vector() { + val vector = mk.ndarray(mk[1.1, 0.0, 3.2, 2.3, 5.0]) + + assertEquals(6.460650122085238, mk.linalg.norm(vector, Norm.Fro)) + assertEquals(11.600000000000001, mk.linalg.norm(vector, Norm.Inf)) + assertEquals(5.0, mk.linalg.norm(vector, Norm.N1)) + assertEquals(5.0, mk.linalg.norm(vector, Norm.Max)) + } } diff --git a/multik-openblas/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg/NativeLinAlgTest.kt b/multik-openblas/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg/NativeLinAlgTest.kt index 3aa362d6..a16bdb90 100644 --- a/multik-openblas/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg/NativeLinAlgTest.kt +++ b/multik-openblas/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg/NativeLinAlgTest.kt @@ -324,4 +324,14 @@ class NativeLinAlgTest { assertFloatingNumber(7.0, NativeLinAlg.norm(b, Norm.N1)) assertFloatingNumber(4.0, NativeLinAlg.norm(b, Norm.Max)) } + + @Test + fun `compute norm for vector`() { + val vector = mk.ndarray(mk[1.1, 0.0, 3.2, 2.3, 5.0]) + + assertEquals(6.460650122085238, mk.linalg.norm(vector, Norm.Fro)) + assertEquals(11.600000000000001, mk.linalg.norm(vector, Norm.Inf)) + assertEquals(5.0, mk.linalg.norm(vector, Norm.N1)) + assertEquals(5.0, mk.linalg.norm(vector, Norm.Max)) + } } \ No newline at end of file From 6b3071f08313d6a78687cdf49a8c263bd873054c Mon Sep 17 00:00:00 2001 From: T45K Date: Mon, 19 Aug 2024 20:58:37 +0900 Subject: [PATCH 3/3] revert import order --- .../multik_kotlin/linAlg/KELinAlgTest.kt | 39 ++++--------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt b/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt index a26ec311..ef753f14 100644 --- a/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt +++ b/multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt @@ -5,49 +5,26 @@ package org.jetbrains.kotlinx.multik_kotlin.linAlg -import kotlin.math.abs -import kotlin.math.max -import kotlin.math.min -import kotlin.random.Random -import kotlin.test.Test -import kotlin.test.assertContentEquals -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertTrue -import org.jetbrains.kotlinx.multik.api.d1array -import org.jetbrains.kotlinx.multik.api.d2array +import org.jetbrains.kotlinx.multik.api.* import org.jetbrains.kotlinx.multik.api.linalg.Norm import org.jetbrains.kotlinx.multik.api.linalg.dot import org.jetbrains.kotlinx.multik.api.linalg.norm -import org.jetbrains.kotlinx.multik.api.mk -import org.jetbrains.kotlinx.multik.api.ndarray -import org.jetbrains.kotlinx.multik.api.zeros -import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg -import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx +import org.jetbrains.kotlinx.multik.kotlin.linalg.* import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx.solve import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx.solveC -import org.jetbrains.kotlinx.multik.kotlin.linalg.conjTranspose -import org.jetbrains.kotlinx.multik.kotlin.linalg.dotMatrixComplex -import org.jetbrains.kotlinx.multik.kotlin.linalg.gramShmidtComplexDouble -import org.jetbrains.kotlinx.multik.kotlin.linalg.qrComplexDouble -import org.jetbrains.kotlinx.multik.kotlin.linalg.schurDecomposition -import org.jetbrains.kotlinx.multik.kotlin.linalg.upperHessenbergDouble import org.jetbrains.kotlinx.multik.ndarray.complex.Complex import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexDouble import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexFloat import org.jetbrains.kotlinx.multik.ndarray.complex.toComplexDouble -import org.jetbrains.kotlinx.multik.ndarray.data.D1Array -import org.jetbrains.kotlinx.multik.ndarray.data.D2 -import org.jetbrains.kotlinx.multik.ndarray.data.D2Array -import org.jetbrains.kotlinx.multik.ndarray.data.DataType -import org.jetbrains.kotlinx.multik.ndarray.data.Dim2 -import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray -import org.jetbrains.kotlinx.multik.ndarray.data.NDArray -import org.jetbrains.kotlinx.multik.ndarray.data.get -import org.jetbrains.kotlinx.multik.ndarray.data.set +import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.operations.map import org.jetbrains.kotlinx.multik.ndarray.operations.minus import org.jetbrains.kotlinx.multik.ndarray.operations.plus +import kotlin.math.abs +import kotlin.math.max +import kotlin.math.min +import kotlin.random.Random +import kotlin.test.* class KELinAlgTest {