From 82d0e0b0efd979896999e8686dffeef2bdae6ae3 Mon Sep 17 00:00:00 2001 From: amory Date: Thu, 9 Nov 2023 18:54:23 +0800 Subject: [PATCH] [Improve](map)Map impli cast #26126 (#26654) --- be/src/vec/functions/function_cast.h | 58 ++++++++++++++++-- .../org/apache/doris/catalog/MapType.java | 6 +- .../java/org/apache/doris/catalog/Type.java | 11 ++++ .../org/apache/doris/analysis/CastExpr.java | 2 +- .../java/org/apache/doris/analysis/Expr.java | 7 +++ .../cast_function/test_cast_map_function.out | 48 +++++++++++++++ .../cast_function/test_cast_map_function.out | 48 +++++++++++++++ .../test_cast_map_function.groovy | 60 +++++++++++++++++++ .../test_cast_map_function.groovy | 60 +++++++++++++++++++ 9 files changed, 291 insertions(+), 9 deletions(-) create mode 100644 regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out create mode 100644 regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out create mode 100644 regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy create mode 100644 regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h index 1cfa86d540e268..449e60d13e2ab9 100644 --- a/be/src/vec/functions/function_cast.h +++ b/be/src/vec/functions/function_cast.h @@ -51,6 +51,7 @@ #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_array.h" +#include "vec/columns/column_map.h" #include "vec/columns/column_nullable.h" #include "vec/columns/column_string.h" #include "vec/columns/column_struct.h" @@ -1852,13 +1853,57 @@ class FunctionCast final : public IFunctionBase { } //TODO(Amory) . Need support more cast for key , value for map - WrapperType create_map_wrapper(const DataTypePtr& from_type, const DataTypeMap& to_type) const { - switch (from_type->get_type_id()) { - case TypeIndex::String: + WrapperType create_map_wrapper(FunctionContext* context, const DataTypePtr& from_type, + const DataTypeMap& to_type) const { + if (from_type->get_type_id() == TypeIndex::String) { return &ConvertImplGenericFromString::execute; - default: - return create_unsupport_wrapper(from_type->get_name(), to_type.get_name()); } + auto from = check_and_get_data_type(from_type.get()); + if (!from) { + return create_unsupport_wrapper( + fmt::format("CAST AS Map can only be performed between Map types or from " + "String. from type: {}, to type: {}", + from_type->get_name(), to_type.get_name())); + } + DataTypes from_kv_types; + DataTypes to_kv_types; + from_kv_types.reserve(2); + to_kv_types.reserve(2); + from_kv_types.push_back(from->get_key_type()); + from_kv_types.push_back(from->get_value_type()); + to_kv_types.push_back(to_type.get_key_type()); + to_kv_types.push_back(to_type.get_value_type()); + + auto kv_wrappers = get_element_wrappers(context, from_kv_types, to_kv_types); + return [kv_wrappers, from_kv_types, to_kv_types]( + FunctionContext* context, Block& block, const ColumnNumbers& arguments, + const size_t result, size_t /*input_rows_count*/) -> Status { + auto& from_column = block.get_by_position(arguments.front()).column; + auto from_col_map = check_and_get_column(from_column.get()); + if (!from_col_map) { + return Status::RuntimeError("Illegal column {} for function CAST AS MAP", + from_column->get_name()); + } + + Columns converted_columns(2); + ColumnsWithTypeAndName columnsWithTypeAndName(2); + columnsWithTypeAndName[0] = {from_col_map->get_keys_ptr(), from_kv_types[0], ""}; + columnsWithTypeAndName[1] = {from_col_map->get_values_ptr(), from_kv_types[1], ""}; + + for (size_t i = 0; i < 2; ++i) { + ColumnNumbers element_arguments {block.columns()}; + block.insert(columnsWithTypeAndName[i]); + size_t element_result = block.columns(); + block.insert({to_kv_types[i], ""}); + RETURN_IF_ERROR(kv_wrappers[i](context, block, element_arguments, element_result, + columnsWithTypeAndName[i].column->size())); + converted_columns[i] = block.get_by_position(element_result).column; + } + + block.get_by_position(result).column = ColumnMap::create( + converted_columns[0], converted_columns[1], from_col_map->get_offsets_ptr()); + return Status::OK(); + }; } ElementWrappers get_element_wrappers(FunctionContext* context, @@ -2115,7 +2160,8 @@ class FunctionCast final : public IFunctionBase { return create_struct_wrapper(context, from_type, static_cast(*to_type)); case TypeIndex::Map: - return create_map_wrapper(from_type, static_cast(*to_type)); + return create_map_wrapper(context, from_type, + static_cast(*to_type)); case TypeIndex::HLL: return create_hll_wrapper(context, from_type, static_cast(*to_type)); diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java index e9efc83a8fcbde..691c4d7e04db68 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java @@ -192,8 +192,10 @@ protected String prettyPrint(int lpad) { } public static boolean canCastTo(MapType type, MapType targetType) { - return Type.canCastTo(type.getKeyType(), targetType.getKeyType()) - && Type.canCastTo(type.getValueType(), targetType.getValueType()); + return (targetType.getKeyType().isStringType() && type.getKeyType().isStringType() + || Type.canCastTo(type.getKeyType(), targetType.getKeyType())) + && (Type.canCastTo(type.getValueType(), targetType.getValueType()) + || targetType.getValueType().isStringType() && type.getValueType().isStringType()); } @Override diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java index 77ac485c75d4e0..fe2c132b9df0dc 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java @@ -1967,6 +1967,17 @@ public static boolean matchExactType(Type type1, Type type2, boolean ignorePreci return false; } return matchExactType(((ArrayType) type2).getItemType(), ((ArrayType) type1).getItemType()); + } else if (type2.isMapType()) { + // For types array, we also need to check contains null for case like + // cast(array as array) + if (!((MapType) type2).getIsKeyContainsNull() == ((MapType) type1).getIsKeyContainsNull()) { + return false; + } + if (!((MapType) type2).getIsValueContainsNull() == ((MapType) type1).getIsValueContainsNull()) { + return false; + } + return matchExactType(((MapType) type2).getKeyType(), ((MapType) type1).getKeyType()) + && matchExactType(((MapType) type2).getValueType(), ((MapType) type1).getValueType()); } else { return true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java index 30333953c4b446..cdb55cec3ca144 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java @@ -268,7 +268,7 @@ private void createComplexTypeCastFunction() { } else if (type.isMapType()) { fn = ScalarFunction.createBuiltin(getFnName(Type.MAP), type, Function.NullableMode.ALWAYS_NULLABLE, - Lists.newArrayList(Type.VARCHAR), false, + Lists.newArrayList(getActualArgTypes(collectChildReturnTypes())[0]), false, "doris::CastFunctions::cast_to_map_val", null, null, true); } else if (type.isStructType()) { fn = ScalarFunction.createBuiltin(getFnName(Type.STRUCT), diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index b9ee7ce15f2b97..d73fa1ee9a2cb2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -28,6 +28,7 @@ import org.apache.doris.catalog.Function; import org.apache.doris.catalog.Function.NullableMode; import org.apache.doris.catalog.FunctionSet; +import org.apache.doris.catalog.MapType; import org.apache.doris.catalog.MaterializedIndexMeta; import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.ScalarFunction; @@ -2489,6 +2490,8 @@ protected Type getActualType(Type originType) { return getActualScalarType(originType); } else if (originType.getPrimitiveType() == PrimitiveType.ARRAY) { return getActualArrayType((ArrayType) originType); + } else if (originType.getPrimitiveType().isMapType()) { + return getActualMapType((MapType) originType); } else { return originType; } @@ -2521,6 +2524,10 @@ protected Type[] getActualArgTypes(Type[] originType) { return Arrays.stream(originType).map(this::getActualType).toArray(Type[]::new); } + private MapType getActualMapType(MapType originMapType) { + return new MapType(getActualType(originMapType.getKeyType()), getActualType(originMapType.getValueType())); + } + private ArrayType getActualArrayType(ArrayType originArrayType) { return new ArrayType(getActualType(originArrayType.getItemType())); } diff --git a/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out b/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out new file mode 100644 index 00000000000000..c9da5a1c286ce9 --- /dev/null +++ b/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out @@ -0,0 +1,48 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +1 {"aa":1, "b":2, "1234567":77} +2 {"b":12, "123":7777} + +-- !select -- +{} + +-- !select -- +{} + +-- !sql1 -- +\N + +-- !sql2 -- +{"":NULL} + +-- !sql3 -- +{"1":2} + +-- !sql4 -- +{"aa":1, "b":2, "1234567":77} +{"b":12, "123":7777} + +-- !sql5 -- +{"aa":1, "b":2, "1234567":77} +{"b":12, "123":7777} + +-- !sql6 -- +{"aa":1, "b":2, "1234567":77} +{"b":12, "123":97} + +-- !sql7 -- +{"aa":"1", "b":"2", "1234567":"77"} +{"b":"12", "123":"7777"} + +-- !sql8 -- +{NULL:"1", NULL:"2", 1234567:"77"} +{NULL:"12", 123:"7777"} + +-- !sql9 -- +{NULL:1, NULL:2, 1234567:77} +{NULL:12, 123:7777} + +-- !sql10 -- +{NULL:NULL, NULL:NULL, 1234567:NULL} +{NULL:NULL, 123:NULL} + diff --git a/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out b/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out new file mode 100644 index 00000000000000..c9da5a1c286ce9 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out @@ -0,0 +1,48 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +1 {"aa":1, "b":2, "1234567":77} +2 {"b":12, "123":7777} + +-- !select -- +{} + +-- !select -- +{} + +-- !sql1 -- +\N + +-- !sql2 -- +{"":NULL} + +-- !sql3 -- +{"1":2} + +-- !sql4 -- +{"aa":1, "b":2, "1234567":77} +{"b":12, "123":7777} + +-- !sql5 -- +{"aa":1, "b":2, "1234567":77} +{"b":12, "123":7777} + +-- !sql6 -- +{"aa":1, "b":2, "1234567":77} +{"b":12, "123":97} + +-- !sql7 -- +{"aa":"1", "b":"2", "1234567":"77"} +{"b":"12", "123":"7777"} + +-- !sql8 -- +{NULL:"1", NULL:"2", 1234567:"77"} +{NULL:"12", 123:"7777"} + +-- !sql9 -- +{NULL:1, NULL:2, 1234567:77} +{NULL:12, 123:7777} + +-- !sql10 -- +{NULL:NULL, NULL:NULL, 1234567:NULL} +{NULL:NULL, 123:NULL} + diff --git a/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy b/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy new file mode 100644 index 00000000000000..e3a4ccdaecc22f --- /dev/null +++ b/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_cast_map_function", "query") { + sql """ set enable_nereids_planner = true; """ + sql """ set enable_fallback_to_original_planner=false; """ + def tableName = "tbl_test_cast_map_function_nereids" + + sql """DROP TABLE IF EXISTS ${tableName}""" + sql """ + CREATE TABLE IF NOT EXISTS ${tableName} ( + `k1` int(11) NULL COMMENT "", + `k2` Map NOT NULL COMMENT "", + ) ENGINE=OLAP + DUPLICATE KEY(`k1`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2" + ) + """ + // insert into with implicit cast + sql """ INSERT INTO ${tableName} VALUES(1, {"aa": 1, "b": 2, "1234567": 77}) """ + sql """ INSERT INTO ${tableName} VALUES(2, {"b":12, "123":7777}) """ + + qt_select """ select * from ${tableName} order by k1; """ + + qt_select " select cast({} as MAP);" + qt_select " select cast(map() as MAP); " + qt_sql1 "select cast(NULL as MAP)" + + // literal NONSTRICT_SUPERTYPE_OF cast + qt_sql2 "select cast({'':''} as MAP);" + qt_sql3 "select cast({1:2} as MAP);" + + // select SUPERTYPE_OF cast + qt_sql4 "select cast(k2 as map) from ${tableName} order by k1;" + + // select NONSTRICT_SUPERTYPE_OF cast , this behavior is same with nested scala type + qt_sql5 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql6 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql7 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql8 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql9 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql10 "select cast(k2 as map) from ${tableName} order by k1;" +} diff --git a/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy b/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy new file mode 100644 index 00000000000000..021f8096b041bd --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_cast_map_function", "query") { + sql """set enable_nereids_planner = false """ + def tableName = "tbl_test_cast_map_function" + // array functions only supported in vectorized engine + + sql """DROP TABLE IF EXISTS ${tableName}""" + sql """ + CREATE TABLE IF NOT EXISTS ${tableName} ( + `k1` int(11) NULL COMMENT "", + `k2` Map NOT NULL COMMENT "", + ) ENGINE=OLAP + DUPLICATE KEY(`k1`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2" + ) + """ + // insert into with implicit cast + sql """ INSERT INTO ${tableName} VALUES(1, {"aa": 1, "b": 2, "1234567": 77}) """ + sql """ INSERT INTO ${tableName} VALUES(2, {"b":12, "123":7777}) """ + + qt_select """ select * from ${tableName} order by k1; """ + + qt_select " select cast({} as MAP);" + qt_select " select cast(map() as MAP); " + qt_sql1 "select cast(NULL as MAP)" + + // literal NONSTRICT_SUPERTYPE_OF cast + qt_sql2 "select cast({'':''} as MAP);" + qt_sql3 "select cast({1:2} as MAP);" + + // select SUPERTYPE_OF cast + qt_sql4 "select cast(k2 as map) from ${tableName} order by k1;" + + // select NONSTRICT_SUPERTYPE_OF cast , this behavior is same with nested scala type + qt_sql5 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql6 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql7 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql8 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql9 "select cast(k2 as map) from ${tableName} order by k1;" + qt_sql10 "select cast(k2 as map) from ${tableName} order by k1;" +}