Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.*;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.Arrays;
Expand Down Expand Up @@ -108,6 +110,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
}
} else if (sparkType instanceof YearMonthIntervalType) {
return new IntegerUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new IntegerToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
case INT64 -> {
Expand Down Expand Up @@ -153,6 +157,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
return new LongAsMicrosUpdater();
} else if (sparkType instanceof DayTimeIntervalType) {
return new LongUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new LongToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
case FLOAT -> {
Expand Down Expand Up @@ -194,6 +200,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
if (sparkType == DataTypes.StringType || sparkType == DataTypes.BinaryType ||
canReadAsBinaryDecimal(descriptor, sparkType)) {
return new BinaryUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new BinaryToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
case FIXED_LEN_BYTE_ARRAY -> {
Expand All @@ -206,6 +214,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
return new FixedLenByteArrayUpdater(arrayLen);
} else if (sparkType == DataTypes.BinaryType) {
return new FixedLenByteArrayUpdater(arrayLen);
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new FixedLenByteArrayToDecimalUpdater(descriptor, (DecimalType) sparkType);
}
}
default -> {}
Expand Down Expand Up @@ -1358,6 +1368,188 @@ public void decodeSingleDictionaryId(
}
}

private abstract static class DecimalUpdater implements ParquetVectorUpdater {

private final DecimalType sparkType;

DecimalUpdater(DecimalType sparkType) {
this.sparkType = sparkType;
}

@Override
public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
readValue(offset + i, values, valuesReader);
}
}

protected void writeDecimal(int offset, WritableColumnVector values, BigDecimal decimal) {
BigDecimal scaledDecimal = decimal.setScale(sparkType.scale(), RoundingMode.UNNECESSARY);
if (DecimalType.is32BitDecimalType(sparkType)) {
values.putInt(offset, scaledDecimal.unscaledValue().intValue());
} else if (DecimalType.is64BitDecimalType(sparkType)) {
values.putLong(offset, scaledDecimal.unscaledValue().longValue());
} else {
values.putByteArray(offset, scaledDecimal.unscaledValue().toByteArray());
}
}
}

private static class IntegerToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;

IntegerToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipIntegers(total);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
valuesReader.skipIntegers(total);
valuesReader.skipIntegers(total);

}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
BigDecimal decimal = BigDecimal.valueOf(valuesReader.readInteger(), parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigDecimal decimal =
BigDecimal.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static class LongToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;

LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipLongs(total);
}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
BigDecimal decimal = BigDecimal.valueOf(valuesReader.readLong(), parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigDecimal decimal =
BigDecimal.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset)), parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static class BinaryToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;

BinaryToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipBinary(total);
}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
valuesReader.readBinary(1, values, offset);
BigInteger value = new BigInteger(values.getBinary(offset));
BigDecimal decimal = new BigDecimal(value, parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigInteger value =
new BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
BigDecimal decimal = new BigDecimal(value, parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static class FixedLenByteArrayToDecimalUpdater extends DecimalUpdater {
private final int parquetScale;
private final int arrayLen;

FixedLenByteArrayToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
this.arrayLen = descriptor.getPrimitiveType().getTypeLength();
}

@Override
public void skipValues(int total, VectorizedValuesReader valuesReader) {
valuesReader.skipFixedLenByteArray(total, arrayLen);
}

@Override
public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
BigInteger value = new BigInteger(valuesReader.readBinary(arrayLen).getBytes());
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
writeDecimal(offset, values, decimal);
}

@Override
public void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {
BigInteger value =
new BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
writeDecimal(offset, values, decimal);
}
}

private static int rebaseDays(int julianDays, final boolean failIfRebase) {
if (failIfRebase) {
if (julianDays < RebaseDateTime.lastSwitchJulianDay()) {
Expand Down Expand Up @@ -1418,16 +1610,21 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc

private static boolean canReadAsIntDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.is32BitDecimalType(dt)) return false;
return isDecimalTypeMatched(descriptor, dt);
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
}

private static boolean canReadAsLongDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.is64BitDecimalType(dt)) return false;
return isDecimalTypeMatched(descriptor, dt);
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
}

private static boolean canReadAsBinaryDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!DecimalType.isByteArrayDecimalType(dt)) return false;
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
}

private static boolean canReadAsDecimal(ColumnDescriptor descriptor, DataType dt) {
if (!(dt instanceof DecimalType)) return false;
return isDecimalTypeMatched(descriptor, dt);
}

Expand All @@ -1444,14 +1641,29 @@ private static boolean isDateTypeMatched(ColumnDescriptor descriptor) {
}

private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name should be isDecimalTypeCompatible?

DecimalType requestedType = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
DecimalLogicalTypeAnnotation parquetType = (DecimalLogicalTypeAnnotation) typeAnnotation;
// If the required scale is larger than or equal to the physical decimal scale in the Parquet
// metadata, we can upscale the value as long as the precision also increases by as much so
// that there is no loss of precision.
int scaleIncrease = requestedType.scale() - parquetType.getScale();
int precisionIncrease = requestedType.precision() - parquetType.getPrecision();
return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
}
return false;
}

private static boolean isSameDecimalScale(ColumnDescriptor descriptor, DataType dt) {
DecimalType d = (DecimalType) dt;
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation decimalType) {
// It's OK if the required decimal precision is larger than or equal to the physical decimal
// precision in the Parquet metadata, as long as the decimal scale is the same.
return decimalType.getPrecision() <= d.precision() && decimalType.getScale() == d.scale();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
return decimalType.getScale() == d.scale();
}
return false;
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -152,32 +152,51 @@ private boolean isLazyDecodingSupported(
switch (typeName) {
case INT32: {
boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation;
boolean needsUpcast = sparkType == LongType || (isDate && sparkType == TimestampNTZType) ||
!DecimalType.is32BitDecimalType(sparkType);
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
boolean needsUpcast = sparkType == LongType || sparkType == DoubleType ||
(isDate && sparkType == TimestampNTZType) ||
(isDecimal && !DecimalType.is32BitDecimalType(sparkType));
Comment on lines +156 to +158
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes an issue from #44368, we were incorrectly disabling lazy dictionary decoding for any non-decimal (INT32) type

boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation &&
!"CORRECTED".equals(datetimeRebaseMode);
isSupported = !needsUpcast && !needsRebase;
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
break;
}
case INT64: {
boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) ||
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
boolean needsUpcast = (isDecimal && !DecimalType.is64BitDecimalType(sparkType)) ||
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
!"CORRECTED".equals(datetimeRebaseMode);
isSupported = !needsUpcast && !needsRebase;
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
break;
}
case FLOAT:
isSupported = sparkType == FloatType;
break;
case DOUBLE:
case BINARY:
isSupported = true;
break;
case BINARY:
isSupported = !needsDecimalScaleRebase(sparkType);
break;
}
return isSupported;
}

/**
* Returns whether the Parquet type of this column and the given spark type are two decimal types
* with different scale.
*/
private boolean needsDecimalScaleRebase(DataType sparkType) {
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return false;
if (!(sparkType instanceof DecimalType)) return false;
DecimalLogicalTypeAnnotation parquetDecimal = (DecimalLogicalTypeAnnotation) typeAnnotation;
DecimalType sparkDecimal = (DecimalType) sparkType;
return parquetDecimal.getScale() != sparkDecimal.scale();
}

/**
* Reads `total` rows from this columnReader into column.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,9 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
}

withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
Seq("a DECIMAL(3, 2)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach { schema =>
val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
checkAnswer(readParquet(schema1, path), df)
Seq("a DECIMAL(3, 0)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach { schema =>
val e = intercept[SparkException] {
readParquet(schema, path).collect()
}.getCause.getCause
Expand Down
Loading