From afc91f7ea9a0911d950fb68163957a24fb4bc02e Mon Sep 17 00:00:00 2001 From: Max Neverov Date: Sun, 25 Oct 2020 12:54:59 +0100 Subject: [PATCH] builtins: fix aggregate functions for decimals fixes #55944 Release note: none --- .../logictest/testdata/logic_test/aggregate | 102 +++++++++++++++--- pkg/sql/sem/builtins/aggregate_builtins.go | 3 +- .../sem/builtins/aggregate_builtins_test.go | 39 +++++++ 3 files changed, 131 insertions(+), 13 deletions(-) diff --git a/pkg/sql/logictest/testdata/logic_test/aggregate b/pkg/sql/logictest/testdata/logic_test/aggregate index e6481f66d58f..709b4de5188d 100644 --- a/pkg/sql/logictest/testdata/logic_test/aggregate +++ b/pkg/sql/logictest/testdata/logic_test/aggregate @@ -10,10 +10,12 @@ CREATE TABLE kv ( ) # Aggregate functions return NULL if there are no rows. -query IIIIRRRRRRRRBBTIIR -SELECT min(1), max(1), count(1), sum_int(1), avg(1), sum(1), stddev(1), stddev_samp(1), stddev_pop(1), var_samp(1), variance(1), var_pop(1), bool_and(true), bool_and(false), xor_agg(b'\x01'), bit_and(1), bit_or(1), corr(1, 1) FROM kv +query IIIIRRRRRRRRBBTIIRR +SELECT min(1), max(1), count(1), sum_int(1), avg(1), sum(1), stddev(1), stddev_samp(1), stddev_pop(1), var_samp(1), +variance(1), var_pop(1), bool_and(true), bool_and(false), xor_agg(b'\x01'), bit_and(1), bit_or(1), corr(1, 1), sqrdiff(1) +FROM kv ---- -NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL # Regression test for #29695 query T @@ -42,10 +44,12 @@ SELECT min(i), avg(i), max(i), sum(i) FROM kv ---- NULL NULL NULL NULL -query IIIIRRRRRRBBTR -SELECT min(v), max(v), count(v), sum_int(1), avg(v), sum(v), stddev(v), stddev_pop(v), variance(v), var_pop(v), bool_and(v = 1), bool_and(v = 1), xor_agg(s::bytes), corr(v,k) FROM kv +query IIIIRRRRRRRBBTRR +SELECT min(v), max(v), count(v), sum_int(1), avg(v), sum(v), stddev(v), stddev_pop(v), variance(v), var_pop(v), var_samp(v), +bool_and(v = 1), bool_and(v = 1), xor_agg(s::bytes), corr(v,k), sqrdiff(v) +FROM kv ---- -NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL query T SELECT array_agg(v) FROM kv @@ -63,10 +67,11 @@ SELECT jsonb_agg(v) FROM kv NULL # Aggregate functions triggers aggregation and computation when there is no source. -query IIIIRRRRRRBBTR -SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev_samp(1), stddev_pop(1), variance(1), var_pop(1), bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')), corr(1, 2) +query IIIIRRRRRRRBBTRR +SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev_samp(1), stddev_pop(1), variance(1), +var_pop(1), var_samp(1), bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')), corr(1, 2), sqrdiff(1) ---- -1 1 1 1 1 1 NULL 0 NULL 0 true true 01 NULL +1 1 1 1 1 1 NULL 0 NULL 0 NULL true true 01 NULL 0 # Aggregate functions triggers aggregation and computation when there is no source. query T @@ -162,10 +167,12 @@ INSERT INTO kv VALUES # Aggregate functions triggers aggregation and computation for every row even when applied to a constant. # NB: The XOR result is 00 because \x01 is XOR'd an even number of times. -query IIIIRRRRRRBBT -SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev(1), stddev_pop(1), variance(1)::float, var_pop(1)::float, bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')) FROM kv +query IIIIRRRRRRRBBTR +SELECT min(1), count(1), max(1), sum_int(1), avg(1)::float, sum(1), stddev(1), stddev_pop(1), variance(1)::float, +var_pop(1)::float, var_samp(1)::float, bool_and(true), bool_or(true), to_hex(xor_agg(b'\x01')), sqrdiff(1) +FROM kv ---- -1 6 1 6 1 6 0 0 0 0 true true 00 +1 6 1 6 1 6 0 0 0 0 0 true true 00 0 # Aggregate functions triggers aggregation and computation for every row even when applied to a constant. query T @@ -832,6 +839,40 @@ true statement ok DROP TABLE xyz +# SQRDIFF + +statement ok +DROP TABLE IF EXISTS ifd; +CREATE TABLE ifd +( + i int, + f float, + d decimal +); +INSERT INTO ifd (i, f, d) +VALUES (1, 1.1, 1.1), + (2, 2.2, 2.2), + (5, 3.0, 3.0), + (10, 7.8, 7.8), + (11, 9.0, 9.0), + (18, 11.2, 11.2); + +query FRF +select sqrdiff(i), round(sqrdiff(f), 13), sqrdiff(d) +from ifd +---- +206.8333333333333333333334 86.2483333333333 86.24833333333333333333333 + +query FRF +SELECT sqrdiff(i), round(sqrdiff(f), 2), sqrdiff(d) +FROM ifd +where i < 10 +---- +8.666666666666666666666666 1.82 1.82 + +statement ok +DROP TABLE IF EXISTS sqrdiff + # Numerical stability test for VARIANCE/STDDEV. # See https://www.johndcook.com/blog/2008/09/28/theoretical-explanation-for-numerical-results. # Avoid using random() since we do not have the deterministic option to specify a pseudo-random seed yet. @@ -2817,3 +2858,40 @@ query FI SELECT corr(DISTINCT y, x), count(DISTINCT y) FROM t55776 ---- 0.522232967867094 3 + +# Regression test for window aggregate functions for decimals reuse the results +# from the previous iterations (#55944). +statement ok +CREATE TABLE t55944 (x decimal); +INSERT INTO t55944 (x) +VALUES (1.0), + (20.0), + (25.0), + (41.0), + (55.5), + (60.9), + (72.0), + (88.0), + (88.0), + (89.0); + +query RFFFFF +SELECT x, + sqrdiff(x) OVER (ORDER BY x) as sqrdiff, + var_pop(x) OVER (ORDER BY x) as var_pop, + var_samp(x) OVER (ORDER BY x) as var_samp, + stddev_pop(x) OVER (ORDER BY x) as stddev_pop, + stddev_samp(x) OVER (ORDER BY x) as stddev_samp +FROM t55944 +ORDER BY x +---- +1.0 0 0 NULL 0 NULL +20.0 180.5 90.25 180.5 9.5 13.435028842544402964 +25.0 320.6666666666666666666667 106.88888888888888889 160.33333333333333333 10.338708279513881752 12.662279942148385993 +41.0 814.7500000000000000000001 203.6875 271.58333333333333333 14.271912976192084391 16.479785597310825856 +55.5 1726 345.2 431.5 18.579558659989746915 20.772578077840988103 +60.9 2600.8 433.46666666666666667 520.16 20.819862311424316109 22.807016464237491316 +72.0 3845.037142857142857142857 549.29102040816326531 640.83952380952380952 23.436958429117103874 25.314808389745394427 +88.0 7527.842222222222222222222 836.42691358024691358 940.98027777777777778 28.921046204801217626 30.675401835636607970 +88.0 7527.842222222222222222222 836.42691358024691358 940.98027777777777778 28.921046204801217626 30.675401835636607970 +89.0 8885.844 888.5844 987.316 29.809132828715430446 31.421584937746218024 diff --git a/pkg/sql/sem/builtins/aggregate_builtins.go b/pkg/sql/sem/builtins/aggregate_builtins.go index 99c8d40f3782..0ed3a904af43 100644 --- a/pkg/sql/sem/builtins/aggregate_builtins.go +++ b/pkg/sql/sem/builtins/aggregate_builtins.go @@ -2575,7 +2575,8 @@ func (a *decimalSqrDiffAggregate) Result() (tree.Datum, error) { if a.count.Cmp(decimalOne) < 0 { return tree.DNull, nil } - dd := &tree.DDecimal{Decimal: a.sqrDiff} + dd := &tree.DDecimal{} + dd.Set(&a.sqrDiff) // Remove trailing zeros. Depending on the order in which the input // is processed, some number of trailing zeros could be added to the // output. Remove them so that the results are the same regardless of order. diff --git a/pkg/sql/sem/builtins/aggregate_builtins_test.go b/pkg/sql/sem/builtins/aggregate_builtins_test.go index f6160339da03..281afe4bf358 100644 --- a/pkg/sql/sem/builtins/aggregate_builtins_test.go +++ b/pkg/sql/sem/builtins/aggregate_builtins_test.go @@ -256,6 +256,21 @@ func TestVarianceDecimalResultDeepCopy(t *testing.T) { testAggregateResultDeepCopy(t, newDecimalVarianceAggregate, makeDecimalTestDatum(10)) } +func TestSqrDiffIntResultDeepCopy(t *testing.T) { + defer leaktest.AfterTest(t)() + testAggregateResultDeepCopy(t, newIntSqrDiffAggregate, makeIntTestDatum(10)) +} + +func TestSqrDiffFloatResultDeepCopy(t *testing.T) { + defer leaktest.AfterTest(t)() + testAggregateResultDeepCopy(t, newFloatSqrDiffAggregate, makeFloatTestDatum(10)) +} + +func TestSqrDiffDecimalResultDeepCopy(t *testing.T) { + defer leaktest.AfterTest(t)() + testAggregateResultDeepCopy(t, newDecimalSqrDiffAggregate, makeDecimalTestDatum(10)) +} + func TestVarPopIntResultDeepCopy(t *testing.T) { defer leaktest.AfterTest(t)() testAggregateResultDeepCopy(t, newIntVarPopAggregate, makeIntTestDatum(10)) @@ -666,6 +681,30 @@ func BenchmarkVarianceAggregateDecimal(b *testing.B) { } } +func BenchmarkSqrDiffAggregateInt(b *testing.B) { + for _, count := range []int{1000} { + b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) { + runBenchmarkAggregate(b, newIntSqrDiffAggregate, makeIntTestDatum(count)) + }) + } +} + +func BenchmarkSqrDiffAggregateFloat(b *testing.B) { + for _, count := range []int{1000} { + b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) { + runBenchmarkAggregate(b, newFloatSqrDiffAggregate, makeFloatTestDatum(count)) + }) + } +} + +func BenchmarkSqrDiffAggregateDecimal(b *testing.B) { + for _, count := range []int{1000} { + b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) { + runBenchmarkAggregate(b, newDecimalSqrDiffAggregate, makeDecimalTestDatum(count)) + }) + } +} + func BenchmarkVarPopAggregateInt(b *testing.B) { for _, count := range []int{1000} { b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) {