Skip to content

Commit 320203e

Browse files
committed
update test
1 parent 3510e24 commit 320203e

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

python/pyspark/tests.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@
6161
from pyspark import keyword_only
6262
from pyspark.conf import SparkConf
6363
from pyspark.context import SparkContext
64-
from pyspark.rdd import RDD
6564
from pyspark.files import SparkFiles
6665
from pyspark.ml.feature import RFormula
66+
from pyspark.rdd import RDD
6767
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
6868
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \
6969
PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \
@@ -2207,20 +2207,20 @@ def set(self, x=None, other=None, other_x=None):
22072207
self.assertEqual(b._x, 2)
22082208

22092209

2210-
class SparkMLTests(unittest.TestCase):
2210+
class SparkMLTests(ReusedPySparkTestCase):
22112211

22122212
def test_rformula(self):
2213-
df = spark.createDataFrame([
2214-
(1.0, 1.0, "a"),
2215-
(0.0, 2.0, "b"),
2216-
(0.0, 0.0, "a")
2217-
], ["y", "x", "s"])
2213+
df = self.sc.parallelize([
2214+
(1.0, 1.0, "a"),
2215+
(0.0, 2.0, "b"),
2216+
(0.0, 0.0, "a")
2217+
]).toDF(["y", "x", "s"])
22182218
rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
22192219
self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
22202220

22212221
result = rf.fit(df).transform(df)
22222222
observed = result.select("features").collect()
2223-
expected = [[1.0, 0.0], [2.0, 1.0], [0.0,0.0]]
2223+
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
22242224
for i in range(0, len(expected)):
22252225
self.assertEqual(observed[i]["features"].toArray(), expected[i])
22262226

0 commit comments

Comments
 (0)