Skip to content

Commit

Permalink
Fix Feature value adaptor and UDF adaptor on Spark executors (#660)
Browse files Browse the repository at this point in the history
* Fix Feature value adaptor and UDF adaptor on Spark executors

* Fix path with #LATEST

* Add comments

* Defer version bump
  • Loading branch information
jaymo001 authored Sep 23, 2022
1 parent 6035f04 commit 25aa097
Show file tree
Hide file tree
Showing 32 changed files with 362 additions and 262 deletions.
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import sbt.Keys.publishLocalConfiguration

ThisBuild / resolvers += Resolver.mavenLocal
ThisBuild / scalaVersion := "2.12.15"
ThisBuild / version := "0.7.2"
ThisBuild / organization := "com.linkedin.feathr"
ThisBuild / organizationName := "linkedin"
val sparkVersion = "3.1.3"

publishLocalConfiguration := publishLocalConfiguration.value.withOverwrite(true)

val localAndCloudDiffDependencies = Seq(
"org.apache.spark" %% "spark-avro" % sparkVersion,
"org.apache.spark" %% "spark-sql" % sparkVersion,
Expand Down Expand Up @@ -101,4 +105,4 @@ assembly / assemblyMergeStrategy := {
// Some systems(like Hadoop) use different versinos of protobuf(like v2) so we have to shade it.
assemblyShadeRules in assembly := Seq(
ShadeRule.rename("com.google.protobuf.**" -> "shade.protobuf.@1").inAll,
)
)
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.linkedin.feathr.offline

import java.io.Serializable

import com.linkedin.feathr.common
import com.linkedin.feathr.common.{FeatureTypes, FeatureValue}
import com.linkedin.feathr.offline.exception.FeatureTransformationException
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.mvel.{FeatureVariableResolverFactory, MvelContext}
import com.linkedin.feathr.offline.transformation.MvelDefinition
import com.linkedin.feathr.offline.util.{CoercionUtilsScala, FeaturizedDatasetUtils}
Expand Down Expand Up @@ -32,9 +32,9 @@ private[offline] object PostTransformationUtil {
* @param input input feature value
* @return transformed feature value
*/
def booleanTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Boolean): Boolean = {
def booleanTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Boolean, mvelContext: Option[FeathrExpressionExecutionContext]): Boolean = {
val toFeatureValue = common.FeatureValue.createBoolean(input)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsTermVector.containsKey("true")
case Failure(ex) =>
Expand All @@ -57,12 +57,12 @@ private[offline] object PostTransformationUtil {
featureName: String,
mvelExpression: MvelDefinition,
compiledExpression: Serializable,
input: GenericRowWithSchema): Map[String, Float] = {
input: GenericRowWithSchema, mvelContext: Option[FeathrExpressionExecutionContext]): Map[String, Float] = {
if (input != null) {
val inputMapKey = input.getAs[Seq[String]](FeaturizedDatasetUtils.FDS_1D_TENSOR_DIM)
val inputMapVal = input.getAs[Seq[Float]](FeaturizedDatasetUtils.FDS_1D_TENSOR_VALUE)
val inputMap = inputMapKey.zip(inputMapVal).toMap
mapTransformer(featureName, mvelExpression, compiledExpression, inputMap)
mapTransformer(featureName, mvelExpression, compiledExpression, inputMap, mvelContext)
} else Map()
}

Expand All @@ -79,7 +79,8 @@ private[offline] object PostTransformationUtil {
featureNameColumnTuples: Seq[(String, String)],
contextDF: DataFrame,
transformationDef: Map[String, MvelDefinition],
defaultTransformation: (DataType, String) => Column): DataFrame = {
defaultTransformation: (DataType, String) => Column,
mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = {
val featureColumnNames = featureNameColumnTuples.map(_._2)

// Transform the features with the provided transformations
Expand All @@ -93,11 +94,11 @@ private[offline] object PostTransformationUtil {
val parserContext = MvelContext.newParserContext()
val compiledExpression = MVEL.compileExpression(mvelExpressionDef.mvelDef, parserContext)
val featureType = mvelExpressionDef.featureType
val convertToString = udf(stringTransformer(featureName, mvelExpressionDef, compiledExpression, _: String))
val convertToBoolean = udf(booleanTransformer(featureName, mvelExpressionDef, compiledExpression, _: Boolean))
val convertToFloat = udf(floatTransformer(featureName, mvelExpressionDef, compiledExpression, _: Float))
val convertToMap = udf(mapTransformer(featureName, mvelExpressionDef, compiledExpression, _: Map[String, Float]))
val convertFDS1dTensorToMap = udf(fds1dTensorTransformer(featureName, mvelExpressionDef, compiledExpression, _: GenericRowWithSchema))
val convertToString = udf(stringTransformer(featureName, mvelExpressionDef, compiledExpression, _: String, mvelContext))
val convertToBoolean = udf(booleanTransformer(featureName, mvelExpressionDef, compiledExpression, _: Boolean, mvelContext))
val convertToFloat = udf(floatTransformer(featureName, mvelExpressionDef, compiledExpression, _: Float, mvelContext))
val convertToMap = udf(mapTransformer(featureName, mvelExpressionDef, compiledExpression, _: Map[String, Float], mvelContext))
val convertFDS1dTensorToMap = udf(fds1dTensorTransformer(featureName, mvelExpressionDef, compiledExpression, _: GenericRowWithSchema, mvelContext))
fieldType.dataType match {
case _: StringType => convertToString(contextDF(columnName))
case _: NumericType => convertToFloat(contextDF(columnName))
Expand Down Expand Up @@ -126,16 +127,17 @@ private[offline] object PostTransformationUtil {
featureName: String,
featureValue: FeatureValue,
compiledExpression: Serializable,
featureType: FeatureTypes): Try[FeatureValue] = Try {
featureType: FeatureTypes,
mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val transformedValue = MvelContext.executeExpressionWithPluginSupport(compiledExpression, featureValue, variableResolverFactory)
val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}

private def floatTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Float): Float = {
private def floatTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Float, mvelContext: Option[FeathrExpressionExecutionContext]): Float = {
val toFeatureValue = common.FeatureValue.createNumeric(input)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.NUMERIC)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.NUMERIC, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsNumeric
case Failure(ex) =>
Expand All @@ -146,9 +148,9 @@ private[offline] object PostTransformationUtil {
}
}

private def stringTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: String): String = {
private def stringTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: String, mvelContext: Option[FeathrExpressionExecutionContext]): String = {
val toFeatureValue = common.FeatureValue.createCategorical(input)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.CATEGORICAL)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.CATEGORICAL, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsString
case Failure(ex) =>
Expand All @@ -163,12 +165,13 @@ private[offline] object PostTransformationUtil {
featureName: String,
mvelExpression: MvelDefinition,
compiledExpression: Serializable,
input: Map[String, Float]): Map[String, Float] = {
input: Map[String, Float],
mvelContext: Option[FeathrExpressionExecutionContext]): Map[String, Float] = {
if (input == null) {
return Map()
}
val toFeatureValue = new common.FeatureValue(input.asJava)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsTermVector.asScala.map(kv => (kv._1.asInstanceOf[String], kv._2.asInstanceOf[Float])).toMap
case Failure(ex) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
package com.linkedin.feathr.offline.anchored.anchorExtractor

import java.io.Serializable

import com.linkedin.feathr.offline.config.MVELFeatureDefinition
import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils}
import org.mvel2.MVEL

import java.io.Serializable
import scala.collection.convert.wrapAll._

private[offline] class DebugMvelAnchorExtractor(keyExprs: Seq[String], features: Map[String, MVELFeatureDefinition])
extends SimpleConfigurableAnchorExtractor(keyExprs, features) {
private val debugExpressions = features.mapValues(value => findDebugExpressions(value.featureExpr)).map(identity)

private val debugCompiledExpressions = debugExpressions.mapValues(_.map(x => (x, compile(x)))).map(identity)

def evaluateDebugExpressions(input: Any): Map[String, Seq[(String, Any)]] = {
debugCompiledExpressions
.mapValues(_.map {
case (expr, compiled) =>
(expr, MvelUtils.executeExpression(compiled, input, null).orNull)
(expr, MvelUtils.executeExpression(compiled, input, null, "", None).orNull)
})
.map(identity)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.linkedin.feathr.common.util.CoercionUtils
import com.linkedin.feathr.common.{AnchorExtractor, FeatureTypeConfig, FeatureTypes, FeatureValue, SparkRowExtractor}
import com.linkedin.feathr.offline
import com.linkedin.feathr.offline.config.MVELFeatureDefinition
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils}
import com.linkedin.feathr.offline.util.FeatureValueTypeValidator
import org.apache.log4j.Logger
Expand All @@ -28,6 +29,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
@JsonProperty("features") features: Map[String, MVELFeatureDefinition])
extends AnchorExtractor[Any] with SparkRowExtractor {

var mvelContext: Option[FeathrExpressionExecutionContext] = None
@transient private lazy val log = Logger.getLogger(getClass)

def getKeyExpression(): Seq[String] = key
Expand Down Expand Up @@ -73,7 +75,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
// be more strict for resolving keys (don't swallow exceptions)
keyExpression.map(k =>
try {
Option(MvelContext.executeExpressionWithPluginSupport(k, datum)) match {
Option(MvelContext.executeExpressionWithPluginSupport(k, datum, mvelContext.orNull)) match {
case None => null
case Some(keys) => keys.toString
}
Expand All @@ -92,7 +94,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k

featureExpressions collect {
case (featureRefStr, (expression, featureType)) if selectedFeatures.contains(featureRefStr) =>
(featureRefStr, (MvelUtils.executeExpression(expression, datum, null, featureRefStr), featureType))
(featureRefStr, (MvelUtils.executeExpression(expression, datum, null, featureRefStr, mvelContext), featureType))
} collect {
// Apply a partial function only for non-empty feature values, empty feature values will be set to default later
case (featureRefStr, (Some(value), fType)) =>
Expand Down Expand Up @@ -165,7 +167,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
* for building a tensor. Feature's value type and dimension type(s) are obtained via Feathr's Feature Metadata
* Library during tensor construction.
*/
(featureRefStr, MvelUtils.executeExpression(expression, datum, null, featureRefStr))
(featureRefStr, MvelUtils.executeExpression(expression, datum, null, featureRefStr, mvelContext))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import com.linkedin.feathr.offline.generation.{DataFrameFeatureGenerator, Featur
import com.linkedin.feathr.offline.job._
import com.linkedin.feathr.offline.join.DataFrameFeatureJoiner
import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlanner}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.DataSource
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.util.{FeathrUtils, _}
import com.linkedin.feathr.offline.util._
import org.apache.log4j.Logger
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -27,7 +28,7 @@ import scala.util.{Failure, Success}
*
*/
class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups: FeatureGroups, logicalPlanner: MultiStageJoinPlanner,
featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler]) {
featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler], mvelContext: Option[FeathrExpressionExecutionContext]) {
private val log = Logger.getLogger(getClass)

type KeyTagStringTuple = Seq[String]
Expand Down Expand Up @@ -91,7 +92,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
// Get logical plan
val logicalPlan = logicalPlanner.getLogicalPlan(featureGroups, keyTaggedRequiredFeatures)
// This pattern is consistent with the join use case which uses DataFrameFeatureJoiner.
val dataFrameFeatureGenerator = new DataFrameFeatureGenerator(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers)
val dataFrameFeatureGenerator = new DataFrameFeatureGenerator(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext)
val featureMap: Map[TaggedFeatureName, (DataFrame, Header)] =
dataFrameFeatureGenerator.generateFeaturesAsDF(sparkSession, featureGenSpec, featureGroups, keyTaggedRequiredFeatures)

Expand Down Expand Up @@ -263,7 +264,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
s"Please rename feature ${conflictFeatureNames} or rename the same field names in the observation data.")
}

val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers)
val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext)
joiner.joinFeaturesAsDF(sparkSession, joinConfig, updatedFeatureGroups, keyTaggedFeatures, left, rowBloomFilterThreshold)
}

Expand Down Expand Up @@ -337,6 +338,7 @@ object FeathrClient {
private var localOverrideDefPath: List[String] = List()
private var featureDefConfs: List[FeathrConfig] = List()
private var dataPathHandlers: List[DataPathHandler] = List()
private var mvelContext: Option[FeathrExpressionExecutionContext] = None;


/**
Expand Down Expand Up @@ -495,6 +497,10 @@ object FeathrClient {
this.featureDefConfs = featureDefConfs
this
}
def addFeathrExpressionContext(_mvelContext: Option[FeathrExpressionExecutionContext]): Builder = {
this.mvelContext = _mvelContext
this
}

/**
* Build a new instance of the FeathrClient from the added feathr definition configs and any local overrides.
Expand Down Expand Up @@ -529,7 +535,7 @@ object FeathrClient {
featureDefConfigs = featureDefConfigs ++ featureDefConfs

val featureGroups = FeatureGroupsGenerator(featureDefConfigs, Some(localDefConfigs)).getFeatureGroups()
val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(), dataPathHandlers)
val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(), dataPathHandlers, mvelContext)

feathrClient
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
package com.linkedin.feathr.offline.client.plugins
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast

import scala.collection.mutable

Expand All @@ -9,15 +11,21 @@ import scala.collection.mutable
* All "external" UDF classes are required to have a public default zero-arg constructor.
*/
object FeathrUdfPluginContext {
val registeredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]()

def registerUdfAdaptor(adaptor: UdfAdaptor[_]): Unit = {
private val localRegisteredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]()
private var registeredUdfAdaptors: Broadcast[mutable.Buffer[UdfAdaptor[_]]] = null
def registerUdfAdaptor(adaptor: UdfAdaptor[_], sc: SparkContext): Unit = {
this.synchronized {
registeredUdfAdaptors += adaptor
localRegisteredUdfAdaptors += adaptor
if (registeredUdfAdaptors != null) {
registeredUdfAdaptors.destroy()
}
registeredUdfAdaptors = sc.broadcast(localRegisteredUdfAdaptors)
}
}

def getRegisteredUdfAdaptor(clazz: Class[_]): Option[UdfAdaptor[_]] = {
registeredUdfAdaptors.find(_.canAdapt(clazz))
if (registeredUdfAdaptors != null) {
registeredUdfAdaptors.value.find(_.canAdapt(clazz))
} else None
}
}
Loading

0 comments on commit 25aa097

Please sign in to comment.