Skip to content

Commit

Permalink
UDF plugin API (feathr-ai#507)
Browse files Browse the repository at this point in the history
* untested prototype of UDF plugin API

* demo of how to enable coercion between different kinds of FeatureValue classes in MVEL expressions

* add plugin APIs (untested)

* added docs, cleaned up APIs, still working on testing

* update how MVEL is invoked, to conditionally apply plugin to convert to FeatureValue. also added tests.

* small doc fix

* attempted bugfix regarding not holding on to the adapted anchorExtractor object

Co-authored-by: David Stein <dstein@dstein-mn1.linkedin.biz>
  • Loading branch information
2 people authored and ahlag committed Aug 26, 2022
1 parent 34eed2d commit f3eeffd
Show file tree
Hide file tree
Showing 20 changed files with 666 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private[offline] object PostTransformationUtil {
featureType: FeatureTypes): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val transformedValue = MVEL.executeExpression(compiledExpression, featureValue, variableResolverFactory)
val transformedValue = MvelContext.executeExpressionWithPluginSupport(compiledExpression, featureValue, variableResolverFactory)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
// be more strict for resolving keys (don't swallow exceptions)
keyExpression.map(k =>
try {
Option(MVEL.executeExpression(k, datum)) match {
Option(MvelContext.executeExpressionWithPluginSupport(k, datum)) match {
case None => null
case Some(keys) => keys.toString
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.linkedin.feathr.offline.client.plugins

import scala.collection.mutable

/**
* A shared registry for loading [[UdfAdaptor]]s, which basically can tell Feathr's runtime how to support different
* kinds of "external" UDFs not natively known to Feathr, but which have similar behavior to Feathr's.
*
* All "external" UDF classes are required to have a public default zero-arg constructor.
*/
object FeathrUdfPluginContext {
val registeredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]()

def registerUdfAdaptor(adaptor: UdfAdaptor[_]): Unit = {
this.synchronized {
registeredUdfAdaptors += adaptor
}
}

def getRegisteredUdfAdaptor(clazz: Class[_]): Option[UdfAdaptor[_]] = {
registeredUdfAdaptors.find(_.canAdapt(clazz))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.linkedin.feathr.offline.client.plugins

import com.linkedin.feathr.common.{AnchorExtractor, FeatureDerivationFunction}
import com.linkedin.feathr.sparkcommon.{SimpleAnchorExtractorSpark, SourceKeyExtractor}

/**
* Tells Feathr how to use UDFs that are defined in "external" non-Feathr UDF classes (i.e. that don't extend from
* Feathr's AnchorExtractor or other UDF traits). An adaptor must match the external UDF class to a specific kind of
* Feathr UDF – see child traits below for the various options.
*
* All "external" UDF classes are required to have a public default zero-arg constructor.
*
* @tparam T the internal Feathr UDF class whose behavior the external UDF can be translated to
*/
sealed trait UdfAdaptor[T] extends Serializable {
/**
* Indicates whether this adaptor can be applied to an object of the provided class.
*
* Implementations should usually look like <pre>classOf[UdfTraitThatIsNotPartOfFeathr].isAssignableFrom(clazz)</pre>
*
* @param clazz some external UDF type
* @return true if this adaptor can "adapt" the given class type; false otherwise
*/
def canAdapt(clazz: Class[_]): Boolean

/**
* Returns an instance of a Feathr UDF, that follows the behavior of some external UDF instance, e.g. via delegation.
*
* @param externalUdf instance of the "external" UDF
* @return the Feathr UDF
*/
def adaptUdf(externalUdf: AnyRef): T
}

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[AnchorExtractor]]
*/
trait AnchorExtractorAdaptor extends UdfAdaptor[AnchorExtractor[_]]

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[SimpleAnchorExtractorSpark]]
*/
trait SimpleAnchorExtractorSparkAdaptor extends UdfAdaptor[SimpleAnchorExtractorSpark]

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[FeatureDerivationFunction]]
*/
trait FeatureDerivationFunctionAdaptor extends UdfAdaptor[FeatureDerivationFunction]

/**
* An adaptor that can "tell Feathr how to use" a UDF type that can act in place of [[SourceKeyExtractor]]
*/
trait SourceKeyExtractorAdaptor extends UdfAdaptor[SourceKeyExtractor]
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.linkedin.feathr.offline.ErasedEntityTaggedFeature
import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnchorExtractor, SimpleConfigurableAnchorExtractor, TimeWindowConfigurableAnchorExtractor}
import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource}
import com.linkedin.feathr.offline.anchored.keyExtractor.{MVELSourceKeyExtractor, SQLSourceKeyExtractor}
import com.linkedin.feathr.offline.client.plugins.{AnchorExtractorAdaptor, FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor, SimpleAnchorExtractorSparkAdaptor, SourceKeyExtractorAdaptor}
import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, KafkaEndpoint, LocationUtils, SimplePath}
import com.linkedin.feathr.offline.derived._
import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction, SimpleMvelDerivationFunction}
Expand Down Expand Up @@ -252,7 +253,16 @@ private[offline] class AnchorLoader extends JsonDeserializer[FeatureAnchor] {
// SimpleConfigurableAnchorExtractor will extract the feature type if defined

// if it is UDF, no extra information other than the class name is required
val anchorExtractor: AnyRef = codec.treeToValue(node, extractorClass).asInstanceOf[AnyRef]
val anchorExtractor: AnyRef = {
val extractor = codec.treeToValue(node, extractorClass).asInstanceOf[AnyRef]
FeathrUdfPluginContext.getRegisteredUdfAdaptor(extractorClass) match {
case None => extractor
case Some(adaptor: SimpleAnchorExtractorSparkAdaptor) =>
adaptor.adaptUdf(extractor)
case Some(adaptor: AnchorExtractorAdaptor) =>
adaptor.adaptUdf(extractor)
}
}

anchorExtractor match {
case extractor: AnchorExtractorBase[_] =>
Expand Down Expand Up @@ -362,8 +372,13 @@ private[offline] class AnchorLoader extends JsonDeserializer[FeatureAnchor] {
anchorExtractorBase: AnyRef,
lateralViewParameters: Option[LateralViewParams]): SourceKeyExtractor = {
Option(node.get("keyExtractor")).map(_.textValue) match {
case Some(keyExtractor) =>
Class.forName(keyExtractor).newInstance().asInstanceOf[SourceKeyExtractor]
case Some(keyExtractorClassName) =>
val keyExtractorClass = Class.forName(keyExtractorClassName)
val newInstance = keyExtractorClass.getDeclaredConstructor().newInstance().asInstanceOf[AnyRef]
FeathrUdfPluginContext.getRegisteredUdfAdaptor(keyExtractorClass) match {
case Some(adaptor: SourceKeyExtractorAdaptor) => adaptor.adaptUdf(newInstance)
case _ => newInstance.asInstanceOf[SourceKeyExtractor]
}
case _ =>
Option(node.get("key")) match {
case Some(keyNode) =>
Expand Down Expand Up @@ -545,8 +560,16 @@ private[offline] class DerivationLoader extends JsonDeserializer[DerivedFeature]
loadAdvancedDerivedFeature(x, codec)
} else if (x.has("class")) {
val config = codec.treeToValue(x, classOf[CustomDerivedFeatureConfig])
val derivationFunction = config.`class`.newInstance().asInstanceOf[AnyRef]
val derivationFunctionClass = config.`class`
val derivationFunction = derivationFunctionClass.getDeclaredConstructor().newInstance().asInstanceOf[AnyRef]
// possibly "adapt" the derivation function, in case it doesn't implement Feathr's FeatureDerivationFunction,
// using FeathrUdfPluginContext
val maybeAdaptedDerivationFunction = FeathrUdfPluginContext.getRegisteredUdfAdaptor(derivationFunctionClass) match {
case Some(adaptor: FeatureDerivationFunctionAdaptor) => adaptor.adaptUdf(derivationFunction)
case _ => derivationFunction
}
val consumedFeatures = config.inputs.map(x => ErasedEntityTaggedFeature(x.key.map(config.key.zipWithIndex.toMap), x.feature)).toIndexedSeq

// consumedFeatures and parameterNames have same order, since they are all from config.inputs
DerivedFeature(consumedFeatures, producedFeatures, derivationFunction, config.parameterNames, featureTypeConfigMap)
} else if (x.has("join")) { // when the derived feature config is a seqJoin config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ private[offline] case class DerivedFeature(
* get the row-based FeatureDerivationFunction, note that some of the derivations that derive from FeatureDerivationFunctionBase
* are not subclass of FeatureDerivationFunction, e.g, [[FeatureDerivationFunctionSpark]], in such cases, this function will
* throw exception, make sure you will not call this function for such cases.
*
* TODO: The above described condition is bad; ideally this class should capture the information about what type of
* derivation function this is in a type-safe way.
*/
def getAsFeatureDerivationFunction(): FeatureDerivationFunction = derivation.asInstanceOf[FeatureDerivationFunction]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig}
import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrException}
import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame}
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.client.plugins.{FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor}
import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction
import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinAsDerivation, SparkUdfDerivation}
import com.linkedin.feathr.offline.join.algorithms.{SequentialJoinConditionBuilder, SparkJoinWithJoinCondition}
Expand Down Expand Up @@ -49,8 +50,16 @@ private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationS
// doesn't have fixed DataFrame schema so that we can't return TensorData but has to return FDS.
val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, x)
offline.FeatureDataFrame(resultDF, getTypeConfigs(producedFeatureColName, derivedFeature, resultDF))
case _ =>
throw new FeathrException(ErrorLabel.FEATHR_ERROR, s"Unsupported feature derivation function for feature ${derivedFeature.producedFeatureNames.head}.")
case derivation =>
FeathrUdfPluginContext.getRegisteredUdfAdaptor(derivation.getClass) match {
case Some(adaptor: FeatureDerivationFunctionAdaptor) =>
// replicating the FeatureDerivationFunction case above
val featureDerivationFunction = adaptor.adaptUdf(derivation)
val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, featureDerivationFunction)
offline.FeatureDataFrame(resultDF, getTypeConfigs(producedFeatureColName, derivedFeature, resultDF))
case _ =>
throw new FeathrException(ErrorLabel.FEATHR_ERROR, s"Unsupported feature derivation function for feature ${derivedFeature.producedFeatureNames.head}.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnch
import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource}
import com.linkedin.feathr.offline.anchored.keyExtractor.MVELSourceKeyExtractor
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.client.plugins.{SimpleAnchorExtractorSparkAdaptor, FeathrUdfPluginContext, AnchorExtractorAdaptor}
import com.linkedin.feathr.offline.config.{MVELFeatureDefinition, TimeWindowFeatureDefinition}
import com.linkedin.feathr.offline.generation.IncrementalAggContext
import com.linkedin.feathr.offline.job.FeatureJoinJob.FeatureName
Expand All @@ -23,7 +24,6 @@ import com.linkedin.feathr.swj.aggregate.AggregationType
import com.linkedin.feathr.{common, offline}
import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
Expand Down Expand Up @@ -179,8 +179,8 @@ private[offline] object FeatureTransformation {
DataFrameBasedSqlEvaluator.transform(transformer, df, featureNamePrefixPairs, featureTypeConfigs)
case transformer: AnchorExtractor[_] =>
DataFrameBasedRowEvaluator.transform(transformer, df, featureNamePrefixPairs, featureTypeConfigs)
case _ => throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_USER_ERROR, s"cannot find valid Transformer for ${featureAnchorWithSource}")

case _ =>
throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_USER_ERROR, s"cannot find valid Transformer for ${featureAnchorWithSource}")
}
// Check for whether there are duplicate columns in the transformed DataFrame, typically this is because
// the feature name is the same as some field name
Expand Down
32 changes: 32 additions & 0 deletions src/main/scala/com/linkedin/feathr/offline/mvel/MvelContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.linkedin.feathr.common.FeatureValue;
import com.linkedin.feathr.common.util.MvelContextUDFs;
import org.apache.avro.generic.GenericEnumSymbol;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.mvel2.DataConversion;
import org.mvel2.MVEL;
import org.mvel2.ParserConfiguration;
import org.mvel2.ParserContext;
import org.mvel2.integration.PropertyHandler;
Expand Down Expand Up @@ -106,6 +109,35 @@ public static ParserContext newParserContext() {
return new ParserContext(PARSER_CONFIG);
}

/**
* Evaluate MVEL expression as per {@link MVEL#executeExpression(Object, Object)}, with added support for
* {@link com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext}. (Output objects that can be converted
* to {@link FeatureValue} via plugins, will be converted after MVEL returns.)
*/
public static Object executeExpressionWithPluginSupport(Object compiledExpression, Object ctx) {
Object output = MVEL.executeExpression(compiledExpression, ctx);
return coerceToFeatureValueViaMvelDataConversionPlugins(output);
}

/**
* Evaluate MVEL expression as per {@link MVEL#executeExpression(Object, Object, VariableResolverFactory)}, with added support for
* {@link com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext}. (Output objects that can be converted
* to {@link FeatureValue} via plugins, will be converted after MVEL returns.)
*/
public static Object executeExpressionWithPluginSupport(Object compiledExpression, Object ctx,
VariableResolverFactory variableResolverFactory) {
Object output = MVEL.executeExpression(compiledExpression, ctx, variableResolverFactory);
return coerceToFeatureValueViaMvelDataConversionPlugins(output);
}

private static Object coerceToFeatureValueViaMvelDataConversionPlugins(Object input) {
if (input != null && DataConversion.canConvert(FeatureValue.class, input.getClass())) {
return DataConversion.convert(input, FeatureValue.class);
} else {
return input;
}
}

/**
* Allows easy access to the properties of GenericRecord object from MVEL.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ private[offline] object MvelUtils {
// (We might not want to check for null explicitly everywhere)
def executeExpression(compiledExpression: Any, input: Any, resolverFactory: VariableResolverFactory, featureName: String = ""): Option[AnyRef] = {
try {
Option(MVEL.executeExpression(compiledExpression, input, resolverFactory))
Option(MvelContext.executeExpressionWithPluginSupport(compiledExpression, input, resolverFactory))
} catch {
case e: RuntimeException =>
log.debug(s"Expression $compiledExpression on input record $input threw exception", e)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.linkedin.feathr.offline.mvel.plugins;

import com.linkedin.feathr.common.FeatureValue;
import com.linkedin.feathr.common.InternalApi;
import org.mvel2.ConversionHandler;
import org.mvel2.DataConversion;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;


/**
* A plugin that allows an advanced user to add additional capabilities or behaviors to Feathr's MVEL runtime.
*
* NOTE: This class is intended for advanced users only, and specifically as a "migration aid" for migrating from
* some previous versions of Feathr whose FeatureValue representations had a different class name, while preserving
* compatibility with feature definitions written against those older versions of Feathr.
*/
public class FeathrMvelPluginContext {
// TODO: Does this need to be "translated" into a different pattern whereby we track the CLASSNAME of the type adaptors
// instead of the instance, such that the class mappings can be broadcasted via Spark and then reinitialized on
// executor hosts?
private static final ConcurrentMap<Class<?>, FeatureValueTypeAdaptor<?>> TYPE_ADAPTORS;

static {
TYPE_ADAPTORS = new ConcurrentHashMap<>();
DataConversion.addConversionHandler(FeatureValue.class, new FeathrFeatureValueConversionHandler());
}

/**
* Add a type adaptor to Feathr's MVEL runtime, that will enable Feathr's expressions to support some alternative
* class representation of {@link FeatureValue} via coercion.
* @param clazz the class of the "other" alternative representation of feature value
* @param typeAdaptor the type adaptor that can convert between the "other" representation and {@link FeatureValue}
* @param <T> type parameter for the "other" feature value class
*/
@SuppressWarnings("unchecked")
public static <T> void addFeatureTypeAdaptor(Class<T> clazz, FeatureValueTypeAdaptor<T> typeAdaptor) {
// TODO: MAKE SURE clazz IS NOT ONE OF THE CLASSES ALREADY COVERED IN org.mvel2.DataConversion.CONVERTERS!
// IF WE OVERRIDE ANY OF THOSE, IT MIGHT CAUSE MVEL TO BEHAVE IN STRANGE AND UNEXPECTED WAYS!
TYPE_ADAPTORS.put(clazz, typeAdaptor);
DataConversion.addConversionHandler(clazz, new ExternalFeatureValueConversionHandler(typeAdaptor));
}

static class FeathrFeatureValueConversionHandler implements ConversionHandler {
@Override
@SuppressWarnings("unchecked")
public Object convertFrom(Object in) {
FeatureValueTypeAdaptor<Object> adaptor = (FeatureValueTypeAdaptor<Object>) TYPE_ADAPTORS.get(in.getClass());
if (adaptor == null) {
throw new IllegalArgumentException("Can't convert to Feathr FeatureValue from " + in);
}
return adaptor.toFeathrFeatureValue(in);
}

@Override
public boolean canConvertFrom(Class cls) {
return TYPE_ADAPTORS.containsKey(cls);
}
}

static class ExternalFeatureValueConversionHandler implements ConversionHandler {
private final FeatureValueTypeAdaptor<?> _adaptor;

public ExternalFeatureValueConversionHandler(FeatureValueTypeAdaptor<?> adaptor) {
_adaptor = adaptor;
}

@Override
public Object convertFrom(Object in) {
return _adaptor.fromFeathrFeatureValue((FeatureValue) in);
}

@Override
public boolean canConvertFrom(Class cls) {
return FeatureValue.class.equals(cls);
}
}
}
Loading

0 comments on commit f3eeffd

Please sign in to comment.