diff --git a/sql/core/src/main/scala/org/apache/spark/sql/schema/SchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/schema/SchemaUtils.scala new file mode 100644 index 000000000000..324cf668eaa0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/schema/SchemaUtils.scala @@ -0,0 +1,69 @@ +package org.apache.spark.sql.schema + +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ + +object SchemaUtils { + + /** + *This method will replace NullType to StringType in a input dataFrame schema + * @param dataFrame + * @param sqlContext + * @return + */ + def replaceNullTypeToStringType(dataFrame: DataFrame, sqlContext: SQLContext): DataFrame = { + require(dataFrame != null, "dataFrame cannot be null") + require(sqlContext != null, "sqlContext cannot be null") + + val schema = getStructType(dataFrame.schema) + + val df = sqlContext.createDataFrame(dataFrame.rdd, schema) + + df + } + + /** + *This method will replace NullType to StringType in a input schema + * @param st + * @return + */ + def getStructType(st: StructType): StructType = { + require(st != null, "StructType cannot be null") + + val fields = st.fields.toList + var fieldsNew = List[StructField]() + var i = 0 + + fields.foreach{ + e => { + fieldsNew = fieldsNew ::: List(StructField(e.name, getDataType(e.dataType), e.nullable, e.metadata)) + i = i + 1 + } + } + StructType(fieldsNew.toArray) + } + + /** + * + * @param dataType + * @return + */ + private def getDataType(dataType: Any): DataType = dataType match { + case mt: MapType => { + MapType(getDataType(mt.keyType), getDataType(mt.valueType), mt.valueContainsNull) + } + case at: ArrayType => { + ArrayType(getDataType(at.elementType), at.containsNull) + } + case st: StructType => { + getStructType(st.asInstanceOf[StructType]) + } + case elem => { + elem.asInstanceOf[DataType] match { + case NullType => StringType + case elem => elem + } + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/schema/TestSchemaUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/schema/TestSchemaUtils.scala new file mode 100644 index 000000000000..4bd0204e1111 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/schema/TestSchemaUtils.scala @@ -0,0 +1,20 @@ +package org.apache.spark.sql.schema + +import org.apache.spark.sql.types._ +import org.scalatest.{Matchers, FlatSpec} + +class TestSchemaUtils extends FlatSpec with Matchers with Serializable{ + +"Test-1: getStructType" should "return testSchema2" in { + val testSchema1 = StructType(Array(StructField("value", ArrayType(StructType(Array(StructField("seqId",IntegerType,true), StructField("value",NullType,true))),false),true))) + val testSchema2 = StructType(Array(StructField("value", ArrayType(StructType(Array(StructField("seqId",IntegerType,true), StructField("value",StringType,true))),false),true))) + SchemaUtils.getStructType(testSchema1) should be (testSchema2) +} + +"Test-2: getStructType" should "return testSchema2" in { + val testSchema1 = StructType(Array(StructField("additionalStrap",StructType(Array(StructField("seqId",IntegerType,true), StructField("isGlobal",BooleanType,true), StructField("label",NullType,true), StructField("name",StringType,true))),true))) + val testSchema2 = StructType(Array(StructField("additionalStrap",StructType(Array(StructField("seqId",IntegerType,true), StructField("isGlobal",BooleanType,true), StructField("label",StringType,true), StructField("name",StringType,true))),true))) + SchemaUtils.getStructType(testSchema1) should be (testSchema2) +} + +}