Skip to content

Commit

Permalink
Apply PR #392
Browse files Browse the repository at this point in the history
  • Loading branch information
zhilingc committed Dec 27, 2019
1 parent a340613 commit 3e40a61
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.cloud.bigquery.Schema;
import com.google.cloud.bigquery.Table;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.storage.Storage;
import feast.serving.ServingAPIProto;
import feast.serving.ServingAPIProto.DataFormat;
Expand All @@ -56,10 +57,12 @@
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.joda.time.Duration;
import org.slf4j.Logger;

public class BigQueryServingService implements ServingService {

public static final long TEMP_TABLE_EXPIRY_DURATION_MS = Duration.standardDays(1).getMillis();
private static final Logger log = org.slf4j.LoggerFactory.getLogger(BigQueryServingService.class);

private final BigQuery bigquery;
Expand Down Expand Up @@ -179,22 +182,33 @@ private Table loadEntities(DatasetSource datasetSource) {
switch (datasetSource.getDatasetSourceCase()) {
case FILE_SOURCE:
try {
String tableName = generateTemporaryTableName();
log.info("Loading entity dataset to table {}.{}.{}", projectId, datasetId, tableName);
TableId tableId = TableId.of(projectId, datasetId, tableName);
// Currently only avro supported
// Currently only AVRO format is supported

if (datasetSource.getFileSource().getDataFormat() != DataFormat.DATA_FORMAT_AVRO) {
throw Status.INVALID_ARGUMENT
.withDescription("Invalid file format, only avro supported")
.withDescription("Invalid file format, only AVRO is supported.")
.asRuntimeException();
}

TableId tableId = TableId.of(projectId, datasetId, createTempTableName());
log.info("Loading entity rows to: {}.{}.{}", projectId, datasetId, tableId.getTable());

LoadJobConfiguration loadJobConfiguration =
LoadJobConfiguration.of(
tableId, datasetSource.getFileSource().getFileUrisList(), FormatOptions.avro());
loadJobConfiguration =
loadJobConfiguration.toBuilder().setUseAvroLogicalTypes(true).build();
Job job = bigquery.create(JobInfo.of(loadJobConfiguration));
job.waitFor();

TableInfo expiry =
bigquery
.getTable(tableId)
.toBuilder()
.setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS)
.build();
bigquery.update(expiry);

loadedEntityTable = bigquery.getTable(tableId);
if (!loadedEntityTable.exists()) {
throw new RuntimeException(
Expand All @@ -204,7 +218,7 @@ private Table loadEntities(DatasetSource datasetSource) {
} catch (Exception e) {
log.error("Exception has occurred in loadEntities method: ", e);
throw Status.INTERNAL
.withDescription("Failed to load entity dataset into store")
.withDescription("Failed to load entity dataset into store: " + e.toString())
.withCause(e)
.asRuntimeException();
}
Expand All @@ -216,20 +230,23 @@ private Table loadEntities(DatasetSource datasetSource) {
}
}

private String generateTemporaryTableName() {
String source = String.format("feastserving%d", System.currentTimeMillis());
String guid = UUID.nameUUIDFromBytes(source.getBytes()).toString();
String suffix = guid.substring(0, Math.min(guid.length(), 10)).replaceAll("-", "");
return String.format("temp_%s", suffix);
}

private TableId generateUUIDs(Table loadedEntityTable) {
try {
String uuidQuery =
createEntityTableUUIDQuery(generateFullTableName(loadedEntityTable.getTableId()));
QueryJobConfiguration queryJobConfig = QueryJobConfiguration.newBuilder(uuidQuery).build();
QueryJobConfiguration queryJobConfig =
QueryJobConfiguration.newBuilder(uuidQuery)
.setDestinationTable(TableId.of(projectId, datasetId, createTempTableName()))
.build();
Job queryJob = bigquery.create(JobInfo.of(queryJobConfig));
queryJob.waitFor();
TableInfo expiry =
bigquery
.getTable(queryJobConfig.getDestinationTable())
.toBuilder()
.setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS)
.build();
bigquery.update(expiry);
queryJobConfig = queryJob.getConfiguration();
return queryJobConfig.getDestinationTable();
} catch (InterruptedException | BigQueryException e) {
Expand All @@ -239,4 +256,8 @@ private TableId generateUUIDs(Table loadedEntityTable) {
.asRuntimeException();
}
}

public static String createTempTableName() {
return "_" + UUID.randomUUID().toString().replace("-", "");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package feast.serving.store.bigquery;

import static feast.serving.service.BigQueryServingService.TEMP_TABLE_EXPIRY_DURATION_MS;
import static feast.serving.service.BigQueryServingService.createTempTableName;
import static feast.serving.store.bigquery.QueryTemplater.createTimestampLimitQuery;

import com.google.auto.value.AutoValue;
Expand All @@ -27,6 +29,8 @@
import com.google.cloud.bigquery.Job;
import com.google.cloud.bigquery.JobInfo;
import com.google.cloud.bigquery.QueryJobConfiguration;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.bigquery.TableResult;
import com.google.cloud.storage.Blob;
import com.google.cloud.storage.Storage;
Expand Down Expand Up @@ -179,10 +183,13 @@ Job runBatchQuery(List<String> featureSetQueries)

for (int i = 0; i < featureSetQueries.size(); i++) {
QueryJobConfiguration queryJobConfig =
QueryJobConfiguration.newBuilder(featureSetQueries.get(i)).build();
QueryJobConfiguration.newBuilder(featureSetQueries.get(i))
.setDestinationTable(TableId.of(projectId(), datasetId(), createTempTableName()))
.build();
Job subqueryJob = bigquery().create(JobInfo.of(queryJobConfig));
executorCompletionService.submit(
SubqueryCallable.builder()
.setBigquery(bigquery())
.setFeatureSetInfo(featureSetInfos().get(i))
.setSubqueryJob(subqueryJob)
.build());
Expand Down Expand Up @@ -214,10 +221,21 @@ Job runBatchQuery(List<String> featureSetQueries)
String joinQuery =
QueryTemplater.createJoinQuery(
featureSetInfos, entityTableColumnNames(), entityTableName());
QueryJobConfiguration queryJobConfig = QueryJobConfiguration.newBuilder(joinQuery).build();
QueryJobConfiguration queryJobConfig =
QueryJobConfiguration.newBuilder(joinQuery)
.setDestinationTable(TableId.of(projectId(), datasetId(), createTempTableName()))
.build();
queryJob = bigquery().create(JobInfo.of(queryJobConfig));
queryJob.waitFor();

TableInfo expiry =
bigquery()
.getTable(queryJobConfig.getDestinationTable())
.toBuilder()
.setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS)
.build();
bigquery().update(expiry);

return queryJob;
}

Expand Down Expand Up @@ -248,10 +266,18 @@ private FieldValueList getTimestampLimits(String entityTableName) {
QueryJobConfiguration getTimestampLimitsQuery =
QueryJobConfiguration.newBuilder(createTimestampLimitQuery(entityTableName))
.setDefaultDataset(DatasetId.of(projectId(), datasetId()))
.setDestinationTable(TableId.of(projectId(), datasetId(), createTempTableName()))
.build();
try {
Job job = bigquery().create(JobInfo.of(getTimestampLimitsQuery));
TableResult getTimestampLimitsQueryResult = job.waitFor().getQueryResults();
TableInfo expiry =
bigquery()
.getTable(getTimestampLimitsQuery.getDestinationTable())
.toBuilder()
.setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS)
.build();
bigquery().update(expiry);
FieldValueList result = null;
for (FieldValueList fields : getTimestampLimitsQueryResult.getValues()) {
result = fields;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
*/
package feast.serving.store.bigquery;

import static feast.serving.service.BigQueryServingService.TEMP_TABLE_EXPIRY_DURATION_MS;
import static feast.serving.store.bigquery.QueryTemplater.generateFullTableName;

import com.google.auto.value.AutoValue;
import com.google.cloud.bigquery.BigQuery;
import com.google.cloud.bigquery.BigQueryException;
import com.google.cloud.bigquery.Job;
import com.google.cloud.bigquery.QueryJobConfiguration;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableInfo;
import feast.serving.store.bigquery.model.FeatureSetInfo;
import java.util.concurrent.Callable;

Expand All @@ -33,6 +36,8 @@
@AutoValue
public abstract class SubqueryCallable implements Callable<FeatureSetInfo> {

public abstract BigQuery bigquery();

public abstract FeatureSetInfo featureSetInfo();

public abstract Job subqueryJob();
Expand All @@ -44,6 +49,8 @@ public static Builder builder() {
@AutoValue.Builder
public abstract static class Builder {

public abstract Builder setBigquery(BigQuery bigquery);

public abstract Builder setFeatureSetInfo(FeatureSetInfo featureSetInfo);

public abstract Builder setSubqueryJob(Job subqueryJob);
Expand All @@ -57,6 +64,15 @@ public FeatureSetInfo call() throws BigQueryException, InterruptedException {
subqueryJob().waitFor();
subqueryConfig = subqueryJob().getConfiguration();
TableId destinationTable = subqueryConfig.getDestinationTable();

TableInfo expiry =
bigquery()
.getTable(destinationTable)
.toBuilder()
.setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS)
.build();
bigquery().update(expiry);

String fullTablePath = generateFullTableName(destinationTable);

return new FeatureSetInfo(featureSetInfo(), fullTablePath);
Expand Down
7 changes: 5 additions & 2 deletions tests/e2e/bq-batch-retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
from google.protobuf.duration_pb2 import Duration
from pandavro import to_avro

pd.set_option('display.max_columns', None)

PROJECT_NAME = 'batch_' + uuid.uuid4().hex.upper()[0:6]


@pytest.fixture(scope="module")
def core_url(pytestconfig):
return pytestconfig.getoption("core_url")
Expand Down Expand Up @@ -319,8 +322,8 @@ def test_multiple_featureset_joins(client):
feature_retrieval_job = client.get_batch_features(
entity_rows=entity_df, feature_refs=[f"{PROJECT_NAME}/feature_value6:1", f"{PROJECT_NAME}/other_feature_value7:1"]
)
output = feature_retrieval_job.to_dataframe()
print(output.head())
output = feature_retrieval_job.to_dataframe().sort_values(by=["entity_id"])
print(output.head(10))

assert output["entity_id"].to_list() == [int(i) for i in output["feature_value6"].to_list()]
assert output["other_entity_id"].to_list() == output["other_feature_value7"].to_list()

0 comments on commit 3e40a61

Please sign in to comment.