Skip to content

Commit

Permalink
Merge branch 'apache:main' into employee-bootcamp
Browse files Browse the repository at this point in the history
  • Loading branch information
ArgusLi authored Sep 10, 2024
2 parents 8bd2ef2 + 44b72d5 commit e82a70e
Show file tree
Hide file tree
Showing 25 changed files with 343 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/r.yml
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ jobs:
echo "$HOME/.local/bin" >> $GITHUB_PATH
- run: mkdir r/windows
- name: Download artifacts
uses: actions/download-artifact@v4.1.7
uses: actions/download-artifact@v4.1.8
with:
name: libarrow-rtools40-ucrt64.zip
path: r/windows
Expand Down
73 changes: 73 additions & 0 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <unordered_set>

#include "arrow/chunked_array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/expression_internal.h"
Expand Down Expand Up @@ -1242,6 +1243,72 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}

/// Simplify an `is_in` call against an inequality guarantee.
///
/// We avoid the complexity of fully simplifying EQUAL comparisons to true
/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
///
/// \pre `is_in_call` is a call to the `is_in` function
/// \return a simplified expression, or nullopt if no simplification occurred
static Result<std::optional<Expression>> SimplifyIsIn(
const Inequality& guarantee, const Expression::Call* is_in_call) {
DCHECK_EQ(is_in_call->function_name, "is_in");

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;
if (*lhs.field_ref() != guarantee.target) return std::nullopt;

FilterOptions::NullSelectionBehavior null_selection;
switch (options->null_matching_behavior) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return std::nullopt;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{options->value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(options->value_set, filter_mask, filter_options));

if (simplified_value_set.length() == 0) return literal(false);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(
Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

/// \brief Simplify the given expression given this inequality as a guarantee.
Result<Expression> Simplify(Expression expr) {
const auto& guarantee = *this;
Expand All @@ -1258,6 +1325,12 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) : literal(false);
}

if (call->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(std::optional<Expression> result,
SimplifyIsIn(guarantee, call));
return result.value_or(expr);
}

auto cmp = Comparison::Get(expr);
if (!cmp) return expr;

Expand Down
173 changes: 173 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/array/builder_primitive.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/registry.h"
Expand Down Expand Up @@ -1616,6 +1617,144 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) {
true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group
}

TEST(Expression, SimplifyIsIn) {
auto is_in = [](Expression field, std::shared_ptr<DataType> value_set_type,
std::string json_array,
SetLookupOptions::NullMatchingBehavior null_matching_behavior) {
SetLookupOptions options{ArrayFromJSON(value_set_type, json_array),
null_matching_behavior};
return call("is_in", {field}, options);
};

for (SetLookupOptions::NullMatchingBehavior null_matching : {
SetLookupOptions::MATCH,
SetLookupOptions::SKIP,
SetLookupOptions::EMIT_NULL,
SetLookupOptions::INCONCLUSIVE,
}) {
Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(equal(field_ref("i32"), literal(6)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching));

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(9)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(less_equal(field_ref("i32"), literal(0)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(0)))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(less_equal(field_ref("i32"), literal(9)))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(and_(less_equal(field_ref("i32"), literal(7)),
greater(field_ref("i32"), literal(4))))
.Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching));

Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching));

Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching));
}

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::MATCH),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3,null]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::SKIP),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::EMIT_NULL));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();
}

TEST(Expression, SimplifyThenExecute) {
auto filter =
or_({equal(field_ref("f32"), literal(0)),
Expand Down Expand Up @@ -1643,6 +1782,40 @@ TEST(Expression, SimplifyThenExecute) {
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}

TEST(Expression, SimplifyIsInThenExecute) {
auto input = RecordBatchFromJSON(kBoringSchema, R"([
{"i64": 2, "i32": 5},
{"i64": 5, "i32": 6},
{"i64": 3, "i32": 6},
{"i64": 3, "i32": 5},
{"i64": 4, "i32": 5},
{"i64": 2, "i32": 7},
{"i64": 5, "i32": 5}
])");

std::vector<Expression> guarantees{greater(field_ref("i64"), literal(1)),
greater_equal(field_ref("i32"), literal(5)),
less_equal(field_ref("i64"), literal(5))};

for (const Expression& guarantee : guarantees) {
auto filter =
call("is_in", {guarantee.call()->arguments[0]},
compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true});
ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee));

Datum evaluated, simplified_evaluated;
ExpectExecute(filter, input, &evaluated);
ExpectExecute(simplified, input, &simplified_evaluated);
if (simplified_evaluated.is_scalar()) {
ASSERT_OK_AND_ASSIGN(
simplified_evaluated,
MakeArrayFromScalar(*simplified_evaluated.scalar(), evaluated.length()));
}
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}
}

TEST(Expression, Filter) {
auto ExpectFilter = [](Expression filter, std::string batch_json) {
ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean())));
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class ArrayLoader {
RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
} else {
buffer_index_++;
out_->buffers[1].reset(new Buffer(nullptr, 0));
out_->buffers[1] = std::make_shared<Buffer>(nullptr, 0);
}
return Status::OK();
}
Expand Down Expand Up @@ -644,11 +644,11 @@ Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
const std::vector<bool>& inclusion_mask, const IpcReadContext& context,
io::RandomAccessFile* file) {
if (inclusion_mask.size() > 0) {
return LoadRecordBatchSubset(metadata, schema, &inclusion_mask, context, file);
} else {
if (inclusion_mask.empty()) {
return LoadRecordBatchSubset(metadata, schema, /*inclusion_mask=*/nullptr, context,
file);
} else {
return LoadRecordBatchSubset(metadata, schema, &inclusion_mask, context, file);
}
}

Expand Down Expand Up @@ -1447,7 +1447,7 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
// Prebuffering's read patterns are also slightly worse than the alternative
// when doing whole-file reads because the logic is not in place to recognize
// we can just read the entire file up-front
if (options_.included_fields.size() != 0 &&
if (!options_.included_fields.empty() &&
options_.included_fields.size() != schema_->fields().size() &&
!file_->supports_zero_copy()) {
RETURN_NOT_OK(state->PreBufferMetadata({}));
Expand Down Expand Up @@ -1907,7 +1907,7 @@ Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
Future<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::OpenAsync(
const std::shared_ptr<io::RandomAccessFile>& file, const IpcReadOptions& options) {
ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
return OpenAsync(std::move(file), footer_offset, options);
return OpenAsync(file, footer_offset, options);
}

Future<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::OpenAsync(
Expand Down
Loading

0 comments on commit e82a70e

Please sign in to comment.