Skip to content

Commit 48db146

Browse files
committed
added more sum tests
1 parent 73ceaa8 commit 48db146

File tree

1 file changed

+76
-0
lines changed
  • core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics

1 file changed

+76
-0
lines changed

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
package org.jetbrains.kotlinx.dataframe.statistics
22

33
import io.kotest.assertions.throwables.shouldThrow
4+
import io.kotest.matchers.booleans.shouldBeTrue
5+
import io.kotest.matchers.doubles.shouldBeNaN
6+
import io.kotest.matchers.floats.shouldBeNaN
47
import io.kotest.matchers.shouldBe
58
import io.kotest.matchers.string.shouldContain
69
import org.jetbrains.kotlinx.dataframe.api.columnOf
710
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
811
import org.jetbrains.kotlinx.dataframe.api.isEmpty
912
import org.jetbrains.kotlinx.dataframe.api.rowSum
13+
import org.jetbrains.kotlinx.dataframe.api.rowSumOf
1014
import org.jetbrains.kotlinx.dataframe.api.sum
15+
import org.jetbrains.kotlinx.dataframe.api.sumFor
1116
import org.jetbrains.kotlinx.dataframe.api.sumOf
1217
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
1318
import org.junit.Test
@@ -106,4 +111,75 @@ class SumTests {
106111
columnOf<Number>(1.0, 2, 3.0.toBigDecimal()).toDataFrame().sum()[0]
107112
}.message?.lowercase() shouldContain "primitive"
108113
}
114+
115+
@Test
116+
fun `test skipNaN with float column`() {
117+
val value by columnOf(1.0f, 2.0f, Float.NaN, 3.0f)
118+
val df = dataFrameOf(value)
119+
120+
// With skipNaN = true (default is false)
121+
value.sum(skipNaN = true) shouldBe 6.0f
122+
df[value].sum(skipNaN = true) shouldBe 6.0f
123+
df.sum(skipNaN = true)[value] shouldBe 6.0f
124+
df.sumOf(skipNaN = true) { value().toInt() } shouldBe 6
125+
126+
// With skipNaN = false (default)
127+
value.sum().shouldBeNaN()
128+
df[value].sum().shouldBeNaN()
129+
df.sum()[value].shouldBeNaN()
130+
df.sumOf { value().toDouble() }.shouldBeNaN()
131+
}
132+
133+
@Test
134+
fun `test skipNaN with double column`() {
135+
val value by columnOf(1.0, 2.0, Double.NaN, 3.0)
136+
val df = dataFrameOf(value)
137+
138+
// With skipNaN = true (default is false)
139+
value.sum(skipNaN = true) shouldBe 6.0
140+
df[value].sum(skipNaN = true) shouldBe 6.0
141+
df.sum(skipNaN = true)[value] shouldBe 6.0
142+
df.sumOf(skipNaN = true) { value().toLong() } shouldBe 6L
143+
144+
// With skipNaN = false (default)
145+
value.sum().shouldBeNaN()
146+
df[value].sum().shouldBeNaN()
147+
df.sum()[value].shouldBeNaN()
148+
df.sumOf { value().toFloat() }.shouldBeNaN()
149+
}
150+
151+
@Test
152+
fun `test rowSum with skipNaN`() {
153+
val row1 = dataFrameOf("a", "b", "c")(1.0, 2.0, 3.0)[0]
154+
val row2 = dataFrameOf("a", "b", "c")(1.0, Double.NaN, 3)[0]
155+
156+
// With skipNaN = true
157+
row1.rowSum(skipNaN = true) shouldBe 6.0
158+
row2.rowSum(skipNaN = true) shouldBe 4.0
159+
160+
// With skipNaN = false (default)
161+
row1.rowSum() shouldBe 6.0
162+
(row2.rowSum() as Double).isNaN().shouldBeTrue()
163+
164+
// Test rowSumOf
165+
row1.rowSumOf<Double?>(skipNaN = true) shouldBe 6.0
166+
row2.rowSumOf<Double?>(skipNaN = true) shouldBe 1.0
167+
row1.rowSumOf<Double?>() shouldBe 6.0
168+
row2.rowSumOf<Double?>().shouldBeNaN()
169+
}
170+
171+
@Test
172+
fun `test sumFor with skipNaN`() {
173+
val value1 by columnOf(1.0, 2.0, 3.0)
174+
val value2 by columnOf<Number>(4.0, Float.NaN, 6)
175+
val df = dataFrameOf(value1, value2)
176+
177+
// With skipNaN = true
178+
df.sumFor(skipNaN = true) { value1 and value2 }[value1] shouldBe 6.0
179+
df.sumFor(skipNaN = true) { value1 and value2 }[value2] shouldBe 10.0
180+
181+
// With skipNaN = false (default)
182+
df.sumFor { value1 and value2 }[value1] shouldBe 6.0
183+
(df.sumFor { value1 and value2 }[value2] as Double).shouldBeNaN()
184+
}
109185
}

0 commit comments

Comments
 (0)