Skip to content

Commit

Permalink
builtins: fix aggregate functions for decimals
Browse files Browse the repository at this point in the history
fixes cockroachdb#55944

Release note: none
  • Loading branch information
mneverov committed Oct 25, 2020
1 parent 8aceac3 commit afc91f7
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 13 deletions.
102 changes: 90 additions & 12 deletions pkg/sql/logictest/testdata/logic_test/aggregate
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ CREATE TABLE kv (
)

# Aggregate functions return NULL if there are no rows.
query IIIIRRRRRRRRBBTIIR
SELECT min(1), max(1), count(1), sum_int(1), avg(1), sum(1), stddev(1), stddev_samp(1), stddev_pop(1), var_samp(1), variance(1), var_pop(1), bool_and(true), bool_and(false), xor_agg(b'\x01'), bit_and(1), bit_or(1), corr(1, 1) FROM kv
query IIIIRRRRRRRRBBTIIRR
SELECT min(1), max(1), count(1), sum_int(1), avg(1), sum(1), stddev(1), stddev_samp(1), stddev_pop(1), var_samp(1),
variance(1), var_pop(1), bool_and(true), bool_and(false), xor_agg(b'\x01'), bit_and(1), bit_or(1), corr(1, 1), sqrdiff(1)
FROM kv
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL
NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL

# Regression test for #29695
query T
Expand Down Expand Up @@ -42,10 +44,12 @@ SELECT min(i), avg(i), max(i), sum(i) FROM kv
----
NULL NULL NULL NULL

query IIIIRRRRRRBBTR
SELECT min(v), max(v), count(v), sum_int(1), avg(v), sum(v), stddev(v), stddev_pop(v), variance(v), var_pop(v), bool_and(v = 1), bool_and(v = 1), xor_agg(s::bytes), corr(v,k) FROM kv
query IIIIRRRRRRRBBTRR
SELECT min(v), max(v), count(v), sum_int(1), avg(v), sum(v), stddev(v), stddev_pop(v), variance(v), var_pop(v), var_samp(v),
bool_and(v = 1), bool_and(v = 1), xor_agg(s::bytes), corr(v,k), sqrdiff(v)
FROM kv
----
NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL
NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL

query T
SELECT array_agg(v) FROM kv
Expand All @@ -63,10 +67,11 @@ SELECT jsonb_agg(v) FROM kv
NULL

# Aggregate functions triggers aggregation and computation when there is no source.
query IIIIRRRRRRBBTR
SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev_samp(1), stddev_pop(1), variance(1), var_pop(1), bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')), corr(1, 2)
query IIIIRRRRRRRBBTRR
SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev_samp(1), stddev_pop(1), variance(1),
var_pop(1), var_samp(1), bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')), corr(1, 2), sqrdiff(1)
----
1 1 1 1 1 1 NULL 0 NULL 0 true true 01 NULL
1 1 1 1 1 1 NULL 0 NULL 0 NULL true true 01 NULL 0

# Aggregate functions triggers aggregation and computation when there is no source.
query T
Expand Down Expand Up @@ -162,10 +167,12 @@ INSERT INTO kv VALUES

# Aggregate functions triggers aggregation and computation for every row even when applied to a constant.
# NB: The XOR result is 00 because \x01 is XOR'd an even number of times.
query IIIIRRRRRRBBT
SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev(1), stddev_pop(1), variance(1)::float, var_pop(1)::float, bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')) FROM kv
query IIIIRRRRRRRBBTR
SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev(1), stddev_pop(1), variance(1)::float,
var_pop(1)::float, var_samp(1)::float, bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')), sqrdiff(1)
FROM kv
----
1 6 1 6 1 6 0 0 0 0 true true 00
1 6 1 6 1 6 0 0 0 0 0 true true 00 0

# Aggregate functions triggers aggregation and computation for every row even when applied to a constant.
query T
Expand Down Expand Up @@ -832,6 +839,40 @@ true
statement ok
DROP TABLE xyz

# SQRDIFF

statement ok
DROP TABLE IF EXISTS ifd;
CREATE TABLE ifd
(
i int,
f float,
d decimal
);
INSERT INTO ifd (i, f, d)
VALUES (1, 1.1, 1.1),
(2, 2.2, 2.2),
(5, 3.0, 3.0),
(10, 7.8, 7.8),
(11, 9.0, 9.0),
(18, 11.2, 11.2);

query FRF
select sqrdiff(i), round(sqrdiff(f), 13), sqrdiff(d)
from ifd
----
206.8333333333333333333334 86.2483333333333 86.24833333333333333333333

query FRF
SELECT sqrdiff(i), round(sqrdiff(f), 2), sqrdiff(d)
FROM ifd
where i < 10
----
8.666666666666666666666666 1.82 1.82

statement ok
DROP TABLE IF EXISTS sqrdiff

# Numerical stability test for VARIANCE/STDDEV.
# See https://www.johndcook.com/blog/2008/09/28/theoretical-explanation-for-numerical-results.
# Avoid using random() since we do not have the deterministic option to specify a pseudo-random seed yet.
Expand Down Expand Up @@ -2817,3 +2858,40 @@ query FI
SELECT corr(DISTINCT y, x), count(DISTINCT y) FROM t55776
----
0.522232967867094 3

# Regression test for window aggregate functions for decimals reuse the results
# from the previous iterations (#55944).
statement ok
CREATE TABLE t55944 (x decimal);
INSERT INTO t55944 (x)
VALUES (1.0),
(20.0),
(25.0),
(41.0),
(55.5),
(60.9),
(72.0),
(88.0),
(88.0),
(89.0);

query RFFFFF
SELECT x,
sqrdiff(x) OVER (ORDER BY x) as sqrdiff,
var_pop(x) OVER (ORDER BY x) as var_pop,
var_samp(x) OVER (ORDER BY x) as var_samp,
stddev_pop(x) OVER (ORDER BY x) as stddev_pop,
stddev_samp(x) OVER (ORDER BY x) as stddev_samp
FROM t55944
ORDER BY x
----
1.0 0 0 NULL 0 NULL
20.0 180.5 90.25 180.5 9.5 13.435028842544402964
25.0 320.6666666666666666666667 106.88888888888888889 160.33333333333333333 10.338708279513881752 12.662279942148385993
41.0 814.7500000000000000000001 203.6875 271.58333333333333333 14.271912976192084391 16.479785597310825856
55.5 1726 345.2 431.5 18.579558659989746915 20.772578077840988103
60.9 2600.8 433.46666666666666667 520.16 20.819862311424316109 22.807016464237491316
72.0 3845.037142857142857142857 549.29102040816326531 640.83952380952380952 23.436958429117103874 25.314808389745394427
88.0 7527.842222222222222222222 836.42691358024691358 940.98027777777777778 28.921046204801217626 30.675401835636607970
88.0 7527.842222222222222222222 836.42691358024691358 940.98027777777777778 28.921046204801217626 30.675401835636607970
89.0 8885.844 888.5844 987.316 29.809132828715430446 31.421584937746218024
3 changes: 2 additions & 1 deletion pkg/sql/sem/builtins/aggregate_builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -2575,7 +2575,8 @@ func (a *decimalSqrDiffAggregate) Result() (tree.Datum, error) {
if a.count.Cmp(decimalOne) < 0 {
return tree.DNull, nil
}
dd := &tree.DDecimal{Decimal: a.sqrDiff}
dd := &tree.DDecimal{}
dd.Set(&a.sqrDiff)
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
Expand Down
39 changes: 39 additions & 0 deletions pkg/sql/sem/builtins/aggregate_builtins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,21 @@ func TestVarianceDecimalResultDeepCopy(t *testing.T) {
testAggregateResultDeepCopy(t, newDecimalVarianceAggregate, makeDecimalTestDatum(10))
}

func TestSqrDiffIntResultDeepCopy(t *testing.T) {
defer leaktest.AfterTest(t)()
testAggregateResultDeepCopy(t, newIntSqrDiffAggregate, makeIntTestDatum(10))
}

func TestSqrDiffFloatResultDeepCopy(t *testing.T) {
defer leaktest.AfterTest(t)()
testAggregateResultDeepCopy(t, newFloatSqrDiffAggregate, makeFloatTestDatum(10))
}

func TestSqrDiffDecimalResultDeepCopy(t *testing.T) {
defer leaktest.AfterTest(t)()
testAggregateResultDeepCopy(t, newDecimalSqrDiffAggregate, makeDecimalTestDatum(10))
}

func TestVarPopIntResultDeepCopy(t *testing.T) {
defer leaktest.AfterTest(t)()
testAggregateResultDeepCopy(t, newIntVarPopAggregate, makeIntTestDatum(10))
Expand Down Expand Up @@ -666,6 +681,30 @@ func BenchmarkVarianceAggregateDecimal(b *testing.B) {
}
}

func BenchmarkSqrDiffAggregateInt(b *testing.B) {
for _, count := range []int{1000} {
b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) {
runBenchmarkAggregate(b, newIntSqrDiffAggregate, makeIntTestDatum(count))
})
}
}

func BenchmarkSqrDiffAggregateFloat(b *testing.B) {
for _, count := range []int{1000} {
b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) {
runBenchmarkAggregate(b, newFloatSqrDiffAggregate, makeFloatTestDatum(count))
})
}
}

func BenchmarkSqrDiffAggregateDecimal(b *testing.B) {
for _, count := range []int{1000} {
b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) {
runBenchmarkAggregate(b, newDecimalSqrDiffAggregate, makeDecimalTestDatum(count))
})
}
}

func BenchmarkVarPopAggregateInt(b *testing.B) {
for _, count := range []int{1000} {
b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) {
Expand Down

0 comments on commit afc91f7

Please sign in to comment.