From dc5718cc543bd166ee58c1367acc6ed7ccd92788 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Mon, 6 Apr 2015 21:36:28 +0900 Subject: [PATCH 1/2] [SPARK-6720] PySpark MultivariateStatisticalSummary unit test for normL1 and normL2 --- python/pyspark/mllib/tests.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 47dad7d12e4e..b77bfeac6ed2 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -363,6 +363,13 @@ def test_col_norms(self): self.assertEqual(10, len(summary.normL1())) self.assertEqual(10, len(summary.normL2())) + data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + # Confirm normL2 is among this span because it is a float value. + self.assertTrue(summary2.normL2()[0] > 16.5) + self.assertTrue(summary2.normL2()[0] < 17.0) + class VectorUDTTests(PySparkTestCase): From 5541b24ffbb4933b2131b43aae2a6c59dbc6a25b Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 7 Apr 2015 20:55:07 +0900 Subject: [PATCH 2/2] More accurate tests --- python/pyspark/mllib/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index b77bfeac6ed2..61ef398487c0 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -366,9 +366,9 @@ def test_col_norms(self): data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) summary2 = Statistics.colStats(data2) self.assertEqual(array([45.0]), summary2.normL1()) - # Confirm normL2 is among this span because it is a float value. - self.assertTrue(summary2.normL2()[0] > 16.5) - self.assertTrue(summary2.normL2()[0] < 17.0) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) class VectorUDTTests(PySparkTestCase):