Skip to content

Commit

Permalink
Named buffers and named module
Browse files Browse the repository at this point in the history
  • Loading branch information
altavir committed Dec 14, 2024
1 parent ab16bd1 commit c475c43
Show file tree
Hide file tree
Showing 19 changed files with 84 additions and 19 deletions.
Empty file removed .space/CODEOWNERS
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import kotlinx.html.br
import kotlinx.html.h3
import space.kscience.kmath.commons.optimization.CMOptimizer
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.autodiff
import space.kscience.kmath.expressions.symbol
import space.kscience.kmath.operations.asIterable
Expand All @@ -19,6 +20,7 @@ import space.kscience.kmath.real.DoubleVector
import space.kscience.kmath.real.map
import space.kscience.kmath.real.step
import space.kscience.kmath.stat.chiSquaredExpression
import space.kscience.kmath.structures.Float64
import space.kscience.plotly.*
import space.kscience.plotly.models.ScatterMode
import space.kscience.plotly.models.TraceValues
Expand Down Expand Up @@ -64,7 +66,7 @@ suspend fun main() {
val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma)

// compute differentiable chi^2 sum for given model ax^2 + bx + c
val chi2 = Double.autodiff.chiSquaredExpression(x, y, yErr) { arg ->
val chi2: DifferentiableExpression<Float64> = Double.autodiff.chiSquaredExpression(x, y, yErr) { arg ->
//bind variables to autodiff context
val a = bindSymbol(a)
val b = bindSymbol(b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ package space.kscience.kmath.wasm.internal

import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.DoubleExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.IntExpression
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.internal.binaryen.*
import space.kscience.kmath.internal.webassembly.Instance
import space.kscience.kmath.named.SimpleSymbolIndexer
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Float64
import space.kscience.kmath.internal.binaryen.Module as BinaryenModule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.*
import space.kscience.kmath.named.SimpleSymbolIndexer
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.Int32Ring
import space.kscience.kmath.structures.Float64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import space.kscience.kmath.asm.internal.*
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.*
import space.kscience.kmath.named.SimpleSymbolIndexer
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.Int32Ring
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.*
import space.kscience.kmath.named.SimpleSymbolIndexer
import space.kscience.kmath.named.SymbolIndexer
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Float64
import java.lang.invoke.MethodHandles
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import space.kscience.attributes.AttributesBuilder
import space.kscience.attributes.SetAttribute
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.SymbolIndexer
import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.expressions.withSymbols
import space.kscience.kmath.named.SymbolIndexer
import space.kscience.kmath.named.withSymbols
import space.kscience.kmath.optimization.*
import space.kscience.kmath.structures.Float64
import kotlin.collections.set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public abstract class Domain1D<T : Comparable<T>>(public val range: ClosedRange<

@UnstableKMathAPI
public class DoubleDomain1D(
@Suppress("CanBeParameter") public val doubleRange: ClosedFloatingPointRange<Float64>,
public val doubleRange: ClosedFloatingPointRange<Float64>,
) : Domain1D<Float64>(doubleRange), Float64Domain {
override fun getLowerBound(num: Int): Double {
require(num == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package space.kscience.kmath.expressions

import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.named.SymbolIndexer
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public class DSCompiler<T, out A : Algebra<T>> internal constructor(
* If all orders are set to 0, then the 0<sup>th</sup> order derivative is returned, which is the value of the
* function.
*
* The indices of derivatives are between 0 and [size] &minus; 1. Their specific order is fixed for a given compiler, but
* The indices of derivatives are between 0 and [getSize] &minus; 1. Their specific order is fixed for a given compiler, but
* otherwise not publicly specified. There are however some simple cases which have guaranteed indices:
*
* * the index of 0<sup>th</sup> order derivative is always 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import space.kscience.attributes.SafeType
import space.kscience.attributes.WithType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.named.SymbolIndexer
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

@file:OptIn(UnstableKMathAPI::class)

package space.kscience.kmath.named

import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer

public class NamedBuffer<T>(public val values: Buffer<T>, public val indexer: SymbolIndexer): Buffer<T> by values{
public operator fun get(symbol: Symbol): T = values[indexer.indexOf(symbol)]
}

public class NamedMutableBuffer<T>(public val values: MutableBuffer<T>, public val indexer: SymbolIndexer): MutableBuffer<T> by values{
public operator fun get(symbol: Symbol): T = values[indexer.indexOf(symbol)]
public operator fun set(symbol: Symbol, value: T) { values[indexer.indexOf(symbol)] = value }
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,24 @@

@file:OptIn(UnstableKMathAPI::class)

package space.kscience.kmath.expressions
package space.kscience.kmath.named

import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.structures.getOrNull

/**
* A square matrix that could be accessed via column and row names.
*
* Multiple symbols could in theory reference the same columns or rows. Some columns could be not references at all.
*/
public class NamedMatrix<T>(public val values: Matrix<T>, public val indexer: SymbolIndexer) : Matrix<T> by values {
init {
require(values.rows.size == values.columns.size) { "Only square matrices could be named" }
}

public operator fun get(i: Symbol, j: Symbol): T = get(indexer.indexOf(i), indexer.indexOf(j))

public companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package space.kscience.kmath.expressions
package space.kscience.kmath.named

import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.linear.Point
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.structures.BufferFactory
Expand Down Expand Up @@ -69,6 +70,9 @@ public interface SymbolIndexer {
public fun Map<Symbol, Double>.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
}

@UnstableKMathAPI
public val SymbolIndexer.size: Int get() = symbols.size

@UnstableKMathAPI
@JvmInline
public value class SimpleSymbolIndexer(override val symbols: List<Symbol>) : SymbolIndexer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
package space.kscience.kmath.optimization

import space.kscience.attributes.*
import space.kscience.kmath.expressions.NamedMatrix
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.Loggable
import space.kscience.kmath.named.NamedMatrix

public interface OptimizationAttribute<T> : Attribute<T>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
package space.kscience.kmath.optimization

import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.expressions.withDefaultArgs
import space.kscience.kmath.linear.*
import space.kscience.kmath.misc.log
import space.kscience.kmath.named.NamedMatrix
import space.kscience.kmath.named.SymbolIndexer
import space.kscience.kmath.named.named
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.Float64L2Norm
import space.kscience.kmath.operations.algebra
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package space.kscience.kmath.distributions

public interface MultivariateNormalDistribution: NamedDistribution<Double> {
public companion object{

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
package space.kscience.kmath.stat

import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.Float64
import space.kscience.kmath.structures.indices
import space.kscience.kmath.structures.*

/**
* Arithmetic mean
Expand Down Expand Up @@ -44,8 +42,8 @@ public class Mean<T>(

public companion object {
public fun evaluate(buffer: Buffer<Float64>): Double = Float64Field.mean.evaluateBlocking(buffer)
public fun evaluate(buffer: Buffer<Int>): Int = Int32Ring.mean.evaluateBlocking(buffer)
public fun evaluate(buffer: Buffer<Long>): Long = Int64Ring.mean.evaluateBlocking(buffer)
public fun evaluate(buffer: Buffer<Int32>): Int = Int32Ring.mean.evaluateBlocking(buffer)
public fun evaluate(buffer: Buffer<Int64>): Long = Int64Ring.mean.evaluateBlocking(buffer)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.last
import kotlinx.coroutines.flow.take
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.random.chain
Expand All @@ -27,21 +28,21 @@ internal class StatisticTest {
val chunked = data.chunked(1000)

@Test
fun singleBlockingMean() {
val first = runBlocking { chunked.first() }
fun singleBlockingMean() = runTest {
val first = chunked.first()
val res = Float64Field.mean(first)
assertEquals(0.5, res, 1e-1)
}

@Test
fun singleSuspendMean() = runBlocking {
fun singleSuspendMean() = runTest {
val first = runBlocking { chunked.first() }
val res = Float64Field.mean(first)
assertEquals(0.5, res, 1e-1)
}

@Test
fun parallelMean() = runBlocking {
fun parallelMean() = runTest {
val average = Float64Field.mean
.flow(chunked) //create a flow from evaluated results
.take(100) // Take 100 data chunks from the source and accumulate them
Expand Down

0 comments on commit c475c43

Please sign in to comment.