diff --git a/build.sbt b/build.sbt index 52c9da927..e81a1bc26 100644 --- a/build.sbt +++ b/build.sbt @@ -1,3 +1,5 @@ +import sbt.Keys.publishLocalConfiguration + ThisBuild / resolvers += Resolver.mavenLocal ThisBuild / scalaVersion := "2.12.15" ThisBuild / version := "0.7.2" @@ -5,6 +7,8 @@ ThisBuild / organization := "com.linkedin.feathr" ThisBuild / organizationName := "linkedin" val sparkVersion = "3.1.3" +publishLocalConfiguration := publishLocalConfiguration.value.withOverwrite(true) + val localAndCloudDiffDependencies = Seq( "org.apache.spark" %% "spark-avro" % sparkVersion, "org.apache.spark" %% "spark-sql" % sparkVersion, @@ -101,4 +105,4 @@ assembly / assemblyMergeStrategy := { // Some systems(like Hadoop) use different versinos of protobuf(like v2) so we have to shade it. assemblyShadeRules in assembly := Seq( ShadeRule.rename("com.google.protobuf.**" -> "shade.protobuf.@1").inAll, -) +) \ No newline at end of file diff --git a/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala b/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala index eb2f4f0ae..b1f75d662 100644 --- a/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala +++ b/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala @@ -1,10 +1,10 @@ package com.linkedin.feathr.offline import java.io.Serializable - import com.linkedin.feathr.common import com.linkedin.feathr.common.{FeatureTypes, FeatureValue} import com.linkedin.feathr.offline.exception.FeatureTransformationException +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.mvel.{FeatureVariableResolverFactory, MvelContext} import com.linkedin.feathr.offline.transformation.MvelDefinition import com.linkedin.feathr.offline.util.{CoercionUtilsScala, FeaturizedDatasetUtils} @@ -32,9 +32,9 @@ private[offline] object PostTransformationUtil { * @param input input feature value * @return transformed feature value */ - def booleanTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Boolean): Boolean = { + def booleanTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Boolean, mvelContext: Option[FeathrExpressionExecutionContext]): Boolean = { val toFeatureValue = common.FeatureValue.createBoolean(input) - val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR) + val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR, mvelContext) transformedFeatureValue match { case Success(fVal) => fVal.getAsTermVector.containsKey("true") case Failure(ex) => @@ -57,12 +57,12 @@ private[offline] object PostTransformationUtil { featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, - input: GenericRowWithSchema): Map[String, Float] = { + input: GenericRowWithSchema, mvelContext: Option[FeathrExpressionExecutionContext]): Map[String, Float] = { if (input != null) { val inputMapKey = input.getAs[Seq[String]](FeaturizedDatasetUtils.FDS_1D_TENSOR_DIM) val inputMapVal = input.getAs[Seq[Float]](FeaturizedDatasetUtils.FDS_1D_TENSOR_VALUE) val inputMap = inputMapKey.zip(inputMapVal).toMap - mapTransformer(featureName, mvelExpression, compiledExpression, inputMap) + mapTransformer(featureName, mvelExpression, compiledExpression, inputMap, mvelContext) } else Map() } @@ -79,7 +79,8 @@ private[offline] object PostTransformationUtil { featureNameColumnTuples: Seq[(String, String)], contextDF: DataFrame, transformationDef: Map[String, MvelDefinition], - defaultTransformation: (DataType, String) => Column): DataFrame = { + defaultTransformation: (DataType, String) => Column, + mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = { val featureColumnNames = featureNameColumnTuples.map(_._2) // Transform the features with the provided transformations @@ -93,11 +94,11 @@ private[offline] object PostTransformationUtil { val parserContext = MvelContext.newParserContext() val compiledExpression = MVEL.compileExpression(mvelExpressionDef.mvelDef, parserContext) val featureType = mvelExpressionDef.featureType - val convertToString = udf(stringTransformer(featureName, mvelExpressionDef, compiledExpression, _: String)) - val convertToBoolean = udf(booleanTransformer(featureName, mvelExpressionDef, compiledExpression, _: Boolean)) - val convertToFloat = udf(floatTransformer(featureName, mvelExpressionDef, compiledExpression, _: Float)) - val convertToMap = udf(mapTransformer(featureName, mvelExpressionDef, compiledExpression, _: Map[String, Float])) - val convertFDS1dTensorToMap = udf(fds1dTensorTransformer(featureName, mvelExpressionDef, compiledExpression, _: GenericRowWithSchema)) + val convertToString = udf(stringTransformer(featureName, mvelExpressionDef, compiledExpression, _: String, mvelContext)) + val convertToBoolean = udf(booleanTransformer(featureName, mvelExpressionDef, compiledExpression, _: Boolean, mvelContext)) + val convertToFloat = udf(floatTransformer(featureName, mvelExpressionDef, compiledExpression, _: Float, mvelContext)) + val convertToMap = udf(mapTransformer(featureName, mvelExpressionDef, compiledExpression, _: Map[String, Float], mvelContext)) + val convertFDS1dTensorToMap = udf(fds1dTensorTransformer(featureName, mvelExpressionDef, compiledExpression, _: GenericRowWithSchema, mvelContext)) fieldType.dataType match { case _: StringType => convertToString(contextDF(columnName)) case _: NumericType => convertToFloat(contextDF(columnName)) @@ -126,16 +127,17 @@ private[offline] object PostTransformationUtil { featureName: String, featureValue: FeatureValue, compiledExpression: Serializable, - featureType: FeatureTypes): Try[FeatureValue] = Try { + featureType: FeatureTypes, + mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try { val args = Map(featureName -> Some(featureValue)) val variableResolverFactory = new FeatureVariableResolverFactory(args) - val transformedValue = MvelContext.executeExpressionWithPluginSupport(compiledExpression, featureValue, variableResolverFactory) + val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull) CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType) } - private def floatTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Float): Float = { + private def floatTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Float, mvelContext: Option[FeathrExpressionExecutionContext]): Float = { val toFeatureValue = common.FeatureValue.createNumeric(input) - val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.NUMERIC) + val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.NUMERIC, mvelContext) transformedFeatureValue match { case Success(fVal) => fVal.getAsNumeric case Failure(ex) => @@ -146,9 +148,9 @@ private[offline] object PostTransformationUtil { } } - private def stringTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: String): String = { + private def stringTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: String, mvelContext: Option[FeathrExpressionExecutionContext]): String = { val toFeatureValue = common.FeatureValue.createCategorical(input) - val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.CATEGORICAL) + val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.CATEGORICAL, mvelContext) transformedFeatureValue match { case Success(fVal) => fVal.getAsString case Failure(ex) => @@ -163,12 +165,13 @@ private[offline] object PostTransformationUtil { featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, - input: Map[String, Float]): Map[String, Float] = { + input: Map[String, Float], + mvelContext: Option[FeathrExpressionExecutionContext]): Map[String, Float] = { if (input == null) { return Map() } val toFeatureValue = new common.FeatureValue(input.asJava) - val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR) + val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR, mvelContext) transformedFeatureValue match { case Success(fVal) => fVal.getAsTermVector.asScala.map(kv => (kv._1.asInstanceOf[String], kv._2.asInstanceOf[Float])).toMap case Failure(ex) => diff --git a/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/DebugMvelAnchorExtractor.scala b/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/DebugMvelAnchorExtractor.scala index 264b2ee9a..c4b574c8a 100644 --- a/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/DebugMvelAnchorExtractor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/DebugMvelAnchorExtractor.scala @@ -1,24 +1,22 @@ package com.linkedin.feathr.offline.anchored.anchorExtractor -import java.io.Serializable - import com.linkedin.feathr.offline.config.MVELFeatureDefinition import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils} import org.mvel2.MVEL +import java.io.Serializable import scala.collection.convert.wrapAll._ private[offline] class DebugMvelAnchorExtractor(keyExprs: Seq[String], features: Map[String, MVELFeatureDefinition]) extends SimpleConfigurableAnchorExtractor(keyExprs, features) { private val debugExpressions = features.mapValues(value => findDebugExpressions(value.featureExpr)).map(identity) - private val debugCompiledExpressions = debugExpressions.mapValues(_.map(x => (x, compile(x)))).map(identity) def evaluateDebugExpressions(input: Any): Map[String, Seq[(String, Any)]] = { debugCompiledExpressions .mapValues(_.map { case (expr, compiled) => - (expr, MvelUtils.executeExpression(compiled, input, null).orNull) + (expr, MvelUtils.executeExpression(compiled, input, null, "", None).orNull) }) .map(identity) } diff --git a/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/SimpleConfigurableAnchorExtractor.scala b/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/SimpleConfigurableAnchorExtractor.scala index 479167e36..59f5bfbe7 100644 --- a/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/SimpleConfigurableAnchorExtractor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/anchored/anchorExtractor/SimpleConfigurableAnchorExtractor.scala @@ -6,6 +6,7 @@ import com.linkedin.feathr.common.util.CoercionUtils import com.linkedin.feathr.common.{AnchorExtractor, FeatureTypeConfig, FeatureTypes, FeatureValue, SparkRowExtractor} import com.linkedin.feathr.offline import com.linkedin.feathr.offline.config.MVELFeatureDefinition +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils} import com.linkedin.feathr.offline.util.FeatureValueTypeValidator import org.apache.log4j.Logger @@ -28,6 +29,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k @JsonProperty("features") features: Map[String, MVELFeatureDefinition]) extends AnchorExtractor[Any] with SparkRowExtractor { + var mvelContext: Option[FeathrExpressionExecutionContext] = None @transient private lazy val log = Logger.getLogger(getClass) def getKeyExpression(): Seq[String] = key @@ -73,7 +75,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k // be more strict for resolving keys (don't swallow exceptions) keyExpression.map(k => try { - Option(MvelContext.executeExpressionWithPluginSupport(k, datum)) match { + Option(MvelContext.executeExpressionWithPluginSupport(k, datum, mvelContext.orNull)) match { case None => null case Some(keys) => keys.toString } @@ -92,7 +94,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k featureExpressions collect { case (featureRefStr, (expression, featureType)) if selectedFeatures.contains(featureRefStr) => - (featureRefStr, (MvelUtils.executeExpression(expression, datum, null, featureRefStr), featureType)) + (featureRefStr, (MvelUtils.executeExpression(expression, datum, null, featureRefStr, mvelContext), featureType)) } collect { // Apply a partial function only for non-empty feature values, empty feature values will be set to default later case (featureRefStr, (Some(value), fType)) => @@ -165,7 +167,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k * for building a tensor. Feature's value type and dimension type(s) are obtained via Feathr's Feature Metadata * Library during tensor construction. */ - (featureRefStr, MvelUtils.executeExpression(expression, datum, null, featureRefStr)) + (featureRefStr, MvelUtils.executeExpression(expression, datum, null, featureRefStr, mvelContext)) } } diff --git a/src/main/scala/com/linkedin/feathr/offline/client/FeathrClient.scala b/src/main/scala/com/linkedin/feathr/offline/client/FeathrClient.scala index 45f8b2b02..b289ba3c5 100644 --- a/src/main/scala/com/linkedin/feathr/offline/client/FeathrClient.scala +++ b/src/main/scala/com/linkedin/feathr/offline/client/FeathrClient.scala @@ -8,9 +8,10 @@ import com.linkedin.feathr.offline.generation.{DataFrameFeatureGenerator, Featur import com.linkedin.feathr.offline.job._ import com.linkedin.feathr.offline.join.DataFrameFeatureJoiner import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlanner} +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.source.DataSource import com.linkedin.feathr.offline.source.accessor.DataPathHandler -import com.linkedin.feathr.offline.util.{FeathrUtils, _} +import com.linkedin.feathr.offline.util._ import org.apache.log4j.Logger import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.internal.SQLConf @@ -27,7 +28,7 @@ import scala.util.{Failure, Success} * */ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups: FeatureGroups, logicalPlanner: MultiStageJoinPlanner, - featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler]) { + featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler], mvelContext: Option[FeathrExpressionExecutionContext]) { private val log = Logger.getLogger(getClass) type KeyTagStringTuple = Seq[String] @@ -91,7 +92,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups: // Get logical plan val logicalPlan = logicalPlanner.getLogicalPlan(featureGroups, keyTaggedRequiredFeatures) // This pattern is consistent with the join use case which uses DataFrameFeatureJoiner. - val dataFrameFeatureGenerator = new DataFrameFeatureGenerator(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers) + val dataFrameFeatureGenerator = new DataFrameFeatureGenerator(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext) val featureMap: Map[TaggedFeatureName, (DataFrame, Header)] = dataFrameFeatureGenerator.generateFeaturesAsDF(sparkSession, featureGenSpec, featureGroups, keyTaggedRequiredFeatures) @@ -263,7 +264,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups: s"Please rename feature ${conflictFeatureNames} or rename the same field names in the observation data.") } - val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers) + val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext) joiner.joinFeaturesAsDF(sparkSession, joinConfig, updatedFeatureGroups, keyTaggedFeatures, left, rowBloomFilterThreshold) } @@ -337,6 +338,7 @@ object FeathrClient { private var localOverrideDefPath: List[String] = List() private var featureDefConfs: List[FeathrConfig] = List() private var dataPathHandlers: List[DataPathHandler] = List() + private var mvelContext: Option[FeathrExpressionExecutionContext] = None; /** @@ -495,6 +497,10 @@ object FeathrClient { this.featureDefConfs = featureDefConfs this } + def addFeathrExpressionContext(_mvelContext: Option[FeathrExpressionExecutionContext]): Builder = { + this.mvelContext = _mvelContext + this + } /** * Build a new instance of the FeathrClient from the added feathr definition configs and any local overrides. @@ -529,7 +535,7 @@ object FeathrClient { featureDefConfigs = featureDefConfigs ++ featureDefConfs val featureGroups = FeatureGroupsGenerator(featureDefConfigs, Some(localDefConfigs)).getFeatureGroups() - val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(), dataPathHandlers) + val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(), dataPathHandlers, mvelContext) feathrClient } diff --git a/src/main/scala/com/linkedin/feathr/offline/client/plugins/FeathrUdfPluginContext.scala b/src/main/scala/com/linkedin/feathr/offline/client/plugins/FeathrUdfPluginContext.scala index 852c2a2e6..d67e5b6d5 100644 --- a/src/main/scala/com/linkedin/feathr/offline/client/plugins/FeathrUdfPluginContext.scala +++ b/src/main/scala/com/linkedin/feathr/offline/client/plugins/FeathrUdfPluginContext.scala @@ -1,4 +1,6 @@ package com.linkedin.feathr.offline.client.plugins +import org.apache.spark.SparkContext +import org.apache.spark.broadcast.Broadcast import scala.collection.mutable @@ -9,15 +11,21 @@ import scala.collection.mutable * All "external" UDF classes are required to have a public default zero-arg constructor. */ object FeathrUdfPluginContext { - val registeredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]() - - def registerUdfAdaptor(adaptor: UdfAdaptor[_]): Unit = { + private val localRegisteredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]() + private var registeredUdfAdaptors: Broadcast[mutable.Buffer[UdfAdaptor[_]]] = null + def registerUdfAdaptor(adaptor: UdfAdaptor[_], sc: SparkContext): Unit = { this.synchronized { - registeredUdfAdaptors += adaptor + localRegisteredUdfAdaptors += adaptor + if (registeredUdfAdaptors != null) { + registeredUdfAdaptors.destroy() + } + registeredUdfAdaptors = sc.broadcast(localRegisteredUdfAdaptors) } } def getRegisteredUdfAdaptor(clazz: Class[_]): Option[UdfAdaptor[_]] = { - registeredUdfAdaptors.find(_.canAdapt(clazz)) + if (registeredUdfAdaptors != null) { + registeredUdfAdaptors.value.find(_.canAdapt(clazz)) + } else None } } \ No newline at end of file diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala b/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala index 36de11fae..ff16ebe18 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/DerivedFeatureEvaluator.scala @@ -6,10 +6,11 @@ import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrException} import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame} import com.linkedin.feathr.offline.client.DataFrameColName import com.linkedin.feathr.offline.client.plugins.{FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor} -import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction +import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SeqJoinDerivationFunction} import com.linkedin.feathr.offline.derived.strategies.{DerivationStrategies, RowBasedDerivation, SequentialJoinAsDerivation, SparkUdfDerivation} import com.linkedin.feathr.offline.join.algorithms.{SequentialJoinConditionBuilder, SparkJoinWithJoinCondition} import com.linkedin.feathr.offline.logical.FeatureGroups +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils import com.linkedin.feathr.offline.source.accessor.DataPathHandler import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark @@ -20,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} * This class is responsible for applying feature derivations. * @param derivationStrategies strategies for executing various derivation functions. */ -private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationStrategies) { +private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationStrategies, mvelContext: Option[FeathrExpressionExecutionContext]) { /** * Calculate a derived feature, this function support all kinds of derived features @@ -39,23 +40,23 @@ private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationS derivedFeature.derivation match { case g: SeqJoinDerivationFunction => - val resultDF = derivationStrategies.sequentialJoinDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, g) + val resultDF = derivationStrategies.sequentialJoinDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, g, mvelContext) convertFeatureColumnToQuinceFds(producedFeatureColName, derivedFeature, resultDF) case h: FeatureDerivationFunctionSpark => - val resultDF = derivationStrategies.customDerivationSparkStrategy(keyTag, keyTagList, contextDF, derivedFeature, h) + val resultDF = derivationStrategies.customDerivationSparkStrategy(keyTag, keyTagList, contextDF, derivedFeature, h, mvelContext) convertFeatureColumnToQuinceFds(producedFeatureColName, derivedFeature, resultDF) case x: FeatureDerivationFunction => // We should do the FDS conversion inside the rowBasedDerivationStrategy here. The result of rowBasedDerivationStrategy // can be NTV FeatureValue or TensorData-based Feature. NTV FeatureValue has fixed FDS schema. However, TensorData // doesn't have fixed DataFrame schema so that we can't return TensorData but has to return FDS. - val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, x) + val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, x, mvelContext) offline.FeatureDataFrame(resultDF, getTypeConfigs(producedFeatureColName, derivedFeature, resultDF)) case derivation => FeathrUdfPluginContext.getRegisteredUdfAdaptor(derivation.getClass) match { case Some(adaptor: FeatureDerivationFunctionAdaptor) => // replicating the FeatureDerivationFunction case above val featureDerivationFunction = adaptor.adaptUdf(derivation) - val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, featureDerivationFunction) + val resultDF = derivationStrategies.rowBasedDerivationStrategy(keyTag, keyTagList, contextDF, derivedFeature, featureDerivationFunction, mvelContext) offline.FeatureDataFrame(resultDF, getTypeConfigs(producedFeatureColName, derivedFeature, resultDF)) case _ => throw new FeathrException(ErrorLabel.FEATHR_ERROR, s"Unsupported feature derivation function for feature ${derivedFeature.producedFeatureNames.head}.") @@ -108,17 +109,18 @@ private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationS private[offline] object DerivedFeatureEvaluator { private val log = Logger.getLogger(getClass) - def apply(derivationStrategies: DerivationStrategies): DerivedFeatureEvaluator = new DerivedFeatureEvaluator(derivationStrategies) + def apply(derivationStrategies: DerivationStrategies, mvelContext: Option[FeathrExpressionExecutionContext]): DerivedFeatureEvaluator = new DerivedFeatureEvaluator(derivationStrategies, mvelContext) def apply(ss: SparkSession, featureGroups: FeatureGroups, - dataPathHandlers: List[DataPathHandler]): DerivedFeatureEvaluator = { + dataPathHandlers: List[DataPathHandler], + mvelContext: Option[FeathrExpressionExecutionContext]): DerivedFeatureEvaluator = { val defaultStrategies = strategies.DerivationStrategies( new SparkUdfDerivation(), - new RowBasedDerivation(featureGroups.allTypeConfigs), + new RowBasedDerivation(featureGroups.allTypeConfigs, mvelContext), new SequentialJoinAsDerivation(ss, featureGroups, SparkJoinWithJoinCondition(SequentialJoinConditionBuilder), dataPathHandlers) ) - new DerivedFeatureEvaluator(defaultStrategies) + new DerivedFeatureEvaluator(defaultStrategies, mvelContext) } /** @@ -132,7 +134,9 @@ private[offline] object DerivedFeatureEvaluator { def evaluateFromFeatureValues( keyTag: Seq[Int], derivedFeature: DerivedFeature, - contextFeatureValues: Map[common.ErasedEntityTaggedFeature, common.FeatureValue]): Map[common.ErasedEntityTaggedFeature, common.FeatureValue] = { + contextFeatureValues: Map[common.ErasedEntityTaggedFeature, common.FeatureValue], + mvelContext: Option[FeathrExpressionExecutionContext] + ): Map[common.ErasedEntityTaggedFeature, common.FeatureValue] = { try { val linkedInputParams = derivedFeature.consumedFeatureNames.map { case ErasedEntityTaggedFeature(calleeTag, featureName) => @@ -141,7 +145,13 @@ private[offline] object DerivedFeatureEvaluator { // for features with value `null`, convert Some(null) to None, to avoid null pointer exception in downstream analysis val keyedContextFeatureValues = contextFeatureValues.map(kv => (kv._1.getErasedTagFeatureName, kv._2)) val resolvedInputArgs = linkedInputParams.map(taggedFeature => keyedContextFeatureValues.get(taggedFeature.getErasedTagFeatureName).flatMap(Option(_))) - val unlinkedOutput = derivedFeature.getAsFeatureDerivationFunction.getFeatures(resolvedInputArgs) + val derivedFunc = derivedFeature.getAsFeatureDerivationFunction match { + case derivedFunc: MvelFeatureDerivationFunction => + derivedFunc.mvelContext = mvelContext + derivedFunc + case func => func + } + val unlinkedOutput = derivedFunc.getFeatures(resolvedInputArgs) val callerKeyTags = derivedFeature.producedFeatureNames.map(ErasedEntityTaggedFeature(keyTag, _)) // This would indicate a problem with the DerivedFeature, where there are a different number of features included in diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala b/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala index 58902e669..42f09ad21 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala @@ -4,6 +4,7 @@ import com.linkedin.feathr.common import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig, TaggedFeatureName} import com.linkedin.feathr.offline.FeatureValue import com.linkedin.feathr.offline.config.TaggedDependency +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.mvel.{FeatureVariableResolverFactory, MvelContext, MvelUtils} import org.mvel2.MVEL @@ -31,6 +32,7 @@ private[offline] class MvelFeatureDerivationFunction( featureTypeConfigOpt: Option[FeatureTypeConfig] = None) extends FeatureDerivationFunction { + var mvelContext: Option[FeathrExpressionExecutionContext] = None val parameterNames: Seq[String] = inputFeatures.keys.toIndexedSeq private val compiledExpression = { @@ -42,7 +44,7 @@ private[offline] class MvelFeatureDerivationFunction( val argMap = (parameterNames zip inputs).toMap val variableResolverFactory = new FeatureVariableResolverFactory(argMap) - MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory) match { + MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match { case Some(value) => val featureTypeConfig = featureTypeConfigOpt.getOrElse(FeatureTypeConfig.UNDEFINED_TYPE_CONFIG) if (value.isInstanceOf[common.FeatureValue]) { diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala b/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala index 9e1f6b0bb..203d1886f 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala @@ -3,6 +3,7 @@ package com.linkedin.feathr.offline.derived.functions import com.linkedin.feathr.common import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig} import com.linkedin.feathr.offline.FeatureValue +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.mvel.{FeatureVariableResolverFactory, MvelContext, MvelUtils} import com.linkedin.feathr.offline.testfwk.TestFwkUtils import org.apache.log4j.Logger @@ -19,6 +20,7 @@ private[offline] class SimpleMvelDerivationFunction(expression: String, featureN extends FeatureDerivationFunction { @transient private lazy val log = Logger.getLogger(getClass) + var mvelContext: Option[FeathrExpressionExecutionContext] = None // strictMode should only be modified by FeathrConfigLoader when loading config, default value to be false var strictMode = false @@ -51,7 +53,7 @@ private[offline] class SimpleMvelDerivationFunction(expression: String, featureN } } - MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory) match { + MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match { case Some(value) => val featureTypeConfig = featureTypeConfigOpt.getOrElse(FeatureTypeConfig.UNDEFINED_TYPE_CONFIG) val featureValue = FeatureValue.fromTypeConfig(value, featureTypeConfig) diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala index 6f7ea1eab..e54d68f59 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/DerivationStrategies.scala @@ -3,6 +3,7 @@ package com.linkedin.feathr.offline.derived.strategies import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureDerivationFunctionBase} import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction import com.linkedin.feathr.offline.derived.DerivedFeature +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark import org.apache.spark.sql.DataFrame @@ -12,7 +13,6 @@ import org.apache.spark.sql.DataFrame * A derivation strategy encapsulates the execution of derivations. */ private[offline] trait DerivationStrategy[T <: FeatureDerivationFunctionBase] { - /** * Apply the derivation strategy. * @param keyTags keyTags for the derived feature. @@ -22,7 +22,7 @@ private[offline] trait DerivationStrategy[T <: FeatureDerivationFunctionBase] { * @param derivationFunction Derivation function to evaluate the derived feature * @return output DataFrame with derived feature. */ - def apply(keyTags: Seq[Int], keyTagList: Seq[String], df: DataFrame, derivedFeature: DerivedFeature, derivationFunction: T): DataFrame + def apply(keyTags: Seq[Int], keyTagList: Seq[String], df: DataFrame, derivedFeature: DerivedFeature, derivationFunction: T, mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame } /** diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/RowBasedDerivation.scala b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/RowBasedDerivation.scala index 389c530ee..ca78ff464 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/RowBasedDerivation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/RowBasedDerivation.scala @@ -6,6 +6,7 @@ import com.linkedin.feathr.common.{FeatureDerivationFunction, FeatureTypeConfig, import com.linkedin.feathr.offline.ErasedEntityTaggedFeature import com.linkedin.feathr.offline.client.DataFrameColName import com.linkedin.feathr.offline.derived.{DerivedFeature, DerivedFeatureEvaluator} +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.testfwk.TestFwkUtils import com.linkedin.feathr.offline.transformation.FDSConversionUtils import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils.tensorTypeToDataFrameSchema @@ -21,7 +22,9 @@ import scala.collection.mutable /** * This class executes custom derivation logic defined in an implementation of FeatureDerivationFunction. */ -class RowBasedDerivation(dependentFeatureTypeConfigs: Map[String, FeatureTypeConfig]) extends RowBasedDerivationStrategy with Serializable { +class RowBasedDerivation(dependentFeatureTypeConfigs: Map[String, FeatureTypeConfig], + val mvelContext: Option[FeathrExpressionExecutionContext], + ) extends RowBasedDerivationStrategy with Serializable { /** * Calculate a Row-based derived features such as Mvel based derivations or UDFs. @@ -44,7 +47,8 @@ class RowBasedDerivation(dependentFeatureTypeConfigs: Map[String, FeatureTypeCon keyTagList: Seq[String], df: DataFrame, derivedFeature: DerivedFeature, - derivationFunction: FeatureDerivationFunction): DataFrame = { + derivationFunction: FeatureDerivationFunction, + mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = { if (derivationFunction.isInstanceOf[FeatureDerivationFunctionSpark]) { throw new FeathrException(ErrorLabel.FEATHR_USER_ERROR, s"Unsupported user customized derived feature ${derivedFeature.producedFeatureNames}") } @@ -96,7 +100,7 @@ class RowBasedDerivation(dependentFeatureTypeConfigs: Map[String, FeatureTypeCon contextFeatureValues.put(ErasedEntityTaggedFeature(dependFeature.getBinding, dependFeature.getFeatureName), featureValue) }) // calculate using original function - val features = DerivedFeatureEvaluator.evaluateFromFeatureValues(keyTags, derivedFeature, contextFeatureValues.toMap) + val features = DerivedFeatureEvaluator.evaluateFromFeatureValues(keyTags, derivedFeature, contextFeatureValues.toMap, mvelContext) val taggFeatures = features.map(kv => (kv._1.getErasedTagFeatureName, kv._2)) val featureValues = featureNames.map(featureName => { taggFeatures.get(ErasedEntityTaggedFeature(keyTags, featureName).getErasedTagFeatureName).map { featureValue => diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SequentialJoinAsDerivation.scala b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SequentialJoinAsDerivation.scala index 9cc3080d9..2cee39d95 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SequentialJoinAsDerivation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SequentialJoinAsDerivation.scala @@ -14,13 +14,14 @@ import com.linkedin.feathr.offline.job.FeatureTransformation._ import com.linkedin.feathr.offline.job.{AnchorFeatureGroups, FeatureTransformation, KeyedTransformedResult} import com.linkedin.feathr.offline.join.algorithms.{JoinType, SeqJoinExplodedJoinKeyColumnAppender, SparkJoinWithJoinCondition} import com.linkedin.feathr.offline.logical.FeatureGroups +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.source.accessor.DataPathHandler import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults import com.linkedin.feathr.offline.transformation.{AnchorToDataSourceMapper, MvelDefinition} -import com.linkedin.feathr.offline.util.{CoercionUtilsScala, DataFrameSplitterMerger, FeaturizedDatasetUtils, FeathrUtils} +import com.linkedin.feathr.offline.util.{CoercionUtilsScala, DataFrameSplitterMerger, FeathrUtils, FeaturizedDatasetUtils} import com.linkedin.feathr.sparkcommon.{ComplexAggregation, SeqJoinCustomAggregation} import org.apache.log4j.Logger -import org.apache.spark.sql.functions.{expr, udf, _} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} @@ -44,7 +45,8 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession, keyTagList: Seq[String], df: DataFrame, derivedFeature: DerivedFeature, - derivationFunction: SeqJoinDerivationFunction): DataFrame = { + derivationFunction: SeqJoinDerivationFunction, + mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = { val allAnchoredFeatures = featureGroups.allAnchoredFeatures // gather sequential join feature info val seqJoinDerivationFunction = derivationFunction @@ -70,7 +72,7 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession, */ val (expansion, expansionJoinKey): (DataFrame, Seq[String]) = if (allAnchoredFeatures.contains(expansionFeatureName)) { // prepare and get right table - loadExpansionAnchor(expansionFeatureName, derivedFeature, allAnchoredFeatures, seqJoinColumnName) + loadExpansionAnchor(expansionFeatureName, derivedFeature, allAnchoredFeatures, seqJoinColumnName, mvelContext) } else { throw new FeathrException( ErrorLabel.FEATHR_ERROR, @@ -93,7 +95,7 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession, Map(baseTaggedDependency.feature -> MvelDefinition(transformation)) } getOrElse Map.empty[String, MvelDefinition] - val left: DataFrame = PostTransformationUtil.transformFeatures(featureNameColumnTuples, obsWithLeftJoined, transformationDef, getDefaultTransformation) + val left: DataFrame = PostTransformationUtil.transformFeatures(featureNameColumnTuples, obsWithLeftJoined, transformationDef, getDefaultTransformation, mvelContext) // Partition build side of the join based on null values val (dfWithNoNull, dfWithNull) = DataFrameSplitterMerger.splitOnNull(left, baseFeatureJoinKey.head) @@ -207,7 +209,8 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession, def getAnchorFeatureDF( allAnchoredFeatures: Map[String, FeatureAnchorWithSource], anchorFeatureName: String, - anchorToDataSourceMapper: AnchorToDataSourceMapper): KeyedTransformedResult = { + anchorToDataSourceMapper: AnchorToDataSourceMapper, + mvelContext: Option[FeathrExpressionExecutionContext]): KeyedTransformedResult = { val featureAnchor = allAnchoredFeatures(anchorFeatureName) val requestedFeatures = featureAnchor.featureAnchor.getProvidedFeatureNames val anchorGroup = AnchorFeatureGroups(Seq(featureAnchor), requestedFeatures) @@ -219,7 +222,9 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession, anchorDFMap1(featureAnchor), featureAnchor.featureAnchor.sourceKeyExtractor, None, - None) + None, + None, + mvelContext) (featureInfo) } @@ -590,10 +595,11 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession, expansionFeatureName: String, derivedFeature: DerivedFeature, allAnchoredFeatures: Map[String, FeatureAnchorWithSource], - seqJoinproducedFeatureName: String): (DataFrame, Seq[String]) = { + seqJoinproducedFeatureName: String, + mvelContext: Option[FeathrExpressionExecutionContext]): (DataFrame, Seq[String]) = { val expansionFeatureKeys = (derivedFeature.derivation.asInstanceOf[SeqJoinDerivationFunction].right.key) val expansionAnchor = allAnchoredFeatures(expansionFeatureName) - val expandFeatureInfo = getAnchorFeatureDF(allAnchoredFeatures, expansionFeatureName, new AnchorToDataSourceMapper(dataPathHandlers)) + val expandFeatureInfo = getAnchorFeatureDF(allAnchoredFeatures, expansionFeatureName, new AnchorToDataSourceMapper(dataPathHandlers), mvelContext) val transformedFeatureDF = expandFeatureInfo.transformedResult.df val expansionAnchorKeyColumnNames = expandFeatureInfo.joinKey if (expansionFeatureKeys.size != expansionAnchorKeyColumnNames.size) { diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SparkUdfDerivation.scala b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SparkUdfDerivation.scala index ba65e3f23..1d4a9212e 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SparkUdfDerivation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SparkUdfDerivation.scala @@ -6,6 +6,7 @@ import com.linkedin.feathr.offline.ErasedEntityTaggedFeature import com.linkedin.feathr.offline.client.DataFrameColName import com.linkedin.feathr.offline.derived.DerivedFeature import com.linkedin.feathr.offline.exception.FeatureTransformationException +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark import org.apache.spark.sql.DataFrame @@ -30,7 +31,8 @@ class SparkUdfDerivation extends SparkUdfDerivationStrategy { keyTagList: Seq[String], df: DataFrame, derivedFeature: DerivedFeature, - derivationFunction: FeatureDerivationFunctionSpark): DataFrame = { + derivationFunction: FeatureDerivationFunctionSpark, + mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = { if (derivedFeature.parameterNames.isEmpty) { throw new FeathrException( ErrorLabel.FEATHR_USER_ERROR, diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala b/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala index f52b0a4b5..310c3931e 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/DataFrameFeatureGenerator.scala @@ -10,6 +10,7 @@ import com.linkedin.feathr.offline.derived.{DerivedFeature, DerivedFeatureEvalua import com.linkedin.feathr.offline.evaluator.DerivedFeatureGenStage import com.linkedin.feathr.offline.job.{FeatureGenSpec, FeatureTransformation} import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlan} +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.source.accessor.DataPathHandler import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import com.linkedin.feathr.offline.transformation.AnchorToDataSourceMapper @@ -20,7 +21,9 @@ import org.apache.spark.sql.{DataFrame, SparkSession} * Feature generator that is responsible for generating anchored and derived features. * @param logicalPlan logical plan for feature generation job. */ -private[offline] class DataFrameFeatureGenerator(logicalPlan: MultiStageJoinPlan, dataPathHandlers: List[DataPathHandler]) extends Serializable { +private[offline] class DataFrameFeatureGenerator(logicalPlan: MultiStageJoinPlan, + dataPathHandlers: List[DataPathHandler], + mvelContext: Option[FeathrExpressionExecutionContext]) extends Serializable { @transient val incrementalAggSnapshotLoader = IncrementalAggSnapshotLoader @transient val anchorToDataFrameMapper = new AnchorToDataSourceMapper(dataPathHandlers) @transient val featureGenFeatureGrouper = FeatureGenFeatureGrouper() @@ -72,7 +75,7 @@ private[offline] class DataFrameFeatureGenerator(logicalPlan: MultiStageJoinPlan val anchoredDFThisStage = anchorDFRDDMap.filterKeys(anchoredFeaturesThisStage.toSet) FeatureTransformation - .transformFeatures(anchoredDFThisStage, anchoredFeatureNamesThisStage, None, Some(incrementalAggContext)) + .transformFeatures(anchoredDFThisStage, anchoredFeatureNamesThisStage, None, Some(incrementalAggContext), mvelContext) .map(f => (f._1, (offline.FeatureDataFrame(f._2.transformedResult.df, f._2.transformedResult.inferredFeatureTypes), f._2.joinKey))) }.toMap @@ -117,18 +120,18 @@ private[offline] class DataFrameFeatureGenerator(logicalPlan: MultiStageJoinPlan DerivedFeatureEvaluator( DerivationStrategies( new SparkUdfDerivation(), - new RowBasedDerivation(featureGroups.allTypeConfigs), + new RowBasedDerivation(featureGroups.allTypeConfigs, mvelContext), new SequentialJoinDerivationStrategy { override def apply( keyTags: Seq[Int], keyTagList: Seq[String], df: DataFrame, derivedFeature: DerivedFeature, - derivationFunction: SeqJoinDerivationFunction): DataFrame = { + derivationFunction: SeqJoinDerivationFunction, mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = { // Feature generation does not support sequential join features throw new FeathrException( ErrorLabel.FEATHR_ERROR, s"Feature Generation does not support Sequential Join features : ${derivedFeature.producedFeatureNames.head}") } - })) + }), mvelContext) } diff --git a/src/main/scala/com/linkedin/feathr/offline/job/FeatureTransformation.scala b/src/main/scala/com/linkedin/feathr/offline/job/FeatureTransformation.scala index 9b713a882..94de8e645 100644 --- a/src/main/scala/com/linkedin/feathr/offline/job/FeatureTransformation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/job/FeatureTransformation.scala @@ -6,11 +6,11 @@ import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnch import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource} import com.linkedin.feathr.offline.anchored.keyExtractor.MVELSourceKeyExtractor import com.linkedin.feathr.offline.client.DataFrameColName -import com.linkedin.feathr.offline.client.plugins.{SimpleAnchorExtractorSparkAdaptor, FeathrUdfPluginContext, AnchorExtractorAdaptor} import com.linkedin.feathr.offline.config.{MVELFeatureDefinition, TimeWindowFeatureDefinition} import com.linkedin.feathr.offline.generation.IncrementalAggContext import com.linkedin.feathr.offline.job.FeatureJoinJob.FeatureName import com.linkedin.feathr.offline.join.DataFrameKeyCombiner +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.source.accessor.{DataSourceAccessor, NonTimeBasedDataSourceAccessor, TimeBasedDataSourceAccessor} import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils import com.linkedin.feathr.offline.transformation.FeatureColumnFormat.FeatureColumnFormat @@ -164,7 +164,8 @@ private[offline] object FeatureTransformation { featureAnchorWithSource: FeatureAnchorWithSource, df: DataFrame, requestedFeatureRefString: Seq[String], - inputDateInterval: Option[DateTimeInterval]): TransformedResult = { + inputDateInterval: Option[DateTimeInterval], + mvelContext: Option[FeathrExpressionExecutionContext]): TransformedResult = { val featureNamePrefix = getFeatureNamePrefix(featureAnchorWithSource.featureAnchor.extractor) val featureNamePrefixPairs = requestedFeatureRefString.map((_, featureNamePrefix)) @@ -178,7 +179,7 @@ private[offline] object FeatureTransformation { // so that transformation logic can be written only once DataFrameBasedSqlEvaluator.transform(transformer, df, featureNamePrefixPairs, featureTypeConfigs) case transformer: AnchorExtractor[_] => - DataFrameBasedRowEvaluator.transform(transformer, df, featureNamePrefixPairs, featureTypeConfigs) + DataFrameBasedRowEvaluator.transform(transformer, df, featureNamePrefixPairs, featureTypeConfigs, mvelContext) case _ => throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_USER_ERROR, s"cannot find valid Transformer for ${featureAnchorWithSource}") } @@ -286,7 +287,8 @@ private[offline] object FeatureTransformation { keyExtractor: SourceKeyExtractor, bloomFilter: Option[BloomFilter], inputDateInterval: Option[DateTimeInterval], - preprocessedDf: Option[DataFrame] = None): KeyedTransformedResult = { + preprocessedDf: Option[DataFrame] = None, + mvelContext: Option[FeathrExpressionExecutionContext]): KeyedTransformedResult = { // Can two diff anchors have different keyExtractor? assert(anchorFeatureGroup.anchorsWithSameSource.map(_.dateParam).distinct.size == 1) val defaultInterval = anchorFeatureGroup.anchorsWithSameSource.head.dateParam.map(OfflineDateTimeUtils.createIntervalFromFeatureGenDateParam) @@ -314,7 +316,7 @@ private[offline] object FeatureTransformation { (prevTransformedResult, featureAnchorWithSource) => { val requestedFeatures = featureAnchorWithSource.selectedFeatures val transformedResultWithoutKey = - transformSingleAnchorDF(featureAnchorWithSource, prevTransformedResult.df, requestedFeatures, inputDateInterval) + transformSingleAnchorDF(featureAnchorWithSource, prevTransformedResult.df, requestedFeatures, inputDateInterval, mvelContext) val namePrefixPairs = prevTransformedResult.featureNameAndPrefixPairs ++ transformedResultWithoutKey.featureNameAndPrefixPairs val columnNameToFeatureNameAndType = prevTransformedResult.inferredFeatureTypes ++ transformedResultWithoutKey.inferredFeatureTypes val featureColumnFormats = prevTransformedResult.featureColumnFormats ++ transformedResultWithoutKey.featureColumnFormats @@ -437,7 +439,8 @@ private[offline] object FeatureTransformation { anchorToSourceDFThisStage: Map[FeatureAnchorWithSource, DataSourceAccessor], requestedFeatureNames: Seq[FeatureName], bloomFilter: Option[BloomFilter], - incrementalAggContext: Option[IncrementalAggContext] = None): Map[FeatureName, KeyedTransformedResult] = { + incrementalAggContext: Option[IncrementalAggContext] = None, + mvelContext: Option[FeathrExpressionExecutionContext]): Map[FeatureName, KeyedTransformedResult] = { val executionService = Executors.newFixedThreadPool(MAX_PARALLEL_FEATURE_GROUP) implicit val executionContext = ExecutionContext.fromExecutorService(executionService) val groupedAnchorToFeatureGroups: Map[FeatureGroupingCriteria, Map[FeatureAnchorWithSource, FeatureGroupWithSameTimeWindow]] = @@ -457,7 +460,7 @@ private[offline] object FeatureTransformation { val sourceDF = featureGroupingFactors.source val transformedResults: Seq[KeyedTransformedResult] = transformMultiAnchorsOnSingleDataFrame(sourceDF, - keyExtractor, featureAnchorWithSource, bloomFilter, selectedFeatures, incrementalAggContext) + keyExtractor, featureAnchorWithSource, bloomFilter, selectedFeatures, incrementalAggContext, mvelContext) val res = transformedResults .map { transformedResultWithKey => @@ -854,7 +857,8 @@ private[offline] object FeatureTransformation { anchorsWithSameSource: Seq[FeatureAnchorWithSource], bloomFilter: Option[BloomFilter], allRequestedFeatures: Seq[String], - incrementalAggContext: Option[IncrementalAggContext]): Seq[KeyedTransformedResult] = { + incrementalAggContext: Option[IncrementalAggContext], + mvelContext: Option[FeathrExpressionExecutionContext]): Seq[KeyedTransformedResult] = { // based on source and feature definition, divide features into direct transform and incremental // transform groups @@ -864,7 +868,7 @@ private[offline] object FeatureTransformation { val preprocessedDf = PreprocessedDataFrameManager.getPreprocessedDataframe(anchorsWithSameSource) val directTransformedResult = - directTransformAnchorGroup.map(anchorGroup => Seq(directCalculate(anchorGroup, source, keyExtractor, bloomFilter, None, preprocessedDf))) + directTransformAnchorGroup.map(anchorGroup => Seq(directCalculate(anchorGroup, source, keyExtractor, bloomFilter, None, preprocessedDf, mvelContext))) val incrementalTransformedResult = incrementalTransformAnchorGroup.map { anchorGroup => { @@ -883,7 +887,7 @@ private[offline] object FeatureTransformation { baseDF.join(curDF, keyColumnNames) }) val preAggRootDir = incrAggCtx.previousSnapshotRootDirMap(anchorGroup.anchorsWithSameSource.head.selectedFeatures.head) - Seq(incrementalCalculate(anchorGroup, joinedPreAggDFs, source, keyExtractor, bloomFilter, preAggRootDir)) + Seq(incrementalCalculate(anchorGroup, joinedPreAggDFs, source, keyExtractor, bloomFilter, preAggRootDir, mvelContext)) } } @@ -1000,7 +1004,8 @@ private[offline] object FeatureTransformation { source: DataSourceAccessor, keyExtractor: SourceKeyExtractor, bloomFilter: Option[BloomFilter], - preAggRootDir: String): KeyedTransformedResult = { + preAggRootDir: String, + mvelContext: Option[FeathrExpressionExecutionContext]): KeyedTransformedResult = { // get the aggregation window of the feature val aggWindow = getFeatureAggWindow(featureAnchorWithSource) @@ -1013,7 +1018,7 @@ private[offline] object FeatureTransformation { // If so, even though the incremental aggregation succeeds, the result is incorrect. // And the incorrect result will be propagated to all subsequent incremental aggregation because the incorrect result will be used as the snapshot. - val newDeltaSourceAgg = directCalculate(featureAnchorWithSource, source, keyExtractor, bloomFilter, Some(dateParam)) + val newDeltaSourceAgg = directCalculate(featureAnchorWithSource, source, keyExtractor, bloomFilter, Some(dateParam), None, mvelContext) // if the new delta window size is smaller than the request feature window, need to use the pre-aggregated results, if (newDeltaWindowSize < aggWindow) { // add prefixes to feature columns and keys for the previous aggregation snapshot @@ -1034,7 +1039,7 @@ private[offline] object FeatureTransformation { renamedPreAgg } else { // preAgg - oldDeltaAgg - val oldDeltaSourceAgg = directCalculate(featureAnchorWithSource, source, keyExtractor, bloomFilter, Some(oldDeltaWindowInterval)) + val oldDeltaSourceAgg = directCalculate(featureAnchorWithSource, source, keyExtractor, bloomFilter, Some(oldDeltaWindowInterval), None, mvelContext) val oldDeltaAgg = oldDeltaSourceAgg.transformedResult.df mergeDeltaDF(renamedPreAgg, oldDeltaAgg, leftKeyColumnNames, joinKeys, newDeltaFeatureColumnNames, false) } diff --git a/src/main/scala/com/linkedin/feathr/offline/job/LocalFeatureJoinJob.scala b/src/main/scala/com/linkedin/feathr/offline/job/LocalFeatureJoinJob.scala index b3adfc03b..4a38d2304 100644 --- a/src/main/scala/com/linkedin/feathr/offline/job/LocalFeatureJoinJob.scala +++ b/src/main/scala/com/linkedin/feathr/offline/job/LocalFeatureJoinJob.scala @@ -2,6 +2,7 @@ package com.linkedin.feathr.offline.job import com.linkedin.feathr.offline.client.FeathrClient import com.linkedin.feathr.offline.config.FeatureJoinConfig +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import com.linkedin.feathr.offline.source.accessor.DataPathHandler import com.linkedin.feathr.offline.source.dataloader.DataLoaderFactory @@ -34,11 +35,13 @@ object LocalFeatureJoinJob { observationData: SparkFeaturizedDataset, extraParams: Array[String] = Array(), ss: SparkSession = ss, - dataPathHandlers: List[DataPathHandler]): SparkFeaturizedDataset = { + dataPathHandlers: List[DataPathHandler], + mvelContext: Option[FeathrExpressionExecutionContext]): SparkFeaturizedDataset = { val joinConfig = FeatureJoinConfig.parseJoinConfig(joinConfigAsHoconString) val feathrClient = FeathrClient.builder(ss) .addFeatureDef(featureDefAsString) .addDataPathHandlers(dataPathHandlers) + .addFeathrExpressionContext(mvelContext) .build() val outputPath: String = FeatureJoinJob.SKIP_OUTPUT @@ -66,10 +69,11 @@ object LocalFeatureJoinJob { observationDataPath: String, extraParams: Array[String] = Array(), ss: SparkSession = ss, - dataPathHandlers: List[DataPathHandler]): SparkFeaturizedDataset = { + dataPathHandlers: List[DataPathHandler], + mvelContext: Option[FeathrExpressionExecutionContext]=None): SparkFeaturizedDataset = { val dataLoaderHandlers: List[DataLoaderHandler] = dataPathHandlers.map(_.dataLoaderHandler) val obsDf = loadObservationAsFDS(ss, observationDataPath,dataLoaderHandlers=dataLoaderHandlers) - joinWithObsDFAndHoconJoinConfig(joinConfigAsHoconString, featureDefAsString, obsDf, extraParams, ss, dataPathHandlers=dataPathHandlers) + joinWithObsDFAndHoconJoinConfig(joinConfigAsHoconString, featureDefAsString, obsDf, extraParams, ss, dataPathHandlers=dataPathHandlers, mvelContext) } /** diff --git a/src/main/scala/com/linkedin/feathr/offline/join/DataFrameFeatureJoiner.scala b/src/main/scala/com/linkedin/feathr/offline/join/DataFrameFeatureJoiner.scala index e7fccbd08..a03abc83c 100644 --- a/src/main/scala/com/linkedin/feathr/offline/join/DataFrameFeatureJoiner.scala +++ b/src/main/scala/com/linkedin/feathr/offline/join/DataFrameFeatureJoiner.scala @@ -12,6 +12,7 @@ import com.linkedin.feathr.offline.join.algorithms._ import com.linkedin.feathr.offline.join.util.{FrequentItemEstimatorFactory, FrequentItemEstimatorType} import com.linkedin.feathr.offline.join.workflow._ import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlan} +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.source.accessor.DataPathHandler import com.linkedin.feathr.offline.swa.SlidingWindowAggregationJoiner import com.linkedin.feathr.offline.transformation.AnchorToDataSourceMapper @@ -30,7 +31,7 @@ import scala.collection.JavaConverters._ * Joiner to join observation with feature data using Spark DataFrame API * @param logicalPlan analyzed feature info */ -private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, dataPathHandlers: List[DataPathHandler]) extends Serializable { +private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, dataPathHandlers: List[DataPathHandler], mvelContext: Option[FeathrExpressionExecutionContext]) extends Serializable { @transient lazy val log = Logger.getLogger(getClass.getName) @transient lazy val anchorToDataSourceMapper = new AnchorToDataSourceMapper(dataPathHandlers) private val windowAggFeatureStages = logicalPlan.windowAggFeatureStages @@ -69,7 +70,7 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d (dfWithFeatureNames, featureAnchorWithSourcePair) => { val featureAnchorWithSource = featureAnchorWithSourcePair._1 val requestedFeatures = featureAnchorWithSourcePair._2.toSeq - val resultWithoutKey = transformSingleAnchorDF(featureAnchorWithSource, dfWithFeatureNames.df, requestedFeatures, None) + val resultWithoutKey = transformSingleAnchorDF(featureAnchorWithSource, dfWithFeatureNames.df, requestedFeatures, None, mvelContext) val namePrefixPairs = dfWithFeatureNames.featureNameAndPrefixPairs ++ resultWithoutKey.featureNameAndPrefixPairs val inferredFeatureTypeConfigs = dfWithFeatureNames.inferredFeatureTypes ++ resultWithoutKey.inferredFeatureTypes val featureColumnFormats = resultWithoutKey.featureColumnFormats ++ dfWithFeatureNames.featureColumnFormats @@ -201,12 +202,12 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d AnchoredFeatureJoinStep( SlickJoinLeftJoinKeyColumnAppender, SlickJoinRightJoinKeyColumnAppender, - SparkJoinWithJoinCondition(EqualityJoinConditionBuilder)) + SparkJoinWithJoinCondition(EqualityJoinConditionBuilder), mvelContext) } else { AnchoredFeatureJoinStep( SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, - SparkJoinWithJoinCondition(EqualityJoinConditionBuilder)) + SparkJoinWithJoinCondition(EqualityJoinConditionBuilder), mvelContext) } val FeatureDataFrameOutput(FeatureDataFrame(withAllBasicAnchoredFeatureDF, inferredBasicAnchoredFeatureTypes)) = anchoredFeatureJoinStep.joinFeatures(requiredRegularFeatureAnchors, AnchorJoinStepInput(withWindowAggFeatureDF, anchorSourceAccessorMap)) @@ -223,7 +224,7 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d } else withAllBasicAnchoredFeatureDF // 6. Join Derived Features - val derivedFeatureEvaluator = DerivedFeatureEvaluator(ss=ss, featureGroups=featureGroups, dataPathHandlers=dataPathHandlers) + val derivedFeatureEvaluator = DerivedFeatureEvaluator(ss=ss, featureGroups=featureGroups, dataPathHandlers=dataPathHandlers, mvelContext) val derivedFeatureJoinStep = DerivedFeatureJoinStep(derivedFeatureEvaluator) val FeatureDataFrameOutput(FeatureDataFrame(withDerivedFeatureDF, inferredDerivedFeatureTypes)) = derivedFeatureJoinStep.joinFeatures(allRequiredFeatures.filter { diff --git a/src/main/scala/com/linkedin/feathr/offline/join/workflow/AnchoredFeatureJoinStep.scala b/src/main/scala/com/linkedin/feathr/offline/join/workflow/AnchoredFeatureJoinStep.scala index 5e69438f8..7abe3901b 100644 --- a/src/main/scala/com/linkedin/feathr/offline/join/workflow/AnchoredFeatureJoinStep.scala +++ b/src/main/scala/com/linkedin/feathr/offline/join/workflow/AnchoredFeatureJoinStep.scala @@ -12,6 +12,7 @@ import com.linkedin.feathr.offline.job.KeyedTransformedResult import com.linkedin.feathr.offline.join._ import com.linkedin.feathr.offline.join.algorithms._ import com.linkedin.feathr.offline.join.util.FrequentItemEstimatorFactory +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults import com.linkedin.feathr.offline.util.FeathrUtils @@ -31,7 +32,8 @@ import org.apache.spark.sql.functions.lit private[offline] class AnchoredFeatureJoinStep( leftJoinColumnExtractor: JoinKeyColumnsAppender, rightJoinColumnExtractor: JoinKeyColumnsAppender, - joiner: SparkJoinWithJoinCondition) + joiner: SparkJoinWithJoinCondition, + mvelContext: Option[FeathrExpressionExecutionContext]) extends FeatureJoinStep[AnchorJoinStepInput, DataFrameJoinStepOutput] { @transient lazy val log = Logger.getLogger(getClass.getName) @@ -126,7 +128,7 @@ private[offline] class AnchoredFeatureJoinStep( val anchoredFeaturesThisStage = featureNames.filter(allAnchoredFeatures.contains).map(allAnchoredFeatures).distinct val anchoredDFThisStage = anchorDFMap.filterKeys(anchoredFeaturesThisStage.toSet) // map feature name to its transformed dataframe and the join key of the dataframe - val featureToDFAndJoinKeys = transformFeatures(anchoredDFThisStage, anchoredFeatureNamesThisStage, bloomFilter) + val featureToDFAndJoinKeys = transformFeatures(anchoredDFThisStage, anchoredFeatureNamesThisStage, bloomFilter, None, mvelContext) featureToDFAndJoinKeys .groupBy(_._2.transformedResult.df) // group by dataframe, join one at a time .map(grouped => (grouped._2.keys.toSeq, grouped._2.values.toSeq)) // extract the feature names and their (dataframe,join keys) pairs @@ -226,6 +228,7 @@ private[offline] object AnchoredFeatureJoinStep { def apply( leftJoinColumnExtractor: JoinKeyColumnsAppender, rightJoinColumnExtractor: JoinKeyColumnsAppender, - joiner: SparkJoinWithJoinCondition): AnchoredFeatureJoinStep = - new AnchoredFeatureJoinStep(leftJoinColumnExtractor, rightJoinColumnExtractor, joiner) + joiner: SparkJoinWithJoinCondition, + mvelContext: Option[FeathrExpressionExecutionContext]): AnchoredFeatureJoinStep = + new AnchoredFeatureJoinStep(leftJoinColumnExtractor, rightJoinColumnExtractor, joiner, mvelContext) } diff --git a/src/main/scala/com/linkedin/feathr/offline/mvel/MvelContext.java b/src/main/scala/com/linkedin/feathr/offline/mvel/MvelContext.java index ce5926605..1ce8136c9 100644 --- a/src/main/scala/com/linkedin/feathr/offline/mvel/MvelContext.java +++ b/src/main/scala/com/linkedin/feathr/offline/mvel/MvelContext.java @@ -4,11 +4,11 @@ import com.google.common.collect.ImmutableSet; import com.linkedin.feathr.common.FeatureValue; import com.linkedin.feathr.common.util.MvelContextUDFs; +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext; import org.apache.avro.generic.GenericEnumSymbol; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.Utf8; import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; -import org.mvel2.DataConversion; import org.mvel2.MVEL; import org.mvel2.ParserConfiguration; import org.mvel2.ParserContext; @@ -114,9 +114,9 @@ public static ParserContext newParserContext() { * {@link com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext}. (Output objects that can be converted * to {@link FeatureValue} via plugins, will be converted after MVEL returns.) */ - public static Object executeExpressionWithPluginSupport(Object compiledExpression, Object ctx) { + public static Object executeExpressionWithPluginSupport(Object compiledExpression, Object ctx, FeathrExpressionExecutionContext mvelContext) { Object output = MVEL.executeExpression(compiledExpression, ctx); - return coerceToFeatureValueViaMvelDataConversionPlugins(output); + return coerceToFeatureValueViaMvelDataConversionPlugins(output, mvelContext); } /** @@ -124,15 +124,18 @@ public static Object executeExpressionWithPluginSupport(Object compiledExpressio * {@link com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext}. (Output objects that can be converted * to {@link FeatureValue} via plugins, will be converted after MVEL returns.) */ - public static Object executeExpressionWithPluginSupport(Object compiledExpression, Object ctx, - VariableResolverFactory variableResolverFactory) { + public static Object executeExpressionWithPluginSupportWithFactory(Object compiledExpression, + Object ctx, + VariableResolverFactory variableResolverFactory, + FeathrExpressionExecutionContext mvelContext) { Object output = MVEL.executeExpression(compiledExpression, ctx, variableResolverFactory); - return coerceToFeatureValueViaMvelDataConversionPlugins(output); + return coerceToFeatureValueViaMvelDataConversionPlugins(output, mvelContext); } - private static Object coerceToFeatureValueViaMvelDataConversionPlugins(Object input) { - if (input != null && DataConversion.canConvert(FeatureValue.class, input.getClass())) { - return DataConversion.convert(input, FeatureValue.class); + private static Object coerceToFeatureValueViaMvelDataConversionPlugins(Object input, FeathrExpressionExecutionContext mvelContext) { + // Convert the input to feature value using the given MvelContext if possible + if (input != null && mvelContext!= null && mvelContext.canConvert(FeatureValue.class, input.getClass())) { + return mvelContext.convert(input, FeatureValue.class); } else { return input; } diff --git a/src/main/scala/com/linkedin/feathr/offline/mvel/MvelUtils.scala b/src/main/scala/com/linkedin/feathr/offline/mvel/MvelUtils.scala index 8da9e2272..db467b0cf 100644 --- a/src/main/scala/com/linkedin/feathr/offline/mvel/MvelUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/mvel/MvelUtils.scala @@ -1,9 +1,10 @@ package com.linkedin.feathr.offline.mvel +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import org.apache.commons.lang.exception.ExceptionUtils import org.apache.log4j.Logger +import org.mvel2.PropertyAccessException import org.mvel2.integration.VariableResolverFactory -import org.mvel2.{MVEL, PropertyAccessException} private[offline] object MvelUtils { @transient private lazy val log = Logger.getLogger(getClass) @@ -15,9 +16,9 @@ private[offline] object MvelUtils { // This approach has pros and cons and will likely be controversial // But it should allow for much simpler expressions for extracting features from data sets whose values may often be null // (We might not want to check for null explicitly everywhere) - def executeExpression(compiledExpression: Any, input: Any, resolverFactory: VariableResolverFactory, featureName: String = ""): Option[AnyRef] = { + def executeExpression(compiledExpression: Any, input: Any, resolverFactory: VariableResolverFactory, featureName: String = "", mvelContext: Option[FeathrExpressionExecutionContext]): Option[AnyRef] = { try { - Option(MvelContext.executeExpressionWithPluginSupport(compiledExpression, input, resolverFactory)) + Option(MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, input, resolverFactory, mvelContext.orNull)) } catch { case e: RuntimeException => log.debug(s"Expression $compiledExpression on input record $input threw exception", e) diff --git a/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrExpressionExecutionContext.scala b/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrExpressionExecutionContext.scala new file mode 100644 index 000000000..67371464f --- /dev/null +++ b/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrExpressionExecutionContext.scala @@ -0,0 +1,144 @@ +package com.linkedin.feathr.offline.mvel.plugins + +import com.linkedin.feathr.common.FeatureValue +import org.apache.spark.SparkContext +import org.apache.spark.broadcast.Broadcast +import org.mvel2.ConversionHandler +import org.mvel2.conversion.ArrayHandler +import org.mvel2.util.ReflectionUtil.{isAssignableFrom, toNonPrimitiveType} + +import java.io.Serializable +import scala.collection.mutable + +/** + * The context needed for the Feathr expression transformation language, in order to + * support the automatic conversion between the Feather Feature value class and + * some customized external data, e.g. 3rd-party feature value class. + * It is intended for advanced cases to enable compatibility with old versions of FeathrExpression language + * and that most users would not need to use it. + */ +class FeathrExpressionExecutionContext extends Serializable { + + // A map of converters that are registered to convert a class into customized data format. + // This include convert from and to feature value. + // The Map is broadcast from the driver to executors + private var converters: Broadcast[mutable.HashMap[String, ConversionHandler]] = null + // A map of adaptors that are registered to convert a Feathr FeatureValue to customized external data format + // The Map is broadcasted from the driver to executors + private var featureValueTypeAdaptors: Broadcast[mutable.HashMap[String, FeatureValueTypeAdaptor[AnyRef]]] = null + + // Same as converters, used to build the map on the driver during the job initialization. + // Will be broadcasted to all executors and available as converters + private val localConverters = new mutable.HashMap[String, ConversionHandler] + // Same as featureValueTypeAdaptors, used to build the map on the driver during the job initialization. + // Will be broadcasted to all executors and available as converters + private val localFeatureValueTypeAdaptors = new mutable.HashMap[String, FeatureValueTypeAdaptor[AnyRef]] + + /** + * Setup Executor Mvel Expression Context by adding a type adaptor to Feathr's MVEL runtime, + * it will enable Feathr's expressions to support some alternative + * class representation of {@link FeatureValue} via coercion. + * + * @param clazz the class of the "other" alternative representation of feature value + * @param typeAdaptor the type adaptor that can convert between the "other" representation and {@link FeatureValue} + * @param < T> type parameter for the "other" feature value class + */ + def setupExecutorMvelContext[T](clazz: Class[T], typeAdaptor: FeatureValueTypeAdaptor[T], sc: SparkContext): Unit = { + localFeatureValueTypeAdaptors.put(clazz.getCanonicalName, typeAdaptor.asInstanceOf[FeatureValueTypeAdaptor[AnyRef]]) + featureValueTypeAdaptors = sc.broadcast(localFeatureValueTypeAdaptors) + // Add a converter that can convert external data to feature value + addConversionHandler(classOf[FeatureValue], new ExternalDataToFeatureValueHandler(featureValueTypeAdaptors), sc) + // Add a converter that can convert a feature value to external data + addConversionHandler(clazz, new FeatureValueToExternalDataHandler(typeAdaptor), sc) + } + + /** + * Check if there is registered converters that can handle the conversion. + * @param toType type to convert to + * @param convertFrom type to convert from + * @return whether it can be converted or not + */ + def canConvert(toType: Class[_], convertFrom: Class[_]): Boolean = { + if (isAssignableFrom(toType, convertFrom)) return true + if (converters.value.contains(toType.getCanonicalName)) { + converters.value.get(toType.getCanonicalName).get.canConvertFrom(toNonPrimitiveType(convertFrom)) + } else if (toType.isArray && canConvert(toType.getComponentType, convertFrom)) { + true + } else { + false + } + } + + /** + * Convert the input to output type using the registered converters + * @param in value to be converted + * @param toType output type + * @tparam T + * @return + */ + def convert[T](in: Any, toType: Class[T]): T = { + if ((toType eq in.getClass) || toType.isAssignableFrom(in.getClass)) return in.asInstanceOf[T] + val converter = if (converters.value != null) { + converters.value.get(toType.getCanonicalName).get + } else { + throw new RuntimeException(s"Cannot convert ${in} to ${toType} due to no converters found.") + } + if (converter == null && toType.isArray) { + val handler = new ArrayHandler(toType) + converters.value.put(toType.getCanonicalName, handler) + handler.convertFrom(in).asInstanceOf[T] + } + else converter.convertFrom(in).asInstanceOf[T] + } + + /** + * Register a new {@link ConversionHandler} with the factory. + * + * @param type - Target type represented by the conversion handler. + * @param handler - An instance of the handler. + */ + private[plugins] def addConversionHandler(`type`: Class[_], handler: ConversionHandler, sc: SparkContext): Unit = { + localConverters.put(`type`.getCanonicalName, handler) + converters = sc.broadcast( localConverters) + } + + /** + * Convert Feathr FeatureValue to external FeatureValue + * @param adaptor An adaptor that knows how to convert the Feathr feature value to requested external data + */ + class FeatureValueToExternalDataHandler(val adaptor: FeatureValueTypeAdaptor[_]) + extends ConversionHandler with Serializable { + /** + * Convert a FeatureValue into requested external data + * @param fv the input feature value + * @return requested external data + */ + override def convertFrom(fv: Any): AnyRef = adaptor.fromFeathrFeatureValue(fv.asInstanceOf[FeatureValue]).asInstanceOf[AnyRef] + + override def canConvertFrom(cls: Class[_]): Boolean = classOf[FeatureValue] == cls + } + + + /** + * Convert external data types to Feathr FeatureValue automatically + * @param adaptors a map of adaptors that knows how to convert external data to feature value. + * It maps the supported input class name to its adaptor. + */ + class ExternalDataToFeatureValueHandler(val adaptors: Broadcast[mutable.HashMap[String, FeatureValueTypeAdaptor[AnyRef]]]) + extends ConversionHandler with Serializable { + + /** + * Convert external data to a Feature value + * + * @param externalData to convert + * @return result feature value + */ + def convertFrom(externalData: Any): AnyRef = { + val adaptor = adaptors.value.get(externalData.getClass.getCanonicalName).get + if (adaptor == null) throw new IllegalArgumentException("Can't convert to Feathr FeatureValue from " + externalData + ", current type adaptors: " + adaptors.value.keySet.mkString(",")) + adaptor.toFeathrFeatureValue(externalData.asInstanceOf[AnyRef]) + } + + override def canConvertFrom(cls: Class[_]): Boolean = adaptors.value.contains(cls.getCanonicalName) + } +} diff --git a/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrMvelPluginContext.java b/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrMvelPluginContext.java deleted file mode 100644 index b672007bc..000000000 --- a/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrMvelPluginContext.java +++ /dev/null @@ -1,79 +0,0 @@ -package com.linkedin.feathr.offline.mvel.plugins; - -import com.linkedin.feathr.common.FeatureValue; -import com.linkedin.feathr.common.InternalApi; -import org.mvel2.ConversionHandler; -import org.mvel2.DataConversion; - -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - - -/** - * A plugin that allows an advanced user to add additional capabilities or behaviors to Feathr's MVEL runtime. - * - * NOTE: This class is intended for advanced users only, and specifically as a "migration aid" for migrating from - * some previous versions of Feathr whose FeatureValue representations had a different class name, while preserving - * compatibility with feature definitions written against those older versions of Feathr. - */ -public class FeathrMvelPluginContext { - // TODO: Does this need to be "translated" into a different pattern whereby we track the CLASSNAME of the type adaptors - // instead of the instance, such that the class mappings can be broadcasted via Spark and then reinitialized on - // executor hosts? - private static final ConcurrentMap, FeatureValueTypeAdaptor> TYPE_ADAPTORS; - - static { - TYPE_ADAPTORS = new ConcurrentHashMap<>(); - DataConversion.addConversionHandler(FeatureValue.class, new FeathrFeatureValueConversionHandler()); - } - - /** - * Add a type adaptor to Feathr's MVEL runtime, that will enable Feathr's expressions to support some alternative - * class representation of {@link FeatureValue} via coercion. - * @param clazz the class of the "other" alternative representation of feature value - * @param typeAdaptor the type adaptor that can convert between the "other" representation and {@link FeatureValue} - * @param type parameter for the "other" feature value class - */ - @SuppressWarnings("unchecked") - public static void addFeatureTypeAdaptor(Class clazz, FeatureValueTypeAdaptor typeAdaptor) { - // TODO: MAKE SURE clazz IS NOT ONE OF THE CLASSES ALREADY COVERED IN org.mvel2.DataConversion.CONVERTERS! - // IF WE OVERRIDE ANY OF THOSE, IT MIGHT CAUSE MVEL TO BEHAVE IN STRANGE AND UNEXPECTED WAYS! - TYPE_ADAPTORS.put(clazz, typeAdaptor); - DataConversion.addConversionHandler(clazz, new ExternalFeatureValueConversionHandler(typeAdaptor)); - } - - static class FeathrFeatureValueConversionHandler implements ConversionHandler { - @Override - @SuppressWarnings("unchecked") - public Object convertFrom(Object in) { - FeatureValueTypeAdaptor adaptor = (FeatureValueTypeAdaptor) TYPE_ADAPTORS.get(in.getClass()); - if (adaptor == null) { - throw new IllegalArgumentException("Can't convert to Feathr FeatureValue from " + in); - } - return adaptor.toFeathrFeatureValue(in); - } - - @Override - public boolean canConvertFrom(Class cls) { - return TYPE_ADAPTORS.containsKey(cls); - } - } - - static class ExternalFeatureValueConversionHandler implements ConversionHandler { - private final FeatureValueTypeAdaptor _adaptor; - - public ExternalFeatureValueConversionHandler(FeatureValueTypeAdaptor adaptor) { - _adaptor = adaptor; - } - - @Override - public Object convertFrom(Object in) { - return _adaptor.fromFeathrFeatureValue((FeatureValue) in); - } - - @Override - public boolean canConvertFrom(Class cls) { - return FeatureValue.class.equals(cls); - } - } -} diff --git a/src/main/scala/com/linkedin/feathr/offline/transformation/DataFrameBasedRowEvaluator.scala b/src/main/scala/com/linkedin/feathr/offline/transformation/DataFrameBasedRowEvaluator.scala index 0bdb013d1..d242372bf 100644 --- a/src/main/scala/com/linkedin/feathr/offline/transformation/DataFrameBasedRowEvaluator.scala +++ b/src/main/scala/com/linkedin/feathr/offline/transformation/DataFrameBasedRowEvaluator.scala @@ -5,7 +5,9 @@ import com.linkedin.feathr.common.tensor.TensorData import com.linkedin.feathr.common.{AnchorExtractor, FeatureTypeConfig, FeatureTypes, SparkRowExtractor} import com.linkedin.feathr.offline import com.linkedin.feathr.offline.FeatureDataFrame +import com.linkedin.feathr.offline.anchored.anchorExtractor.SimpleConfigurableAnchorExtractor import com.linkedin.feathr.offline.job.{FeatureTransformation, FeatureTypeInferenceContext, TransformedResult} +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema @@ -32,7 +34,8 @@ private[offline] object DataFrameBasedRowEvaluator { def transform(transformer: AnchorExtractor[_], inputDf: DataFrame, requestedFeatureNameAndPrefix: Seq[(String, String)], - featureTypeConfigs: Map[String, FeatureTypeConfig]): TransformedResult = { + featureTypeConfigs: Map[String, FeatureTypeConfig], + mvelContext: Option[FeathrExpressionExecutionContext]): TransformedResult = { if (!transformer.isInstanceOf[SparkRowExtractor]) { throw new FeathrException(ErrorLabel.FEATHR_USER_ERROR, s"${transformer} must extend SparkRowExtractor.") } @@ -42,7 +45,7 @@ private[offline] object DataFrameBasedRowEvaluator { val featureFormat = FeatureColumnFormat.FDS_TENSOR // features to calculate, if empty, will calculate all features defined in the extractor val selectedFeatureNames = if (requestedFeatureRefString.nonEmpty) requestedFeatureRefString else transformer.getProvidedFeatureNames - val FeatureDataFrame(transformedDF, transformedFeatureTypes) = transformToFDSTensor(extractor, inputDf, selectedFeatureNames, featureTypeConfigs) + val FeatureDataFrame(transformedDF, transformedFeatureTypes) = transformToFDSTensor(extractor, inputDf, selectedFeatureNames, featureTypeConfigs, mvelContext) TransformedResult( // Re-compute the featureNamePrefixPairs because feature names can be coming from the extractor. selectedFeatureNames.map((_, featureNamePrefix)), @@ -64,7 +67,8 @@ private[offline] object DataFrameBasedRowEvaluator { private def transformToFDSTensor(rowExtractor: SparkRowExtractor, inputDF: DataFrame, featureRefStrs: Seq[String], - featureTypeConfigs: Map[String, FeatureTypeConfig]): FeatureDataFrame = { + featureTypeConfigs: Map[String, FeatureTypeConfig], + mvelContext: Option[FeathrExpressionExecutionContext]): FeatureDataFrame = { val inputSchema = inputDF.schema val spark = SparkSession.builder().getOrCreate() val featureTypes = featureTypeConfigs.mapValues(_.getFeatureType) @@ -78,6 +82,9 @@ private[offline] object DataFrameBasedRowEvaluator { } else { new GenericRowWithSchema(row.toSeq.toArray, inputSchema) } + if (rowExtractor.isInstanceOf[SimpleConfigurableAnchorExtractor]) { + rowExtractor.asInstanceOf[SimpleConfigurableAnchorExtractor].mvelContext = mvelContext + } val result = rowExtractor.getFeaturesFromRow(rowWithSchema) val featureValues = featureRefStrs map { featureRef => diff --git a/src/main/scala/com/linkedin/feathr/offline/util/FeathrTestUtils.scala b/src/main/scala/com/linkedin/feathr/offline/util/FeathrTestUtils.scala index 23fca857f..47af7d5b1 100644 --- a/src/main/scala/com/linkedin/feathr/offline/util/FeathrTestUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/util/FeathrTestUtils.scala @@ -1,11 +1,10 @@ package com.linkedin.feathr.offline.util -import org.apache.spark.sql.internal.SQLConf -import Transformations.sortColumns import com.linkedin.feathr.offline.config.datasource.{DataSourceConfigUtils, DataSourceConfigs} -import com.linkedin.feathr.offline.job.FeatureGenJob +import com.linkedin.feathr.offline.util.Transformations.sortColumns import org.apache.avro.generic.GenericRecord import org.apache.spark.SparkConf +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.{DataFrame, Row, SparkSession} private[offline] object FeathrTestUtils { diff --git a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala index 211ef7e46..e9d3a2bf1 100644 --- a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala @@ -9,15 +9,16 @@ import com.linkedin.feathr.common.{AnchorExtractor, DateParam} import com.linkedin.feathr.offline.client.InputData import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.generation.SparkIOUtils +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils} import com.linkedin.feathr.offline.source.SourceFormatType import com.linkedin.feathr.offline.source.SourceFormatType.SourceFormatType +import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import com.linkedin.feathr.offline.source.dataloader.hdfs.FileFormat import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils import com.linkedin.feathr.offline.source.pathutil.{PathChecker, TimeBasedHdfsPathAnalyzer, TimeBasedHdfsPathGenerator} import com.linkedin.feathr.offline.util.AclCheckUtils.getLatestPath import com.linkedin.feathr.offline.util.datetime.OfflineDateTimeUtils -import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import org.apache.avro.generic.GenericData.{Array, Record} import org.apache.avro.generic.{GenericDatumReader, GenericRecord, IndexedRecord} import org.apache.avro.io.DecoderFactory @@ -29,14 +30,13 @@ import org.apache.avro.{Schema, SchemaBuilder} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.Job import org.apache.log4j.Logger import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.types.StructType -import org.codehaus.jackson.JsonNode import org.joda.time.{Days, Hours, Interval, DateTime => JodaDateTime, DateTimeZone => JodaTimeZone} import org.mvel2.MVEL @@ -234,51 +234,6 @@ private[offline] object SourceUtils { field.defaultVal() } - /* Defines a symmetric relationship for two keys regarding to that target fields, for example - * ( (viewerId, vieweeId), affinity ) <=> ( (vieweedId, viewerId), affinity ), so in the dataset, - * they are only stored once on HDFS, here this operation should generate the full data - */ - def getRDDViewSymmKeys(rawRDD: RDD[_], targetFields: Option[Seq[String]], otherFields: Option[Seq[String]] = None): RDD[_] = { - - val symmKeys = targetFields match { - case Some(v: Seq[String]) => v - case None => - throw new FeathrConfigException( - ErrorLabel.FEATHR_USER_ERROR, - s"Trying to get symmetric RDD view. Symmetric keys are not defined. Please provide targetFields fields.") - } - - if (symmKeys.size != 2) { - throw new FeathrConfigException( - ErrorLabel.FEATHR_USER_ERROR, - s"Trying to get symmetric RDD view. Symmetric keys (targetFields) must have size of two, found ${symmKeys.size}." + - s" Please provide the targetFields.") - } - - val otherFeatures = otherFields match { - case Some(v: Seq[String]) => v - case None => - throw new FeathrConfigException( - ErrorLabel.FEATHR_USER_ERROR, - s"Trying to get symmetric RDD view. Oother feature fields are not defined. Please provide other feature fields.") - } - - val allFields = (otherFeatures ++ symmKeys).distinct - val extractorForFields = extractorForFieldNames(allFields) - - val rddView = rawRDD.flatMap(record => { - val extractedRecord: Map[String, Any] = extractorForFields(record) - val symmKeyVal0 = extractedRecord(symmKeys(0)) - val symmKeyVal1 = extractedRecord(symmKeys(1)) - // to create the symmetric version of the data (swapping the two keys) - // procedure: remove the original keys from the Map and then add the symmetric pairs - val extractedRecordDup = extractedRecord - symmKeys(0) - symmKeys(1) + (symmKeys(0) -> symmKeyVal1, symmKeys(1) -> symmKeyVal0) - Seq(extractedRecord.asJava, extractedRecordDup.asJava) - }) - - rddView - } - /** * Get the needed fact/feature dataset for a feature anchor as a DataFrame. * @param ss Spark Session @@ -435,7 +390,7 @@ private[offline] object SourceUtils { /* * Given a sequence of field names, return the corresponding field, must be the top level */ - private def extractorForFieldNames(allFields: Seq[String]): Any => Map[String, Any] = { + private def extractorForFieldNames(allFields: Seq[String], mvelContext: Option[FeathrExpressionExecutionContext]): Any => Map[String, Any] = { val compiledExpressionMap = allFields .map( fieldName => @@ -446,7 +401,7 @@ private[offline] object SourceUtils { compiledExpressionMap .mapValues(expression => { MvelContext.ensureInitialized() - MvelUtils.executeExpression(expression, record, null) + MvelUtils.executeExpression(expression, record, null, "", mvelContext) }) .collect { case (name, Some(value)) => (name, value) } .toMap diff --git a/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValueTypeAdaptor.java b/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValueTypeAdaptor.java index fd771fe73..bbe48f850 100644 --- a/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValueTypeAdaptor.java +++ b/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValueTypeAdaptor.java @@ -4,7 +4,9 @@ import com.linkedin.feathr.common.types.NumericFeatureType; import com.linkedin.feathr.offline.mvel.plugins.FeatureValueTypeAdaptor; -public class AlienFeatureValueTypeAdaptor implements FeatureValueTypeAdaptor { +import java.io.Serializable; + +public class AlienFeatureValueTypeAdaptor implements FeatureValueTypeAdaptor, Serializable { @Override public FeatureValue toFeathrFeatureValue(AlienFeatureValue other) { if (other.isFloat()) { diff --git a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala index 3ca387b55..061b42598 100644 --- a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala +++ b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala @@ -5,8 +5,6 @@ import com.linkedin.feathr.common.exception.FeathrConfigException import com.linkedin.feathr.offline.config.location.SimplePath import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager -import com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext -import com.linkedin.feathr.offline.plugins.{AlienFeatureValue, AlienFeatureValueTypeAdaptor} import com.linkedin.feathr.offline.source.dataloader.{AvroJsonDataLoader, CsvDataLoader} import com.linkedin.feathr.offline.util.FeathrTestUtils import org.apache.spark.sql.Row diff --git a/src/test/scala/com/linkedin/feathr/offline/FeathrIntegTest.scala b/src/test/scala/com/linkedin/feathr/offline/FeathrIntegTest.scala index dc6078d13..13bf5578e 100644 --- a/src/test/scala/com/linkedin/feathr/offline/FeathrIntegTest.scala +++ b/src/test/scala/com/linkedin/feathr/offline/FeathrIntegTest.scala @@ -2,6 +2,7 @@ package com.linkedin.feathr.offline import com.linkedin.feathr.common.TaggedFeatureName import com.linkedin.feathr.offline.job.{LocalFeatureGenJob, LocalFeatureJoinJob} +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.util.{FeathrTestUtils, SparkFeaturizedDataset} import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.SparkSession @@ -36,8 +37,9 @@ abstract class FeathrIntegTest extends TestFeathr { joinConfigAsString: String, featureDefAsString: String, observationDataPath: String, - extraParams: Array[String] = Array()): SparkFeaturizedDataset = { - LocalFeatureJoinJob.joinWithHoconJoinConfig(joinConfigAsString, featureDefAsString, observationDataPath, extraParams, dataPathHandlers=List()) + extraParams: Array[String] = Array(), + mvelContext: Option[FeathrExpressionExecutionContext] = None): SparkFeaturizedDataset = { + LocalFeatureJoinJob.joinWithHoconJoinConfig(joinConfigAsString, featureDefAsString, observationDataPath, extraParams, dataPathHandlers=List(), mvelContext=mvelContext) } def getOrCreateSparkSession: SparkSession = { diff --git a/src/test/scala/com/linkedin/feathr/offline/TestFeathr.scala b/src/test/scala/com/linkedin/feathr/offline/TestFeathr.scala index bb451c6a5..f052663e3 100644 --- a/src/test/scala/com/linkedin/feathr/offline/TestFeathr.scala +++ b/src/test/scala/com/linkedin/feathr/offline/TestFeathr.scala @@ -4,7 +4,7 @@ import com.linkedin.feathr.common import com.linkedin.feathr.common.JoiningFeatureParams import com.linkedin.feathr.offline.client.FeathrClient import com.linkedin.feathr.offline.config.{FeathrConfig, FeathrConfigLoader} -import com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.plugins.{AlienFeatureValue, AlienFeatureValueTypeAdaptor} import com.linkedin.feathr.offline.util.FeathrTestUtils import org.apache.avro.generic.GenericRecord @@ -23,14 +23,16 @@ abstract class TestFeathr extends TestNGSuite { protected var feathr: FeathrClient = _ val FeathrFeatureNamePrefix = "__feathr_feature_" protected var feathrConfigLoader: FeathrConfig = FeathrConfigLoader() + + private val mvelContext = new FeathrExpressionExecutionContext() import org.apache.log4j.{Level, Logger} Logger.getLogger("org").setLevel(Level.OFF) Logger.getLogger("akka").setLevel(Level.OFF) @BeforeClass def setup(): Unit = { - FeathrMvelPluginContext.addFeatureTypeAdaptor(classOf[AlienFeatureValue], new AlienFeatureValueTypeAdaptor) setupSpark() + mvelContext.setupExecutorMvelContext(classOf[AlienFeatureValue], new AlienFeatureValueTypeAdaptor(), ss.sparkContext) } /** diff --git a/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala b/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala index 68ead2408..63637a989 100644 --- a/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala +++ b/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala @@ -4,7 +4,7 @@ import com.linkedin.feathr.common.FeatureTypes import com.linkedin.feathr.offline.anchored.keyExtractor.AlienSourceKeyExtractorAdaptor import com.linkedin.feathr.offline.client.plugins.FeathrUdfPluginContext import com.linkedin.feathr.offline.derived.AlienDerivationFunctionAdaptor -import com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.plugins.{AlienFeatureValue, AlienFeatureValueTypeAdaptor} import com.linkedin.feathr.offline.util.FeathrTestUtils import org.apache.spark.sql.Row @@ -16,11 +16,12 @@ class TestFeathrUdfPlugins extends FeathrIntegTest { val MULTILINE_QUOTE = "\"\"\"" + private val mvelContext = new FeathrExpressionExecutionContext() @Test def testMvelUdfPluginSupport: Unit = { - FeathrMvelPluginContext.addFeatureTypeAdaptor(classOf[AlienFeatureValue], new AlienFeatureValueTypeAdaptor()) - FeathrUdfPluginContext.registerUdfAdaptor(new AlienDerivationFunctionAdaptor()) - FeathrUdfPluginContext.registerUdfAdaptor(new AlienSourceKeyExtractorAdaptor()) + mvelContext.setupExecutorMvelContext(classOf[AlienFeatureValue], new AlienFeatureValueTypeAdaptor(), ss.sparkContext) + FeathrUdfPluginContext.registerUdfAdaptor(new AlienDerivationFunctionAdaptor(), ss.sparkContext) + FeathrUdfPluginContext.registerUdfAdaptor(new AlienSourceKeyExtractorAdaptor(), ss.sparkContext) val df = runLocalFeatureJoinForTest( joinConfigAsString = """ | features: { @@ -107,7 +108,8 @@ class TestFeathrUdfPlugins extends FeathrIntegTest { | } |} """.stripMargin, - observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv") + observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv", + mvelContext = Some(mvelContext)) val f8Type = df.fdsMetadata.header.get.featureInfoMap.filter(_._1.getFeatureName == "f8").head._2.featureType.getFeatureType assertEquals(f8Type, FeatureTypes.NUMERIC) diff --git a/src/test/scala/com/linkedin/feathr/offline/derived/TestSequentialJoinAsDerivation.scala b/src/test/scala/com/linkedin/feathr/offline/derived/TestSequentialJoinAsDerivation.scala index fd9f2d147..33e4ac822 100644 --- a/src/test/scala/com/linkedin/feathr/offline/derived/TestSequentialJoinAsDerivation.scala +++ b/src/test/scala/com/linkedin/feathr/offline/derived/TestSequentialJoinAsDerivation.scala @@ -10,6 +10,7 @@ import com.linkedin.feathr.offline.derived.strategies.SequentialJoinAsDerivation import com.linkedin.feathr.offline.job.FeatureTransformation.FEATURE_NAME_PREFIX import com.linkedin.feathr.offline.join.algorithms.{SeqJoinExplodedJoinKeyColumnAppender, SequentialJoinConditionBuilder, SparkJoinWithJoinCondition} import com.linkedin.feathr.offline.logical.FeatureGroups +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.{TestFeathr, TestUtils} import org.apache.log4j.{Level, Logger} import org.apache.spark.SparkException @@ -25,7 +26,7 @@ import org.apache.spark.sql.internal.SQLConf class TestSequentialJoinAsDerivation extends TestFeathr with MockitoSugar { Logger.getLogger("org").setLevel(Level.OFF) Logger.getLogger("akka").setLevel(Level.OFF) - + val mvelContext = new FeathrExpressionExecutionContext() private def getSampleEmployeeDF = { val schema = { StructType( @@ -997,7 +998,7 @@ class TestSequentialJoinAsDerivation extends TestFeathr with MockitoSugar { when(mockBaseTaggedDependency.outputKey).thenReturn(Some(Seq("outputKey1", "outputKey2"))) when(mockTaggedDependency.key).thenReturn(Seq("expansionKey1")) - seqJoinDerivations(Seq(0, 1, 2), Seq("keyTag1", "keyTag2", "keyTag3", "keyTag4"), ss.emptyDataFrame, mockDerivedFeature, mockDerivationFunction) + seqJoinDerivations(Seq(0, 1, 2), Seq("keyTag1", "keyTag2", "keyTag3", "keyTag4"), ss.emptyDataFrame, mockDerivedFeature, mockDerivationFunction, Some(mvelContext)) } /** @@ -1071,7 +1072,7 @@ class TestSequentialJoinAsDerivation extends TestFeathr with MockitoSugar { when(mockBaseTaggedDependency.outputKey).thenReturn(Some(Seq("outputKey1"))) when(mockTaggedDependency.key).thenReturn(Seq("expansionKey1")) - seqJoinDerivations(Seq(0, 1, 2), Seq("keyTag1", "keyTag2", "keyTag3", "keyTag4"), ss.emptyDataFrame, mockDerivedFeature, mockDerivationFunction) + seqJoinDerivations(Seq(0, 1, 2), Seq("keyTag1", "keyTag2", "keyTag3", "keyTag4"), ss.emptyDataFrame, mockDerivedFeature, mockDerivationFunction, Some(mvelContext)) } /** diff --git a/src/test/scala/com/linkedin/feathr/offline/join/workflow/TestAnchoredFeatureJoinStep.scala b/src/test/scala/com/linkedin/feathr/offline/join/workflow/TestAnchoredFeatureJoinStep.scala index dc699276d..f9ec6d50e 100644 --- a/src/test/scala/com/linkedin/feathr/offline/join/workflow/TestAnchoredFeatureJoinStep.scala +++ b/src/test/scala/com/linkedin/feathr/offline/join/workflow/TestAnchoredFeatureJoinStep.scala @@ -50,7 +50,7 @@ class TestAnchoredFeatureJoinStep extends TestFeathr with MockitoSugar { val mockAnchorStepInput = mock[AnchorJoinStepInput] when(mockAnchorStepInput.observation).thenReturn(ss.emptyDataFrame) val basicAnchoredFeatureJoinStep = - AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, SparkJoinWithJoinCondition(EqualityJoinConditionBuilder)) + AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, SparkJoinWithJoinCondition(EqualityJoinConditionBuilder), None) val FeatureDataFrameOutput(FeatureDataFrame(outputDF, inferredFeatureType)) = basicAnchoredFeatureJoinStep.joinFeatures(Seq(ErasedEntityTaggedFeature(Seq(0), "featureName1")), mockAnchorStepInput)(mockExecutionContext) @@ -77,7 +77,7 @@ class TestAnchoredFeatureJoinStep extends TestFeathr with MockitoSugar { KeyedTransformedResult(Seq("joinKey1", "joinKey2"), mockTransformedResult), KeyedTransformedResult(Seq("joinKey2", "joinKey3"), mockTransformedResult)) val basicAnchoredFeatureJoinStep = - AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, SparkJoinWithJoinCondition(EqualityJoinConditionBuilder)) + AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, SparkJoinWithJoinCondition(EqualityJoinConditionBuilder), None) basicAnchoredFeatureJoinStep.joinFeaturesOnSingleDF( Seq(0), Seq("leftJoinKeyColumn"), @@ -103,7 +103,7 @@ class TestAnchoredFeatureJoinStep extends TestFeathr with MockitoSugar { // observation DF val leftDF = getDefaultDataFrame() val basicAnchoredFeatureJoinStep = - AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, SparkJoinWithJoinCondition(EqualityJoinConditionBuilder)) + AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, SparkJoinWithJoinCondition(EqualityJoinConditionBuilder), None) val resultDF = basicAnchoredFeatureJoinStep.joinFeaturesOnSingleDF(Seq(0), Seq("x"), leftDF, (Seq("feature1", "feature2"), keyedTransformedResults))( mockExecutionContext) resultDF.show() @@ -141,7 +141,7 @@ class TestAnchoredFeatureJoinStep extends TestFeathr with MockitoSugar { // observation DF val leftDF = getDefaultDataFrame() - val basicAnchoredFeatureJoinStep = AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, mockJoiner) + val basicAnchoredFeatureJoinStep = AnchoredFeatureJoinStep(SqlTransformedLeftJoinKeyColumnAppender, IdentityJoinKeyColumnAppender, mockJoiner, None) val resultDF = basicAnchoredFeatureJoinStep.joinFeaturesOnSingleDF(Seq(0), Seq("x"), leftDF, (Seq("feature1", "feature2"), keyedTransformedResults))( mockExecutionContext) // Verify that the joiner was called by validating an empty DataFrame was indeed returned