Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved framework for writing parametric functions #3926

Merged
merged 11 commits into from
Feb 23, 2016
70 changes: 70 additions & 0 deletions presto-docs/src/main/sphinx/develop/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,76 @@ for its native container type.
native container type when using ``@Nullable``. The method must be annotated with
``@Nullable`` if it can return ``NULL`` when the arguments are non-null.

Parametric Scalar Functions
---------------------------

Scalar functions that have type parameters have some additional complexity.
To make our previous example work with any type we need the following:

.. code-block:: java

@ScalarFunction(name = "is_null")
@Description("Returns TRUE if the argument is NULL")
public final class IsNullFunction
{
@TypeParameter("T")
@SqlType(StandardTypes.BOOLEAN)
public static boolean isNullSlice(@Nullable @SqlType("T") Slice value)
{
return (value == null);
}

@TypeParameter("T")
@SqlType(StandardTypes.BOOLEAN)
public static boolean isNullLong(@Nullable @SqlType("T") Long value)
{
return (value == null);
}

@TypeParameter("T")
@SqlType(StandardTypes.BOOLEAN)
public static boolean isNullDouble(@Nullable @SqlType("T") Double value)
{
return (value == null);
}

// ...and so on for each native container type
}

* ``@TypeParameter``:

The ``@TypeParameter`` annotation is used to declare a type parameter which can
be used in the argument types ``@SqlType`` annotation, or return type of the function.
It can also be used to annotate a parameter of type ``Type``. At runtime, the engine
will bind the concrete type to this parameter. ``@OperatorDependency`` may be used
to declare that an additional function for operating on the given type parameter is needed.
For example, the following function will only bind to types which have an equals function
defined:

.. code-block:: java

@ScalarFunction(name = "is_equal_or_null")
@Description("Returns TRUE if arguments are equal or both NULL")
public final class IsEqualOrNullFunction
{
@TypeParameter("T")
@SqlType(StandardTypes.BOOLEAN)
public static boolean isEqualOrNullSlice(
@OperatorDependency(operator = OperatorType.EQUAL, returnType = StandardTypes.BOOLEAN, argumentTypes = {"T", "T"}) MethodHandle equals,
@Nullable @SqlType("T") Slice value1,
@Nullable @SqlType("T") Slice value2)
{
if (value1 == null && value2 == null) {
return true;
}
if (value1 == null || value2 == null) {
return false;
}
return (boolean) equals.invokeExact(value1, value2);
}

// ...and so on for each native container type
}

Aggregation Functions
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.operator.aggregation.GenericAggregationFunctionFactory;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.operator.scalar.JsonPath;
import com.facebook.presto.operator.scalar.ReflectionParametricScalar;
import com.facebook.presto.operator.scalar.ScalarFunction;
import com.facebook.presto.operator.scalar.ScalarOperator;
import com.facebook.presto.operator.window.ReflectionWindowFunctionSupplier;
Expand Down Expand Up @@ -179,13 +180,19 @@ private FunctionListBuilder operator(

public FunctionListBuilder scalar(Class<?> clazz)
{
ScalarFunction scalarAnnotation = clazz.getAnnotation(ScalarFunction.class);
ScalarOperator operatorAnnotation = clazz.getAnnotation(ScalarOperator.class);
if (scalarAnnotation != null || operatorAnnotation != null) {
functions.add(ReflectionParametricScalar.parseDefinition(clazz));
return this;
}
try {
boolean foundOne = false;
for (Method method : clazz.getMethods()) {
foundOne = processScalarFunction(method) || foundOne;
foundOne = processScalarOperator(method) || foundOne;
}
checkArgument(foundOne, "Expected class %s to contain at least one method annotated with @%s", clazz.getName(), ScalarFunction.class.getSimpleName());
checkArgument(foundOne, "Expected class %s to be annotated with @%s, or contain at least one method annotated with @%s", clazz.getName(), ScalarFunction.class.getSimpleName(), ScalarFunction.class.getSimpleName());
}
catch (IllegalAccessException e) {
throw Throwables.propagate(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@
import com.facebook.presto.operator.aggregation.NumericHistogramAggregation;
import com.facebook.presto.operator.aggregation.RegressionAggregation;
import com.facebook.presto.operator.aggregation.VarianceAggregation;
import com.facebook.presto.operator.scalar.ArrayConcatFunction;
import com.facebook.presto.operator.scalar.ArrayDistinctFunction;
import com.facebook.presto.operator.scalar.ArrayElementAtFunction;
import com.facebook.presto.operator.scalar.ArrayFunctions;
import com.facebook.presto.operator.scalar.ArrayGreaterThanOperator;
import com.facebook.presto.operator.scalar.ArrayMaxFunction;
import com.facebook.presto.operator.scalar.ArrayMinFunction;
import com.facebook.presto.operator.scalar.ArrayRemoveFunction;
import com.facebook.presto.operator.scalar.BitwiseFunctions;
import com.facebook.presto.operator.scalar.ColorFunctions;
import com.facebook.presto.operator.scalar.CombineHashFunction;
Expand Down Expand Up @@ -144,25 +151,18 @@
import static com.facebook.presto.operator.aggregation.MinNAggregationFunction.MIN_N_AGGREGATION;
import static com.facebook.presto.operator.aggregation.MultimapAggregationFunction.MULTIMAP_AGG;
import static com.facebook.presto.operator.scalar.ArrayCardinalityFunction.ARRAY_CARDINALITY;
import static com.facebook.presto.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR;
import static com.facebook.presto.operator.scalar.ArrayContains.ARRAY_CONTAINS;
import static com.facebook.presto.operator.scalar.ArrayDistinctFunction.ARRAY_DISTINCT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArrayElementAtFunction.ARRAY_ELEMENT_AT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArrayEqualOperator.ARRAY_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayGreaterThanOperator.ARRAY_GREATER_THAN;
import static com.facebook.presto.operator.scalar.ArrayGreaterThanOrEqualOperator.ARRAY_GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayHashCodeOperator.ARRAY_HASH_CODE;
import static com.facebook.presto.operator.scalar.ArrayIntersectFunction.ARRAY_INTERSECT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArrayJoin.ARRAY_JOIN;
import static com.facebook.presto.operator.scalar.ArrayJoin.ARRAY_JOIN_WITH_NULL_REPLACEMENT;
import static com.facebook.presto.operator.scalar.ArrayLessThanOperator.ARRAY_LESS_THAN;
import static com.facebook.presto.operator.scalar.ArrayLessThanOrEqualOperator.ARRAY_LESS_THAN_OR_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayMaxFunction.ARRAY_MAX;
import static com.facebook.presto.operator.scalar.ArrayMinFunction.ARRAY_MIN;
import static com.facebook.presto.operator.scalar.ArrayNotEqualOperator.ARRAY_NOT_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayPositionFunction.ARRAY_POSITION;
import static com.facebook.presto.operator.scalar.ArrayRemoveFunction.ARRAY_REMOVE_FUNCTION;
import static com.facebook.presto.operator.scalar.ArraySliceFunction.ARRAY_SLICE_FUNCTION;
import static com.facebook.presto.operator.scalar.ArraySortFunction.ARRAY_SORT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArraySubscriptOperator.ARRAY_SUBSCRIPT;
Expand Down Expand Up @@ -345,12 +345,18 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key)
.scalar(JsonOperators.class)
.scalar(FailureFunction.class)
.functions(IDENTITY_CAST, CAST_FROM_UNKNOWN)
.scalar(ArrayRemoveFunction.class)
.scalar(ArrayGreaterThanOperator.class)
.scalar(ArrayElementAtFunction.class)
.scalar(ArrayMinFunction.class)
.scalar(ArrayMaxFunction.class)
.scalar(ArrayDistinctFunction.class)
.scalar(ArrayConcatFunction.class)
.functions(ARRAY_CONTAINS, ARRAY_JOIN, ARRAY_JOIN_WITH_NULL_REPLACEMENT)
.functions(ARRAY_MIN, ARRAY_MAX)
.functions(ARRAY_TO_ARRAY_CAST, ARRAY_HASH_CODE, ARRAY_EQUAL, ARRAY_NOT_EQUAL, ARRAY_LESS_THAN, ARRAY_LESS_THAN_OR_EQUAL, ARRAY_GREATER_THAN, ARRAY_GREATER_THAN_OR_EQUAL)
.functions(ARRAY_CONCAT_FUNCTION, ARRAY_TO_ELEMENT_CONCAT_FUNCTION, ELEMENT_TO_ARRAY_CONCAT_FUNCTION)
.functions(ARRAY_TO_ARRAY_CAST, ARRAY_HASH_CODE, ARRAY_EQUAL, ARRAY_NOT_EQUAL, ARRAY_LESS_THAN, ARRAY_LESS_THAN_OR_EQUAL, ARRAY_GREATER_THAN_OR_EQUAL)
.functions(ARRAY_TO_ELEMENT_CONCAT_FUNCTION, ELEMENT_TO_ARRAY_CONCAT_FUNCTION)
.functions(MAP_EQUAL, MAP_NOT_EQUAL, MAP_HASH_CODE)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_ELEMENT_AT_FUNCTION, ARRAY_CARDINALITY, ARRAY_POSITION, ARRAY_SORT_FUNCTION, ARRAY_INTERSECT_FUNCTION, ARRAY_TO_JSON, JSON_TO_ARRAY, ARRAY_DISTINCT_FUNCTION, ARRAY_REMOVE_FUNCTION, ARRAY_SLICE_FUNCTION)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_CARDINALITY, ARRAY_POSITION, ARRAY_SORT_FUNCTION, ARRAY_INTERSECT_FUNCTION, ARRAY_TO_JSON, JSON_TO_ARRAY, ARRAY_SLICE_FUNCTION)
.functions(MAP_CONSTRUCTOR, MAP_CARDINALITY, MAP_SUBSCRIPT, MAP_TO_JSON, JSON_TO_MAP, MAP_KEYS, MAP_VALUES, MAP_CONCAT_FUNCTION)
.functions(MAP_AGG, MULTIMAP_AGG)
.function(HISTOGRAM)
Expand All @@ -377,22 +383,19 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key)
addFunctions(builder.getFunctions());
}

@Nullable
private static Signature bindSignature(Signature signature, List<? extends Type> types, boolean allowCoercion, TypeManager typeManager)
public static Signature bindSignature(Signature signature, Map<String, Type> boundParameters, int arity)
{
checkArgument((signature.isVariableArity() && arity >= signature.getArgumentTypes().size() - 1) || arity == signature.getArgumentTypes().size(),
"Illegal arity %d for function %s", arity, signature);
List<TypeSignature> argumentTypes = signature.getArgumentTypes();
Map<String, Type> boundParameters = signature.bindTypeParameters(types, allowCoercion, typeManager);
if (boundParameters == null) {
return null;
}
ImmutableList.Builder<TypeSignature> boundArguments = ImmutableList.builder();
for (int i = 0; i < argumentTypes.size() - 1; i++) {
boundArguments.add(argumentTypes.get(i).bindParameters(boundParameters));
}
if (!argumentTypes.isEmpty()) {
TypeSignature lastArgument = argumentTypes.get(argumentTypes.size() - 1).bindParameters(boundParameters);
if (signature.isVariableArity()) {
for (int i = 0; i < types.size() - (argumentTypes.size() - 1); i++) {
for (int i = 0; i < arity - (argumentTypes.size() - 1); i++) {
boundArguments.add(lastArgument);
}
}
Expand Down Expand Up @@ -435,7 +438,11 @@ public Signature resolveFunction(QualifiedName name, List<TypeSignature> paramet
// search for exact match
Signature match = null;
for (SqlFunction function : candidates) {
Signature signature = bindSignature(function.getSignature(), resolvedTypes, false, typeManager);
Map<String, Type> boundParameters = function.getSignature().bindTypeParameters(resolvedTypes, false, typeManager);
if (boundParameters == null) {
continue;
}
Signature signature = bindSignature(function.getSignature(), boundParameters, resolvedTypes.size());
if (signature != null) {
checkArgument(match == null, "Ambiguous call to %s with parameters %s", name, parameterTypes);
match = signature;
Expand All @@ -448,7 +455,11 @@ public Signature resolveFunction(QualifiedName name, List<TypeSignature> paramet

// search for coerced match
for (SqlFunction function : candidates) {
Signature signature = bindSignature(function.getSignature(), resolvedTypes, true, typeManager);
Map<String, Type> boundParameters = function.getSignature().bindTypeParameters(resolvedTypes, true, typeManager);
if (boundParameters == null) {
continue;
}
Signature signature = bindSignature(function.getSignature(), boundParameters, resolvedTypes.size());
if (signature != null) {
// TODO: This should also check for ambiguities
match = signature;
Expand Down Expand Up @@ -494,7 +505,10 @@ public Signature resolveFunction(QualifiedName name, List<TypeSignature> paramet
if (parameterTypes.size() == 1 && parameterTypes.get(0).getBase().equals(StandardTypes.ROW)) {
SqlFunction fieldReference = getRowFieldReference(name.getSuffix(), parameterTypes.get(0));
if (fieldReference != null) {
return bindSignature(fieldReference.getSignature(), resolvedTypes, true, typeManager);
Map<String, Type> boundParameters = fieldReference.getSignature().bindTypeParameters(resolvedTypes, true, typeManager);
if (boundParameters != null) {
return bindSignature(fieldReference.getSignature(), boundParameters, resolvedTypes.size());
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ public static Signature internalOperator(OperatorType operator, Type returnType,
return internalScalarFunction(mangleOperatorName(operator.name()), returnType.getTypeSignature(), argumentTypes.stream().map(Type::getTypeSignature).collect(toImmutableList()));
}

public static Signature internalOperator(OperatorType operator, String returnType, List<String> argumentTypes)
{
return internalScalarFunction(mangleOperatorName(operator.name()), returnType, argumentTypes);
}

public static Signature internalOperator(String name, TypeSignature returnType, List<TypeSignature> argumentTypes)
{
return internalScalarFunction(mangleOperatorName(name), returnType, argumentTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,69 +13,50 @@
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.SqlScalarFunction;
import com.facebook.presto.operator.Description;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.type.SqlType;
import com.google.common.collect.ImmutableList;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.metadata.Signature.typeParameter;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static com.facebook.presto.util.Reflection.constructorMethodHandle;
import static com.facebook.presto.util.Reflection.methodHandle;
import static java.lang.invoke.MethodHandles.permuteArguments;

public class ArrayConcatFunction
extends SqlScalarFunction
@ScalarFunction("concat")
@Description("Concatenates given arrays")
public final class ArrayConcatFunction
{
public static final ArrayConcatFunction ARRAY_CONCAT_FUNCTION = new ArrayConcatFunction();
private static final String FUNCTION_NAME = "concat";
private static final MethodHandle CONSTRUCTOR = constructorMethodHandle(FUNCTION_IMPLEMENTATION_ERROR, ArrayConcatUtils.class, Type.class);
private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayConcatUtils.class, FUNCTION_NAME, Type.class, Block.class, Block.class);
private final PageBuilder pageBuilder;

public ArrayConcatFunction()
@TypeParameter("E")
public ArrayConcatFunction(@TypeParameter("E") Type elementType)
{
super(FUNCTION_NAME, ImmutableList.of(typeParameter("E")), "array(E)", ImmutableList.of("array(E)", "array(E)"));
pageBuilder = new PageBuilder(ImmutableList.of(elementType));
}

@Override
public boolean isHidden()
@TypeParameter("E")
@SqlType("array(E)")
public Block concat(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block leftBlock, @SqlType("array(E)") Block rightBlock)
{
return false;
}

@Override
public boolean isDeterministic()
{
return true;
}
if (leftBlock.getPositionCount() == 0) {
return rightBlock;
}
if (rightBlock.getPositionCount() == 0) {
return leftBlock;
}

@Override
public String getDescription()
{
return "Concatenates given arrays";
}
if (pageBuilder.isFull()) {
pageBuilder.reset();
}

@Override
public ScalarFunctionImplementation specialize(Map<String, Type> types, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
Type elementType = types.get("E");
MethodType newType = METHOD_HANDLE.type().changeParameterType(0, Type.class).changeParameterType(1, ArrayConcatUtils.class);
int[] permutedIndices = new int[newType.parameterCount()];
permutedIndices[0] = 1;
permutedIndices[1] = 0;
for (int i = 2; i < permutedIndices.length; i++) {
permutedIndices[i] = i;
BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0);
for (int i = 0; i < leftBlock.getPositionCount(); i++) {
elementType.appendTo(leftBlock, i, blockBuilder);
}
for (int i = 0; i < rightBlock.getPositionCount(); i++) {
elementType.appendTo(rightBlock, i, blockBuilder);
}
MethodHandle methodHandle = permuteArguments(METHOD_HANDLE, newType, permutedIndices);
methodHandle = methodHandle.bindTo(elementType);
MethodHandle instanceFactory = CONSTRUCTOR.bindTo(elementType);
return new ScalarFunctionImplementation(false, ImmutableList.of(false, false), methodHandle, Optional.of(instanceFactory), isDeterministic());
int total = leftBlock.getPositionCount() + rightBlock.getPositionCount();
pageBuilder.declarePositions(total);
return blockBuilder.getRegion(blockBuilder.getPositionCount() - total, total);
}
}
Loading