diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 1c046abce26d..158eb3d9c8de 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -405,11 +405,26 @@ public static ExprEval ofType(@Nullable ExpressionType type, @Nullable Object va if (value instanceof List) { return bestEffortOf(value); } - return of((String) value); + if (value == null) { + return of(null); + } + return of(String.valueOf(value)); case LONG: - return ofLong((Number) value); + if (value instanceof Number) { + return ofLong((Number) value); + } + if (value instanceof String) { + return ofLong(ExprEval.computeNumber((String) value)); + } + return ofLong(null); case DOUBLE: - return ofDouble((Number) value); + if (value instanceof Number) { + return ofDouble((Number) value); + } + if (value instanceof String) { + return ofDouble(ExprEval.computeNumber((String) value)); + } + return ofDouble(null); case COMPLEX: byte[] bytes = null; if (value instanceof String) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessing.java b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessing.java index 73fcf588e445..c62f7f4b95b7 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessing.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessing.java @@ -48,13 +48,13 @@ public class ExpressionProcessing @VisibleForTesting public static void initializeForTests(@Nullable Boolean allowNestedArrays) { - INSTANCE = new ExpressionProcessingConfig(allowNestedArrays, null); + INSTANCE = new ExpressionProcessingConfig(allowNestedArrays, null, null); } @VisibleForTesting public static void initializeForStrictBooleansTests(boolean useStrict) { - INSTANCE = new ExpressionProcessingConfig(null, useStrict); + INSTANCE = new ExpressionProcessingConfig(null, useStrict, null); } /** @@ -81,4 +81,16 @@ public static boolean useStrictBooleans() } return INSTANCE.isUseStrictBooleans(); } + + + public static boolean processArraysAsMultiValueStrings() + { + // this should only be null in a unit test context, in production this will be injected by the null handling module + if (INSTANCE == null) { + throw new IllegalStateException( + "ExpressionProcessing module not initialized, call ExpressionProcessing.initializeForTests()" + ); + } + return INSTANCE.processArraysAsMultiValueStrings(); + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingConfig.java b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingConfig.java index f933f8ca6543..1009a9d0c299 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingConfig.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingConfig.java @@ -28,6 +28,9 @@ public class ExpressionProcessingConfig { public static final String NESTED_ARRAYS_CONFIG_STRING = "druid.expressions.allowNestedArrays"; public static final String NULL_HANDLING_LEGACY_LOGICAL_OPS_STRING = "druid.expressions.useStrictBooleans"; + // Coerce arrays to multi value strings + public static final String + PROCESS_ARRAYS_AS_MULTIVALUE_STRINGS_CONFIG_STRING = "druid.expressions.processArraysAsMultiValueStrings"; @JsonProperty("allowNestedArrays") private final boolean allowNestedArrays; @@ -35,10 +38,14 @@ public class ExpressionProcessingConfig @JsonProperty("useStrictBooleans") private final boolean useStrictBooleans; + @JsonProperty("processArraysAsMultiValueStrings") + private final boolean processArraysAsMultiValueStrings; + @JsonCreator public ExpressionProcessingConfig( @JsonProperty("allowNestedArrays") @Nullable Boolean allowNestedArrays, - @JsonProperty("useStrictBooleans") @Nullable Boolean useStrictBooleans + @JsonProperty("useStrictBooleans") @Nullable Boolean useStrictBooleans, + @JsonProperty("processArraysAsMultiValueStrings") @Nullable Boolean processArraysAsMultiValueStrings ) { this.allowNestedArrays = allowNestedArrays == null @@ -51,6 +58,10 @@ public ExpressionProcessingConfig( } else { this.useStrictBooleans = useStrictBooleans; } + this.processArraysAsMultiValueStrings + = processArraysAsMultiValueStrings == null + ? Boolean.valueOf(System.getProperty(PROCESS_ARRAYS_AS_MULTIVALUE_STRINGS_CONFIG_STRING, "false")) + : processArraysAsMultiValueStrings; } public boolean allowNestedArrays() @@ -62,4 +73,9 @@ public boolean isUseStrictBooleans() { return useStrictBooleans; } + + public boolean processArraysAsMultiValueStrings() + { + return processArraysAsMultiValueStrings; + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index d351f3dd985a..1879d663d411 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -2976,6 +2976,67 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args, Expr.ObjectBinding bindings) + { + return args.get(0).eval(bindings).castTo(ExpressionType.STRING_ARRAY); + } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs exactly 1 argument of type String", name()); + } + IdentifierExpr expr = args.get(0).getIdentifierExprIfIdentifierExpr(); + + if (expr == null) { + throw new IAE( + "Arg %s should be an identifier expression ie refer to columns directaly. Use array() instead", + args.get(0).toString() + ); + } + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.STRING_ARRAY; + } + + @Override + public boolean hasArrayInputs() + { + return true; + } + + @Override + public boolean hasArrayOutput() + { + return true; + } + + @Override + public Set getScalarInputs(List args) + { + return Collections.emptySet(); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + } class ArrayConstructorFunction implements Function { @Override @@ -2993,6 +3054,7 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) Object[] out = new Object[length]; ExpressionType arrayType = null; + for (int i = 0; i < length; i++) { ExprEval evaluated = args.get(i).eval(bindings); arrayType = setArrayOutput(arrayType, out, i, evaluated); diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 40680e026509..dfc399b395c9 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -934,6 +934,45 @@ public void testComplexDecodeBaseArg0Unknown() ); } + @Test + public void testMVToArrayWithValidInputs() + { + assertArrayExpr("mv_to_array(x)", new String[]{"foo"}); + assertArrayExpr("mv_to_array(a)", new String[]{"foo", "bar", "baz", "foobar"}); + } + + @Test + public void testMVToArrayWithConstantLiteral() + { + expectedException.expect(IAE.class); + expectedException.expectMessage("should be an identifier expression"); + assertArrayExpr("mv_to_array('1')", null); + } + + @Test + public void testMVToArrayWithFunction() + { + expectedException.expect(IAE.class); + expectedException.expectMessage("should be an identifier expression"); + assertArrayExpr("mv_to_array(repeat('hello', 2))", null); + } + + @Test + public void testMVToArrayWithMoreArgs() + { + expectedException.expect(IAE.class); + expectedException.expectMessage("needs exactly 1 argument of type String"); + assertArrayExpr("mv_to_array(x,y)", null); + } + + @Test + public void testMVToArrayWithNoArgs() + { + expectedException.expect(IAE.class); + expectedException.expectMessage("needs exactly 1 argument of type String"); + assertArrayExpr("mv_to_array()", null); + } + private void assertExpr(final String expression, @Nullable final Object expectedResult) { final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java index 83239839f40c..313264926a62 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -486,22 +486,18 @@ private class FactorizePlan FactorizePlan(ColumnSelectorFactory metricFactory) { - final List columns; - if (fields != null) { // if fields are set, we are accumulating from raw inputs, use fold expression plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), foldExpression.get()); seed = initialValue.get(); - columns = plan.getAnalysis().getRequiredBindingsList(); } else { // else we are merging intermediary results, use combine expression plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), combineExpression.get()); seed = initialCombineValue.get(); - columns = plan.getAnalysis().getRequiredBindingsList(); } bindings = new ExpressionLambdaAggregatorInputBindings( - ExpressionSelectors.createBindings(metricFactory, columns), + ExpressionSelectors.createBindings(metricFactory, plan), accumulatorId, seed ); diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java index 938fa46e00cd..f02d3596dea9 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java @@ -67,6 +67,8 @@ import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -775,6 +777,15 @@ private static int compareDimsForLimitPushDown( } else { dimCompare = comparator.compare(String.valueOf(lhsObj), String.valueOf(rhsObj)); } + } else if (dimensionType.equals(ColumnType.STRING_ARRAY)) { + final ComparableStringArray lhsArr = DimensionHandlerUtils.convertToComparableStringArray(lhsObj); + final ComparableStringArray rhsArr = DimensionHandlerUtils.convertToComparableStringArray(rhsObj); + dimCompare = Comparators.naturalNullsFirst().compare(lhsArr, rhsArr); + } else if (dimensionType.equals(ColumnType.LONG_ARRAY) + || dimensionType.equals(ColumnType.DOUBLE_ARRAY)) { + final ComparableList lhsArr = DimensionHandlerUtils.convertToList(lhsObj); + final ComparableList rhsArr = DimensionHandlerUtils.convertToList(rhsObj); + dimCompare = Comparators.naturalNullsFirst().compare(lhsArr, rhsArr); } else { dimCompare = comparator.compare((String) lhsObj, (String) rhsObj); } diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java index f0d834948199..292931c49370 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java @@ -42,6 +42,9 @@ import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.groupby.epinephelinae.column.ArrayDoubleGroupByColumnSelectorStrategy; +import org.apache.druid.query.groupby.epinephelinae.column.ArrayLongGroupByColumnSelectorStrategy; +import org.apache.druid.query.groupby.epinephelinae.column.ArrayStringGroupByColumnSelectorStrategy; import org.apache.druid.query.groupby.epinephelinae.column.DictionaryBuildingStringGroupByColumnSelectorStrategy; import org.apache.druid.query.groupby.epinephelinae.column.DoubleGroupByColumnSelectorStrategy; import org.apache.druid.query.groupby.epinephelinae.column.FloatGroupByColumnSelectorStrategy; @@ -233,7 +236,7 @@ public GroupByEngineIterator make() processingBuffer, fudgeTimestamp, dims, - isAllSingleValueDims(columnSelectorFactory, query.getDimensions()), + hasNoExplodingDimensions(columnSelectorFactory, query.getDimensions()), cardinalityForArrayAggregation ); } else { @@ -244,7 +247,7 @@ public GroupByEngineIterator make() processingBuffer, fudgeTimestamp, dims, - isAllSingleValueDims(columnSelectorFactory, query.getDimensions()) + hasNoExplodingDimensions(columnSelectorFactory, query.getDimensions()) ); } } @@ -290,6 +293,11 @@ public static int getCardinalityForArrayAggregation( if (query.getVirtualColumns().exists(Iterables.getOnlyElement(dimensions).getDimension())) { return -1; } + // We cannot support array-based aggregation on array based grouping as we we donot have all the indexes up front + // to allocate appropriate values + if (dimensions.get(0).getOutputType().equals(ColumnType.STRING_ARRAY)) { + return -1; + } final String columnName = Iterables.getOnlyElement(dimensions).getDimension(); columnCapabilities = storageAdapter.getColumnCapabilities(columnName); @@ -319,11 +327,12 @@ public static int getCardinalityForArrayAggregation( } /** - * Checks whether all "dimensions" are either single-valued, or if allowed, nonexistent. Since non-existent column - * selectors will show up as full of nulls they are effectively single valued, however they can also be null during - * broker merge, for example with an 'inline' datasource subquery. + * Checks whether all "dimensions" are either single-valued, + * or STRING_ARRAY, in case we don't want to explode the underline multi value column, + * or if allowed, nonexistent. Since non-existent columnselectors will show up as full of nulls they are effectively + * single valued, however they can also be null during broker merge, for example with an 'inline' datasource subquery. */ - public static boolean isAllSingleValueDims( + public static boolean hasNoExplodingDimensions( final ColumnInspector inspector, final List dimensions ) @@ -340,7 +349,8 @@ public static boolean isAllSingleValueDims( // Now check column capabilities, which must be present and explicitly not multi-valued final ColumnCapabilities columnCapabilities = inspector.getColumnCapabilities(dimension.getDimension()); - return columnCapabilities != null && columnCapabilities.hasMultipleValues().isFalse(); + return dimension.getOutputType().equals(ColumnType.STRING_ARRAY) + || (columnCapabilities != null && columnCapabilities.hasMultipleValues().isFalse()); }); } @@ -403,6 +413,20 @@ public GroupByColumnSelectorStrategy makeColumnSelectorStrategy( return makeNullableNumericStrategy(new FloatGroupByColumnSelectorStrategy()); case DOUBLE: return makeNullableNumericStrategy(new DoubleGroupByColumnSelectorStrategy()); + case ARRAY: + switch (capabilities.getElementType().getType()) { + case LONG: + return new ArrayLongGroupByColumnSelectorStrategy(); + case STRING: + return new ArrayStringGroupByColumnSelectorStrategy(); + case DOUBLE: + return new ArrayDoubleGroupByColumnSelectorStrategy(); + case FLOAT: + // Array not supported in expressions, ingestion + default: + throw new IAE("Cannot create query type helper from invalid type [%s]", capabilities.asTypeString()); + + } default: throw new IAE("Cannot create query type helper from invalid type [%s]", capabilities.asTypeString()); } diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index c185dddd07ef..1f825f503eed 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -72,6 +72,8 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import org.apache.druid.segment.data.IndexedInts; import org.apache.druid.segment.filter.BooleanValueMatcher; import org.apache.druid.segment.filter.Filters; @@ -750,6 +752,24 @@ public InputRawSupplierColumnSelectorStrategy makeColumnSelectorStrategy( case DOUBLE: return (InputRawSupplierColumnSelectorStrategy) columnSelector -> () -> columnSelector.isNull() ? null : columnSelector.getDouble(); + case ARRAY: + switch (capabilities.getElementType().getType()) { + case STRING: + return (InputRawSupplierColumnSelectorStrategy) + columnSelector -> + () -> DimensionHandlerUtils.convertToComparableStringArray(columnSelector.getObject()); + case FLOAT: + case LONG: + case DOUBLE: + return (InputRawSupplierColumnSelectorStrategy) + columnSelector -> + () -> DimensionHandlerUtils.convertToList(columnSelector.getObject()); + default: + throw new IAE( + "Cannot create query type helper from invalid type [%s]", + capabilities.asTypeString() + ); + } default: throw new IAE("Cannot create query type helper from invalid type [%s]", capabilities.asTypeString()); } @@ -1017,6 +1037,15 @@ private static int compareDimsInRows( lhs != null ? ((Number) lhs).doubleValue() : null, rhs != null ? ((Number) rhs).doubleValue() : null ); + } else if (fieldTypes.get(i - dimStart).equals(ColumnType.STRING_ARRAY)) { + final ComparableStringArray lhs = DimensionHandlerUtils.convertToComparableStringArray(key1.getKey()[i]); + final ComparableStringArray rhs = DimensionHandlerUtils.convertToComparableStringArray(key2.getKey()[i]); + cmp = Comparators.naturalNullsFirst().compare(lhs, rhs); + } else if (fieldTypes.get(i - dimStart).equals(ColumnType.LONG_ARRAY) + || fieldTypes.get(i - dimStart).equals(ColumnType.DOUBLE_ARRAY)) { + final ComparableList lhs = DimensionHandlerUtils.convertToList(key1.getKey()[i]); + final ComparableList rhs = DimensionHandlerUtils.convertToList(key2.getKey()[i]); + cmp = Comparators.naturalNullsFirst().compare(lhs, rhs); } else { cmp = Comparators.naturalNullsFirst().compare( (Comparable) key1.getKey()[i], @@ -1046,24 +1075,24 @@ private static int compareDimsInRowsWithAggs( final int fieldIndex = fieldIndices.get(i); final boolean needsReverse = needsReverses.get(i); final int cmp; - final Comparable lhs; - final Comparable rhs; + final Object lhs; + final Object rhs; if (aggFlags.get(i)) { if (needsReverse) { - lhs = (Comparable) entry2.getValues()[fieldIndex]; - rhs = (Comparable) entry1.getValues()[fieldIndex]; + lhs = entry2.getValues()[fieldIndex]; + rhs = entry1.getValues()[fieldIndex]; } else { - lhs = (Comparable) entry1.getValues()[fieldIndex]; - rhs = (Comparable) entry2.getValues()[fieldIndex]; + lhs = entry1.getValues()[fieldIndex]; + rhs = entry2.getValues()[fieldIndex]; } } else { if (needsReverse) { - lhs = (Comparable) entry2.getKey().getKey()[fieldIndex + dimStart]; - rhs = (Comparable) entry1.getKey().getKey()[fieldIndex + dimStart]; + lhs = entry2.getKey().getKey()[fieldIndex + dimStart]; + rhs = entry1.getKey().getKey()[fieldIndex + dimStart]; } else { - lhs = (Comparable) entry1.getKey().getKey()[fieldIndex + dimStart]; - rhs = (Comparable) entry2.getKey().getKey()[fieldIndex + dimStart]; + lhs = entry1.getKey().getKey()[fieldIndex + dimStart]; + rhs = entry2.getKey().getKey()[fieldIndex + dimStart]; } } @@ -1080,8 +1109,23 @@ private static int compareDimsInRowsWithAggs( rhs != null ? ((Number) rhs).doubleValue() : null ); } else { - cmp = Comparators.naturalNullsFirst().compare(lhs, rhs); + cmp = Comparators.naturalNullsFirst().compare((Comparable) lhs, (Comparable) rhs); } + } else if (fieldType.equals(ColumnType.STRING_ARRAY)) { + cmp = ComparableStringArray.compareWithComparator( + comparator, + DimensionHandlerUtils.convertToComparableStringArray(lhs), + DimensionHandlerUtils.convertToComparableStringArray(rhs) + ); + } else if (fieldType.equals(ColumnType.LONG_ARRAY) + || fieldType.equals(ColumnType.DOUBLE_ARRAY)) { + + cmp = ComparableList.compareWithComparator( + comparator, + DimensionHandlerUtils.convertToList(lhs), + DimensionHandlerUtils.convertToList(rhs) + ); + } else { cmp = comparator.compare( DimensionHandlerUtils.convertObjectToString(lhs), @@ -1124,6 +1168,13 @@ private static class RowBasedKeySerde implements Grouper.KeySerde dictionary; private final Object2IntMap reverseDictionary; + private final List arrayDictionary; + private final Object2IntMap reverseArrayDictionary; + + private final List listDictionary; + private final Object2IntMap reverseListDictionary; + + // Size limiting for the dictionary, in (roughly estimated) bytes. private final long maxDictionarySize; @@ -1156,7 +1207,17 @@ private static class RowBasedKeySerde implements Grouper.KeySerde() : new Object2IntOpenHashMap<>(dictionary.size()); + + this.arrayDictionary = new ArrayList<>(); + this.reverseArrayDictionary = new Object2IntOpenHashMap<>(); + + this.listDictionary = new ArrayList<>(); + this.reverseListDictionary = new Object2IntOpenHashMap<>(); + this.reverseDictionary.defaultReturnValue(UNKNOWN_DICTIONARY_ID); + this.reverseArrayDictionary.defaultReturnValue(UNKNOWN_DICTIONARY_ID); + this.reverseListDictionary.defaultReturnValue(UNKNOWN_DICTIONARY_ID); + this.maxDictionarySize = maxDictionarySize; this.serdeHelpers = makeSerdeHelpers(limitSpec != null, enableRuntimeDictionaryGeneration); this.serdeHelperComparators = new BufferComparator[serdeHelpers.length]; @@ -1357,6 +1418,21 @@ private RowBasedKeySerdeHelper makeSerdeHelper( ) { switch (valueType.getType()) { + case ARRAY: + switch (valueType.getElementType().getType()) { + case STRING: + return new ArrayStringRowBasedKeySerdeHelper( + keyBufferPosition, + stringComparator + ); + case LONG: + case FLOAT: + case DOUBLE: + return new ArrayNumericRowBasedKeySerdeHelper(keyBufferPosition, stringComparator); + default: + throw new IAE("invalid type: %s", valueType); + } + case STRING: if (enableRuntimeDictionaryGeneration) { return new DynamicDictionaryStringRowBasedKeySerdeHelper( @@ -1426,6 +1502,123 @@ private RowBasedKeySerdeHelper makeNumericSerdeHelper( } } + private class ArrayNumericRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper + { + final int keyBufferPosition; + final BufferComparator bufferComparator; + + public ArrayNumericRowBasedKeySerdeHelper( + int keyBufferPosition, + @Nullable StringComparator stringComparator + ) + { + this.keyBufferPosition = keyBufferPosition; + this.bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> + ComparableList.compareWithComparator( + stringComparator, + listDictionary.get(lhsBuffer.getInt(lhsPosition + + keyBufferPosition)), + listDictionary.get(rhsBuffer.getInt(rhsPosition + + keyBufferPosition)) + ); + } + + @Override + public int getKeyBufferValueSize() + { + return Integer.BYTES; + } + + @Override + public boolean putToKeyBuffer(RowBasedKey key, int idx) + { + final ComparableList comparableList = (ComparableList) key.getKey()[idx]; + int id = reverseDictionary.getInt(comparableList); + if (id == UNKNOWN_DICTIONARY_ID) { + id = listDictionary.size(); + reverseListDictionary.put(comparableList, id); + listDictionary.add(comparableList); + } + keyBuffer.putInt(id); + return true; + } + + @Override + public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValIdx, Comparable[] dimValues) + { + dimValues[dimValIdx] = listDictionary.get(buffer.getInt(initialOffset + keyBufferPosition)); + } + + @Override + public BufferComparator getBufferComparator() + { + return bufferComparator; + } + } + + private class ArrayStringRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper + { + final int keyBufferPosition; + final BufferComparator bufferComparator; + + ArrayStringRowBasedKeySerdeHelper( + int keyBufferPosition, + @Nullable StringComparator stringComparator + ) + { + this.keyBufferPosition = keyBufferPosition; + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> + ComparableStringArray.compareWithComparator( + stringComparator, + arrayDictionary.get(lhsBuffer.getInt(lhsPosition + + keyBufferPosition)), + arrayDictionary.get(rhsBuffer.getInt(rhsPosition + + keyBufferPosition)) + ); + } + + @Override + public int getKeyBufferValueSize() + { + return Integer.BYTES; + } + + @Override + public boolean putToKeyBuffer(RowBasedKey key, int idx) + { + ComparableStringArray comparableStringArray = (ComparableStringArray) key.getKey()[idx]; + final int id = addToArrayDictionary(comparableStringArray); + if (id < 0) { + return false; + } + keyBuffer.putInt(id); + return true; + } + + @Override + public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValIdx, Comparable[] dimValues) + { + dimValues[dimValIdx] = arrayDictionary.get(buffer.getInt(initialOffset + keyBufferPosition)); + } + + @Override + public BufferComparator getBufferComparator() + { + return bufferComparator; + } + + private int addToArrayDictionary(final ComparableStringArray s) + { + int idx = reverseArrayDictionary.getInt(s); + if (idx == UNKNOWN_DICTIONARY_ID) { + idx = arrayDictionary.size(); + reverseArrayDictionary.put(s, idx); + arrayDictionary.add(s); + } + return idx; + } + } + private abstract class AbstractStringRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper { final int keyBufferPosition; @@ -1502,7 +1695,6 @@ public boolean putToKeyBuffer(RowBasedKey key, int idx) * this returns -1. * * @param s a string - * * @return id for this string, or -1 */ private int addToDictionary(final String s) diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayDoubleGroupByColumnSelectorStrategy.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayDoubleGroupByColumnSelectorStrategy.java new file mode 100644 index 000000000000..46cfe0d9c09e --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayDoubleGroupByColumnSelectorStrategy.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.column; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.segment.ColumnValueSelector; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +public class ArrayDoubleGroupByColumnSelectorStrategy extends ArrayNumericGroupByColumnSelectorStrategy +{ + public ArrayDoubleGroupByColumnSelectorStrategy() + { + + } + + @VisibleForTesting + ArrayDoubleGroupByColumnSelectorStrategy( + List> dictionary, + Object2IntOpenHashMap> reverseDictionary + ) + { + super(dictionary, reverseDictionary); + } + + @Override + public Object getOnlyValue(ColumnValueSelector selector) + { + Object object = selector.getObject(); + if (object == null) { + return GROUP_BY_MISSING_VALUE; + } else if (object instanceof Double) { + return addToIndexedDictionary(ImmutableList.of((Double) object)); + } else if (object instanceof List) { + return addToIndexedDictionary((List) object); + } else if (object instanceof Double[]) { + return addToIndexedDictionary(Arrays.asList((Double[]) object)); + } else if (object instanceof Object[]) { + return addToIndexedDictionary(Arrays.stream(((Object[]) (object))) + .map(a -> (Double) a) + .collect(Collectors.toList())); + } else { + throw new ISE("Found unknowm type %s in ColumnValueSelector.", object.getClass().toString()); + } + } +} + + diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayLongGroupByColumnSelectorStrategy.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayLongGroupByColumnSelectorStrategy.java new file mode 100644 index 000000000000..ff137c68264c --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayLongGroupByColumnSelectorStrategy.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.column; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.segment.ColumnValueSelector; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +public class ArrayLongGroupByColumnSelectorStrategy extends ArrayNumericGroupByColumnSelectorStrategy +{ + + public ArrayLongGroupByColumnSelectorStrategy() + { + + } + + @VisibleForTesting + ArrayLongGroupByColumnSelectorStrategy( + List> dictionary, + Object2IntOpenHashMap> reverseDictionary + ) + { + super(dictionary, reverseDictionary); + } + + + @Override + public Object getOnlyValue(ColumnValueSelector selector) + { + Object object = selector.getObject(); + if (object == null) { + return GROUP_BY_MISSING_VALUE; + } else if (object instanceof Long) { + return addToIndexedDictionary(ImmutableList.of((Long) object)); + } else if (object instanceof List) { + return addToIndexedDictionary((List) object); + } else if (object instanceof Long[]) { + return addToIndexedDictionary(Arrays.asList((Long[]) object)); + } else if (object instanceof Object[]) { + return addToIndexedDictionary(Arrays.stream(((Object[]) (object))) + .map(a -> (Long) a) + .collect(Collectors.toList())); + } else { + throw new ISE("Found unknowm type %s in ColumnValueSelector.", object.getClass().toString()); + } + } +} diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayNumericGroupByColumnSelectorStrategy.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayNumericGroupByColumnSelectorStrategy.java new file mode 100644 index 000000000000..55dc29feda7e --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayNumericGroupByColumnSelectorStrategy.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.column; + +import com.google.common.annotations.VisibleForTesting; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.groupby.epinephelinae.Grouper; +import org.apache.druid.query.ordering.StringComparator; +import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.data.ComparableList; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public abstract class ArrayNumericGroupByColumnSelectorStrategy implements GroupByColumnSelectorStrategy +{ + protected static final int GROUP_BY_MISSING_VALUE = -1; + + protected final List> dictionary; + protected final Object2IntOpenHashMap> reverseDictionary; + + public ArrayNumericGroupByColumnSelectorStrategy() + { + dictionary = new ArrayList<>(); + reverseDictionary = new Object2IntOpenHashMap<>(); + reverseDictionary.defaultReturnValue(-1); + } + + @VisibleForTesting + ArrayNumericGroupByColumnSelectorStrategy( + List> dictionary, + Object2IntOpenHashMap> reverseDictionary + ) + { + this.dictionary = dictionary; + this.reverseDictionary = reverseDictionary; + } + + @Override + public int getGroupingKeySize() + { + return Integer.BYTES; + } + + @Override + public void processValueFromGroupingKey( + GroupByColumnSelectorPlus selectorPlus, + ByteBuffer key, + ResultRow resultRow, + int keyBufferPosition + ) + { + final int id = key.getInt(keyBufferPosition); + + // GROUP_BY_MISSING_VALUE is used to indicate empty rows, which are omitted from the result map. + if (id != GROUP_BY_MISSING_VALUE) { + final List value = dictionary.get(id); + resultRow.set(selectorPlus.getResultRowPosition(), new ComparableList(value)); + } else { + resultRow.set(selectorPlus.getResultRowPosition(), null); + } + } + + @Override + public void initColumnValues(ColumnValueSelector selector, int columnIndex, Object[] valuess) + { + final int groupingKey = (int) getOnlyValue(selector); + valuess[columnIndex] = groupingKey; + } + + @Override + public void initGroupingKeyColumnValue( + int keyBufferPosition, + int columnIndex, + Object rowObj, + ByteBuffer keyBuffer, + int[] stack + ) + { + final int groupingKey = (int) rowObj; + writeToKeyBuffer(keyBufferPosition, groupingKey, keyBuffer); + if (groupingKey == GROUP_BY_MISSING_VALUE) { + stack[columnIndex] = 0; + } else { + stack[columnIndex] = 1; + } + + } + + @Override + public boolean checkRowIndexAndAddValueToGroupingKey( + int keyBufferPosition, + Object rowObj, + int rowValIdx, + ByteBuffer keyBuffer + ) + { + return false; + } + + @Override + public abstract Object getOnlyValue(ColumnValueSelector selector); + + + @Override + public void writeToKeyBuffer(int keyBufferPosition, Object obj, ByteBuffer keyBuffer) + { + keyBuffer.putInt(keyBufferPosition, (int) obj); + } + + int addToIndexedDictionary(List t) + { + final int dictId = reverseDictionary.getInt(t); + if (dictId < 0) { + final int size = dictionary.size(); + dictionary.add(t); + reverseDictionary.put(t, size); + return size; + } + return dictId; + } + + @Override + public Grouper.BufferComparator bufferComparator(int keyBufferPosition, @Nullable StringComparator stringComparator) + { + StringComparator comparator = stringComparator == null ? StringComparators.NUMERIC : stringComparator; + return (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> { + List lhs = dictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)); + List rhs = dictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition)); + + int minLength = Math.min(lhs.size(), rhs.size()); + if (lhs == rhs) { + return 0; + } else { + for (int i = 0; i < minLength; i++) { + final T left = lhs.get(i); + final T right = rhs.get(i); + final int cmp; + if (left == null && right == null) { + cmp = 0; + } else if (left == null) { + cmp = -1; + } else { + cmp = comparator.compare(String.valueOf(lhs.get(i)), String.valueOf(rhs.get(i))); + } + if (cmp == 0) { + continue; + } + return cmp; + } + if (lhs.size() == rhs.size()) { + return 0; + } else if (lhs.size() < rhs.size()) { + return -1; + } + return 1; + } + }; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayStringGroupByColumnSelectorStrategy.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayStringGroupByColumnSelectorStrategy.java new file mode 100644 index 000000000000..3d42d5a1bea0 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayStringGroupByColumnSelectorStrategy.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.column; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.groupby.epinephelinae.Grouper; +import org.apache.druid.query.ordering.StringComparator; +import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.data.ComparableIntArray; +import org.apache.druid.segment.data.ComparableStringArray; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; +import java.util.List; + +public class ArrayStringGroupByColumnSelectorStrategy + implements GroupByColumnSelectorStrategy +{ + private static final int GROUP_BY_MISSING_VALUE = -1; + + + // contains string <-> id for each element of the multi value grouping column + // for eg : [a,b,c] is the col value. dictionaryToInt will contain { a <-> 1, b <-> 2, c <-> 3} + private final BiMap dictionaryToInt; + + // stores each row as a integer array where the int represents the value in dictionaryToInt + // for eg : [a,b,c] would be converted to [1,2,3] and assigned a integer value 1. + // [1,2,3] <-> 1 + private final BiMap intListToInt; + + @Override + public int getGroupingKeySize() + { + return Integer.BYTES; + } + + public ArrayStringGroupByColumnSelectorStrategy() + { + dictionaryToInt = HashBiMap.create(); + intListToInt = HashBiMap.create(); + } + + @VisibleForTesting + ArrayStringGroupByColumnSelectorStrategy( + BiMap dictionaryToInt, + BiMap intArrayToInt + ) + { + this.dictionaryToInt = dictionaryToInt; + this.intListToInt = intArrayToInt; + } + + @Override + public void processValueFromGroupingKey( + GroupByColumnSelectorPlus selectorPlus, + ByteBuffer key, + ResultRow resultRow, + int keyBufferPosition + ) + { + final int id = key.getInt(keyBufferPosition); + + // GROUP_BY_MISSING_VALUE is used to indicate empty rows + if (id != GROUP_BY_MISSING_VALUE) { + final int[] intRepresentation = intListToInt.inverse() + .get(id).getDelegate(); + final String[] stringRepresentaion = new String[intRepresentation.length]; + for (int i = 0; i < intRepresentation.length; i++) { + stringRepresentaion[i] = dictionaryToInt.inverse().get(intRepresentation[i]); + } + resultRow.set(selectorPlus.getResultRowPosition(), ComparableStringArray.of(stringRepresentaion)); + } else { + resultRow.set(selectorPlus.getResultRowPosition(), null); + } + + } + + @Override + public void initColumnValues( + ColumnValueSelector selector, + int columnIndex, + Object[] valuess + ) + { + final int groupingKey = (int) getOnlyValue(selector); + valuess[columnIndex] = groupingKey; + } + + @Override + public void initGroupingKeyColumnValue( + int keyBufferPosition, + int columnIndex, + Object rowObj, + ByteBuffer keyBuffer, + int[] stack + ) + { + final int groupingKey = (int) rowObj; + writeToKeyBuffer(keyBufferPosition, groupingKey, keyBuffer); + if (groupingKey == GROUP_BY_MISSING_VALUE) { + stack[columnIndex] = 0; + } else { + stack[columnIndex] = 1; + } + } + + @Override + public boolean checkRowIndexAndAddValueToGroupingKey( + int keyBufferPosition, + Object rowObj, + int rowValIdx, + ByteBuffer keyBuffer + ) + { + return false; + } + + @Override + public Object getOnlyValue(ColumnValueSelector selector) + { + final int[] intRepresentation; + Object object = selector.getObject(); + if (object == null) { + return GROUP_BY_MISSING_VALUE; + } else if (object instanceof String) { + intRepresentation = new int[1]; + intRepresentation[0] = addToIndexedDictionary((String) object); + } else if (object instanceof List) { + final int size = ((List) object).size(); + intRepresentation = new int[size]; + for (int i = 0; i < size; i++) { + intRepresentation[i] = addToIndexedDictionary((String) ((List) object).get(i)); + } + } else if (object instanceof String[]) { + final int size = ((String[]) object).length; + intRepresentation = new int[size]; + for (int i = 0; i < size; i++) { + intRepresentation[i] = addToIndexedDictionary(((String[]) object)[i]); + } + } else if (object instanceof Object[]) { + final int size = ((Object[]) object).length; + intRepresentation = new int[size]; + for (int i = 0; i < size; i++) { + intRepresentation[i] = addToIndexedDictionary((String) ((Object[]) object)[i]); + } + } else { + throw new ISE("Found unknowm type %s in ColumnValueSelector.", object.getClass().toString()); + } + + final ComparableIntArray comparableIntArray = ComparableIntArray.of(intRepresentation); + final int dictId = intListToInt.getOrDefault(comparableIntArray, GROUP_BY_MISSING_VALUE); + if (dictId == GROUP_BY_MISSING_VALUE) { + final int dictionarySize = intListToInt.keySet().size(); + intListToInt.put(comparableIntArray, dictionarySize); + return dictionarySize; + } else { + return dictId; + } + } + + private int addToIndexedDictionary(String value) + { + final Integer dictId = dictionaryToInt.get(value); + if (dictId == null) { + final int size = dictionaryToInt.size(); + dictionaryToInt.put(value, dictionaryToInt.size()); + return size; + } else { + return dictId; + } + } + + @Override + public void writeToKeyBuffer(int keyBufferPosition, Object obj, ByteBuffer keyBuffer) + { + keyBuffer.putInt(keyBufferPosition, (int) obj); + } + + @Override + public Grouper.BufferComparator bufferComparator(int keyBufferPosition, @Nullable StringComparator stringComparator) + { + final StringComparator comparator = stringComparator == null ? StringComparators.LEXICOGRAPHIC : stringComparator; + return (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> { + int[] lhs = intListToInt.inverse().get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)).getDelegate(); + int[] rhs = intListToInt.inverse().get(rhsBuffer.getInt(rhsPosition + keyBufferPosition)).getDelegate(); + + int minLength = Math.min(lhs.length, rhs.length); + //noinspection ArrayEquality + if (lhs == rhs) { + return 0; + } else { + for (int i = 0; i < minLength; i++) { + final int cmp = comparator.compare( + dictionaryToInt.inverse().get(lhs[i]), + dictionaryToInt.inverse().get(rhs[i]) + ); + if (cmp == 0) { + continue; + } + return cmp; + } + if (lhs.length == rhs.length) { + return 0; + } else if (lhs.length < rhs.length) { + return -1; + } + return 1; + } + }; + } +} + diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java index f66d51ac3cc0..e2cb33ca4c2e 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java @@ -47,6 +47,7 @@ import org.apache.druid.segment.StorageAdapter; import org.apache.druid.segment.VirtualColumns; import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorCursor; @@ -105,6 +106,11 @@ public static boolean canVectorizeDimensions( return false; } + if (dimension.getOutputType().getType().equals(ValueType.ARRAY)) { + // group by on arrays is not currently supported in the vector processing engine + return false; + } + // Now check column capabilities. final ColumnCapabilities columnCapabilities = inspector.getColumnCapabilities(dimension.getDimension()); // null here currently means the column does not exist, nil columns can be vectorized diff --git a/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java b/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java index 2ca0b19b399b..5597dcf597d7 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java @@ -44,9 +44,12 @@ import org.apache.druid.query.groupby.ResultRow; import org.apache.druid.query.ordering.StringComparator; import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -58,6 +61,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -202,6 +206,12 @@ public Function, Sequence> build(final GroupByQue naturalComparator = StringComparators.LEXICOGRAPHIC; } else if (columnType.isNumeric()) { naturalComparator = StringComparators.NUMERIC; + } else if (columnType.isArray()) { + if (columnType.getElementType().isNumeric()) { + naturalComparator = StringComparators.NUMERIC; + } else { + naturalComparator = StringComparators.LEXICOGRAPHIC; + } } else { sortingNeeded = true; break; @@ -375,7 +385,17 @@ public int compare(ResultRow left, ResultRow right) //noinspection unchecked nextOrdering = metricOrdering(columnIndex, aggregatorsMap.get(columnName).getComparator()); } else if (dimensionsMap.containsKey(columnName)) { - nextOrdering = dimensionOrdering(columnIndex, columnSpec.getDimensionComparator()); + Optional dimensionSpec = dimensions.stream() + .filter(ds -> ds.getOutputName().equals(columnName)) + .findFirst(); + if (!dimensionSpec.isPresent()) { + throw new ISE("Could not find the dimension spec for ordering column %s", columnName); + } + nextOrdering = dimensionOrdering( + columnIndex, + dimensionSpec.get().getOutputType(), + columnSpec.getDimensionComparator() + ); } } @@ -412,10 +432,41 @@ private Ordering metricOrdering(final int column, final Comparato return Ordering.from(Comparator.comparing(row -> (T) row.get(column), nullFriendlyComparator)); } - private Ordering dimensionOrdering(final int column, final StringComparator comparator) + private Ordering dimensionOrdering( + final int column, + final ColumnType columnType, + final StringComparator comparator + ) { + Comparator arrayComparator = null; + if (columnType.isArray()) { + if (columnType.getElementType().isNumeric()) { + arrayComparator = (Comparator) (o1, o2) -> ComparableList.compareWithComparator( + comparator, + DimensionHandlerUtils.convertToList(o1), + DimensionHandlerUtils.convertToList(o2) + ); + } else if (columnType.getElementType().equals(ColumnType.STRING)) { + arrayComparator = (Comparator) (o1, o2) -> ComparableStringArray.compareWithComparator( + comparator, + DimensionHandlerUtils.convertToComparableStringArray(o1), + DimensionHandlerUtils.convertToComparableStringArray(o2) + ); + } else { + throw new ISE("Cannot create comparator for array type %s.", columnType.toString()); + } + } return Ordering.from( - Comparator.comparing((ResultRow row) -> getDimensionValue(row, column), Comparator.nullsFirst(comparator)) + Comparator.comparing( + (ResultRow row) -> { + if (columnType.isArray()) { + return row.get(column); + } else { + return getDimensionValue(row, column); + } + }, + Comparator.nullsFirst(arrayComparator == null ? comparator : arrayComparator) + ) ); } diff --git a/processing/src/main/java/org/apache/druid/segment/DimensionHandlerUtils.java b/processing/src/main/java/org/apache/druid/segment/DimensionHandlerUtils.java index 1a652f3daf3c..3e2ce241b796 100644 --- a/processing/src/main/java/org/apache/druid/segment/DimensionHandlerUtils.java +++ b/processing/src/main/java/org/apache/druid/segment/DimensionHandlerUtils.java @@ -39,6 +39,8 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.TypeSignature; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import javax.annotation.Nullable; import java.math.BigDecimal; @@ -156,9 +158,7 @@ public static List getValueTypesFromDimensionSpecs(List ColumnSelectorPlus createColumnSelectorPlus( @@ -174,10 +174,10 @@ public static ColumnSelectorPlus * The ColumnSelectorPlus provides access to a type strategy (e.g., how to group on a float column) * and a value selector for a single column. - * + *

* A caller should define a strategy factory that provides an interface for type-specific operations * in a query engine. See GroupByStrategyFactory for a reference. * @@ -185,9 +185,7 @@ public static ColumnSelectorPlus ColumnSelectorPlus[] createColumnSelectorPluses( @@ -376,11 +374,59 @@ public static Comparable convertObjectToType( return convertObjectToDouble(obj, reportParseExceptions); case STRING: return convertObjectToString(obj); + case ARRAY: + switch (type.getElementType().getType()) { + case STRING: + return convertToComparableStringArray(obj); + default: + return convertToList(obj); + } + default: throw new IAE("Type[%s] is not supported for dimensions!", type); } } + @Nullable + public static ComparableList convertToList(Object obj) + { + if (obj == null) { + return null; + } + if (obj instanceof List) { + return new ComparableList((List) obj); + } + if (obj instanceof ComparableList) { + return (ComparableList) obj; + } + throw new ISE("Unable to convert type %s to %s", obj.getClass().getName(), ComparableList.class.getName()); + } + + + @Nullable + public static ComparableStringArray convertToComparableStringArray(Object obj) + { + if (obj == null) { + return null; + } + if (obj instanceof ComparableStringArray) { + return (ComparableStringArray) obj; + } + if (obj instanceof String[]) { + return ComparableStringArray.of((String[]) obj); + } + // Jackson converts the serialized array into a list. Converting it back to a string array + if (obj instanceof List) { + return ComparableStringArray.of((String[]) ((List) obj).toArray(new String[0])); + } + Objects[] objects = (Objects[]) obj; + String[] delegate = new String[objects.length]; + for (int i = 0; i < objects.length; i++) { + delegate[i] = convertObjectToString(objects[i]); + } + return ComparableStringArray.of(delegate); + } + public static int compareObjectsAsType( @Nullable final Object lhs, @Nullable final Object rhs, @@ -443,12 +489,11 @@ public static Double convertObjectToDouble(@Nullable Object valObj, boolean repo /** * Convert a string representing a decimal value to a long. - * + *

* If the decimal value is not an exact integral value (e.g. 42.0), or if the decimal value * is too large to be contained within a long, this function returns null. * * @param decimalStr string representing a decimal value - * * @return long equivalent of decimalStr, returns null for non-integral decimals and integral decimal values outside * of the values representable by longs */ diff --git a/processing/src/main/java/org/apache/druid/segment/data/ComparableIntArray.java b/processing/src/main/java/org/apache/druid/segment/data/ComparableIntArray.java new file mode 100644 index 000000000000..7769e98fda92 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/data/ComparableIntArray.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import java.util.Arrays; + +public class ComparableIntArray implements Comparable +{ + public static final ComparableIntArray EMPTY_ARRAY = new ComparableIntArray(new int[0]); + + final int[] delegate; + private int hashCode; + private boolean hashCodeComputed; + + private ComparableIntArray(int[] array) + { + delegate = array; + } + + @JsonCreator + public static ComparableIntArray of(int... array) + { + if (array.length == 0) { + return EMPTY_ARRAY; + } else { + return new ComparableIntArray(array); + } + } + + @JsonValue + public int[] getDelegate() + { + return delegate; + } + + @Override + public int hashCode() + { + // Check is not thread-safe, but that's fine. Even if used by multiple threads, it's ok to write these primitive + // fields more than once. + // As ComparableIntArray is used in hot loop caching the hashcode + if (!hashCodeComputed) { + hashCode = Arrays.hashCode(delegate); + hashCodeComputed = true; + } + + return hashCode; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + return Arrays.equals(delegate, ((ComparableIntArray) obj).getDelegate()); + } + + @Override + public int compareTo(ComparableIntArray rhs) + { + // rhs.getDelegate() cannot be null + if (rhs == null) { + return 1; + } + final int minSize = Math.min(this.getDelegate().length, rhs.getDelegate().length); + //noinspection ArrayEquality + if (this.delegate == rhs.getDelegate()) { + return 0; + } else { + for (int i = 0; i < minSize; i++) { + //int's cant be null + final int cmp = Integer.compare(delegate[i], rhs.getDelegate()[i]); + if (cmp == 0) { + continue; + } + return cmp; + } + if (this.getDelegate().length == rhs.getDelegate().length) { + return 0; + } else if (this.getDelegate().length < rhs.getDelegate().length) { + return -1; + } else { + return 1; + } + } + } + + @Override + public String toString() + { + return Arrays.toString(delegate); + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/data/ComparableList.java b/processing/src/main/java/org/apache/druid/segment/data/ComparableList.java new file mode 100644 index 000000000000..442d4b5a3a78 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/data/ComparableList.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.base.Preconditions; +import org.apache.druid.query.ordering.StringComparator; +import org.apache.druid.query.ordering.StringComparators; + +import java.util.List; + + +public class ComparableList implements Comparable +{ + + private final List delegate; + + public ComparableList(List input) + { + Preconditions.checkArgument( + input != null, + "Input cannot be null for %s", + ComparableList.class.getName() + ); + this.delegate = input; + } + + @JsonValue + public List getDelegate() + { + return delegate; + } + + @Override + public int hashCode() + { + return delegate.hashCode(); + } + + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + return this.delegate.equals(((ComparableList) obj).getDelegate()); + } + + @Override + public int compareTo(ComparableList rhs) + { + if (rhs == null) { + return 1; + } + + final int minSize = Math.min(this.getDelegate().size(), rhs.getDelegate().size()); + + if (this.delegate == rhs.getDelegate()) { + return 0; + } else { + for (int i = 0; i < minSize; i++) { + final int cmp; + T first = this.delegate.get(i); + Object second = rhs.getDelegate().get(i); + if (first == null && second == null) { + cmp = 0; + } else if (first == null) { + cmp = -1; + } else if (second == null) { + cmp = 1; + } else { + cmp = first.compareTo(second); + } + if (cmp == 0) { + continue; + } + return cmp; + } + if (this.getDelegate().size() == rhs.getDelegate().size()) { + return 0; + } else if (this.getDelegate().size() < rhs.getDelegate().size()) { + return -1; + } else { + return 1; + } + } + } + + @Override + public String toString() + { + return delegate.toString(); + } + + public static int compareWithComparator( + StringComparator stringComparator, + ComparableList lhsComparableArray, + ComparableList rhsComparableArray + ) + { + final StringComparator comparator = stringComparator == null + ? StringComparators.NUMERIC + : stringComparator; + + if (lhsComparableArray == null && rhsComparableArray == null) { + return 0; + } else if (lhsComparableArray == null) { + return -1; + } else if (rhsComparableArray == null) { + return 1; + } + + List lhs = lhsComparableArray.getDelegate(); + List rhs = rhsComparableArray.getDelegate(); + + int minLength = Math.min(lhs.size(), rhs.size()); + + //noinspection ArrayEquality + if (lhs == rhs) { + return 0; + } + for (int i = 0; i < minLength; i++) { + final int cmp = comparator.compare(String.valueOf(lhs.get(i)), String.valueOf(rhs.get(i))); + if (cmp == 0) { + continue; + } + return cmp; + } + if (lhs.size() == rhs.size()) { + return 0; + } else if (lhs.size() < rhs.size()) { + return -1; + } + return 1; + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/data/ComparableStringArray.java b/processing/src/main/java/org/apache/druid/segment/data/ComparableStringArray.java new file mode 100644 index 000000000000..0f66483a47ec --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/data/ComparableStringArray.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import org.apache.druid.query.ordering.StringComparator; +import org.apache.druid.query.ordering.StringComparators; + +import java.util.Arrays; + +public class ComparableStringArray implements Comparable +{ + public static final ComparableStringArray EMPTY_ARRAY = new ComparableStringArray(new String[0]); + + final String[] delegate; + private int hashCode; + private boolean hashCodeComputed; + + private ComparableStringArray(String[] array) + { + delegate = array; + } + + @JsonCreator + public static ComparableStringArray of(String... array) + { + if (array.length == 0) { + return EMPTY_ARRAY; + } else { + return new ComparableStringArray(array); + } + } + + @JsonValue + public String[] getDelegate() + { + return delegate; + } + + @Override + public int hashCode() + { + // Check is not thread-safe, but that's fine. Even if used by multiple threads, it's ok to write these primitive + // fields more than once. + // As ComparableIntArray is used in hot loop caching the hashcode + if (!hashCodeComputed) { + hashCode = Arrays.hashCode(delegate); + hashCodeComputed = true; + } + + return hashCode; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + return Arrays.equals(delegate, ((ComparableStringArray) obj).getDelegate()); + } + + @Override + public int compareTo(ComparableStringArray rhs) + { + // rhs.getDelegate() cannot be null + if (rhs == null) { + return 1; + } + final int minSize = Math.min(this.getDelegate().length, rhs.getDelegate().length); + //noinspection ArrayEquality + if (this.delegate == rhs.getDelegate()) { + return 0; + } else { + for (int i = 0; i < minSize; i++) { + final int cmp; + String first = this.delegate[i]; + String second = rhs.getDelegate()[i]; + if (first == null && second == null) { + cmp = 0; + } else if (first == null) { + cmp = -1; + } else if (second == null) { + cmp = 1; + } else { + cmp = first.compareTo(second); + } + if (cmp == 0) { + continue; + } + return cmp; + } + if (this.getDelegate().length == rhs.getDelegate().length) { + return 0; + } else if (this.getDelegate().length < rhs.getDelegate().length) { + return -1; + } else { + return 1; + } + } + } + + @Override + public String toString() + { + return Arrays.toString(delegate); + } + + + public static int compareWithComparator( + StringComparator stringComparator, + ComparableStringArray lhsComparableArray, + ComparableStringArray rhsComparableArray + ) + { + final StringComparator comparator = stringComparator == null + ? StringComparators.LEXICOGRAPHIC + : stringComparator; + if (lhsComparableArray == null && rhsComparableArray == null) { + return 0; + } else if (lhsComparableArray == null) { + return -1; + } else if (rhsComparableArray == null) { + return 1; + } + + String[] lhs = lhsComparableArray.getDelegate(); + String[] rhs = rhsComparableArray.getDelegate(); + + int minLength = Math.min(lhs.length, rhs.length); + + //noinspection ArrayEquality + if (lhs == rhs) { + return 0; + } + for (int i = 0; i < minLength; i++) { + final int cmp = comparator.compare(lhs[i], rhs[i]); + if (cmp == 0) { + continue; + } + return cmp; + } + if (lhs.length == rhs.length) { + return 0; + } else if (lhs.length < rhs.length) { + return -1; + } + return 1; + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java index 44210be3eed4..13fbe3b3fe48 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java @@ -159,7 +159,7 @@ public static ExpressionPlan plan(ColumnInspector inspector, Expr expression) outputType = expression.getOutputType(inspector); } - // if analysis predicts output, or inferred output type is array, output will be multi-valued + // if analysis predicts output, or inferred output type, is array, output will be arrays if (analysis.isOutputArray() || (outputType != null && outputType.isArray())) { traits.add(ExpressionPlan.Trait.NON_SCALAR_OUTPUT); diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 97c66ed6d04a..f9ff05c40ea8 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -158,7 +158,7 @@ public static ColumnValueSelector makeExprEvalSelector( ); } } - final Expr.ObjectBinding bindings = createBindings(plan.getAnalysis(), columnSelectorFactory); + final Expr.ObjectBinding bindings = createBindings(columnSelectorFactory, plan); // Optimization for constant expressions if (bindings.equals(InputBindings.nilBindings())) { @@ -244,29 +244,16 @@ public static boolean canMapOverDictionary( } /** - * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which - * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they - * are used as array or scalar inputs - */ - public static Expr.ObjectBinding createBindings( - Expr.BindingAnalysis bindingAnalysis, - ColumnSelectorFactory columnSelectorFactory - ) - { - final List columns = bindingAnalysis.getRequiredBindingsList(); - return createBindings(columnSelectorFactory, columns); - } - - /** - * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which + * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link ExpressionPlan} which * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they * are used as array or scalar inputs */ public static Expr.ObjectBinding createBindings( ColumnSelectorFactory columnSelectorFactory, - List columns + ExpressionPlan plan ) { + final List columns = plan.getAnalysis().getRequiredBindingsList(); final Map>> suppliers = new HashMap<>(); for (String columnName : columns) { final ColumnCapabilities columnCapabilities = columnSelectorFactory.getColumnCapabilities(columnName); @@ -274,9 +261,15 @@ public static Expr.ObjectBinding createBindings( final Supplier supplier; final ExpressionType expressionType = ExpressionType.fromColumnType(columnCapabilities); - if (columnCapabilities == null || columnCapabilities.isArray()) { + if (columnCapabilities == null || + columnCapabilities.isArray() || + (plan.is(ExpressionPlan.Trait.NON_SCALAR_OUTPUT) && !plan.is(ExpressionPlan.Trait.NEEDS_APPLIED)) + ) { // Unknown ValueType or array type. Try making an Object selector and see if that gives us anything useful. - supplier = supplierFromObjectSelector(columnSelectorFactory.makeColumnValueSelector(columnName)); + supplier = supplierFromObjectSelector( + columnSelectorFactory.makeColumnValueSelector(columnName), + plan.is(ExpressionPlan.Trait.NEEDS_APPLIED) + ); } else if (columnCapabilities.is(ValueType.FLOAT)) { ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(columnName); supplier = makeNullableNumericSupplier(selector, selector::getFloat); @@ -393,7 +386,10 @@ static Supplier supplierFromDimensionSelector(final DimensionSelector se * detected as a primitive type */ @Nullable - static Supplier supplierFromObjectSelector(final BaseObjectColumnValueSelector selector) + static Supplier supplierFromObjectSelector( + final BaseObjectColumnValueSelector selector, + boolean homogenizeMultiValue + ) { if (selector instanceof NilColumnValueSelector) { return null; @@ -408,7 +404,7 @@ static Supplier supplierFromObjectSelector(final BaseObjectColumnValueSe return () -> { final Object val = selector.getObject(); if (val instanceof List) { - NonnullPair coerced = ExprEval.coerceListToArray((List) val, true); + NonnullPair coerced = ExprEval.coerceListToArray((List) val, homogenizeMultiValue); if (coerced == null) { return null; } @@ -421,7 +417,7 @@ static Supplier supplierFromObjectSelector(final BaseObjectColumnValueSe return () -> { final Object val = selector.getObject(); if (val != null) { - NonnullPair coerced = ExprEval.coerceListToArray((List) val, true); + NonnullPair coerced = ExprEval.coerceListToArray((List) val, homogenizeMultiValue); if (coerced == null) { return null; } @@ -443,7 +439,10 @@ static Supplier supplierFromObjectSelector(final BaseObjectColumnValueSe public static Object coerceEvalToSelectorObject(ExprEval eval) { if (eval.type().isArray()) { - return Arrays.stream(eval.asArray()).collect(Collectors.toList()); + final Object[] asArray = eval.asArray(); + return asArray == null + ? null + : Arrays.stream(asArray).collect(Collectors.toList()); } return eval.value(); } diff --git a/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java b/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java index ba3eafab48bf..80b5768dd92d 100644 --- a/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java +++ b/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java @@ -85,6 +85,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -426,7 +427,14 @@ public void testGroupByExpression() ); List expectedResults = Arrays.asList( - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow( + query, + "1970", + "texpr", + NullHandling.sqlCompatible() ? "foo" : null, + "count", + 2L + ), GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1foo", "count", 2L), GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2foo", "count", 2L), GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t3foo", "count", 4L), @@ -473,13 +481,23 @@ public void testGroupByExpressionMultiMulti() query ); - List expectedResults = Arrays.asList( - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1u1", "count", 2L), - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1u2", "count", 2L), - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2u1", "count", 2L), - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2u2", "count", 2L), - GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t3u1", "count", 2L) - ); + List + expectedResults = + NullHandling.sqlCompatible() ? + Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t3u1", "count", 2L) + ) : + Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", null, "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2u2", "count", 2L) + ); TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-multi-multi"); } @@ -1089,7 +1107,15 @@ public void testTopNExpression() ImmutableList.>builder() .add(ImmutableMap.of("texpr", "t3foo", "count", 2L)) .add(ImmutableMap.of("texpr", "t5foo", "count", 2L)) - .add(ImmutableMap.of("texpr", "foo", "count", 1L)) + .add( + new HashMap() + { + { + put("texpr", NullHandling.sqlCompatible() ? "foo" : null); + put("count", 1L); + } + } + ) .add(ImmutableMap.of("texpr", "t1foo", "count", 1L)) .add(ImmutableMap.of("texpr", "t2foo", "count", 1L)) .add(ImmutableMap.of("texpr", "t4foo", "count", 1L)) diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java index b444218d6ff3..6fbfa083acea 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChestTest.java @@ -589,25 +589,26 @@ public void testResultSerde() throws Exception query.withOverriddenContext(ImmutableMap.of(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, false)) ); - final Object[] rowObjects = {DateTimes.of("2000").getMillis(), "foo", 100, 10}; + final Object[] rowObjects = {DateTimes.of("2000").getMillis(), "foo", 100, 10.0}; final ResultRow resultRow = ResultRow.of(rowObjects); + Assert.assertEquals( resultRow, arraysObjectMapper.readValue( - StringUtils.format("[%s, \"foo\", 100, 10]", DateTimes.of("2000").getMillis()), + StringUtils.format("[%s, \"foo\", 100, 10.0]", DateTimes.of("2000").getMillis()), ResultRow.class ) ); - Assert.assertEquals( + TestHelper.assertRow("", resultRow, - arraysObjectMapper.readValue( + arraysObjectMapper.readValue( StringUtils.format( "{\"version\":\"v1\"," + "\"timestamp\":\"%s\"," + "\"event\":" - + " {\"test\":\"foo\", \"rows\":100, \"post\":10}" + + " {\"test\":\"foo\", \"rows\":100, \"post\":10.0}" + "}", DateTimes.of("2000") ), diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java index a2b5fed3f3e7..b64d11c2fc5e 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java @@ -49,6 +49,7 @@ import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.js.JavaScriptConfig; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.BySegmentResultValue; import org.apache.druid.query.BySegmentResultValueClass; import org.apache.druid.query.ChainedExecutionQueryRunner; @@ -98,6 +99,7 @@ import org.apache.druid.query.extraction.SearchQuerySpecDimExtractionFn; import org.apache.druid.query.extraction.StringFormatExtractionFn; import org.apache.druid.query.extraction.StrlenExtractionFn; +import org.apache.druid.query.extraction.SubstringDimExtractionFn; import org.apache.druid.query.extraction.TimeFormatExtractionFn; import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.BoundDimFilter; @@ -129,6 +131,8 @@ import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.testing.InitializedNullHandlingTest; import org.joda.time.DateTime; @@ -1306,6 +1310,769 @@ public void testMultiValueDimension() TestHelper.assertExpectedObjects(expectedResults, results, "multi-value-dim"); } + @Test + public void testMultiValueDimensionAsArray() + { + // array types don't work with group by v1 + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage( + "GroupBy v1 only supports dimensions with an outputType of STRING."); + } + + // Cannot vectorize due to multi-value dimensions. + cannotVectorize(); + + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placementish)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.STRING_ARRAY)) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + List expectedResults = Arrays.asList( + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("a", "preferred"), "rows", 2L, "idx", 282L), + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("b", "preferred"), "rows", 2L, "idx", 230L), + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("e", "preferred"), "rows", 2L, "idx", 324L), + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("h", "preferred"), "rows", 2L, "idx", 233L), + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("m", "preferred"), "rows", 6L, "idx", 5317L), + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("n", "preferred"), "rows", 2L, "idx", 235L), + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("p", "preferred"), "rows", 6L, "idx", 5405L), + makeRow(query, "2011-04-01", "alias", ComparableStringArray.of("preferred", "t"), "rows", 4L, "idx", 420L) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "multi-value-dim-groupby-arrays"); + } + + @Test + public void testSingleValueDimensionAsArray() + { + // array types don't work with group by v1 + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage( + "GroupBy v1 only supports dimensions with an outputType of STRING"); + } + + // Cannot vectorize due to multi-value dimensions. + cannotVectorize(); + + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placement)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.STRING_ARRAY)) + .setAggregatorSpecs( + QueryRunnerTestHelper.ROWS_COUNT, + new LongSumAggregatorFactory("idx", "index") + ) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + List expectedResults = ImmutableList.of( + makeRow(query, "2011-04-01", "alias", + ComparableStringArray.of("preferred"), "rows", 26L, "idx", 12446L + ) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "multi-value-dim-groupby-arrays"); + } + + @Test + public void testMultiValueDimensionAsArrayWithOtherDims() + { + // array types don't work with group by v1 + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage( + "GroupBy v1 only supports dimensions with an outputType of STRING"); + } + + + // Cannot vectorize due to multi-value dimensions. + cannotVectorize(); + + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placementish)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.STRING_ARRAY), + new DefaultDimensionSpec("quality", "quality") + ) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.LEXICOGRAPHIC + ), new OrderByColumnSpec( + "quality", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.LEXICOGRAPHIC + )), + Integer.MAX_VALUE - 1 + )) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + List expectedResults = Arrays.asList( + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("a", "preferred"), + "quality", + "automotive", + "rows", + 2L, + "idx", + 282L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("b", "preferred"), + "quality", + "business", + "rows", + 2L, + "idx", + 230L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("e", "preferred"), + "quality", + "entertainment", + "rows", + 2L, + "idx", + 324L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("h", "preferred"), + "quality", + "health", + "rows", + 2L, + "idx", + 233L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("m", "preferred"), + "quality", + "mezzanine", + "rows", + 6L, + "idx", + 5317L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("n", "preferred"), + "quality", + "news", + "rows", + 2L, + "idx", + 235L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("p", "preferred"), + "quality", + "premium", + "rows", + 6L, + "idx", + 5405L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("preferred", "t"), + "quality", + "technology", + "rows", + 2L, + "idx", + 175L + ), + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("preferred", "t"), + "quality", + "travel", + "rows", + 2L, + "idx", + 245L + ) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "multi-value-dims-groupby-arrays"); + + query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placementish)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.STRING_ARRAY), + new DefaultDimensionSpec("quality", "quality") + ) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.LEXICOGRAPHIC + ), + new OrderByColumnSpec( + "quality", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.LEXICOGRAPHIC + ) + ), + Integer.MAX_VALUE - 1 + )) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + Collections.reverse(expectedResults); + + results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "multi-value-dims-groupby-arrays-descending"); + } + + @Test + public void testMultiValueDimensionAsStringArrayWithoutExpression() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } else if (!vectorize) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("Not supported for multi-value dimensions"); + } + + cannotVectorize(); + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions( + new DefaultDimensionSpec("placementish", "alias", ColumnType.STRING_ARRAY) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + } + + @Test + public void testSingleValueDimensionAsStringArrayWithoutExpression() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } else if (!vectorize) { + // cannot add exact class cast message due to discrepancies between various JDK versions + expectedException.expect(RuntimeException.class); + } + cannotVectorize(); + + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions( + new DefaultDimensionSpec("placement", "alias", ColumnType.STRING_ARRAY) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + List expectedResults = ImmutableList.of( + makeRow( + query, + "2011-04-01", + "alias", + ComparableStringArray.of("preferred"), + "rows", + 26L, + "idx", + 12446L + )); + TestHelper.assertExpectedObjects( + expectedResults, + results, + "single-value-dims-groupby-arrays-as-string-arrays" + ); + } + + + @Test + public void testNumericDimAsStringArrayWithoutExpression() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } else if (!vectorize) { + // cannot add exact class cast message due to discrepancies between various JDK versions + expectedException.expect(RuntimeException.class); + } + + cannotVectorize(); + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions( + new DefaultDimensionSpec("index", "alias", ColumnType.STRING_ARRAY) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + } + + + @Test + public void testMultiValueVirtualDimAsString() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } else if (!vectorize) { + // cannot add exact class cast message due to discrepancies between various JDK versions + expectedException.expect(RuntimeException.class); + } + + cannotVectorize(); + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placementish)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.STRING) + ) + .setDimensions( + new DefaultDimensionSpec("index", "alias", ColumnType.STRING_ARRAY) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + } + + @Test + public void testExtractionStringSpecWithMultiValueVirtualDimAsInput() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality"); + } + cannotVectorize(); + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placementish)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new ExtractionDimensionSpec("v0", "alias", ColumnType.STRING, + new SubstringDimExtractionFn(1, 1) + ) + ) + + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + List expectedResults = Arrays.asList( + makeRow( + query, + "2011-04-01", + "alias", + null, + "rows", + 26L, + "idx", + 12446L + ), + makeRow( + query, + "2011-04-01", + "alias", + "r", + "rows", + 26L, + "idx", + 12446L + ) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects( + expectedResults, + results, + "multi-value-extraction-spec-as-string-dim-groupby-arrays" + ); + } + + + @Test + public void testExtractionStringArraySpecWithMultiValueVirtualDimAsInput() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } else if (!vectorize) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("Not supported for multi-value dimensions"); + } + + cannotVectorize(); + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placementish)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new ExtractionDimensionSpec("v0", "alias", ColumnType.STRING_ARRAY, + new SubstringDimExtractionFn(1, 1) + ) + ) + + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + } + + @Test + public void testVirtualColumnNumericTypeAsStringArray() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } else if (!vectorize) { + // cannot add exact class cast message due to discrepancies between various JDK versions + expectedException.expect(RuntimeException.class); + } + + cannotVectorize(); + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "array(index)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.STRING_ARRAY + ) + ) + + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + } + + @Test + public void testNestedGroupByWithStringArray() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } + cannotVectorize(); + GroupByQuery inner = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "mv_to_array(placementish)", + ColumnType.STRING_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.STRING_ARRAY) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + GroupByQuery outer = makeQueryBuilder() + .setDataSource(new QueryDataSource(inner)) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions( + new DefaultDimensionSpec("alias", "alias_outer", ColumnType.STRING_ARRAY + ) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + List expectedResults = Arrays.asList( + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("a", "preferred"), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("b", "preferred"), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("e", "preferred"), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("h", "preferred"), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("m", "preferred"), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("n", "preferred"), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("p", "preferred"), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", ComparableStringArray.of("preferred", "t"), "rows", 1L) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, outer); + TestHelper.assertExpectedObjects(expectedResults, results, "multi-value-dim-nested-groupby-arrays"); + } + + @Test + public void testNestedGroupByWithLongArrays() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } + cannotVectorize(); + GroupByQuery inner = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "array(1,2)", + ColumnType.LONG_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias", ColumnType.LONG_ARRAY) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + GroupByQuery outer = makeQueryBuilder() + .setDataSource(new QueryDataSource(inner)) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions( + new DefaultDimensionSpec("alias", "alias_outer", ColumnType.LONG_ARRAY + ) + ) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + List expectedResults = ImmutableList.of( + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1L, 2L)), + "rows", 1L + )); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, outer); + TestHelper.assertExpectedObjects(expectedResults, results, "long-nested-groupby-arrays"); + } + + @Test + public void testGroupByWithLongArrays() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } + cannotVectorize(); + GroupByQuery outer = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "array(index)", + ColumnType.LONG_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias_outer", ColumnType.LONG_ARRAY) + ) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "alias_outer", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + )) + .setAggregatorSpecs( + QueryRunnerTestHelper.ROWS_COUNT + ) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + + List expectedResults = Arrays.asList( + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(78.622547)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(97.387433)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(109.705815)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(110.931934)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(112.987027)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(113.446008)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(114.290141)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(118.57034)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(119.922742)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(120.134704)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(121.583581)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(126.411364)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(135.301506)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(135.885094)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(144.507368)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(147.425935)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(158.747224)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(166.016049)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1049.738585)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1144.342401)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1193.556278)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1234.247546)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1314.839715)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1321.375057)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1447.34116)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1522.043733)), "rows", 1L) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, outer); + TestHelper.assertExpectedObjects(expectedResults, results, "long-groupby-arrays"); + } + + @Test + public void testGroupByWithLongArraysDesc() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("GroupBy v1 only supports dimensions with an outputType of STRING"); + } + cannotVectorize(); + GroupByQuery outer = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "array(index)", + ColumnType.LONG_ARRAY, + ExprMacroTable.nil() + )) + .setDimensions( + new DefaultDimensionSpec("v0", "alias_outer", ColumnType.LONG_ARRAY) + ) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "alias_outer", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE - 1 + )) + .setAggregatorSpecs( + QueryRunnerTestHelper.ROWS_COUNT + ) + .setGranularity(QueryRunnerTestHelper.ALL_GRAN) + .build(); + + + List expectedResults = Arrays.asList( + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(78.622547)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(97.387433)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(109.705815)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(110.931934)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(112.987027)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(113.446008)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(114.290141)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(118.57034)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(119.922742)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(120.134704)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(121.583581)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(126.411364)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(135.301506)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(135.885094)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(144.507368)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(147.425935)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(158.747224)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(166.016049)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1049.738585)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1144.342401)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1193.556278)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1234.247546)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1314.839715)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1321.375057)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1447.34116)), "rows", 1L), + makeRow(outer, "2011-04-01", "alias_outer", new ComparableList(ImmutableList.of(1522.043733)), "rows", 1L) + ); + // reversing list + Collections.reverse(expectedResults); + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, outer); + TestHelper.assertExpectedObjects(expectedResults, results, "long-groupby-arrays"); + } + @Test public void testTwoMultiValueDimensions() { diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java index 186a43c27af8..b80d715d0041 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java @@ -42,6 +42,7 @@ import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.data.ComparableStringArray; import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.junit.Assert; import org.junit.Test; @@ -61,7 +62,13 @@ public void testQuerySerialization() throws IOException .builder() .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) - .setDimensions(new DefaultDimensionSpec("quality", "alias")) + .setDimensions(new DefaultDimensionSpec(QueryRunnerTestHelper.QUALITY_DIMENSION, "alias"), + new DefaultDimensionSpec( + QueryRunnerTestHelper.MARKET_DIMENSION, + "market", + ColumnType.STRING_ARRAY + ) + ) .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) .setGranularity(QueryRunnerTestHelper.DAY_GRAN) .setPostAggregatorSpecs(ImmutableList.of(new FieldAccessPostAggregator("x", "idx"))) @@ -120,12 +127,13 @@ public void testRowOrderingMixTypes() .addDimension(new DefaultDimensionSpec("foo", "foo", ColumnType.LONG)) .addDimension(new DefaultDimensionSpec("bar", "bar", ColumnType.FLOAT)) .addDimension(new DefaultDimensionSpec("baz", "baz", ColumnType.STRING)) + .addDimension(new DefaultDimensionSpec("bat", "bat", ColumnType.STRING_ARRAY)) .build(); final Ordering rowOrdering = query.getRowOrdering(false); final int compare = rowOrdering.compare( - ResultRow.of(1, 1f, "a"), - ResultRow.of(1L, 1d, "b") + ResultRow.of(1, 1f, "a", ComparableStringArray.of("1", "2")), + ResultRow.of(1L, 1d, "b", ComparableStringArray.of("3")) ); Assert.assertEquals(-1, compare); } diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayDoubleGroupByColumnSelectorStrategyTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayDoubleGroupByColumnSelectorStrategyTest.java new file mode 100644 index 000000000000..7b66225f0049 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayDoubleGroupByColumnSelectorStrategyTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.column; + +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.groupby.epinephelinae.Grouper; +import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.data.ComparableList; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class ArrayDoubleGroupByColumnSelectorStrategyTest +{ + protected final List> dictionary = new ArrayList>() + { + { + add(ImmutableList.of(1.0, 2.0)); + add(ImmutableList.of(2.0, 3.0)); + add(ImmutableList.of(1.0)); + } + }; + + protected final Object2IntOpenHashMap> reverseDictionary = new Object2IntOpenHashMap>() + { + { + put(ImmutableList.of(1.0, 2.0), 0); + put(ImmutableList.of(2.0, 3.0), 1); + put(ImmutableList.of(1.0), 2); + } + }; + + private final ByteBuffer buffer1 = ByteBuffer.allocate(4); + private final ByteBuffer buffer2 = ByteBuffer.allocate(4); + + private ArrayNumericGroupByColumnSelectorStrategy strategy; + + @Before + public void setup() + { + reverseDictionary.defaultReturnValue(-1); + strategy = new ArrayDoubleGroupByColumnSelectorStrategy(dictionary, reverseDictionary); + } + + @Test + public void testKeySize() + { + Assert.assertEquals(Integer.BYTES, strategy.getGroupingKeySize()); + } + + @Test + public void testWriteKey() + { + strategy.writeToKeyBuffer(0, 1, buffer1); + Assert.assertEquals(1, buffer1.getInt(0)); + } + + @Test + public void testBufferComparatorsWithNullAndNonNullStringComprators() + { + buffer1.putInt(1); + buffer2.putInt(2); + Grouper.BufferComparator comparator = strategy.bufferComparator(0, null); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + + comparator = strategy.bufferComparator(0, StringComparators.LEXICOGRAPHIC); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + + comparator = strategy.bufferComparator(0, StringComparators.STRLEN); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + } + + @Test + public void testBufferComparator() + { + buffer1.putInt(0); + buffer2.putInt(2); + Grouper.BufferComparator comparator = strategy.bufferComparator(0, null); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + + } + + + @Test + public void testSanity() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(ImmutableList.of(1.0, 2.0)); + Assert.assertEquals(0, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + + buffer1.putInt(0); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(new ComparableList(ImmutableList.of(1.0, 2.0)), row.get(0)); + } + + + @Test + public void testAddingInDictionary() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(ImmutableList.of(4.0, 2.0)); + Assert.assertEquals(3, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + + buffer1.putInt(3); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(new ComparableList(ImmutableList.of(4.0, 2.0)), row.get(0)); + } + + @Test + public void testAddingInDictionaryWithObjects() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(new Object[]{4.0D, 2.0D}); + Assert.assertEquals(3, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + buffer1.putInt(3); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(new ComparableList(ImmutableList.of(4.0, 2.0)), row.get(0)); + } + + @After + public void tearDown() + { + buffer1.clear(); + buffer2.clear(); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayLongGroupByColumnSelectorStrategyTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayLongGroupByColumnSelectorStrategyTest.java new file mode 100644 index 000000000000..51325b500311 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayLongGroupByColumnSelectorStrategyTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.column; + +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.groupby.epinephelinae.Grouper; +import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.data.ComparableList; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +@RunWith(MockitoJUnitRunner.class) +public class ArrayLongGroupByColumnSelectorStrategyTest +{ + protected final List> dictionary = new ArrayList>() + { + { + add(ImmutableList.of(1L, 2L)); + add(ImmutableList.of(2L, 3L)); + add(ImmutableList.of(1L)); + } + }; + + protected final Object2IntOpenHashMap> reverseDictionary = new Object2IntOpenHashMap>() + { + { + put(ImmutableList.of(1L, 2L), 0); + put(ImmutableList.of(2L, 3L), 1); + put(ImmutableList.of(1L), 2); + } + }; + + private final ByteBuffer buffer1 = ByteBuffer.allocate(4); + private final ByteBuffer buffer2 = ByteBuffer.allocate(4); + + private ArrayNumericGroupByColumnSelectorStrategy strategy; + + @Before + public void setup() + { + reverseDictionary.defaultReturnValue(-1); + strategy = new ArrayLongGroupByColumnSelectorStrategy(dictionary, reverseDictionary); + } + + @Test + public void testKeySize() + { + Assert.assertEquals(Integer.BYTES, strategy.getGroupingKeySize()); + } + + @Test + public void testWriteKey() + { + strategy.writeToKeyBuffer(0, 1, buffer1); + Assert.assertEquals(1, buffer1.getInt(0)); + } + + @Test + public void testBufferComparatorsWithNullAndNonNullStringComprators() + { + buffer1.putInt(1); + buffer2.putInt(2); + Grouper.BufferComparator comparator = strategy.bufferComparator(0, null); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + + comparator = strategy.bufferComparator(0, StringComparators.LEXICOGRAPHIC); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + + comparator = strategy.bufferComparator(0, StringComparators.STRLEN); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + } + + @Test + public void testBufferComparator() + { + buffer1.putInt(0); + buffer2.putInt(2); + Grouper.BufferComparator comparator = strategy.bufferComparator(0, null); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + + } + + + @Test + public void testSanity() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(ImmutableList.of(1L, 2L)); + Assert.assertEquals(0, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + + buffer1.putInt(0); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(new ComparableList(ImmutableList.of(1L, 2L)), row.get(0)); + } + + + @Test + public void testAddingInDictionary() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(ImmutableList.of(4L, 2L)); + Assert.assertEquals(3, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + + buffer1.putInt(3); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(new ComparableList(ImmutableList.of(4L, 2L)), row.get(0)); + } + + @Test + public void testAddingInDictionaryWithObjects() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(new Object[]{4L, 2L}); + Assert.assertEquals(3, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + buffer1.putInt(3); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(new ComparableList(ImmutableList.of(4L, 2L)), row.get(0)); + } + + @After + public void tearDown() + { + buffer1.clear(); + buffer2.clear(); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayStringGroupByColumnSelectorStrategyTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayStringGroupByColumnSelectorStrategyTest.java new file mode 100644 index 000000000000..a1d70b54fc37 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/column/ArrayStringGroupByColumnSelectorStrategyTest.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae.column; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.groupby.epinephelinae.Grouper; +import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.data.ComparableIntArray; +import org.apache.druid.segment.data.ComparableStringArray; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.nio.ByteBuffer; +import java.util.HashMap; + +@RunWith(MockitoJUnitRunner.class) +public class ArrayStringGroupByColumnSelectorStrategyTest +{ + + private final BiMap DICTIONARY_INT = HashBiMap.create(new HashMap() + { + { + put("a", 0); + put("b", 1); + put("bd", 2); + put("d", 3); + put("e", 4); + } + }); + + // The dictionary has been constructed such that the values are not sorted lexicographically + // so we can tell when the comparator uses a lexicographic comparison and when it uses the indexes. + private final BiMap INDEXED_INTARRAYS = HashBiMap.create( + new HashMap() + { + { + put(ComparableIntArray.of(0, 1), 0); + put(ComparableIntArray.of(2, 4), 1); + put(ComparableIntArray.of(0, 2), 2); + } + } + ); + + + private final ByteBuffer buffer1 = ByteBuffer.allocate(4); + private final ByteBuffer buffer2 = ByteBuffer.allocate(4); + + private ArrayStringGroupByColumnSelectorStrategy strategy; + + @Before + public void setup() + { + strategy = new ArrayStringGroupByColumnSelectorStrategy(DICTIONARY_INT, INDEXED_INTARRAYS); + } + + @Test + public void testKeySize() + { + Assert.assertEquals(Integer.BYTES, strategy.getGroupingKeySize()); + } + + @Test + public void testWriteKey() + { + strategy.writeToKeyBuffer(0, 1, buffer1); + Assert.assertEquals(1, buffer1.getInt(0)); + } + + @Test + public void testBufferComparatorCanCompareIntsAndNullStringComparatorShouldUseLexicographicComparator() + { + buffer1.putInt(1); + buffer2.putInt(2); + Grouper.BufferComparator comparator = strategy.bufferComparator(0, null); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + } + + @Test + public void testBufferComparatorCanCompareIntsAndLexicographicStringComparatorShouldUseLexicographicComparator() + { + buffer1.putInt(1); + buffer2.putInt(2); + Grouper.BufferComparator comparator = strategy.bufferComparator(0, StringComparators.LEXICOGRAPHIC); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + } + + @Test + public void testBufferComparatorCanCompareIntsAndStrLenStringComparatorShouldUseLexicographicComparator() + { + buffer1.putInt(1); + buffer2.putInt(2); + Grouper.BufferComparator comparator = strategy.bufferComparator(0, StringComparators.STRLEN); + Assert.assertTrue(comparator.compare(buffer1, buffer2, 0, 0) > 0); + Assert.assertTrue(comparator.compare(buffer2, buffer1, 0, 0) < 0); + } + + @Test + public void testSanity() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(ImmutableList.of("a", "b")); + Assert.assertEquals(0, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + + buffer1.putInt(0); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(ComparableStringArray.of("a", "b"), row.get(0)); + } + + + @Test + public void testAddingInDictionary() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(ImmutableList.of("f", "a")); + Assert.assertEquals(3, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + + buffer1.putInt(3); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(ComparableStringArray.of("f", "a"), row.get(0)); + } + + @Test + public void testAddingInDictionaryWithObjects() + { + ColumnValueSelector columnValueSelector = Mockito.mock(ColumnValueSelector.class); + Mockito.when(columnValueSelector.getObject()).thenReturn(new Object[]{"f", "a"}); + Assert.assertEquals(3, strategy.getOnlyValue(columnValueSelector)); + + GroupByColumnSelectorPlus groupByColumnSelectorPlus = Mockito.mock(GroupByColumnSelectorPlus.class); + Mockito.when(groupByColumnSelectorPlus.getResultRowPosition()).thenReturn(0); + ResultRow row = ResultRow.create(1); + + buffer1.putInt(3); + strategy.processValueFromGroupingKey(groupByColumnSelectorPlus, buffer1, row, 0); + Assert.assertEquals(ComparableStringArray.of("f", "a"), row.get(0)); + } + + @After + public void tearDown() + { + buffer1.clear(); + buffer2.clear(); + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/TestHelper.java b/processing/src/test/java/org/apache/druid/segment/TestHelper.java index 40dac735d679..0d6345cbd8cc 100644 --- a/processing/src/test/java/org/apache/druid/segment/TestHelper.java +++ b/processing/src/test/java/org/apache/druid/segment/TestHelper.java @@ -40,6 +40,8 @@ import org.apache.druid.query.timeseries.TimeseriesResultValue; import org.apache.druid.query.topn.TopNResultValue; import org.apache.druid.segment.column.ColumnConfig; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import org.apache.druid.segment.writeout.SegmentWriteOutMediumFactory; import org.apache.druid.timeline.DataSegment.PruneSpecsHolder; import org.junit.Assert; @@ -372,7 +374,7 @@ private static void assertRow(String msg, Row expected, Row actual) } } - private static void assertRow(String msg, ResultRow expected, ResultRow actual) + public static void assertRow(String msg, ResultRow expected, ResultRow actual) { Assert.assertEquals( StringUtils.format("%s: row length", msg), @@ -408,6 +410,16 @@ private static void assertRow(String msg, ResultRow expected, ResultRow actual) ((Number) actualValue).doubleValue(), Math.abs(((Number) expectedValue).doubleValue() * 1e-6) ); + } else if (expectedValue instanceof ComparableStringArray && actualValue instanceof List) { + Assert.assertArrayEquals( + ((ComparableStringArray) expectedValue).getDelegate(), + ExprEval.coerceListToArray((List) actualValue, true).rhs + ); + } else if (expectedValue instanceof ComparableList && actualValue instanceof List) { + Assert.assertEquals( + ((ComparableList) expectedValue).getDelegate(), + (List) actualValue + ); } else { Assert.assertEquals( message, diff --git a/processing/src/test/java/org/apache/druid/segment/data/ComparableIntArrayTest.java b/processing/src/test/java/org/apache/druid/segment/data/ComparableIntArrayTest.java new file mode 100644 index 000000000000..cfc4e34440f2 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/data/ComparableIntArrayTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +public class ComparableIntArrayTest +{ + private final int[] array = new int[]{1, 2, 3}; + private final ComparableIntArray comparableIntArray = ComparableIntArray.of(1, 2, 3); + + @Test + public void testDelegate() + { + Assert.assertArrayEquals(array, comparableIntArray.getDelegate()); + Assert.assertEquals(0, ComparableIntArray.of(new int[0]).getDelegate().length); + Assert.assertEquals(0, ComparableIntArray.of().getDelegate().length); + } + + @Test + public void testHashCode() + { + Assert.assertEquals(Arrays.hashCode(array), comparableIntArray.hashCode()); + Set set = new HashSet<>(); + set.add(comparableIntArray); + set.add(ComparableIntArray.of(array)); + Assert.assertEquals(1, set.size()); + } + + @Test + public void testEquals() + { + Assert.assertTrue(comparableIntArray.equals(ComparableIntArray.of(array))); + Assert.assertFalse(comparableIntArray.equals(ComparableIntArray.of(1, 2, 5))); + Assert.assertFalse(comparableIntArray.equals(ComparableIntArray.EMPTY_ARRAY)); + Assert.assertFalse(comparableIntArray.equals(null)); + } + + @Test + public void testCompareTo() + { + Assert.assertEquals(0, comparableIntArray.compareTo(ComparableIntArray.of(array))); + Assert.assertEquals(1, comparableIntArray.compareTo(null)); + Assert.assertEquals(1, comparableIntArray.compareTo(ComparableIntArray.of(1, 2))); + Assert.assertEquals(-1, comparableIntArray.compareTo(ComparableIntArray.of(1, 2, 3, 4))); + Assert.assertTrue(comparableIntArray.compareTo(ComparableIntArray.of(2)) < 0); + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/data/ComparableListTest.java b/processing/src/test/java/org/apache/druid/segment/data/ComparableListTest.java new file mode 100644 index 000000000000..89a8dff12765 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/data/ComparableListTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import com.google.common.collect.ImmutableList; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class ComparableListTest +{ + private final List integers = ImmutableList.of(1, 2, 3); + private final ComparableList comparableList = new ComparableList(ImmutableList.of(1, 2, 3)); + + @Test + public void testDelegate() + { + Assert.assertEquals(integers, comparableList.getDelegate()); + Assert.assertEquals(0, new ComparableList(ImmutableList.of()).getDelegate().size()); + } + + @Test + public void testHashCode() + { + Assert.assertEquals(integers.hashCode(), comparableList.hashCode()); + Set set = new HashSet<>(); + set.add(comparableList); + set.add(new ComparableList(integers)); + Assert.assertEquals(1, set.size()); + } + + @Test + public void testEquals() + { + Assert.assertTrue(comparableList.equals(new ComparableList(integers))); + Assert.assertFalse(comparableList.equals(new ComparableList(ImmutableList.of(1, 2, 5)))); + Assert.assertFalse(comparableList.equals(null)); + } + + @Test + public void testCompareTo() + { + Assert.assertEquals(0, comparableList.compareTo(new ComparableList(integers))); + Assert.assertEquals(1, comparableList.compareTo(null)); + Assert.assertEquals(1, comparableList.compareTo(new ComparableList(ImmutableList.of(1, 2)))); + Assert.assertEquals(-1, comparableList.compareTo(new ComparableList(ImmutableList.of(1, 2, 3, 4)))); + Assert.assertTrue(comparableList.compareTo(new ComparableList(ImmutableList.of(2))) < 0); + ComparableList nullList = new ComparableList(new ArrayList() + { + { + add(null); + add(1); + } + }); + + Assert.assertTrue(comparableList.compareTo(nullList) > 0); + Assert.assertTrue(nullList.compareTo(comparableList) < 0); + Assert.assertTrue(nullList.compareTo(new ComparableList(new ArrayList() + { + { + add(null); + add(1); + } + })) == 0); + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/data/ComparableStringArrayTest.java b/processing/src/test/java/org/apache/druid/segment/data/ComparableStringArrayTest.java new file mode 100644 index 000000000000..a33ad54901a8 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/data/ComparableStringArrayTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.data; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +public class ComparableStringArrayTest +{ + private final String[] array = new String[]{"a", "b", "c"}; + private final ComparableStringArray comparableStringArray = ComparableStringArray.of("a", "b", "c"); + + @Test + public void testDelegate() + { + Assert.assertArrayEquals(array, comparableStringArray.getDelegate()); + Assert.assertEquals(0, ComparableStringArray.of(new String[0]).getDelegate().length); + Assert.assertEquals(0, ComparableStringArray.of().getDelegate().length); + } + + @Test + public void testHashCode() + { + Assert.assertEquals(Arrays.hashCode(array), comparableStringArray.hashCode()); + Set set = new HashSet<>(); + set.add(comparableStringArray); + set.add(ComparableStringArray.of(array)); + Assert.assertEquals(1, set.size()); + } + + @Test + public void testEquals() + { + Assert.assertTrue(comparableStringArray.equals(ComparableStringArray.of(array))); + Assert.assertFalse(comparableStringArray.equals(ComparableStringArray.of("a", "b", "C"))); + Assert.assertFalse(comparableStringArray.equals(ComparableStringArray.EMPTY_ARRAY)); + Assert.assertFalse(comparableStringArray.equals(null)); + } + + @Test + public void testCompareTo() + { + Assert.assertEquals(0, comparableStringArray.compareTo(ComparableStringArray.of(array))); + Assert.assertEquals(1, comparableStringArray.compareTo(null)); + Assert.assertEquals(1, comparableStringArray.compareTo(ComparableStringArray.of("a", "b"))); + Assert.assertEquals(-1, comparableStringArray.compareTo(ComparableStringArray.of("a", "b", "c", "d"))); + Assert.assertTrue(comparableStringArray.compareTo(ComparableStringArray.of("b")) < 0); + + ComparableStringArray nullList = ComparableStringArray.of(null, "a"); + + Assert.assertTrue(comparableStringArray.compareTo(nullList) > 0); + Assert.assertTrue(nullList.compareTo(comparableStringArray) < 0); + Assert.assertTrue(nullList.compareTo(ComparableStringArray.of(null, "a")) == 0); + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java index e1a9f27ded99..56d6d2d7ee1e 100644 --- a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java +++ b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java @@ -171,7 +171,8 @@ public void test_supplierFromObjectSelector_onObject() { final SettableSupplier settableSupplier = new SettableSupplier<>(); final Supplier supplier = ExpressionSelectors.supplierFromObjectSelector( - objectSelectorFromSupplier(settableSupplier, Object.class) + objectSelectorFromSupplier(settableSupplier, Object.class), + true ); Assert.assertNotNull(supplier); @@ -195,7 +196,8 @@ public void test_supplierFromObjectSelector_onNumber() { final SettableSupplier settableSupplier = new SettableSupplier<>(); final Supplier supplier = ExpressionSelectors.supplierFromObjectSelector( - objectSelectorFromSupplier(settableSupplier, Number.class) + objectSelectorFromSupplier(settableSupplier, Number.class), + true ); @@ -214,7 +216,8 @@ public void test_supplierFromObjectSelector_onString() { final SettableSupplier settableSupplier = new SettableSupplier<>(); final Supplier supplier = ExpressionSelectors.supplierFromObjectSelector( - objectSelectorFromSupplier(settableSupplier, String.class) + objectSelectorFromSupplier(settableSupplier, String.class), + true ); Assert.assertNotNull(supplier); @@ -232,7 +235,8 @@ public void test_supplierFromObjectSelector_onList() { final SettableSupplier settableSupplier = new SettableSupplier<>(); final Supplier supplier = ExpressionSelectors.supplierFromObjectSelector( - objectSelectorFromSupplier(settableSupplier, List.class) + objectSelectorFromSupplier(settableSupplier, List.class), + true ); Assert.assertNotNull(supplier); diff --git a/server/src/test/java/org/apache/druid/server/ClientQuerySegmentWalkerTest.java b/server/src/test/java/org/apache/druid/server/ClientQuerySegmentWalkerTest.java index 1ad43824f6ec..a912ac6fb397 100644 --- a/server/src/test/java/org/apache/druid/server/ClientQuerySegmentWalkerTest.java +++ b/server/src/test/java/org/apache/druid/server/ClientQuerySegmentWalkerTest.java @@ -71,6 +71,8 @@ import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import org.apache.druid.segment.join.InlineJoinableFactory; import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinType; @@ -807,19 +809,48 @@ public void testGroupByOnArraysDoubles() .setDataSource(ARRAY) .setGranularity(Granularities.ALL) .setInterval(Collections.singletonList(INTERVAL)) - .setDimensions(DefaultDimensionSpec.of("ad")) + .setDimensions( + new DefaultDimensionSpec( + "ad", + "ad", + ColumnType.DOUBLE_ARRAY + )) .build() .withId(DUMMY_QUERY_ID); + testQuery( + query, + ImmutableList.of(ExpectedQuery.cluster(query)), + ImmutableList.of( + new Object[]{new ComparableList(ImmutableList.of(1.0, 2.0))}, + new Object[]{new ComparableList(ImmutableList.of(2.0, 4.0))}, + new Object[]{new ComparableList(ImmutableList.of(3.0, 6.0))}, + new Object[]{new ComparableList(ImmutableList.of(4.0, 8.0))} + ) + ); + } - // group by cannot handle true array types, expect this, RuntimeExeception with IAE in stack trace - expectedException.expect(RuntimeException.class); - expectedException.expectMessage("Cannot create query type helper from invalid type [ARRAY]"); + @Test + public void testGroupByOnArraysDoublesAsString() + { + final GroupByQuery query = + (GroupByQuery) GroupByQuery.builder() + .setDataSource(ARRAY) + .setGranularity(Granularities.ALL) + .setInterval(Collections.singletonList(INTERVAL)) + .setDimensions(DefaultDimensionSpec.of("ad")) + .build() + .withId(DUMMY_QUERY_ID); testQuery( query, ImmutableList.of(ExpectedQuery.cluster(query)), - ImmutableList.of() + ImmutableList.of( + new Object[]{new ComparableList(ImmutableList.of(1.0, 2.0)).toString()}, + new Object[]{new ComparableList(ImmutableList.of(2.0, 4.0)).toString()}, + new Object[]{new ComparableList(ImmutableList.of(3.0, 6.0)).toString()}, + new Object[]{new ComparableList(ImmutableList.of(4.0, 8.0)).toString()} + ) ); } @@ -865,18 +896,49 @@ public void testGroupByOnArraysLongs() .setDataSource(ARRAY) .setGranularity(Granularities.ALL) .setInterval(Collections.singletonList(INTERVAL)) - .setDimensions(DefaultDimensionSpec.of("al")) + .setDimensions(new DefaultDimensionSpec( + "al", + "al", + ColumnType.LONG_ARRAY + )) .build() .withId(DUMMY_QUERY_ID); - // group by cannot handle true array types, expect this, RuntimeExeception with IAE in stack trace - expectedException.expect(RuntimeException.class); - expectedException.expectMessage("Cannot create query type helper from invalid type [ARRAY]"); testQuery( query, ImmutableList.of(ExpectedQuery.cluster(query)), - ImmutableList.of() + ImmutableList.of( + new Object[]{new ComparableList(ImmutableList.of(1L, 2L))}, + new Object[]{new ComparableList(ImmutableList.of(2L, 4L))}, + new Object[]{new ComparableList(ImmutableList.of(3L, 6L))}, + new Object[]{new ComparableList(ImmutableList.of(4L, 8L))} + ) + ); + } + + @Test + public void testGroupByOnArraysLongsAsString() + { + final GroupByQuery query = + (GroupByQuery) GroupByQuery.builder() + .setDataSource(ARRAY) + .setGranularity(Granularities.ALL) + .setInterval(Collections.singletonList(INTERVAL)) + .setDimensions(DefaultDimensionSpec.of("al")) + .build() + .withId(DUMMY_QUERY_ID); + + // when we donot define an outputType, convert {@link ComparableList} to a string + testQuery( + query, + ImmutableList.of(ExpectedQuery.cluster(query)), + ImmutableList.of( + new Object[]{new ComparableList(ImmutableList.of(1L, 2L)).toString()}, + new Object[]{new ComparableList(ImmutableList.of(2L, 4L)).toString()}, + new Object[]{new ComparableList(ImmutableList.of(3L, 6L)).toString()}, + new Object[]{new ComparableList(ImmutableList.of(4L, 8L)).toString()} + ) ); } @@ -922,22 +984,47 @@ public void testGroupByOnArraysStrings() .setDataSource(ARRAY) .setGranularity(Granularities.ALL) .setInterval(Collections.singletonList(INTERVAL)) - .setDimensions(DefaultDimensionSpec.of("as")) + .setDimensions(new DefaultDimensionSpec("as", "as", ColumnType.STRING_ARRAY)) .build() .withId(DUMMY_QUERY_ID); + testQuery( + query, + ImmutableList.of(ExpectedQuery.cluster(query)), + ImmutableList.of( + new Object[]{ComparableStringArray.of("1.0", "2.0")}, + new Object[]{ComparableStringArray.of("2.0", "4.0")}, + new Object[]{ComparableStringArray.of("3.0", "6.0")}, + new Object[]{ComparableStringArray.of("4.0", "8.0")} + ) + ); + } - // group by cannot handle true array types, expect this, RuntimeExeception with IAE in stack trace - expectedException.expect(RuntimeException.class); - expectedException.expectMessage("Cannot create query type helper from invalid type [ARRAY]"); + @Test + public void testGroupByOnArraysStringsasString() + { + final GroupByQuery query = + (GroupByQuery) GroupByQuery.builder() + .setDataSource(ARRAY) + .setGranularity(Granularities.ALL) + .setInterval(Collections.singletonList(INTERVAL)) + .setDimensions(DefaultDimensionSpec.of("as")) + .build() + .withId(DUMMY_QUERY_ID); testQuery( query, ImmutableList.of(ExpectedQuery.cluster(query)), - ImmutableList.of() + ImmutableList.of( + new Object[]{ComparableStringArray.of("1.0", "2.0").toString()}, + new Object[]{ComparableStringArray.of("2.0", "4.0").toString()}, + new Object[]{ComparableStringArray.of("3.0", "6.0").toString()}, + new Object[]{ComparableStringArray.of("4.0", "8.0").toString()} + ) ); } + @Test public void testGroupByOnArraysUnknownStrings() { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index e2213a39096f..988adbb59376 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -54,7 +54,6 @@ import org.apache.druid.sql.calcite.table.RowSignatures; import javax.annotation.Nullable; - import java.util.List; import java.util.Objects; import java.util.stream.Collectors; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayAppendOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayAppendOperatorConversion.java index 0c1aee2060d7..115a528adef6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayAppendOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayAppendOperatorConversion.java @@ -24,12 +24,12 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; public class ArrayAppendOperatorConversion implements SqlOperatorConversion @@ -50,7 +50,7 @@ public class ArrayAppendOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeInference(ReturnTypes.ARG0_NULLABLE) + .returnTypeInference(Calcites.ARG0_NULLABLE_ARRAY_RETURN_TYPE_INFERENCE) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayConcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayConcatOperatorConversion.java index 0f47c3d02b62..09aca9ef882d 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayConcatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayConcatOperatorConversion.java @@ -24,12 +24,12 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; public class ArrayConcatOperatorConversion implements SqlOperatorConversion @@ -50,7 +50,7 @@ public class ArrayConcatOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeInference(ReturnTypes.ARG0_NULLABLE) + .returnTypeInference(Calcites.ARG0_NULLABLE_ARRAY_RETURN_TYPE_INFERENCE) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayPrependOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayPrependOperatorConversion.java index 06f270d6d52b..34eee1000c1a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayPrependOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayPrependOperatorConversion.java @@ -24,12 +24,12 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; public class ArrayPrependOperatorConversion implements SqlOperatorConversion @@ -50,7 +50,7 @@ public class ArrayPrependOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeInference(ReturnTypes.ARG1_NULLABLE) + .returnTypeInference(Calcites.ARG1_NULLABLE_ARRAY_RETURN_TYPE_INFERENCE) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArraySliceOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArraySliceOperatorConversion.java index 49b670151775..0f4375f21ae2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArraySliceOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArraySliceOperatorConversion.java @@ -24,12 +24,12 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; public class ArraySliceOperatorConversion implements SqlOperatorConversion @@ -58,7 +58,7 @@ public class ArraySliceOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeInference(ReturnTypes.ARG0_NULLABLE) + .returnTypeInference(Calcites.ARG0_NULLABLE_ARRAY_RETURN_TYPE_INFERENCE) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringToArrayOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringToArrayOperatorConversion.java new file mode 100644 index 000000000000..975f16ce8b44 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringToArrayOperatorConversion.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression.builtin; + +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.OperatorConversions; +import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.planner.PlannerContext; + +import javax.annotation.Nullable; + +/** + * Function that converts a String or a Multi Value direct column to an array. + * Input expressions are not supported as one should use the array function for such cases. + **/ + +public class MultiValueStringToArrayOperatorConversion implements SqlOperatorConversion +{ + private static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder("MV_TO_ARRAY") + .operandTypeChecker( + OperandTypes.family(SqlTypeFamily.STRING) + ) + .functionCategory(SqlFunctionCategory.STRING) + .returnTypeNullableArray(SqlTypeName.VARCHAR) + .build(); + + @Override + public SqlOperator calciteOperator() + { + return SQL_FUNCTION; + } + + @Nullable + @Override + public DruidExpression toDruidExpression(PlannerContext plannerContext, RowSignature rowSignature, RexNode rexNode) + { + return OperatorConversions.convertCall( + plannerContext, + rowSignature, + rexNode, + druidExpressions -> DruidExpression.of( + null, + DruidExpression.functionCall("mv_to_array", druidExpressions) + ) + ); + } + +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index bcfac162683c..ea9290641095 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -29,7 +29,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlCollation; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.util.ConversionUtil; import org.apache.calcite.util.DateString; import org.apache.calcite.util.TimeString; @@ -38,6 +41,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.query.ordering.StringComparator; import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.segment.column.ColumnType; @@ -84,6 +88,11 @@ public class Calcites private static final Pattern TRAILING_ZEROS = Pattern.compile("\\.?0+$"); + public static final SqlReturnTypeInference + ARG0_NULLABLE_ARRAY_RETURN_TYPE_INFERENCE = new Arg0NullableArrayTypeInference(); + public static final SqlReturnTypeInference + ARG1_NULLABLE_ARRAY_RETURN_TYPE_INFERENCE = new Arg1NullableArrayTypeInference(); + private Calcites() { // No instantiation. @@ -128,17 +137,16 @@ public static String escapeStringLiteral(final String s) } /** - * Convert {@link RelDataType} to the most appropriate {@link ValueType}, coercing all ARRAY types to STRING (until - * the time is right and we are more comfortable handling Druid ARRAY types in all parts of the engine). - * - * Callers who are not scared of ARRAY types should isntead call {@link #getValueTypeForRelDataTypeFull(RelDataType)}, - * which returns the most accurate conversion of {@link RelDataType} to {@link ValueType}. + * Convert {@link RelDataType} to the most appropriate {@link ValueType} + * Caller who want to coerce all ARRAY types to STRING can set `druid.expressions.allowArrayToStringCast` + * runtime property in {@link org.apache.druid.math.expr.ExpressionProcessingConfig} */ @Nullable public static ColumnType getColumnTypeForRelDataType(final RelDataType type) { ColumnType valueType = getValueTypeForRelDataTypeFull(type); - if (valueType != null && valueType.isArray()) { + // coerce array to multi value string + if (ExpressionProcessing.processArraysAsMultiValueStrings() && valueType != null && valueType.isArray()) { return ColumnType.STRING; } return valueType; @@ -468,4 +476,38 @@ public static Class sqlTypeNameJdbcToJavaClass(SqlTypeName typeName) return Object.class; } } + + public static class Arg0NullableArrayTypeInference implements SqlReturnTypeInference + { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) + { + RelDataType type = opBinding.getOperandType(0); + if (SqlTypeUtil.isArray(type)) { + return type; + } + return Calcites.createSqlArrayTypeWithNullability( + opBinding.getTypeFactory(), + type.getSqlTypeName(), + true + ); + } + } + + public static class Arg1NullableArrayTypeInference implements SqlReturnTypeInference + { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) + { + RelDataType type = opBinding.getOperandType(1); + if (SqlTypeUtil.isArray(type)) { + return type; + } + return Calcites.createSqlArrayTypeWithNullability( + opBinding.getTypeFactory(), + type.getSqlTypeName(), + true + ); + } + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index 11244d492995..dca7d76bd6f3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -86,6 +86,7 @@ import org.apache.druid.sql.calcite.expression.builtin.LikeOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.MillisToTimestampOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.MultiValueStringOperatorConversions; +import org.apache.druid.sql.calcite.expression.builtin.MultiValueStringToArrayOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.ParseLongOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.PositionOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.RPadOperatorConversion; @@ -242,6 +243,7 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new MultiValueStringOperatorConversions.StringToMultiString()) .add(new MultiValueStringOperatorConversions.FilterOnly()) .add(new MultiValueStringOperatorConversions.FilterNone()) + .add(new MultiValueStringToArrayOperatorConversion()) .build(); private static final List REDUCTION_OPERATOR_CONVERSIONS = diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index 71b618853395..8c00d18bb495 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -87,7 +87,6 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; - import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -415,7 +414,6 @@ private static List computeDimensions( final VirtualColumn virtualColumn; - final String dimOutputName = outputNamePrefix + outputNameCounter++; if (!druidExpression.isSimpleExtraction()) { virtualColumn = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( @@ -933,6 +931,10 @@ private TopNQuery toTopNQuery(final QueryFeatureInspector queryFeatureInspector) } final DimensionSpec dimensionSpec = Iterables.getOnlyElement(grouping.getDimensions()).toDimensionSpec(); + // grouping col cannot be type array + if (dimensionSpec.getOutputType().isArray()) { + return null; + } final OrderByColumnSpec limitColumn; if (sorting.getOrderBys().isEmpty()) { limitColumn = new OrderByColumnSpec( diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java index d160230b7d71..8b40cc838da0 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java @@ -46,6 +46,8 @@ import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnHolder; +import org.apache.druid.segment.data.ComparableList; +import org.apache.druid.segment.data.ComparableStringArray; import org.apache.druid.server.QueryLifecycle; import org.apache.druid.server.QueryLifecycleFactory; import org.apache.druid.server.security.Access; @@ -340,6 +342,10 @@ private Object coerce(final Object value, final SqlTypeName sqlType) coercedValue = Arrays.asList((Double[]) value); } else if (value instanceof Object[]) { coercedValue = Arrays.asList((Object[]) value); + } else if (value instanceof ComparableStringArray) { + coercedValue = Arrays.asList(((ComparableStringArray) value).getDelegate()); + } else if (value instanceof ComparableList) { + coercedValue = ((ComparableList) value).getDelegate(); } else { throw new ISE("Cannot coerce[%s] to %s", value.getClass().getName(), sqlType); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java index fc21872c0131..480bb5870a18 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java @@ -110,7 +110,6 @@ import org.junit.rules.TemporaryFolder; import javax.annotation.Nullable; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -895,6 +894,12 @@ public void assertResultsEquals(String sql, List expectedResults, List } } + public void testQueryThrows(final String sql, Consumer expectedExceptionInitializer) + throws Exception + { + testQueryThrows(sql, new HashMap<>(QUERY_CONTEXT_DEFAULT), ImmutableList.of(), expectedExceptionInitializer); + } + public void testQueryThrows( final String sql, final Map queryContext, diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index 61e41c724801..6e429366d86b 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -26,6 +26,7 @@ import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.query.Druids; import org.apache.druid.query.Query; @@ -40,6 +41,7 @@ import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.ExpressionDimFilter; import org.apache.druid.query.filter.InDimFilter; +import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; @@ -47,6 +49,8 @@ import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.query.topn.DimensionTopNMetricSpec; +import org.apache.druid.query.topn.TopNQuery; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.join.JoinType; import org.apache.druid.sql.calcite.filtration.Filtration; @@ -55,6 +59,7 @@ import org.junit.runner.RunWith; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -74,7 +79,7 @@ public void testSelectConstantArrayExpressionFromTable() throws Exception newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("v0", "array(1,2)", ColumnType.STRING)) + .virtualColumns(expressionVirtualColumn("v0", "array(1,2)", ColumnType.LONG_ARRAY)) .columns("dim1", "v0") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(1) @@ -93,6 +98,7 @@ public void testGroupByArrayFromCase() throws Exception cannotVectorize(); testQuery( "SELECT CASE WHEN dim4 = 'a' THEN ARRAY['foo','bar','baz'] END as mv_value, count(1) from numfoo GROUP BY 1", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE3) @@ -100,9 +106,9 @@ public void testGroupByArrayFromCase() throws Exception .setVirtualColumns(expressionVirtualColumn( "v0", "case_searched((\"dim4\" == 'a'),array('foo','bar','baz'),null)", - ColumnType.STRING + ColumnType.STRING_ARRAY )) - .setDimensions(new DefaultDimensionSpec("v0", "_d0")) + .setDimensions(new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY)) .setGranularity(Granularities.ALL) .setAggregatorSpecs(new CountAggregatorFactory("a0")) .setContext(QUERY_CONTEXT_DEFAULT) @@ -110,9 +116,7 @@ public void testGroupByArrayFromCase() throws Exception ), ImmutableList.of( new Object[]{null, 3L}, - new Object[]{"bar", 3L}, - new Object[]{"baz", 3L}, - new Object[]{"foo", 3L} + new Object[]{ImmutableList.of("foo", "bar", "baz"), 3L} ) ); } @@ -126,7 +130,11 @@ public void testSelectNonConstantArrayExpressionFromTable() throws Exception newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("v0", "array(concat(\"dim1\",'word'),'up')", ColumnType.STRING)) + .virtualColumns(expressionVirtualColumn( + "v0", + "array(concat(\"dim1\",'word'),'up')", + ColumnType.STRING_ARRAY + )) .columns("dim1", "v0") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(5) @@ -150,7 +158,7 @@ public void testSelectNonConstantArrayExpressionFromTableForMultival() throws Ex final Query scanQuery = newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("v0", "array(concat(\"dim3\",'word'),'up')", ColumnType.STRING)) + .virtualColumns(expressionVirtualColumn("v0", "array(concat(\"dim3\",'word'),'up')", ColumnType.STRING_ARRAY)) .columns("dim1", "v0") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(5) @@ -183,7 +191,8 @@ public void testSelectNonConstantArrayExpressionFromTableForMultival() throws Ex // if nested arrays are not enabled, this doesn't work expectedException.expect(IAE.class); - expectedException.expectMessage("Cannot create a nested array type [ARRAY>], 'druid.expressions.allowNestedArrays' must be set to true"); + expectedException.expectMessage( + "Cannot create a nested array type [ARRAY>], 'druid.expressions.allowNestedArrays' must be set to true"); testQuery( sql, ImmutableList.of(scanQuery), @@ -288,22 +297,26 @@ public void testSomeArrayFunctionsWithScanQuery() throws Exception .intervals(querySegmentSpec(Filtration.eternity())) .virtualColumns( // these report as strings even though they are not, someday this will not be so - expressionVirtualColumn("v0", "array('a','b','c')", ColumnType.STRING), - expressionVirtualColumn("v1", "array(1,2,3)", ColumnType.STRING), - expressionVirtualColumn("v10", "array_concat(array(\"l1\"),array(\"l2\"))", ColumnType.STRING), - expressionVirtualColumn("v11", "array_concat(array(\"d1\"),array(\"d2\"))", ColumnType.STRING), - expressionVirtualColumn("v12", "array_offset(array(\"l1\"),0)", ColumnType.STRING), - expressionVirtualColumn("v13", "array_offset(array(\"d1\"),0)", ColumnType.STRING), - expressionVirtualColumn("v14", "array_ordinal(array(\"l1\"),1)", ColumnType.STRING), - expressionVirtualColumn("v15", "array_ordinal(array(\"d1\"),1)", ColumnType.STRING), - expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.STRING), - expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING), - expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING), - expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.STRING), - expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.STRING), - expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.STRING), - expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.STRING), - expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING) + expressionVirtualColumn("v0", "array('a','b','c')", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v1", "array(1,2,3)", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v10", "array_concat(array(\"l1\"),array(\"l2\"))", ColumnType.LONG_ARRAY), + expressionVirtualColumn( + "v11", + "array_concat(array(\"d1\"),array(\"d2\"))", + ColumnType.DOUBLE_ARRAY + ), + expressionVirtualColumn("v12", "array_offset(array(\"l1\"),0)", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v13", "array_offset(array(\"d1\"),0)", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v14", "array_ordinal(array(\"l1\"),1)", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v15", "array_ordinal(array(\"d1\"),1)", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING_ARRAY) ) .columns( "d1", @@ -357,13 +370,13 @@ public void testSomeArrayFunctionsWithScanQueryNoStringify() throws Exception Arrays.asList("a", "b", "c"), Arrays.asList(1L, 2L, 3L), Arrays.asList(1.9, 2.2, 4.3), - "[\"a\",\"b\",\"foo\"]", + Arrays.asList("a", "b", "foo"), Arrays.asList("foo", "a"), Arrays.asList(1L, 2L, 7L), Arrays.asList(0L, 1L, 2L), Arrays.asList(1.2, 2.2, 1.0), Arrays.asList(0.0, 1.1, 2.2), - "[\"a\",\"a\",\"b\"]", + Arrays.asList("a", "a", "b"), Arrays.asList(7L, 0L), Arrays.asList(1.0, 0.0) } @@ -377,13 +390,13 @@ public void testSomeArrayFunctionsWithScanQueryNoStringify() throws Exception Arrays.asList("a", "b", "c"), Arrays.asList(1L, 2L, 3L), Arrays.asList(1.9, 2.2, 4.3), - "[\"a\",\"b\",\"foo\"]", + Arrays.asList("a", "b", "foo"), Arrays.asList("foo", "a"), Arrays.asList(1L, 2L, 7L), Arrays.asList(null, 1L, 2L), Arrays.asList(1.2, 2.2, 1.0), Arrays.asList(null, 1.1, 2.2), - "[\"a\",\"a\",\"b\"]", + Arrays.asList("a", "a", "b"), Arrays.asList(7L, null), Arrays.asList(1.0, null) } @@ -414,19 +427,22 @@ public void testSomeArrayFunctionsWithScanQueryNoStringify() throws Exception .dataSource(CalciteTests.DATASOURCE3) .intervals(querySegmentSpec(Filtration.eternity())) .virtualColumns( - // these report as strings even though they are not, someday this will not be so - expressionVirtualColumn("v0", "array('a','b','c')", ColumnType.STRING), - expressionVirtualColumn("v1", "array(1,2,3)", ColumnType.STRING), - expressionVirtualColumn("v10", "array_concat(array(\"l1\"),array(\"l2\"))", ColumnType.STRING), - expressionVirtualColumn("v11", "array_concat(array(\"d1\"),array(\"d2\"))", ColumnType.STRING), - expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.STRING), - expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING), - expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING), - expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.STRING), - expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.STRING), - expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.STRING), - expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.STRING), - expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING) + expressionVirtualColumn("v0", "array('a','b','c')", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v1", "array(1,2,3)", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v10", "array_concat(array(\"l1\"),array(\"l2\"))", ColumnType.LONG_ARRAY), + expressionVirtualColumn( + "v11", + "array_concat(array(\"d1\"),array(\"d2\"))", + ColumnType.DOUBLE_ARRAY + ), + expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v5", "array_append(array(1,2),\"l1\")", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v6", "array_prepend(\"l2\",array(1,2))", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v7", "array_append(array(1.2,2.2),\"d1\")", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v8", "array_prepend(\"d2\",array(1.1,2.2))", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v9", "array_concat(\"dim2\",\"dim3\")", ColumnType.STRING_ARRAY) ) .columns( "dim1", @@ -554,6 +570,7 @@ public void testArrayContainsArrayOfNonLiteral() throws Exception { testQuery( "SELECT dim3 FROM druid.numfoo WHERE ARRAY_CONTAINS(dim3, ARRAY[dim2]) LIMIT 5", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, ImmutableList.of( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE3) @@ -562,12 +579,11 @@ public void testArrayContainsArrayOfNonLiteral() throws Exception .columns("dim3") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(5) - .context(QUERY_CONTEXT_DEFAULT) + .context(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) .build() ), ImmutableList.of( - new Object[]{"[\"a\",\"b\"]"}, - new Object[]{useDefault ? "" : null} + new Object[]{"[\"a\",\"b\"]"} ) ); } @@ -577,24 +593,25 @@ public void testArraySlice() throws Exception { testQuery( "SELECT ARRAY_SLICE(dim3, 1) FROM druid.numfoo", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, ImmutableList.of( new Druids.ScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE3) .intervals(querySegmentSpec(Filtration.eternity())) - .virtualColumns(expressionVirtualColumn("v0", "array_slice(\"dim3\",1)", ColumnType.STRING)) + .virtualColumns(expressionVirtualColumn("v0", "array_slice(\"dim3\",1)", ColumnType.STRING_ARRAY)) .columns(ImmutableList.of("v0")) - .context(QUERY_CONTEXT_DEFAULT) + .context(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .legacy(false) .build() ), ImmutableList.of( - new Object[]{"[\"b\"]"}, - new Object[]{"[\"c\"]"}, - new Object[]{"[]"}, - new Object[]{"[]"}, - new Object[]{"[]"}, - new Object[]{"[]"} + new Object[]{Collections.singletonList("b")}, + new Object[]{Collections.singletonList("c")}, + new Object[]{Collections.emptyList()}, + new Object[]{useDefault ? null : Collections.emptyList()}, + new Object[]{null}, + new Object[]{null} ) ); } @@ -651,26 +668,23 @@ public void testArrayAppend() throws Exception ImmutableList results; if (useDefault) { results = ImmutableList.of( - new Object[]{"foo", 6L}, - new Object[]{"", 3L}, - new Object[]{"b", 2L}, - new Object[]{"a", 1L}, - new Object[]{"c", 1L}, - new Object[]{"d", 1L} + new Object[]{null, 3L}, + new Object[]{ImmutableList.of("a", "b", "foo"), 1L}, + new Object[]{ImmutableList.of("b", "c", "foo"), 1L}, + new Object[]{ImmutableList.of("d", "foo"), 1L} ); } else { results = ImmutableList.of( - new Object[]{"foo", 6L}, new Object[]{null, 2L}, - new Object[]{"b", 2L}, - new Object[]{"", 1L}, - new Object[]{"a", 1L}, - new Object[]{"c", 1L}, - new Object[]{"d", 1L} + new Object[]{ImmutableList.of("", "foo"), 1L}, + new Object[]{ImmutableList.of("a", "b", "foo"), 1L}, + new Object[]{ImmutableList.of("b", "c", "foo"), 1L}, + new Object[]{ImmutableList.of("d", "foo"), 1L} ); } testQuery( "SELECT ARRAY_APPEND(dim3, 'foo'), SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE3) @@ -679,11 +693,11 @@ public void testArrayAppend() throws Exception .setVirtualColumns(expressionVirtualColumn( "v0", "array_append(\"dim3\",'foo')", - ColumnType.STRING + ColumnType.STRING_ARRAY )) .setDimensions( dimensions( - new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING) + new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY) ) ) .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) @@ -695,7 +709,7 @@ public void testArrayAppend() throws Exception )), Integer.MAX_VALUE )) - .setContext(QUERY_CONTEXT_DEFAULT) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) .build() ), results @@ -711,26 +725,23 @@ public void testArrayPrepend() throws Exception ImmutableList results; if (useDefault) { results = ImmutableList.of( - new Object[]{"foo", 6L}, - new Object[]{"", 3L}, - new Object[]{"b", 2L}, - new Object[]{"a", 1L}, - new Object[]{"c", 1L}, - new Object[]{"d", 1L} + new Object[]{null, 3L}, + new Object[]{ImmutableList.of("foo", "a", "b"), 1L}, + new Object[]{ImmutableList.of("foo", "b", "c"), 1L}, + new Object[]{ImmutableList.of("foo", "d"), 1L} ); } else { results = ImmutableList.of( - new Object[]{"foo", 6L}, new Object[]{null, 2L}, - new Object[]{"b", 2L}, - new Object[]{"", 1L}, - new Object[]{"a", 1L}, - new Object[]{"c", 1L}, - new Object[]{"d", 1L} + new Object[]{ImmutableList.of("foo", ""), 1L}, + new Object[]{ImmutableList.of("foo", "a", "b"), 1L}, + new Object[]{ImmutableList.of("foo", "b", "c"), 1L}, + new Object[]{ImmutableList.of("foo", "d"), 1L} ); } testQuery( "SELECT ARRAY_PREPEND('foo', dim3), SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE3) @@ -739,11 +750,11 @@ public void testArrayPrepend() throws Exception .setVirtualColumns(expressionVirtualColumn( "v0", "array_prepend('foo',\"dim3\")", - ColumnType.STRING + ColumnType.STRING_ARRAY )) .setDimensions( dimensions( - new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING) + new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY) ) ) .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) @@ -755,7 +766,7 @@ public void testArrayPrepend() throws Exception )), Integer.MAX_VALUE )) - .setContext(QUERY_CONTEXT_DEFAULT) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) .build() ), results @@ -835,24 +846,23 @@ public void testArrayConcat() throws Exception ImmutableList results; if (useDefault) { results = ImmutableList.of( - new Object[]{"", 6L}, - new Object[]{"b", 4L}, - new Object[]{"a", 2L}, - new Object[]{"c", 2L}, - new Object[]{"d", 2L} + new Object[]{null, 3L}, + new Object[]{ImmutableList.of("a", "b", "a", "b"), 1L}, + new Object[]{ImmutableList.of("b", "c", "b", "c"), 1L}, + new Object[]{ImmutableList.of("d", "d"), 1L} ); } else { results = ImmutableList.of( - new Object[]{null, 4L}, - new Object[]{"b", 4L}, - new Object[]{"", 2L}, - new Object[]{"a", 2L}, - new Object[]{"c", 2L}, - new Object[]{"d", 2L} + new Object[]{null, 2L}, + new Object[]{ImmutableList.of("", ""), 1L}, + new Object[]{ImmutableList.of("a", "b", "a", "b"), 1L}, + new Object[]{ImmutableList.of("b", "c", "b", "c"), 1L}, + new Object[]{ImmutableList.of("d", "d"), 1L} ); } testQuery( "SELECT ARRAY_CONCAT(dim3, dim3), SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE3) @@ -861,11 +871,11 @@ public void testArrayConcat() throws Exception .setVirtualColumns(expressionVirtualColumn( "v0", "array_concat(\"dim3\",\"dim3\")", - ColumnType.STRING + ColumnType.STRING_ARRAY )) .setDimensions( dimensions( - new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING) + new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY) ) ) .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) @@ -877,7 +887,7 @@ public void testArrayConcat() throws Exception )), Integer.MAX_VALUE )) - .setContext(QUERY_CONTEXT_DEFAULT) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) .build() ), results @@ -923,6 +933,197 @@ public void testArrayOffset() throws Exception ); } + @Test + public void testArrayGroupAsLongArray() throws Exception + { + // Cannot vectorize as we donot have support in native query subsytem for grouping on arrays + cannotVectorize(); + testQuery( + "SELECT ARRAY[l1], SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "array(\"l1\")", + ColumnType.LONG_ARRAY + )) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.LONG_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + )) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{ImmutableList.of(0L), 4L}, + new Object[]{ImmutableList.of(7L), 1L}, + new Object[]{ImmutableList.of(325323L), 1L} + ) : ImmutableList.of( + new Object[]{Collections.singletonList(null), 3L}, + new Object[]{ImmutableList.of(0L), 1L}, + new Object[]{ImmutableList.of(7L), 1L}, + new Object[]{ImmutableList.of(325323L), 1L} + ) + ); + } + + + @Test + public void testArrayGroupAsDoubleArray() throws Exception + { + // Cannot vectorize as we donot have support in native query subsytem for grouping on arrays as keys + cannotVectorize(); + testQuery( + "SELECT ARRAY[d1], SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "array(\"d1\")", + ColumnType.DOUBLE_ARRAY + )) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.DOUBLE_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + )) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{ImmutableList.of(0.0), 4L}, + new Object[]{ImmutableList.of(1.0), 1L}, + new Object[]{ImmutableList.of(1.7), 1L} + ) : + ImmutableList.of( + new Object[]{Collections.singletonList(null), 3L}, + new Object[]{ImmutableList.of(0.0), 1L}, + new Object[]{ImmutableList.of(1.0), 1L}, + new Object[]{ImmutableList.of(1.7), 1L} + ) + ); + } + + @Test + public void testArrayGroupAsFloatArray() throws Exception + { + // Cannot vectorize as we donot have support in native query subsytem for grouping on arrays as keys + cannotVectorize(); + testQuery( + "SELECT ARRAY[f1], SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "array(\"f1\")", + ColumnType.DOUBLE_ARRAY + )) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.DOUBLE_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + )) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{ImmutableList.of(0.0), 4L}, + new Object[]{ImmutableList.of(0.10000000149011612), 1L}, + new Object[]{ImmutableList.of(1.0), 1L} + ) : + ImmutableList.of( + new Object[]{Collections.singletonList(null), 3L}, + new Object[]{ImmutableList.of(0.0), 1L}, + new Object[]{ImmutableList.of(0.10000000149011612), 1L}, + new Object[]{ImmutableList.of(1.0), 1L} + ) + ); + } + + @Test + public void testArrayGroupAsArrayWithFunction() throws Exception + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + testQuery( + "SELECT ARRAY[ARRAY_ORDINAL(dim3, 2)], SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "array(array_ordinal(\"dim3\",2))", + ColumnType.STRING_ARRAY + )) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + ImmutableList.of( + new Object[]{Collections.singletonList(null), 4L}, + new Object[]{ImmutableList.of("b"), 1L}, + new Object[]{ImmutableList.of("c"), 1L} + ) + ); + } + @Test public void testArrayOrdinal() throws Exception { @@ -936,7 +1137,11 @@ public void testArrayOrdinal() throws Exception .setDataSource(CalciteTests.DATASOURCE3) .setInterval(querySegmentSpec(Filtration.eternity())) .setGranularity(Granularities.ALL) - .setVirtualColumns(expressionVirtualColumn("v0", "array_ordinal(\"dim3\",2)", ColumnType.STRING)) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "array_ordinal(\"dim3\",2)", + ColumnType.STRING + )) .setDimensions( dimensions( new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING) @@ -1113,24 +1318,23 @@ public void testArrayToStringToMultiValueString() throws Exception ImmutableList results; if (useDefault) { results = ImmutableList.of( - new Object[]{"d", 7L}, - new Object[]{null, 3L}, - new Object[]{"b", 2L}, - new Object[]{"a", 1L}, - new Object[]{"c", 1L} + new Object[]{ImmutableList.of("", "d"), 3L}, + new Object[]{ImmutableList.of("a", "b", "d"), 1L}, + new Object[]{ImmutableList.of("b", "c", "d"), 1L}, + new Object[]{ImmutableList.of("d", "d"), 1L} ); } else { results = ImmutableList.of( - new Object[]{"d", 5L}, new Object[]{null, 2L}, - new Object[]{"b", 2L}, - new Object[]{"", 1L}, - new Object[]{"a", 1L}, - new Object[]{"c", 1L} + new Object[]{ImmutableList.of("", "d"), 1L}, + new Object[]{ImmutableList.of("a", "b", "d"), 1L}, + new Object[]{ImmutableList.of("b", "c", "d"), 1L}, + new Object[]{ImmutableList.of("d", "d"), 1L} ); } testQuery( "SELECT STRING_TO_ARRAY(CONCAT(ARRAY_TO_STRING(dim3, ','), ',d'), ','), SUM(cnt) FROM druid.numfoo WHERE ARRAY_LENGTH(dim3) > 0 GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE3) @@ -1141,13 +1345,13 @@ public void testArrayToStringToMultiValueString() throws Exception expressionVirtualColumn( "v1", "string_to_array(concat(array_to_string(\"dim3\",','),',d'),',')", - ColumnType.STRING + ColumnType.STRING_ARRAY ) ) .setDimFilter(bound("v0", "0", null, true, false, null, StringComparators.NUMERIC)) .setDimensions( dimensions( - new DefaultDimensionSpec("v1", "_d0", ColumnType.STRING) + new DefaultDimensionSpec("v1", "_d0", ColumnType.STRING_ARRAY) ) ) .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) @@ -1640,13 +1844,65 @@ public void testArrayAggAsArrayFromJoin() throws Exception public void testArrayAggGroupByArrayAggFromSubquery() throws Exception { cannotVectorize(); - // yo, can't group on array types right now so expect failure - expectedException.expect(RuntimeException.class); - expectedException.expectMessage("Cannot create query type helper from invalid type [ARRAY]"); + testQuery( "SELECT dim2, arr, COUNT(*) FROM (SELECT dim2, ARRAY_AGG(DISTINCT dim1) as arr FROM foo WHERE dim1 is not null GROUP BY 1 LIMIT 5) GROUP BY 1,2", - ImmutableList.of(), - ImmutableList.of() + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(new TopNQuery( + new TableDataSource(CalciteTests.DATASOURCE1), + null, + new DefaultDimensionSpec( + "dim2", + "d0", + ColumnType.STRING + ), + new DimensionTopNMetricSpec( + null, + StringComparators.LEXICOGRAPHIC + ), 5, + querySegmentSpec(Filtration.eternity()), + new NotDimFilter(new SelectorDimFilter("dim1", null, null)), + Granularities.ALL, + aggregators(new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + true, + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + ExprMacroTable.nil() + )), + null, + QUERY_CONTEXT_NO_STRINGIFY_ARRAY + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(null).setGranularity(Granularities.ALL).setDimensions(dimensions( + new DefaultDimensionSpec("d0", "_d0", ColumnType.STRING), + new DefaultDimensionSpec("a0", "_d1", ColumnType.STRING_ARRAY) + )) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("_a0"))) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY).build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"", ImmutableList.of("2", "abc", "10.1"), 1L}, + new Object[]{"a", ImmutableList.of("1"), 1L}, + new Object[]{"abc", ImmutableList.of("def"), 1L} + ) : + ImmutableList.of( + new Object[]{null, ImmutableList.of("abc", "10.1"), 1L}, + new Object[]{"", ImmutableList.of("2"), 1L}, + new Object[]{"a", ImmutableList.of("", "1"), 1L}, + new Object[]{"abc", ImmutableList.of("def"), 1L} + ) ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index 8d17c701f949..49b959170ce2 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -38,6 +38,7 @@ import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.virtual.ListFilteredVirtualColumn; +import org.apache.druid.sql.SqlPlanningException; import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.util.CalciteTests; import org.junit.Test; @@ -337,8 +338,7 @@ public void testMultiValueStringContainsArrayOfNonLiteral() throws Exception .build() ), ImmutableList.of( - new Object[]{"[\"a\",\"b\"]"}, - new Object[]{useDefault ? "" : null} + new Object[]{"[\"a\",\"b\"]"} ) ); } @@ -363,9 +363,9 @@ public void testMultiValueStringSlice() throws Exception new Object[]{"[\"b\"]"}, new Object[]{"[\"c\"]"}, new Object[]{"[]"}, - new Object[]{"[]"}, - new Object[]{"[]"}, - new Object[]{"[]"} + new Object[]{useDefault ? NULL_STRING : "[]"}, + new Object[]{NULL_STRING}, + new Object[]{NULL_STRING} ) ); } @@ -422,8 +422,8 @@ public void testMultiValueStringAppend() throws Exception ImmutableList results; if (useDefault) { results = ImmutableList.of( - new Object[]{"foo", 6L}, new Object[]{"", 3L}, + new Object[]{"foo", 3L}, new Object[]{"b", 2L}, new Object[]{"a", 1L}, new Object[]{"c", 1L}, @@ -431,7 +431,7 @@ public void testMultiValueStringAppend() throws Exception ); } else { results = ImmutableList.of( - new Object[]{"foo", 6L}, + new Object[]{"foo", 4L}, new Object[]{null, 2L}, new Object[]{"b", 2L}, new Object[]{"", 1L}, @@ -482,8 +482,8 @@ public void testMultiValueStringPrepend() throws Exception ImmutableList results; if (useDefault) { results = ImmutableList.of( - new Object[]{"foo", 6L}, new Object[]{"", 3L}, + new Object[]{"foo", 3L}, new Object[]{"b", 2L}, new Object[]{"a", 1L}, new Object[]{"c", 1L}, @@ -491,7 +491,7 @@ public void testMultiValueStringPrepend() throws Exception ); } else { results = ImmutableList.of( - new Object[]{"foo", 6L}, + new Object[]{"foo", 4L}, new Object[]{null, 2L}, new Object[]{"b", 2L}, new Object[]{"", 1L}, @@ -606,16 +606,16 @@ public void testMultiValueStringConcat() throws Exception ImmutableList results; if (useDefault) { results = ImmutableList.of( - new Object[]{"", 6L}, new Object[]{"b", 4L}, + new Object[]{"", 3L}, new Object[]{"a", 2L}, new Object[]{"c", 2L}, new Object[]{"d", 2L} ); } else { results = ImmutableList.of( - new Object[]{null, 4L}, new Object[]{"b", 4L}, + new Object[]{null, 2L}, new Object[]{"", 2L}, new Object[]{"a", 2L}, new Object[]{"c", 2L}, @@ -1259,4 +1259,226 @@ public void testFilterOnMultiValueListFilterMatchLike() throws Exception ) ); } + + @Test + public void testMultiValueToArrayGroupAsArrayWithMultiValueDimension() throws Exception + { + // Cannot vectorize as we donot have support in native query subsytem for grouping on arrays as keys + cannotVectorize(); + testQuery( + "SELECT MV_TO_ARRAY(dim3), SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "mv_to_array(\"dim3\")", + ColumnType.STRING_ARRAY + )) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + )) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{null, 3L}, + new Object[]{ImmutableList.of("a", "b"), 1L}, + new Object[]{ImmutableList.of("b", "c"), 1L}, + new Object[]{ImmutableList.of("d"), 1L} + ) : + ImmutableList.of( + new Object[]{null, 2L}, + new Object[]{ImmutableList.of(""), 1L}, + new Object[]{ImmutableList.of("a", "b"), 1L}, + new Object[]{ImmutableList.of("b", "c"), 1L}, + new Object[]{ImmutableList.of("d"), 1L} + ) + ); + } + + + @Test + public void testMultiValueToArrayGroupAsArrayWithSingleValueDim() throws Exception + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + testQuery( + "SELECT MV_TO_ARRAY(dim1), SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "mv_to_array(\"dim1\")", + ColumnType.STRING_ARRAY + )) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + Integer.MAX_VALUE + )) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{null, 1L}, + new Object[]{ImmutableList.of("1"), 1L}, + new Object[]{ImmutableList.of("10.1"), 1L}, + new Object[]{ImmutableList.of("2"), 1L}, + new Object[]{ImmutableList.of("abc"), 1L}, + new Object[]{ImmutableList.of("def"), 1L} + ) : + ImmutableList.of( + new Object[]{ImmutableList.of(""), 1L}, + new Object[]{ImmutableList.of("1"), 1L}, + new Object[]{ImmutableList.of("10.1"), 1L}, + new Object[]{ImmutableList.of("2"), 1L}, + new Object[]{ImmutableList.of("abc"), 1L}, + new Object[]{ImmutableList.of("def"), 1L} + ) + ); + } + + @Test + public void testMultiValueToArrayGroupAsArrayWithSingleValueDimIsNotConvertedToTopN() throws Exception + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + // Test for method {@link org.apache.druid.sql.calcite.rel.DruidQuery.toTopNQuery()} so that it does not convert + // group by on array to topn + testQuery( + "SELECT MV_TO_ARRAY(dim1), SUM(cnt) FROM druid.numfoo GROUP BY 1 ORDER BY 2 DESC limit 10", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "mv_to_array(\"dim1\")", + ColumnType.STRING_ARRAY + )) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec(new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + )), + 10 + )) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + useDefault ? ImmutableList.of( + new Object[]{null, 1L}, + new Object[]{ImmutableList.of("1"), 1L}, + new Object[]{ImmutableList.of("10.1"), 1L}, + new Object[]{ImmutableList.of("2"), 1L}, + new Object[]{ImmutableList.of("abc"), 1L}, + new Object[]{ImmutableList.of("def"), 1L} + ) : + ImmutableList.of( + new Object[]{ImmutableList.of(""), 1L}, + new Object[]{ImmutableList.of("1"), 1L}, + new Object[]{ImmutableList.of("10.1"), 1L}, + new Object[]{ImmutableList.of("2"), 1L}, + new Object[]{ImmutableList.of("abc"), 1L}, + new Object[]{ImmutableList.of("def"), 1L} + ) + ); + } + + @Test + public void testMultiValueToArrayMoreArgs() throws Exception + { + testQueryThrows( + "SELECT MV_TO_ARRAY(dim3,dim3) FROM druid.numfoo", + exception -> { + exception.expect(SqlPlanningException.class); + exception.expectMessage("Invalid number of arguments to function"); + } + ); + } + + @Test + public void testMultiValueToArrayNoArgs() throws Exception + { + testQueryThrows( + "SELECT MV_TO_ARRAY() FROM druid.numfoo", + exception -> { + exception.expect(SqlPlanningException.class); + exception.expectMessage("Invalid number of arguments to function"); + } + ); + } + + @Test + public void testMultiValueToArrayArgsWithMultiValueDimFunc() throws Exception + { + testQueryThrows( + "SELECT MV_TO_ARRAY(concat(dim3,'c')) FROM druid.numfoo", + exception -> exception.expect(RuntimeException.class) + ); + } + + @Test + public void testMultiValueToArrayArgsWithSingleDimFunc() throws Exception + { + testQueryThrows( + "SELECT MV_TO_ARRAY(concat(dim1,'c')) FROM druid.numfoo", + exception -> exception.expect(RuntimeException.class) + ); + } + + @Test + public void testMultiValueToArrayArgsWithConstant() throws Exception + { + testQueryThrows( + "SELECT MV_TO_ARRAY(concat(dim1,'c')) FROM druid.numfoo", + exception -> exception.expect(RuntimeException.class) + ); + } + + @Test + public void testMultiValueToArrayArgsWithArray() throws Exception + { + testQueryThrows( + "SELECT MV_TO_ARRAY(Array[1,2]) FROM druid.numfoo", + exception -> exception.expect(RuntimeException.class) + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java index 86af750cc11b..049a9358940e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java @@ -108,7 +108,7 @@ public void testExpressionContainingNull() throws Exception new ExpressionVirtualColumn( "v0", "array('Hello',null)", - ColumnType.STRING, + ColumnType.STRING_ARRAY, ExprMacroTable.nil() ) )