diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc index 72061c085e679..b4b9ffedceda7 100644 --- a/cpp/src/gandiva/jni/jni_common.cc +++ b/cpp/src/gandiva/jni/jni_common.cc @@ -313,6 +313,45 @@ NodePtr ProtoTypeToOrNode(const types::OrNode& node) { return TreeExprBuilder::MakeOr(children); } +NodePtr ProtoTypeToInNode(const types::InNode& node) { + NodePtr field = ProtoTypeToFieldNode(node.field()); + + if (node.has_intvalues()) { + std::unordered_set int_values; + for (int i = 0; i < node.intvalues().intvalues_size(); i++) { + int_values.insert(node.intvalues().intvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionInt32(field, int_values); + } + + if (node.has_longvalues()) { + std::unordered_set long_values; + for (int i = 0; i < node.longvalues().longvalues_size(); i++) { + long_values.insert(node.longvalues().longvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionInt64(field, long_values); + } + + if (node.has_stringvalues()) { + std::unordered_set stringvalues; + for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) { + stringvalues.insert(node.stringvalues().stringvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionString(field, stringvalues); + } + + if (node.has_binaryvalues()) { + std::unordered_set stringvalues; + for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) { + stringvalues.insert(node.binaryvalues().binaryvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues); + } + // not supported yet. + std::cerr << "Unknown constant type for in expression.\n"; + return nullptr; +} + NodePtr ProtoTypeToNullNode(const types::NullNode& node) { DataTypePtr data_type = ProtoTypeToDataType(node.type()); if (data_type == nullptr) { @@ -344,6 +383,10 @@ NodePtr ProtoTypeToNode(const types::TreeNode& node) { return ProtoTypeToOrNode(node.ornode()); } + if (node.has_innode()) { + return ProtoTypeToInNode(node.innode()); + } + if (node.has_nullnode()) { return ProtoTypeToNullNode(node.nullnode()); } diff --git a/cpp/src/gandiva/proto/Types.proto b/cpp/src/gandiva/proto/Types.proto index 9efa80f67604f..d264450cb0a19 100644 --- a/cpp/src/gandiva/proto/Types.proto +++ b/cpp/src/gandiva/proto/Types.proto @@ -173,6 +173,9 @@ message TreeNode { optional StringNode stringNode = 17; optional BinaryNode binaryNode = 18; optional DecimalNode decimalNode = 19; + + // in expr + optional InNode inNode = 21; } message ExpressionRoot { @@ -205,3 +208,27 @@ message FunctionSignature { optional ExtGandivaType returnType = 2; repeated ExtGandivaType paramTypes = 3; } + +message InNode { + optional FieldNode field = 1; + optional IntConstants intValues = 2; + optional LongConstants longValues = 3; + optional StringConstants stringValues = 4; + optional BinaryConstants binaryValues = 5; +} + +message IntConstants { + repeated IntNode intValues = 1; +} + +message LongConstants { + repeated LongNode longValues = 1; +} + +message StringConstants { + repeated StringNode stringValues = 1; +} + +message BinaryConstants { + repeated BinaryNode binaryValues = 1; +} diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java new file mode 100644 index 0000000000000..0420ffb9758ef --- /dev/null +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import java.nio.charset.Charset; +import java.util.Set; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.pojo.Field; + +import com.google.protobuf.ByteString; + +/** + * In Node representation in java. + */ +public class InNode implements TreeNode { + private static final Charset charset = Charset.forName("UTF-8"); + + private final Set intValues; + private final Set longValues; + private final Set stringValues; + private final Set binaryValues; + private final Field field; + + private InNode(Set values, Set longValues, Set stringValues, Set + binaryValues, Field field) { + this.intValues = values; + this.longValues = longValues; + this.stringValues = stringValues; + this.binaryValues = binaryValues; + this.field = field; + } + + public static InNode makeIntInExpr(Field field, Set intValues) { + return new InNode(intValues, null, null, null ,field); + } + + public static InNode makeLongInExpr(Field field, Set longValues) { + return new InNode(null, longValues, null, null ,field); + } + + public static InNode makeStringInExpr(Field field, Set stringValues) { + return new InNode(null, null, stringValues, null ,field); + } + + public static InNode makeBinaryInExpr(Field field, Set binaryValues) { + return new InNode(null, null, null, binaryValues ,field); + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.InNode.Builder inNode = GandivaTypes.InNode.newBuilder(); + + GandivaTypes.FieldNode.Builder fieldNode = GandivaTypes.FieldNode.newBuilder(); + fieldNode.setField(ArrowTypeHelper.arrowFieldToProtobuf(field)); + inNode.setField(fieldNode); + + if (intValues != null) { + GandivaTypes.IntConstants.Builder intConstants = GandivaTypes.IntConstants.newBuilder(); + intValues.stream().forEach(val -> intConstants.addIntValues(GandivaTypes.IntNode.newBuilder() + .setValue(val).build())); + inNode.setIntValues(intConstants.build()); + } else if (longValues != null) { + GandivaTypes.LongConstants.Builder longConstants = GandivaTypes.LongConstants.newBuilder(); + longValues.stream().forEach(val -> longConstants.addLongValues(GandivaTypes.LongNode.newBuilder() + .setValue(val).build())); + inNode.setLongValues(longConstants.build()); + } else if (stringValues != null) { + GandivaTypes.StringConstants.Builder stringConstants = GandivaTypes.StringConstants + .newBuilder(); + stringValues.stream().forEach(val -> stringConstants.addStringValues(GandivaTypes.StringNode + .newBuilder().setValue(ByteString.copyFrom(val.getBytes(charset))).build())); + inNode.setStringValues(stringConstants.build()); + } else if (binaryValues != null) { + GandivaTypes.BinaryConstants.Builder binaryConstants = GandivaTypes.BinaryConstants + .newBuilder(); + binaryValues.stream().forEach(val -> binaryConstants.addBinaryValues(GandivaTypes.BinaryNode + .newBuilder().setValue(ByteString.copyFrom(val)).build())); + inNode.setBinaryValues(binaryConstants.build()); + } + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setInNode(inNode.build()); + return builder.build(); + + } +} diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java index a220c547e44a6..78e662a4d35e0 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Set; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -185,4 +186,24 @@ public static Condition makeCondition(String function, TreeNode root = makeFunction(function, children, new ArrowType.Bool()); return makeCondition(root); } + + public static TreeNode makeInExpressionInt32(Field resultField, + Set intValues) { + return InNode.makeIntInExpr(resultField, intValues); + } + + public static TreeNode makeInExpressionBigInt(Field resultField, + Set longValues) { + return InNode.makeLongInExpr(resultField, longValues); + } + + public static TreeNode makeInExpressionString(Field resultField, + Set stringValues) { + return InNode.makeStringInExpr(resultField, stringValues); + } + + public static TreeNode makeInExpressionBinary(Field resultField, + Set binaryValues) { + return InNode.makeBinaryInExpr(resultField, binaryValues); + } } diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index e5f9fe6ab92b8..62a12710cc757 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -51,6 +51,7 @@ import org.junit.Test; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import io.netty.buffer.ArrowBuf; @@ -1047,6 +1048,96 @@ public void testEquals() throws GandivaException, Exception { eval.close(); } + @Test + public void testInExpr() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", int32); + + TreeNode inExpr = + TreeBuilder.makeInExpressionInt32(c1, Sets.newHashSet(1,2,3,4,5,15,16)); + ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(c1)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + int[] c1Values = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + ArrowBuf c1Validity = buf(validity); + ArrowBuf c1Data = intBuf(c1Values); + ArrowBuf c2Validity = buf(validity); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(c1Validity, c1Data, c2Validity)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List output = new ArrayList(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 5; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + for (int i = 5; i < 16; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testInExprStrings() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", new ArrowType.Utf8()); + + TreeNode inExpr = + TreeBuilder.makeInExpressionString(c1, Sets.newHashSet("one", "two", "three", "four")); + ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(c1)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven", + "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", + "sixteen"}; + + ArrowBuf c1Validity = buf(validity); + List dataBufsX = stringBufs(c1Values); + ArrowBuf c2Validity = buf(validity); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(c1Validity, dataBufsX.get(0),dataBufsX.get(1), c2Validity)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List output = new ArrayList(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 4; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + for (int i = 5; i < 16; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + @Test public void testSmallOutputVectors() throws GandivaException, Exception { Field a = Field.nullable("a", int32);