Skip to content

Commit

Permalink
Refactor duplicated online retriever code
Browse files Browse the repository at this point in the history
Signed-off-by: Terence <terencelimxp@gmail.com>
  • Loading branch information
terryyylim committed Oct 13, 2020
1 parent 68973c9 commit aacce71
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import feast.serving.service.ServingServiceV2;
import feast.serving.specs.CachedSpecService;
import feast.storage.api.retriever.OnlineRetrieverV2;
import feast.storage.connectors.redis.retriever.RedisClusterOnlineRetrieverV2;
import feast.storage.connectors.redis.retriever.RedisOnlineRetrieverV2;
import feast.storage.connectors.redis.retriever.*;
import io.opentracing.Tracer;
import java.util.Map;
import org.slf4j.Logger;
Expand All @@ -46,11 +45,13 @@ public ServingServiceV2 servingServiceV2(

switch (storeType) {
case REDIS_CLUSTER:
OnlineRetrieverV2 redisClusterRetriever = RedisClusterOnlineRetrieverV2.create(config);
RedisClientWrapper redisClusterClient = RedisClusterClient.create(config);
OnlineRetrieverV2 redisClusterRetriever = new OnlineRetriever(redisClusterClient);
servingService = new OnlineServingServiceV2(redisClusterRetriever, specService, tracer);
break;
case REDIS:
OnlineRetrieverV2 redisRetriever = RedisOnlineRetrieverV2.create(config);
RedisClientWrapper redisClient = RedisClient.create(config);
OnlineRetrieverV2 redisRetriever = new OnlineRetriever(redisClient);
servingService = new OnlineServingServiceV2(redisRetriever, specService, tracer);
break;
case CASSANDRA:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,66 +18,46 @@

import com.google.common.collect.Lists;
import com.google.protobuf.InvalidProtocolBufferException;
import feast.proto.serving.ServingAPIProto.FeatureReferenceV2;
import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequestV2.EntityRow;
import feast.proto.storage.RedisProto.RedisKeyV2;
import feast.proto.serving.ServingAPIProto;
import feast.proto.storage.RedisProto;
import feast.storage.api.retriever.Feature;
import feast.storage.api.retriever.OnlineRetrieverV2;
import feast.storage.connectors.redis.common.RedisHashDecoder;
import feast.storage.connectors.redis.common.RedisKeyGenerator;
import io.grpc.Status;
import io.lettuce.core.KeyValue;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisURI;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.codec.ByteArrayCodec;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

public class RedisOnlineRetrieverV2 implements OnlineRetrieverV2 {
public class OnlineRetriever implements OnlineRetrieverV2 {

private static final String timestampPrefix = "_ts";
private final RedisAsyncCommands<byte[], byte[]> asyncCommands;
RedisClientWrapper redisClientWrapper;

private RedisOnlineRetrieverV2(StatefulRedisConnection<byte[], byte[]> connection) {
this.asyncCommands = connection.async();

// Disable auto-flushing
this.asyncCommands.setAutoFlushCommands(false);
}

public static OnlineRetrieverV2 create(Map<String, String> config) {

StatefulRedisConnection<byte[], byte[]> connection =
RedisClient.create(
RedisURI.create(config.get("host"), Integer.parseInt(config.get("port"))))
.connect(new ByteArrayCodec());

return new RedisOnlineRetrieverV2(connection);
}

public static OnlineRetrieverV2 create(StatefulRedisConnection<byte[], byte[]> connection) {
return new RedisOnlineRetrieverV2(connection);
public OnlineRetriever(RedisClientWrapper redisClientWrapper) {
this.redisClientWrapper = redisClientWrapper;
}

@Override
public List<List<Optional<Feature>>> getOnlineFeatures(
String project, List<EntityRow> entityRows, List<FeatureReferenceV2> featureReferences) {
String project,
List<ServingAPIProto.GetOnlineFeaturesRequestV2.EntityRow> entityRows,
List<ServingAPIProto.FeatureReferenceV2> featureReferences) {

List<RedisKeyV2> redisKeys = RedisKeyGenerator.buildRedisKeys(project, entityRows);
List<RedisProto.RedisKeyV2> redisKeys = RedisKeyGenerator.buildRedisKeys(project, entityRows);
List<List<Optional<Feature>>> features = getFeaturesFromRedis(redisKeys, featureReferences);

return features;
}

private List<List<Optional<Feature>>> getFeaturesFromRedis(
List<RedisKeyV2> redisKeys, List<FeatureReferenceV2> featureReferences) {
List<RedisProto.RedisKeyV2> redisKeys,
List<ServingAPIProto.FeatureReferenceV2> featureReferences) {
List<List<Optional<Feature>>> features = new ArrayList<>();
// To decode bytes back to Feature Reference
Map<String, FeatureReferenceV2> byteToFeatureReferenceMap = new HashMap<>();
Map<String, ServingAPIProto.FeatureReferenceV2> byteToFeatureReferenceMap = new HashMap<>();

// Serialize using proto
List<byte[]> binaryRedisKeys =
Expand Down Expand Up @@ -106,11 +86,11 @@ private List<List<Optional<Feature>>> getFeaturesFromRedis(
byte[][] featureReferenceWithTsByteArrays =
featureReferenceWithTsByteList.toArray(new byte[0][]);
// Access redis keys and extract features
futures.add(asyncCommands.hmget(binaryRedisKey, featureReferenceWithTsByteArrays));
futures.add(redisClientWrapper.hmget(binaryRedisKey, featureReferenceWithTsByteArrays));
}

// Write all commands to the transport layer
asyncCommands.flushCommands();
redisClientWrapper.flushCommands();

futures.forEach(
future -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2020 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.storage.connectors.redis.retriever;

import io.lettuce.core.KeyValue;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisURI;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.codec.ByteArrayCodec;
import java.util.List;
import java.util.Map;

public class RedisClient implements RedisClientWrapper {

public final RedisAsyncCommands<byte[], byte[]> asyncCommands;

@Override
public RedisFuture<List<KeyValue<byte[], byte[]>>> hmget(byte[] key, byte[]... fields) {
return asyncCommands.hmget(key, fields);
}

@Override
public void flushCommands() {
asyncCommands.flushCommands();
}

private RedisClient(StatefulRedisConnection<byte[], byte[]> connection) {
this.asyncCommands = connection.async();

// Disable auto-flushing
this.asyncCommands.setAutoFlushCommands(false);
}

public static RedisClientWrapper create(Map<String, String> config) {
StatefulRedisConnection<byte[], byte[]> connection =
io.lettuce.core.RedisClient.create(
RedisURI.create(config.get("host"), Integer.parseInt(config.get("port"))))
.connect(new ByteArrayCodec());

return new RedisClient(connection);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2020 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.storage.connectors.redis.retriever;

import io.lettuce.core.*;
import java.util.List;

public interface RedisClientWrapper {
RedisFuture<List<KeyValue<byte[], byte[]>>> hmget(byte[] key, byte[]... fields);

void flushCommands();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2020 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.storage.connectors.redis.retriever;

import feast.storage.connectors.redis.serializer.RedisKeyPrefixSerializerV2;
import feast.storage.connectors.redis.serializer.RedisKeySerializerV2;
import io.lettuce.core.KeyValue;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisURI;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.codec.ByteArrayCodec;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

public class RedisClusterClient implements RedisClientWrapper {

public final RedisAdvancedClusterAsyncCommands<byte[], byte[]> asyncCommands;
public final RedisKeySerializerV2 serializer;
@Nullable public final RedisKeySerializerV2 fallbackSerializer;

@Override
public RedisFuture<List<KeyValue<byte[], byte[]>>> hmget(byte[] key, byte[]... fields) {
return asyncCommands.hmget(key, fields);
}

@Override
public void flushCommands() {
asyncCommands.flushCommands();
}

static class Builder {
private final StatefulRedisClusterConnection<byte[], byte[]> connection;
private final RedisKeySerializerV2 serializer;
@Nullable private RedisKeySerializerV2 fallbackSerializer;

Builder(
StatefulRedisClusterConnection<byte[], byte[]> connection,
RedisKeySerializerV2 serializer) {
this.connection = connection;
this.serializer = serializer;
}

Builder withFallbackSerializer(RedisKeySerializerV2 fallbackSerializer) {
this.fallbackSerializer = fallbackSerializer;
return this;
}

RedisClusterClient build() {
return new RedisClusterClient(this);
}
}

private RedisClusterClient(Builder builder) {
this.asyncCommands = builder.connection.async();
this.serializer = builder.serializer;
this.fallbackSerializer = builder.fallbackSerializer;

// Disable auto-flushing
this.asyncCommands.setAutoFlushCommands(false);
}

public static RedisClientWrapper create(Map<String, String> config) {
List<RedisURI> redisURIList =
Arrays.stream(config.get("connection_string").split(","))
.map(
hostPort -> {
String[] hostPortSplit = hostPort.trim().split(":");
return RedisURI.create(hostPortSplit[0], Integer.parseInt(hostPortSplit[1]));
})
.collect(Collectors.toList());
StatefulRedisClusterConnection<byte[], byte[]> connection =
io.lettuce.core.cluster.RedisClusterClient.create(redisURIList)
.connect(new ByteArrayCodec());

RedisKeySerializerV2 serializer =
new RedisKeyPrefixSerializerV2(config.getOrDefault("key_prefix", ""));

Builder builder = new Builder(connection, serializer);

if (Boolean.parseBoolean(config.getOrDefault("enable_fallback", "false"))) {
RedisKeySerializerV2 fallbackSerializer =
new RedisKeyPrefixSerializerV2(config.getOrDefault("fallback_prefix", ""));
builder = builder.withFallbackSerializer(fallbackSerializer);
}

return builder.build();
}
}
Loading

0 comments on commit aacce71

Please sign in to comment.