Skip to content

Commit

Permalink
[oap-native-sql] Gandiva: Add shift_left, shift_right (apache#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Aug 7, 2020
1 parent fba5422 commit 9f71a5e
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float32),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float64),

// bitwise functions
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_left, {}, int32, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_left, {}, int64, int32, int64),
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_right, {}, int32, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(shift_right, {}, int64, int32, int64),

// compare functions
BINARY_RELATIONAL_BOOL_FN(equal, ({"eq", "same"})),
BINARY_RELATIONAL_BOOL_FN(not_equal, {}),
Expand Down
22 changes: 22 additions & 0 deletions cpp/src/gandiva/precompiled/arithmetic_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,28 @@ DIV_FLOAT(float64)

#undef DIV_FLOAT

#define SHIFT_LEFT_INT(LTYPE, RTYPE) \
FORCE_INLINE \
gdv_##LTYPE shift_left_##LTYPE##_##RTYPE(gdv_##LTYPE in1, gdv_##RTYPE in2) { \
return static_cast<gdv_##LTYPE>(in1 << in2); \
}

SHIFT_LEFT_INT(int32, int32)
SHIFT_LEFT_INT(int64, int32)

#undef SHIFT_RIGHT_INT

#define SHIFT_RIGHT_INT(LTYPE, RTYPE) \
FORCE_INLINE \
gdv_##LTYPE shift_right_##LTYPE##_##RTYPE(gdv_##LTYPE in1, gdv_##RTYPE in2) { \
return static_cast<gdv_##LTYPE>(in1 >> in2); \
}

SHIFT_RIGHT_INT(int32, int32)
SHIFT_RIGHT_INT(int64, int32)

#undef SHIFT_RIGHT_INT

#undef DATE_FUNCTION
#undef DATE_TYPES
#undef NUMERIC_BOOL_DATE_TYPES
Expand Down
87 changes: 87 additions & 0 deletions cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,93 @@ TEST_F(TestProjector, TestProjectCacheDecimalCast) {
EXPECT_EQ(projector0.get(), projector2.get());
}

TEST_F(TestProjector, TestShiftRight) {
// schema for input fields
auto field0 = field("f0", int32());
auto field1 = field("f1", int32());
auto schema = arrow::schema({field0, field1});

// output fields
auto field_shift_right = field("shift_right", int32());

// Build expression
auto shift_right_expr = TreeExprBuilder::MakeExpression("shift_right", {field0, field1},
field_shift_right);

std::shared_ptr<Projector> projector;

auto status =
Projector::Make(schema, {shift_right_expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());

if (!status.ok()) {
std::cout << status.message() << std::endl;
}

// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt32({4, 8, 16, 32}, {true, true, true, true});
auto array1 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, true});

// expected output
auto exp_shift_right = MakeArrowArrayInt32({2, 2, 2, 2}, {true, true, true, true});

// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});

// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());

// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp_shift_right, outputs.at(0));
}


TEST_F(TestProjector, TestShiftLeft) {
// schema for input fields
auto field0 = field("f0", int32());
auto field1 = field("f1", int32());
auto schema = arrow::schema({field0, field1});

// output fields
auto field_shift_left = field("shift_left", int32());

// Build expression
auto shift_left_expr = TreeExprBuilder::MakeExpression("shift_left", {field0, field1},
field_shift_left);

std::shared_ptr<Projector> projector;

auto status =
Projector::Make(schema, {shift_left_expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());

if (!status.ok()) {
std::cout << status.message() << std::endl;
}

// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt32({4, 8, 16, 32}, {true, true, true, true});
auto array1 = MakeArrowArrayInt32({4, 3, 2, 1}, {true, true, true, true});

// expected output
auto exp_shift_left = MakeArrowArrayInt32({64, 64, 64, 64}, {true, true, true, true});

// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});

// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());

// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp_shift_left, outputs.at(0));
}

TEST_F(TestProjector, TestIntSumSub) {
// schema for input fields
auto field0 = field("f0", int32());
Expand Down

0 comments on commit 9f71a5e

Please sign in to comment.