Skip to content

Commit bd5abb6

Browse files
committed
temp
1 parent 7bfd0e6 commit bd5abb6

File tree

7 files changed

+337
-36
lines changed

7 files changed

+337
-36
lines changed

kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ public Set<CollationIdentifier> getReferencedCollations() {
8989
} else if (child instanceof Predicate) {
9090
throw new IllegalStateException(
9191
String.format(
92-
"Expected child Predicate of DataSkippingPredicate to be an instance of" +
93-
" DataSkippingPredicate but found: %s",
92+
"Expected child Predicate of DataSkippingPredicate to be an instance of"
93+
+ " DataSkippingPredicate but found: %s",
9494
child, this));
9595
}
9696
}
97-
return referencedCollations;
97+
return Collections.unmodifiableSet(referencedCollations);
9898
}
9999

100100
/** @return an unmodifiable set containing all elements from both sets. */

kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ private static Optional<DataSkippingPredicate> constructDataSkippingFilter(
262262
Expression right = getRight(dataFilters);
263263
Optional<CollationIdentifier> collationIdentifier = dataFilters.getCollationIdentifier();
264264
if (collationIdentifier
265-
.filter(ci -> !ci.isSparkUTF8BinaryCollation() && ci.getVersion().isEmpty())
266-
.isPresent()) {
265+
.filter(ci -> !ci.isSparkUTF8BinaryCollation() && ci.getVersion().isEmpty())
266+
.isPresent()) {
267267
// Each collated statistics is stored with a specific version, so collation
268268
// must specify a version to be used for data skipping.
269269
return Optional.empty();

kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,27 +301,44 @@ private static StructType getMinMaxStatsSchema(StructType dataSchema) {
301301

302302
/**
303303
* Given a data schema and a set of collation identifiers returns the expected schema for
304-
* collation-aware statistics columns. This means 1) replace logical names with physical names 2)
305-
* set nullable=true 3) only keep collated-stats eligible fields (`StringType` fields)
304+
* collation-aware statistics columns.
306305
*/
307306
private static StructType getCollatedStatsSchema(
308307
StructType dataSchema, Set<CollationIdentifier> collationIdentifiers) {
309308
StructType statsWithCollation = new StructType();
310-
StructType collatedMinMaxStatsSchema = getMinMaxStatsSchema(dataSchema, true);
309+
StructType collationAwareFields = getCollationAwareFields(dataSchema);
311310
for (CollationIdentifier collationIdentifier : collationIdentifiers) {
312-
if (collatedMinMaxStatsSchema.length() > 0) {
311+
if (collationAwareFields.length() > 0) {
313312
statsWithCollation =
314313
statsWithCollation.add(
315314
collationIdentifier.toString(),
316315
new StructType()
317-
.add(MIN, collatedMinMaxStatsSchema, true)
318-
.add(MAX, collatedMinMaxStatsSchema, true),
316+
.add(MIN, collationAwareFields, true)
317+
.add(MAX, collationAwareFields, true),
319318
true);
320319
}
321320
}
322321
return statsWithCollation;
323322
}
324323

324+
/** Given a data schema returns its collation aware fields. */
325+
private static StructType getCollationAwareFields(StructType dataSchema) {
326+
StructType collationAwareFields = new StructType();
327+
for (StructField field : dataSchema.fields()) {
328+
DataType dataType = field.getDataType();
329+
if (dataType instanceof StructType) {
330+
StructType nestedCollationAwareFields = getCollationAwareFields((StructType) dataType);
331+
if (nestedCollationAwareFields.length() > 0) {
332+
collationAwareFields =
333+
collationAwareFields.add(getPhysicalName(field), nestedCollationAwareFields, true);
334+
}
335+
} else if (dataType instanceof StringType) {
336+
collationAwareFields = collationAwareFields.add(getPhysicalName(field), dataType, true);
337+
}
338+
}
339+
return collationAwareFields;
340+
}
341+
325342
/**
326343
* Given a data schema returns the expected schema for a null_count statistics column. This means
327344
* 1) replace logical names with physical names 2) set nullable=true 3) use LongType for all

kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,6 @@ class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils {
5353
new DataSkippingPredicate(operator, children.asJava, collation, referencedColumns.asJava)
5454
}
5555

56-
private def collatedStatsCol(
57-
collation: CollationIdentifier,
58-
statName: String,
59-
fieldName: String): Column = {
60-
new Column(Array(STATS_WITH_COLLATION, collation.toString, statName, fieldName))
61-
}
62-
6356
/* For struct type checks for equality based on field names & data type only */
6457
def compareDataTypeUnordered(type1: DataType, type2: DataType): Boolean = (type1, type2) match {
6558
case (schema1: StructType, schema2: StructType) =>
@@ -188,6 +181,59 @@ class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils {
188181
new StructType())
189182
}
190183

184+
test("pruneStatsSchema - collated statistics") {
185+
val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE.75")
186+
val unicode = CollationIdentifier.fromString("ICU.UNICODE.74.1")
187+
val unicodeString = new StringType(unicode)
188+
189+
val statsSchema = new StructType()
190+
.add(
191+
MIN,
192+
new StructType()
193+
.add("a", StringType.STRING)
194+
.add("b", unicodeString))
195+
.add(
196+
MAX,
197+
new StructType()
198+
.add("a", StringType.STRING)
199+
.add("b", unicodeString))
200+
.add(
201+
STATS_WITH_COLLATION,
202+
new StructType()
203+
.add(
204+
utf8Lcase.toString,
205+
new StructType()
206+
.add(MIN, new StructType().add("a", StringType.STRING).add("b", unicodeString))
207+
.add(MAX, new StructType().add("a", StringType.STRING).add("b", unicodeString)))
208+
.add(
209+
unicode.toString,
210+
new StructType()
211+
.add(MIN, new StructType().add("a", StringType.STRING).add("b", unicodeString))
212+
.add(MAX, new StructType().add("a", StringType.STRING).add("b", unicodeString))))
213+
214+
// Keep only: binary MAX.b and collated(utf8Lcase) MIN.a
215+
val referenced = Set(
216+
nestedCol(s"$MAX.b"),
217+
collatedStatsCol(utf8Lcase, MIN, "a"),
218+
collatedStatsCol(unicode, MAX, "b"))
219+
220+
val expected = new StructType()
221+
.add(MAX, new StructType().add("b", unicodeString))
222+
.add(
223+
STATS_WITH_COLLATION,
224+
new StructType()
225+
.add(
226+
utf8Lcase.toString,
227+
new StructType()
228+
.add(MIN, new StructType().add("a", StringType.STRING)))
229+
.add(
230+
unicode.toString,
231+
new StructType()
232+
.add(MAX, new StructType().add("b", unicodeString))))
233+
234+
checkPruneStatsSchema(statsSchema, referenced, expected)
235+
}
236+
191237
// TODO: add tests for remaining operators
192238
test("check constructDataSkippingFilter") {
193239
val testCases = Seq(
@@ -372,8 +418,8 @@ class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils {
372418
}
373419

374420
test("check constructDataSkippingFilter with collations") {
375-
val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE")
376-
val unicode = CollationIdentifier.fromString("ICU.UNICODE")
421+
val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE.75")
422+
val unicode = CollationIdentifier.fromString("ICU.UNICODE.74.1")
377423

378424
val testCases = Seq(
379425
// (schema, predicate, expectedDataSkippingPredicateOpt)
@@ -519,4 +565,26 @@ class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils {
519565
}
520566
}
521567
}
568+
569+
test("check constructDataSkippingFilter with collations (no version in collation)") {
570+
val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE")
571+
val unicode = CollationIdentifier.fromString("ICU.UNICODE")
572+
573+
val testCases = Seq(
574+
(
575+
new StructType()
576+
.add("a", StringType.STRING)
577+
.add("b", StringType.STRING),
578+
createPredicate("<", col("a"), literal("m"), Optional.of(unicode))),
579+
(
580+
new StructType()
581+
.add("a", StringType.STRING),
582+
createPredicate("<", literal("m"), col("a"), Optional.of(utf8Lcase))))
583+
584+
testCases.foreach { case (schema, predicate) =>
585+
val dataSkippingPredicateOpt =
586+
JavaOptionalOps(constructDataSkippingFilter(predicate, schema)).toScala
587+
assert(dataSkippingPredicateOpt.isEmpty)
588+
}
589+
}
522590
}

kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import io.delta.kernel.types.{ArrayType, BinaryType, BooleanType, ByteType, Coll
2222
import org.scalatest.funsuite.AnyFunSuite
2323

2424
class StatsSchemaHelperSuite extends AnyFunSuite {
25-
val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE")
26-
val unicode = CollationIdentifier.fromString("ICU.UNICODE")
25+
val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE.74")
26+
val unicode = CollationIdentifier.fromString("ICU.UNICODE.75.1")
27+
val utf8LcaseString = new StringType(utf8Lcase)
28+
val unicodeString = new StringType(unicode)
2729

2830
test("check getStatsSchema for supported data types") {
2931
val testCases = Seq(
@@ -126,15 +128,21 @@ class StatsSchemaHelperSuite extends AnyFunSuite {
126128
.add(StatsSchemaHelper.MIN, new StructType().add("k", new DecimalType(20, 5), true), true)
127129
.add(StatsSchemaHelper.MAX, new StructType().add("k", new DecimalType(20, 5), true), true)
128130
.add(StatsSchemaHelper.NULL_COUNT, new StructType().add("k", LongType.LONG, true), true)
131+
.add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)),
132+
(
133+
new StructType().add("b", utf8LcaseString),
134+
new StructType()
135+
.add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true)
136+
.add(StatsSchemaHelper.MIN, new StructType().add("b", utf8LcaseString, true), true)
137+
.add(StatsSchemaHelper.MAX, new StructType().add("b", utf8LcaseString, true), true)
138+
.add(StatsSchemaHelper.NULL_COUNT, new StructType().add("b", LongType.LONG, true), true)
129139
.add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)))
130140

131141
testCases.foreach { case (dataSchema, expectedStatsSchema) =>
132142
val statsSchema = StatsSchemaHelper.getStatsSchema(
133143
dataSchema,
134144
Set.empty[CollationIdentifier].asJava)
135-
assert(
136-
statsSchema == expectedStatsSchema,
137-
s"Stats schema mismatch for data schema: $dataSchema")
145+
assert(statsSchema == expectedStatsSchema)
138146
}
139147
}
140148

@@ -262,31 +270,36 @@ class StatsSchemaHelperSuite extends AnyFunSuite {
262270
.add("a", StringType.STRING)
263271
.add("b", IntegerType.INTEGER)
264272
.add("c", BinaryType.BINARY)
273+
.add("d", unicodeString)
265274

266275
val collations = Set(utf8Lcase)
267276

268-
val expectedCollatedMinMax = new StructType().add("a", StringType.STRING, true)
277+
val expectedCollatedMinMax = new StructType()
278+
.add("a", StringType.STRING, true).add("d", unicodeString, true)
269279

270280
val expectedStatsSchema = new StructType()
271281
.add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true)
272282
.add(
273283
StatsSchemaHelper.MIN,
274284
new StructType()
275285
.add("a", StringType.STRING, true)
276-
.add("b", IntegerType.INTEGER, true),
286+
.add("b", IntegerType.INTEGER, true)
287+
.add("d", unicodeString, true),
277288
true)
278289
.add(
279290
StatsSchemaHelper.MAX,
280291
new StructType()
281292
.add("a", StringType.STRING, true)
282-
.add("b", IntegerType.INTEGER, true),
293+
.add("b", IntegerType.INTEGER, true)
294+
.add("d", unicodeString, true),
283295
true)
284296
.add(
285297
StatsSchemaHelper.NULL_COUNT,
286298
new StructType()
287299
.add("a", LongType.LONG, true)
288300
.add("b", LongType.LONG, true)
289-
.add("c", LongType.LONG, true),
301+
.add("c", LongType.LONG, true)
302+
.add("d", LongType.LONG, true),
290303
true)
291304
.add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)
292305
.add(

kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
* limitations under the License.
1515
*/
1616
package io.delta.kernel.test
17-
1817
import java.util.Optional
1918

2019
import io.delta.kernel.expressions.{Column, Literal}
20+
import io.delta.kernel.internal.skipping.StatsSchemaHelper.STATS_WITH_COLLATION
21+
import io.delta.kernel.types.CollationIdentifier
2122

2223
/** Utility functions for tests. */
2324
trait TestUtils {
@@ -27,6 +28,15 @@ trait TestUtils {
2728
new Column(name.split("\\."))
2829
}
2930

31+
def collatedStatsCol(
32+
collation: CollationIdentifier,
33+
statName: String,
34+
fieldName: String): Column = {
35+
val columnPath =
36+
Array(STATS_WITH_COLLATION, collation.toString, statName) ++ fieldName.split('.')
37+
new Column(columnPath)
38+
}
39+
3040
def literal(value: Any): Literal = {
3141
value match {
3242
case v: String => Literal.ofString(v)

0 commit comments

Comments
 (0)