Skip to content

Commit 80de4f2

Browse files
committed
typelevel#787 - attempt covar_pop and kurtosis through tolerances
1 parent 66b31e9 commit 80de4f2

File tree

2 files changed

+84
-6
lines changed

2 files changed

+84
-6
lines changed

dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala

+24-6
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,10 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
466466
TypedColumn[X3[Int, A, B], A],
467467
TypedColumn[X3[Int, A, B], B]
468468
) => TypedAggregate[X3[Int, A, B], Option[Double]],
469-
sparkFun: (Column, Column) => Column
469+
sparkFun: (Column, Column) => Column,
470+
fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
471+
BigDecimal
472+
], Option[BigDecimal]] = identity
470473
)(implicit
471474
encEv: Encoder[(Int, A, B)],
472475
encEv2: Encoder[(Int, Option[Double])],
@@ -496,7 +499,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
496499
})
497500

498501
// Should be the same
499-
tdBivar.toMap ?= compBivar.collect().toMap
502+
// tdBivar.toMap ?= compBivar.collect().toMap
503+
DoubleBehaviourUtils.compareMaps(
504+
tdBivar.toMap,
505+
compBivar.collect().toMap,
506+
fudger
507+
)
500508
}
501509

502510
def univariatePropTemplate[A: TypedEncoder](
@@ -505,7 +513,10 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
505513
X2[Int, A],
506514
Option[Double]
507515
],
508-
sparkFun: (Column) => Column
516+
sparkFun: (Column) => Column,
517+
fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
518+
BigDecimal
519+
], Option[BigDecimal]] = identity
509520
)(implicit
510521
encEv: Encoder[(Int, A)],
511522
encEv2: Encoder[(Int, Option[Double])],
@@ -534,7 +545,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
534545
})
535546

536547
// Should be the same
537-
tdUnivar.toMap ?= compUnivar.collect().toMap
548+
// tdUnivar.toMap ?= compUnivar.collect().toMap
549+
DoubleBehaviourUtils.compareMaps(
550+
tdUnivar.toMap,
551+
compUnivar.collect().toMap,
552+
fudger
553+
)
538554
}
539555

540556
test("corr") {
@@ -571,7 +587,8 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
571587
evCanBeDoubleB: CatalystCast[B, Double]
572588
): Prop = bivariatePropTemplate(xs)(
573589
covarPop[A, B, X3[Int, A, B]],
574-
org.apache.spark.sql.functions.covar_pop
590+
org.apache.spark.sql.functions.covar_pop,
591+
fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("100"))
575592
)
576593

577594
check(forAll(prop[Double, Double] _))
@@ -614,7 +631,8 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
614631
evCanBeDoubleA: CatalystCast[A, Double]
615632
): Prop = univariatePropTemplate(xs)(
616633
kurtosis[A, X2[Int, A]],
617-
org.apache.spark.sql.functions.kurtosis
634+
org.apache.spark.sql.functions.kurtosis,
635+
fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("0.1"))
618636
)
619637

620638
check(forAll(prop[Double] _))

dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala

+60
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package frameless
22
package functions
33

4+
import org.scalacheck.Prop
5+
import org.scalacheck.util.Pretty
6+
47
/**
58
* Some statistical functions in Spark can result in Double, Double.NaN or Null.
69
* This tends to break ?= of the property based testing. Use the nanNullHandler function
@@ -37,6 +40,63 @@ object DoubleBehaviourUtils {
3740
else
3841
BigDecimal.RoundingMode.CEILING
3942
)
43+
44+
def compareMaps[K](
45+
m1: Map[K, Option[BigDecimal]],
46+
m2: Map[K, Option[BigDecimal]],
47+
fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
48+
BigDecimal
49+
], Option[BigDecimal]]
50+
): Prop = {
51+
def compareKey(k: K): Prop = {
52+
val m1v = m1.get(k)
53+
val m2v = m2.get(k)
54+
if (!m2v.isDefined)
55+
Prop.falsified :| {
56+
val expKey = Pretty.pretty[K](k, Pretty.Params(0))
57+
"Expected key of " + expKey + " in right side map"
58+
}
59+
else {
60+
val (v1, v2) = fudger((m1v.get, m2v.get))
61+
if (v1 == v2)
62+
Prop.proved
63+
else
64+
Prop.falsified :| {
65+
val expKey = Pretty.pretty[K](k, Pretty.Params(0))
66+
val leftVal =
67+
Pretty.pretty[Option[BigDecimal]](v1, Pretty.Params(0))
68+
val rightVal =
69+
Pretty.pretty[Option[BigDecimal]](v2, Pretty.Params(0))
70+
"For key of " + expKey + " expected " + leftVal + " got " + rightVal
71+
}
72+
}
73+
}
74+
75+
if (m1.size != m2.size)
76+
Prop.falsified :| {
77+
"Expected map of size " + m1.size + " but got " + m2.size
78+
}
79+
else
80+
m1.keys.foldLeft(Prop.passed) { (curr, elem) => curr && compareKey(elem) }
81+
}
82+
83+
/** running covar_pop and kurtosis multiple times is giving slightly different results */
84+
def tolerance(
85+
p: Tuple2[Option[BigDecimal], Option[BigDecimal]],
86+
of: BigDecimal
87+
): Tuple2[Option[BigDecimal], Option[BigDecimal]] = {
88+
val comb = p._1.flatMap(a => p._2.map(b => (a, b)))
89+
if (comb.isEmpty)
90+
p
91+
else {
92+
val (l, r) = comb.get
93+
if ((l.max(r) - l.min(r)).abs < of)
94+
// tolerate it
95+
(Some(l), Some(l))
96+
else
97+
p
98+
}
99+
}
40100
}
41101

42102
/** drop in conversion for doubles to handle serialization on cluster */

0 commit comments

Comments
 (0)