diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be178d7689fd..4bf14fba34ee 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -260,7 +260,7 @@ List localTraining = Lists.newArrayList( new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); +JavaSchemaRDD training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -300,7 +300,7 @@ List localTest = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); +JavaSchemaRDD test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. @@ -443,7 +443,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -469,7 +469,7 @@ List localTest = Lists.newArrayList( new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); JavaSchemaRDD test = - jsql.applySchema(jsc.parallelize(localTest), Document.class); + jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. model.transform(test).registerAsTable("prediction"); @@ -626,7 +626,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -669,7 +669,7 @@ List localTest = Lists.newArrayList( new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); +JavaSchemaRDD test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test).registerAsTable("prediction"); diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 38f617d0c836..b2b007509c73 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -225,7 +225,7 @@ public static class Person implements Serializable { {% endhighlight %} -A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object +A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object for the JavaBean. {% highlight java %} @@ -247,7 +247,7 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); +JavaSchemaRDD schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. @@ -315,7 +315,7 @@ a `SchemaRDD` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided by `SQLContext`. For example: @@ -341,7 +341,7 @@ val schema = val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) // Apply the schema to the RDD. -val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema) +val peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema) // Register the SchemaRDD as a table. peopleSchemaRDD.registerTempTable("people") @@ -367,7 +367,7 @@ a `SchemaRDD` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided by `JavaSQLContext`. For example: @@ -406,7 +406,7 @@ JavaRDD rowRDD = people.map( }); // Apply the schema to the RDD. -JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema); +JavaSchemaRDD peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema); // Register the SchemaRDD as a table. peopleSchemaRDD.registerTempTable("people"); @@ -436,7 +436,7 @@ a `SchemaRDD` can be created programmatically with three steps. 1. Create an RDD of tuples or lists from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of tuples or lists in the RDD created in the step 1. -3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`. +3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`. For example: {% highlight python %} @@ -458,7 +458,7 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt schema = StructType(fields) # Apply the schema to the RDD. -schemaPeople = sqlContext.applySchema(people, schema) +schemaPeople = sqlContext.createDataFrame(people, schema) # Register the SchemaRDD as a table. schemaPeople.registerTempTable("people") diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 5041e0b6d34b..5d8c5d0a92da 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,7 +112,7 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test).registerTempTable("prediction"); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 4d9dad9f2303..19d0eb216848 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -62,7 +62,7 @@ public static void main(String[] args) throws Exception { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); @@ -80,7 +80,7 @@ public static void main(String[] args) throws Exception { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). DataFrame results = model.transform(test); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index cc69e6315fdd..4c4d53238878 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -94,7 +94,7 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index d929f1ad2014..fdcfc888c235 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,7 +79,7 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. model.transform(test).registerTempTable("prediction"); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 8defb769ffaa..dee794840a3e 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -74,7 +74,7 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 7f5c68e3d0fe..47202fde7510 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -31,7 +31,7 @@ Row(name="Smith", age=23), Row(name="Sarah", age=18)]) # Infer schema from the first row, create a DataFrame and print the schema - some_df = sqlContext.inferSchema(some_rdd) + some_df = sqlContext.createDataFrame(some_rdd) some_df.printSchema() # Another RDD is created from a list of tuples @@ -40,7 +40,7 @@ schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) # Create a DataFrame by applying the schema to the RDD and print the schema - another_df = sqlContext.applySchema(another_rdd, schema) + another_df = sqlContext.createDataFrame(another_rdd, schema) another_df.printSchema() # root # |-- age: integer (nullable = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5d51c5134666..324b1ba78438 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -76,8 +76,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val metrics = new Array[Double](epm.size) val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => - val trainingDataset = sqlCtx.applySchema(training, schema).cache() - val validationDataset = sqlCtx.applySchema(validation, schema).cache() + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 50995ffef9ad..0a8c9e595467 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -45,7 +45,7 @@ public void setUp() { jsql = new SQLContext(jsc); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.applySchema(points, LabeledPoint.class); + dataset = jsql.createDataFrame(points, LabeledPoint.class); } @After diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index d4b664479255..3f8e59de0f05 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -50,7 +50,7 @@ public void setUp() { jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.applySchema(datasetRDD, LabeledPoint.class); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 40d5a92bb32a..0cc36c8d56d7 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -46,7 +46,7 @@ public void setUp() { jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.applySchema(datasetRDD, LabeledPoint.class); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 074b58c07df7..0bb6b489f275 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -45,7 +45,7 @@ public void setUp() { jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @After diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 49f016a9cf2e..08d72186ab0f 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -24,7 +24,7 @@ from pyspark.rdd import _prepare_for_python_RDD from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql.types import StringType, StructType, _verify_type, \ +from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame @@ -46,23 +46,11 @@ def __init__(self, sparkContext, sqlContext=None): :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - - >>> bad_rdd = sc.parallelize([1,2,3]) - >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - >>> from datetime import datetime >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> df = sqlCtx.inferSchema(allTypes) + >>> df = sqlCtx.createDataFrame(allTypes) >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() @@ -118,6 +106,9 @@ def registerFunction(self, name, f, returnType=StringType()): def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. + ::note: + Deprecated in 1.3, use :func:`createDataFrame` instead + When samplingRatio is specified, the schema is inferred by looking at the types of each row in the sampled dataset. Otherwise, the first 100 rows of the RDD are inspected. Nested collections are @@ -186,7 +177,7 @@ def inferSchema(self, rdd, samplingRatio=None): warnings.warn("Some of types cannot be determined by the " "first 100 rows, please try again with sampling") else: - if samplingRatio > 0.99: + if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) schema = rdd.map(_infer_schema).reduce(_merge_type) @@ -198,6 +189,9 @@ def applySchema(self, rdd, schema): """ Applies the given schema to the given RDD of L{tuple} or L{list}. + ::note: + Deprecated in 1.3, use :func:`createDataFrame` instead + These tuples or lists can contain complex nested structures like lists, maps or nested rows. @@ -287,13 +281,68 @@ def applySchema(self, rdd, schema): df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return DataFrame(df, self) + def createDataFrame(self, rdd, schema=None, samplingRatio=None): + """ + Create a DataFrame from an RDD of tuple/list and an optional `schema`. + + `schema` could be :class:`StructType` or a list of column names. + + When `schema` is a list of column names, the type of each column + will be inferred from `rdd`. + + When `schema` is None, it will try to infer the column name and type + from `rdd`, which should be an RDD of :class:`Row`, or namedtuple, + or dict. + + If referring needed, `samplingRatio` is used to determined how many + rows will be used to do referring. The first row will be used if + `samplingRatio` is None. + + :param rdd: an RDD of Row or tuple or list or dict + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd = sc.parallelize([('Alice', 1)]) + >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql import Row + >>> Person = Row('name', 'age') + >>> person = rdd.map(lambda r: Person(*r)) + >>> df2 = sqlCtx.createDataFrame(person) + >>> df2.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("name", StringType(), True), + ... StructField("age", IntegerType(), True)]) + >>> df3 = sqlCtx.createDataFrame(rdd, schema) + >>> df3.collect() + [Row(name=u'Alice', age=1)] + """ + if isinstance(rdd, DataFrame): + raise TypeError("rdd is already a DataFrame") + + if isinstance(schema, StructType): + return self.applySchema(rdd, schema) + else: + if isinstance(schema, (list, tuple)): + first = rdd.first() + if not isinstance(first, (list, tuple)): + raise ValueError("each row in `rdd` should be list or tuple") + row_cls = Row(*schema) + rdd = rdd.map(lambda r: row_cls(*r)) + return self.inferSchema(rdd, samplingRatio) + def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. Temporary tables exist only during the lifetime of this instance of SQLContext. - >>> df = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(df, "table1") """ if (rdd.__class__ is DataFrame): @@ -308,7 +357,6 @@ def parquetFile(self, *paths): >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) >>> df.saveAsParquetFile(parquetFile) >>> df2 = sqlCtx.parquetFile(parquetFile) >>> sorted(df.collect()) == sorted(df2.collect()) @@ -458,7 +506,6 @@ def func(iterator): def sql(self, sqlQuery): """Return a L{DataFrame} representing the result of the given query. - >>> df = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(df, "table1") >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -469,7 +516,6 @@ def sql(self, sqlQuery): def table(self, tableName): """Returns the specified table as a L{DataFrame}. - >>> df = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(df, "table1") >>> df2 = sqlCtx.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -617,11 +663,12 @@ def _test(): sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlCtx'] = sqlCtx = SQLContext(sc) - globs['rdd'] = sc.parallelize( + globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['df'] = sqlCtx.createDataFrame(rdd) jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d25c6365ed06..70c0ff2d1c16 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -97,7 +97,7 @@ def setUpClass(cls): cls.sqlCtx = SQLContext(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] rdd = cls.sc.parallelize(cls.testData) - cls.df = cls.sqlCtx.inferSchema(rdd) + cls.df = cls.sqlCtx.createDataFrame(rdd) @classmethod def tearDownClass(cls): @@ -111,14 +111,14 @@ def test_udf(self): def test_udf2(self): self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) def test_udf_with_array_type(self): d = [Row(l=range(3), d={"key": range(5)})] rdd = self.sc.parallelize(d) - self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.createDataFrame(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() @@ -156,17 +156,17 @@ def test_basic_functions(self): def test_apply_schema_to_row(self): df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema()) self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.applySchema(rdd, df.schema()) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema()) self.assertEqual(10, df3.count()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) @@ -188,14 +188,14 @@ def test_infer_schema(self): d = [Row(l=[], d={}), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) self.assertEqual([], df.map(lambda r: r.l).first()) self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) - df2 = self.sqlCtx.inferSchema(rdd, 1.0) + df2 = self.sqlCtx.createDataFrame(rdd, 1.0) self.assertEqual(df.schema(), df2.schema()) self.assertEqual({}, df2.map(lambda r: r.d).first()) self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) @@ -206,7 +206,7 @@ def test_infer_schema(self): def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) k, v = df.head().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -215,7 +215,7 @@ def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) df.registerTempTable("test") row = self.sqlCtx.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) @@ -225,7 +225,7 @@ def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) schema = df.schema() field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) @@ -239,7 +239,7 @@ def test_apply_schema_with_udt(self): rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.applySchema(rdd, schema) + df = self.sqlCtx.createDataFrame(rdd, schema) point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) @@ -247,7 +247,7 @@ def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - df0 = self.sqlCtx.inferSchema(rdd) + df0 = self.sqlCtx.createDataFrame(rdd) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 97e3777f933e..90b42d03a320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * val people = * sc.textFile("examples/src/main/resources/people.txt").map( * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val dataFrame = sqlContext. applySchema(people, schema) + * val dataFrame = sqlContext.createDataFrame(people, schema) * dataFrame.printSchema * // root * // |-- name: string (nullable = false) @@ -252,11 +252,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * dataFrame.registerTempTable("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} - * - * @group userf */ @DeveloperApi - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) @@ -264,8 +262,21 @@ class SQLContext(@transient val sparkContext: SparkContext) } @DeveloperApi - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - applySchema(rowRDD.rdd, schema); + def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD.rdd, schema) + } + + /** + * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying + * a seq of names of columns to this RDD, the data type for each column will + * be inferred by the first row. + * + * @param rowRDD an JavaRDD of Row + * @param columns names for each column + * @return DataFrame + */ + def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = { + createDataFrame(rowRDD.rdd, columns.toSeq) } /** @@ -274,7 +285,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -301,8 +312,72 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ + def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd.rdd, beanClass) + } + + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * val sqlContext = new org.apache.spark.sql.SQLContext(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val dataFrame = sqlContext. applySchema(people, schema) + * dataFrame.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * dataFrame.registerTempTable("people") + * sqlContext.sql("select name from people").collect.foreach(println) + * }}} + * + * @group userf + */ + @DeveloperApi + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + @DeveloperApi + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + */ + @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - applySchema(rdd.rdd, beanClass) + createDataFrame(rdd, beanClass) } /** @@ -375,7 +450,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - applySchema(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema) } @Experimental @@ -393,7 +468,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - applySchema(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema) } @Experimental diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa4cdecbcb34..1d7103987243 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -180,7 +180,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -240,7 +240,7 @@ class ColumnExpressionSuite extends QueryTest { testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) } - val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 11502edf972e..575e4e2f4ef5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -34,6 +34,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { TestData import org.apache.spark.sql.test.TestSQLContext.implicits._ + val sqlCtx = TestSQLContext test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( @@ -669,7 +670,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = applySchema(rowRDD1, schema1) + val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -699,7 +700,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = applySchema(rowRDD2, schema2) + val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -724,7 +725,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = applySchema(rowRDD3, schema2) + val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -769,7 +770,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person.rdd, schemaWithMeta) + val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index df108a9d262b..c3210733f1d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -71,7 +71,7 @@ class PlannerSuite extends FunSuite { val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - applySchema(rowRDD, schema).registerTempTable("testLimit") + createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e581ac9b50c2..21e70936102f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -54,7 +54,7 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) + val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false) assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count) @@ -62,8 +62,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { } test("CREATE with overwrite") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) srdd.createJDBCTable(url, "TEST.DROPTEST", false) assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count) @@ -75,8 +75,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { } test("CREATE then INSERT to append") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) srdd.createJDBCTable(url, "TEST.APPENDTEST", false) srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false) @@ -85,8 +85,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { } test("CREATE then INSERT to truncate") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false) srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true) @@ -95,8 +95,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { } test("Incompatible INSERT to append") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3) + val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) intercept[org.apache.spark.SparkException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 7870cf9b0a86..87975ea979e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -820,7 +820,7 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = applySchema(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDataFrame val result = df2.toJSON.collect() @@ -841,7 +841,7 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = applySchema(rowRDD2, schema2) + val df3 = createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDataFrame val result2 = df4.toJSON.collect() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 869d01eb398c..b630704feea7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -84,7 +84,7 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -129,7 +129,7 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -146,7 +146,7 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -163,7 +163,7 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 49fe79d98925..9a6e8650a0ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} @@ -34,6 +35,7 @@ case class Nested3(f3: Int) class SQLQuerySuite extends QueryTest { import org.apache.spark.sql.hive.test.TestHive.implicits._ + val sqlCtx = TestHive test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( @@ -277,7 +279,7 @@ class SQLQuerySuite extends QueryTest { val rowRdd = sparkContext.parallelize(row :: Nil) - applySchema(rowRdd, schema).registerTempTable("testTable") + sqlCtx.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes