Skip to content

Commit

Permalink
Archery C++ round trip working. Java disabled. Fix c-bridge (apache#8268
Browse files Browse the repository at this point in the history
)

Archery lint issue needs to be fixed, i'll do that in a follow-up
  • Loading branch information
emkornfield committed Oct 17, 2020
1 parent 68b13c2 commit 6279531
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 16 deletions.
22 changes: 18 additions & 4 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,15 @@ struct SchemaExporter {
}

Status Visit(const DecimalType& type) {
return SetFormat("d:" + std::to_string(type.precision()) + "," +
std::to_string(type.scale()));
if (type.bit_width() == 128) {
// 128 is the default bit-width
return SetFormat("d:" + std::to_string(type.precision()) + "," +
std::to_string(type.scale()));
} else {
return SetFormat("d:" + std::to_string(type.precision()) + "," +
std::to_string(type.scale()) + "," +
std::to_string(type.bit_width()));
}
}

Status Visit(const BinaryType& type) { return SetFormat("z"); }
Expand Down Expand Up @@ -973,13 +980,20 @@ struct SchemaImporter {
Status ProcessDecimal() {
RETURN_NOT_OK(f_parser_.CheckNext(':'));
ARROW_ASSIGN_OR_RAISE(auto prec_scale, f_parser_.ParseInts(f_parser_.Rest()));
if (prec_scale.size() != 2) {
// 3 elements indicates bit width was communicated as well.
if (prec_scale.size() != 2 && prec_scale.size() != 3) {
return f_parser_.Invalid();
}
if (prec_scale[0] <= 0 || prec_scale[1] <= 0) {
return f_parser_.Invalid();
}
type_ = decimal(prec_scale[0], prec_scale[1]);
if (prec_scale.size() == 2 || prec_scale[2] == 128) {
type_ = decimal(prec_scale[0], prec_scale[1]);
} else if (prec_scale[2] == 256) {
type_ = decimal256(prec_scale[0], prec_scale[1]);
} else {
return f_parser_.Invalid();
}
return Status::OK();
}

Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/c/bridge_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ TEST_F(TestSchemaExport, Primitive) {
TestPrimitive(large_utf8(), "U");

TestPrimitive(decimal(16, 4), "d:16,4");
TestPrimitive(decimal256(16, 4), "d:16,4,256");
}

TEST_F(TestSchemaExport, Temporal) {
Expand Down Expand Up @@ -740,6 +741,7 @@ TEST_F(TestArrayExport, Primitive) {
TestPrimitive(large_utf8(), R"(["foo", "bar", null])");

TestPrimitive(decimal(16, 4), R"(["1234.5670", null])");
TestPrimitive(decimal256(16, 4), R"(["1234.5670", null])");
}

TEST_F(TestArrayExport, PrimitiveSliced) {
Expand Down
28 changes: 21 additions & 7 deletions cpp/src/arrow/ipc/metadata_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ static inline TimeUnit::type FromFlatbufferUnit(flatbuf::TimeUnit unit) {
return TimeUnit::SECOND;
}

constexpr int32_t kDecimalBitWidth = 128;
constexpr int32_t kDecimalBitWidth128 = 128;
constexpr int32_t kDecimalBitWidth256 = 256;

Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
const std::vector<std::shared_ptr<Field>>& children,
Expand Down Expand Up @@ -273,10 +274,13 @@ Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
return Status::OK();
case flatbuf::Type::Decimal: {
auto dec_type = static_cast<const flatbuf::Decimal*>(type_data);
if (dec_type->bitWidth() != kDecimalBitWidth) {
return Status::Invalid("Library only supports 128-bit decimal values");
if (dec_type->bitWidth() == kDecimalBitWidth128) {
return Decimal128Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
} else if (dec_type->bitWidth() == kDecimalBitWidth256) {
return Decimal256Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
} else {
return Status::Invalid("Library only supports 128-bit or 256-bit decimal values");
}
return Decimal128Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
}
case flatbuf::Type::Date: {
auto date_type = static_cast<const flatbuf::Date*>(type_data);
Expand Down Expand Up @@ -594,11 +598,21 @@ class FieldToFlatbufferVisitor {
return Status::OK();
}

Status Visit(const DecimalType& type) {
Status Visit(const Decimal128Type& type) {
const auto& dec_type = checked_cast<const Decimal128Type&>(type);
fb_type_ = flatbuf::Type::Decimal;
type_offset_ =
flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale()).Union();
type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(),
/*bitWidth=*/128)
.Union();
return Status::OK();
}

Status Visit(const Decimal256Type& type) {
const auto& dec_type = checked_cast<const Decimal256Type&>(type);
fb_type_ = flatbuf::Type::Decimal;
type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(),
/*bitWith=*/256)
.Union();
return Status::OK();
}

Expand Down
19 changes: 14 additions & 5 deletions dev/archery/archery/integration/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,15 @@ def generate_column(self, size, name=None):

DECIMAL_PRECISION_TO_VALUE = {
key: (1 << (8 * i - 1)) - 1 for i, key in enumerate(
[1, 3, 5, 7, 10, 12, 15, 17, 19, 22, 24, 27, 29, 32, 34, 36],
[1, 3, 5, 7, 10, 12, 15, 17, 19, 22, 24, 27, 29, 32, 34, 36,
38, 40, 42, 44, 50, 60, 70],
start=1,
)
}


def decimal_range_from_precision(precision):
assert 1 <= precision <= 38
assert 1 <= precision <= 76
try:
max_value = DECIMAL_PRECISION_TO_VALUE[precision]
except KeyError:
Expand All @@ -417,7 +418,7 @@ def decimal_range_from_precision(precision):


class DecimalField(PrimitiveField):
def __init__(self, name, precision, scale, bit_width=128, *,
def __init__(self, name, precision, scale, bit_width, *,
nullable=True, metadata=None):
super().__init__(name, nullable=True,
metadata=metadata)
Expand All @@ -434,6 +435,7 @@ def _get_type(self):
('name', 'decimal'),
('precision', self.precision),
('scale', self.scale),
('bitWidth', self.bit_width),
])

def generate_column(self, size, name=None):
Expand All @@ -448,7 +450,7 @@ def generate_column(self, size, name=None):

class DecimalColumn(PrimitiveColumn):

def __init__(self, name, count, is_valid, values, bit_width=128):
def __init__(self, name, count, is_valid, values, bit_width):
super().__init__(name, count, is_valid, values)
self.bit_width = bit_width

Expand Down Expand Up @@ -1274,8 +1276,13 @@ def generate_null_trivial_case(batch_sizes):

def generate_decimal_case():
fields = [
DecimalField(name='f{}'.format(i), precision=precision, scale=2)
DecimalField(name='f{}'.format(i), precision=precision, scale=2,
bit_width=128)
for i, precision in enumerate(range(3, 39))
] + [
DecimalField(name='f{}'.format(i), precision=precision, scale=5,
bit_width=256)
for i, precision in enumerate(range(37, 70))
]

possible_batch_sizes = 7, 10
Expand Down Expand Up @@ -1511,6 +1518,8 @@ def _temp_path():
generate_decimal_case()
.skip_category('Go') # TODO(ARROW-7948): Decimal + Go
.skip_category('Rust'),
.skip_category('Java'),


generate_datetime_case(),

Expand Down

0 comments on commit 6279531

Please sign in to comment.