Skip to content

Commit a89542e

Browse files
committed
typelevel#787 - tolerance on map members and on vectors for cluster runs
1 parent 80de4f2 commit a89542e

File tree

4 files changed

+131
-23
lines changed

4 files changed

+131
-23
lines changed

dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,8 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
610610
evCanBeDoubleB: CatalystCast[B, Double]
611611
): Prop = bivariatePropTemplate(xs)(
612612
covarSamp[A, B, X3[Int, A, B]],
613-
org.apache.spark.sql.functions.covar_samp
613+
org.apache.spark.sql.functions.covar_samp,
614+
fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("10"))
614615
)
615616

616617
check(forAll(prop[Double, Double] _))

dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala

+57
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package frameless
22
package functions
33

44
import org.scalacheck.Prop
5+
import org.scalacheck.Prop.AnyOperators
56
import org.scalacheck.util.Pretty
7+
import shapeless.{ Lens, OpticDefns }
68

79
/**
810
* Some statistical functions in Spark can result in Double, Double.NaN or Null.
@@ -14,6 +16,8 @@ import org.scalacheck.util.Pretty
1416
*/
1517
object DoubleBehaviourUtils {
1618

19+
val dp5 = BigDecimal(0.00001)
20+
1721
// Mapping with this function is needed because spark uses Double.NaN for some semantics in the
1822
// correlation function. ?= for prop testing will use == underlying and will break because Double.NaN != Double.NaN
1923
private val nanHandler: Double => Option[Double] = value =>
@@ -41,6 +45,45 @@ object DoubleBehaviourUtils {
4145
BigDecimal.RoundingMode.CEILING
4246
)
4347

48+
import shapeless._
49+
50+
def tolerantCompareVectors[K, CC[X] <: Seq[X]](
51+
v1: CC[K],
52+
v2: CC[K],
53+
of: BigDecimal
54+
)(fudgers: Seq[OpticDefns.RootLens[K] => Lens[K, Option[BigDecimal]]]
55+
): Prop = compareVectors(v1, v2)(fudgers.map(f => (f, tolerance(_, of))))
56+
57+
def compareVectors[K, CC[X] <: Seq[X]](
58+
v1: CC[K],
59+
v2: CC[K]
60+
)(fudgers: Seq[
61+
(OpticDefns.RootLens[K] => Lens[K, Option[BigDecimal]],
62+
Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[
63+
BigDecimal
64+
], Option[BigDecimal]]
65+
)
66+
]
67+
): Prop =
68+
if (v1.size != v2.size)
69+
Prop.falsified :| {
70+
"Expected Seq of size " + v1.size + " but got " + v2.size
71+
}
72+
else {
73+
val together = v1.zip(v2)
74+
val m =
75+
together.map { p =>
76+
fudgers.foldLeft(p) { (curr, nf) =>
77+
val theLens = nf._1(lens[K])
78+
val p = (theLens.get(curr._1), theLens.get(curr._2))
79+
val (nl, nr) = nf._2(p)
80+
(theLens.set(curr._1)(nl), theLens.set(curr._2)(nr))
81+
}
82+
}.toMap
83+
84+
m.keys.toVector ?= m.values.toVector
85+
}
86+
4487
def compareMaps[K](
4588
m1: Map[K, Option[BigDecimal]],
4689
m2: Map[K, Option[BigDecimal]],
@@ -97,11 +140,25 @@ object DoubleBehaviourUtils {
97140
p
98141
}
99142
}
143+
144+
import shapeless._
145+
146+
def tl[X](
147+
lensf: OpticDefns.RootLens[X] => Lens[X, Option[BigDecimal]],
148+
of: BigDecimal
149+
): (X, X) => (X, X) =
150+
(l: X, r: X) => {
151+
val theLens = lensf(lens[X])
152+
val (nl, rl) = tolerance((theLens.get(l), theLens.get(r)), of)
153+
(theLens.set(l)(nl), theLens.set(r)(rl))
154+
}
155+
100156
}
101157

102158
/** drop in conversion for doubles to handle serialization on cluster */
103159
trait ToDecimal[A] {
104160
def truncate(a: A): Option[BigDecimal]
161+
105162
}
106163

107164
object ToDecimal {

dataset/src/test/scala/frameless/ops/CubeTests.scala

+38-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package frameless
22
package ops
33

4+
import frameless.functions.DoubleBehaviourUtils.{ dp5, tolerantCompareVectors }
45
import frameless.functions.ToDecimal
56
import frameless.functions.aggregate._
67
import org.scalacheck.Prop
@@ -249,10 +250,22 @@ class CubeTests extends TypedDatasetSuite {
249250
)
250251
.sortBy(t => (t._1, t._2, t._3))
251252

252-
(framelessSumBC ?= sparkSumBC)
253-
.&&(framelessSumBCB ?= sparkSumBCB)
254-
.&&(framelessSumBCBC ?= sparkSumBCBC)
255-
.&&(framelessSumBCBCB ?= sparkSumBCBCB)
253+
(tolerantCompareVectors(framelessSumBC, sparkSumBC, dp5)(Seq(l => l._3)))
254+
.&&(
255+
tolerantCompareVectors(framelessSumBCB, sparkSumBCB, dp5)(
256+
Seq(l => l._3)
257+
)
258+
)
259+
.&&(
260+
tolerantCompareVectors(framelessSumBCBC, sparkSumBCBC, dp5)(
261+
Seq(l => l._3, l => l._5)
262+
)
263+
)
264+
.&&(
265+
tolerantCompareVectors(framelessSumBCBCB, sparkSumBCBCB, dp5)(
266+
Seq(l => l._3, l => l._5)
267+
)
268+
)
256269
}
257270

258271
check(forAll(prop[String, Long, Double, Long, Double] _))
@@ -265,7 +278,7 @@ class CubeTests extends TypedDatasetSuite {
265278
C: TypedEncoder,
266279
D: TypedEncoder,
267280
OutC: TypedEncoder: Numeric,
268-
OutD: TypedEncoder: Numeric
281+
OutD: TypedEncoder: Numeric: ToDecimal
269282
](data: List[X4[A, B, C, D]]
270283
)(implicit
271284
summableC: CatalystSummable[C, OutC],
@@ -277,12 +290,15 @@ class CubeTests extends TypedDatasetSuite {
277290
val C = dataset.col[C]('c)
278291
val D = dataset.col[D]('d)
279292

293+
val toDecOpt = implicitly[ToDecimal[OutD]].truncate _
294+
280295
val framelessSumByAB = dataset
281296
.cube(A, B)
282297
.agg(sum(C), sum(D))
283298
.collect()
284299
.run()
285300
.toVector
301+
.map(row => row.copy(_4 = toDecOpt(row._4)))
286302
.sortBy(x => (x._1, x._2))
287303

288304
val sparkSumByAB = dataset.dataset
@@ -295,12 +311,14 @@ class CubeTests extends TypedDatasetSuite {
295311
Option(row.getAs[A](0)),
296312
Option(row.getAs[B](1)),
297313
row.getAs[OutC](2),
298-
row.getAs[OutD](3)
314+
toDecOpt(row.getAs[OutD](3))
299315
)
300316
)
301317
.sortBy(x => (x._1, x._2))
302318

303-
framelessSumByAB ?= sparkSumByAB
319+
tolerantCompareVectors(framelessSumByAB, sparkSumByAB, dp5)(
320+
Seq(l => l._4)
321+
)
304322
}
305323

306324
check(forAll(prop[Byte, Int, Long, Double, Long, Double] _))
@@ -470,11 +488,19 @@ class CubeTests extends TypedDatasetSuite {
470488
)
471489
.sortBy(t => (t._2, t._1, t._3))
472490

473-
(framelessSumC ?= sparkSumC) &&
474-
(framelessSumCC ?= sparkSumCC) &&
475-
(framelessSumCCC ?= sparkSumCCC) &&
476-
(framelessSumCCCC ?= sparkSumCCCC) &&
477-
(framelessSumCCCCC ?= sparkSumCCCCC)
491+
(tolerantCompareVectors(framelessSumC, sparkSumC, dp5)(Seq(l => l._3))) &&
492+
(tolerantCompareVectors(framelessSumCC, sparkSumCC, dp5)(
493+
Seq(l => l._3, l => l._4)
494+
)) &&
495+
(tolerantCompareVectors(framelessSumCCC, sparkSumCCC, dp5)(
496+
Seq(l => l._3, l => l._4, l => l._5)
497+
)) &&
498+
(tolerantCompareVectors(framelessSumCCCC, sparkSumCCCC, dp5)(
499+
Seq(l => l._3, l => l._4, l => l._5, l => l._6)
500+
)) &&
501+
(tolerantCompareVectors(framelessSumCCCCC, sparkSumCCCCC, dp5)(
502+
Seq(l => l._3, l => l._4, l => l._5, l => l._6, l => l._7)
503+
))
478504
}
479505

480506
check(forAll(prop[String, Long, Double, Double] _))

dataset/src/test/scala/frameless/ops/RollupTests.scala

+34-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package frameless
22
package ops
33

4+
import frameless.functions.DoubleBehaviourUtils.{ dp5, tolerantCompareVectors }
45
import frameless.functions.ToDecimal
56
import frameless.functions.aggregate._
67
import org.scalacheck.Prop
@@ -239,10 +240,23 @@ class RollupTests extends TypedDatasetSuite {
239240
)
240241
.sortBy(identity)
241242

242-
(framelessSumBC ?= sparkSumBC)
243-
.&&(framelessSumBCB ?= sparkSumBCB)
244-
.&&(framelessSumBCBC ?= sparkSumBCBC)
245-
.&&(framelessSumBCBCB ?= sparkSumBCBCB)
243+
(tolerantCompareVectors(framelessSumBC, sparkSumBC, dp5)(Seq(l => l._3)))
244+
.&&(
245+
tolerantCompareVectors(framelessSumBCB, sparkSumBCB, dp5)(
246+
Seq(l => l._3)
247+
)
248+
)
249+
.&&(
250+
tolerantCompareVectors(framelessSumBCBC, sparkSumBCBC, dp5)(
251+
Seq(l => l._3, l => l._5)
252+
)
253+
)
254+
.&&(
255+
tolerantCompareVectors(framelessSumBCBCB, sparkSumBCBCB, dp5)(
256+
Seq(l => l._3, l => l._5)
257+
)
258+
)
259+
246260
}
247261

248262
check(forAll(prop[String, Long, Double, Long, Double] _))
@@ -293,7 +307,9 @@ class RollupTests extends TypedDatasetSuite {
293307
)
294308
.sortBy(t => (t._2, t._1, t._3, t._4))
295309

296-
framelessSumByAB ?= sparkSumByAB
310+
tolerantCompareVectors(framelessSumByAB, sparkSumByAB, dp5)(
311+
Seq(l => l._4)
312+
)
297313
}
298314

299315
check(forAll(prop[Byte, Int, Long, Double, Long, Double] _))
@@ -462,11 +478,19 @@ class RollupTests extends TypedDatasetSuite {
462478
)
463479
.sortBy(t => (t._2, t._1, t._3))
464480

465-
(framelessSumC ?= sparkSumC) &&
466-
(framelessSumCC ?= sparkSumCC) &&
467-
(framelessSumCCC ?= sparkSumCCC) &&
468-
(framelessSumCCCC ?= sparkSumCCCC) &&
469-
(framelessSumCCCCC ?= sparkSumCCCCC)
481+
(tolerantCompareVectors(framelessSumC, sparkSumC, dp5)(Seq(l => l._3))) &&
482+
(tolerantCompareVectors(framelessSumCC, sparkSumCC, dp5)(
483+
Seq(l => l._3, l => l._4)
484+
)) &&
485+
(tolerantCompareVectors(framelessSumCCC, sparkSumCCC, dp5)(
486+
Seq(l => l._3, l => l._4, l => l._5)
487+
)) &&
488+
(tolerantCompareVectors(framelessSumCCCC, sparkSumCCCC, dp5)(
489+
Seq(l => l._3, l => l._4, l => l._5, l => l._6)
490+
)) &&
491+
(tolerantCompareVectors(framelessSumCCCCC, sparkSumCCCCC, dp5)(
492+
Seq(l => l._3, l => l._4, l => l._5, l => l._6, l => l._7)
493+
))
470494
}
471495

472496
check(forAll(prop[String, Long, Double, Double] _))

0 commit comments

Comments
 (0)