Skip to content

Commit

Permalink
Fix feature type bug where inferred feature type might not be honored…
Browse files Browse the repository at this point in the history
… when all the feature types are not provided (feathr-ai#701)
  • Loading branch information
jaymo001 authored and hyingyang-linkedin committed Oct 25, 2022
1 parent 89835c3 commit 930640f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.linkedin.feathr.offline.client

import com.google.common.annotations.VisibleForTesting
import com.linkedin.feathr.common._
import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureTransformationException}
import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource
Expand Down Expand Up @@ -357,11 +358,13 @@ object DataFrameColName {
/**
* generate header info (e.g, feature type, feature column name map) for output dataframe of
* feature join or feature generation
*
* @param featureToColumnNameMap map of feature to its column name in the dataframe
* @param inferredFeatureTypeConfigs feature name to inferred feature types
* @return header info for a dataframe that contains the features in featureToColumnNameMap
*/
private def generateHeader(
@VisibleForTesting
def generateHeader(
featureToColumnNameMap: Map[TaggedFeatureName, String],
allAnchoredFeatures: Map[String, FeatureAnchorWithSource],
allDerivedFeatures: Map[String, DerivedFeature],
Expand All @@ -370,13 +373,10 @@ object DataFrameColName {
// if the feature type is unspecified in the anchor config, we will use FeatureTypes.UNSPECIFIED
val anchoredFeatureTypes: Map[String, FeatureTypeConfig] = allAnchoredFeatures.map {
case (featureName, anchorWithSource) =>
val featureTypeOpt = anchorWithSource.featureAnchor.getFeatureTypes.map(types => {
// Get the actual type in the output dataframe, the type is inferred and stored previously, if not specified by users
val inferredType = inferredFeatureTypeConfigs.getOrElse(featureName, FeatureTypeConfig.UNDEFINED_TYPE_CONFIG)
val fType = new FeatureTypeConfig(types.getOrElse(featureName, FeatureTypes.UNSPECIFIED))
if (fType == FeatureTypeConfig.UNDEFINED_TYPE_CONFIG) inferredType else fType
})
val featureType = featureTypeOpt.getOrElse(FeatureTypeConfig.UNDEFINED_TYPE_CONFIG)
val featureTypeOpt = anchorWithSource.featureAnchor.featureTypeConfigs.get(featureName)
// Get the actual type in the output dataframe, the type is inferred and stored previously, if not specified by users
val inferredType = inferredFeatureTypeConfigs.getOrElse(featureName, FeatureTypeConfig.UNDEFINED_TYPE_CONFIG)
val featureType = featureTypeOpt.getOrElse(inferredType)
featureName -> featureType
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.linkedin.feathr.offline.client

import com.linkedin.feathr.common.{DateParam, JoiningFeatureParams, TaggedFeatureName}
import com.linkedin.feathr.common.{DateParam, FeatureTypeConfig, JoiningFeatureParams, TaggedFeatureName}
import com.linkedin.feathr.offline.TestFeathr
import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.mockito.Mockito.when
import org.scalatest.mockito.MockitoSugar.mock
import org.testng.Assert.assertEquals
import org.testng.annotations.Test

Expand Down Expand Up @@ -59,4 +62,26 @@ class TestDataFrameColName extends TestFeathr {
val taggedFeature3 = new TaggedFeatureName("x", "seq_join_a_names")
assertEquals(taggedFeatureToNewColumnNameMap(taggedFeature3)._2, "seq_join_a_names")
}

@Test(description = "Inferred feature type should be honored when user does not provide feature type")
def testGenerateHeader(): Unit = {
val mockFeatureAnchor = mock[FeatureAnchor]
// Mock if the user does not define feature type
when(mockFeatureAnchor.featureTypeConfigs).thenReturn(Map.empty[String, FeatureTypeConfig])

val mockFeatureAnchorWithSource = mock[FeatureAnchorWithSource]
when(mockFeatureAnchorWithSource.featureAnchor).thenReturn(mockFeatureAnchor)
val taggedFeatureName = new TaggedFeatureName("id", "f")
val featureToColumnNameMap: Map[TaggedFeatureName, String] = Map(taggedFeatureName -> "f")
val allAnchoredFeatures: Map[String, FeatureAnchorWithSource] = Map("f" -> mockFeatureAnchorWithSource)
// Mock if the type if inferred to be numeric
val inferredFeatureTypeConfigs: Map[String, FeatureTypeConfig] = Map("f" -> FeatureTypeConfig.NUMERIC_TYPE_CONFIG)
val header = DataFrameColName.generateHeader(
featureToColumnNameMap,
allAnchoredFeatures,
Map(),
inferredFeatureTypeConfigs)
// output should be using the inferred type, i.e. numeric
assertEquals(header.featureInfoMap.get(taggedFeatureName).get.featureType, FeatureTypeConfig.NUMERIC_TYPE_CONFIG)
}
}

0 comments on commit 930640f

Please sign in to comment.