-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9178b66
commit 4dc8df6
Showing
5 changed files
with
285 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
...l-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseOnlineRetriever.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package dev.caraml.serving.store.bigtable; | ||
|
||
import com.google.protobuf.ByteString; | ||
import com.google.protobuf.Timestamp; | ||
import dev.caraml.serving.store.AvroFeature; | ||
import dev.caraml.serving.store.Feature; | ||
import dev.caraml.store.protobuf.serving.ServingServiceProto; | ||
|
||
import java.io.IOException; | ||
import java.util.*; | ||
import java.util.stream.Collectors; | ||
|
||
import org.apache.avro.AvroRuntimeException; | ||
import org.apache.avro.generic.GenericDatumReader; | ||
import org.apache.avro.generic.GenericRecord; | ||
import org.apache.avro.io.BinaryDecoder; | ||
import org.apache.avro.io.DecoderFactory; | ||
import org.apache.hadoop.hbase.Cell; | ||
import org.apache.hadoop.hbase.TableName; | ||
import org.apache.hadoop.hbase.client.*; | ||
import org.apache.hadoop.hbase.util.Bytes; | ||
|
||
public class HBaseOnlineRetriever implements SSTableOnlineRetriever<ByteString, Result> { | ||
private final Connection client; | ||
private final HBaseSchemaRegistry schemaRegistry; | ||
|
||
public HBaseOnlineRetriever(Connection client) { | ||
this.client = client; | ||
this.schemaRegistry = new HBaseSchemaRegistry(client); | ||
} | ||
|
||
@Override | ||
public ByteString convertEntityValueToKey(ServingServiceProto.GetOnlineFeaturesRequest.EntityRow entityRow, List<String> entityNames) { | ||
return ByteString.copyFrom(entityNames.stream().sorted().map(entity -> entityRow.getFieldsMap().get(entity)).map(this::valueToString).collect(Collectors.joining("#")).getBytes()); | ||
} | ||
|
||
@Override | ||
public List<List<Feature>> convertRowToFeature(String tableName, List<ByteString> rowKeys, Map<ByteString, Result> rows, List<ServingServiceProto.FeatureReference> featureReferences) { | ||
BinaryDecoder reusedDecoder = DecoderFactory.get().binaryDecoder(new byte[0], null); | ||
|
||
return rowKeys.stream().map(rowKey -> { | ||
if (!rows.containsKey(rowKey)) { | ||
return Collections.<Feature>emptyList(); | ||
} else { | ||
Result row = rows.get(rowKey); | ||
return featureReferences.stream().map(ServingServiceProto.FeatureReference::getFeatureTable).distinct().map(cf -> row.getColumnCells(cf.getBytes(), null)).filter(ls -> !ls.isEmpty()).flatMap(rowCells -> { | ||
Cell rowCell = rowCells.get(0); // Latest cell | ||
String family = Bytes.toString(rowCell.getFamilyArray()); | ||
ByteString value = ByteString.copyFrom(rowCell.getValueArray()); | ||
|
||
List<Feature> features; | ||
List<ServingServiceProto.FeatureReference> localFeatureReferences = featureReferences.stream().filter(featureReference -> featureReference.getFeatureTable().equals(family)).collect(Collectors.toList()); | ||
|
||
try { | ||
features = decodeFeatures(tableName, value, localFeatureReferences, reusedDecoder, rowCell.getTimestamp()); | ||
} catch (IOException e) { | ||
throw new RuntimeException("Failed to decode features from BigTable"); | ||
} | ||
|
||
return features.stream(); | ||
}).collect(Collectors.toList()); | ||
} | ||
}).collect(Collectors.toList()); | ||
} | ||
|
||
@Override | ||
public Map<ByteString, Result> getFeaturesFromSSTable(String tableName, List<ByteString> rowKeys, List<String> columnFamilies) { | ||
try { | ||
Table table = this.client.getTable(TableName.valueOf(tableName)); | ||
|
||
List<Get> getList = new ArrayList<>(); | ||
for (ByteString rowKey : rowKeys) { | ||
Get get = new Get(rowKey.toByteArray()); | ||
for (String columnFamily : columnFamilies) { | ||
get.addFamily(columnFamily.getBytes()); | ||
} | ||
getList.add(get); | ||
} | ||
|
||
Result[] rows = table.get(getList); | ||
|
||
Map<ByteString, Result> result = new HashMap<>(); | ||
for (Result row : rows) { | ||
if (row.isEmpty()) { | ||
continue; | ||
} | ||
result.put(ByteString.copyFrom(row.getRow()), row); | ||
} | ||
|
||
return result; | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private List<Feature> decodeFeatures(String tableName, ByteString value, List<ServingServiceProto.FeatureReference> featureReferences, BinaryDecoder reusedDecoder, long timestamp) throws IOException { | ||
ByteString schemaReferenceBytes = value.substring(0, 4); | ||
byte[] featureValueBytes = value.substring(4).toByteArray(); | ||
|
||
HBaseSchemaRegistry.SchemaReference schemaReference = new HBaseSchemaRegistry.SchemaReference(tableName, schemaReferenceBytes); | ||
|
||
GenericDatumReader<GenericRecord> reader = this.schemaRegistry.getReader(schemaReference); | ||
|
||
reusedDecoder = DecoderFactory.get().binaryDecoder(featureValueBytes, reusedDecoder); | ||
GenericRecord record = reader.read(null, reusedDecoder); | ||
|
||
return featureReferences.stream().map(featureReference -> { | ||
Object featureValue; | ||
try { | ||
featureValue = record.get(featureReference.getName()); | ||
} catch (AvroRuntimeException e) { | ||
// Feature is not found in schema | ||
return null; | ||
} | ||
return new AvroFeature(featureReference, Timestamp.newBuilder().setSeconds(timestamp / 1000).build(), Objects.requireNonNullElseGet(featureValue, Object::new)); | ||
}).filter(Objects::nonNull).collect(Collectors.toList()); | ||
} | ||
} |
101 changes: 101 additions & 0 deletions
101
...ml-store-serving/src/main/java/dev/caraml/serving/store/bigtable/HBaseSchemaRegistry.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
package dev.caraml.serving.store.bigtable; | ||
|
||
import com.google.common.cache.CacheBuilder; | ||
import com.google.common.cache.CacheLoader; | ||
import com.google.common.cache.LoadingCache; | ||
import com.google.protobuf.ByteString; | ||
|
||
import java.io.IOException; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
import org.apache.avro.Schema; | ||
import org.apache.avro.generic.GenericDatumReader; | ||
import org.apache.avro.generic.GenericRecord; | ||
import org.apache.hadoop.hbase.Cell; | ||
import org.apache.hadoop.hbase.TableName; | ||
import org.apache.hadoop.hbase.client.Connection; | ||
import org.apache.hadoop.hbase.client.Get; | ||
import org.apache.hadoop.hbase.client.Result; | ||
import org.apache.hadoop.hbase.client.Table; | ||
import org.apache.hadoop.hbase.util.Bytes; | ||
|
||
public class HBaseSchemaRegistry { | ||
private final Connection hbaseClient; | ||
private final LoadingCache<SchemaReference, GenericDatumReader<GenericRecord>> cache; | ||
|
||
private static String COLUMN_FAMILY = "metadata"; | ||
private static String QUALIFIER = "avro"; | ||
private static String KEY_PREFIX = "schema#"; | ||
|
||
public static class SchemaReference { | ||
private final String tableName; | ||
private final ByteString schemaHash; | ||
|
||
public SchemaReference(String tableName, ByteString schemaHash) { | ||
this.tableName = tableName; | ||
this.schemaHash = schemaHash; | ||
} | ||
|
||
public String getTableName() { | ||
return tableName; | ||
} | ||
|
||
public ByteString getSchemaHash() { | ||
return schemaHash; | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
int result = tableName.hashCode(); | ||
result = 31 * result + schemaHash.hashCode(); | ||
return result; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
|
||
SchemaReference that = (SchemaReference) o; | ||
|
||
if (!tableName.equals(that.tableName)) return false; | ||
return schemaHash.equals(that.schemaHash); | ||
} | ||
} | ||
|
||
public HBaseSchemaRegistry(Connection hbaseClient) { | ||
this.hbaseClient = hbaseClient; | ||
|
||
CacheLoader<SchemaReference, GenericDatumReader<GenericRecord>> schemaCacheLoader = CacheLoader.from(this::loadReader); | ||
|
||
cache = CacheBuilder.newBuilder().build(schemaCacheLoader); | ||
} | ||
|
||
public GenericDatumReader<GenericRecord> getReader(SchemaReference reference) { | ||
GenericDatumReader<GenericRecord> reader; | ||
try { | ||
reader = this.cache.get(reference); | ||
} catch (ExecutionException | CacheLoader.InvalidCacheLoadException e) { | ||
throw new RuntimeException(String.format("Unable to find Schema"), e); | ||
} | ||
return reader; | ||
} | ||
|
||
private GenericDatumReader<GenericRecord> loadReader(SchemaReference reference) { | ||
try { | ||
Table table = this.hbaseClient.getTable(TableName.valueOf(reference.getTableName())); | ||
|
||
byte[] rowKey = ByteString.copyFrom(KEY_PREFIX.getBytes()).concat(reference.getSchemaHash()).toByteArray(); | ||
Get query = new Get(rowKey); | ||
query.addColumn(COLUMN_FAMILY.getBytes(), QUALIFIER.getBytes()); | ||
|
||
Result result = table.get(query); | ||
|
||
Cell last = result.getColumnLatestCell(COLUMN_FAMILY.getBytes(), QUALIFIER.getBytes()); | ||
Schema schema = new Schema.Parser().parse(Bytes.toString(last.getValueArray())); | ||
return new GenericDatumReader<>(schema); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters