Skip to content

Commit

Permalink
Make value immutable for all scalars with scratch space (#6)
Browse files Browse the repository at this point in the history
* Make value immutable for all scalars with scratch space

* Glib change for list scalar

* Fix

* Fix glib build
  • Loading branch information
zanmato1984 authored Apr 8, 2024
1 parent c6bb2ec commit 1510e20
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 59 deletions.
3 changes: 2 additions & 1 deletion c_glib/arrow-glib/scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1984,7 +1984,8 @@ garrow_base_list_scalar_get_value(GArrowBaseListScalar *scalar)
if (!priv->value) {
const auto arrow_scalar = std::static_pointer_cast<arrow::BaseListScalar>(
garrow_scalar_get_raw(GARROW_SCALAR(scalar)));
priv->value = garrow_array_new_raw(&(arrow_scalar->value));
priv->value = garrow_array_new_raw(
const_cast<std::shared_ptr<arrow::Array> *>(&(arrow_scalar->value)));
}
return priv->value;
}
Expand Down
21 changes: 14 additions & 7 deletions cpp/src/arrow/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s, bool is_valid)
BaseListScalar::BaseListScalar(std::shared_ptr<Array> value,
std::shared_ptr<DataType> type, bool is_valid)
: Scalar{std::move(type), is_valid}, value(std::move(value)) {
ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type()));
if (this->value) {
ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type()));
}
}

ListScalar::ListScalar(std::shared_ptr<Array> value, bool is_valid)
Expand Down Expand Up @@ -669,8 +671,10 @@ void MapScalar::FillScratchSpace() {
FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr<Array> value,
std::shared_ptr<DataType> type, bool is_valid)
: BaseListScalar(std::move(value), std::move(type), is_valid) {
ARROW_CHECK_EQ(this->value->length(),
checked_cast<const FixedSizeListType&>(*this->type).list_size());
if (value) {
ARROW_CHECK_EQ(this->value->length(),
checked_cast<const FixedSizeListType&>(*this->type).list_size());
}
}

FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr<Array> value, bool is_valid)
Expand Down Expand Up @@ -811,11 +815,14 @@ SparseUnionScalar::SparseUnionScalar(ValueType value, int8_t type_code,
std::shared_ptr<DataType> type)
: UnionScalar(std::move(type), type_code, /*is_valid=*/true),
value(std::move(value)) {
this->child_id =
checked_cast<const SparseUnionType&>(*this->type).child_ids()[type_code];
const auto child_ids = checked_cast<const SparseUnionType&>(*this->type).child_ids();
if (type_code >= 0 && static_cast<size_t>(type_code) < child_ids.size() &&
child_ids[type_code] != UnionType::kInvalidChildId) {
this->child_id = child_ids[type_code];

// Fix nullness based on whether the selected child is null
this->is_valid = this->value[this->child_id]->is_valid;
// Fix nullness based on whether the selected child is null
this->is_valid = this->value[this->child_id]->is_valid;
}
}

std::shared_ptr<Scalar> SparseUnionScalar::FromValue(std::shared_ptr<Scalar> value,
Expand Down
11 changes: 5 additions & 6 deletions cpp/src/arrow/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,12 @@ struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar<Decimal256Type, Deci
};

struct ARROW_EXPORT BaseListScalar : public Scalar {
using Scalar::Scalar;
using ValueType = std::shared_ptr<Array>;

BaseListScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type,
bool is_valid = true);

std::shared_ptr<Array> value;
const std::shared_ptr<Array> value;
};

struct ARROW_EXPORT ListScalar
Expand Down Expand Up @@ -659,7 +658,7 @@ struct ARROW_EXPORT StructScalar : public Scalar {
};

struct ARROW_EXPORT UnionScalar : public Scalar {
int8_t type_code;
const int8_t type_code;

virtual const std::shared_ptr<Scalar>& child_value() const = 0;

Expand Down Expand Up @@ -687,7 +686,7 @@ struct ARROW_EXPORT SparseUnionScalar
// nonetheless construct a vector of scalars, one per union value, to have
// enough data to reconstruct a valid ArraySpan of length 1 from this scalar
using ValueType = std::vector<std::shared_ptr<Scalar>>;
ValueType value;
const ValueType value;

// The value index corresponding to the active type code
int child_id;
Expand Down Expand Up @@ -720,7 +719,7 @@ struct ARROW_EXPORT DenseUnionScalar
// For DenseUnionScalar, we can make a valid ArraySpan of length 1 from this
// scalar
using ValueType = std::shared_ptr<Scalar>;
ValueType value;
const ValueType value;

const std::shared_ptr<Scalar>& child_value() const override { return this->value; }

Expand All @@ -743,7 +742,7 @@ struct ARROW_EXPORT RunEndEncodedScalar
using ArraySpanFillFromScalarScratchSpace =
internal::ArraySpanFillFromScalarScratchSpace<RunEndEncodedScalar>;

ValueType value;
const ValueType value;

RunEndEncodedScalar(std::shared_ptr<Scalar> value, std::shared_ptr<DataType> type);

Expand Down
129 changes: 84 additions & 45 deletions cpp/src/arrow/scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1167,24 +1167,25 @@ class TestListLikeScalar : public ::testing::Test {
}

void TestValidateErrors() {
ScalarType scalar(value_);
scalar.is_valid = false;
ASSERT_OK(scalar.ValidateFull());

// Value must be defined
scalar = ScalarType(value_);
scalar.value = nullptr;
AssertValidationFails(scalar);
{
ScalarType scalar(value_);
scalar.is_valid = false;
ASSERT_OK(scalar.ValidateFull());
}

// Inconsistent child type
scalar = ScalarType(value_);
scalar.value = ArrayFromJSON(int32(), "[1, 2, null]");
AssertValidationFails(scalar);
{
// Value must be defined
ScalarType scalar(nullptr, type_);
scalar.is_valid = true;
AssertValidationFails(scalar);
}

// Invalid UTF8 in child data
scalar = ScalarType(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]"));
ASSERT_OK(scalar.Validate());
ASSERT_RAISES(Invalid, scalar.ValidateFull());
{
// Invalid UTF8 in child data
ScalarType scalar(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]"));
ASSERT_OK(scalar.Validate());
ASSERT_RAISES(Invalid, scalar.ValidateFull());
}
}

void TestHashing() {
Expand Down Expand Up @@ -1576,17 +1577,41 @@ void CheckGetNullUnionScalar(const Array& arr, int64_t index) {
ASSERT_FALSE(checked_cast<const UnionScalar&>(*scalar).child_value()->is_valid);
}

std::shared_ptr<Scalar> MakeUnionScalar(const SparseUnionType& type, int8_t type_code,
std::shared_ptr<Scalar> field_value,
int field_index) {
ScalarVector field_values;
for (int i = 0; i < type.num_fields(); ++i) {
if (i == field_index) {
field_values.emplace_back(std::move(field_value));
} else {
field_values.emplace_back(MakeNullScalar(type.field(i)->type()));
}
}
return std::make_shared<SparseUnionScalar>(std::move(field_values), type_code,
type.GetSharedPtr());
}

std::shared_ptr<Scalar> MakeUnionScalar(const SparseUnionType& type,
std::shared_ptr<Scalar> field_value,
int field_index) {
return SparseUnionScalar::FromValue(field_value, field_index, type.GetSharedPtr());
return SparseUnionScalar::FromValue(std::move(field_value), field_index,
type.GetSharedPtr());
}

std::shared_ptr<Scalar> MakeUnionScalar(const DenseUnionType& type, int8_t type_code,
std::shared_ptr<Scalar> field_value,
int field_index) {
return std::make_shared<DenseUnionScalar>(std::move(field_value), type_code,
type.GetSharedPtr());
}

std::shared_ptr<Scalar> MakeUnionScalar(const DenseUnionType& type,
std::shared_ptr<Scalar> field_value,
int field_index) {
int8_t type_code = type.type_codes()[field_index];
return std::make_shared<DenseUnionScalar>(field_value, type_code, type.GetSharedPtr());
return std::make_shared<DenseUnionScalar>(std::move(field_value), type_code,
type.GetSharedPtr());
}

std::shared_ptr<Scalar> MakeSpecificNullScalar(const DenseUnionType& type,
Expand Down Expand Up @@ -1634,7 +1659,13 @@ class TestUnionScalar : public ::testing::Test {

std::shared_ptr<Scalar> ScalarFromValue(int field_index,
std::shared_ptr<Scalar> field_value) {
return MakeUnionScalar(*union_type_, field_value, field_index);
return MakeUnionScalar(*union_type_, std::move(field_value), field_index);
}

std::shared_ptr<Scalar> ScalarFromTypeCodeAndValue(int8_t type_code,
std::shared_ptr<Scalar> field_value,
int field_index) {
return MakeUnionScalar(*union_type_, type_code, std::move(field_value), field_index);
}

std::shared_ptr<Scalar> SpecificNull(int field_index) {
Expand All @@ -1652,40 +1683,48 @@ class TestUnionScalar : public ::testing::Test {
}

void TestValidateErrors() {
// Type code doesn't exist
auto scalar = ScalarFromValue(0, alpha_);
UnionScalar* union_scalar = static_cast<UnionScalar*>(scalar.get());

// Invalid type code
union_scalar->type_code = 0;
AssertValidationFails(*union_scalar);
{
// Invalid type code
auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0);
AssertValidationFails(*scalar);
}

union_scalar->is_valid = false;
AssertValidationFails(*union_scalar);
{
auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0);
scalar->is_valid = false;
AssertValidationFails(*scalar);
}

union_scalar->type_code = -42;
union_scalar->is_valid = true;
AssertValidationFails(*union_scalar);
{
auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0);
AssertValidationFails(*scalar);
}

union_scalar->is_valid = false;
AssertValidationFails(*union_scalar);
{
auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0);
scalar->is_valid = false;
AssertValidationFails(*scalar);
}

// Type code doesn't correspond to child type
if (type_->id() == ::arrow::Type::DENSE_UNION) {
union_scalar->type_code = 42;
union_scalar->is_valid = true;
AssertValidationFails(*union_scalar);

scalar = ScalarFromValue(2, two_);
union_scalar = static_cast<UnionScalar*>(scalar.get());
union_scalar->type_code = 3;
AssertValidationFails(*union_scalar);
{
auto scalar = ScalarFromTypeCodeAndValue(42, alpha_, 0);
AssertValidationFails(*scalar);
}

{
auto scalar = ScalarFromTypeCodeAndValue(3, two_, 2);
AssertValidationFails(*scalar);
}
}

// underlying value has invalid UTF8
scalar = ScalarFromValue(0, std::make_shared<StringScalar>("\xff"));
ASSERT_OK(scalar->Validate());
ASSERT_RAISES(Invalid, scalar->ValidateFull());
{
// underlying value has invalid UTF8
auto scalar = ScalarFromValue(0, std::make_shared<StringScalar>("\xff"));
ASSERT_OK(scalar->Validate());
ASSERT_RAISES(Invalid, scalar->ValidateFull());
}
}

void TestEquals() {
Expand Down

0 comments on commit 1510e20

Please sign in to comment.