Skip to content

Commit

Permalink
Expose the registrations for avg and count (facebookincubator#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored and zhejiangxiaomai committed Jul 26, 2022
1 parent 0f53935 commit d07c4bf
Show file tree
Hide file tree
Showing 5 changed files with 528 additions and 489 deletions.
332 changes: 2 additions & 330 deletions velox/functions/prestosql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,340 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/exec/Aggregate.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/functions/prestosql/aggregates/AverageAggregate.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
#include "velox/vector/ComplexVector.h"
#include "velox/vector/DecodedVector.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::aggregate {

namespace {

struct SumCount {
double sum{0};
int64_t count{0};
};

// Partial aggregation produces a pair of sum and count.
// Final aggregation takes a pair of sum and count and returns a real for real
// input types and double for other input types.
// T is the input type for partial aggregation. Not used for final aggregation.
template <typename T>
class AverageAggregate : public exec::Aggregate {
public:
explicit AverageAggregate(TypePtr resultType) : exec::Aggregate(resultType) {}

int32_t accumulatorFixedWidthSize() const override {
return sizeof(SumCount);
}

void initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) override {
setAllNulls(groups, indices);
for (auto i : indices) {
new (groups[i] + offset_) SumCount();
}
}

void finalize(char** /* unused */, int32_t /* unused */) override {}

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
// Real input type in Presto has special case and returns REAL, not DOUBLE.
if (resultType_->isDouble()) {
extractValuesImpl<double>(groups, numGroups, result);
} else {
extractValuesImpl<float>(groups, numGroups, result);
}
}

void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result)
override {
auto rowVector = (*result)->as<RowVector>();
auto sumVector = rowVector->childAt(0)->asFlatVector<double>();
auto countVector = rowVector->childAt(1)->asFlatVector<int64_t>();

rowVector->resize(numGroups);
sumVector->resize(numGroups);
countVector->resize(numGroups);
uint64_t* rawNulls = getRawNulls(rowVector);

int64_t* rawCounts = countVector->mutableRawValues();
double* rawSums = sumVector->mutableRawValues();
for (auto i = 0; i < numGroups; ++i) {
char* group = groups[i];
if (isNull(group)) {
rowVector->setNull(i, true);
} else {
clearNull(rawNulls, i);
auto* sumCount = accumulator(group);
rawCounts[i] = sumCount->count;
rawSums[i] = sumCount->sum;
}
}
}

void addRawInput(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
decodedRaw_.decode(*args[0], rows);
if (decodedRaw_.isConstantMapping()) {
if (!decodedRaw_.isNullAt(0)) {
auto value = decodedRaw_.valueAt<T>(0);
rows.applyToSelected(
[&](vector_size_t i) { updateNonNullValue(groups[i], value); });
}
} else if (decodedRaw_.mayHaveNulls()) {
rows.applyToSelected([&](vector_size_t i) {
if (decodedRaw_.isNullAt(i)) {
return;
}
updateNonNullValue(groups[i], decodedRaw_.valueAt<T>(i));
});
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) {
auto data = decodedRaw_.data<T>();
rows.applyToSelected([&](vector_size_t i) {
updateNonNullValue<false>(groups[i], data[i]);
});
} else {
rows.applyToSelected([&](vector_size_t i) {
updateNonNullValue(groups[i], decodedRaw_.valueAt<T>(i));
});
}
}

void addSingleGroupRawInput(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
decodedRaw_.decode(*args[0], rows);

if (decodedRaw_.isConstantMapping()) {
if (!decodedRaw_.isNullAt(0)) {
const T value = decodedRaw_.valueAt<T>(0);
const auto numRows = rows.countSelected();
updateNonNullValue(group, numRows, value * numRows);
}
} else if (decodedRaw_.mayHaveNulls()) {
rows.applyToSelected([&](vector_size_t i) {
if (!decodedRaw_.isNullAt(i)) {
updateNonNullValue(group, decodedRaw_.valueAt<T>(i));
}
});
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) {
const T* data = decodedRaw_.data<T>();
double totalSum = 0;
rows.applyToSelected([&](vector_size_t i) { totalSum += data[i]; });
updateNonNullValue<false>(group, rows.countSelected(), totalSum);
} else {
double totalSum = 0;
rows.applyToSelected(
[&](vector_size_t i) { totalSum += decodedRaw_.valueAt<T>(i); });
updateNonNullValue(group, rows.countSelected(), totalSum);
}
}

void addIntermediateResults(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /* mayPushdown */) override {
decodedPartial_.decode(*args[0], rows);
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base());
auto baseSumVector = baseRowVector->childAt(0)->as<SimpleVector<double>>();
auto baseCountVector =
baseRowVector->childAt(1)->as<SimpleVector<int64_t>>();

if (decodedPartial_.isConstantMapping()) {
if (!decodedPartial_.isNullAt(0)) {
auto decodedIndex = decodedPartial_.index(0);
auto count = baseCountVector->valueAt(decodedIndex);
auto sum = baseSumVector->valueAt(decodedIndex);
rows.applyToSelected([&](vector_size_t i) {
updateNonNullValue(groups[i], count, sum);
});
}
} else if (decodedPartial_.mayHaveNulls()) {
rows.applyToSelected([&](vector_size_t i) {
if (decodedPartial_.isNullAt(i)) {
return;
}
auto decodedIndex = decodedPartial_.index(i);
updateNonNullValue(
groups[i],
baseCountVector->valueAt(decodedIndex),
baseSumVector->valueAt(decodedIndex));
});
} else {
rows.applyToSelected([&](vector_size_t i) {
auto decodedIndex = decodedPartial_.index(i);
updateNonNullValue(
groups[i],
baseCountVector->valueAt(decodedIndex),
baseSumVector->valueAt(decodedIndex));
});
}
}

void addSingleGroupIntermediateResults(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /* mayPushdown */) override {
decodedPartial_.decode(*args[0], rows);
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base());
auto baseSumVector = baseRowVector->childAt(0)->as<SimpleVector<double>>();
auto baseCountVector =
baseRowVector->childAt(1)->as<SimpleVector<int64_t>>();

if (decodedPartial_.isConstantMapping()) {
if (!decodedPartial_.isNullAt(0)) {
auto decodedIndex = decodedPartial_.index(0);
const auto numRows = rows.countSelected();
auto totalCount = baseCountVector->valueAt(decodedIndex) * numRows;
auto totalSum = baseSumVector->valueAt(decodedIndex) * numRows;
updateNonNullValue(group, totalCount, totalSum);
}
} else if (decodedPartial_.mayHaveNulls()) {
rows.applyToSelected([&](vector_size_t i) {
if (!decodedPartial_.isNullAt(i)) {
auto decodedIndex = decodedPartial_.index(i);
updateNonNullValue(
group,
baseCountVector->valueAt(decodedIndex),
baseSumVector->valueAt(decodedIndex));
}
});
} else {
double totalSum = 0;
int64_t totalCount = 0;
rows.applyToSelected([&](vector_size_t i) {
auto decodedIndex = decodedPartial_.index(i);
totalCount += baseCountVector->valueAt(decodedIndex);
totalSum += baseSumVector->valueAt(decodedIndex);
});
updateNonNullValue(group, totalCount, totalSum);
}
}

private:
// partial
template <bool tableHasNulls = true>
inline void updateNonNullValue(char* group, T value) {
if constexpr (tableHasNulls) {
exec::Aggregate::clearNull(group);
}
accumulator(group)->sum += value;
accumulator(group)->count += 1;
}

template <bool tableHasNulls = true>
inline void updateNonNullValue(char* group, int64_t count, double sum) {
if constexpr (tableHasNulls) {
exec::Aggregate::clearNull(group);
}
accumulator(group)->sum += sum;
accumulator(group)->count += count;
}

inline SumCount* accumulator(char* group) {
return exec::Aggregate::value<SumCount>(group);
}

template <typename TResult>
void extractValuesImpl(char** groups, int32_t numGroups, VectorPtr* result) {
auto vector = (*result)->as<FlatVector<TResult>>();
VELOX_CHECK(vector);
vector->resize(numGroups);
uint64_t* rawNulls = getRawNulls(vector);

TResult* rawValues = vector->mutableRawValues();
for (int32_t i = 0; i < numGroups; ++i) {
char* group = groups[i];
if (isNull(group)) {
vector->setNull(i, true);
} else {
clearNull(rawNulls, i);
auto* sumCount = accumulator(group);
rawValues[i] = (TResult)sumCount->sum / sumCount->count;
}
}
}

DecodedVector decodedRaw_;
DecodedVector decodedPartial_;
};

void checkSumCountRowType(TypePtr type, const std::string& errorMessage) {
VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage);
VELOX_CHECK_EQ(
type->childAt(0)->kind(), TypeKind::DOUBLE, "{}", errorMessage);
VELOX_CHECK_EQ(
type->childAt(1)->kind(), TypeKind::BIGINT, "{}", errorMessage);
}

bool registerAverageAggregate(const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;

for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) {
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
.returnType("double")
.intermediateType("row(double,bigint)")
.argumentType(inputType)
.build());
}
// Real input type in Presto has special case and returns REAL, not DOUBLE.
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
.returnType("real")
.intermediateType("row(double,bigint)")
.argumentType("real")
.build());

exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType) -> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(
argTypes.size(), 1, "{} takes at most one argument", name);
auto inputType = argTypes[0];
if (exec::isRawInput(step)) {
switch (inputType->kind()) {
case TypeKind::SMALLINT:
return std::make_unique<AverageAggregate<int16_t>>(resultType);
case TypeKind::INTEGER:
return std::make_unique<AverageAggregate<int32_t>>(resultType);
case TypeKind::BIGINT:
return std::make_unique<AverageAggregate<int64_t>>(resultType);
case TypeKind::REAL:
return std::make_unique<AverageAggregate<float>>(resultType);
case TypeKind::DOUBLE:
return std::make_unique<AverageAggregate<double>>(resultType);
default:
VELOX_FAIL(
"Unknown input type for {} aggregation {}",
name,
inputType->kindName());
}
} else {
checkSumCountRowType(
inputType,
"Input type for final aggregation must be (sum:double, count:bigint) struct");
return std::make_unique<AverageAggregate<int64_t>>(resultType);
}
});
return true;
}

static bool FB_ANONYMOUS_VARIABLE(g_AggregateFunction) =
registerAverageAggregate(kAvg);
} // namespace

} // namespace facebook::velox::aggregate
Loading

0 comments on commit d07c4bf

Please sign in to comment.