Skip to content

Commit 69adca5

Browse files
committed
more fixes
1 parent 671471f commit 69adca5

File tree

9 files changed

+25
-26
lines changed

9 files changed

+25
-26
lines changed

docs/sql-ref-datatypes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Spark SQL and DataFrames support the following data types:
3838
* String type
3939
- `StringType`: Represents character string values.
4040
- `VarcharType(length)`: A variant of `StringType` which has a length limitation. Data writing will fail if the input string exceeds the length limitation. Note: this type can only be used in table schema, not functions/operators.
41-
- `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `VarcharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length.
41+
- `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `CharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length.
4242
* Binary type
4343
- `BinaryType`: Represents byte sequence values.
4444
* Boolean type

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
9999
}
100100

101101
override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
102-
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
102+
val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(
103+
StructType(visitColTypeList(ctx.colTypeList)))
104+
withOrigin(ctx)(schema)
103105
}
104106

105107
def parseRawDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ object CharVarcharUtils {
137137
}
138138

139139
/**
140-
* Returns an expression to apply write-side char type padding for the given expression. A string
141-
* value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) column/field.
140+
* Returns an expression to apply write-side string length check for the given expression. A
141+
* string value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N)
142+
* column/field.
142143
*/
143144
def stringLengthCheck(expr: Expression, targetAttr: Attribute): Expression = {
144145
getRawType(targetAttr.metadata).map { rawType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.unsafe.types.UTF8String
2525

2626
@Experimental
2727
case class CharType(length: Int) extends AtomicType {
28-
require(length >= 0, "The length if char type cannot be negative.")
28+
require(length >= 0, "The length of char type cannot be negative.")
2929

3030
private[sql] type InternalType = UTF8String
3131
@transient private[sql] lazy val tag = typeTag[InternalType]

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import org.apache.spark.sql.AnalysisException
3232
import org.apache.spark.sql.catalyst.analysis.Resolver
3333
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
3434
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
35-
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
3635
import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer}
3736
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
3837
import org.apache.spark.sql.internal.SQLConf
@@ -133,8 +132,7 @@ object DataType {
133132
ddl,
134133
CatalystSqlParser.parseDataType,
135134
"Cannot parse the data type: ",
136-
fallbackParser = str => CharVarcharUtils.replaceCharVarcharWithStringInSchema(
137-
CatalystSqlParser.parseTableSchema(str)))
135+
fallbackParser = str => CatalystSqlParser.parseTableSchema(str))
138136
}
139137

140138
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String
2424

2525
@Experimental
2626
case class VarcharType(length: Int) extends AtomicType {
27-
require(length >= 0, "The length if varchar type cannot be negative.")
27+
require(length >= 0, "The length of varchar type cannot be negative.")
2828

2929
private[sql] type InternalType = UTF8String
3030
@transient private[sql] lazy val tag = typeTag[InternalType]

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
7373
* @since 1.4.0
7474
*/
7575
def schema(schema: StructType): DataFrameReader = {
76-
this.userSpecifiedSchema = Option(schema)
76+
this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema))
7777
this
7878
}
7979

@@ -274,14 +274,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
274274
extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray))
275275
}
276276

277-
val cleanedUserSpecifiedSchema = userSpecifiedSchema
278-
.map(CharVarcharUtils.replaceCharVarcharWithStringInSchema)
279-
280277
val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++
281278
optionsWithPath.originalMap
282279
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
283280
val (table, catalog, ident) = provider match {
284-
case _: SupportsCatalogOptions if cleanedUserSpecifiedSchema.nonEmpty =>
281+
case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty =>
285282
throw new IllegalArgumentException(
286283
s"$source does not support user specified schema. Please don't specify the schema.")
287284
case hasCatalog: SupportsCatalogOptions =>
@@ -293,8 +290,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
293290
(catalog.loadTable(ident), Some(catalog), Some(ident))
294291
case _ =>
295292
// TODO: Non-catalog paths for DSV2 are currently not well defined.
296-
val tbl = DataSourceV2Utils.getTableFromProvider(
297-
provider, dsOptions, cleanedUserSpecifiedSchema)
293+
val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema)
298294
(tbl, None, None)
299295
}
300296
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
@@ -316,15 +312,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
316312
} else {
317313
(paths, extraOptions)
318314
}
319-
val cleanedUserSpecifiedSchema = userSpecifiedSchema
320-
.map(CharVarcharUtils.replaceCharVarcharWithStringInSchema)
321315

322316
// Code path for data source v1.
323317
sparkSession.baseRelationToDataFrame(
324318
DataSource.apply(
325319
sparkSession,
326320
paths = finalPaths,
327-
userSpecifiedSchema = cleanedUserSpecifiedSchema,
321+
userSpecifiedSchema = userSpecifiedSchema,
328322
className = source,
329323
options = finalOptions.originalMap).resolveRelation())
330324
}

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
6464
* @since 2.0.0
6565
*/
6666
def schema(schema: StructType): DataStreamReader = {
67-
this.userSpecifiedSchema = Option(schema)
67+
this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema))
6868
this
6969
}
7070

@@ -203,17 +203,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
203203
extraOptions + ("path" -> path.get)
204204
}
205205

206-
val cleanedUserSpecifiedSchema = userSpecifiedSchema
207-
.map(CharVarcharUtils.replaceCharVarcharWithStringInSchema)
208-
209206
val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).
210207
getConstructor().newInstance()
211208
// We need to generate the V1 data source so we can pass it to the V2 relation as a shim.
212209
// We can't be sure at this point whether we'll actually want to use V2, since we don't know the
213210
// writer or whether the query is continuous.
214211
val v1DataSource = DataSource(
215212
sparkSession,
216-
userSpecifiedSchema = cleanedUserSpecifiedSchema,
213+
userSpecifiedSchema = userSpecifiedSchema,
217214
className = source,
218215
options = optionsWithPath.originalMap)
219216
val v1Relation = ds match {
@@ -228,8 +225,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
228225
val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++
229226
optionsWithPath.originalMap
230227
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
231-
val table = DataSourceV2Utils.getTableFromProvider(
232-
provider, dsOptions, cleanedUserSpecifiedSchema)
228+
val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema)
233229
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
234230
table match {
235231
case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) =>

sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,14 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession {
448448
assert(schema.map(_.dataType) == Seq(StringType))
449449
}
450450

451+
test("user-specified schema in DataFrameReader: file source from Dataset") {
452+
val ds = spark.range(10).map(_.toString)
453+
val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds)
454+
assert(df1.schema.map(_.dataType) == Seq(StringType))
455+
val df2 = spark.read.schema("id char(5)").csv(ds)
456+
assert(df2.schema.map(_.dataType) == Seq(StringType))
457+
}
458+
451459
test("user-specified schema in DataFrameReader: DSV1") {
452460
def checkSchema(df: DataFrame): Unit = {
453461
val relations = df.queryExecution.analyzed.collect {

0 commit comments

Comments
 (0)