From 43af3713fa6c586786e2a2ce17b095d3c1b9812a Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Mon, 12 Apr 2021 02:26:54 +0000 Subject: [PATCH] support normalize function in WSCG --- .../spark/sql/ColumnExpressionSuite.scala | 2 +- .../ext/expression_codegen_visitor.cc | 21 +++++++++++++++++++ .../cpp/src/precompile/gandiva.h | 10 +++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index d84338c71..f094e48d0 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -745,7 +745,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { - withSQLConf(("spark.oap.sql.columnar.testing", "true")) { + withSQLConf(("spark.oap.sql.columnar.batchscan", "true")) { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 4502ea474..1285522d0 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -802,6 +802,27 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } prepare_str_ += prepare_ss.str(); check_str_ = validity; + } else if (func_name.compare("normalize") == 0) { + codes_str_ = "normalize_" + std::to_string(cur_func_id); + auto validity = codes_str_ + "_validity"; + std::stringstream fix_ss; + fix_ss << "normalize_nan_zero(" << child_visitor_list[0]->GetResult() << ")"; + std::stringstream prepare_ss; + prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" + << std::endl; + prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() + << ";" << std::endl; + prepare_ss << "if (" << validity << ") {" << std::endl; + prepare_ss << codes_str_ << " = (" << GetCTypeString(node.return_type()) << ")" + << fix_ss.str() << ";" << std::endl; + prepare_ss << "}" << std::endl; + + for (int i = 0; i < 1; i++) { + prepare_str_ += child_visitor_list[i]->GetPrepare(); + } + prepare_str_ += prepare_ss.str(); + check_str_ = validity; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else { return arrow::Status::NotImplemented(func_name, " is currently not supported."); } diff --git a/native-sql-engine/cpp/src/precompile/gandiva.h b/native-sql-engine/cpp/src/precompile/gandiva.h index 6d1614684..b7500bee4 100644 --- a/native-sql-engine/cpp/src/precompile/gandiva.h +++ b/native-sql-engine/cpp/src/precompile/gandiva.h @@ -205,6 +205,16 @@ bool equal_with_nan(double left, double right) { return left == right; } +double normalize_nan_zero(double in) { + if (std::isnan(in)) { + return 0.0 / 0.0; + } else if (in < 0 && std::abs(in) < 0.0000001) { + return 0.0; + } else { + return in; + } +} + arrow::Decimal128 round(arrow::Decimal128 in, int32_t original_precision, int32_t original_scale, bool* overflow_, int32_t res_scale = 2) { bool overflow = false;