|
1 | 1 | package org.jetbrains.kotlinx.dataframe.statistics
|
2 | 2 |
|
3 | 3 | import io.kotest.assertions.throwables.shouldThrow
|
| 4 | +import io.kotest.matchers.booleans.shouldBeTrue |
| 5 | +import io.kotest.matchers.doubles.shouldBeNaN |
| 6 | +import io.kotest.matchers.floats.shouldBeNaN |
4 | 7 | import io.kotest.matchers.shouldBe
|
5 | 8 | import io.kotest.matchers.string.shouldContain
|
6 | 9 | import org.jetbrains.kotlinx.dataframe.api.columnOf
|
7 | 10 | import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
|
8 | 11 | import org.jetbrains.kotlinx.dataframe.api.isEmpty
|
9 | 12 | import org.jetbrains.kotlinx.dataframe.api.rowSum
|
| 13 | +import org.jetbrains.kotlinx.dataframe.api.rowSumOf |
10 | 14 | import org.jetbrains.kotlinx.dataframe.api.sum
|
| 15 | +import org.jetbrains.kotlinx.dataframe.api.sumFor |
11 | 16 | import org.jetbrains.kotlinx.dataframe.api.sumOf
|
12 | 17 | import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
13 | 18 | import org.junit.Test
|
@@ -106,4 +111,75 @@ class SumTests {
|
106 | 111 | columnOf<Number>(1.0, 2, 3.0.toBigDecimal()).toDataFrame().sum()[0]
|
107 | 112 | }.message?.lowercase() shouldContain "primitive"
|
108 | 113 | }
|
| 114 | + |
| 115 | + @Test |
| 116 | + fun `test skipNaN with float column`() { |
| 117 | + val value by columnOf(1.0f, 2.0f, Float.NaN, 3.0f) |
| 118 | + val df = dataFrameOf(value) |
| 119 | + |
| 120 | + // With skipNaN = true (default is false) |
| 121 | + value.sum(skipNaN = true) shouldBe 6.0f |
| 122 | + df[value].sum(skipNaN = true) shouldBe 6.0f |
| 123 | + df.sum(skipNaN = true)[value] shouldBe 6.0f |
| 124 | + df.sumOf(skipNaN = true) { value().toInt() } shouldBe 6 |
| 125 | + |
| 126 | + // With skipNaN = false (default) |
| 127 | + value.sum().shouldBeNaN() |
| 128 | + df[value].sum().shouldBeNaN() |
| 129 | + df.sum()[value].shouldBeNaN() |
| 130 | + df.sumOf { value().toDouble() }.shouldBeNaN() |
| 131 | + } |
| 132 | + |
| 133 | + @Test |
| 134 | + fun `test skipNaN with double column`() { |
| 135 | + val value by columnOf(1.0, 2.0, Double.NaN, 3.0) |
| 136 | + val df = dataFrameOf(value) |
| 137 | + |
| 138 | + // With skipNaN = true (default is false) |
| 139 | + value.sum(skipNaN = true) shouldBe 6.0 |
| 140 | + df[value].sum(skipNaN = true) shouldBe 6.0 |
| 141 | + df.sum(skipNaN = true)[value] shouldBe 6.0 |
| 142 | + df.sumOf(skipNaN = true) { value().toLong() } shouldBe 6L |
| 143 | + |
| 144 | + // With skipNaN = false (default) |
| 145 | + value.sum().shouldBeNaN() |
| 146 | + df[value].sum().shouldBeNaN() |
| 147 | + df.sum()[value].shouldBeNaN() |
| 148 | + df.sumOf { value().toFloat() }.shouldBeNaN() |
| 149 | + } |
| 150 | + |
| 151 | + @Test |
| 152 | + fun `test rowSum with skipNaN`() { |
| 153 | + val row1 = dataFrameOf("a", "b", "c")(1.0, 2.0, 3.0)[0] |
| 154 | + val row2 = dataFrameOf("a", "b", "c")(1.0, Double.NaN, 3)[0] |
| 155 | + |
| 156 | + // With skipNaN = true |
| 157 | + row1.rowSum(skipNaN = true) shouldBe 6.0 |
| 158 | + row2.rowSum(skipNaN = true) shouldBe 4.0 |
| 159 | + |
| 160 | + // With skipNaN = false (default) |
| 161 | + row1.rowSum() shouldBe 6.0 |
| 162 | + (row2.rowSum() as Double).isNaN().shouldBeTrue() |
| 163 | + |
| 164 | + // Test rowSumOf |
| 165 | + row1.rowSumOf<Double?>(skipNaN = true) shouldBe 6.0 |
| 166 | + row2.rowSumOf<Double?>(skipNaN = true) shouldBe 1.0 |
| 167 | + row1.rowSumOf<Double?>() shouldBe 6.0 |
| 168 | + row2.rowSumOf<Double?>().shouldBeNaN() |
| 169 | + } |
| 170 | + |
| 171 | + @Test |
| 172 | + fun `test sumFor with skipNaN`() { |
| 173 | + val value1 by columnOf(1.0, 2.0, 3.0) |
| 174 | + val value2 by columnOf<Number>(4.0, Float.NaN, 6) |
| 175 | + val df = dataFrameOf(value1, value2) |
| 176 | + |
| 177 | + // With skipNaN = true |
| 178 | + df.sumFor(skipNaN = true) { value1 and value2 }[value1] shouldBe 6.0 |
| 179 | + df.sumFor(skipNaN = true) { value1 and value2 }[value2] shouldBe 10.0 |
| 180 | + |
| 181 | + // With skipNaN = false (default) |
| 182 | + df.sumFor { value1 and value2 }[value1] shouldBe 6.0 |
| 183 | + (df.sumFor { value1 and value2 }[value2] as Double).shouldBeNaN() |
| 184 | + } |
109 | 185 | }
|
0 commit comments