Skip to content

Commit

Permalink
apacheGH-41407: [C++] Use static method to fill scalar scratch space …
Browse files Browse the repository at this point in the history
…to prevent ub (apache#41421)

### Rationale for this change

In apache#40237, I introduced scalar scratch space filling in concrete scalar sub-class constructor, in which there is a static down-casting of `this` to sub-class pointer. Though this is common in CRTP, it happens in base cast constructor. And this is reported in apache#41407 to be UB by UBSAN's "vptr" sanitizing.

I'm not a language lawyer to tell if this is a true/false-positive. So I proposed two approaches:
1. The easy way: add suppression in [1], like we already did for `shared_ptr`. But apparently this won't be feasible if this is a true-positive (need some language lawyer's help to confirm).
2. The hard way: totally avoid this so-to-speak UB but may introduce more boilerplate code. This PR is the hard way.

[1] https://github.com/apache/arrow/blob/main/r/tools/ubsan.supp

### What changes are included in this PR?

Make `FillScratchSpace` static.

### Are these changes tested?

The existing UT should cover it well.

### Are there any user-facing changes?

None.

* GitHub Issue: apache#41407

Lead-authored-by: Ruoxi Sun <zanmato1984@gmail.com>
Co-authored-by: Rossi Sun <zanmato1984@gmail.com>
Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>
Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
  • Loading branch information
2 people authored and tolleybot committed May 2, 2024
1 parent f046a15 commit 391efe7
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 55 deletions.
73 changes: 42 additions & 31 deletions cpp/src/arrow/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,25 +563,28 @@ Status Scalar::ValidateFull() const {
BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr<DataType> type)
: BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {}

void BinaryScalar::FillScratchSpace() {
void BinaryScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value) {
FillScalarScratchSpace(
scratch_space_,
scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->size()) : int32_t(0)});
}

void BinaryViewScalar::FillScratchSpace() {
void BinaryViewScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value) {
static_assert(sizeof(BinaryViewType::c_type) <= internal::kScalarScratchSpaceSize);
auto* view = new (&scratch_space_) BinaryViewType::c_type;
auto* view = new (scratch_space) BinaryViewType::c_type;
if (value) {
*view = util::ToBinaryView(std::string_view{*value}, 0, 0);
} else {
*view = {};
}
}

void LargeBinaryScalar::FillScratchSpace() {
void LargeBinaryScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value) {
FillScalarScratchSpace(
scratch_space_,
scratch_space,
{int64_t(0), value ? static_cast<int64_t>(value->size()) : int64_t(0)});
}

Expand Down Expand Up @@ -612,36 +615,40 @@ BaseListScalar::BaseListScalar(std::shared_ptr<Array> value,
}

ListScalar::ListScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, list(value->type()), is_valid) {}
: ListScalar(value, list(value->type()), is_valid) {}

void ListScalar::FillScratchSpace() {
void ListScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value) {
FillScalarScratchSpace(
scratch_space_,
scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

LargeListScalar::LargeListScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, large_list(value->type()), is_valid) {}
: LargeListScalar(value, large_list(value->type()), is_valid) {}

void LargeListScalar::FillScratchSpace() {
FillScalarScratchSpace(scratch_space_,
void LargeListScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value) {
FillScalarScratchSpace(scratch_space,
{int64_t(0), value ? value->length() : int64_t(0)});
}

ListViewScalar::ListViewScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, list_view(value->type()), is_valid) {}
: ListViewScalar(value, list_view(value->type()), is_valid) {}

void ListViewScalar::FillScratchSpace() {
void ListViewScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value) {
FillScalarScratchSpace(
scratch_space_,
scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

LargeListViewScalar::LargeListViewScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, large_list_view(value->type()), is_valid) {}
: LargeListViewScalar(value, large_list_view(value->type()), is_valid) {}

void LargeListViewScalar::FillScratchSpace() {
FillScalarScratchSpace(scratch_space_,
void LargeListViewScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value) {
FillScalarScratchSpace(scratch_space,
{int64_t(0), value ? value->length() : int64_t(0)});
}

Expand All @@ -652,11 +659,12 @@ inline std::shared_ptr<DataType> MakeMapType(const std::shared_ptr<DataType>& pa
}

MapScalar::MapScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, MakeMapType(value->type()), is_valid) {}
: MapScalar(value, MakeMapType(value->type()), is_valid) {}

void MapScalar::FillScratchSpace() {
void MapScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value) {
FillScalarScratchSpace(
scratch_space_,
scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

Expand Down Expand Up @@ -705,7 +713,9 @@ Result<std::shared_ptr<Scalar>> StructScalar::field(FieldRef ref) const {

RunEndEncodedScalar::RunEndEncodedScalar(std::shared_ptr<Scalar> value,
std::shared_ptr<DataType> type)
: Scalar{std::move(type), value->is_valid}, value{std::move(value)} {
: Scalar{std::move(type), value->is_valid},
ArraySpanFillFromScalarScratchSpace(*this->type),
value{std::move(value)} {
ARROW_CHECK_EQ(this->type->id(), Type::RUN_END_ENCODED);
}

Expand All @@ -716,18 +726,18 @@ RunEndEncodedScalar::RunEndEncodedScalar(const std::shared_ptr<DataType>& type)

RunEndEncodedScalar::~RunEndEncodedScalar() = default;

void RunEndEncodedScalar::FillScratchSpace() {
auto run_end = run_end_type()->id();
void RunEndEncodedScalar::FillScratchSpace(uint8_t* scratch_space, const DataType& type) {
Type::type run_end = checked_cast<const RunEndEncodedType&>(type).run_end_type()->id();
switch (run_end) {
case Type::INT16:
FillScalarScratchSpace(scratch_space_, {int16_t(1)});
FillScalarScratchSpace(scratch_space, {int16_t(1)});
break;
case Type::INT32:
FillScalarScratchSpace(scratch_space_, {int32_t(1)});
FillScalarScratchSpace(scratch_space, {int32_t(1)});
break;
default:
DCHECK_EQ(run_end, Type::INT64);
FillScalarScratchSpace(scratch_space_, {int64_t(1)});
FillScalarScratchSpace(scratch_space, {int64_t(1)});
}
}

Expand Down Expand Up @@ -806,6 +816,7 @@ Result<TimestampScalar> TimestampScalar::FromISO8601(std::string_view iso8601,
SparseUnionScalar::SparseUnionScalar(ValueType value, int8_t type_code,
std::shared_ptr<DataType> type)
: UnionScalar(std::move(type), type_code, /*is_valid=*/true),
ArraySpanFillFromScalarScratchSpace(type_code),
value(std::move(value)) {
const auto child_ids = checked_cast<const SparseUnionType&>(*this->type).child_ids();
if (type_code >= 0 && static_cast<size_t>(type_code) < child_ids.size() &&
Expand Down Expand Up @@ -833,13 +844,13 @@ std::shared_ptr<Scalar> SparseUnionScalar::FromValue(std::shared_ptr<Scalar> val
return std::make_shared<SparseUnionScalar>(field_values, type_code, std::move(type));
}

void SparseUnionScalar::FillScratchSpace() {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(&scratch_space_);
void SparseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(scratch_space);
union_scratch_space->type_code = type_code;
}

void DenseUnionScalar::FillScratchSpace() {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(&scratch_space_);
void DenseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(scratch_space);
union_scratch_space->type_code = type_code;
FillScalarScratchSpace(union_scratch_space->offsets, {int32_t(0), int32_t(1)});
}
Expand Down
Loading

0 comments on commit 391efe7

Please sign in to comment.