Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions be/src/vec/functions/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,24 @@ struct AtanhName {
using FunctionAtanh =
FunctionMathUnaryAlwayNullable<UnaryFunctionPlainAlwayNullable<AtanhName, std::atanh>>;

struct CotName {
static constexpr auto name = "cot";
static constexpr bool is_invalid_input(Float64 x) {
constexpr double epsilon = 1e-10;
double remainder = std::fmod(std::abs(x), M_PI);

return std::abs(x) < epsilon || std::abs(remainder) < epsilon ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just std::abs(remainder) < epsilon maybe is enough? this check seems could cover other two conditions. could you confirm this?

std::abs(remainder - M_PI) < epsilon;
}
};

static inline double cot_impl(double x) {
return 1 / std::tan(x);
}

using FunctionCot =
FunctionMathUnaryAlwayNullable<UnaryFunctionPlainAlwayNullable<CotName, cot_impl>>;

template <PrimitiveType AType, PrimitiveType BType>
struct Atan2Impl {
using A = typename PrimitiveTypeTraits<AType>::ColumnItemType;
Expand Down Expand Up @@ -540,6 +558,7 @@ void register_function_math(SimpleFunctionFactory& factory) {
factory.register_function<FunctionAtan2>();
factory.register_function<FunctionCos>();
factory.register_function<FunctionCosh>();
factory.register_function<FunctionCot>();
factory.register_function<FunctionE>();
factory.register_alias("ln", "dlog1");
factory.register_function<FunctionLog>();
Expand Down
12 changes: 12 additions & 0 deletions be/test/vec/function/function_math_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ TEST(MathFunctionTest, cos_test) {
static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, cot_test) {
std::string func_name = "cot";

InputTypeSet input_types = {PrimitiveType::TYPE_DOUBLE};

DataSet data_set = {{{-1.0}, -0.6420926159343306}, {{0.5}, 1.830487721712452},
{{1.0}, 0.6420926159343306}, {{0.0}, Null()},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a testcase return Null which input is not zero

{{M_PI / 4}, 1.0000000000000002}, {{M_PI / 2}, 6.123233995736766e-17}};

static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, sin_test) {
std::string func_name = "sin";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cosh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountSubstring;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Crc32;
Expand Down Expand Up @@ -621,6 +622,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Cos.class, "cos"),
scalar(Cosh.class, "cosh"),
scalar(CosineDistance.class, "cosine_distance"),
scalar(Cot.class, "cot"),
scalar(CountEqual.class, "countequal"),
scalar(CountSubstring.class, "count_substrings"),
scalar(CreateMap.class, "map"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,25 @@ public static Expression cos(DoubleLiteral first) {
return checkOutputBoundary(new DoubleLiteral(Math.cos(first.getValue())));
}

/**
* cot
*/
@ExecFunction(name = "cot")
public static Expression cot(DoubleLiteral first) {
if (inputOutOfBound(first, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, false, false)) {
return new NullLiteral(DoubleType.INSTANCE);
} else {
final double epsilon = 1e-10;
double value = first.getValue();
double remainder = Math.abs(value) % Math.PI;

if (Math.abs(value) < epsilon || Math.abs(remainder) < epsilon || Math.abs(remainder - Math.PI) < epsilon) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seem with comment in BE

return new NullLiteral(DoubleType.INSTANCE);
}
return checkOutputBoundary(new DoubleLiteral(1.0 / Math.tan(first.getValue())));
}
}

/**
* tan
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* Cot Scala Function
*/
public class Cot extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable, PropagateNullLiteral {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
);

/**
* constructor with 1 argument.
*/
public Cot(Expression arg) {
super("cot", arg);
}

/**
* withChildren.
*/
@Override
public Cot withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Cot(children.get(0));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitCot(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cosh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Cot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CountSubstring;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Crc32;
Expand Down Expand Up @@ -938,6 +939,10 @@ default R visitConcat(Concat concat, C context) {
return visitScalarFunction(concat, context);
}

default R visitCot(Cot cot, C context) {
return visitScalarFunction(cot, context);
}

default R visitChar(Char charFunc, C context) {
return visitScalarFunction(charFunc, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@
0.0 false 8
0.0 false 9

-- !cot_1 --
-0.6420926159343306 false

-- !cot_2 --
0.6420926159343306 false

-- !cot_3 --
\N true 0
\N true 1
\N true 2
\N true 3
\N true 4
\N true 5
\N true 6
\N true 7
\N true 8
\N true 9

-- !sqrt_1 --
\N true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ suite("fold_constant_numeric_arithmatic") {
testFoldConst("SELECT COSH(709.782712893384)") // Near overflow boundary
testFoldConst("SELECT COSH(-709.782712893384)") // Near negative overflow boundary

//Cot function cases
testFoldConst("SELECT COT(PI()) AS cot_case_1") //cot(π) = null
testFoldConst("SELECT COT(0) AS cot_case_2") //cot(0) = null
testFoldConst("SELECT COT(PI()/2) AS cot_case_3") //cot(π/2)
testFoldConst("SELECT COT(PI()/4)")
testFoldConst("SELECT COT(PI()/6)")
testFoldConst("SELECT COT(-PI())") // Negative PI
testFoldConst("SELECT COT(1E-308)") // Very small number
testFoldConst("SELECT COT(-1E-308)") // Very small negative number
testFoldConst("SELECT COT(NULL)") // NULL handling

//CountEqual function cases
testFoldConst("SELECT COUNT(CASE WHEN 5 = 5 THEN 1 END) AS countequal_case_1") //1 (true)
testFoldConst("SELECT COUNT(CASE WHEN 5 = 3 THEN 1 END) AS countequal_case_2") //0 (false)
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add some test with data in table

Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ suite("test_math_unary_alway_nullable") {
select atanh(0.0), atanh(0.0) is NULL, number from numbers("number"="10")
"""

qt_cot_1 """
select cot(-1.0), cot(-1.0) is null;
"""
qt_cot_2 """
select cot(1.0), cot(1.0) is null;
"""
qt_cot_3 """
select cot(0.0), cot(0.0) is NULL, number from numbers("number"="10")
"""

qt_sqrt_1 """
select sqrt(-1), sqrt(-1) is null;
"""
Expand Down