diff --git a/be/src/vec/data_types/data_type_date_or_datetime_v2.h b/be/src/vec/data_types/data_type_date_or_datetime_v2.h index a2e652f9243329..8066cda6b0b9fe 100644 --- a/be/src/vec/data_types/data_type_date_or_datetime_v2.h +++ b/be/src/vec/data_types/data_type_date_or_datetime_v2.h @@ -190,5 +190,11 @@ template constexpr bool IsDataTypeDateTimeV2 = false; template <> inline constexpr bool IsDataTypeDateTimeV2 = true; + +template +constexpr bool IsDataTypeMap = false; +template <> +inline constexpr bool IsDataTypeMap = true; + #include "common/compile_check_end.h" } // namespace doris::vectorized diff --git a/be/src/vec/functions/function_map.cpp b/be/src/vec/functions/function_map.cpp index e0b20ec42aeb1e..6235841158e786 100644 --- a/be/src/vec/functions/function_map.cpp +++ b/be/src/vec/functions/function_map.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include "vec/columns/column_const.h" #include "vec/columns/column_map.h" #include "vec/columns/column_nullable.h" +#include "vec/columns/column_variant.h" #include "vec/columns/column_vector.h" #include "vec/common/assert_cast.h" #include "vec/common/typeid_cast.h" @@ -88,7 +90,6 @@ class FunctionMap : public IFunction { uint32_t result, size_t input_rows_count) const override { DCHECK(arguments.size() % 2 == 0) << "function: " << get_name() << ", arguments should not be even number"; - size_t num_element = arguments.size(); auto result_col = block.get_by_position(result).type->create_column(); @@ -792,6 +793,78 @@ class FunctionDeduplicateMap : public IFunction { private: }; +class FunctionMapConcat : public IFunction { +public: + static constexpr auto name = "map_concat"; + static FunctionPtr create() { return std::make_shared(); } + String get_name() const override { return name; } + bool is_variadic() const override { return true; } + size_t get_number_of_arguments() const override { return 0; } + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + if (arguments.empty()) { + return std::make_shared( + make_nullable(std::make_shared()), + make_nullable(std::make_shared())); + } + DCHECK(arguments.size() > 0) + << "function: " << get_name() << ", arguments should not be empty"; + return arguments[0]; + } + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + const uint32_t result, size_t input_rows_count) const override { + auto result_col = block.get_by_position(result).type->create_column(); + ColumnMap* result_map_column = nullptr; + ColumnNullable* result_nullable_column = nullptr; + if (result_col->is_nullable()) { + result_nullable_column = assert_cast(result_col.get()); + result_map_column = + assert_cast(result_nullable_column->get_nested_column_ptr().get()); + } else { + result_map_column = assert_cast(result_col.get()); + } + auto& result_col_map_keys_data = result_map_column->get_keys(); + auto& result_col_map_vals_data = result_map_column->get_values(); + ColumnArray::Offsets64& column_offsets = result_map_column->get_offsets(); + column_offsets.resize(input_rows_count); + + if (result_nullable_column) { + auto& null_map_data = result_nullable_column->get_null_map_data(); + null_map_data.resize_fill(input_rows_count, 0); + } + + size_t off = 0; + for (size_t row = 0; row < input_rows_count; row++) { + for (size_t col : arguments) { + const ColumnMap* map_column = nullptr; + auto src_column = + block.get_by_position(col).column->convert_to_full_column_if_const(); + if (src_column->is_nullable()) { + auto nullable_column = assert_cast(src_column.get()); + map_column = assert_cast( + nullable_column->get_nested_column_ptr().get()); + } else { + map_column = assert_cast(src_column.get()); + } + if (!map_column) { + return Status::RuntimeError("unsupported types for function {}({})", get_name(), + block.get_by_position(col).type->get_name()); + } + const auto& src_column_offsets = map_column->get_offsets(); + const size_t length = src_column_offsets[row] - src_column_offsets[row - 1]; + off += length; + for (size_t i = src_column_offsets[row - 1]; i < src_column_offsets[row]; i++) { + result_col_map_keys_data.insert_from(map_column->get_keys(), i); + result_col_map_vals_data.insert_from(map_column->get_values(), i); + } + } + column_offsets[row] = off; + } + RETURN_IF_ERROR(result_map_column->deduplicate_keys()); + block.replace_by_position(result, std::move(result_col)); + return Status::OK(); + } +}; + void register_function_map(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function>(); @@ -801,6 +874,7 @@ void register_function_map(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function(); factory.register_function(); + factory.register_function(); factory.register_function(); } diff --git a/be/test/vec/function/function_map_concat_test.cpp b/be/test/vec/function/function_map_concat_test.cpp new file mode 100644 index 00000000000000..6243aa799458e3 --- /dev/null +++ b/be/test/vec/function/function_map_concat_test.cpp @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include + +#include "function_test_util.h" +#include "vec/core/types.h" + +namespace doris::vectorized { +TEST(FunctionMapConcatTest, TestBase) { + const std::string func_name = "map_concat"; + { + // simple case - two maps + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING}; + DataSet data_set = {{TestArray({TestArray({std::int32_t(1), std::string("A")}), + TestArray({std::int32_t(2), std::string("B")})}), + TestArray({std::int32_t(1), std::string("A"), std::int32_t(2), + std::string("B")})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + // EXPECT_TRUE(status.ok()) << "Function test failed: " << status.to_string(); + } + { + // two maps concatenation with string keys + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_STRING}; + DataSet data_set = {{TestArray({TestArray({std::string("A"), std::string("a")}), + TestArray({std::string("B"), std::string("b")})}), + TestArray({std::string("A"), std::string("a"), std::string("B"), + std::string("b")})}}; + check_function_all_arg_comb(func_name, input_types, data_set); + } + { + // two maps concatenation with string keys + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_STRING}; + DataSet data_set = {{TestArray({TestArray({std::string("A"), std::string("a")}), + TestArray({std::string("B"), std::string("b")})}), + TestArray({std::string("A"), std::string("a"), std::string("B"), + std::string("b")})}, + {TestArray({TestArray({std::string("C"), std::string("c")}), + TestArray({std::string("D"), std::string("d")})}), + TestArray({std::string("C"), std::string("c"), std::string("D"), + std::string("d")})}}; + check_function_all_arg_comb(func_name, input_types, data_set); + } +} + +TEST(FunctionMapConcatTest, TestEdgeCases) { + const std::string func_name = "map_concat"; + + // Test empty maps + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING}; + DataSet data_set = { + {TestArray({TestArray({}), TestArray({})}), TestArray({})}, + {TestArray({TestArray({std::int32_t(1), std::string("A")}), TestArray({})}), + TestArray({std::int32_t(1), std::string("A")})}, + {TestArray({TestArray({}), TestArray({std::int32_t(2), std::string("B")})}), + TestArray({std::int32_t(2), std::string("B")})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test key conflicts (later map should override earlier) + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING}; + DataSet data_set = {{TestArray({TestArray({std::int32_t(1), std::string("A"), + std::int32_t(2), std::string("B")}), + TestArray({std::int32_t(1), std::string("C"), + std::int32_t(3), std::string("D")})}), + TestArray({std::int32_t(2), std::string("B"), std::int32_t(1), + std::string("C"), std::int32_t(3), std::string("D")})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test multiple maps (more than 2) + { + InputTypeSet input_types = { + PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING}; + DataSet data_set = {{TestArray({TestArray({std::int32_t(1), std::string("A")}), + TestArray({std::int32_t(2), std::string("B")}), + TestArray({std::int32_t(3), std::string("C")})}), + TestArray({std::int32_t(1), std::string("A"), std::int32_t(2), + std::string("B"), std::int32_t(3), std::string("C")})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test different value types + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_INT, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_INT}; + DataSet data_set = {{TestArray({TestArray({std::string("A"), std::int32_t(1), + std::string("B"), std::int32_t(2)}), + TestArray({std::string("C"), std::int32_t(3), + std::string("D"), std::int32_t(4)})}), + TestArray({std::string("A"), std::int32_t(1), std::string("B"), + std::int32_t(2), std::string("C"), std::int32_t(3), + std::string("D"), std::int32_t(4)})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test single map (should return the same map) + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, + PrimitiveType::TYPE_STRING}; + TestArray mp_src_array = + TestArray({std::int32_t(1), std::string("A"), std::int32_t(2), std::string("B")}); + TestArray src; + src.push_back(mp_src_array); + TestArray mp_dest_array = + TestArray({std::int32_t(1), std::string("A"), std::int32_t(2), std::string("B")}); + DataSet data_set = {{src, mp_dest_array}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } +} + +TEST(FunctionMapConcatTest, TestWithNULL) { + const std::string func_name = "map_concat"; + // Test with null map (one map is null) + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING}; + DataSet data_set = {{TestArray({Null(), TestArray({std::int32_t(1), std::string("A"), + std::int32_t(2), std::string("B")})}), + Null()}, + {TestArray({TestArray({std::int32_t(1), std::string("A"), + std::int32_t(2), std::string("B")}), + Null()}), + Null()}, + {TestArray({Null(), Null()}), Null()}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test with null values + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_INT, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_INT, PrimitiveType::TYPE_STRING}; + DataSet data_set = { + {TestArray({TestArray({std::int32_t(1), std::string("A"), std::int32_t(2), Null()}), + TestArray({std::int32_t(2), std::string("B"), std::int32_t(3), + std::string("C")})}), + TestArray({std::int32_t(1), std::string("A"), std::int32_t(2), std::string("B"), + std::int32_t(3), std::string("C")})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } +} + +TEST(FunctionMapConcatTest, TestComplexTypes) { + const std::string func_name = "map_concat"; + + // Test with nested types (array as value) + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_ARRAY, PrimitiveType::TYPE_INT, + PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_ARRAY, PrimitiveType::TYPE_INT}; + DataSet data_set = { + {TestArray({TestArray({std::string("A"), + TestArray({std::int32_t(1), std::int32_t(2)}), + std::string("B"), + TestArray({std::int32_t(3), std::int32_t(5)})}), + TestArray({std::string("C"), + TestArray({std::int32_t(4), std::int32_t(5)})})}), + TestArray({std::string("A"), TestArray({std::int32_t(1), std::int32_t(2)}), + std::string("B"), TestArray({std::int32_t(3), std::int32_t(5)}), + std::string("C"), TestArray({std::int32_t(4), std::int32_t(5)})})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test with map as value + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_INT, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_INT}; + DataSet data_set = { + {TestArray( + {TestArray( + {std::string("outer1"), + TestArray({TestArray({std::string("inner1"), std::int32_t(1)})}), + std::string("outer2"), + TestArray( + {TestArray({std::string("inner2"), std::int32_t(2)})})}), + TestArray({std::string("outer3"), + TestArray({TestArray( + {std::string("inner3"), std::int32_t(3)})})})}), + TestArray({std::string("outer1"), + TestArray({TestArray({std::string("inner1"), std::int32_t(1)})}), + std::string("outer2"), + TestArray({TestArray({std::string("inner2"), std::int32_t(2)})}), + std::string("outer3"), + TestArray({TestArray({std::string("inner3"), std::int32_t(3)})})})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test with double as value + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_DOUBLE, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_DOUBLE}; + DataSet data_set = { + {TestArray({TestArray({std::string("key1"), 1.5, std::string("key2"), 2.7}), + TestArray({std::string("key3"), 3.9})}), + TestArray({std::string("key1"), 1.5, std::string("key2"), 2.7, std::string("key3"), + 3.9})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test with float as value + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_FLOAT, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_FLOAT}; + DataSet data_set = { + {TestArray({TestArray({std::string("key1"), 1.5f, std::string("key2"), 2.7f}), + TestArray({std::string("key3"), 3.9f})}), + TestArray({std::string("key1"), 1.5f, std::string("key2"), 2.7f, + std::string("key3"), 3.9f})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } + + // Test with decimalv2 as value + { + InputTypeSet input_types = {PrimitiveType::TYPE_MAP, PrimitiveType::TYPE_STRING, + PrimitiveType::TYPE_DECIMALV2, PrimitiveType::TYPE_MAP, + PrimitiveType::TYPE_STRING, PrimitiveType::TYPE_DECIMALV2}; + DataSet data_set = { + {TestArray( + {TestArray({std::string("key1"), ut_type::DECIMALV2VALUEFROMDOUBLE(1.5), + std::string("key2"), ut_type::DECIMALV2VALUEFROMDOUBLE(2.7)}), + TestArray( + {std::string("key3"), ut_type::DECIMALV2VALUEFROMDOUBLE(3.9)})}), + TestArray({std::string("key1"), ut_type::DECIMALV2VALUEFROMDOUBLE(1.5), + std::string("key2"), ut_type::DECIMALV2VALUEFROMDOUBLE(2.7), + std::string("key3"), ut_type::DECIMALV2VALUEFROMDOUBLE(3.9)})}}; + + check_function_all_arg_comb(func_name, input_types, data_set); + } +} +} // namespace doris::vectorized diff --git a/be/test/vec/function/function_test_util.cpp b/be/test/vec/function/function_test_util.cpp index 091a3faae4880e..2ab0db5d26527a 100644 --- a/be/test/vec/function/function_test_util.cpp +++ b/be/test/vec/function/function_test_util.cpp @@ -219,14 +219,14 @@ static size_t type_index_to_data_type(const std::vector& input_types, s ut_type::UTDataTypeDesc value_desc; DataTypePtr value_type = nullptr; ++index; - size_t ret = type_index_to_data_type(input_types, index, key_desc, key_type); - if (ret <= 0) { - return ret; + size_t ret_key = type_index_to_data_type(input_types, index, key_desc, key_type); + if (ret_key <= 0) { + return ret_key; } ++index; - ret = type_index_to_data_type(input_types, index, value_desc, value_type); - if (ret <= 0) { - return ret; + size_t ret_value = type_index_to_data_type(input_types, index, value_desc, value_type); + if (ret_value <= 0) { + return ret_value; } if (key_desc.is_nullable) { key_type = make_nullable(key_type); @@ -236,7 +236,7 @@ static size_t type_index_to_data_type(const std::vector& input_types, s } type = std::make_shared(key_type, value_type); desc = type; - return ret + 1; + return ret_key + ret_value + 1; } case PrimitiveType::TYPE_STRUCT: { ++index; @@ -358,6 +358,31 @@ bool insert_array_cell(MutableColumnPtr& column, DataTypePtr type_ptr, const Any return true; } +bool insert_map_cell(MutableColumnPtr& column, DataTypePtr type_ptr, const AnyType& cell, + bool datetime_is_string_format) { + auto origin_input_array = any_cast(cell); + DataTypePtr key_type = assert_cast(type_ptr.get())->get_key_type(); + DataTypePtr value_type = assert_cast(type_ptr.get())->get_value_type(); + MutableColumnPtr key_column = key_type->create_column(); + MutableColumnPtr value_column = value_type->create_column(); + for (size_t i = 0; i < origin_input_array.size(); i += 2) { + const auto& key = origin_input_array[i]; + const auto& value = origin_input_array[i + 1]; + insert_cell(key_column, key_type, key, datetime_is_string_format); + insert_cell(value_column, value_type, value, datetime_is_string_format); + } + Array key_array, value_array; + for (size_t i = 0; i < origin_input_array.size() / 2; i++) { + key_array.push_back((*key_column)[i]); + value_array.push_back((*value_column)[i]); + } + Map map; + map.push_back(Field::create_field(key_array)); + map.push_back(Field::create_field(value_array)); + column->insert(Field::create_field(map)); + return true; +} + // NOLINTBEGIN(readability-function-size) bool insert_cell(MutableColumnPtr& column, DataTypePtr type_ptr, const AnyType& cell, bool datetime_is_string_format) { @@ -547,6 +572,10 @@ bool insert_cell(MutableColumnPtr& column, DataTypePtr type_ptr, const AnyType& } break; } + case PrimitiveType::TYPE_MAP: { + RETURN_IF_FALSE((insert_map_cell(column, type_ptr, cell, datetime_is_string_format))); + break; + } default: { std::cerr << "dataset not supported for type:" << type_to_string(type); return false; diff --git a/be/test/vec/function/function_test_util.h b/be/test/vec/function/function_test_util.h index abe896572d23e8..1c34471a1ec365 100644 --- a/be/test/vec/function/function_test_util.h +++ b/be/test/vec/function/function_test_util.h @@ -60,6 +60,7 @@ #include "vec/data_types/data_type_time.h" #include "vec/data_types/data_type_varbinary.h" #include "vec/exprs/function_context.h" +#include "vec/functions/function_helpers.h" #include "vec/functions/simple_function_factory.h" namespace doris::vectorized { @@ -371,6 +372,19 @@ Status check_function(const std::string& func_name, const InputTypeSet& input_ty } // 2. execute function + auto get_key_value_type = [&]() { + const DataTypeMap* map_type = nullptr; + if (descs[0].data_type->is_nullable()) { + auto* data_type_nullable = + assert_cast(descs[0].data_type.get()); + map_type = check_and_get_data_type( + data_type_nullable->get_nested_type().get()); + } else { + map_type = check_and_get_data_type(descs[0].data_type.get()); + } + assert(map_type); + return std::make_pair(map_type->get_key_type(), map_type->get_value_type()); + }; auto return_type = [&]() { if constexpr (IsDataTypeDecimal) { // decimal return ResultNullable ? make_nullable(std::make_shared(result_precision, @@ -384,6 +398,11 @@ Status check_function(const std::string& func_name, const InputTypeSet& input_ty } return ResultNullable ? make_nullable(std::make_shared(real_scale)) : std::make_shared(real_scale); + } else if constexpr (IsDataTypeMap) { + auto [key_type, value_type] = get_key_value_type(); + return ResultNullable + ? make_nullable(std::make_shared(key_type, value_type)) + : std::make_shared(key_type, value_type); } else { return ResultNullable ? make_nullable(std::make_shared()) : std::make_shared(); @@ -436,6 +455,11 @@ Status check_function(const std::string& func_name, const InputTypeSet& input_ty } result_type_ptr = ResultNullable ? make_nullable(std::make_shared(real_scale)) : std::make_shared(real_scale); + } else if constexpr (IsDataTypeMap) { + auto [key_type, value_type] = get_key_value_type(); + result_type_ptr = + ResultNullable ? make_nullable(std::make_shared(key_type, value_type)) + : std::make_shared(key_type, value_type); } else { result_type_ptr = ResultNullable ? make_nullable(std::make_shared()) : std::make_shared(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index 19f806b964d6cb..b6a72d9eb8cd68 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -320,6 +320,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeDate; import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeSet; import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeTime; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MapConcat; import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsEntry; import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey; import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue; @@ -881,6 +882,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(MapKeys.class, "map_keys"), scalar(MapSize.class, "map_size"), scalar(MapValues.class, "map_values"), + scalar(MapConcat.class, "map_concat"), scalar(Mask.class, "mask"), scalar(MaskFirstN.class, "mask_first_n"), scalar(MaskLastN.class, "mask_last_n"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapConcat.java new file mode 100644 index 00000000000000..0a497f259b9265 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/MapConcat.java @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.MapType; +import org.apache.doris.nereids.types.NullType; +import org.apache.doris.nereids.util.TypeCoercionUtils; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * ScalarFunction 'map_concat' + */ +public class MapConcat extends ScalarFunction + implements ExplicitlyCastableSignature, PropagateNullable { + + /** + * constructor with more than 0 arguments. + */ + public MapConcat(Expression... varArgs) { + super("map_concat", varArgs); + } + + /** + * private constructor + */ + private MapConcat(ScalarFunctionParams functionParams) { + super(functionParams); + } + + @Override + public MapConcat withChildren(List children) { + return new MapConcat(getFunctionParams(children)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitMapConcat(this, context); + } + + @Override + public List getSignatures() { + if (arity() == 0) { + return ImmutableList.of(FunctionSignature.ret(MapType.SYSTEM_DEFAULT).args()); + } + + List children = children(); + + List keyTypes = new ArrayList<>(); + List valueTypes = new ArrayList<>(); + + for (Expression child : children) { + DataType argType = child.getDataType(); + if (argType instanceof MapType) { + MapType mapType = (MapType) argType; + keyTypes.add(mapType.getKeyType()); + valueTypes.add(mapType.getValueType()); + } else if (!(argType instanceof NullType)) { + throw new AnalysisException("mapconcat function cannot process" + + "non-map and non-null child elements. " + + "Invalid SQL: " + toSql()); + } + } + Optional commonKeyType = TypeCoercionUtils.findWiderCommonType( + keyTypes, true, true); + Optional commonValueType = TypeCoercionUtils.findWiderCommonType( + valueTypes, true, true); + + if (!commonKeyType.isPresent()) { + throw new AnalysisException("mapconcat cannot find the common key type of " + toSql()); + } + if (!commonValueType.isPresent()) { + throw new AnalysisException("mapconcat cannot find the common value type of " + toSql()); + } + + DataType retMapType = MapType.of(commonKeyType.get(), commonValueType.get()); + ImmutableList.Builder retArgTypes = ImmutableList.builder(); + for (int i = 0; i < children.size(); i++) { + DataType argType = children.get(i).getDataType(); + if (argType instanceof MapType) { + retArgTypes.add(retMapType); + } else { + retArgTypes.add(argType); + } + } + + return ImmutableList.of(FunctionSignature.of(retMapType, retArgTypes.build())); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index f09bd61f8c36f3..e871904788971e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -328,6 +328,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeDate; import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeSet; import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeTime; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MapConcat; import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsEntry; import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey; import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue; @@ -2762,6 +2763,10 @@ default R visitPeriodDiff(PeriodDiff periodDiff, C context) { return visitScalarFunction(periodDiff, context); } + default R visitMapConcat(MapConcat mapConcat, C context) { + return visitScalarFunction(mapConcat, context); + } + default R visitUnicodeNormalize(UnicodeNormalize func, C context) { return visitScalarFunction(func, context); } diff --git a/regression-test/data/function_p0/test_map_concat.out b/regression-test/data/function_p0/test_map_concat.out new file mode 100644 index 00000000000000..6b79850984676f --- /dev/null +++ b/regression-test/data/function_p0/test_map_concat.out @@ -0,0 +1,41 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +{} + +-- !sql -- +{"single":"argument"} + +-- !sql -- +{"a":"apple", "b":"banana", "c":"cherry"} + +-- !sql -- +1 {"a":"apple", "b":"banana", "x":"10", "y":"20", "1":"one", "2":"two"} +2 {"c":"cherry", "z":"30", "3":"three"} +3 {} +4 \N + +-- !sql -- +{"a":"apple", "b":"banana"} \N \N + +-- !sql -- +{"a":"apple", "b":"blueberry", "c":"cherry"} + +-- !sql -- +4 ["x", "y"] + +-- !sql -- +{1:"one", 2:"two", 3:"three", 4:"four"} {"bigint1":10000000000, "bigint2":20000000000, "bigint3":30000000000} {"double1":1.2345678901, "double2":2.3456789012, "double3":3.4567890123} {"flag1":1, "flag2":0} {"arr1":[1, 2, 3], "arr2":[4, 5, 6]} {"decimal1":123.456, "decimal2":789.012, "decimal3":345.678} {"date1":"2023-01-01 00:00:00", "date2":"2023-12-31 00:00:00", "timestamp1":"2023-01-01 12:00:00"} {"int_val":100, "bigint_val":200} {"decimal_val":123.456, "double_val":789.0119999999999} {"中文键":"中文值", "key with emoji 🔥":"value with emoji 🚀", "key with accents café":"value with accents naïve"} + +-- !sql -- +{"a":1, "b":2, "c":3, "d":4} {"a":"1", "b":"2"} + +-- !sql -- +1 {"1":"one", "2":"two", "a":"apple", "b":"banana"} +2 {"3":"three", "c":"cherry"} +3 {} +4 \N + +-- !sql -- +1 {"a":"apple", "b":"banana", "x":"10", "y":"20"} +2 {"c":"cherry", "z":"30"} + diff --git a/regression-test/suites/function_p0/test_map_concat.groovy b/regression-test/suites/function_p0/test_map_concat.groovy new file mode 100644 index 00000000000000..43586da6d75aae --- /dev/null +++ b/regression-test/suites/function_p0/test_map_concat.groovy @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_map_concat") { + sql """ set enable_nereids_planner=true; """ + sql """ set enable_fallback_to_original_planner=false; """ + + def testTable = "test_map_concat_table" + + sql """ + drop table if exists ${testTable}; + """ + + sql """ + CREATE TABLE ${testTable} ( + id INT, + map1 MAP, + map2 MAP, + map3 MAP + ) ENGINE=OLAP + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """ + insert into ${testTable} values + (1, {'a': 'apple', 'b': 'banana'}, {'x': 10, 'y': 20}, {1: 'one', 2: 'two'}), + (2, {'c': 'cherry'}, {'z': 30}, {3: 'three'}), + (3, {}, {}, {}), + (4, NULL, NULL, NULL); + """ + + qt_sql """ + select map_concat() as empty_map; + """ + + qt_sql """ + select map_concat(map('single', 'argument')) as single_argument; + """ + + qt_sql """ + select map_concat({'a': 'apple'}, {'b': 'banana'}, {'c': 'cherry'}) as literal_maps_merged; + """ + + qt_sql """ + select id, map_concat(map1, map2, map3) as all_maps_merged from ${testTable} order by id; + """ + + qt_sql """ + select + map_concat(map1, {}) as merged_with_empty, + map_concat(map1, NULL) as map_with_null, + map_concat(NULL, NULL) as null_with_null + from ${testTable} where id = 1; + """ + + qt_sql """ + select map_concat({'a': 'apple', 'b': 'banana'}, {'b': 'blueberry', 'c': 'cherry'}) as conflict_resolution; + """ + + qt_sql """ + select + map_size(map_concat(map1, map2)) as two_maps_size, + map_keys(map_concat({'x': 10}, {'y': 20})) as keys_result + from ${testTable} where id = 1; + """ + + qt_sql """ + select + map_concat({1: 'one', 2: 'two'}, {3: 'three', 4: 'four'}) as int_key_maps, + map_concat( + CAST({'bigint1': 10000000000, 'bigint2': 20000000000} AS MAP), + CAST({'bigint3': 30000000000} AS MAP) + ) as bigint_values, + map_concat( + CAST({'double1': 1.2345678901, 'double2': 2.3456789012} AS MAP), + CAST({'double3': 3.4567890123} AS MAP) + ) as double_values, + map_concat({'flag1': true}, {'flag2': false}) as boolean_values, + map_concat({'arr1': [1, 2, 3]}, {'arr2': [4, 5, 6]}) as array_values, + map_concat( + CAST({'decimal1': 123.456, 'decimal2': 789.012} AS MAP), + CAST({'decimal3': 345.678} AS MAP) + ) as decimal_values, + map_concat({'date1': DATE '2023-01-01', 'date2': DATE '2023-12-31'}, + {'timestamp1': TIMESTAMP '2023-01-01 12:00:00'}) as timestamp_values, + map_concat( + CAST({'int_val': 100} AS MAP), + CAST({'bigint_val': 200} AS MAP) + ) as int_bigint_mixed, + map_concat( + CAST({'decimal_val': 123.456} AS MAP), + CAST({'double_val': 789.012} AS MAP) + ) as decimal_double_mixed, + map_concat({'中文键': '中文值', 'key with emoji 🔥': 'value with emoji 🚀'}, + {'key with accents café': 'value with accents naïve'}) as utf8_charset_test; + """ + + qt_sql """ + select + map_concat( + map_concat({'a': 1}, {'b': 2}), + map_concat({'c': 3}, {'d': 4}) + ) as nested_concat, + map_concat( + CAST({'a': 1} AS MAP), + CAST({'b': '2'} AS MAP) + ) as type_mismatch; + """ + + qt_sql """ + select id, map_concat(map3, map1) as merged_different_key_types + from ${testTable} + order by id; + """ + + qt_sql """ + select id, map_concat(map1, map2) as merged + from ${testTable} + where map_size(map_concat(map1, map2)) > 0 + order by id; + """ +}