Skip to content

Commit

Permalink
apacheGH-41407: [C++] Use static method to fill scalar scratch space …
Browse files Browse the repository at this point in the history
…to prevent ub (apache#41421)

### Rationale for this change

In apache#40237, I introduced scalar scratch space filling in concrete scalar sub-class constructor, in which there is a static down-casting of `this` to sub-class pointer. Though this is common in CRTP, it happens in base cast constructor. And this is reported in apache#41407 to be UB by UBSAN's "vptr" sanitizing.

I'm not a language lawyer to tell if this is a true/false-positive. So I proposed two approaches:
1. The easy way: add suppression in [1], like we already did for `shared_ptr`. But apparently this won't be feasible if this is a true-positive (need some language lawyer's help to confirm).
2. The hard way: totally avoid this so-to-speak UB but may introduce more boilerplate code. This PR is the hard way.

[1] https://github.com/apache/arrow/blob/main/r/tools/ubsan.supp

### What changes are included in this PR?

Make `FillScratchSpace` static.

### Are these changes tested?

The existing UT should cover it well.

### Are there any user-facing changes?

None.

* GitHub Issue: apache#41407

Lead-authored-by: Ruoxi Sun <zanmato1984@gmail.com>
Co-authored-by: Rossi Sun <zanmato1984@gmail.com>
Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com>
Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
  • Loading branch information
2 people authored and tolleybot committed May 4, 2024
1 parent 7c0c3fd commit 8b02cac
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 63 deletions.
61 changes: 18 additions & 43 deletions cpp/src/arrow/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,25 +563,28 @@ Status Scalar::ValidateFull() const {
BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr<DataType> type)
: BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {}

void BinaryScalar::FillScratchSpace() {
void BinaryScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value) {
FillScalarScratchSpace(
scratch_space_,
scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->size()) : int32_t(0)});
}

void BinaryViewScalar::FillScratchSpace() {
void BinaryViewScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value) {
static_assert(sizeof(BinaryViewType::c_type) <= internal::kScalarScratchSpaceSize);
auto* view = new (&scratch_space_) BinaryViewType::c_type;
auto* view = new (scratch_space) BinaryViewType::c_type;
if (value) {
*view = util::ToBinaryView(std::string_view{*value}, 0, 0);
} else {
*view = {};
}
}

void LargeBinaryScalar::FillScratchSpace() {
void LargeBinaryScalar::FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value) {
FillScalarScratchSpace(
scratch_space_,
scratch_space,
{int64_t(0), value ? static_cast<int64_t>(value->size()) : int64_t(0)});
}

Expand Down Expand Up @@ -621,12 +624,6 @@ void ListScalar::FillScratchSpace(uint8_t* scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

void ListScalar::FillScratchSpace() {
FillScalarScratchSpace(
scratch_space_,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

LargeListScalar::LargeListScalar(std::shared_ptr<Array> value, bool is_valid)
: LargeListScalar(value, large_list(value->type()), is_valid) {}

Expand All @@ -636,11 +633,6 @@ void LargeListScalar::FillScratchSpace(uint8_t* scratch_space,
{int64_t(0), value ? value->length() : int64_t(0)});
}

void LargeListScalar::FillScratchSpace() {
FillScalarScratchSpace(scratch_space_,
{int64_t(0), value ? value->length() : int64_t(0)});
}

ListViewScalar::ListViewScalar(std::shared_ptr<Array> value, bool is_valid)
: ListViewScalar(value, list_view(value->type()), is_valid) {}

Expand All @@ -651,12 +643,6 @@ void ListViewScalar::FillScratchSpace(uint8_t* scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

void ListViewScalar::FillScratchSpace() {
FillScalarScratchSpace(
scratch_space_,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

LargeListViewScalar::LargeListViewScalar(std::shared_ptr<Array> value, bool is_valid)
: LargeListViewScalar(value, large_list_view(value->type()), is_valid) {}

Expand All @@ -666,11 +652,6 @@ void LargeListViewScalar::FillScratchSpace(uint8_t* scratch_space,
{int64_t(0), value ? value->length() : int64_t(0)});
}

void LargeListViewScalar::FillScratchSpace() {
FillScalarScratchSpace(scratch_space_,
{int64_t(0), value ? value->length() : int64_t(0)});
}

inline std::shared_ptr<DataType> MakeMapType(const std::shared_ptr<DataType>& pair_type) {
ARROW_CHECK_EQ(pair_type->id(), Type::STRUCT);
ARROW_CHECK_EQ(pair_type->num_fields(), 2);
Expand All @@ -687,12 +668,6 @@ void MapScalar::FillScratchSpace(uint8_t* scratch_space,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

void MapScalar::FillScratchSpace() {
FillScalarScratchSpace(
scratch_space_,
{int32_t(0), value ? static_cast<int32_t>(value->length()) : int32_t(0)});
}

FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr<Array> value,
std::shared_ptr<DataType> type, bool is_valid)
: BaseListScalar(std::move(value), std::move(type), is_valid) {
Expand Down Expand Up @@ -751,18 +726,18 @@ RunEndEncodedScalar::RunEndEncodedScalar(const std::shared_ptr<DataType>& type)

RunEndEncodedScalar::~RunEndEncodedScalar() = default;

void RunEndEncodedScalar::FillScratchSpace() {
auto run_end = run_end_type()->id();
void RunEndEncodedScalar::FillScratchSpace(uint8_t* scratch_space, const DataType& type) {
Type::type run_end = checked_cast<const RunEndEncodedType&>(type).run_end_type()->id();
switch (run_end) {
case Type::INT16:
FillScalarScratchSpace(scratch_space_, {int16_t(1)});
FillScalarScratchSpace(scratch_space, {int16_t(1)});
break;
case Type::INT32:
FillScalarScratchSpace(scratch_space_, {int32_t(1)});
FillScalarScratchSpace(scratch_space, {int32_t(1)});
break;
default:
DCHECK_EQ(run_end, Type::INT64);
FillScalarScratchSpace(scratch_space_, {int64_t(1)});
FillScalarScratchSpace(scratch_space, {int64_t(1)});
}
}

Expand Down Expand Up @@ -869,13 +844,13 @@ std::shared_ptr<Scalar> SparseUnionScalar::FromValue(std::shared_ptr<Scalar> val
return std::make_shared<SparseUnionScalar>(field_values, type_code, std::move(type));
}

void SparseUnionScalar::FillScratchSpace() {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(&scratch_space_);
void SparseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(scratch_space);
union_scratch_space->type_code = type_code;
}

void DenseUnionScalar::FillScratchSpace() {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(&scratch_space_);
void DenseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {
auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(scratch_space);
union_scratch_space->type_code = type_code;
FillScalarScratchSpace(union_scratch_space->offsets, {int32_t(0), int32_t(1)});
}
Expand Down
98 changes: 78 additions & 20 deletions cpp/src/arrow/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace {
alignas(int64_t) mutable uint8_t scratch_space_[kScalarScratchSpaceSize];

private:
ArraySpanFillFromScalarScratchSpace() { static_cast<Impl*>(this)->FillScratchSpace(); }
template <typename... Args>
explicit ArraySpanFillFromScalarScratchSpace(Args&&... args) {
Impl::FillScratchSpace(scratch_space_, std::forward<Args>(args)...);
}

ArraySpanFillFromScalarScratchSpace() = delete;

friend Impl;
};
Expand Down Expand Up @@ -278,11 +283,22 @@ struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase {
struct ARROW_EXPORT BinaryScalar
: public BaseBinaryScalar,
private internal::ArraySpanFillFromScalarScratchSpace<BinaryScalar> {
using BaseBinaryScalar::BaseBinaryScalar;
using TypeClass = BinaryType;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<BinaryScalar>;

explicit BinaryScalar(std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}

BinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(value), std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}

BinaryScalar(std::string s, std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(s), std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}

explicit BinaryScalar(std::shared_ptr<Buffer> value)
: BinaryScalar(std::move(value), binary()) {}

Expand All @@ -291,7 +307,8 @@ struct ARROW_EXPORT BinaryScalar
BinaryScalar() : BinaryScalar(binary()) {}

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand All @@ -312,11 +329,22 @@ struct ARROW_EXPORT StringScalar : public BinaryScalar {
struct ARROW_EXPORT BinaryViewScalar
: public BaseBinaryScalar,
private internal::ArraySpanFillFromScalarScratchSpace<BinaryViewScalar> {
using BaseBinaryScalar::BaseBinaryScalar;
using TypeClass = BinaryViewType;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<BinaryViewScalar>;

explicit BinaryViewScalar(std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}

BinaryViewScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(value), std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}

BinaryViewScalar(std::string s, std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(s), std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}

explicit BinaryViewScalar(std::shared_ptr<Buffer> value)
: BinaryViewScalar(std::move(value), binary_view()) {}

Expand All @@ -328,7 +356,8 @@ struct ARROW_EXPORT BinaryViewScalar
std::string_view view() const override { return std::string_view(*this->value); }

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand All @@ -350,11 +379,14 @@ struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar {
struct ARROW_EXPORT LargeBinaryScalar
: public BaseBinaryScalar,
private internal::ArraySpanFillFromScalarScratchSpace<LargeBinaryScalar> {
using BaseBinaryScalar::BaseBinaryScalar;
using TypeClass = LargeBinaryType;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<LargeBinaryScalar>;

explicit LargeBinaryScalar(std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}

LargeBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type)
: BaseBinaryScalar(std::move(value), std::move(type)),
ArraySpanFillFromScalarScratchSpace(this->value) {}
Expand All @@ -372,7 +404,8 @@ struct ARROW_EXPORT LargeBinaryScalar
LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {}

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Buffer>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand Down Expand Up @@ -555,14 +588,19 @@ struct ARROW_EXPORT ListScalar
: public BaseListScalar,
private internal::ArraySpanFillFromScalarScratchSpace<ListScalar> {
using TypeClass = ListType;
using BaseListScalar::BaseListScalar;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<ListScalar>;

ListScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type,
bool is_valid = true)
: BaseListScalar(std::move(value), std::move(type), is_valid),
ArraySpanFillFromScalarScratchSpace(this->value) {}

explicit ListScalar(std::shared_ptr<Array> value, bool is_valid = true);

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand All @@ -572,14 +610,19 @@ struct ARROW_EXPORT LargeListScalar
: public BaseListScalar,
private internal::ArraySpanFillFromScalarScratchSpace<LargeListScalar> {
using TypeClass = LargeListType;
using BaseListScalar::BaseListScalar;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<LargeListScalar>;

LargeListScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type,
bool is_valid = true)
: BaseListScalar(std::move(value), std::move(type), is_valid),
ArraySpanFillFromScalarScratchSpace(this->value) {}

explicit LargeListScalar(std::shared_ptr<Array> value, bool is_valid = true);

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand All @@ -589,14 +632,19 @@ struct ARROW_EXPORT ListViewScalar
: public BaseListScalar,
private internal::ArraySpanFillFromScalarScratchSpace<ListViewScalar> {
using TypeClass = ListViewType;
using BaseListScalar::BaseListScalar;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<ListViewScalar>;

ListViewScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type,
bool is_valid = true)
: BaseListScalar(std::move(value), std::move(type), is_valid),
ArraySpanFillFromScalarScratchSpace(this->value) {}

explicit ListViewScalar(std::shared_ptr<Array> value, bool is_valid = true);

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand All @@ -606,14 +654,19 @@ struct ARROW_EXPORT LargeListViewScalar
: public BaseListScalar,
private internal::ArraySpanFillFromScalarScratchSpace<LargeListViewScalar> {
using TypeClass = LargeListViewType;
using BaseListScalar::BaseListScalar;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<LargeListViewScalar>;

LargeListViewScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type,
bool is_valid = true)
: BaseListScalar(std::move(value), std::move(type), is_valid),
ArraySpanFillFromScalarScratchSpace(this->value) {}

explicit LargeListViewScalar(std::shared_ptr<Array> value, bool is_valid = true);

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand All @@ -623,14 +676,19 @@ struct ARROW_EXPORT MapScalar
: public BaseListScalar,
private internal::ArraySpanFillFromScalarScratchSpace<MapScalar> {
using TypeClass = MapType;
using BaseListScalar::BaseListScalar;
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<MapScalar>;

MapScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type,
bool is_valid = true)
: BaseListScalar(std::move(value), std::move(type), is_valid),
ArraySpanFillFromScalarScratchSpace(this->value) {}

explicit MapScalar(std::shared_ptr<Array> value, bool is_valid = true);

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space,
const std::shared_ptr<Array>& value);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand Down Expand Up @@ -712,7 +770,7 @@ struct ARROW_EXPORT SparseUnionScalar
std::shared_ptr<DataType> type);

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand Down Expand Up @@ -742,7 +800,7 @@ struct ARROW_EXPORT DenseUnionScalar
value(std::move(value)) {}

private:
void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand Down Expand Up @@ -778,7 +836,7 @@ struct ARROW_EXPORT RunEndEncodedScalar
private:
const TypeClass& ree_type() const { return internal::checked_cast<TypeClass&>(*type); }

void FillScratchSpace();
static void FillScratchSpace(uint8_t* scratch_space, const DataType& type);

friend ArraySpan;
friend ArraySpanFillFromScalarScratchSpace;
Expand Down

0 comments on commit 8b02cac

Please sign in to comment.