Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Sum and Average aggregations #1387

Merged
merged 14 commits into from
Oct 9, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/** Represents an aggregation that can be performed by Firestore. */
public abstract class AggregateField {
ehsannas marked this conversation as resolved.
Show resolved Hide resolved
/**
* Create a {@link CountAggregateField} object that can be used to compute the count of documents
* in the result set of a query.
*/
@Nonnull
public static CountAggregateField count() {
return new CountAggregateField();
}

/**
* Create a {@link SumAggregateField} object that can be used to compute the sum of a specified
* field over a range of documents in the result set of a query.
*
* @param field Specifies the field to sum across the result set.
*/
@Nonnull
public static SumAggregateField sum(@Nonnull String field) {
return new SumAggregateField(FieldPath.fromDotSeparatedString(field));
}

/**
* Create a {@link SumAggregateField} object that can be used to compute the sum of a specified
* field over a range of documents in the result set of a query.
*
* @param fieldPath Specifies the field to sum across the result set.
*/
@Nonnull
public static SumAggregateField sum(@Nonnull FieldPath fieldPath) {
return new SumAggregateField(fieldPath);
}

/**
* Create an {@link AverageAggregateField} object that can be used to compute the average of a
* specified field over a range of documents in the result set of a query.
*
* @param field Specifies the field to average across the result set.
*/
@Nonnull
public static AverageAggregateField average(@Nonnull String field) {
return new AverageAggregateField(FieldPath.fromDotSeparatedString(field));
}

/**
* Create an {@link AverageAggregateField} object that can be used to compute the average of a
* specified field over a range of documents in the result set of a query.
*
* @param fieldPath Specifies the field to average across the result set.
*/
@Nonnull
public static AverageAggregateField average(@Nonnull FieldPath fieldPath) {
return new AverageAggregateField(fieldPath);
}

/** The field over which the aggregation is performed. */
@Nullable FieldPath fieldPath;

/** Returns the alias used internally for this aggregate field. */
@Nonnull
String getAlias() {
// Use $operator_$field format if it's an aggregation of a specific field. For example: sum_foo.
// Use $operator format if there's no field. For example: count.
return getOperator() + (fieldPath == null ? "" : "_" + fieldPath.getEncodedPath());
}

/**
* Returns the field on which the aggregation takes place. Returns an empty string if there's no
* field (e.g. for count).
*/
@Nonnull
String getFieldPath() {
return fieldPath == null ? "" : fieldPath.getEncodedPath();
}

/** Returns a string representation of this aggregation's operator. For example: "sum" */
abstract @Nonnull String getOperator();

/**
* Returns true if the given object is equal to this object. Two `AggregateField` objects are
* considered equal if they have the same operator and operate on the same field.
*/
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof AggregateField)) {
return false;
}
AggregateField otherAggregateField = (AggregateField) other;
return getOperator().equals(otherAggregateField.getOperator())
&& getFieldPath().equals(otherAggregateField.getFieldPath());
}

/** Calculates and returns the hash code for this object. */
@Override
public int hashCode() {
return Objects.hash(getOperator(), getFieldPath());
}

/** Represents a "sum" aggregation that can be performed by Firestore. */
public static class SumAggregateField extends AggregateField {
private SumAggregateField(@Nonnull FieldPath field) {
fieldPath = field;
}

@Override
@Nonnull
public String getOperator() {
return "sum";
}
}

/** Represents an "average" aggregation that can be performed by Firestore. */
public static class AverageAggregateField extends AggregateField {
private AverageAggregateField(@Nonnull FieldPath field) {
fieldPath = field;
}

@Override
@Nonnull
public String getOperator() {
return "average";
}
}

/** Represents a "count" aggregation that can be performed by Firestore. */
public static class CountAggregateField extends AggregateField {
private CountAggregateField() {}

@Override
@Nonnull
public String getOperator() {
return "count";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,27 @@
import com.google.api.gax.rpc.StreamController;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.v1.FirestoreSettings;
import com.google.firestore.v1.RunAggregationQueryRequest;
import com.google.firestore.v1.RunAggregationQueryResponse;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.StructuredAggregationQuery;
import com.google.firestore.v1.Value;
import com.google.firestore.v1.*;
ehsannas marked this conversation as resolved.
Show resolved Hide resolved
import com.google.firestore.v1.StructuredAggregationQuery.Aggregation;
import com.google.protobuf.ByteString;
import java.util.Set;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/** A query that calculates aggregations over an underlying query. */
@InternalExtensionOnly
public class AggregateQuery {
@Nonnull private final Query query;

/**
* The "alias" to specify in the {@link RunAggregationQueryRequest} proto when running a count
* query. The actual value is not meaningful, but will be used to get the count out of the {@link
* RunAggregationQueryResponse}.
*/
private static final String ALIAS_COUNT = "count";
@Nonnull private List<AggregateField> aggregateFieldList;

@Nonnull private final Query query;
@Nonnull private Map<String, String> aliasMap;

AggregateQuery(@Nonnull Query query) {
AggregateQuery(@Nonnull Query query, @Nonnull List<AggregateField> aggregateFields) {
this.query = query;
this.aggregateFieldList = aggregateFields;
this.aliasMap = new HashMap<>();
}

/** Returns the query whose aggregations will be calculated by this object. */
Expand Down Expand Up @@ -112,9 +107,11 @@ long getStartTimeNanos() {
return startTimeNanos;
}

void deliverResult(long count, Timestamp readTime) {
void deliverResult(@Nonnull Map<String, Value> data, Timestamp readTime) {
if (isFutureCompleted.compareAndSet(false, true)) {
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));
Map<String, Value> mappedData = new HashMap<>();
data.forEach((serverAlias, value) -> mappedData.put(aliasMap.get(serverAlias), value));
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, mappedData));
}
}

Expand Down Expand Up @@ -145,26 +142,13 @@ public void onResponse(RunAggregationQueryResponse response) {
// Close the stream to avoid it dangling, since we're not expecting any more responses.
streamController.cancel();

// Extract the count and read time from the RunAggregationQueryResponse.
// Extract the aggregations and read time from the RunAggregationQueryResponse.
Timestamp readTime = Timestamp.fromProto(response.getReadTime());
Value value = response.getResult().getAggregateFieldsMap().get(ALIAS_COUNT);
if (value == null) {
throw new IllegalArgumentException(
"RunAggregationQueryResponse is missing required alias: " + ALIAS_COUNT);
} else if (value.getValueTypeCase() != Value.ValueTypeCase.INTEGER_VALUE) {
throw new IllegalArgumentException(
"RunAggregationQueryResponse alias "
+ ALIAS_COUNT
+ " has incorrect type: "
+ value.getValueTypeCase());
}
long count = value.getIntegerValue();

// Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC, meaning
// that `onResponse()` can be called multiple times, it _should_ only be called once for count
// queries. But even if it is called more than once, `responseDeliverer` will drop superfluous
// results.
responseDeliverer.deliverResult(count, readTime);
// that `onResponse()` can be called multiple times, it _should_ only be called once. But even
// if it is called more than once, `responseDeliverer` will drop superfluous results.
responseDeliverer.deliverResult(response.getResult().getAggregateFieldsMap(), readTime);
}

@Override
Expand Down Expand Up @@ -215,12 +199,45 @@ RunAggregationQueryRequest toProto(@Nullable final ByteString transactionId) {
request.getStructuredAggregationQueryBuilder();
structuredAggregationQuery.setStructuredQuery(runQueryRequest.getStructuredQuery());

StructuredAggregationQuery.Aggregation.Builder aggregation =
StructuredAggregationQuery.Aggregation.newBuilder();
aggregation.setCount(StructuredAggregationQuery.Aggregation.Count.getDefaultInstance());
aggregation.setAlias(ALIAS_COUNT);
structuredAggregationQuery.addAggregations(aggregation);
// We use this set to remove duplicate aggregates. e.g. `aggregate(sum("foo"), sum("foo"))`
HashSet<String> uniqueAggregates = new HashSet<>();
List<StructuredAggregationQuery.Aggregation> aggregations = new ArrayList<>();
int aggregationNum = 0;
for (AggregateField aggregateField : aggregateFieldList) {
// `getAlias()` provides a unique representation of an AggregateField.
boolean isNewAggregateField = uniqueAggregates.add(aggregateField.getAlias());
if (!isNewAggregateField) {
// This is a duplicate AggregateField. We don't need to include it in the request.
continue;
}

// If there's a field for this aggregation, build its proto.
StructuredQuery.FieldReference field = null;
if (!aggregateField.getFieldPath().isEmpty()) {
field =
StructuredQuery.FieldReference.newBuilder()
.setFieldPath(aggregateField.getFieldPath())
.build();
}
// Build the aggregation proto.
Aggregation.Builder aggregation = Aggregation.newBuilder();
if (aggregateField instanceof AggregateField.CountAggregateField) {
aggregation.setCount(Aggregation.Count.getDefaultInstance());
} else if (aggregateField instanceof AggregateField.SumAggregateField) {
aggregation.setSum(Aggregation.Sum.newBuilder().setField(field).build());
} else if (aggregateField instanceof AggregateField.AverageAggregateField) {
aggregation.setAvg(Aggregation.Avg.newBuilder().setField(field).build());
} else {
throw new RuntimeException("Unsupported aggregation");
}
// Map all client-side aliases to a unique short-form alias.
// This avoids issues with client-side aliases that exceed the 1500-byte string size limit.
String serverAlias = "aggregate_" + aggregationNum++;
aliasMap.put(serverAlias, aggregateField.getAlias());
aggregation.setAlias(serverAlias);
aggregations.add(aggregation.build());
}
structuredAggregationQuery.addAllAggregations(aggregations);
return request.build();
}

Expand All @@ -243,7 +260,23 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
.setStructuredQuery(proto.getStructuredAggregationQuery().getStructuredQuery())
.build();
Query query = Query.fromProto(firestore, runQueryRequest);
return new AggregateQuery(query);

List<AggregateField> aggregateFields = new ArrayList<>();
List<Aggregation> aggregations = proto.getStructuredAggregationQuery().getAggregationsList();
aggregations.forEach(
aggregation -> {
if (aggregation.hasCount()) {
aggregateFields.add(AggregateField.count());
} else if (aggregation.hasAvg()) {
aggregateFields.add(
AggregateField.average(aggregation.getAvg().getField().getFieldPath()));
} else if (aggregation.hasSum()) {
aggregateFields.add(AggregateField.sum(aggregation.getSum().getField().getFieldPath()));
} else {
throw new RuntimeException("Unsupported aggregation.");
}
});
return new AggregateQuery(query, aggregateFields);
}

/**
Expand All @@ -253,7 +286,7 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
*/
@Override
public int hashCode() {
return query.hashCode();
return Objects.hash(query, aggregateFieldList);
}

/**
Expand All @@ -280,6 +313,6 @@ public boolean equals(Object object) {
return false;
}
AggregateQuery other = (AggregateQuery) object;
return query.equals(other.query);
return query.equals(other.query) && aggregateFieldList.equals(other.aggregateFieldList);
}
}
Loading