Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function resolution to SPI #12588

Merged
merged 13 commits into from
Sep 4, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.trino.spi.connector.ConnectorPageSinkProvider;
import io.trino.spi.connector.ConnectorPageSourceProvider;
import io.trino.spi.connector.ConnectorSplitManager;
import io.trino.spi.function.FunctionProvider;

import javax.inject.Inject;
import javax.inject.Singleton;
Expand Down Expand Up @@ -162,6 +163,13 @@ public static CatalogServiceProvider<Optional<ConnectorAccessControl>> createAcc
return new ConnectorCatalogServiceProvider<>("access control", connectorServicesProvider, ConnectorServices::getAccessControl);
}

@Provides
@Singleton
public static CatalogServiceProvider<FunctionProvider> createFunctionProvider(ConnectorServicesProvider connectorServicesProvider)
{
return new ConnectorCatalogServiceProvider<>("function provider", connectorServicesProvider, ConnectorServices::getFunctionProvider);
}

private static class ConnectorAccessControlLazyRegister
{
@Inject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.trino.spi.connector.SystemTable;
import io.trino.spi.connector.TableProcedureMetadata;
import io.trino.spi.eventlistener.EventListener;
import io.trino.spi.function.FunctionProvider;
import io.trino.spi.procedure.Procedure;
import io.trino.spi.ptf.ArgumentSpecification;
import io.trino.spi.ptf.ConnectorTableFunction;
Expand Down Expand Up @@ -65,6 +66,7 @@ public class ConnectorServices
private final Set<SystemTable> systemTables;
private final CatalogProcedures procedures;
private final CatalogTableProcedures tableProcedures;
private final Optional<FunctionProvider> functionProvider;
private final CatalogTableFunctions tableFunctions;
private final Optional<ConnectorSplitManager> splitManager;
private final Optional<ConnectorPageSourceProvider> pageSourceProvider;
Expand Down Expand Up @@ -101,6 +103,8 @@ public ConnectorServices(CatalogHandle catalogHandle, Connector connector, Runna
requireNonNull(procedures, format("Connector '%s' returned a null table procedures set", catalogHandle));
this.tableProcedures = new CatalogTableProcedures(tableProcedures);

this.functionProvider = requireNonNull(connector.getFunctionProvider(), format("Connector '%s' returned a null function provider", catalogHandle));

Set<ConnectorTableFunction> tableFunctions = connector.getTableFunctions();
requireNonNull(tableFunctions, format("Connector '%s' returned a null table functions set", catalogHandle));
this.tableFunctions = new CatalogTableFunctions(tableFunctions);
Expand Down Expand Up @@ -226,6 +230,12 @@ public CatalogTableProcedures getTableProcedures()
return tableProcedures;
}

public FunctionProvider getFunctionProvider()
{
checkArgument(functionProvider.isPresent(), "Connector '%s' does not have functions", catalogHandle);
return functionProvider.get();
}

public CatalogTableFunctions getTableFunctions()
{
return tableFunctions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import io.trino.Session;
import io.trino.collect.cache.NonEvictableCache;
import io.trino.json.ir.IrPathNode;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.Metadata;
import io.trino.metadata.OperatorNotFoundException;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package io.trino.metadata;

import io.trino.spi.function.SchemaFunctionName;

import java.util.Objects;

public final class CatalogSchemaFunctionName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.metadata;

import com.google.common.collect.Maps;
import io.trino.spi.function.SchemaFunctionName;
import io.trino.spi.ptf.ConnectorTableFunction;

import javax.annotation.concurrent.ThreadSafe;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package io.trino.metadata;

import com.google.common.collect.ImmutableSortedMap;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionId;
import io.trino.spi.type.Type;

import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
*/
package io.trino.metadata;

import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionDependencyDeclaration;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.function.WindowFunctionSupplier;

import java.util.Collection;

Expand All @@ -27,13 +34,13 @@ public interface FunctionBundle

FunctionDependencyDeclaration getFunctionDependencies(FunctionId functionId, BoundSignature boundSignature);

FunctionInvoker getScalarFunctionInvoker(
ScalarFunctionImplementation getScalarFunctionImplementation(
FunctionId functionId,
BoundSignature boundSignature,
FunctionDependencies functionDependencies,
InvocationConvention invocationConvention);

AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies);
AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies);

WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies);
WindowFunctionSupplier getWindowFunctionSupplier(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies);
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.trino.FeaturesConfig;
import io.trino.collect.cache.NonEvictableCache;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.connector.CatalogServiceProvider;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionProvider;
import io.trino.spi.function.InOut;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.function.WindowFunctionSupplier;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.type.BlockTypeOperators;
Expand All @@ -51,14 +58,15 @@

public class FunctionManager
{
private final NonEvictableCache<FunctionKey, FunctionInvoker> specializedScalarCache;
private final NonEvictableCache<FunctionKey, AggregationMetadata> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, ScalarFunctionImplementation> specializedScalarCache;
private final NonEvictableCache<FunctionKey, AggregationImplementation> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, WindowFunctionSupplier> specializedWindowCache;

private final CatalogServiceProvider<FunctionProvider> functionProviders;
private final GlobalFunctionCatalog globalFunctionCatalog;

@Inject
public FunctionManager(GlobalFunctionCatalog globalFunctionCatalog)
public FunctionManager(CatalogServiceProvider<FunctionProvider> functionProviders, GlobalFunctionCatalog globalFunctionCatalog)
{
specializedScalarCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
Expand All @@ -72,80 +80,92 @@ public FunctionManager(GlobalFunctionCatalog globalFunctionCatalog)
.maximumSize(1000)
.expireAfterWrite(1, HOURS));

this.globalFunctionCatalog = globalFunctionCatalog;
this.functionProviders = requireNonNull(functionProviders, "functionProviders is null");
this.globalFunctionCatalog = requireNonNull(globalFunctionCatalog, "globalFunctionCatalog is null");
}

public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
{
try {
return uncheckedCacheGet(specializedScalarCache, new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionInvokerInternal(resolvedFunction, invocationConvention));
return uncheckedCacheGet(specializedScalarCache, new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionImplementationInternal(resolvedFunction, invocationConvention));
}
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
}

private FunctionInvoker getScalarFunctionInvokerInternal(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
private ScalarFunctionImplementation getScalarFunctionImplementationInternal(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
{
FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction);
FunctionInvoker functionInvoker = globalFunctionCatalog.getScalarFunctionInvoker(
ScalarFunctionImplementation scalarFunctionImplementation = getFunctionProvider(resolvedFunction).getScalarFunctionImplementation(
resolvedFunction.getFunctionId(),
resolvedFunction.getSignature(),
functionDependencies,
invocationConvention);
verifyMethodHandleSignature(resolvedFunction.getSignature(), functionInvoker, invocationConvention);
return functionInvoker;
verifyMethodHandleSignature(resolvedFunction.getSignature(), scalarFunctionImplementation, invocationConvention);
return scalarFunctionImplementation;
}

public AggregationMetadata getAggregateFunctionImplementation(ResolvedFunction resolvedFunction)
public AggregationImplementation getAggregationImplementation(ResolvedFunction resolvedFunction)
{
try {
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregateFunctionImplementationInternal(resolvedFunction));
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregationImplementationInternal(resolvedFunction));
}
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
}

private AggregationMetadata getAggregateFunctionImplementationInternal(ResolvedFunction resolvedFunction)
private AggregationImplementation getAggregationImplementationInternal(ResolvedFunction resolvedFunction)
{
FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction);
return globalFunctionCatalog.getAggregateFunctionImplementation(
return getFunctionProvider(resolvedFunction).getAggregationImplementation(
resolvedFunction.getFunctionId(),
resolvedFunction.getSignature(),
functionDependencies);
}

public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction)
public WindowFunctionSupplier getWindowFunctionSupplier(ResolvedFunction resolvedFunction)
{
try {
return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionImplementationInternal(resolvedFunction));
return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionSupplierInternal(resolvedFunction));
}
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
}

private WindowFunctionSupplier getWindowFunctionImplementationInternal(ResolvedFunction resolvedFunction)
private WindowFunctionSupplier getWindowFunctionSupplierInternal(ResolvedFunction resolvedFunction)
{
FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction);
return globalFunctionCatalog.getWindowFunctionImplementation(
return getFunctionProvider(resolvedFunction).getWindowFunctionSupplier(
resolvedFunction.getFunctionId(),
resolvedFunction.getSignature(),
functionDependencies);
}

private FunctionDependencies getFunctionDependencies(ResolvedFunction resolvedFunction)
{
return new FunctionDependencies(this::getScalarFunctionInvoker, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies());
return new InternalFunctionDependencies(this::getScalarFunctionImplementation, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies());
}

private static void verifyMethodHandleSignature(BoundSignature boundSignature, FunctionInvoker functionInvoker, InvocationConvention convention)
private FunctionProvider getFunctionProvider(ResolvedFunction resolvedFunction)
{
MethodHandle methodHandle = functionInvoker.getMethodHandle();
if (resolvedFunction.getCatalogHandle().equals(GlobalSystemConnector.CATALOG_HANDLE)) {
return globalFunctionCatalog;
}

FunctionProvider functionProvider = functionProviders.getService(resolvedFunction.getCatalogHandle());
checkArgument(functionProvider != null, "No function provider for catalog: '%s' (function '%s')", resolvedFunction.getCatalogHandle(), resolvedFunction.getSignature().getName());
return functionProvider;
}

private static void verifyMethodHandleSignature(BoundSignature boundSignature, ScalarFunctionImplementation scalarFunctionImplementation, InvocationConvention convention)
{
MethodHandle methodHandle = scalarFunctionImplementation.getMethodHandle();
MethodType methodType = methodHandle.type();

checkArgument(convention.getArgumentConventions().size() == boundSignature.getArgumentTypes().size(),
Expand All @@ -155,16 +175,16 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, F
.mapToInt(InvocationArgumentConvention::getParameterCount)
.sum();
expectedParameterCount += methodType.parameterList().stream().filter(ConnectorSession.class::equals).count();
if (functionInvoker.getInstanceFactory().isPresent()) {
if (scalarFunctionImplementation.getInstanceFactory().isPresent()) {
expectedParameterCount++;
}
checkArgument(expectedParameterCount == methodType.parameterCount(),
"Expected %s method parameters, but got %s", expectedParameterCount, methodType.parameterCount());

int parameterIndex = 0;
if (functionInvoker.getInstanceFactory().isPresent()) {
if (scalarFunctionImplementation.getInstanceFactory().isPresent()) {
verifyFunctionSignature(convention.supportsInstanceFactory(), "Method requires instance factory, but calling convention does not support an instance factory");
MethodHandle factoryMethod = functionInvoker.getInstanceFactory().orElseThrow();
MethodHandle factoryMethod = scalarFunctionImplementation.getInstanceFactory().orElseThrow();
verifyFunctionSignature(methodType.parameterType(parameterIndex).equals(factoryMethod.type().returnType()), "Invalid return type");
parameterIndex++;
}
Expand Down Expand Up @@ -203,7 +223,7 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, F
verifyFunctionSignature(parameterType.equals(InOut.class), "Expected IN_OUT argument type to be InOut");
break;
case FUNCTION:
Class<?> lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex);
Class<?> lambdaInterface = scalarFunctionImplementation.getLambdaInterfaces().get(lambdaArgumentIndex);
verifyFunctionSignature(parameterType.equals(lambdaInterface),
"Expected function interface to be %s, but is %s", lambdaInterface, parameterType);
lambdaArgumentIndex++;
Expand Down Expand Up @@ -297,6 +317,6 @@ public static FunctionManager createTestingFunctionManager()
GlobalFunctionCatalog functionCatalog = new GlobalFunctionCatalog();
functionCatalog.addFunctions(SystemFunctionBundle.create(new FeaturesConfig(), typeOperators, new BlockTypeOperators(typeOperators), UNKNOWN));
functionCatalog.addFunctions(new InternalFunctionBundle(new LiteralFunction(new InternalBlockEncodingSerde(new BlockEncodingManager(), TESTING_TYPE_MANAGER))));
return new FunctionManager(functionCatalog);
return new FunctionManager(CatalogServiceProvider.fail(), functionCatalog);
}
}
Loading