Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDF plugin API #507

Merged
merged 7 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private[offline] object PostTransformationUtil {
featureType: FeatureTypes): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val transformedValue = MVEL.executeExpression(compiledExpression, featureValue, variableResolverFactory)
val transformedValue = MvelContext.executeExpressionWithPluginSupport(compiledExpression, featureValue, variableResolverFactory)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
// be more strict for resolving keys (don't swallow exceptions)
keyExpression.map(k =>
try {
Option(MVEL.executeExpression(k, datum)) match {
Option(MvelContext.executeExpressionWithPluginSupport(k, datum)) match {
case None => null
case Some(keys) => keys.toString
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.linkedin.feathr.offline.client.plugins

import scala.collection.mutable

/**
* A shared registry for loading [[UdfAdaptor]]s, which basically can tell Feathr's runtime how to support different
* kinds of "external" UDFs not natively known to Feathr, but which have similar behavior to Feathr's.
*
* All "external" UDF classes are required to have a public default zero-arg constructor.
*/
object FeathrUdfPluginContext {
val registeredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]()

def registerUdfAdaptor(adaptor: UdfAdaptor[_]): Unit = {
this.synchronized {
registeredUdfAdaptors += adaptor
}
}

def getRegisteredUdfAdaptor(clazz: Class[_]): Option[UdfAdaptor[_]] = {
registeredUdfAdaptors.find(_.canAdapt(clazz))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.linkedin.feathr.offline.client.plugins

import com.linkedin.feathr.common.{AnchorExtractor, FeatureDerivationFunction}
import com.linkedin.feathr.sparkcommon.{SimpleAnchorExtractorSpark, SourceKeyExtractor}

/**
* Tells Feathr how to use UDFs that are defined in "external" non-Feathr UDF classes (i.e. that don't extend from
* Feathr's AnchorExtractor or other UDF traits). An adaptor must match the external UDF class to a specific kind of
* Feathr UDF – see child traits below for the various options.
*
* All "external" UDF classes are required to have a public default zero-arg constructor.
*
* @tparam T the internal Feathr UDF class whose behavior the external UDF can be translated to
*/
sealed trait UdfAdaptor[T] extends Serializable {
/**
* Indicates whether this adaptor can be applied to an object of the provided class.
*
* Implementations should usually look like <pre>classOf[UdfTraitThatIsNotPartOfFeathr].isAssignableFrom(clazz)</pre>
*
* @param clazz some external UDF type
* @return true if this adaptor can "adapt" the given class type; false otherwise
*/
def canAdapt(clazz: Class[_]): Boolean

/**
* Returns an instance of a Feathr UDF, that follows the behavior of some external UDF instance, e.g. via delegation.
*
* @param externalUdf instance of the "external" UDF
* @return the Feathr UDF
*/
def adaptUdf(externalUdf: AnyRef): T
}

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[AnchorExtractor]]
*/
trait AnchorExtractorAdaptor extends UdfAdaptor[AnchorExtractor[_]]

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[SimpleAnchorExtractorSpark]]
*/
trait SimpleAnchorExtractorSparkAdaptor extends UdfAdaptor[SimpleAnchorExtractorSpark]

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[FeatureDerivationFunction]]
*/
trait FeatureDerivationFunctionAdaptor extends UdfAdaptor[FeatureDerivationFunction]

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[SourceKeyExtractor]]
*/
trait SourceKeyExtractorAdaptor extends UdfAdaptor[SourceKeyExtractor]
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.linkedin.feathr.offline.ErasedEntityTaggedFeature
import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnchorExtractor, SimpleConfigurableAnchorExtractor, TimeWindowConfigurableAnchorExtractor}
import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource}
import com.linkedin.feathr.offline.anchored.keyExtractor.{MVELSourceKeyExtractor, SQLSourceKeyExtractor}
import com.linkedin.feathr.offline.client.plugins.{AnchorExtractorAdaptor, FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor, SimpleAnchorExtractorSparkAdaptor, SourceKeyExtractorAdaptor}
import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, KafkaEndpoint, LocationUtils, SimplePath}
import com.linkedin.feathr.offline.derived._
import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction, SimpleMvelDerivationFunction}
Expand Down Expand Up @@ -254,14 +255,23 @@ private[offline] class AnchorLoader extends JsonDeserializer[FeatureAnchor] {
// if it is UDF, no extra information other than the class name is required
val anchorExtractor: AnyRef = codec.treeToValue(node, extractorClass).asInstanceOf[AnyRef]

anchorExtractor match {
case extractor: AnchorExtractorBase[_] =>
val extractorNode = node.get("extractor")
if (extractorNode != null && extractorNode.get("params") != null) {
// init the param into the extractor
val config = ConfigFactory.parseString(extractorNode.get("params").toString)
extractor.init(config)
val extractorNode = node.get("extractor")
if (extractorNode != null && extractorNode.get("params") != null) {
// init the param into the extractor
val config = ConfigFactory.parseString(extractorNode.get("params").toString)
anchorExtractor match {
case aebExtractor: AnchorExtractorBase[_] =>
aebExtractor.init(config)
case otherExtractor =>
FeathrUdfPluginContext.getRegisteredUdfAdaptor(extractorClass) match {
case Some(adaptor: SimpleAnchorExtractorSparkAdaptor) =>
adaptor.adaptUdf(otherExtractor).init(config)
jaymo001 marked this conversation as resolved.
Show resolved Hide resolved
case Some(adaptor: AnchorExtractorAdaptor) =>
adaptor.adaptUdf(otherExtractor).init(config)
case _ =>
throw new FeathrConfigException(ErrorLabel.FEATHR_ERROR, s"Unknown extractor type ${extractorClass}")
}
}
}

// cast the the extractor class to AnchorExtractor[Any]
Expand Down Expand Up @@ -362,8 +372,13 @@ private[offline] class AnchorLoader extends JsonDeserializer[FeatureAnchor] {
anchorExtractorBase: AnyRef,
lateralViewParameters: Option[LateralViewParams]): SourceKeyExtractor = {
Option(node.get("keyExtractor")).map(_.textValue) match {
case Some(keyExtractor) =>
Class.forName(keyExtractor).newInstance().asInstanceOf[SourceKeyExtractor]
case Some(keyExtractorClassName) =>
val keyExtractorClass = Class.forName(keyExtractorClassName)
val newInstance = keyExtractorClass.getDeclaredConstructor().newInstance().asInstanceOf[AnyRef]
FeathrUdfPluginContext.getRegisteredUdfAdaptor(keyExtractorClass) match {
case Some(adaptor: SourceKeyExtractorAdaptor) => adaptor.adaptUdf(newInstance)
case _ => newInstance.asInstanceOf[SourceKeyExtractor]
}
case _ =>
Option(node.get("key")) match {
case Some(keyNode) =>
Expand Down Expand Up @@ -545,8 +560,16 @@ private[offline] class DerivationLoader extends JsonDeserializer[DerivedFeature]
loadAdvancedDerivedFeature(x, codec)
} else if (x.has("class")) {
val config = codec.treeToValue(x, classOf[CustomDerivedFeatureConfig])
val derivationFunction = config.`class`.newInstance().asInstanceOf[AnyRef]
val derivationFunctionClass = config.`class`
val derivationFunction = derivationFunctionClass.getDeclaredConstructor().newInstance().asInstanceOf[AnyRef]
// possibly "adapt" the derivation function, in case it doesn't implement Feathr's FeatureDerivationFunction,
// using FeathrUdfPluginContext
val maybeAdaptedDerivationFunction = FeathrUdfPluginContext.getRegisteredUdfAdaptor(derivationFunctionClass) match {
case Some(adaptor: FeatureDerivationFunctionAdaptor) => adaptor.adaptUdf(derivationFunction)
case _ => derivationFunction
}
val consumedFeatures = config.inputs.map(x => ErasedEntityTaggedFeature(x.key.map(config.key.zipWithIndex.toMap), x.feature)).toIndexedSeq

// consumedFeatures and parameterNames have same order, since they are all from config.inputs
DerivedFeature(consumedFeatures, producedFeatures, derivationFunction, config.parameterNames, featureTypeConfigMap)
} else if (x.has("join")) { // when the derived feature config is a seqJoin config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ private[offline] case class DerivedFeature(
* get the row-based FeatureDerivationFunction, note that some of the derivations that derive from FeatureDerivationFunctionBase
* are not subclass of FeatureDerivationFunction, e.g, [[FeatureDerivationFunctionSpark]], in such cases, this function will
* throw exception, make sure you will not call this function for such cases.
*
* TODO: The above described condition is bad; ideally this class should capture the information about what type of
* derivation function this is in a type-safe way.
*/
def getAsFeatureDerivationFunction(): FeatureDerivationFunction = derivation.asInstanceOf[FeatureDerivationFunction]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig}
import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrException}
import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame}
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.client.plugins.{FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor}
import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction
import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinAsDerivation, SparkUdfDerivation}
import com.linkedin.feathr.offline.join.algorithms.{SequentialJoinConditionBuilder, SparkJoinWithJoinCondition}
Expand Down Expand Up @@ -49,8 +50,16 @@ private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationS
// doesn't have fixed DataFrame schema so that we can't return TensorData but has to return FDS.
val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, x)
offline.FeatureDataFrame(resultDF, getTypeConfigs(producedFeatureColName, derivedFeature, resultDF))
case _ =>
throw new FeathrException(ErrorLabel.FEATHR_ERROR, s"Unsupported feature derivation function for feature ${derivedFeature.producedFeatureNames.head}.")
case derivation =>
FeathrUdfPluginContext.getRegisteredUdfAdaptor(derivation.getClass) match {
case Some(adaptor: FeatureDerivationFunctionAdaptor) =>
// replicating the FeatureDerivationFunction case above
val featureDerivationFunction = adaptor.adaptUdf(derivation)
val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, featureDerivationFunction)
offline.FeatureDataFrame(resultDF, getTypeConfigs(producedFeatureColName, derivedFeature, resultDF))
case _ =>
throw new FeathrException(ErrorLabel.FEATHR_ERROR, s"Unsupported feature derivation function for feature ${derivedFeature.producedFeatureNames.head}.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnch
import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource}
import com.linkedin.feathr.offline.anchored.keyExtractor.MVELSourceKeyExtractor
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.client.plugins.{SimpleAnchorExtractorSparkAdaptor, FeathrUdfPluginContext, AnchorExtractorAdaptor}
import com.linkedin.feathr.offline.config.{MVELFeatureDefinition, TimeWindowFeatureDefinition}
import com.linkedin.feathr.offline.generation.IncrementalAggContext
import com.linkedin.feathr.offline.job.FeatureJoinJob.FeatureName
Expand All @@ -23,7 +24,6 @@ import com.linkedin.feathr.swj.aggregate.AggregationType
import com.linkedin.feathr.{common, offline}
import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
Expand Down Expand Up @@ -179,8 +179,8 @@ private[offline] object FeatureTransformation {
DataFrameBasedSqlEvaluator.transform(transformer, df, featureNamePrefixPairs, featureTypeConfigs)
case transformer: AnchorExtractor[_] =>
DataFrameBasedRowEvaluator.transform(transformer, df, featureNamePrefixPairs, featureTypeConfigs)
case _ => throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_USER_ERROR, s"cannot find valid Transformer for ${featureAnchorWithSource}")

case _ =>
throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_USER_ERROR, s"cannot find valid Transformer for ${featureAnchorWithSource}")
}
// Check for whether there are duplicate columns in the transformed DataFrame, typically this is because
// the feature name is the same as some field name
Expand Down
32 changes: 32 additions & 0 deletions src/main/scala/com/linkedin/feathr/offline/mvel/MvelContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.linkedin.feathr.common.FeatureValue;
import com.linkedin.feathr.common.util.MvelContextUDFs;
import org.apache.avro.generic.GenericEnumSymbol;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.mvel2.DataConversion;
import org.mvel2.MVEL;
import org.mvel2.ParserConfiguration;
import org.mvel2.ParserContext;
import org.mvel2.integration.PropertyHandler;
Expand Down Expand Up @@ -106,6 +109,35 @@ public static ParserContext newParserContext() {
return new ParserContext(PARSER_CONFIG);
}

/**
* Evaluate MVEL expression as per {@link MVEL#executeExpression(Object, Object)}, with added support for
* {@link com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext}. (Output objects that can be converted
* to {@link FeatureValue} via plugins, will be converted after MVEL returns.)
*/
public static Object executeExpressionWithPluginSupport(Object compiledExpression, Object ctx) {
Object output = MVEL.executeExpression(compiledExpression, ctx);
return coerceToFeatureValueViaMvelDataConversionPlugins(output);
}

/**
* Evaluate MVEL expression as per {@link MVEL#executeExpression(Object, Object, VariableResolverFactory)}, with added support for
* {@link com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext}. (Output objects that can be converted
* to {@link FeatureValue} via plugins, will be converted after MVEL returns.)
*/
public static Object executeExpressionWithPluginSupport(Object compiledExpression, Object ctx,
VariableResolverFactory variableResolverFactory) {
Object output = MVEL.executeExpression(compiledExpression, ctx, variableResolverFactory);
return coerceToFeatureValueViaMvelDataConversionPlugins(output);
}

private static Object coerceToFeatureValueViaMvelDataConversionPlugins(Object input) {
if (input != null && DataConversion.canConvert(FeatureValue.class, input.getClass())) {
return DataConversion.convert(input, FeatureValue.class);
} else {
return input;
}
}

/**
* Allows easy access to the properties of GenericRecord object from MVEL.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ private[offline] object MvelUtils {
// (We might not want to check for null explicitly everywhere)
def executeExpression(compiledExpression: Any, input: Any, resolverFactory: VariableResolverFactory, featureName: String = ""): Option[AnyRef] = {
try {
Option(MVEL.executeExpression(compiledExpression, input, resolverFactory))
Option(MvelContext.executeExpressionWithPluginSupport(compiledExpression, input, resolverFactory))
} catch {
case e: RuntimeException =>
log.debug(s"Expression $compiledExpression on input record $input threw exception", e)
Expand Down
Loading