Skip to content

Commit

Permalink
add native statistics impl #62
Browse files Browse the repository at this point in the history
  • Loading branch information
devcrocod committed Jul 16, 2022
1 parent f53fc40 commit e5d38b8
Show file tree
Hide file tree
Showing 28 changed files with 397 additions and 43 deletions.
4 changes: 2 additions & 2 deletions multik-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ kotlin {
implementation("org.apache.commons:commons-csv:$common_csv_version")
}
}
val nativeCommonMain by creating {
val nativeMain by creating {
dependsOn(commonMain)
}
names.forEach { n ->
if (n.contains("X64Main") || n.contains("Arm64Main")){
this@sourceSets.getByName(n).apply{
dependsOn(nativeCommonMain)
dependsOn(nativeMain)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package org.jetbrains.kotlinx.multik.api

import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
import org.jetbrains.kotlinx.multik.api.math.Math
import org.jetbrains.kotlinx.multik.api.stat.Statistics

public sealed class EngineType(public val name: String)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.multik.api.Multik.math
import org.jetbrains.kotlinx.multik.api.Multik.stat
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
import org.jetbrains.kotlinx.multik.api.math.Math
import org.jetbrains.kotlinx.multik.api.stat.Statistics

/**
* Abbreviated name for [Multik].
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.api
package org.jetbrains.kotlinx.multik.api.stat

import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexDouble
import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexFloat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

package org.jetbrains.kotlinx.multik.default

import org.jetbrains.kotlinx.multik.api.*
import org.jetbrains.kotlinx.multik.api.DefaultEngineType
import org.jetbrains.kotlinx.multik.api.Engine
import org.jetbrains.kotlinx.multik.api.EngineType
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
import org.jetbrains.kotlinx.multik.api.math.Math
import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.default.linalg.DefaultLinAlg
import org.jetbrains.kotlinx.multik.default.math.DefaultMath
import org.jetbrains.kotlinx.multik.default.stat.DefaultStatistics

public class DefaultEngine : Engine() {
override val name: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,26 @@
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.default
package org.jetbrains.kotlinx.multik.default.stat

import org.jetbrains.kotlinx.multik.api.Statistics
import org.jetbrains.kotlinx.multik.kotlin.KEStatistics
import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.ndarray.data.*

public object DefaultStatistics : Statistics {
public expect object DefaultStatistics : Statistics {

override fun <T : Number, D : Dimension> median(a: MultiArray<T, D>): Double? = KEStatistics.median(a)
override fun <T : Number, D : Dimension> median(a: MultiArray<T, D>): Double?

override fun <T : Number, D : Dimension> average(a: MultiArray<T, D>, weights: MultiArray<T, D>?): Double =
KEStatistics.average(a, weights)
override fun <T : Number, D : Dimension> average(a: MultiArray<T, D>, weights: MultiArray<T, D>?): Double

override fun <T : Number, D : Dimension> mean(a: MultiArray<T, D>): Double = KEStatistics.mean(a)
override fun <T : Number, D : Dimension> mean(a: MultiArray<T, D>): Double

override fun <T : Number, D : Dimension, O : Dimension> mean(a: MultiArray<T, D>, axis: Int): NDArray<Double, O> =
KEStatistics.mean(a, axis)
override fun <T : Number, D : Dimension, O : Dimension> mean(a: MultiArray<T, D>, axis: Int): NDArray<Double, O>

override fun <T : Number> meanD2(a: MultiArray<T, D2>, axis: Int): NDArray<Double, D1> = mean(a, axis)
override fun <T : Number> meanD2(a: MultiArray<T, D2>, axis: Int): NDArray<Double, D1>

override fun <T : Number> meanD3(a: MultiArray<T, D3>, axis: Int): NDArray<Double, D2> = mean(a, axis)
override fun <T : Number> meanD3(a: MultiArray<T, D3>, axis: Int): NDArray<Double, D2>

override fun <T : Number> meanD4(a: MultiArray<T, D4>, axis: Int): NDArray<Double, D3> = mean(a, axis)

override fun <T : Number> meanDN(a: MultiArray<T, DN>, axis: Int): NDArray<Double, D4> = mean(a, axis)
override fun <T : Number> meanD4(a: MultiArray<T, D4>, axis: Int): NDArray<Double, D3>

override fun <T : Number> meanDN(a: MultiArray<T, DN>, axis: Int): NDArray<Double, D4>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.default.stat

import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.kotlin.stat.KEStatistics
import org.jetbrains.kotlinx.multik.ndarray.data.*

public actual object DefaultStatistics : Statistics {

actual override fun <T : Number, D : Dimension> median(a: MultiArray<T, D>): Double? = KEStatistics.median(a)

actual override fun <T : Number, D : Dimension> average(a: MultiArray<T, D>, weights: MultiArray<T, D>?): Double =
KEStatistics.average(a, weights)

actual override fun <T : Number, D : Dimension> mean(a: MultiArray<T, D>): Double = KEStatistics.mean(a)

actual override fun <T : Number, D : Dimension, O : Dimension> mean(a: MultiArray<T, D>, axis: Int): NDArray<Double, O> =
KEStatistics.mean(a, axis)

actual override fun <T : Number> meanD2(a: MultiArray<T, D2>, axis: Int): NDArray<Double, D1> = mean(a, axis)

actual override fun <T : Number> meanD3(a: MultiArray<T, D3>, axis: Int): NDArray<Double, D2> = mean(a, axis)

actual override fun <T : Number> meanD4(a: MultiArray<T, D4>, axis: Int): NDArray<Double, D3> = mean(a, axis)

actual override fun <T : Number> meanDN(a: MultiArray<T, DN>, axis: Int): NDArray<Double, D4> = mean(a, axis)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.default.stat

import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.kotlin.stat.KEStatistics
import org.jetbrains.kotlinx.multik.ndarray.data.*

public actual object DefaultStatistics : Statistics {

actual override fun <T : Number, D : Dimension> median(a: MultiArray<T, D>): Double? = KEStatistics.median(a)

actual override fun <T : Number, D : Dimension> average(a: MultiArray<T, D>, weights: MultiArray<T, D>?): Double =
KEStatistics.average(a, weights)

actual override fun <T : Number, D : Dimension> mean(a: MultiArray<T, D>): Double = KEStatistics.mean(a)

actual override fun <T : Number, D : Dimension, O : Dimension> mean(a: MultiArray<T, D>, axis: Int): NDArray<Double, O> =
KEStatistics.mean(a, axis)

actual override fun <T : Number> meanD2(a: MultiArray<T, D2>, axis: Int): NDArray<Double, D1> = mean(a, axis)

actual override fun <T : Number> meanD3(a: MultiArray<T, D3>, axis: Int): NDArray<Double, D2> = mean(a, axis)

actual override fun <T : Number> meanD4(a: MultiArray<T, D4>, axis: Int): NDArray<Double, D3> = mean(a, axis)

actual override fun <T : Number> meanDN(a: MultiArray<T, DN>, axis: Int): NDArray<Double, D4> = mean(a, axis)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.default.stat

import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.kotlin.stat.KEStatistics
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.openblas.stat.NativeStatistics

public actual object DefaultStatistics : Statistics {

actual override fun <T : Number, D : Dimension> median(a: MultiArray<T, D>): Double? = NativeStatistics.median(a)

actual override fun <T : Number, D : Dimension> average(a: MultiArray<T, D>, weights: MultiArray<T, D>?): Double =
KEStatistics.average(a, weights)

actual override fun <T : Number, D : Dimension> mean(a: MultiArray<T, D>): Double = KEStatistics.mean(a)

actual override fun <T : Number, D : Dimension, O : Dimension> mean(a: MultiArray<T, D>, axis: Int): NDArray<Double, O> =
KEStatistics.mean(a, axis)

actual override fun <T : Number> meanD2(a: MultiArray<T, D2>, axis: Int): NDArray<Double, D1> = mean(a, axis)

actual override fun <T : Number> meanD3(a: MultiArray<T, D3>, axis: Int): NDArray<Double, D2> = mean(a, axis)

actual override fun <T : Number> meanD4(a: MultiArray<T, D4>, axis: Int): NDArray<Double, D3> = mean(a, axis)

actual override fun <T : Number> meanDN(a: MultiArray<T, DN>, axis: Int): NDArray<Double, D4> = mean(a, axis)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.default.stat

import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.kotlin.stat.KEStatistics
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.openblas.stat.NativeStatistics

public actual object DefaultStatistics : Statistics {

actual override fun <T : Number, D : Dimension> median(a: MultiArray<T, D>): Double? = NativeStatistics.median(a)

actual override fun <T : Number, D : Dimension> average(a: MultiArray<T, D>, weights: MultiArray<T, D>?): Double =
KEStatistics.average(a, weights)

actual override fun <T : Number, D : Dimension> mean(a: MultiArray<T, D>): Double = KEStatistics.mean(a)

actual override fun <T : Number, D : Dimension, O : Dimension> mean(a: MultiArray<T, D>, axis: Int): NDArray<Double, O> =
KEStatistics.mean(a, axis)

actual override fun <T : Number> meanD2(a: MultiArray<T, D2>, axis: Int): NDArray<Double, D1> = mean(a, axis)

actual override fun <T : Number> meanD3(a: MultiArray<T, D3>, axis: Int): NDArray<Double, D2> = mean(a, axis)

actual override fun <T : Number> meanD4(a: MultiArray<T, D4>, axis: Int): NDArray<Double, D3> = mean(a, axis)

actual override fun <T : Number> meanDN(a: MultiArray<T, DN>, axis: Int): NDArray<Double, D4> = mean(a, axis)
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
/*
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.kotlin

import org.jetbrains.kotlinx.multik.api.Engine
import org.jetbrains.kotlinx.multik.api.EngineType
import org.jetbrains.kotlinx.multik.api.KEEngineType
import org.jetbrains.kotlinx.multik.api.Statistics
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
import org.jetbrains.kotlinx.multik.api.math.Math
import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
import org.jetbrains.kotlinx.multik.kotlin.math.KEMath
import org.jetbrains.kotlinx.multik.kotlin.stat.KEStatistics


public class KEEngine : Engine() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
/*
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik.kotlin
package org.jetbrains.kotlinx.multik.kotlin.stat

import org.jetbrains.kotlinx.multik.api.Statistics
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.kotlin.math.KEMath
import org.jetbrains.kotlinx.multik.kotlin.math.remove
import org.jetbrains.kotlinx.multik.ndarray.data.*
Expand Down Expand Up @@ -34,7 +33,7 @@ public object KEStatistics : Statistics {

override fun <T : Number, D : Dimension> average(a: MultiArray<T, D>, weights: MultiArray<T, D>?): Double {
if (weights == null) return mean(a)
return mk.math.sum(a * weights).toDouble() / mk.math.sum(weights).toDouble()
return KEMath.sum(a * weights).toDouble() / KEMath.sum(weights).toDouble()
}

override fun <T : Number, D : Dimension> mean(a: MultiArray<T, D>): Double {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
/*
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package org.jetbrains.kotlinx.multik_kotlin.statistics

import org.jetbrains.kotlinx.multik.api.arange
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.ndarray
import org.jetbrains.kotlinx.multik.kotlin.KEStatistics
import org.jetbrains.kotlinx.multik.kotlin.stat.KEStatistics
import kotlin.test.Test
import kotlin.test.assertEquals

Expand All @@ -16,7 +16,7 @@ class KEStatisticsTest {
@Test
fun test_median() {
val a = mk.ndarray(mk[mk[10, 7, 4], mk[3, 2, 1]])
println(mk.stat.median(a))
assertEquals(3.5, mk.stat.median(a))
}

@Test
Expand Down
3 changes: 2 additions & 1 deletion multik-openblas/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ kotlin {
val cinteropDir = "${projectDir}/cinterop"
val headersDir = "${projectDir}/multik_jni/src/main/headers/"
val cppDir = "${projectDir}/multik_jni/src/main/cpp"
headers("$headersDir/mk_math.h", "$headersDir/mk_linalg.h")
headers("$headersDir/mk_math.h", "$headersDir/mk_linalg.h", "$headersDir/mk_stat.h")
defFile(project.file(("$cinteropDir/libmultik.def")))

when (konanTarget.family) {
Expand All @@ -87,6 +87,7 @@ kotlin {
extraOpts("-Xsource-compiler-option", "-I${buildDir}/cmake-build/openblas-install/include")
extraOpts("-Xcompile-source", "$cppDir/mk_math.cpp")
extraOpts("-Xcompile-source", "$cppDir/mk_linalg.cpp")
extraOpts("-Xcompile-source", "$cppDir/mk_stat.cpp")
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion multik-openblas/multik_jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,17 @@ include_directories("src/main/headers")
set(SRC_FILES_FOR_SHARED
src/main/cpp/jni_Linalg.cpp
src/main/cpp/jni_JniMath.cpp
src/main/cpp/JniStat.cpp
src/main/cpp/ComplexDouble.cpp
src/main/cpp/ComplexFloat.cpp
src/main/cpp/mk_math.cpp
src/main/cpp/mk_linalg.cpp
src/main/cpp/mk_stat.cpp
)
set(SRC_FILES_FOR_STATIC
src/main/cpp/mk_math.cpp
src/main/cpp/mk_linalg.cpp)
src/main/cpp/mk_linalg.cpp
src/main/cpp/mk_stat.cpp)
########################

### FIND GFORTRAN AND QUADMATH LIBRARIES ###
Expand Down
23 changes: 23 additions & 0 deletions multik-openblas/multik_jni/src/main/cpp/JniStat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

#include "JniStat.h"
#include "mk_stat.h"

/*
* Class: org_jetbrains_kotlinx_multik_openblas_stat_JniStat
* Method: median
* Signature: (Ljava/lang/Object;II)D
*/
JNIEXPORT jdouble JNICALL Java_org_jetbrains_kotlinx_multik_openblas_stat_JniStat_median
(JNIEnv *env, jobject jobj, jobject jarr, jint size, jint type) {

void *varr = env->GetPrimitiveArrayCritical((jarray)jarr, nullptr);

double ret = array_median(varr, size, type);

env->ReleasePrimitiveArrayCritical((jarray)jarr, varr, 0);

return ret;
}
3 changes: 1 addition & 2 deletions multik-openblas/multik_jni/src/main/cpp/jni_JniMath.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
/*
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

#include <map>
#include "jni_JniMath.h"
#include "mk_math.h"

Expand Down
Loading

0 comments on commit e5d38b8

Please sign in to comment.