Skip to content

Commit

Permalink
Minor code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jaymo001 committed Oct 4, 2022
1 parent 6f82b6f commit 6991bf4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ trait SparkRowExtractor {
* This is especially useful when users need to convert the source DataFrame
* into specific datatype, e.g. Avro GenericRecord or SpecificRecord.
*/
def hasBatchPreProcessing() = false
def isLowLevelRddExtractor() = false

/**
* One time batch preprocess the input data source into a RDD[_] for feature extraction later
* One time batch preprocess the input data source into a RDD[IndexedRecord] for feature extraction later
* @param df input data source
* @return batch preprocessed dataframe, as RDD[IndexedRecord]
*/
def batchPreProcess(df: DataFrame) : RDD[IndexedRecord] = throw new NotImplementedError("Batch preprocess is not implemented")
def convertToAvroRdd(df: DataFrame) : RDD[IndexedRecord] = throw new NotImplementedError("Batch preprocess is not implemented")
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ import scala.concurrent.{Await, ExecutionContext, Future}
*/
private[offline] case class AnchorFeatureGroups(anchorsWithSameSource: Seq[FeatureAnchorWithSource], requestedFeatures: Seq[String])

/**
* Context info needed in feature transformation
* @param featureAnchorWithSource feature annchor with its source
* @param featureNamePrefixPairs map of feature name to its prefix
* @param transformer transformer of anchor
*/
private[offline] case class TransformInfo(featureAnchorWithSource: FeatureAnchorWithSource,
featureNamePrefixPairs: Seq[(FeatureName, FeatureName)],
transformer: AnchorExtractorBase[IndexedRecord])

/**
* Represent the transformed result of an anchor extractor after evaluating its features
* @param featureNameAndPrefixPairs pairs of feature name and feature name prefix
Expand Down Expand Up @@ -765,7 +775,7 @@ private[offline] object FeatureTransformation {
* @param bloomFilter bloomfilter to apply on source rdd
* @param requestedFeatureNames requested features
* @param featureTypeConfigs user specified feature types
* @return TransformedResultWithKey
* @return TransformedResultWithKey The output feature DataFrame conforms to FDS format
*/
private def transformFeaturesOnAvroRecord(df: DataFrame,
keyExtractor: SourceKeyExtractor,
Expand All @@ -778,11 +788,11 @@ private[offline] object FeatureTransformation {
s"Key extractor ${keyExtractor} must extends MVELSourceKeyExtractor.")
}
val extractor = keyExtractor.asInstanceOf[MVELSourceKeyExtractor]
if (!extractor.anchorExtractorV1.hasBatchPreProcessing()) {
if (!extractor.anchorExtractorV1.isLowLevelRddExtractor()) {
throw new FeathrException(ErrorLabel.FEATHR_ERROR, s"Error processing requested Feature :${requestedFeatureNames}. " +
s"Missing batch preprocessors.")
s"isLowLevelRddExtractor() should return true and convertToAvroRdd should be implemented.")
}
val rdd = extractor.anchorExtractorV1.batchPreProcess(df)
val rdd = extractor.anchorExtractorV1.convertToAvroRdd(df)
val filteredFactData = applyBloomFilterRdd(keyExtractor, rdd, bloomFilter)

// Build a sequence of 3-tuple of (FeatureAnchorWithSource, featureNamePrefixPairs, AnchorExtractorBase)
Expand All @@ -795,25 +805,25 @@ private[offline] object FeatureTransformation {
val featureNamePrefix = ""
val featureNames = featureAnchorWithSource.selectedFeatures.filter(requestedFeatureNames.contains)
val featureNamePrefixPairs = featureNames.map((_, featureNamePrefix))
(featureAnchorWithSource, featureNamePrefixPairs, transformer)
TransformInfo(featureAnchorWithSource, featureNamePrefixPairs, transformer)

case _ =>
throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_USER_ERROR, s"Unsupported transformer $extractor for features: $requestedFeatureNames")
}
}

// to avoid name conflict between feature names and the raw data field names
val sourceKeyExtractors = transformInfo.map(_._1.featureAnchor.sourceKeyExtractor)
val sourceKeyExtractors = transformInfo.map(_.featureAnchorWithSource.featureAnchor.sourceKeyExtractor)
assert(sourceKeyExtractors.map(_.toString).distinct.size == 1)

val transformers = transformInfo map (_._3)
val transformers = transformInfo map (_.transformer)

/*
* Transform the given RDD by applying extractors to each row to create an RDD[Row] where each Row
* represents keys and feature values
*/
val spark = SparkSession.builder().getOrCreate()
val userProvidedFeatureTypes = transformInfo.flatMap(_._1.featureAnchor.getFeatureTypes.getOrElse(Map.empty[String, FeatureTypes])).toMap
val userProvidedFeatureTypes = transformInfo.flatMap(_.featureAnchorWithSource.featureAnchor.getFeatureTypes.getOrElse(Map.empty[String, FeatureTypes])).toMap
val FeatureTypeInferenceContext(featureTypeAccumulators) =
FeatureTransformation.getTypeInferenceContext(spark, userProvidedFeatureTypes, requestedFeatureNames)
val transformedRdd = filteredFactData map { record =>
Expand Down Expand Up @@ -849,7 +859,7 @@ private[offline] object FeatureTransformation {

val featureFormat = FeatureColumnFormat.FDS_TENSOR
val featureColumnFormats = requestedFeatureNames.map(name => name -> featureFormat).toMap
val transformedInfo = TransformedResult(transformInfo.flatMap(_._2), transformedDF, featureColumnFormats, inferredFeatureTypeConfigs)
val transformedInfo = TransformedResult(transformInfo.flatMap(_.featureNamePrefixPairs), transformedDF, featureColumnFormats, inferredFeatureTypeConfigs)
KeyedTransformedResult(keyNames, transformedInfo)
}

Expand Down Expand Up @@ -893,8 +903,8 @@ private[offline] object FeatureTransformation {
}

/*
* Retain feature values for only the requested features, and represent each feature value as a term-vector or as
* a tensor, as specified. If tensors are required, create a row for each feature value (that is, the tensor).
* Retain feature values for only the requested features, and represent each feature value as
* a tensor, as specified.
*/
val featureValuesWithType = requestedFeatureNames map { name =>
features.get(name) map {
Expand Down

0 comments on commit 6991bf4

Please sign in to comment.