Skip to content

Commit

Permalink
[HUDI-8998] Improve handling of zero scale decimals in MercifulJsonCo…
Browse files Browse the repository at this point in the history
…nverter (#12822)
  • Loading branch information
the-other-tim-brown authored Feb 11, 2025
1 parent 861fe11 commit 7013912
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.MathContext;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.sql.Date;
Expand Down Expand Up @@ -1156,6 +1157,17 @@ private static Object rewritePrimaryTypeWithDiffSchemaType(Object oldValue, Sche
throw new HoodieAvroSchemaException(String.format("cannot support rewrite value for schema type: %s since the old schema type is: %s", newSchema, oldSchema));
}

/**
* Use this instead of DECIMAL_CONVERSION.fromBytes() because that method does not add in precision
*
* bytes is the result of BigDecimal.unscaledValue().toByteArray();
* This is also what Conversions.DecimalConversion.toBytes() outputs inside a byte buffer
*/
public static BigDecimal convertBytesToBigDecimal(byte[] value, LogicalTypes.Decimal decimal) {
return new BigDecimal(new BigInteger(value),
decimal.getScale(), new MathContext(decimal.getPrecision(), RoundingMode.HALF_UP));
}

/**
* Checks whether the provided schema contains a decimal with a precision less than or equal to 18,
* which allows the decimal to be stored as int/long instead of a fixed size byte array in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.hudi.exception.HoodieException;
import org.apache.hudi.exception.HoodieJsonToAvroConversionException;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.avro.Conversions;
import org.apache.avro.LogicalType;
Expand Down Expand Up @@ -88,7 +89,7 @@ public MercifulJsonConverter() {
* Allows enabling sanitization and allows choice of invalidCharMask for sanitization
*/
public MercifulJsonConverter(boolean shouldSanitize, String invalidCharMask) {
this(new ObjectMapper(), shouldSanitize, invalidCharMask);
this(new ObjectMapper().enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS), shouldSanitize, invalidCharMask);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@

package org.apache.hudi.avro.processors;

import org.apache.hudi.avro.HoodieAvroUtils;
import org.apache.hudi.common.util.collection.Pair;

import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;

import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Base64;

public abstract class DecimalLogicalTypeProcessor extends JsonFieldProcessor {

Expand Down Expand Up @@ -51,17 +55,26 @@ protected static boolean isValidDecimalTypeConfig(Schema schema) {
*/
protected static Pair<Boolean, BigDecimal> parseObjectToBigDecimal(Object obj, Schema schema) {
BigDecimal bigDecimal = null;
if (obj instanceof Number) {
bigDecimal = BigDecimal.valueOf(((Number) obj).doubleValue());
}

// Case 2: Object is a number in String format.
if (obj instanceof String) {
try {
bigDecimal = new BigDecimal(((String) obj));
} catch (java.lang.NumberFormatException ignored) {
/* ignore */
LogicalTypes.Decimal logicalType = (LogicalTypes.Decimal) schema.getLogicalType();
try {
if (obj instanceof BigDecimal) {
bigDecimal = ((BigDecimal) obj).setScale(logicalType.getScale(), RoundingMode.UNNECESSARY);
} else if (obj instanceof String) {
// Case 2: Object is a number in String format.
try {
//encoded big decimal
bigDecimal = HoodieAvroUtils.convertBytesToBigDecimal(decodeStringToBigDecimalBytes(obj),
(LogicalTypes.Decimal) schema.getLogicalType());
} catch (IllegalArgumentException e) {
//no-op
}
}
// None fixed byte or fixed byte conversion failure would end up here.
if (bigDecimal == null) {
bigDecimal = new BigDecimal(obj.toString(), new MathContext(logicalType.getPrecision(), RoundingMode.UNNECESSARY)).setScale(logicalType.getScale(), RoundingMode.UNNECESSARY);
}
} catch (java.lang.NumberFormatException | ArithmeticException ignored) {
/* ignore */
}

if (bigDecimal == null) {
Expand All @@ -82,4 +95,8 @@ protected static Pair<Boolean, BigDecimal> parseObjectToBigDecimal(Object obj, S
}
return Pair.of(true, bigDecimal);
}

protected static byte[] decodeStringToBigDecimalBytes(Object value) {
return Base64.getDecoder().decode(((String) value).getBytes());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class MercifulJsonConverterTestBase {
private static final String DECIMAL_AVRO_FILE_INVALID_PATH = "/decimal-logical-type-invalid.avsc";
private static final String DECIMAL_AVRO_FILE_PATH = "/decimal-logical-type.avsc";
private static final String DECIMAL_FIXED_AVRO_FILE_PATH = "/decimal-logical-type-fixed-type.avsc";
protected static final String DECIMAL_ZERO_SCALE_AVRO_FILE_PATH = "/decimal-logical-type-zero-scale.avsc";
private static final String LOCAL_TIMESTAMP_MICRO_AVRO_FILE_PATH = "/local-timestamp-micros-logical-type.avsc";
private static final String LOCAL_TIMESTAMP_MILLI_AVRO_FILE_PATH = "/local-timestamp-millis-logical-type.avsc";
private static final String DURATION_AVRO_FILE_PATH_INVALID = "/duration-logical-type-invalid.avsc";
Expand Down Expand Up @@ -70,7 +71,7 @@ static Stream<Object> decimalGoodCases() {
Arguments.of(DECIMAL_AVRO_FILE_PATH, "123.45", null, 123.45, false),
// Test MIN/MAX allowed by the schema.
Arguments.of(DECIMAL_AVRO_FILE_PATH, "-999.99", "-999.99", null, false),
Arguments.of(DECIMAL_AVRO_FILE_PATH, "999.99",null, 999.99, false),
Arguments.of(DECIMAL_AVRO_FILE_PATH, "999.99", null, 999.99, false),
// Test 0.
Arguments.of(DECIMAL_AVRO_FILE_PATH, "0", null, 0D, false),
Arguments.of(DECIMAL_AVRO_FILE_PATH, "0", "0", null, false),
Expand All @@ -79,15 +80,26 @@ static Stream<Object> decimalGoodCases() {
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "123.45", "123.45", null, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "123.45", null, 123.45, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "-999.99", "-999.99", null, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "999.99",null, 999.99, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "999.99", null, 999.99, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "999", null, 999, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "999", null, 999L, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "999", null, (short) 999, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "100", null, (byte) 100, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "0", null, 0D, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "0", null, 0, false),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "0", "0", null, true),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "0", "000.00", null, true),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "123.45", null, null, true),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "123.45", null, 123.45, true),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "-999.99", null, null, true),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "999.99", null, 999.99, true),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "0", null, null, true)
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "0", null, null, true),
Arguments.of(DECIMAL_FIXED_AVRO_FILE_PATH, "0", null, null, true),
Arguments.of(DECIMAL_ZERO_SCALE_AVRO_FILE_PATH, "12345", "12345.0", null, false),
Arguments.of(DECIMAL_ZERO_SCALE_AVRO_FILE_PATH, "12345", null, 12345.0, false),
Arguments.of(DECIMAL_ZERO_SCALE_AVRO_FILE_PATH, "12345", null, 12345, false),
Arguments.of(DECIMAL_ZERO_SCALE_AVRO_FILE_PATH, "1230", null, 1.23e+3, false),
Arguments.of(DECIMAL_ZERO_SCALE_AVRO_FILE_PATH, "1230", "1.23e+3", null, false)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericFixed;
import org.apache.avro.generic.GenericRecord;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.io.IOException;
Expand All @@ -40,7 +40,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class TestMercifulJsonConverter extends MercifulJsonConverterTestBase {
Expand All @@ -64,7 +66,7 @@ public void basicConversion() throws IOException {
rec.put("favorite_number", number);
rec.put("favorite_color", color);

Assertions.assertEquals(rec, CONVERTER.convert(json, simpleSchema));
assertEquals(rec, CONVERTER.convert(json, simpleSchema));
}

@ParameterizedTest
Expand All @@ -78,11 +80,11 @@ void nestedJsonAsString(String nameInput) throws IOException {
rec.put("favorite_number", 1337);
rec.put("favorite_color", "10");

Assertions.assertEquals(rec, CONVERTER.convert(json, simpleSchema));
assertEquals(rec, CONVERTER.convert(json, simpleSchema));
}

private static final String DECIMAL_AVRO_FILE_PATH = "/decimal-logical-type.avsc";
private static final String DECIMAL_FIXED_AVRO_FILE_PATH = "/decimal-logical-type-fixed-type.avsc";

/**
* Covered case:
* Avro Logical Type: Decimal
Expand Down Expand Up @@ -122,7 +124,7 @@ void decimalLogicalTypeInvalidCaseTest(String avroFile, String strInput, Double
@ParameterizedTest
@MethodSource("decimalGoodCases")
void decimalLogicalTypeTest(String avroFilePath, String groundTruth, String strInput,
Double numInput, boolean testFixedByteArray) throws IOException {
Number numInput, boolean testFixedByteArray) throws IOException {
BigDecimal bigDecimal = new BigDecimal(groundTruth);
Map<String, Object> data = new HashMap<>();

Expand Down Expand Up @@ -153,7 +155,7 @@ void decimalLogicalTypeTest(String avroFilePath, String groundTruth, String strI
}

// Decide the decimal field expected output according to the test dimension.
if (avroFilePath.equals(DECIMAL_AVRO_FILE_PATH)) {
if (avroFilePath.equals(DECIMAL_AVRO_FILE_PATH) || avroFilePath.equals(DECIMAL_ZERO_SCALE_AVRO_FILE_PATH)) {
record.put("decimalField", conv.toBytes(bigDecimal, decimalFieldSchema, decimalFieldSchema.getLogicalType()));
} else {
record.put("decimalField", conv.toFixed(bigDecimal, decimalFieldSchema, decimalFieldSchema.getLogicalType()));
Expand All @@ -162,7 +164,58 @@ void decimalLogicalTypeTest(String avroFilePath, String groundTruth, String strI
String json = MAPPER.writeValueAsString(data);

GenericRecord real = CONVERTER.convert(json, schema);
Assertions.assertEquals(record, real);
assertEquals(record, real);
}

// tests cases where decimals with fraction `.0` can be interpreted as having scale > 0
@ParameterizedTest
@MethodSource("zeroScaleDecimalCases")
void zeroScaleDecimalConversion(String inputValue, String expected, boolean shouldConvert) {
Schema schema = new Schema.Parser().parse("{\"namespace\": \"example.avro\",\"type\": \"record\",\"name\": \"decimalLogicalType\",\"fields\": [{\"name\": \"decimalField\", "
+ "\"type\": {\"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 38, \"scale\": 0}}]}");
String json = String.format("{\"decimalField\":%s}", inputValue);

if (shouldConvert) {
GenericRecord record = new GenericData.Record(schema);
Conversions.DecimalConversion conv = new Conversions.DecimalConversion();
Schema decimalFieldSchema = schema.getField("decimalField").schema();
record.put("decimalField", conv.toBytes(new BigDecimal(expected), decimalFieldSchema, decimalFieldSchema.getLogicalType()));

GenericRecord real = CONVERTER.convert(json, schema);
assertEquals(record, real);
} else {
assertThrows(HoodieJsonToAvroConversionException.class, () -> CONVERTER.convert(json, schema));
}
}

static Stream<Object> zeroScaleDecimalCases() {
return Stream.of(
// Input value in JSON, expected decimal, whether conversion should be successful
// Values that can be converted
Arguments.of("0.0", "0", true),
Arguments.of("20.0", "20", true),
Arguments.of("320", "320", true),
Arguments.of("320.00", "320", true),
Arguments.of("-1320.00", "-1320", true),
Arguments.of("1520423524459", "1520423524459", true),
Arguments.of("1520423524459.0", "1520423524459", true),
Arguments.of("1000000000000000.0", "1000000000000000", true),
// Values that are big enough and out of range of int or long types
// Note that we can have at most 17 significant decimal digits in double values
Arguments.of("1.2684037455962608e+16", "12684037455962608", true),
Arguments.of("4.0100001e+16", "40100001000000000", true),
Arguments.of("3.52838e+17", "352838000000000000", true),
Arguments.of("9223372036853999600.0000", "9223372036853999600", true),
Arguments.of("999998887654321000000000000000.0000", "999998887654321000000000000000", true),
Arguments.of("-999998887654321000000000000000.0000", "-999998887654321000000000000000", true),
// Values covering high precision decimals that lose precision when converting to a double
Arguments.of("3.781239258857277e+16", "37812392588572770", true),
Arguments.of("1.6585135379127473e+18", "1658513537912747300", true),
// Values that should not be converted
Arguments.of("0.0001", null, false),
Arguments.of("300.9999", null, false),
Arguments.of("1928943043.0001", null, false)
);
}

private static final String DURATION_AVRO_FILE_PATH = "/duration-logical-type.avsc";
Expand Down Expand Up @@ -196,7 +249,7 @@ void durationLogicalTypeTest(int months, int days, int milliseconds) throws IOEx
durationRecord.put("duration", new GenericData.Fixed(schema.getField("duration").schema(), buffer.array()));

GenericRecord real = CONVERTER.convert(json, schema);
Assertions.assertEquals(durationRecord, real);
assertEquals(durationRecord, real);
}

@ParameterizedTest
Expand Down Expand Up @@ -235,7 +288,7 @@ void dateLogicalTypeTest(int groundTruth, Object dateInput) throws IOException {
data.put("dateField", dateInput);
String json = MAPPER.writeValueAsString(data);
GenericRecord real = CONVERTER.convert(json, schema);
Assertions.assertEquals(record, real);
assertEquals(record, real);
}

/**
Expand Down Expand Up @@ -285,7 +338,7 @@ void localTimestampLogicalTypeGoodCaseTest(
data.put("localTimestampMicrosField", timeMicro);
String json = MAPPER.writeValueAsString(data);
GenericRecord real = CONVERTER.convert(json, schema);
Assertions.assertEquals(record, real);
assertEquals(record, real);
}

private static final String LOCAL_TIMESTAMP_MILLI_AVRO_FILE_PATH = "/local-timestamp-millis-logical-type.avsc";
Expand Down Expand Up @@ -332,7 +385,7 @@ void timestampLogicalTypeGoodCaseTest(
data.put("timestampMicrosField", timeMicro);
String json = MAPPER.writeValueAsString(data);
GenericRecord real = CONVERTER.convert(json, schema);
Assertions.assertEquals(record, real);
assertEquals(record, real);
}

@ParameterizedTest
Expand Down Expand Up @@ -386,7 +439,7 @@ void timeLogicalTypeTest(Long expectedMicroSecOfDay, Object timeMilli, Object ti
data.put("timeMillisField", timeMilli);
String json = MAPPER.writeValueAsString(data);
GenericRecord real = CONVERTER.convert(json, schema);
Assertions.assertEquals(record, real);
assertEquals(record, real);
}

@ParameterizedTest
Expand Down Expand Up @@ -434,7 +487,7 @@ void uuidLogicalTypeTest(String uuid) throws IOException {
data.put("uuidField", uuid);
String json = MAPPER.writeValueAsString(data);
GenericRecord real = CONVERTER.convert(json, schema);
Assertions.assertEquals(record, real);
assertEquals(record, real);
}

@Test
Expand All @@ -456,7 +509,7 @@ public void conversionWithFieldNameSanitization() throws IOException {
rec.put("favorite__number", number);
rec.put("favorite__color__", color);

Assertions.assertEquals(rec, CONVERTER.convert(json, sanitizedSchema));
assertEquals(rec, CONVERTER.convert(json, sanitizedSchema));
}

@Test
Expand All @@ -479,6 +532,6 @@ public void conversionWithFieldNameAliases() throws IOException {
rec.put("favorite_number", number);
rec.put("favorite_color", color);

Assertions.assertEquals(rec, CONVERTER.convert(json, sanitizedSchema));
assertEquals(rec, CONVERTER.convert(json, sanitizedSchema));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
{
"namespace": "example.avro",
"type": "record",
"name": "decimalLogicalType",
"fields": [
{"name": "decimalField", "type": {"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 0}}
]
}

0 comments on commit 7013912

Please sign in to comment.