Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51065][SQL] Disallowing non-nullable schema when Avro encoding is used for TransformWithState #49751

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4669,6 +4669,13 @@
],
"sqlState" : "42K06"
},
"STATE_STORE_SCHEMA_MUST_BE_NULLABLE" : {
"message" : [
"If schema evolution is enabled, all the fields in the schema for column family <columnFamilyName> must be nullable.",
"Please make the schema nullable. Current schema: <schema>"
],
"sqlState" : "XXKST"
},
"STATE_STORE_STATE_SCHEMA_FILES_THRESHOLD_EXCEEDED" : {
"message" : [
"The number of state schema files <numStateSchemaFiles> exceeds the maximum number of state schema files for this query: <maxStateSchemaFiles>.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,37 @@ def check_exception(error):
df,
check_exception=check_exception,
)
def test_not_nullable_fails(self):
with self.sql_conf({"spark.sql.streaming.stateStore.encodingFormat": "avro"}):
with tempfile.TemporaryDirectory() as checkpoint_dir:
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)

df = self._build_test_df(input_path)

def check_basic_state(batch_df, batch_id):
result = batch_df.collect()[0]
assert result.value["id"] == 0 # First ID from test data
assert result.value["name"] == "name-0"

def check_exception(error):
from pyspark.errors.exceptions.captured import StreamingQueryException

if not isinstance(error, StreamingQueryException):
return False

error_msg = str(error)
return (
"[STATE_STORE_SCHEMA_MUST_BE_NULLABLE]" in error_msg
and "column family state must be nullable" in error_msg
)
self._run_evolution_test(
BasicProcessorNotNullable(),
checkpoint_dir,
check_basic_state,
df,
check_exception=check_exception
)

class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
Expand Down Expand Up @@ -1892,6 +1922,26 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
def close(self) -> None:
pass

class BasicProcessorNotNullable(StatefulProcessor):
# Schema definitions
state_schema = StructType(
[StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
)

def init(self, handle):
self.state = handle.getValueState("state", self.state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
for pdf in rows:
pass
id_val = int(key[0])
name = f"name-{id_val}"
self.state.update((id_val, name))
yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})

def close(self) -> None:
pass


class AddFieldsProcessor(StatefulProcessor):
state_schema = StructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ case class TransformWithStateInPandasExec(
override def getColFamilySchemas(
setNullableFields: Boolean
): Map[String, StateStoreColFamilySchema] = {
driverProcessorHandle.getColumnFamilySchemas(setNullableFields)
driverProcessorHandle.getColumnFamilySchemas(true)
}

override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,25 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
addTimerColFamily()
}

private def isInternal(columnFamilyName: String): Boolean = {
columnFamilyName.startsWith("_") || columnFamilyName.startsWith("$")
}

def getColumnFamilySchemas(
setNullableFields: Boolean
shouldCheckNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
val schemas = columnFamilySchemas.toMap
if (setNullableFields) {
schemas.map { case (colFamilyName, stateStoreColFamilySchema) =>
colFamilyName -> stateStoreColFamilySchema.copy(
valueSchema = stateStoreColFamilySchema.valueSchema.toNullable
)
schemas.map { case (colFamilyName, schema) =>
// assert that each field is nullable if schema evolution is enabled
schema.valueSchema.fields.foreach { field =>
if (!field.nullable && shouldCheckNullable && !isInternal(colFamilyName)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to treat internal col families differently ?

throw StateStoreErrors.stateStoreSchemaMustBeNullable(
schema.colFamilyName, schema.valueSchema.toString())
}
}
} else {
schemas
colFamilyName -> schema.copy(
valueSchema = schema.valueSchema.toNullable
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ case class TransformWithStateExec(
Some(NoPrefixKeyStateEncoderSpec(keySchema)))

val columnFamilySchemas = getDriverProcessorHandle()
.getColumnFamilySchemas(setNullableFields) ++
.getColumnFamilySchemas(false) ++
Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
closeProcessorHandle()
columnFamilySchemas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ object StateStoreErrors {
new StateStoreValueSchemaNotCompatible(storedValueSchema, newValueSchema)
}

def stateStoreSchemaMustBeNullable(
columnFamilyName: String,
schema: String): StateStoreSchemaMustBeNullable = {
new StateStoreSchemaMustBeNullable(columnFamilyName, schema)
}

def stateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String): StateStoreInvalidValueSchemaEvolution = {
Expand Down Expand Up @@ -346,6 +352,15 @@ class StateStoreValueSchemaNotCompatible(
"storedValueSchema" -> storedValueSchema,
"newValueSchema" -> newValueSchema))

class StateStoreSchemaMustBeNullable(
columnFamilyName: String,
schema: String)
extends SparkUnsupportedOperationException(
errorClass = "STATE_STORE_SCHEMA_MUST_BE_NULLABLE",
messageParameters = Map(
"columnFamilyName" -> columnFamilyName,
"schema" -> schema))

class StateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
withTempDir { checkpointDir =>
// When Avro is used, we want to set the StructFields to nullable
val shouldBeNullable = usingAvroEncoding()
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString,
s"$metadataPathPostfix")
Expand All @@ -1826,15 +1825,15 @@ class TransformWithStateSuite extends StateStoreMetricsTest
val schema0 = StateStoreColFamilySchema(
"countState", 0,
keySchema, 0,
new StructType().add("value", LongType, nullable = shouldBeNullable),
new StructType().add("value", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
val schema1 = StateStoreColFamilySchema(
"listState", 0,
keySchema, 0,
new StructType()
.add("id", LongType, nullable = shouldBeNullable)
.add("id", LongType, nullable = true)
.add("name", StringType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
Expand All @@ -1857,7 +1856,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
) {
withTempDir { checkpointDir =>
// When Avro is used, we want to set the StructFields to nullable
val shouldBeNullable = usingAvroEncoding()
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString, s"$metadataPathPostfix")
val hadoopConf = spark.sessionState.newHadoopConf()
Expand Down Expand Up @@ -317,15 +316,15 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema2 = StateStoreColFamilySchema(
"$count_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)

val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down Expand Up @@ -409,7 +408,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
"valueStateTTL", 0,
keySchema, 0,
new StructType()
.add("value", new StructType().add("value", IntegerType, nullable = shouldBeNullable))
.add("value", new StructType().add("value", IntegerType, nullable = true))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
Expand All @@ -418,7 +417,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema10 = StateStoreColFamilySchema(
"valueState", 0,
keySchema, 0,
new StructType().add("value", IntegerType, nullable = shouldBeNullable),
new StructType().add("value", IntegerType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand All @@ -428,7 +427,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
keySchema, 0,
new StructType()
.add("value", new StructType()
.add("id", LongType, nullable = shouldBeNullable)
.add("id", LongType, nullable = true)
.add("name", StringType))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
Expand Down