Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CBRT to the V2 engine #1081

Merged
merged 2 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ public static FunctionExpression sqrt(Expression... expressions) {
return compile(BuiltinFunctionName.SQRT, expressions);
}

public FunctionExpression cbrt(Expression... expressions) {
return compile(BuiltinFunctionName.CBRT, expressions);
}

public static FunctionExpression truncate(Expression... expressions) {
return compile(BuiltinFunctionName.TRUNCATE, expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public enum BuiltinFunctionName {
ROUND(FunctionName.of("round")),
SIGN(FunctionName.of("sign")),
SQRT(FunctionName.of("sqrt")),
CBRT(FunctionName.of("cbrt")),
TRUNCATE(FunctionName.of("truncate")),

ACOS(FunctionName.of("acos")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class MathematicalFunction {
*/
public static void register(BuiltinFunctionRepository repository) {
repository.register(abs());
repository.register(cbrt());
repository.register(ceil());
repository.register(ceiling());
repository.register(conv());
Expand Down Expand Up @@ -471,6 +472,20 @@ private static DefaultFunctionResolver sqrt() {
DOUBLE, type)).collect(Collectors.toList()));
}

/**
* Definition of cbrt(x) function.
* Calculate the cube root of a number x
* The supported signature is
* INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE
*/
private static DefaultFunctionResolver cbrt() {
return FunctionDSL.define(BuiltinFunctionName.CBRT.getName(),
ExprCoreType.numberTypes().stream()
.map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling(
v -> new ExprDoubleValue(Math.cbrt(v.doubleValue()))),
DOUBLE, type)).collect(Collectors.toList()));
}

/**
* Definition of truncate(x, d) function.
* Returns the number x, truncated to d decimal places
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2327,4 +2327,79 @@ public void tan_missing_value() {
assertEquals(DOUBLE, tan.type());
assertTrue(tan.valueOf(valueEnv()).isMissing());
}

/**
* Test cbrt with int value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(ints = {1, 2})
public void cbrt_int_value(Integer value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with long value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(longs = {1L, 2L})
public void cbrt_long_value(Long value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with float value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(floats = {1F, 2F})
public void cbrt_float_value(Float value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with double value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(doubles = {1D, 2D, Double.MAX_VALUE, Double.MIN_VALUE})
public void cbrt_double_value(Double value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with negative value.
*/
@ParameterizedTest(name = "cbrt({0})")
@ValueSource(doubles = {-1D, -2D})
public void cbrt_negative_value(Double value) {
FunctionExpression cbrt = dsl.cbrt(DSL.literal(value));
assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value))));
assertEquals(String.format("cbrt(%s)", value), cbrt.toString());
}

/**
* Test cbrt with null value.
*/
@Test
public void cbrt_null_value() {
FunctionExpression cbrt = dsl.cbrt(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER));
assertEquals(DOUBLE, cbrt.type());
assertTrue(cbrt.valueOf(valueEnv()).isNull());
}

/**
* Test cbrt with missing value.
*/
@Test
public void cbrt_missing_value() {
FunctionExpression cbrt = dsl.cbrt(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER));
assertEquals(DOUBLE, cbrt.type());
assertTrue(cbrt.valueOf(valueEnv()).isMissing());
}
}
19 changes: 17 additions & 2 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,24 @@ CBRT
Description
>>>>>>>>>>>

Specifications:
Usage: CBRT(number) calculates the cube root of a number

Argument type: INTEGER/LONG/FLOAT/DOUBLE

Return type: DOUBLE

1. CBRT(NUMBER T) -> T
(Non-negative) INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE
(Negative) INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE
dai-chen marked this conversation as resolved.
Show resolved Hide resolved

Example::

opensearchsql> SELECT CBRT(8), CBRT(9.261), CBRT(-27);
fetched rows / total rows = 1/1
+-----------+---------------+-------------+
| CBRT(8) | CBRT(9.261) | CBRT(-27) |
|-----------+---------------+-------------|
| 2.0 | 2.1 | -3.0 |
+-----------+---------------+-------------+


CEIL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,20 @@ protected JSONObject executeQuery(String query) throws IOException {
Response response = client().performRequest(request);
return new JSONObject(getResponseBody(response));
}


@Test
public void testCbrt() throws IOException {
JSONObject result = executeQuery("select cbrt(8)");
verifySchema(result, schema("cbrt(8)", "double"));
verifyDataRows(result, rows(2.0));

result = executeQuery("select cbrt(9.261)");
verifySchema(result, schema("cbrt(9.261)", "double"));
verifyDataRows(result, rows(2.1));

result = executeQuery("select cbrt(-27)");
verifySchema(result, schema("cbrt(-27)", "double"));
verifyDataRows(result, rows(-3.0));
}
}
2 changes: 1 addition & 1 deletion sql/src/main/antlr/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ aggregationFunctionName
;

mathematicalFunctionName
: ABS | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER
: ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER
| RAND | ROUND | SIGN | SQRT | TRUNCATE
| trigonometricFunctionName
;
Expand Down