@@ -181,7 +181,7 @@ object DataType {
181181 /**
182182 * Compares two types, ignoring nullability of ArrayType, MapType, StructType.
183183 */
184- private [spark ] def equalsIgnoreNullability (left : DataType , right : DataType ): Boolean = {
184+ private [types ] def equalsIgnoreNullability (left : DataType , right : DataType ): Boolean = {
185185 (left, right) match {
186186 case (ArrayType (leftElementType, _), ArrayType (rightElementType, _)) =>
187187 equalsIgnoreNullability(leftElementType, rightElementType)
@@ -213,7 +213,7 @@ object DataType {
213213 * if and only if for all every pair of fields, `to.nullable` is true, or both
214214 * of `fromField.nullable` and `toField.nullable` are false.
215215 */
216- private [spark ] def equalsIgnoreCompatibleNullability (from : DataType , to : DataType ): Boolean = {
216+ private [sql ] def equalsIgnoreCompatibleNullability (from : DataType , to : DataType ): Boolean = {
217217 (from, to) match {
218218 case (ArrayType (fromElement, fn), ArrayType (toElement, tn)) =>
219219 (tn || ! fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
@@ -235,20 +235,6 @@ object DataType {
235235 case (fromDataType, toDataType) => fromDataType == toDataType
236236 }
237237 }
238-
239- /** Sets all nullable/containsNull/valueContainsNull to true. */
240- private [spark] def alwaysNullable (dataType : DataType ): DataType = dataType match {
241- case ArrayType (elementType, _) =>
242- ArrayType (alwaysNullable(elementType), containsNull = true )
243- case MapType (keyType, valueType, _) =>
244- MapType (alwaysNullable(keyType), alwaysNullable(valueType), valueContainsNull = true )
245- case StructType (fields) =>
246- val newFields = fields.map { field =>
247- StructField (field.name, alwaysNullable(field.dataType), nullable = true )
248- }
249- StructType (newFields)
250- case other => other
251- }
252238}
253239
254240
@@ -281,6 +267,16 @@ abstract class DataType {
281267 def prettyJson : String = pretty(render(jsonValue))
282268
283269 def simpleString : String = typeName
270+
271+ /** Check if `this` and `other` are the same data type when ignoring nullability
272+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
273+ */
274+ def sameType (other : DataType ): Boolean = DataType .equalsIgnoreNullability(this , other)
275+
276+ /** Returns the same data type but set all nullability fields are true
277+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
278+ */
279+ def asNullable : DataType
284280}
285281
286282/**
@@ -296,6 +292,8 @@ class NullType private() extends DataType {
296292 // this type. Otherwise, the companion object would be of type "NullType$" in byte code.
297293 // Defined with a private constructor so the companion object is the only possible instantiation.
298294 override def defaultSize : Int = 1
295+
296+ override def asNullable : NullType = this
299297}
300298
301299case object NullType extends NullType
@@ -361,6 +359,8 @@ class StringType private() extends NativeType with PrimitiveType {
361359 * The default size of a value of the StringType is 4096 bytes.
362360 */
363361 override def defaultSize : Int = 4096
362+
363+ override def asNullable : StringType = this
364364}
365365
366366case object StringType extends StringType
@@ -395,6 +395,8 @@ class BinaryType private() extends NativeType with PrimitiveType {
395395 * The default size of a value of the BinaryType is 4096 bytes.
396396 */
397397 override def defaultSize : Int = 4096
398+
399+ override def asNullable : BinaryType = this
398400}
399401
400402case object BinaryType extends BinaryType
@@ -420,6 +422,8 @@ class BooleanType private() extends NativeType with PrimitiveType {
420422 * The default size of a value of the BooleanType is 1 byte.
421423 */
422424 override def defaultSize : Int = 1
425+
426+ override def asNullable : BooleanType = this
423427}
424428
425429case object BooleanType extends BooleanType
@@ -450,6 +454,8 @@ class TimestampType private() extends NativeType {
450454 * The default size of a value of the TimestampType is 12 bytes.
451455 */
452456 override def defaultSize : Int = 12
457+
458+ override def asNullable : TimestampType = this
453459}
454460
455461case object TimestampType extends TimestampType
@@ -478,6 +484,8 @@ class DateType private() extends NativeType {
478484 * The default size of a value of the DateType is 4 bytes.
479485 */
480486 override def defaultSize : Int = 4
487+
488+ override def asNullable : DateType = this
481489}
482490
483491case object DateType extends DateType
@@ -536,6 +544,8 @@ class LongType private() extends IntegralType {
536544 override def defaultSize : Int = 8
537545
538546 override def simpleString = " bigint"
547+
548+ override def asNullable : LongType = this
539549}
540550
541551case object LongType extends LongType
@@ -565,6 +575,8 @@ class IntegerType private() extends IntegralType {
565575 override def defaultSize : Int = 4
566576
567577 override def simpleString = " int"
578+
579+ override def asNullable : IntegerType = this
568580}
569581
570582case object IntegerType extends IntegerType
@@ -594,6 +606,8 @@ class ShortType private() extends IntegralType {
594606 override def defaultSize : Int = 2
595607
596608 override def simpleString = " smallint"
609+
610+ override def asNullable : ShortType = this
597611}
598612
599613case object ShortType extends ShortType
@@ -623,6 +637,8 @@ class ByteType private() extends IntegralType {
623637 override def defaultSize : Int = 1
624638
625639 override def simpleString = " tinyint"
640+
641+ override def asNullable : ByteType = this
626642}
627643
628644case object ByteType extends ByteType
@@ -689,6 +705,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
689705 case Some (PrecisionInfo (precision, scale)) => s " decimal( $precision, $scale) "
690706 case None => " decimal(10,0)"
691707 }
708+
709+ override def asNullable : DecimalType = this
692710}
693711
694712
@@ -747,6 +765,8 @@ class DoubleType private() extends FractionalType {
747765 * The default size of a value of the DoubleType is 8 bytes.
748766 */
749767 override def defaultSize : Int = 8
768+
769+ override def asNullable : DoubleType = this
750770}
751771
752772case object DoubleType extends DoubleType
@@ -775,6 +795,8 @@ class FloatType private() extends FractionalType {
775795 * The default size of a value of the FloatType is 4 bytes.
776796 */
777797 override def defaultSize : Int = 4
798+
799+ override def asNullable : FloatType = this
778800}
779801
780802case object FloatType extends FloatType
@@ -823,6 +845,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
823845 override def defaultSize : Int = 100 * elementType.defaultSize
824846
825847 override def simpleString = s " array< ${elementType.simpleString}> "
848+
849+ override def asNullable : ArrayType = ArrayType (elementType.asNullable, containsNull = true )
826850}
827851
828852
@@ -1068,6 +1092,15 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
10681092 */
10691093 private [sql] def merge (that : StructType ): StructType =
10701094 StructType .merge(this , that).asInstanceOf [StructType ]
1095+
1096+ override def asNullable : StructType = {
1097+ val newFields = fields.map {
1098+ case StructField (name, dataType, nullable, metadata) =>
1099+ StructField (name, dataType.asNullable, nullable = true , metadata)
1100+ }
1101+
1102+ StructType (newFields)
1103+ }
10711104}
10721105
10731106
@@ -1120,6 +1153,9 @@ case class MapType(
11201153 override def defaultSize : Int = 100 * (keyType.defaultSize + valueType.defaultSize)
11211154
11221155 override def simpleString = s " map< ${keyType.simpleString}, ${valueType.simpleString}> "
1156+
1157+ override def asNullable : MapType =
1158+ MapType (keyType.asNullable, valueType.asNullable, valueContainsNull = true )
11231159}
11241160
11251161
@@ -1173,4 +1209,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
11731209 * The default size of a value of the UserDefinedType is 4096 bytes.
11741210 */
11751211 override def defaultSize : Int = 4096
1212+
1213+ override def sameType (other : DataType ): Boolean = ???
1214+
1215+ override def asNullable : DataType = ???
11761216}
0 commit comments