Skip to content

Commit

Permalink
[feature](function)support count_substrings functions (apache#42055)
Browse files Browse the repository at this point in the history
## Proposed changes
support count_substrings functions
doc: apache/doris-website#1211

<!--Describe your changes.-->
  • Loading branch information
zhangstar333 authored Oct 24, 2024
1 parent 9f10825 commit 5155919
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 0 deletions.
1 change: 1 addition & 0 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionFromBase64>();
factory.register_function<FunctionSplitPart>();
factory.register_function<FunctionSplitByString>();
factory.register_function<FunctionCountSubString>();
factory.register_function<FunctionSubstringIndex>();
factory.register_function<FunctionExtractURLParameter>();
factory.register_function<FunctionStringParseUrl>();
Expand Down
116 changes: 116 additions & 0 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -2370,6 +2370,122 @@ class FunctionSplitByString : public IFunction {
}
};

class FunctionCountSubString : public IFunction {
public:
static constexpr auto name = "count_substrings";

static FunctionPtr create() { return std::make_shared<FunctionCountSubString>(); }
using NullMapType = PaddedPODArray<UInt8>;

String get_name() const override { return name; }

size_t get_number_of_arguments() const override { return 2; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
DCHECK(is_string(arguments[0]))
<< "first argument for function: " << name << " should be string"
<< " and arguments[0] is " << arguments[0]->get_name();
DCHECK(is_string(arguments[1]))
<< "second argument for function: " << name << " should be string"
<< " and arguments[1] is " << arguments[1]->get_name();
return std::make_shared<DataTypeInt32>();
}

Status execute_impl(FunctionContext* /*context*/, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
DCHECK_EQ(arguments.size(), 2);
const auto& [src_column, left_const] =
unpack_if_const(block.get_by_position(arguments[0]).column);
const auto& [right_column, right_const] =
unpack_if_const(block.get_by_position(arguments[1]).column);

const auto* col_left = check_and_get_column<ColumnString>(src_column.get());
if (!col_left) {
return Status::InternalError("Left operator of function {} can not be {}", get_name(),
block.get_by_position(arguments[0]).type->get_name());
}

const auto* col_right = check_and_get_column<ColumnString>(right_column.get());
if (!col_right) {
return Status::InternalError("Right operator of function {} can not be {}", get_name(),
block.get_by_position(arguments[1]).type->get_name());
}

auto dest_column_ptr = ColumnInt32::create(input_rows_count, 0);
// count_substring(ColumnString, "xxx")
if (right_const) {
_execute_constant_pattern(*col_left, col_right->get_data_at(0),
dest_column_ptr->get_data(), input_rows_count);
} else if (left_const) {
// count_substring("xxx", ColumnString)
_execute_constant_src_string(col_left->get_data_at(0), *col_right,
dest_column_ptr->get_data(), input_rows_count);
} else {
// count_substring(ColumnString, ColumnString)
_execute_vector(*col_left, *col_right, dest_column_ptr->get_data(), input_rows_count);
}

block.replace_by_position(result, std::move(dest_column_ptr));
return Status::OK();
}

private:
void _execute_constant_pattern(const ColumnString& src_column_string,
const StringRef& pattern_ref,
ColumnInt32::Container& dest_column_data,
size_t input_rows_count) const {
for (size_t i = 0; i < input_rows_count; i++) {
const StringRef str_ref = src_column_string.get_data_at(i);
dest_column_data[i] = find_str_count(str_ref, pattern_ref);
}
}

void _execute_vector(const ColumnString& src_column_string, const ColumnString& pattern_column,
ColumnInt32::Container& dest_column_data, size_t input_rows_count) const {
for (size_t i = 0; i < input_rows_count; i++) {
const StringRef pattern_ref = pattern_column.get_data_at(i);
const StringRef str_ref = src_column_string.get_data_at(i);
dest_column_data[i] = find_str_count(str_ref, pattern_ref);
}
}

void _execute_constant_src_string(const StringRef& str_ref, const ColumnString& pattern_col,
ColumnInt32::Container& dest_column_data,
size_t input_rows_count) const {
for (size_t i = 0; i < input_rows_count; ++i) {
const StringRef pattern_ref = pattern_col.get_data_at(i);
dest_column_data[i] = find_str_count(str_ref, pattern_ref);
}
}

size_t find_pos(size_t pos, const StringRef str_ref, const StringRef pattern_ref) const {
size_t old_size = pos;
size_t str_size = str_ref.size;
while (pos < str_size && memcmp_small_allow_overflow15(str_ref.data + pos, pattern_ref.data,
pattern_ref.size)) {
pos++;
}
return pos - old_size;
}

int find_str_count(const StringRef str_ref, StringRef pattern_ref) const {
int count = 0;
if (str_ref.size == 0 || pattern_ref.size == 0) {
return 0;
} else {
for (size_t str_pos = 0; str_pos <= str_ref.size;) {
const size_t res_pos = find_pos(str_pos, str_ref, pattern_ref);
if (res_pos == (str_ref.size - str_pos)) {
break; // not find
}
count++;
str_pos = str_pos + res_pos + pattern_ref.size;
}
}
return count;
}
};

struct SM3Sum {
static constexpr auto name = "sm3sum";
using ObjectData = SM3Digest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cosh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountSubstring;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Crc32;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct;
Expand Down Expand Up @@ -596,6 +597,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Cosh.class, "cosh"),
scalar(CosineDistance.class, "cosine_distance"),
scalar(CountEqual.class, "countequal"),
scalar(CountSubstring.class, "count_substrings"),
scalar(CreateMap.class, "map"),
scalar(CreateStruct.class, "struct"),
scalar(CreateNamedStruct.class, "named_struct"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'count_substrings'.
*/
public class CountSubstring extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(IntegerType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE)
);

/**
* constructor with 2 arguments.
*/
public CountSubstring(Expression arg0, Expression arg1) {
super("count_substrings", arg0, arg1);
}

/**
* withChildren.
*/
@Override
public CountSubstring withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new CountSubstring(children.get(0), children.get(1));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitCountSubstring(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cosh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountSubstring;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Crc32;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct;
Expand Down Expand Up @@ -954,6 +955,10 @@ default R visitCountEqual(CountEqual countequal, C context) {
return visitScalarFunction(countequal, context);
}

default R visitCountSubstring(CountSubstring countSubstring, C context) {
return visitScalarFunction(countSubstring, context);
}

default R visitCurrentCatalog(CurrentCatalog currentCatalog, C context) {
return visitScalarFunction(currentCatalog, context);
}
Expand Down
1 change: 1 addition & 0 deletions gensrc/script/doris_builtins_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,7 @@

[['overlay'], 'VARCHAR', ['VARCHAR', 'INT', 'INT', 'VARCHAR'], ''],

[['count_substrings'], 'INT', ['STRING', 'STRING'], 'DEPEND_ON_ARGUMENT'],
[['substr', 'substring'], 'STRING', ['STRING', 'INT'], 'DEPEND_ON_ARGUMENT'],
[['substr', 'substring'], 'STRING', ['STRING', 'INT', 'INT'], 'DEPEND_ON_ARGUMENT'],
[['strleft', 'left'], 'STRING', ['STRING', 'INT'], 'DEPEND_ON_ARGUMENT'],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select1 --
\N

-- !select2 --
\N

-- !select3 --
\N

-- !select4 --
2

-- !select5 --
6

-- !select6 --
1

-- !select4_empty --

-- !select5_empty --

-- !select6_empty --

-- !select7_empty --

-- !select5_null_null --
abcde 0
0
a 0
\N \N \N
asdasd a 2
a1b1c1d 1 3
,,, # 0
a,b,c v 0
a,b,c, \N \N
\N asd \N
a,b,c,12345 5 1
a,b,c,12345 a 1
a,你,你,1我2你4我5 你 3

-- !select6_null_not --
abcde 0
0
a 0
\N \N
asdasd a 2
a1b1c1d 1 3
,,, # 0
a,b,c v 0
a,b,c, 0
\N asd \N
a,b,c,12345 5 1
a,b,c,12345 a 1
a,你,你,1我2你4我5 我 2

-- !select7_not_null --
abcde 0
0
a 0
\N \N
asdasd a 2
a1b1c1d 1 3
,,, # 0
a,b,c v 0
a,b,c \N \N
asd 0
a,b,c,12345 5 1
a,b,c,12345 a 1
a你,你,1我2你4我5 你 3

-- !select8_not_not --
abcde 0
0
a 0
0
asdasd a 2
a1b1c1d 1 3
,,, # 0
a,b,c v 0
a,b,c 0
asd 0
a,b,c,12345 5 1
a,b,c,12345 a 1
a你,你,1我2你4我5 我 2

-- !select9_null_const --
abcde a 1
a 0
a 0
\N a \N
asdasd a 2
a1b1c1d a 1
,,, a 0
a,b,c a 1
a,b,c, a 1
\N a \N
a,b,c,12345 a 1
a,b,c,12345 a 1
a,你,你,1我2你4我5 a 1

-- !select10_not_null_const --
abcde a 1
a 0
a 0
a 0
asdasd a 2
a1b1c1d a 1
,,, a 0
a,b,c a 1
a,b,c a 1
a 0
a,b,c,12345 a 1
a,b,c,12345 a 1
a你,你,1我2你4我5 a 1

-- !select11_const_null --
a 0
a 0
a a 1
a \N \N
a a 1
a 1 0
a # 0
a v 0
a \N \N
a asd 0
a 5 0
a a 1
a 你 0

-- !select12_const_not_null --
a 0
a 0
a a 1
a 0
a a 1
a 1 0
a # 0
a v 0
a 0
a asd 0
a 5 0
a a 1
a 我 0

Loading

0 comments on commit 5155919

Please sign in to comment.