diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 01e41b3c5df3..9cfc943cd2b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -524,41 +524,40 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("filter pushdown - decimal") { - Seq(true, false).foreach { legacyFormat => + Seq( + (false, Decimal.MAX_INT_DIGITS), // int32Writer + (false, Decimal.MAX_LONG_DIGITS), // int64Writer + (true, Decimal.MAX_LONG_DIGITS), // binaryWriterUsingUnscaledLong + (false, DecimalType.MAX_PRECISION) // binaryWriterUsingUnscaledBytes + ).foreach { case (legacyFormat, precision) => withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { - Seq( - s"a decimal(${Decimal.MAX_INT_DIGITS}, 2)", // 32BitDecimalType - s"a decimal(${Decimal.MAX_LONG_DIGITS}, 2)", // 64BitDecimalType - "a decimal(38, 18)" // ByteArrayDecimalType - ).foreach { schemaDDL => - val schema = StructType.fromDDL(schemaDDL) - val rdd = - spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) - val dataFrame = spark.createDataFrame(rdd, schema) - testDecimalPushDown(dataFrame) { implicit df => - assert(df.schema === schema) - checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('a === 1, classOf[Eq[_]], 1) - checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('a < 2, classOf[Lt[_]], 1) - checkFilterPredicate('a > 3, classOf[Gt[_]], 4) - checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + val schema = StructType.fromDDL(s"a decimal($precision, 2)") + val rdd = + spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) + val dataFrame = spark.createDataFrame(rdd, schema) + testDecimalPushDown(dataFrame) { implicit df => + assert(df.schema === schema) + checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('a === 1, classOf[Eq[_]], 1) + checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('a < 2, classOf[Lt[_]], 1) + checkFilterPredicate('a > 3, classOf[Gt[_]], 4) + checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } }