|
61 | 61 | from pyspark import keyword_only |
62 | 62 | from pyspark.conf import SparkConf |
63 | 63 | from pyspark.context import SparkContext |
64 | | -from pyspark.rdd import RDD |
65 | 64 | from pyspark.files import SparkFiles |
66 | 65 | from pyspark.ml.feature import RFormula |
| 66 | +from pyspark.rdd import RDD |
67 | 67 | from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ |
68 | 68 | CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ |
69 | 69 | PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ |
@@ -2207,20 +2207,20 @@ def set(self, x=None, other=None, other_x=None): |
2207 | 2207 | self.assertEqual(b._x, 2) |
2208 | 2208 |
|
2209 | 2209 |
|
2210 | | -class SparkMLTests(unittest.TestCase): |
| 2210 | +class SparkMLTests(ReusedPySparkTestCase): |
2211 | 2211 |
|
2212 | 2212 | def test_rformula(self): |
2213 | | - df = spark.createDataFrame([ |
2214 | | - (1.0, 1.0, "a"), |
2215 | | - (0.0, 2.0, "b"), |
2216 | | - (0.0, 0.0, "a") |
2217 | | - ], ["y", "x", "s"]) |
| 2213 | + df = self.sc.parallelize([ |
| 2214 | + (1.0, 1.0, "a"), |
| 2215 | + (0.0, 2.0, "b"), |
| 2216 | + (0.0, 0.0, "a") |
| 2217 | + ]).toDF(["y", "x", "s"]) |
2218 | 2218 | rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") |
2219 | 2219 | self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') |
2220 | 2220 |
|
2221 | 2221 | result = rf.fit(df).transform(df) |
2222 | 2222 | observed = result.select("features").collect() |
2223 | | - expected = [[1.0, 0.0], [2.0, 1.0], [0.0,0.0]] |
| 2223 | + expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] |
2224 | 2224 | for i in range(0, len(expected)): |
2225 | 2225 | self.assertEqual(observed[i]["features"].toArray(), expected[i]) |
2226 | 2226 |
|
|
0 commit comments