Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@ def json(self, path, schema=None):
or RDD of Strings storing JSON objects.
:param schema: an optional :class:`StructType` for the input schema.

You can set the following JSON-specific options to deal with non-standard JSON files:
* ``primitivesAsString`` (default ``false``): infers all primitive values as a string \
type
* ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records
* ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names
* ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \
quotes
* ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \
(e.g. 00012)

>>> df1 = sqlContext.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
[('age', 'bigint'), ('name', 'string')]
Expand Down
22 changes: 14 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation}
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -227,6 +227,15 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
* <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li>
* <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li>
* <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
* </li>
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
* (e.g. 00012)</li>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add samplingRatio?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we skipped it in the past because it had very little impact on performance, so in most cases it is better to just use 1.0... Maybe we should even deprecate that option.

*
* @param path input path
* @since 1.4.0
*/
Expand Down Expand Up @@ -255,16 +264,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @since 1.4.0
*/
def json(jsonRDD: RDD[String]): DataFrame = {
val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble
val primitivesAsString = extraOptions.getOrElse("primitivesAsString", "false").toBoolean
sqlContext.baseRelationToDataFrame(
new JSONRelation(
Some(jsonRDD),
samplingRatio,
primitivesAsString,
userSpecifiedSchema,
None,
None)(sqlContext)
maybeDataSchema = userSpecifiedSchema,
maybePartitionSpec = None,
userDefinedPartitionColumns = None,
parameters = extraOptions.toMap)(sqlContext)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

private[this] def isTesting: Boolean = sys.props.contains("spark.testing")

protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
log.debug(s"Creating Projection: $expressions, inputSchema: $inputSchema")
try {
GenerateProjection.generate(expressions, inputSchema)
} catch {
case e: Exception =>
if (isTesting) {
throw e
} else {
log.error("Failed to generate projection, fallback to interpret", e)
new InterpretedProjection(expressions, inputSchema)
}
}
}

protected def newMutableProjection(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now unused.

expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
Expand Down Expand Up @@ -282,6 +266,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}
}

/**
* Creates a row ordering for the given schema, in natural ascending order.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,36 @@ import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

private[sql] object InferSchema {

private[json] object InferSchema {

/**
* Infer the type of a collection of json records in three stages:
* 1. Infer the type of each record
* 2. Merge types by choosing the lowest type necessary to cover equal keys
* 3. Replace any remaining null fields with string, the top type
*/
def apply(
def infer(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was super confusing for apply to return RDD (i.e. it is not a factory method).

json: RDD[String],
samplingRatio: Double = 1.0,
columnNameOfCorruptRecords: String,
primitivesAsString: Boolean = false): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) {
configOptions: JSONOptions): StructType = {
require(configOptions.samplingRatio > 0,
s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
val schemaData = if (configOptions.samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, samplingRatio, 1)
json.sample(withReplacement = false, configOptions.samplingRatio, 1)
}

// perform schema inference on each row and merge afterwards
val rootType = schemaData.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.map { row =>
try {
Utils.tryWithResource(factory.createParser(row)) { parser =>
parser.nextToken()
inferField(parser, primitivesAsString)
inferField(parser, configOptions)
}
} catch {
case _: JsonParseException =>
Expand All @@ -71,14 +74,14 @@ private[sql] object InferSchema {
/**
* Infer the type of a json document from the parser's token stream
*/
private def inferField(parser: JsonParser, primitivesAsString: Boolean): DataType = {
private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType

case FIELD_NAME =>
parser.nextToken()
inferField(parser, primitivesAsString)
inferField(parser, configOptions)

case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
Expand All @@ -95,7 +98,7 @@ private[sql] object InferSchema {
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
inferField(parser, primitivesAsString),
inferField(parser, configOptions),
nullable = true)
}

Expand All @@ -107,14 +110,15 @@ private[sql] object InferSchema {
// the type as we pass through all JSON objects.
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(elementType, inferField(parser, primitivesAsString))
elementType = compatibleType(
elementType, inferField(parser, configOptions))
}

ArrayType(elementType)

case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if primitivesAsString => StringType
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType

case (VALUE_TRUE | VALUE_FALSE) if primitivesAsString => StringType
case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType

case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
import JsonParser.NumberType._
Expand Down Expand Up @@ -178,7 +182,7 @@ private[sql] object InferSchema {
/**
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
def compatibleType(t1: DataType, t2: DataType): DataType = {
HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.json

import com.fasterxml.jackson.core.{JsonParser, JsonFactory}

/**
* Options for the JSON data source.
*
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
case class JSONOptions(
samplingRatio: Double = 1.0,
primitivesAsString: Boolean = false,
allowComments: Boolean = false,
allowUnquotedFieldNames: Boolean = false,
allowSingleQuotes: Boolean = true,
allowNumericLeadingZeros: Boolean = false,
allowNonNumericNumbers: Boolean = false) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allowNonNumericNumbers is undocumented for now, since I can't figure out how it works.


/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments)
factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, allowUnquotedFieldNames)
factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes)
factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros)
factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers)
}
}


object JSONOptions {
def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions(
samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0),
primitivesAsString =
parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false),
allowComments =
parameters.get("allowComments").map(_.toBoolean).getOrElse(false),
allowUnquotedFieldNames =
parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false),
allowSingleQuotes =
parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true),
allowNumericLeadingZeros =
parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false),
allowNonNumericNumbers =
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,9 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false)

new JSONRelation(
inputRDD = None,
samplingRatio = samplingRatio,
primitivesAsString = primitivesAsString,
maybeDataSchema = dataSchema,
maybePartitionSpec = None,
userDefinedPartitionColumns = partitionColumns,
Expand All @@ -69,8 +65,6 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {

private[sql] class JSONRelation(
val inputRDD: Option[RDD[String]],
val samplingRatio: Double,
val primitivesAsString: Boolean,
val maybeDataSchema: Option[StructType],
val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
Expand All @@ -79,6 +73,8 @@ private[sql] class JSONRelation(
(@transient val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec, parameters) {

val options: JSONOptions = JSONOptions.createFromConfigMap(parameters)

/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
Expand Down Expand Up @@ -109,17 +105,16 @@ private[sql] class JSONRelation(
classOf[Text]).map(_._2.toString) // get the text line
}

override lazy val dataSchema = {
override lazy val dataSchema: StructType = {
val jsonSchema = maybeDataSchema.getOrElse {
val files = cachedLeafStatuses().filterNot { status =>
val name = status.getPath.getName
name.startsWith("_") || name.startsWith(".")
}.toArray
InferSchema(
InferSchema.infer(
inputRDD.getOrElse(createBaseRdd(files)),
samplingRatio,
sqlContext.conf.columnNameOfCorruptRecord,
primitivesAsString)
options)
}
checkConstraints(jsonSchema)

Expand All @@ -132,10 +127,11 @@ private[sql] class JSONRelation(
inputPaths: Array[FileStatus],
broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = {
val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_)))
val rows = JacksonParser(
val rows = JacksonParser.parse(
inputRDD.getOrElse(createBaseRdd(inputPaths)),
requiredDataSchema,
sqlContext.conf.columnNameOfCorruptRecord)
sqlContext.conf.columnNameOfCorruptRecord,
options)

rows.mapPartitions { iterator =>
val unsafeProjection = UnsafeProjection.create(requiredDataSchema)
Expand Down
Loading