Skip to content

Commit

Permalink
In Expr bindings. (apache#10)
Browse files Browse the repository at this point in the history
* Complete In Expression Support.

* Fixed lint issues.

* Address Review comments.

* Fix review comments.

e Please enter the commit message for your changes. Lines starting

* Fix lint issues.
  • Loading branch information
praveenbingo authored Jun 23, 2019
1 parent 4bef481 commit f14f45e
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 0 deletions.
43 changes: 43 additions & 0 deletions cpp/src/gandiva/jni/jni_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,45 @@ NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
return TreeExprBuilder::MakeOr(children);
}

NodePtr ProtoTypeToInNode(const types::InNode& node) {
NodePtr field = ProtoTypeToFieldNode(node.field());

if (node.has_intvalues()) {
std::unordered_set<int32_t> int_values;
for (int i = 0; i < node.intvalues().intvalues_size(); i++) {
int_values.insert(node.intvalues().intvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionInt32(field, int_values);
}

if (node.has_longvalues()) {
std::unordered_set<int64_t> long_values;
for (int i = 0; i < node.longvalues().longvalues_size(); i++) {
long_values.insert(node.longvalues().longvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionInt64(field, long_values);
}

if (node.has_stringvalues()) {
std::unordered_set<std::string> stringvalues;
for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
stringvalues.insert(node.stringvalues().stringvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionString(field, stringvalues);
}

if (node.has_binaryvalues()) {
std::unordered_set<std::string> stringvalues;
for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) {
stringvalues.insert(node.binaryvalues().binaryvalues(i).value());
}
return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues);
}
// not supported yet.
std::cerr << "Unknown constant type for in expression.\n";
return nullptr;
}

NodePtr ProtoTypeToNullNode(const types::NullNode& node) {
DataTypePtr data_type = ProtoTypeToDataType(node.type());
if (data_type == nullptr) {
Expand Down Expand Up @@ -344,6 +383,10 @@ NodePtr ProtoTypeToNode(const types::TreeNode& node) {
return ProtoTypeToOrNode(node.ornode());
}

if (node.has_innode()) {
return ProtoTypeToInNode(node.innode());
}

if (node.has_nullnode()) {
return ProtoTypeToNullNode(node.nullnode());
}
Expand Down
27 changes: 27 additions & 0 deletions cpp/src/gandiva/proto/Types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ message TreeNode {
optional StringNode stringNode = 17;
optional BinaryNode binaryNode = 18;
optional DecimalNode decimalNode = 19;

// in expr
optional InNode inNode = 21;
}

message ExpressionRoot {
Expand Down Expand Up @@ -205,3 +208,27 @@ message FunctionSignature {
optional ExtGandivaType returnType = 2;
repeated ExtGandivaType paramTypes = 3;
}

message InNode {
optional FieldNode field = 1;
optional IntConstants intValues = 2;
optional LongConstants longValues = 3;
optional StringConstants stringValues = 4;
optional BinaryConstants binaryValues = 5;
}

message IntConstants {
repeated IntNode intValues = 1;
}

message LongConstants {
repeated LongNode longValues = 1;
}

message StringConstants {
repeated StringNode stringValues = 1;
}

message BinaryConstants {
repeated BinaryNode binaryValues = 1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.gandiva.expression;

import java.nio.charset.Charset;
import java.util.Set;

import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.gandiva.ipc.GandivaTypes;
import org.apache.arrow.vector.types.pojo.Field;

import com.google.protobuf.ByteString;

/**
* In Node representation in java.
*/
public class InNode implements TreeNode {
private static final Charset charset = Charset.forName("UTF-8");

private final Set<Integer> intValues;
private final Set<Long> longValues;
private final Set<String> stringValues;
private final Set<byte[]> binaryValues;
private final Field field;

private InNode(Set<Integer> values, Set<Long> longValues, Set<String> stringValues, Set<byte[]>
binaryValues, Field field) {
this.intValues = values;
this.longValues = longValues;
this.stringValues = stringValues;
this.binaryValues = binaryValues;
this.field = field;
}

public static InNode makeIntInExpr(Field field, Set<Integer> intValues) {
return new InNode(intValues, null, null, null ,field);
}

public static InNode makeLongInExpr(Field field, Set<Long> longValues) {
return new InNode(null, longValues, null, null ,field);
}

public static InNode makeStringInExpr(Field field, Set<String> stringValues) {
return new InNode(null, null, stringValues, null ,field);
}

public static InNode makeBinaryInExpr(Field field, Set<byte[]> binaryValues) {
return new InNode(null, null, null, binaryValues ,field);
}

@Override
public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
GandivaTypes.InNode.Builder inNode = GandivaTypes.InNode.newBuilder();

GandivaTypes.FieldNode.Builder fieldNode = GandivaTypes.FieldNode.newBuilder();
fieldNode.setField(ArrowTypeHelper.arrowFieldToProtobuf(field));
inNode.setField(fieldNode);

if (intValues != null) {
GandivaTypes.IntConstants.Builder intConstants = GandivaTypes.IntConstants.newBuilder();
intValues.stream().forEach(val -> intConstants.addIntValues(GandivaTypes.IntNode.newBuilder()
.setValue(val).build()));
inNode.setIntValues(intConstants.build());
} else if (longValues != null) {
GandivaTypes.LongConstants.Builder longConstants = GandivaTypes.LongConstants.newBuilder();
longValues.stream().forEach(val -> longConstants.addLongValues(GandivaTypes.LongNode.newBuilder()
.setValue(val).build()));
inNode.setLongValues(longConstants.build());
} else if (stringValues != null) {
GandivaTypes.StringConstants.Builder stringConstants = GandivaTypes.StringConstants
.newBuilder();
stringValues.stream().forEach(val -> stringConstants.addStringValues(GandivaTypes.StringNode
.newBuilder().setValue(ByteString.copyFrom(val.getBytes(charset))).build()));
inNode.setStringValues(stringConstants.build());
} else if (binaryValues != null) {
GandivaTypes.BinaryConstants.Builder binaryConstants = GandivaTypes.BinaryConstants
.newBuilder();
binaryValues.stream().forEach(val -> binaryConstants.addBinaryValues(GandivaTypes.BinaryNode
.newBuilder().setValue(ByteString.copyFrom(val)).build()));
inNode.setBinaryValues(binaryConstants.build());
}
GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
builder.setInNode(inNode.build());
return builder.build();

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
Expand Down Expand Up @@ -185,4 +186,24 @@ public static Condition makeCondition(String function,
TreeNode root = makeFunction(function, children, new ArrowType.Bool());
return makeCondition(root);
}

public static TreeNode makeInExpressionInt32(Field resultField,
Set<Integer> intValues) {
return InNode.makeIntInExpr(resultField, intValues);
}

public static TreeNode makeInExpressionBigInt(Field resultField,
Set<Long> longValues) {
return InNode.makeLongInExpr(resultField, longValues);
}

public static TreeNode makeInExpressionString(Field resultField,
Set<String> stringValues) {
return InNode.makeStringInExpr(resultField, stringValues);
}

public static TreeNode makeInExpressionBinary(Field resultField,
Set<byte[]> binaryValues) {
return InNode.makeBinaryInExpr(resultField, binaryValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.junit.Test;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import io.netty.buffer.ArrowBuf;

Expand Down Expand Up @@ -1047,6 +1048,96 @@ public void testEquals() throws GandivaException, Exception {
eval.close();
}

@Test
public void testInExpr() throws GandivaException, Exception {
Field c1 = Field.nullable("c1", int32);

TreeNode inExpr =
TreeBuilder.makeInExpressionInt32(c1, Sets.newHashSet(1,2,3,4,5,15,16));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));

int numRows = 16;
byte[] validity = new byte[]{(byte) 255, 0};
int[] c1Values = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

ArrowBuf c1Validity = buf(validity);
ArrowBuf c1Data = intBuf(c1Values);
ArrowBuf c2Validity = buf(validity);

ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
new ArrowRecordBatch(
numRows,
Lists.newArrayList(fieldNode, fieldNode),
Lists.newArrayList(c1Validity, c1Data, c2Validity));

BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
bitVector.allocateNew(numRows);

List<ValueVector> output = new ArrayList<ValueVector>();
output.add(bitVector);
eval.evaluate(batch, output);

for (int i = 0; i < 5; i++) {
assertTrue(bitVector.getObject(i).booleanValue());
}
for (int i = 5; i < 16; i++) {
assertFalse(bitVector.getObject(i).booleanValue());
}

releaseRecordBatch(batch);
releaseValueVectors(output);
eval.close();
}

@Test
public void testInExprStrings() throws GandivaException, Exception {
Field c1 = Field.nullable("c1", new ArrowType.Utf8());

TreeNode inExpr =
TreeBuilder.makeInExpressionString(c1, Sets.newHashSet("one", "two", "three", "four"));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));

int numRows = 16;
byte[] validity = new byte[]{(byte) 255, 0};
String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
"eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
"sixteen"};

ArrowBuf c1Validity = buf(validity);
List<ArrowBuf> dataBufsX = stringBufs(c1Values);
ArrowBuf c2Validity = buf(validity);

ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
new ArrowRecordBatch(
numRows,
Lists.newArrayList(fieldNode, fieldNode),
Lists.newArrayList(c1Validity, dataBufsX.get(0),dataBufsX.get(1), c2Validity));

BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
bitVector.allocateNew(numRows);

List<ValueVector> output = new ArrayList<ValueVector>();
output.add(bitVector);
eval.evaluate(batch, output);

for (int i = 0; i < 4; i++) {
assertTrue(bitVector.getObject(i).booleanValue());
}
for (int i = 5; i < 16; i++) {
assertFalse(bitVector.getObject(i).booleanValue());
}

releaseRecordBatch(batch);
releaseValueVectors(output);
eval.close();
}

@Test
public void testSmallOutputVectors() throws GandivaException, Exception {
Field a = Field.nullable("a", int32);
Expand Down

0 comments on commit f14f45e

Please sign in to comment.