diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryExpression.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryExpression.scala index 897df4e03..1ad6b4bed 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryExpression.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryExpression.scala @@ -95,6 +95,22 @@ class ColumnarGetJsonObject(left: Expression, right: Expression, original: GetJs } } +class ColumnarStringInstr(left: Expression, right: Expression, original: StringInstr) + extends StringInstr(original.str, original.substr) with ColumnarExpression with Logging { + + override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = { + val (left_node, _): (TreeNode, ArrowType) = + left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val (right_node, _): (TreeNode, ArrowType) = + right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val resultType = CodeGeneration.getResultType(dataType) + // Be careful about the argument order. + val funcNode = TreeBuilder.makeFunction("locate", + Lists.newArrayList(right_node, left_node, + TreeBuilder.makeLiteral(1.asInstanceOf[java.lang.Integer])), resultType) + (funcNode, resultType) + } +} object ColumnarBinaryExpression { @@ -116,6 +132,8 @@ object ColumnarBinaryExpression { new ColumnarDateSub(left, right) case g: GetJsonObject => new ColumnarGetJsonObject(left, right, g) + case instr: StringInstr => + new ColumnarStringInstr(left, right, instr) case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 9cf33e04c..20d63da00 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -303,6 +303,31 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_ss << "if (" << check_str_ << ")" << std::endl; prepare_ss << codes_str_ << " = " << ss.str() << ";" << std::endl; prepare_str_ += prepare_ss.str(); + } else if (func_name.compare("instr") == 0) { + codes_str_ = func_name + "_" + std::to_string(cur_func_id); + auto validity = codes_str_ + "_validity"; + real_codes_str_ = codes_str_; + real_validity_str_ = validity; + std::stringstream prepare_ss; + prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" + << std::endl; + prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() + << ";" << std::endl; + prepare_ss << "if (" << validity << ") {" << std::endl; + prepare_ss << "auto ind = " << child_visitor_list[0]->GetResult() << ".find(" + << child_visitor_list[1]->GetResult() << ");" << std::endl; + prepare_ss << "if (ind == std::string::npos) {" << std::endl; + prepare_ss << codes_str_ << " = 0;" << std::endl; + prepare_ss << "}" << std::endl; + prepare_ss << "else {" << std::endl; + prepare_ss << codes_str_ << " = ind + 1;" << std::endl; + prepare_ss << "}" << std::endl; + prepare_ss << "}" << std::endl; + for (int i = 0; i < 1; i++) { + prepare_str_ += child_visitor_list[i]->GetPrepare(); + } + prepare_str_ += prepare_ss.str(); + check_str_ = validity; } else if (func_name.compare("btrim") == 0) { codes_str_ = func_name + "_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity";