diff --git a/caraml-store-serving/build.gradle b/caraml-store-serving/build.gradle index ccdf03f..0d04182 100644 --- a/caraml-store-serving/build.gradle +++ b/caraml-store-serving/build.gradle @@ -8,7 +8,8 @@ dependencies { implementation 'org.apache.commons:commons-lang3:3.10' implementation 'org.apache.avro:avro:1.10.2' implementation platform('com.google.cloud:libraries-bom:26.43.0') - implementation 'com.google.cloud:google-cloud-bigtable:2.40.0' + implementation 'com.google.cloud:google-cloud-bigtable:2.39.2' + implementation 'com.google.cloud.bigtable:bigtable-hbase-2.x:2.14.3' implementation 'commons-codec:commons-codec:1.17.1' implementation 'io.lettuce:lettuce-core:6.2.0.RELEASE' implementation 'io.netty:netty-transport-native-epoll:4.1.52.Final:linux-x86_64' diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java index 9c0609a..f6c6dd2 100644 --- a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java @@ -2,51 +2,66 @@ import com.google.cloud.bigtable.data.v2.BigtableDataClient; import com.google.cloud.bigtable.data.v2.BigtableDataSettings; +import com.google.cloud.bigtable.hbase.BigtableConfiguration; +import com.google.cloud.bigtable.hbase.BigtableOptionsFactory; import dev.caraml.serving.store.OnlineRetriever; -import java.io.IOException; import lombok.Getter; import lombok.Setter; +import org.apache.hadoop.hbase.client.Connection; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.threeten.bp.Duration; +import java.io.IOException; + @Configuration @ConfigurationProperties(prefix = "caraml.store.bigtable") @ConditionalOnProperty(prefix = "caraml.store", name = "active", havingValue = "bigtable") @Getter @Setter public class BigTableStoreConfig { - private String projectId; - private String instanceId; - private String appProfileId; - private Boolean enableClientSideMetrics; - private Long timeoutMs; + private String projectId; + private String instanceId; + private String appProfileId; + private Boolean enableClientSideMetrics; + private Long timeoutMs; + private Boolean isUsingHBaseSDK; + + @Bean + public OnlineRetriever getRetriever() { + // Using HBase SDK + if (isUsingHBaseSDK) { + org.apache.hadoop.conf.Configuration config = BigtableConfiguration.configure(projectId, instanceId); + config.set(BigtableOptionsFactory.APP_PROFILE_ID_KEY, appProfileId); + + Connection connection = BigtableConfiguration.connect(config); + return new HBaseOnlineRetriever(connection); + } - @Bean - public OnlineRetriever getRetriever() { - try { - BigtableDataSettings.Builder builder = - BigtableDataSettings.newBuilder() - .setProjectId(projectId) - .setInstanceId(instanceId) - .setAppProfileId(appProfileId); - if (timeoutMs > 0) { - builder - .stubSettings() - .readRowsSettings() - .retrySettings() - .setTotalTimeout(Duration.ofMillis(timeoutMs)); - } - BigtableDataSettings settings = builder.build(); - if (enableClientSideMetrics) { - BigtableDataSettings.enableBuiltinMetrics(); - } - BigtableDataClient client = BigtableDataClient.create(settings); - return new BigTableOnlineRetriever(client); - } catch (IOException e) { - throw new RuntimeException(e); + // Using BigTable SDK + try { + BigtableDataSettings.Builder builder = + BigtableDataSettings.newBuilder() + .setProjectId(projectId) + .setInstanceId(instanceId) + .setAppProfileId(appProfileId); + if (timeoutMs > 0) { + builder + .stubSettings() + .readRowsSettings() + .retrySettings() + .setTotalTimeout(Duration.ofMillis(timeoutMs)); + } + BigtableDataSettings settings = builder.build(); + if (enableClientSideMetrics) { + BigtableDataSettings.enableBuiltinMetrics(); + } + BigtableDataClient client = BigtableDataClient.create(settings); + return new BigTableOnlineRetriever(client); + } catch (IOException e) { + throw new RuntimeException(e); + } } - } } diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseOnlineRetriever.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseOnlineRetriever.java new file mode 100644 index 0000000..532fb78 --- /dev/null +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseOnlineRetriever.java @@ -0,0 +1,118 @@ +package dev.caraml.serving.store.bigtable; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Timestamp; +import dev.caraml.serving.store.AvroFeature; +import dev.caraml.serving.store.Feature; +import dev.caraml.store.protobuf.serving.ServingServiceProto; + +import java.io.IOException; +import java.util.*; +import java.util.stream.Collectors; + +import org.apache.avro.AvroRuntimeException; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.*; +import org.apache.hadoop.hbase.util.Bytes; + +public class HBaseOnlineRetriever implements SSTableOnlineRetriever { + private final Connection client; + private final HBaseSchemaRegistry schemaRegistry; + + public HBaseOnlineRetriever(Connection client) { + this.client = client; + this.schemaRegistry = new HBaseSchemaRegistry(client); + } + + @Override + public ByteString convertEntityValueToKey(ServingServiceProto.GetOnlineFeaturesRequest.EntityRow entityRow, List entityNames) { + return ByteString.copyFrom(entityNames.stream().sorted().map(entity -> entityRow.getFieldsMap().get(entity)).map(this::valueToString).collect(Collectors.joining("#")).getBytes()); + } + + @Override + public List> convertRowToFeature(String tableName, List rowKeys, Map rows, List featureReferences) { + BinaryDecoder reusedDecoder = DecoderFactory.get().binaryDecoder(new byte[0], null); + + return rowKeys.stream().map(rowKey -> { + if (!rows.containsKey(rowKey)) { + return Collections.emptyList(); + } else { + Result row = rows.get(rowKey); + return featureReferences.stream().map(ServingServiceProto.FeatureReference::getFeatureTable).distinct().map(cf -> row.getColumnCells(cf.getBytes(), null)).filter(ls -> !ls.isEmpty()).flatMap(rowCells -> { + Cell rowCell = rowCells.get(0); // Latest cell + String family = Bytes.toString(rowCell.getFamilyArray()); + ByteString value = ByteString.copyFrom(rowCell.getValueArray()); + + List features; + List localFeatureReferences = featureReferences.stream().filter(featureReference -> featureReference.getFeatureTable().equals(family)).collect(Collectors.toList()); + + try { + features = decodeFeatures(tableName, value, localFeatureReferences, reusedDecoder, rowCell.getTimestamp()); + } catch (IOException e) { + throw new RuntimeException("Failed to decode features from BigTable"); + } + + return features.stream(); + }).collect(Collectors.toList()); + } + }).collect(Collectors.toList()); + } + + @Override + public Map getFeaturesFromSSTable(String tableName, List rowKeys, List columnFamilies) { + try { + Table table = this.client.getTable(TableName.valueOf(tableName)); + + List getList = new ArrayList<>(); + for (ByteString rowKey : rowKeys) { + Get get = new Get(rowKey.toByteArray()); + for (String columnFamily : columnFamilies) { + get.addFamily(columnFamily.getBytes()); + } + getList.add(get); + } + + Result[] rows = table.get(getList); + + Map result = new HashMap<>(); + for (Result row : rows) { + if (row.isEmpty()) { + continue; + } + result.put(ByteString.copyFrom(row.getRow()), row); + } + + return result; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private List decodeFeatures(String tableName, ByteString value, List featureReferences, BinaryDecoder reusedDecoder, long timestamp) throws IOException { + ByteString schemaReferenceBytes = value.substring(0, 4); + byte[] featureValueBytes = value.substring(4).toByteArray(); + + HBaseSchemaRegistry.SchemaReference schemaReference = new HBaseSchemaRegistry.SchemaReference(tableName, schemaReferenceBytes); + + GenericDatumReader reader = this.schemaRegistry.getReader(schemaReference); + + reusedDecoder = DecoderFactory.get().binaryDecoder(featureValueBytes, reusedDecoder); + GenericRecord record = reader.read(null, reusedDecoder); + + return featureReferences.stream().map(featureReference -> { + Object featureValue; + try { + featureValue = record.get(featureReference.getName()); + } catch (AvroRuntimeException e) { + // Feature is not found in schema + return null; + } + return new AvroFeature(featureReference, Timestamp.newBuilder().setSeconds(timestamp / 1000).build(), Objects.requireNonNullElseGet(featureValue, Object::new)); + }).filter(Objects::nonNull).collect(Collectors.toList()); + } +} diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseSchemaRegistry.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseSchemaRegistry.java new file mode 100644 index 0000000..d2856e5 --- /dev/null +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseSchemaRegistry.java @@ -0,0 +1,101 @@ +package dev.caraml.serving.store.bigtable; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.protobuf.ByteString; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.util.Bytes; + +public class HBaseSchemaRegistry { + private final Connection hbaseClient; + private final LoadingCache> cache; + + private static String COLUMN_FAMILY = "metadata"; + private static String QUALIFIER = "avro"; + private static String KEY_PREFIX = "schema#"; + + public static class SchemaReference { + private final String tableName; + private final ByteString schemaHash; + + public SchemaReference(String tableName, ByteString schemaHash) { + this.tableName = tableName; + this.schemaHash = schemaHash; + } + + public String getTableName() { + return tableName; + } + + public ByteString getSchemaHash() { + return schemaHash; + } + + @Override + public int hashCode() { + int result = tableName.hashCode(); + result = 31 * result + schemaHash.hashCode(); + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SchemaReference that = (SchemaReference) o; + + if (!tableName.equals(that.tableName)) return false; + return schemaHash.equals(that.schemaHash); + } + } + + public HBaseSchemaRegistry(Connection hbaseClient) { + this.hbaseClient = hbaseClient; + + CacheLoader> schemaCacheLoader = CacheLoader.from(this::loadReader); + + cache = CacheBuilder.newBuilder().build(schemaCacheLoader); + } + + public GenericDatumReader getReader(SchemaReference reference) { + GenericDatumReader reader; + try { + reader = this.cache.get(reference); + } catch (ExecutionException | CacheLoader.InvalidCacheLoadException e) { + throw new RuntimeException(String.format("Unable to find Schema"), e); + } + return reader; + } + + private GenericDatumReader loadReader(SchemaReference reference) { + try { + Table table = this.hbaseClient.getTable(TableName.valueOf(reference.getTableName())); + + byte[] rowKey = ByteString.copyFrom(KEY_PREFIX.getBytes()).concat(reference.getSchemaHash()).toByteArray(); + Get query = new Get(rowKey); + query.addColumn(COLUMN_FAMILY.getBytes(), QUALIFIER.getBytes()); + + Result result = table.get(query); + + Cell last = result.getColumnLatestCell(COLUMN_FAMILY.getBytes(), QUALIFIER.getBytes()); + Schema schema = new Schema.Parser().parse(Bytes.toString(last.getValueArray())); + return new GenericDatumReader<>(schema); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/caraml-store-serving/src/main/resources/application.yaml b/caraml-store-serving/src/main/resources/application.yaml index 5b01830..20cb4cd 100644 --- a/caraml-store-serving/src/main/resources/application.yaml +++ b/caraml-store-serving/src/main/resources/application.yaml @@ -50,24 +50,24 @@ caraml: readFrom: MASTER # Redis operation timeout in ISO-8601 format timeout: PT0.5S -# # Uncomment to customize netty behaviour -# tcp: -# # Epoll Channel Option: TCP_KEEPIDLE -# keepIdle: 15 -# # Epoll Channel Option: TCP_KEEPINTVL -# keepInterval: 5 -# # Epoll Channel Option: TCP_KEEPCNT -# keepConnection: 3 -# # Epoll Channel Option: TCP_USER_TIMEOUT -# userConnection: 60000 -# # Uncomment to customize redis cluster topology refresh config -# topologyRefresh: -# # enable adaptive topology refresh from all triggers : MOVED_REDIRECT, ASK_REDIRECT, PERSISTENT_RECONNECTS, UNKNOWN_NODE (since 5.1), and UNCOVERED_SLOT (since 5.2) (see also reconnect attempts for the reconnect trigger) -# enableAllAdaptiveTriggerRefresh: true -# # enable periodic refresh -# enablePeriodicRefresh: false -# # topology refresh period in seconds -# refreshPeriodSecond: 30 + # # Uncomment to customize netty behaviour + # tcp: + # # Epoll Channel Option: TCP_KEEPIDLE + # keepIdle: 15 + # # Epoll Channel Option: TCP_KEEPINTVL + # keepInterval: 5 + # # Epoll Channel Option: TCP_KEEPCNT + # keepConnection: 3 + # # Epoll Channel Option: TCP_USER_TIMEOUT + # userConnection: 60000 + # # Uncomment to customize redis cluster topology refresh config + # topologyRefresh: + # # enable adaptive topology refresh from all triggers : MOVED_REDIRECT, ASK_REDIRECT, PERSISTENT_RECONNECTS, UNKNOWN_NODE (since 5.1), and UNCOVERED_SLOT (since 5.2) (see also reconnect attempts for the reconnect trigger) + # enableAllAdaptiveTriggerRefresh: true + # # enable periodic refresh + # enablePeriodicRefresh: false + # # topology refresh period in seconds + # refreshPeriodSecond: 30 bigtable: projectId: gcp-project-name @@ -76,6 +76,7 @@ caraml: enableClientSideMetrics: false # Timeout configuration for BigTable client. Set 0 to use the default client configuration. timeoutMs: 0 + isUsingHBaseSDK: true grpc: server: