diff --git a/caraml-store-serving/build.gradle b/caraml-store-serving/build.gradle index ccdf03f..0d04182 100644 --- a/caraml-store-serving/build.gradle +++ b/caraml-store-serving/build.gradle @@ -8,7 +8,8 @@ dependencies { implementation 'org.apache.commons:commons-lang3:3.10' implementation 'org.apache.avro:avro:1.10.2' implementation platform('com.google.cloud:libraries-bom:26.43.0') - implementation 'com.google.cloud:google-cloud-bigtable:2.40.0' + implementation 'com.google.cloud:google-cloud-bigtable:2.39.2' + implementation 'com.google.cloud.bigtable:bigtable-hbase-2.x:2.14.3' implementation 'commons-codec:commons-codec:1.17.1' implementation 'io.lettuce:lettuce-core:6.2.0.RELEASE' implementation 'io.netty:netty-transport-native-epoll:4.1.52.Final:linux-x86_64' diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BaseSchemaRegistry.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BaseSchemaRegistry.java new file mode 100644 index 0000000..60f400d --- /dev/null +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BaseSchemaRegistry.java @@ -0,0 +1,65 @@ +package dev.caraml.serving.store.bigtable; + +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.protobuf.ByteString; +import java.util.concurrent.ExecutionException; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; + +public abstract class BaseSchemaRegistry { + protected LoadingCache> cache = null; + + protected static String COLUMN_FAMILY = "metadata"; + protected static String QUALIFIER = "avro"; + protected static String KEY_PREFIX = "schema#"; + public static final int SCHEMA_REFERENCE_LENGTH = 4; + + public static class SchemaReference { + private final String tableName; + private final ByteString schemaHash; + + public SchemaReference(String tableName, ByteString schemaHash) { + this.tableName = tableName; + this.schemaHash = schemaHash; + } + + public String getTableName() { + return tableName; + } + + public ByteString getSchemaHash() { + return schemaHash; + } + + @Override + public int hashCode() { + int result = tableName.hashCode(); + result = 31 * result + schemaHash.hashCode(); + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SchemaReference that = (SchemaReference) o; + + if (!tableName.equals(that.tableName)) return false; + return schemaHash.equals(that.schemaHash); + } + } + + public GenericDatumReader getReader(SchemaReference reference) { + GenericDatumReader reader; + try { + reader = this.cache.get(reference); + } catch (ExecutionException | CacheLoader.InvalidCacheLoadException e) { + throw new RuntimeException(String.format("Unable to find Schema"), e); + } + return reader; + } + + public abstract GenericDatumReader loadReader(SchemaReference reference); +} diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetriever.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetriever.java index c10009b..921784b 100644 --- a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetriever.java +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetriever.java @@ -161,8 +161,10 @@ private List decodeFeatures( BinaryDecoder reusedDecoder, long timestamp) throws IOException { - ByteString schemaReferenceBytes = value.substring(0, 4); - byte[] featureValueBytes = value.substring(4).toByteArray(); + ByteString schemaReferenceBytes = + value.substring(0, BigTableSchemaRegistry.SCHEMA_REFERENCE_LENGTH); + byte[] featureValueBytes = + value.substring(BigTableSchemaRegistry.SCHEMA_REFERENCE_LENGTH).toByteArray(); BigTableSchemaRegistry.SchemaReference schemaReference = new BigTableSchemaRegistry.SchemaReference(tableName, schemaReferenceBytes); diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableSchemaRegistry.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableSchemaRegistry.java index d37ae0e..054284b 100644 --- a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableSchemaRegistry.java +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableSchemaRegistry.java @@ -6,57 +6,14 @@ import com.google.cloud.bigtable.data.v2.models.RowCell; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; import com.google.common.collect.Iterables; import com.google.protobuf.ByteString; -import java.util.concurrent.ExecutionException; import org.apache.avro.Schema; import org.apache.avro.generic.GenericDatumReader; import org.apache.avro.generic.GenericRecord; -public class BigTableSchemaRegistry { +public class BigTableSchemaRegistry extends BaseSchemaRegistry { private final BigtableDataClient client; - private final LoadingCache> cache; - - private static String COLUMN_FAMILY = "metadata"; - private static String QUALIFIER = "avro"; - private static String KEY_PREFIX = "schema#"; - - public static class SchemaReference { - private final String tableName; - private final ByteString schemaHash; - - public SchemaReference(String tableName, ByteString schemaHash) { - this.tableName = tableName; - this.schemaHash = schemaHash; - } - - public String getTableName() { - return tableName; - } - - public ByteString getSchemaHash() { - return schemaHash; - } - - @Override - public int hashCode() { - int result = tableName.hashCode(); - result = 31 * result + schemaHash.hashCode(); - return result; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - SchemaReference that = (SchemaReference) o; - - if (!tableName.equals(that.tableName)) return false; - return schemaHash.equals(that.schemaHash); - } - } public BigTableSchemaRegistry(BigtableDataClient client) { this.client = client; @@ -67,17 +24,8 @@ public BigTableSchemaRegistry(BigtableDataClient client) { cache = CacheBuilder.newBuilder().build(schemaCacheLoader); } - public GenericDatumReader getReader(SchemaReference reference) { - GenericDatumReader reader; - try { - reader = this.cache.get(reference); - } catch (ExecutionException | CacheLoader.InvalidCacheLoadException e) { - throw new RuntimeException(String.format("Unable to find Schema"), e); - } - return reader; - } - - private GenericDatumReader loadReader(SchemaReference reference) { + @Override + public GenericDatumReader loadReader(SchemaReference reference) { Row row = client.readRow( reference.getTableName(), diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java index 9c0609a..bd80774 100644 --- a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/BigTableStoreConfig.java @@ -2,10 +2,13 @@ import com.google.cloud.bigtable.data.v2.BigtableDataClient; import com.google.cloud.bigtable.data.v2.BigtableDataSettings; +import com.google.cloud.bigtable.hbase.BigtableConfiguration; +import com.google.cloud.bigtable.hbase.BigtableOptionsFactory; import dev.caraml.serving.store.OnlineRetriever; import java.io.IOException; import lombok.Getter; import lombok.Setter; +import org.apache.hadoop.hbase.client.Connection; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.context.annotation.Bean; @@ -23,10 +26,22 @@ public class BigTableStoreConfig { private String appProfileId; private Boolean enableClientSideMetrics; private Long timeoutMs; + private Boolean isUsingHBaseSDK; @Bean public OnlineRetriever getRetriever() { try { + // Using HBase SDK + if (isUsingHBaseSDK) { + org.apache.hadoop.conf.Configuration config = + BigtableConfiguration.configure(projectId, instanceId); + config.set(BigtableOptionsFactory.APP_PROFILE_ID_KEY, appProfileId); + + Connection connection = BigtableConfiguration.connect(config); + return new HBaseOnlineRetriever(connection); + } + + // Using BigTable SDK BigtableDataSettings.Builder builder = BigtableDataSettings.newBuilder() .setProjectId(projectId) @@ -45,6 +60,7 @@ public OnlineRetriever getRetriever() { } BigtableDataClient client = BigtableDataClient.create(settings); return new BigTableOnlineRetriever(client); + } catch (IOException e) { throw new RuntimeException(e); } diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseOnlineRetriever.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseOnlineRetriever.java new file mode 100644 index 0000000..3dbcc12 --- /dev/null +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseOnlineRetriever.java @@ -0,0 +1,222 @@ +package dev.caraml.serving.store.bigtable; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Timestamp; +import dev.caraml.serving.store.AvroFeature; +import dev.caraml.serving.store.Feature; +import dev.caraml.store.protobuf.serving.ServingServiceProto; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.avro.AvroRuntimeException; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.*; + +public class HBaseOnlineRetriever implements SSTableOnlineRetriever { + private final Connection client; + private final HBaseSchemaRegistry schemaRegistry; + + public HBaseOnlineRetriever(Connection client) { + this.client = client; + this.schemaRegistry = new HBaseSchemaRegistry(client); + } + + /** + * Generate Hbase key in the form of entity values joined by #. + * + * @param entityRow Single EntityRow representation in feature retrieval call + * @param entityNames List of entities related to feature references in retrieval call + * @return + */ + @Override + public ByteString convertEntityValueToKey( + ServingServiceProto.GetOnlineFeaturesRequest.EntityRow entityRow, List entityNames) { + return ByteString.copyFrom( + entityNames.stream() + .sorted() + .map(entity -> entityRow.getFieldsMap().get(entity)) + .map(this::valueToString) + .collect(Collectors.joining("#")) + .getBytes()); + } + + /** + * Converts rowCell feature into @NativeFeature type, HBase specific implementation + * + * @param tableName Name of SSTable + * @param rowKeys List of keys of rows to retrieve + * @param rows Map of rowKey to Row related to it + * @param featureReferences List of feature references + * @return List of List of Features associated with respective rowKey + */ + @Override + public List> convertRowToFeature( + String tableName, + List rowKeys, + Map rows, + List featureReferences) { + BinaryDecoder reusedDecoder = DecoderFactory.get().binaryDecoder(new byte[0], null); + + return rowKeys.stream() + .map( + rowKey -> { + if (!rows.containsKey(rowKey)) { + return Collections.emptyList(); + } + + Result row = rows.get(rowKey); + return featureReferences.stream() + .map(ServingServiceProto.FeatureReference::getFeatureTable) + .distinct() + .map(cf -> row.getColumnCells(cf.getBytes(), null)) + .filter(ls -> !ls.isEmpty()) + .flatMap( + rowCells -> + this.convertRowCellsToFeatures( + featureReferences, reusedDecoder, tableName, rowCells)) + .collect(Collectors.toList()); + }) + .collect(Collectors.toList()); + } + + /** + * Converts rowCells feature into stream @NativeFeature type + * + * @param featureReferences List of feature references + * @param reusedDecoder Decoder for decoding feature values + * @param tableName Name of SSTable + * @param rowCells row cells data from SSTable + * @return Stream of @NativeFeature + * @throws RuntimeException failed to decode features + */ + private Stream convertRowCellsToFeatures( + List featureReferences, + BinaryDecoder reusedDecoder, + String tableName, + List rowCells) { + + Cell rowCell = rowCells.get(0); // Latest cell + ByteBuffer valueBuffer = HBaseSchemaRegistry.GetValueByteBufferFromRowCell(rowCell); + ByteBuffer familyBuffer = + ByteBuffer.wrap(rowCell.getFamilyArray()) + .position(rowCell.getFamilyOffset()) + .limit(rowCell.getFamilyOffset() + rowCell.getFamilyLength()) + .slice(); + String family = ByteString.copyFrom(familyBuffer).toStringUtf8(); + ByteString value = ByteString.copyFrom(valueBuffer); + + List features; + List localFeatureReferences = + featureReferences.stream() + .filter(featureReference -> featureReference.getFeatureTable().equals(family)) + .collect(Collectors.toList()); + + try { + features = + decodeFeatures( + tableName, value, localFeatureReferences, reusedDecoder, rowCell.getTimestamp()); + } catch (IOException e) { + throw new RuntimeException("Failed to decode features from BigTable"); + } + + return features.stream(); + } + + /** + * Retrieve rows with required column families for each row entity by sending batch Get request, + * HBase specific implementation + * + * @param tableName Name of SSTable + * @param rowKeys List of keys of rows to retrieve + * @param columnFamilies List of column names + * @return + */ + @Override + public Map getFeaturesFromSSTable( + String tableName, List rowKeys, List columnFamilies) { + try { + Table table = this.client.getTable(TableName.valueOf(tableName)); + + // construct query get list + List queryGetList = new ArrayList<>(); + rowKeys.forEach( + rowKey -> { + Get get = new Get(rowKey.toByteArray()); + columnFamilies.forEach(cf -> get.addFamily(cf.getBytes())); + + queryGetList.add(get); + }); + + // fetch data from table + Result[] rows = table.get(queryGetList); + + // construct result + Map result = new HashMap<>(); + Arrays.stream(rows) + .filter(row -> !row.isEmpty()) + .forEach(row -> result.put(ByteString.copyFrom(row.getRow()), row)); + + return result; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Decode features from Avro serialized bytes + * + * @param tableName Name of Hbase table + * @param value Value of HBase cell where first 4 bytes represents the schema reference and the + * remaining bytes represent the avro-serialized features + * @param featureReferences List of feature references + * @param reusedDecoder Decoder for decoding feature values + * @param timestamp Timesttamp of rowcell + * @return @NativeFeature with retrieved value stored in Hbase Cell + * @throws IOException + */ + private List decodeFeatures( + String tableName, + ByteString value, + List featureReferences, + BinaryDecoder reusedDecoder, + long timestamp) + throws IOException { + ByteString schemaReferenceBytes = + value.substring(0, HBaseSchemaRegistry.SCHEMA_REFERENCE_LENGTH); + byte[] featureValueBytes = + value.substring(HBaseSchemaRegistry.SCHEMA_REFERENCE_LENGTH).toByteArray(); + + HBaseSchemaRegistry.SchemaReference schemaReference = + new HBaseSchemaRegistry.SchemaReference(tableName, schemaReferenceBytes); + + GenericDatumReader reader = this.schemaRegistry.getReader(schemaReference); + + reusedDecoder = DecoderFactory.get().binaryDecoder(featureValueBytes, reusedDecoder); + GenericRecord record = reader.read(null, reusedDecoder); + + return featureReferences.stream() + .map( + featureReference -> { + Object featureValue; + try { + featureValue = record.get(featureReference.getName()); + } catch (AvroRuntimeException e) { + // Feature is not found in schema + return null; + } + return new AvroFeature( + featureReference, + Timestamp.newBuilder().setSeconds(timestamp / 1000).build(), + Objects.requireNonNullElseGet(featureValue, Object::new)); + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } +} diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseSchemaRegistry.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseSchemaRegistry.java new file mode 100644 index 0000000..8fc4612 --- /dev/null +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseSchemaRegistry.java @@ -0,0 +1,63 @@ +package dev.caraml.serving.store.bigtable; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.protobuf.ByteString; +import java.io.IOException; +import java.nio.ByteBuffer; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Table; + +public class HBaseSchemaRegistry extends BaseSchemaRegistry { + private final Connection hbaseClient; + + public HBaseSchemaRegistry(Connection hbaseClient) { + this.hbaseClient = hbaseClient; + + CacheLoader> schemaCacheLoader = + CacheLoader.from(this::loadReader); + + cache = CacheBuilder.newBuilder().build(schemaCacheLoader); + } + + @Override + public GenericDatumReader loadReader(SchemaReference reference) { + try { + Table table = this.hbaseClient.getTable(TableName.valueOf(reference.getTableName())); + + byte[] rowKey = + ByteString.copyFrom(KEY_PREFIX.getBytes()) + .concat(reference.getSchemaHash()) + .toByteArray(); + Get query = new Get(rowKey); + query.addColumn(COLUMN_FAMILY.getBytes(), QUALIFIER.getBytes()); + + Result result = table.get(query); + + Cell last = result.getColumnLatestCell(COLUMN_FAMILY.getBytes(), QUALIFIER.getBytes()); + if (last == null) { + // NOTE: this should never happen + throw new RuntimeException("Schema not found"); + } + ByteBuffer schemaBuffer = GetValueByteBufferFromRowCell(last); + Schema schema = new Schema.Parser().parse(ByteString.copyFrom(schemaBuffer).toStringUtf8()); + return new GenericDatumReader<>(schema); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static ByteBuffer GetValueByteBufferFromRowCell(Cell cell) { + return ByteBuffer.wrap(cell.getValueArray()) + .position(cell.getValueOffset()) + .limit(cell.getValueOffset() + cell.getValueLength()) + .slice(); + } +} diff --git a/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseStoreConfig.java b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseStoreConfig.java new file mode 100644 index 0000000..d36203c --- /dev/null +++ b/caraml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseStoreConfig.java @@ -0,0 +1,39 @@ +package dev.caraml.serving.store.bigtable; + +import dev.caraml.serving.store.OnlineRetriever; +import java.io.IOException; +import lombok.Getter; +import lombok.Setter; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.ConnectionFactory; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +@ConfigurationProperties(prefix = "caraml.store.hbase") +@ConditionalOnProperty(prefix = "caraml.store", name = "active", havingValue = "hbase") +@Getter +@Setter +public class HBaseStoreConfig { + private String zookeeperQuorum; + private String zookeeperClientPort; + + @Bean + public OnlineRetriever getRetriever() { + org.apache.hadoop.conf.Configuration conf; + conf = HBaseConfiguration.create(); + conf.set("hbase.zookeeper.quorum", zookeeperQuorum); + conf.set("hbase.zookeeper.property.clientPort", zookeeperClientPort); + Connection connection; + try { + connection = ConnectionFactory.createConnection(conf); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return new HBaseOnlineRetriever(connection); + } +} diff --git a/caraml-store-serving/src/main/resources/application.yaml b/caraml-store-serving/src/main/resources/application.yaml index 5b01830..e15d3a4 100644 --- a/caraml-store-serving/src/main/resources/application.yaml +++ b/caraml-store-serving/src/main/resources/application.yaml @@ -33,41 +33,41 @@ caraml: maxExpectedCount: 150 store: - # Active store. Possible values: [redisCluster, redis, bigtable] + # Active store. Possible values: [redisCluster, redis, bigtable, hbase] active: redis - - redis: - host: localhost - port: 6379 - password: "" - ssl: false - - redisCluster: - # Connection string specifies the host:port of Redis instances in the redis cluster. - connectionString: "localhost:7000,localhost:7001,localhost:7002,localhost:7003,localhost:7004,localhost:7005" - # Password authentication. Empty string if password is not set. - password: "" - readFrom: MASTER - # Redis operation timeout in ISO-8601 format - timeout: PT0.5S -# # Uncomment to customize netty behaviour -# tcp: -# # Epoll Channel Option: TCP_KEEPIDLE -# keepIdle: 15 -# # Epoll Channel Option: TCP_KEEPINTVL -# keepInterval: 5 -# # Epoll Channel Option: TCP_KEEPCNT -# keepConnection: 3 -# # Epoll Channel Option: TCP_USER_TIMEOUT -# userConnection: 60000 -# # Uncomment to customize redis cluster topology refresh config -# topologyRefresh: -# # enable adaptive topology refresh from all triggers : MOVED_REDIRECT, ASK_REDIRECT, PERSISTENT_RECONNECTS, UNKNOWN_NODE (since 5.1), and UNCOVERED_SLOT (since 5.2) (see also reconnect attempts for the reconnect trigger) -# enableAllAdaptiveTriggerRefresh: true -# # enable periodic refresh -# enablePeriodicRefresh: false -# # topology refresh period in seconds -# refreshPeriodSecond: 30 + # + # redis: + # host: localhost + # port: 6379 + # password: "" + # ssl: false + # + # redisCluster: + # # Connection string specifies the host:port of Redis instances in the redis cluster. + # connectionString: "localhost:7000,localhost:7001,localhost:7002,localhost:7003,localhost:7004,localhost:7005" + # # Password authentication. Empty string if password is not set. + # password: "" + # readFrom: MASTER + # # Redis operation timeout in ISO-8601 format + # timeout: PT0.5S + # # Uncomment to customize netty behaviour + # tcp: + # # Epoll Channel Option: TCP_KEEPIDLE + # keepIdle: 15 + # # Epoll Channel Option: TCP_KEEPINTVL + # keepInterval: 5 + # # Epoll Channel Option: TCP_KEEPCNT + # keepConnection: 3 + # # Epoll Channel Option: TCP_USER_TIMEOUT + # userConnection: 60000 + # # Uncomment to customize redis cluster topology refresh config + # topologyRefresh: + # # enable adaptive topology refresh from all triggers : MOVED_REDIRECT, ASK_REDIRECT, PERSISTENT_RECONNECTS, UNKNOWN_NODE (since 5.1), and UNCOVERED_SLOT (since 5.2) (see also reconnect attempts for the reconnect trigger) + # enableAllAdaptiveTriggerRefresh: true + # # enable periodic refresh + # enablePeriodicRefresh: false + # # topology refresh period in seconds + # refreshPeriodSecond: 30 bigtable: projectId: gcp-project-name @@ -76,6 +76,11 @@ caraml: enableClientSideMetrics: false # Timeout configuration for BigTable client. Set 0 to use the default client configuration. timeoutMs: 0 + isUsingHBaseSDK: true + + hbase: + zookeeperQuorum: 127.0.0.1 + zookeeperClientPort: 2181 grpc: server: @@ -95,4 +100,4 @@ spring: logging: level: - root: "info" \ No newline at end of file + root: "info" diff --git a/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetrieverTest.java b/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetrieverTest.java index 7b51c5f..0e9fabe 100644 --- a/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetrieverTest.java +++ b/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/BigTableOnlineRetrieverTest.java @@ -8,6 +8,8 @@ import com.google.cloud.bigtable.data.v2.BigtableDataClient; import com.google.cloud.bigtable.data.v2.BigtableDataSettings; import com.google.cloud.bigtable.data.v2.models.RowMutation; +import com.google.cloud.bigtable.hbase.BigtableConfiguration; +import com.google.cloud.bigtable.hbase.BigtableOptionsFactory; import com.google.common.hash.Hashing; import com.google.protobuf.ByteString; import dev.caraml.serving.store.Feature; @@ -26,6 +28,8 @@ import org.apache.avro.generic.GenericRecordBuilder; import org.apache.avro.io.Encoder; import org.apache.avro.io.EncoderFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.client.Connection; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.testcontainers.containers.GenericContainer; @@ -40,6 +44,7 @@ public class BigTableOnlineRetrieverTest { static final Integer BIGTABLE_EMULATOR_PORT = 8086; static final String FEAST_PROJECT = "default"; static BigtableDataClient client; + static Connection hbaseClient; static BigtableTableAdminClient adminClient; @Container @@ -74,6 +79,11 @@ public static void setup() throws IOException { .setProjectId(PROJECT_ID) .setInstanceId(INSTANCE_ID) .build()); + Configuration config = BigtableConfiguration.configure(PROJECT_ID, INSTANCE_ID); + config.set( + BigtableOptionsFactory.BIGTABLE_EMULATOR_HOST_KEY, + "localhost:" + bigtableEmulator.getMappedPort(BIGTABLE_EMULATOR_PORT)); + hbaseClient = BigtableConfiguration.connect(config); ingestData(); } @@ -227,4 +237,40 @@ public void shouldFilterOutMissingFeatureRef() { assertEquals(1, features.size()); assertEquals(0, features.get(0).size()); } + + @Test + public void shouldRetrieveFeaturesSuccessfullyWhenUsingHbase() { + HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient); + List featureReferences = + Stream.of("trip_cost", "trip_distance") + .map(f -> FeatureReference.newBuilder().setFeatureTable("rides").setName(f).build()) + .toList(); + List entityNames = List.of("driver"); + List entityRows = + List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100)); + List> featuresForRows = + retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames); + assertEquals(1, featuresForRows.size()); + List features = featuresForRows.get(0); + assertEquals(2, features.size()); + assertEquals(5L, features.get(0).getFeatureValue(ValueType.Enum.INT64).getInt64Val()); + assertEquals(featureReferences.get(0), features.get(0).getFeatureReference()); + assertEquals(3.5, features.get(1).getFeatureValue(ValueType.Enum.DOUBLE).getDoubleVal()); + assertEquals(featureReferences.get(1), features.get(1).getFeatureReference()); + } + + @Test + public void shouldFilterOutMissingFeatureRefUsingHbase() { + HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient); + List featureReferences = + List.of( + FeatureReference.newBuilder().setFeatureTable("rides").setName("not_exists").build()); + List entityNames = List.of("driver"); + List entityRows = + List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100)); + List> features = + retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames); + assertEquals(1, features.size()); + assertEquals(0, features.get(0).size()); + } } diff --git a/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/GenericHbase2Container.java b/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/GenericHbase2Container.java new file mode 100644 index 0000000..58fa2bb --- /dev/null +++ b/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/GenericHbase2Container.java @@ -0,0 +1,43 @@ +package dev.caraml.serving.store.bigtable; + +import java.time.Duration; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.utility.DockerImageName; + +public class GenericHbase2Container extends GenericContainer { + + private final String hostName = "hbase-docker"; + public final Configuration hbase2Configuration = HBaseConfiguration.create(); + + public GenericHbase2Container() { + super(DockerImageName.parse("dajobe/hbase:latest")); + withCreateContainerCmdModifier( + cmd -> { + cmd.withHostName(hostName); + }); + + withNetworkMode("host"); + withEnv("HBASE_DOCKER_HOSTNAME", "127.0.0.1"); + + waitingFor(Wait.forLogMessage(".*master.HMaster: Master has completed initialization.*", 1)); + withStartupTimeout(Duration.ofMinutes(10)); + } + + @Override + protected void doStart() { + super.doStart(); + + hbase2Configuration.set("hbase.client.pause", "200"); + hbase2Configuration.set("hbase.client.retries.number", "10"); + hbase2Configuration.set("hbase.rpc.timeout", "3000"); + hbase2Configuration.set("hbase.client.operation.timeout", "3000"); + hbase2Configuration.set("hbase.rpc.timeout", "3000"); + hbase2Configuration.set("hbase.client.scanner.timeout.period", "10000"); + hbase2Configuration.set("zookeeper.session.timeout", "10000"); + hbase2Configuration.set("hbase.zookeeper.quorum", "localhost"); + hbase2Configuration.set("hbase.zookeeper.property.clientPort", "2181"); + } +} diff --git a/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/HbaseOnlineRetrieverTest.java b/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/HbaseOnlineRetrieverTest.java new file mode 100644 index 0000000..45fff5f --- /dev/null +++ b/caraml-store-serving/src/test/java/dev/caraml/serving/store/bigtable/HbaseOnlineRetrieverTest.java @@ -0,0 +1,207 @@ +package dev.caraml.serving.store.bigtable; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.common.hash.Hashing; +import dev.caraml.serving.store.Feature; +import dev.caraml.store.protobuf.serving.ServingServiceProto; +import dev.caraml.store.protobuf.types.ValueProto; +import dev.caraml.store.testutils.it.DataGenerator; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.List; +import java.util.stream.Stream; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.*; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +@Testcontainers +public class HbaseOnlineRetrieverTest { + static Connection hbaseClient; + static HBaseAdmin admin; + static Configuration hbaseConfiguration = HBaseConfiguration.create(); + static final String FEAST_PROJECT = "default"; + + @Container public static GenericHbase2Container hbase = new GenericHbase2Container(); + + @BeforeAll + public static void setup() throws IOException { + hbaseClient = ConnectionFactory.createConnection(hbase.hbase2Configuration); + admin = (HBaseAdmin) hbaseClient.getAdmin(); + ingestData(); + } + + private static void ingestData() throws IOException { + String featureTableName = "rides"; + + /** Single Entity Ingestion Workflow */ + Schema schema = + SchemaBuilder.record("DriverData") + .namespace(featureTableName) + .fields() + .requiredLong("trip_cost") + .requiredDouble("trip_distance") + .nullableString("trip_empty", "null") + .requiredString("trip_wrong_type") + .endRecord(); + createTable(FEAST_PROJECT, List.of("driver"), List.of(featureTableName)); + insertSchema(FEAST_PROJECT, List.of("driver"), schema); + + GenericRecord record = + new GenericRecordBuilder(schema) + .set("trip_cost", 5L) + .set("trip_distance", 3.5) + .set("trip_empty", null) + .set("trip_wrong_type", "test") + .build(); + String entityKey = String.valueOf(DataGenerator.createInt64Value(1).getInt64Val()); + insertRow(FEAST_PROJECT, List.of("driver"), entityKey, featureTableName, schema, record); + } + + private static String getTableName(String project, List entityNames) { + return String.format("%s__%s", project, String.join("__", entityNames)); + } + + private static byte[] serializedSchemaReference(Schema schema) { + return Hashing.murmur3_32().hashBytes(schema.toString().getBytes()).asBytes(); + } + + private static void createTable( + String project, List entityNames, List featureTables) { + String tableName = getTableName(project, entityNames); + + List columnFamilies = + Stream.concat(featureTables.stream(), Stream.of("metadata")).toList(); + TableDescriptorBuilder tb = TableDescriptorBuilder.newBuilder(TableName.valueOf(tableName)); + columnFamilies.forEach(cf -> tb.setColumnFamily(ColumnFamilyDescriptorBuilder.of(cf))); + try { + if (admin.tableExists(TableName.valueOf(tableName))) { + return; + } + admin.createTable(tb.build()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void insertSchema(String project, List entityNames, Schema schema) + throws IOException { + String tableName = getTableName(project, entityNames); + byte[] schemaReference = serializedSchemaReference(schema); + byte[] schemaKey = createSchemaKey(schemaReference); + Table table = hbaseClient.getTable(TableName.valueOf(tableName)); + Put put = new Put(schemaKey); + put.addColumn("metadata".getBytes(), "avro".getBytes(), schema.toString().getBytes()); + table.put(put); + table.close(); + } + + private static byte[] createSchemaKey(byte[] schemaReference) throws IOException { + String schemaKeyPrefix = "schema#"; + ByteArrayOutputStream concatOutputStream = new ByteArrayOutputStream(); + concatOutputStream.write(schemaKeyPrefix.getBytes()); + concatOutputStream.write(schemaReference); + return concatOutputStream.toByteArray(); + } + + private static byte[] createEntityValue(Schema schema, GenericRecord record) throws IOException { + byte[] schemaReference = serializedSchemaReference(schema); + // Entity-Feature Row + byte[] avroSerializedFeatures = recordToAvro(record, schema); + + ByteArrayOutputStream concatOutputStream = new ByteArrayOutputStream(); + concatOutputStream.write(schemaReference); + concatOutputStream.write("".getBytes()); + concatOutputStream.write(avroSerializedFeatures); + byte[] entityFeatureValue = concatOutputStream.toByteArray(); + + return entityFeatureValue; + } + + private static byte[] recordToAvro(GenericRecord datum, Schema schema) throws IOException { + GenericDatumWriter writer = new GenericDatumWriter<>(schema); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + Encoder encoder = EncoderFactory.get().binaryEncoder(output, null); + writer.write(datum, encoder); + encoder.flush(); + + return output.toByteArray(); + } + + private static void insertRow( + String project, + List entityNames, + String entityKey, + String featureTableName, + Schema schema, + GenericRecord record) + throws IOException { + byte[] entityFeatureValue = createEntityValue(schema, record); + String tableName = getTableName(project, entityNames); + + // Update Compound Entity-Feature Row + Table table = hbaseClient.getTable(TableName.valueOf(tableName)); + Put put = new Put(entityKey.getBytes()); + put.addColumn(featureTableName.getBytes(), "".getBytes(), entityFeatureValue); + table.put(put); + table.close(); + } + + @Test + public void shouldRetrieveFeaturesSuccessfully() { + HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient); + List featureReferences = + Stream.of("trip_cost", "trip_distance") + .map( + f -> + ServingServiceProto.FeatureReference.newBuilder() + .setFeatureTable("rides") + .setName(f) + .build()) + .toList(); + List entityNames = List.of("driver"); + List entityRows = + List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100)); + List> featuresForRows = + retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames); + assertEquals(1, featuresForRows.size()); + List features = featuresForRows.get(0); + assertEquals(2, features.size()); + assertEquals( + 5L, features.get(0).getFeatureValue(ValueProto.ValueType.Enum.INT64).getInt64Val()); + assertEquals(featureReferences.get(0), features.get(0).getFeatureReference()); + assertEquals( + 3.5, features.get(1).getFeatureValue(ValueProto.ValueType.Enum.DOUBLE).getDoubleVal()); + assertEquals(featureReferences.get(1), features.get(1).getFeatureReference()); + } + + @Test + public void shouldFilterOutMissingFeatureRefUsingHbase() { + HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient); + List featureReferences = + List.of( + ServingServiceProto.FeatureReference.newBuilder() + .setFeatureTable("rides") + .setName("not_exists") + .build()); + List entityNames = List.of("driver"); + List entityRows = + List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100)); + List> features = + retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames); + assertEquals(1, features.size()); + assertEquals(0, features.get(0).size()); + } +} diff --git a/caraml-store-spark/docker/Dockerfile b/caraml-store-spark/docker/Dockerfile index 710d9b8..bb2c48b 100644 --- a/caraml-store-spark/docker/Dockerfile +++ b/caraml-store-spark/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM apache/spark-py:v3.1.3 +FROM --platform=linux/amd64 apache/spark-py:v3.1.3 ARG GCS_CONNECTOR_VERSION=2.2.5 ARG BQ_CONNECTOR_VERSION=0.27.1 diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala index cb854aa..82f1d08 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala @@ -33,6 +33,15 @@ object BasePipeline { conf .set("spark.bigtable.projectId", projectId) .set("spark.bigtable.instanceId", instanceId) + case HBaseConfig(zookeeperQuorum, zookeeperPort, hbaseProperties) => + conf + .set("spark.hbase.zookeeper.quorum", zookeeperQuorum) + .set("spark.hbase.zookeeper.port", zookeeperPort.toString) + .set( + "spark.hbase.properties.regionSplitPolicyClassName", + hbaseProperties.regionSplitPolicy + ) + .set("spark.hbase.properties.compressionAlgorithm", hbaseProperties.compressionAlgorithm) } jobConfig.metrics match { diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/BatchPipeline.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/BatchPipeline.scala index 4ce5d55..733d7d2 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/BatchPipeline.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/BatchPipeline.scala @@ -66,11 +66,19 @@ object BatchPipeline extends BasePipeline { .map(metrics.incrementRead) .filter(rowValidator.allChecks) + val onlineStore = config.store match { + case _: RedisConfig => "redis" + case _: BigTableConfig => "bigtable" + case _: HBaseConfig => "hbase" + } + validRows.write .format(config.store match { case _: RedisConfig => "dev.caraml.spark.stores.redis" case _: BigTableConfig => "dev.caraml.spark.stores.bigtable" + case _: HBaseConfig => "dev.caraml.spark.stores.bigtable" }) + .option("online_store", onlineStore) .option("entity_columns", featureTable.entities.map(_.name).mkString(",")) .option("namespace", featureTable.name) .option("project_name", featureTable.project) diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJob.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJob.scala index 69196c9..20939b2 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJob.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJob.scala @@ -87,6 +87,9 @@ object IngestionJob { opt[String](name = "bigtable") .action((x, c) => c.copy(store = parseJSON(x).camelizeKeys.extract[BigTableConfig])) + opt[String](name = "hbase") + .action((x, c) => c.copy(store = parseJSON(x).extract[HBaseConfig])) + opt[String](name = "statsd") .action((x, c) => c.copy(metrics = Some(parseJSON(x).extract[StatsDConfig]))) diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala index a13524d..639f2e5 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala @@ -27,6 +27,16 @@ case class RedisWriteProperties( ratePerSecondLimit: Int = 50000 ) case class BigTableConfig(projectId: String, instanceId: String) extends StoreConfig +case class HBaseConfig( + zookeeperQuorum: String, + zookeeperPort: Int, + hbaseProperties: HBaseProperties = HBaseProperties() +) extends StoreConfig +case class HBaseProperties( + regionSplitPolicy: String = + "org.apache.hadoop.hbase.regionserver.IncreasingToUpperBoundRegionSplitPolicy", + compressionAlgorithm: String = "ZSTD" +) sealed trait MetricConfig diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/StreamingPipeline.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/StreamingPipeline.scala index 1620705..fedae24 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/StreamingPipeline.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/StreamingPipeline.scala @@ -77,6 +77,12 @@ object StreamingPipeline extends BasePipeline with Serializable { case _ => Array() } + val onlineStore = config.store match { + case _: RedisConfig => "redis" + case _: BigTableConfig => "bigtable" + case _: HBaseConfig => "hbase" + } + val parsed = input .withColumn("features", featureStruct) .select(metadata :+ col("features.*"): _*) @@ -108,7 +114,9 @@ object StreamingPipeline extends BasePipeline with Serializable { .format(config.store match { case _: RedisConfig => "dev.caraml.spark.stores.redis" case _: BigTableConfig => "dev.caraml.spark.stores.bigtable" + case _: HBaseConfig => "dev.caraml.spark.stores.bigtable" }) + .option("online_store", onlineStore) .option("entity_columns", featureTable.entities.map(_.name).mkString(",")) .option("namespace", featureTable.name) .option("project_name", featureTable.project) diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/BigTableSinkRelation.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/BigTableSinkRelation.scala index 8cf36b8..a5049d3 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/BigTableSinkRelation.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/BigTableSinkRelation.scala @@ -4,7 +4,13 @@ import com.google.cloud.bigtable.hbase.BigtableConfiguration import dev.caraml.spark.serialization.Serializer import dev.caraml.spark.utils.StringUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.client.{ + Admin, + ColumnFamilyDescriptorBuilder, + Connection, + Put, + TableDescriptorBuilder +} import org.apache.hadoop.hbase.mapred.TableOutputFormat import org.apache.hadoop.hbase.{HColumnDescriptor, HTableDescriptor, TableName} import org.apache.hadoop.mapred.JobConf @@ -30,42 +36,49 @@ class BigTableSinkRelation( override def schema: StructType = ??? + def getConnection(hadoopConfig: Configuration): Connection = { + BigtableConfiguration.connect(hadoopConfig) + } + def createTable(): Unit = { - val btConn = BigtableConfiguration.connect(hadoopConfig) + val btConn = getConnection(hadoopConfig) try { val admin = btConn.getAdmin val table = if (!admin.isTableAvailable(TableName.valueOf(tableName))) { - val t = new HTableDescriptor(TableName.valueOf(tableName)) - val metadataCF = new HColumnDescriptor(metadataColumnFamily) - t.addFamily(metadataCF) - t + val tableBuilder = TableDescriptorBuilder.newBuilder(TableName.valueOf(tableName)) + val cf = ColumnFamilyDescriptorBuilder.of(metadataColumnFamily) + tableBuilder.setColumnFamily(cf) + val table = tableBuilder.build() + table } else { - admin.getTableDescriptor(TableName.valueOf(tableName)) + val t = btConn.getTable(TableName.valueOf(tableName)) + t.getDescriptor() } - - val featuresCF = new HColumnDescriptor(config.namespace) + val featuresCFBuilder = ColumnFamilyDescriptorBuilder.newBuilder(config.namespace.getBytes) if (config.maxAge > 0) { - featuresCF.setTimeToLive(config.maxAge.toInt) + featuresCFBuilder.setTimeToLive(config.maxAge.toInt) } + featuresCFBuilder.setMaxVersions(1) + val featuresCF = featuresCFBuilder.build() - featuresCF.setMaxVersions(1) + val tdb = TableDescriptorBuilder.newBuilder(table) if (!table.getColumnFamilyNames.contains(config.namespace.getBytes)) { - table.addFamily(featuresCF) - + tdb.setColumnFamily(featuresCF) + val t = tdb.build() if (!admin.isTableAvailable(table.getTableName)) { - admin.createTable(table) + admin.createTable(t) } else { - admin.modifyTable(table) + admin.modifyTable(t) } } else if ( config.maxAge > 0 && table .getColumnFamily(config.namespace.getBytes) .getTimeToLive != featuresCF.getTimeToLive ) { - table.modifyFamily(featuresCF) - admin.modifyTable(table) + tdb.modifyColumnFamily(featuresCF) + admin.modifyTable(tdb.build()) } } finally { btConn.close() @@ -115,7 +128,7 @@ class BigTableSinkRelation( val qualifier = "avro".getBytes put.addColumn(metadataColumnFamily.getBytes, qualifier, schema.asInstanceOf[String].getBytes) - val btConn = BigtableConfiguration.connect(hadoopConfig) + val btConn = getConnection(hadoopConfig) try { val table = btConn.getTable(TableName.valueOf(tableName)) table.checkAndPut( @@ -130,19 +143,19 @@ class BigTableSinkRelation( } } - private def tableName: String = { + protected def tableName: String = { val entities = config.entityColumns.sorted.mkString("__") StringUtils.trimAndHash(s"${config.projectName}__${entities}", maxTableNameLength) } - private def joinEntityKey: UserDefinedFunction = udf { r: Row => + protected def joinEntityKey: UserDefinedFunction = udf { r: Row => ((0 until r.size)).map(r.getString).mkString("#").getBytes } - private val metadataColumnFamily = "metadata" - private val schemaKeyPrefix = "schema#" - private val emptyQualifier = "" - private val maxTableNameLength = 50 + protected val metadataColumnFamily = "metadata" + protected val schemaKeyPrefix = "schema#" + protected val emptyQualifier = "" + protected val maxTableNameLength = 50 private def isSystemColumn(name: String) = (config.entityColumns ++ Seq(config.timestampColumn)).contains(name) diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/DefaultSource.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/DefaultSource.scala index 3c31c89..5838aab 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/DefaultSource.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/DefaultSource.scala @@ -23,27 +23,43 @@ class DefaultSource extends CreatableRelationProvider { parameters: Map[String, String], data: DataFrame ): BaseRelation = { - val bigtableConf = BigtableConfiguration.configure( - sqlContext.getConf(PROJECT_KEY), - sqlContext.getConf(INSTANCE_KEY) - ) - - if (sqlContext.getConf("spark.bigtable.emulatorHost", "").nonEmpty) { - bigtableConf.set( - BIGTABLE_EMULATOR_HOST_KEY, - sqlContext.getConf("spark.bigtable.emulatorHost") + val onlineStore = parameters.getOrElse("online_store", "bigtable") + var rel: BigTableSinkRelation = null + println(s"onlineStore: $onlineStore") + if (onlineStore == "bigtable") { + val bigtableConf = BigtableConfiguration.configure( + sqlContext.getConf(PROJECT_KEY), + sqlContext.getConf(INSTANCE_KEY) ) - } - configureBigTableClient(bigtableConf, sqlContext) + if (sqlContext.getConf("spark.bigtable.emulatorHost", "").nonEmpty) { + bigtableConf.set( + BIGTABLE_EMULATOR_HOST_KEY, + sqlContext.getConf("spark.bigtable.emulatorHost") + ) + } + + configureBigTableClient(bigtableConf, sqlContext) - val rel = - new BigTableSinkRelation( + rel = new BigTableSinkRelation( sqlContext, new AvroSerializer, SparkBigtableConfig.parse(parameters), bigtableConf ) + } else if (onlineStore == "hbase") { + val hbaseConf = new Configuration() + hbaseConf.set("hbase.zookeeper.quorum", sqlContext.getConf(ZOOKEEPER_QUOROM_KEY)) + hbaseConf.set("hbase.zookeeper.property.clientPort", sqlContext.getConf(ZOOKEEPER_PORT_KEY)) + rel = new HbaseSinkRelation( + sqlContext, + new AvroSerializer, + SparkBigtableConfig.parse(parameters), + hbaseConf + ) + } else { + throw new UnsupportedOperationException(s"Unsupported online store: $onlineStore") + } rel.createTable() rel.saveWriteSchema(data) rel.insert(data, overwrite = false) @@ -79,4 +95,7 @@ object DefaultSource { private val THROTTLING_THRESHOLD_MILLIS_KEY = "spark.bigtable.throttlingThresholdMs" private val MAX_ROW_COUNT_KEY = "spark.bigtable.maxRowCount" private val MAX_INFLIGHT_KEY = "spark.bigtable.maxInflightRpcs" + + private val ZOOKEEPER_QUOROM_KEY = "spark.hbase.zookeeper.quorum" + private val ZOOKEEPER_PORT_KEY = "spark.hbase.zookeeper.port" } diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/HbaseSinkRelation.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/HbaseSinkRelation.scala new file mode 100644 index 0000000..97cc575 --- /dev/null +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/bigtable/HbaseSinkRelation.scala @@ -0,0 +1,79 @@ +package dev.caraml.spark.stores.bigtable + +import dev.caraml.spark.serialization.Serializer +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hbase.TableName +import org.apache.hadoop.hbase.client.{ + ColumnFamilyDescriptorBuilder, + Connection, + ConnectionFactory, + TableDescriptorBuilder +} +import org.apache.hadoop.hbase.io.compress.Compression +import org.apache.spark.sql.SQLContext + +class HbaseSinkRelation( + sqlContext: SQLContext, + serializer: Serializer, + config: SparkBigtableConfig, + hadoopConfig: Configuration +) extends BigTableSinkRelation(sqlContext, serializer, config, hadoopConfig) { + override def getConnection(hadoopConfig: Configuration): Connection = { + ConnectionFactory.createConnection(hadoopConfig) + } + override def createTable(): Unit = { + val hbaseConn = getConnection(hadoopConfig) + try { + val admin = hbaseConn.getAdmin + + val table = if (!admin.isTableAvailable(TableName.valueOf(tableName))) { + val tableBuilder = TableDescriptorBuilder.newBuilder(TableName.valueOf(tableName)) + val cf = ColumnFamilyDescriptorBuilder.of(metadataColumnFamily) + tableBuilder.setColumnFamily(cf) + val table = tableBuilder.build() + table + } else { + val t = hbaseConn.getTable(TableName.valueOf(tableName)) + t.getDescriptor() + } + val featuresCFBuilder = ColumnFamilyDescriptorBuilder.newBuilder(config.namespace.getBytes) + if (config.maxAge > 0) { + featuresCFBuilder.setTimeToLive(config.maxAge.toInt) + } + featuresCFBuilder.setMaxVersions(1) + sqlContext.getConf("spark.hbase.properties.compressionAlgorithm") match { + case "ZSTD" => featuresCFBuilder.setCompressionType(Compression.Algorithm.ZSTD) + case "GZ" => featuresCFBuilder.setCompressionType(Compression.Algorithm.GZ) + case "LZ4" => featuresCFBuilder.setCompressionType(Compression.Algorithm.LZ4) + case "SNAPPY" => featuresCFBuilder.setCompressionType(Compression.Algorithm.SNAPPY) + case _ => featuresCFBuilder.setCompressionType(Compression.Algorithm.NONE) + } + val featuresCF = featuresCFBuilder.build() + + val tdb = TableDescriptorBuilder.newBuilder(table) + tdb.setRegionSplitPolicyClassName( + sqlContext.getConf("spark.hbase.properties.regionSplitPolicyClassName") + ) + + if (!table.getColumnFamilyNames.contains(config.namespace.getBytes)) { + tdb.setColumnFamily(featuresCF) + val t = tdb.build() + if (!admin.isTableAvailable(table.getTableName)) { + admin.createTable(t) + } else { + admin.modifyTable(t) + } + } else if ( + config.maxAge > 0 && table + .getColumnFamily(config.namespace.getBytes) + .getTimeToLive != featuresCF.getTimeToLive + ) { + tdb.modifyColumnFamily(featuresCF) + admin.modifyTable(tdb.build()) + } + + } finally { + hbaseConn.close() + } + } +}