Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support retrieval from multiple feature views with different join keys #2835

Merged
merged 5 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions java/serving/src/main/java/feast/serving/registry/Registry.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class Registry {
private Map<String, OnDemandFeatureViewProto.OnDemandFeatureViewSpec>
onDemandFeatureViewNameToSpec;
private final Map<String, FeatureServiceProto.FeatureServiceSpec> featureServiceNameToSpec;
private final Map<String, String> entityNameToJoinKey;

Registry(RegistryProto.Registry registry) {
this.registry = registry;
Expand Down Expand Up @@ -60,6 +61,12 @@ public class Registry {
.collect(
Collectors.toMap(
FeatureServiceProto.FeatureServiceSpec::getName, Function.identity()));
this.entityNameToJoinKey =
registry.getEntitiesList().stream()
.map(EntityProto.Entity::getSpec)
.collect(
Collectors.toMap(
EntityProto.EntitySpecV2::getName, EntityProto.EntitySpecV2::getJoinKey));
}

public RegistryProto.Registry getRegistry() {
Expand Down Expand Up @@ -115,4 +122,12 @@ public FeatureServiceProto.FeatureServiceSpec getFeatureServiceSpec(String name)
}
return spec;
}

public String getEntityJoinKey(String name) {
String joinKey = entityNameToJoinKey.get(name);
if (joinKey == null) {
throw new SpecRetrievalException(String.format("Unable to find entity with name: %s", name));
}
return joinKey;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,8 @@ public Duration getMaxAge(ServingAPIProto.FeatureReferenceV2 featureReference) {
public List<String> getEntitiesList(ServingAPIProto.FeatureReferenceV2 featureReference) {
return getFeatureViewSpec(featureReference).getEntitiesList();
}

public String getEntityJoinKey(String name) {
return this.registry.getEntityJoinKey(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import feast.serving.registry.RegistryRepository;
import feast.serving.util.Metrics;
import feast.storage.api.retriever.OnlineRetrieverV2;
import io.grpc.Status;
import io.opentracing.Span;
import io.opentracing.Tracer;
import java.util.*;
Expand All @@ -51,6 +50,11 @@ public class OnlineServingServiceV2 implements ServingServiceV2 {
private final OnlineTransformationService onlineTransformationService;
private final String project;

public static final String DUMMY_ENTITY_ID = "__dummy_id";
public static final String DUMMY_ENTITY_VAL = "";
public static final ValueProto.Value DUMMY_ENTITY_VALUE =
ValueProto.Value.newBuilder().setStringVal(DUMMY_ENTITY_VAL).build();

public OnlineServingServiceV2(
OnlineRetrieverV2 retriever,
Tracer tracer,
Expand Down Expand Up @@ -103,31 +107,18 @@ public ServingAPIProto.GetOnlineFeaturesResponse getOnlineFeatures(

List<Map<String, ValueProto.Value>> entityRows = getEntityRows(request);

List<String> entityNames;
if (retrievedFeatureReferences.size() > 0) {
entityNames = this.registryRepository.getEntitiesList(retrievedFeatureReferences.get(0));
} else {
throw new RuntimeException("Requested features list must not be empty");
}

Span storageRetrievalSpan = tracer.buildSpan("storageRetrieval").start();
if (storageRetrievalSpan != null) {
storageRetrievalSpan.setTag("entities", entityRows.size());
storageRetrievalSpan.setTag("features", retrievedFeatureReferences.size());
}

List<List<feast.storage.api.retriever.Feature>> features =
retriever.getOnlineFeatures(entityRows, retrievedFeatureReferences, entityNames);
retrieveFeatures(retrievedFeatureReferences, entityRows);

if (storageRetrievalSpan != null) {
storageRetrievalSpan.finish();
}
if (features.size() != entityRows.size()) {
throw Status.INTERNAL
.withDescription(
"The no. of FeatureRow obtained from OnlineRetriever"
+ "does not match no. of entityRow passed.")
.asRuntimeException();
}

Span postProcessingSpan = tracer.buildSpan("postProcessing").start();

Expand Down Expand Up @@ -255,6 +246,84 @@ private List<Map<String, ValueProto.Value>> getEntityRows(
return entityRows;
}

private List<List<feast.storage.api.retriever.Feature>> retrieveFeatures(
List<FeatureReferenceV2> featureReferences, List<Map<String, ValueProto.Value>> entityRows) {
// Prepare feature reference to index mapping. This mapping will be used to arrange the
// retrieved features to the same order as in the input.
if (featureReferences.isEmpty()) {
throw new RuntimeException("Requested features list must not be empty.");
}
Map<FeatureReferenceV2, Integer> featureReferenceToIndexMap =
new HashMap<>(featureReferences.size());
for (int i = 0; i < featureReferences.size(); i++) {
FeatureReferenceV2 featureReference = featureReferences.get(i);
if (featureReferenceToIndexMap.containsKey(featureReference)) {
throw new RuntimeException(
String.format(
"Found duplicate features %s:%s.",
featureReference.getFeatureViewName(), featureReference.getFeatureName()));
}
featureReferenceToIndexMap.put(featureReference, i);
}

// Create placeholders for retrieved features.
List<List<feast.storage.api.retriever.Feature>> features = new ArrayList<>(entityRows.size());
for (int i = 0; i < entityRows.size(); i++) {
List<feast.storage.api.retriever.Feature> featuresPerEntity =
new ArrayList<>(featureReferences.size());
for (int j = 0; j < featureReferences.size(); j++) {
featuresPerEntity.add(null);
}
features.add(featuresPerEntity);
}

// Group feature references by join keys.
Map<String, List<FeatureReferenceV2>> groupNameToFeatureReferencesMap =
featureReferences.stream()
Copy link
Collaborator

@pyalex pyalex Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To speed up this part we might want to extract distinct feature views from all feature references. And then group feature views instead.

Copy link
Contributor Author

@yongheng yongheng Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC grouping by join keys results in the same or less groups (therefore same or more efficient) than grouping by feature view. The is because different feature views can have the same join keys. In L286, this.registryRepository.getEntitiesList(featureReference) internally gets feature view spec first, then gets entity names of the feature view spec, then we find join keys for the entity names.

Actually, I grouped by feature view at the beginning. Then I switched to grouping by join keys in the second commit of this PR, as an optimization.

.collect(
Collectors.groupingBy(
featureReference ->
this.registryRepository.getEntitiesList(featureReference).stream()
.map(this.registryRepository::getEntityJoinKey)
.sorted()
.collect(Collectors.joining(","))));

// Retrieve features one group at a time.
for (List<FeatureReferenceV2> featureReferencesPerGroup :
groupNameToFeatureReferencesMap.values()) {
List<String> entityNames =
this.registryRepository.getEntitiesList(featureReferencesPerGroup.get(0));
List<Map<String, ValueProto.Value>> entityRowsPerGroup = new ArrayList<>(entityRows.size());
for (Map<String, ValueProto.Value> entityRow : entityRows) {
Map<String, ValueProto.Value> entityRowPerGroup = new HashMap<>();
entityNames.stream()
.map(this.registryRepository::getEntityJoinKey)
.forEach(
joinKey -> {
if (joinKey.equals(DUMMY_ENTITY_ID)) {
entityRowPerGroup.put(joinKey, DUMMY_ENTITY_VALUE);
} else {
ValueProto.Value value = entityRow.get(joinKey);
if (value != null) {
entityRowPerGroup.put(joinKey, value);
}
}
});
entityRowsPerGroup.add(entityRowPerGroup);
}
List<List<feast.storage.api.retriever.Feature>> featuresPerGroup =
retriever.getOnlineFeatures(entityRowsPerGroup, featureReferencesPerGroup, entityNames);
for (int i = 0; i < featuresPerGroup.size(); i++) {
for (int j = 0; j < featureReferencesPerGroup.size(); j++) {
int k = featureReferenceToIndexMap.get(featureReferencesPerGroup.get(j));
features.get(i).set(k, featuresPerGroup.get(i).get(j));
}
}
}

return features;
}

private void populateOnDemandFeatures(
List<FeatureReferenceV2> onDemandFeatureReferences,
List<FeatureReferenceV2> onDemandFeatureSources,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,35 @@ public void shouldGetOnlineFeaturesWithStringEntity() {
}
}

@Test
public void shouldGetOnlineFeaturesFromAllFeatureViews() {
Map<String, ValueProto.RepeatedValue> entityRows =
ImmutableMap.of(
"entity",
ValueProto.RepeatedValue.newBuilder()
.addVal(DataGenerator.createStrValue("key-1"))
.build(),
"driver_id",
ValueProto.RepeatedValue.newBuilder()
.addVal(DataGenerator.createInt64Value(1005))
.build());

ImmutableList<String> featureReferences =
ImmutableList.of(
"feature_view_0:feature_0",
"feature_view_0:feature_1",
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:avg_daily_trips");

ServingAPIProto.GetOnlineFeaturesRequest req =
TestUtils.createOnlineFeatureRequest(featureReferences, entityRows);

ServingAPIProto.GetOnlineFeaturesResponse resp = servingStub.getOnlineFeatures(req);

for (final int featureIdx : List.of(0, 1, 2, 3)) {
assertEquals(FieldStatus.PRESENT, resp.getResults(featureIdx).getStatuses(0));
}
}

abstract void updateRegistryFile(RegistryProto.Registry registry);
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfKeysPresent() {
.thenReturn(featureSpecs.get(0));
when(registry.getFeatureSpec(mockedFeatureRows.get(3).getFeatureReference()))
.thenReturn(featureSpecs.get(1));
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");

when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));

Expand Down Expand Up @@ -237,6 +239,8 @@ public void shouldReturnResponseWithUnsetValuesAndMetadataIfKeysNotPresent() {
.thenReturn(featureSpecs.get(0));
when(registry.getFeatureSpec(mockedFeatureRows.get(1).getFeatureReference()))
.thenReturn(featureSpecs.get(1));
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");

when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));

Expand Down Expand Up @@ -314,6 +318,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfMaxAgeIsExceeded() {
.thenReturn(featureSpecs.get(1));
when(registry.getFeatureSpec(mockedFeatureRows.get(5).getFeatureReference()))
.thenReturn(featureSpecs.get(0));
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");

when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));

Expand Down