Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Feature value adaptor and UDF adaptor on Spark executors #660

Merged
merged 4 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import sbt.Keys.publishLocalConfiguration

ThisBuild / resolvers += Resolver.mavenLocal
ThisBuild / scalaVersion := "2.12.15"
ThisBuild / version := "0.7.2"
ThisBuild / organization := "com.linkedin.feathr"
ThisBuild / organizationName := "linkedin"
val sparkVersion = "3.1.3"

publishLocalConfiguration := publishLocalConfiguration.value.withOverwrite(true)
xiaoyongzhu marked this conversation as resolved.
Show resolved Hide resolved

val localAndCloudDiffDependencies = Seq(
"org.apache.spark" %% "spark-avro" % sparkVersion,
"org.apache.spark" %% "spark-sql" % sparkVersion,
Expand Down Expand Up @@ -101,4 +105,4 @@ assembly / assemblyMergeStrategy := {
// Some systems(like Hadoop) use different versinos of protobuf(like v2) so we have to shade it.
assemblyShadeRules in assembly := Seq(
ShadeRule.rename("com.google.protobuf.**" -> "shade.protobuf.@1").inAll,
)
)
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.linkedin.feathr.offline

import java.io.Serializable

import com.linkedin.feathr.common
import com.linkedin.feathr.common.{FeatureTypes, FeatureValue}
import com.linkedin.feathr.offline.exception.FeatureTransformationException
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.mvel.{FeatureVariableResolverFactory, MvelContext}
import com.linkedin.feathr.offline.transformation.MvelDefinition
import com.linkedin.feathr.offline.util.{CoercionUtilsScala, FeaturizedDatasetUtils}
Expand Down Expand Up @@ -32,9 +32,9 @@ private[offline] object PostTransformationUtil {
* @param input input feature value
* @return transformed feature value
*/
def booleanTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Boolean): Boolean = {
def booleanTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Boolean, mvelContext: Option[FeathrExpressionExecutionContext]): Boolean = {
val toFeatureValue = common.FeatureValue.createBoolean(input)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsTermVector.containsKey("true")
case Failure(ex) =>
Expand All @@ -57,12 +57,12 @@ private[offline] object PostTransformationUtil {
featureName: String,
mvelExpression: MvelDefinition,
compiledExpression: Serializable,
input: GenericRowWithSchema): Map[String, Float] = {
input: GenericRowWithSchema, mvelContext: Option[FeathrExpressionExecutionContext]): Map[String, Float] = {
if (input != null) {
val inputMapKey = input.getAs[Seq[String]](FeaturizedDatasetUtils.FDS_1D_TENSOR_DIM)
val inputMapVal = input.getAs[Seq[Float]](FeaturizedDatasetUtils.FDS_1D_TENSOR_VALUE)
val inputMap = inputMapKey.zip(inputMapVal).toMap
mapTransformer(featureName, mvelExpression, compiledExpression, inputMap)
mapTransformer(featureName, mvelExpression, compiledExpression, inputMap, mvelContext)
} else Map()
}

Expand All @@ -79,7 +79,8 @@ private[offline] object PostTransformationUtil {
featureNameColumnTuples: Seq[(String, String)],
contextDF: DataFrame,
transformationDef: Map[String, MvelDefinition],
defaultTransformation: (DataType, String) => Column): DataFrame = {
defaultTransformation: (DataType, String) => Column,
mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = {
val featureColumnNames = featureNameColumnTuples.map(_._2)

// Transform the features with the provided transformations
Expand All @@ -93,11 +94,11 @@ private[offline] object PostTransformationUtil {
val parserContext = MvelContext.newParserContext()
val compiledExpression = MVEL.compileExpression(mvelExpressionDef.mvelDef, parserContext)
val featureType = mvelExpressionDef.featureType
val convertToString = udf(stringTransformer(featureName, mvelExpressionDef, compiledExpression, _: String))
val convertToBoolean = udf(booleanTransformer(featureName, mvelExpressionDef, compiledExpression, _: Boolean))
val convertToFloat = udf(floatTransformer(featureName, mvelExpressionDef, compiledExpression, _: Float))
val convertToMap = udf(mapTransformer(featureName, mvelExpressionDef, compiledExpression, _: Map[String, Float]))
val convertFDS1dTensorToMap = udf(fds1dTensorTransformer(featureName, mvelExpressionDef, compiledExpression, _: GenericRowWithSchema))
val convertToString = udf(stringTransformer(featureName, mvelExpressionDef, compiledExpression, _: String, mvelContext))
val convertToBoolean = udf(booleanTransformer(featureName, mvelExpressionDef, compiledExpression, _: Boolean, mvelContext))
val convertToFloat = udf(floatTransformer(featureName, mvelExpressionDef, compiledExpression, _: Float, mvelContext))
val convertToMap = udf(mapTransformer(featureName, mvelExpressionDef, compiledExpression, _: Map[String, Float], mvelContext))
val convertFDS1dTensorToMap = udf(fds1dTensorTransformer(featureName, mvelExpressionDef, compiledExpression, _: GenericRowWithSchema, mvelContext))
fieldType.dataType match {
case _: StringType => convertToString(contextDF(columnName))
case _: NumericType => convertToFloat(contextDF(columnName))
Expand Down Expand Up @@ -126,16 +127,17 @@ private[offline] object PostTransformationUtil {
featureName: String,
featureValue: FeatureValue,
compiledExpression: Serializable,
featureType: FeatureTypes): Try[FeatureValue] = Try {
featureType: FeatureTypes,
mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val transformedValue = MvelContext.executeExpressionWithPluginSupport(compiledExpression, featureValue, variableResolverFactory)
val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}

private def floatTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Float): Float = {
private def floatTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: Float, mvelContext: Option[FeathrExpressionExecutionContext]): Float = {
val toFeatureValue = common.FeatureValue.createNumeric(input)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.NUMERIC)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.NUMERIC, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsNumeric
case Failure(ex) =>
Expand All @@ -146,9 +148,9 @@ private[offline] object PostTransformationUtil {
}
}

private def stringTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: String): String = {
private def stringTransformer(featureName: String, mvelExpression: MvelDefinition, compiledExpression: Serializable, input: String, mvelContext: Option[FeathrExpressionExecutionContext]): String = {
val toFeatureValue = common.FeatureValue.createCategorical(input)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.CATEGORICAL)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.CATEGORICAL, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsString
case Failure(ex) =>
Expand All @@ -163,12 +165,13 @@ private[offline] object PostTransformationUtil {
featureName: String,
mvelExpression: MvelDefinition,
compiledExpression: Serializable,
input: Map[String, Float]): Map[String, Float] = {
input: Map[String, Float],
mvelContext: Option[FeathrExpressionExecutionContext]): Map[String, Float] = {
if (input == null) {
return Map()
}
val toFeatureValue = new common.FeatureValue(input.asJava)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR)
val transformedFeatureValue = transformFeatureValues(featureName, toFeatureValue, compiledExpression, FeatureTypes.TERM_VECTOR, mvelContext)
transformedFeatureValue match {
case Success(fVal) => fVal.getAsTermVector.asScala.map(kv => (kv._1.asInstanceOf[String], kv._2.asInstanceOf[Float])).toMap
case Failure(ex) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
package com.linkedin.feathr.offline.anchored.anchorExtractor

import java.io.Serializable

import com.linkedin.feathr.offline.config.MVELFeatureDefinition
import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils}
import org.mvel2.MVEL

import java.io.Serializable
import scala.collection.convert.wrapAll._

private[offline] class DebugMvelAnchorExtractor(keyExprs: Seq[String], features: Map[String, MVELFeatureDefinition])
extends SimpleConfigurableAnchorExtractor(keyExprs, features) {
private val debugExpressions = features.mapValues(value => findDebugExpressions(value.featureExpr)).map(identity)

private val debugCompiledExpressions = debugExpressions.mapValues(_.map(x => (x, compile(x)))).map(identity)

def evaluateDebugExpressions(input: Any): Map[String, Seq[(String, Any)]] = {
debugCompiledExpressions
.mapValues(_.map {
case (expr, compiled) =>
(expr, MvelUtils.executeExpression(compiled, input, null).orNull)
(expr, MvelUtils.executeExpression(compiled, input, null, "", None).orNull)
})
.map(identity)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.linkedin.feathr.common.util.CoercionUtils
import com.linkedin.feathr.common.{AnchorExtractor, FeatureTypeConfig, FeatureTypes, FeatureValue, SparkRowExtractor}
import com.linkedin.feathr.offline
import com.linkedin.feathr.offline.config.MVELFeatureDefinition
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils}
import com.linkedin.feathr.offline.util.FeatureValueTypeValidator
import org.apache.log4j.Logger
Expand All @@ -28,6 +29,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
@JsonProperty("features") features: Map[String, MVELFeatureDefinition])
extends AnchorExtractor[Any] with SparkRowExtractor {

var mvelContext: Option[FeathrExpressionExecutionContext] = None
@transient private lazy val log = Logger.getLogger(getClass)

def getKeyExpression(): Seq[String] = key
Expand Down Expand Up @@ -73,7 +75,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
// be more strict for resolving keys (don't swallow exceptions)
keyExpression.map(k =>
try {
Option(MvelContext.executeExpressionWithPluginSupport(k, datum)) match {
Option(MvelContext.executeExpressionWithPluginSupport(k, datum, mvelContext.orNull)) match {
case None => null
case Some(keys) => keys.toString
}
Expand All @@ -92,7 +94,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k

featureExpressions collect {
case (featureRefStr, (expression, featureType)) if selectedFeatures.contains(featureRefStr) =>
(featureRefStr, (MvelUtils.executeExpression(expression, datum, null, featureRefStr), featureType))
(featureRefStr, (MvelUtils.executeExpression(expression, datum, null, featureRefStr, mvelContext), featureType))
} collect {
// Apply a partial function only for non-empty feature values, empty feature values will be set to default later
case (featureRefStr, (Some(value), fType)) =>
Expand Down Expand Up @@ -165,7 +167,7 @@ private[offline] class SimpleConfigurableAnchorExtractor( @JsonProperty("key") k
* for building a tensor. Feature's value type and dimension type(s) are obtained via Feathr's Feature Metadata
* Library during tensor construction.
*/
(featureRefStr, MvelUtils.executeExpression(expression, datum, null, featureRefStr))
(featureRefStr, MvelUtils.executeExpression(expression, datum, null, featureRefStr, mvelContext))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import com.linkedin.feathr.offline.generation.{DataFrameFeatureGenerator, Featur
import com.linkedin.feathr.offline.job._
import com.linkedin.feathr.offline.join.DataFrameFeatureJoiner
import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlanner}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.DataSource
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.util.{FeathrUtils, _}
import com.linkedin.feathr.offline.util._
import org.apache.log4j.Logger
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -27,7 +28,7 @@ import scala.util.{Failure, Success}
*
*/
class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups: FeatureGroups, logicalPlanner: MultiStageJoinPlanner,
featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler]) {
featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler], mvelContext: Option[FeathrExpressionExecutionContext]) {
private val log = Logger.getLogger(getClass)

type KeyTagStringTuple = Seq[String]
Expand Down Expand Up @@ -91,7 +92,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
// Get logical plan
val logicalPlan = logicalPlanner.getLogicalPlan(featureGroups, keyTaggedRequiredFeatures)
// This pattern is consistent with the join use case which uses DataFrameFeatureJoiner.
val dataFrameFeatureGenerator = new DataFrameFeatureGenerator(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers)
val dataFrameFeatureGenerator = new DataFrameFeatureGenerator(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext)
val featureMap: Map[TaggedFeatureName, (DataFrame, Header)] =
dataFrameFeatureGenerator.generateFeaturesAsDF(sparkSession, featureGenSpec, featureGroups, keyTaggedRequiredFeatures)

Expand Down Expand Up @@ -263,7 +264,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
s"Please rename feature ${conflictFeatureNames} or rename the same field names in the observation data.")
}

val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers)
val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext)
joiner.joinFeaturesAsDF(sparkSession, joinConfig, updatedFeatureGroups, keyTaggedFeatures, left, rowBloomFilterThreshold)
}

Expand Down Expand Up @@ -337,6 +338,7 @@ object FeathrClient {
private var localOverrideDefPath: List[String] = List()
private var featureDefConfs: List[FeathrConfig] = List()
private var dataPathHandlers: List[DataPathHandler] = List()
private var mvelContext: Option[FeathrExpressionExecutionContext] = None;


/**
Expand Down Expand Up @@ -495,6 +497,10 @@ object FeathrClient {
this.featureDefConfs = featureDefConfs
this
}
def addFeathrExpressionContext(_mvelContext: Option[FeathrExpressionExecutionContext]): Builder = {
this.mvelContext = _mvelContext
this
}

/**
* Build a new instance of the FeathrClient from the added feathr definition configs and any local overrides.
Expand Down Expand Up @@ -529,7 +535,7 @@ object FeathrClient {
featureDefConfigs = featureDefConfigs ++ featureDefConfs

val featureGroups = FeatureGroupsGenerator(featureDefConfigs, Some(localDefConfigs)).getFeatureGroups()
val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(), dataPathHandlers)
val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(), dataPathHandlers, mvelContext)

feathrClient
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
package com.linkedin.feathr.offline.client.plugins
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast

import scala.collection.mutable

Expand All @@ -9,15 +11,21 @@ import scala.collection.mutable
* All "external" UDF classes are required to have a public default zero-arg constructor.
*/
object FeathrUdfPluginContext {
val registeredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]()

def registerUdfAdaptor(adaptor: UdfAdaptor[_]): Unit = {
private val localRegisteredUdfAdaptors = mutable.Buffer[UdfAdaptor[_]]()
private var registeredUdfAdaptors: Broadcast[mutable.Buffer[UdfAdaptor[_]]] = null
def registerUdfAdaptor(adaptor: UdfAdaptor[_], sc: SparkContext): Unit = {
this.synchronized {
registeredUdfAdaptors += adaptor
localRegisteredUdfAdaptors += adaptor
if (registeredUdfAdaptors != null) {
registeredUdfAdaptors.destroy()
}
registeredUdfAdaptors = sc.broadcast(localRegisteredUdfAdaptors)
}
}

def getRegisteredUdfAdaptor(clazz: Class[_]): Option[UdfAdaptor[_]] = {
registeredUdfAdaptors.find(_.canAdapt(clazz))
if (registeredUdfAdaptors != null) {
registeredUdfAdaptors.value.find(_.canAdapt(clazz))
} else None
}
}
Loading