Skip to content

Commit

Permalink
check shapes in arithmetic operations
Browse files Browse the repository at this point in the history
fixes #59
  • Loading branch information
plastic-karma authored and devcrocod committed Nov 9, 2021
1 parent e05ec25 commit 91b9aa3
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ internal inline fun requireArraySizes(rightSize: Int, otherSize: Int) {
require(rightSize == otherSize) { "Array sizes don't match: (right operand size) $rightSize != $otherSize (left operand size)" }
}

@Suppress("NOTHING_TO_INLINE")
internal inline fun requireEqualShape(left: IntArray, right: IntArray) {
require(left.contentEquals(right)) { "Array shapes don't match: ${left.contentToString()} != ${right.contentToString()}" }
}

@Suppress("NOTHING_TO_INLINE")
internal inline fun requirePositiveShape(dim: Int) {
require(dim > 0) { "Shape must be positive but was $dim." }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public fun <T, D : Dimension> MultiArray<T, D>.isTransposed(): Boolean {

// TODO: boolean array
public infix fun <T : Number, D : Dimension> MultiArray<T, D>.and(other: MultiArray<T, D>): NDArray<Int, D> {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)

val ret = mk.zeros<Int, D>(this.shape, DataType.IntDataType)
val lIter = this.iterator()
Expand Down Expand Up @@ -61,7 +61,7 @@ public infix fun <T : Number, D : Dimension> MultiArray<T, D>.and(other: MultiAr
}

public infix fun <T : Number, D : Dimension> MultiArray<T, D>.or(other: MultiArray<T, D>): NDArray<Int, D> {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)

val ret = mk.zeros<Int, D>(this.shape, DataType.IntDataType)
val lIter = this.iterator()
Expand Down Expand Up @@ -796,7 +796,7 @@ public inline fun <T, D : Dimension, reified R : Any> MultiArray<T, D>.map(trans
* Returns the element-wise minimum of array elements for [this] and [other].
*/
public fun <T: Number, D : Dimension> MultiArray<T, D>.minimum(other: MultiArray<T, D>): NDArray<T, D> {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
val ret = (this as NDArray).deepCopy()
when (dtype) {
DataType.DoubleDataType -> (ret as NDArray<Double, D>).commonAssignOp(other.iterator() as Iterator<Double>) { a, b -> min(a, b) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public operator fun <T, D : Dimension> MultiArray<T, D>.unaryMinus(): NDArray<T,
* Create a new array as the sum of [this] and [other].
*/
public operator fun <T, D : Dimension> MultiArray<T, D>.plus(other: MultiArray<T, D>): NDArray<T, D> {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
val ret = if (this.consistent) (this as NDArray).copy() else (this as NDArray).deepCopy()
ret += other
return ret
Expand All @@ -40,7 +40,7 @@ public operator fun <T, D : Dimension> MultiArray<T, D>.plus(other: T): NDArray<
* Add [other] to [this]. Inplace operator.
*/
public operator fun <T, D : Dimension> MutableMultiArray<T, D>.plusAssign(other: MultiArray<T, D>) {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
if (this.consistent && other.consistent) {
this.data += (other.data as MemoryView)
} else {
Expand Down Expand Up @@ -82,7 +82,7 @@ public operator fun <T, D : Dimension> MutableMultiArray<T, D>.plusAssign(other:
* Create a new array as difference between [this] and [other].
*/
public operator fun <T, D : Dimension> MultiArray<T, D>.minus(other: MultiArray<T, D>): NDArray<T, D> {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
val ret = if (this.consistent) (this as NDArray).copy() else (this as NDArray).deepCopy()
ret -= other
return ret
Expand All @@ -98,7 +98,7 @@ public operator fun <T, D : Dimension> MultiArray<T, D>.minus(other: T): NDArray
* Subtract [other] from [this]. Inplace operator.
*/
public operator fun <T, D : Dimension> MutableMultiArray<T, D>.minusAssign(other: MultiArray<T, D>) {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
if (this.consistent && other.consistent) {
this.data -= (other.data as MemoryView)
} else {
Expand Down Expand Up @@ -139,7 +139,7 @@ public operator fun <T, D : Dimension> MutableMultiArray<T, D>.minusAssign(other
* Create a new array as product of [this] and [other].
*/
public operator fun <T, D : Dimension> MultiArray<T, D>.times(other: MultiArray<T, D>): NDArray<T, D> {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
val ret = if (this.consistent) (this as NDArray).copy() else (this as NDArray).deepCopy()
ret *= other
return ret
Expand All @@ -155,7 +155,7 @@ public operator fun <T, D : Dimension> MultiArray<T, D>.times(other: T): NDArray
* Multiply [this] by [other]. Inplace operator.
*/
public operator fun <T, D : Dimension> MutableMultiArray<T, D>.timesAssign(other: MultiArray<T, D>) {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
if (this.consistent && other.consistent) {
this.data *= (other.data as MemoryView)
} else {
Expand Down Expand Up @@ -196,7 +196,7 @@ public operator fun <T, D : Dimension> MutableMultiArray<T, D>.timesAssign(other
* Create a new array as division of [this] by [other].
*/
public operator fun <T, D : Dimension> MultiArray<T, D>.div(other: MultiArray<T, D>): NDArray<T, D> {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
val ret = if (this.consistent) (this as NDArray).copy() else (this as NDArray).deepCopy()
ret /= other
return ret
Expand All @@ -212,7 +212,7 @@ public operator fun <T, D : Dimension> MultiArray<T, D>.div(other: T): NDArray<T
* Divide [this] by [other]. Inplace operator.
*/
public operator fun <T, D : Dimension> MutableMultiArray<T, D>.divAssign(other: MultiArray<T, D>) {
requireArraySizes(this.size, other.size)
requireEqualShape(this.shape, other.shape)
if (this.consistent && other.consistent) {
this.data /= (other.data as MemoryView)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.jetbrains.kotlinx.multik.ndarray.data

import org.jetbrains.kotlinx.multik.api.empty
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.zeros
import kotlin.test.Test
import kotlin.test.assertTrue
import kotlin.test.fail

class InternalsTest {

@Test
fun `require equal shape throws exception for unequal shape`() {
val left = mk.zeros<Double>(0, 1, 2, 3)
val right = mk.zeros<Double>(0, 1, 2, 4)
expectUnEqualShape(left, right)
}

@Test
fun `require equal shape throws exception for different no of dim`() {
val left = mk.zeros<Double>(0, 1, 2, 3)
val right = mk.zeros<Double>(0, 1, 2)
expectUnEqualShape(left, right)
}

@Test
fun `require equal shape succeeds for arrays with equal shapes`() {
val left = mk.zeros<Double>(0, 1, 2, 3)
val right = mk.zeros<Double>(0, 1, 2, 3)
requireEqualShape(left.shape, right.shape)
}

@Test
fun `require equal shape succeeds empty arrays`() {
val left = mk.zeros<Double>(0)
val right = mk.zeros<Double>(0)
assertTrue(left.isEmpty())
assertTrue(right.isEmpty())
requireEqualShape(left.shape, right.shape)
}

private fun expectUnEqualShape(left: NDArray<Double, *>, right: NDArray<Double, *>) {
try {
requireEqualShape(left.shape, right.shape)
fail("Exception expected")
} catch (e: IllegalArgumentException) { }
}
}

0 comments on commit 91b9aa3

Please sign in to comment.