diff --git a/.travis.yml b/.travis.yml index 7923eda025b14..5c0285508c0c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,10 +3,18 @@ language: go go_import_path: github.com/pingcap/tidb go: - - "1.11.2" + - "1.11.3" env: - TRAVIS_COVERAGE=0 + - TRAVIS_COVERAGE=1 + +# Run coverage tests. +matrix: + fast_finish: true + allow_failures: + - go: "1.11.3" + env: TRAVIS_COVERAGE=1 before_install: # create /logs/unit-test for unit test. @@ -15,4 +23,4 @@ before_install: # See https://github.com/golang/go/issues/12933 - bash gitcookie.sh script: - - make dev + - make dev upload-coverage diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3e23f56c60f38..eed7fbd2b1279 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,7 +15,7 @@ A Contributor refers to the person who contributes to the following projects: ## How to become a TiDB Contributor? -If a PR (Pull Request) submitted to the TiDB / TiKV / TiSpark / PD / Docs/Docs-cn projects by you is approved and merged, then you become a TiDB Contributor. +If a PR (Pull Request) submitted to the TiDB/TiKV/TiSpark/PD/Docs/Docs-cn projects by you is approved and merged, then you become a TiDB Contributor. You are also encouraged to participate in the projects in the following ways: - Actively answer technical questions asked by community users. diff --git a/Makefile b/Makefile index ff64fceb30611..be1c07d0cccdd 100644 --- a/Makefile +++ b/Makefile @@ -14,8 +14,7 @@ export PATH := $(path_to_add):$(PATH) GO := GO111MODULE=on go GOBUILD := CGO_ENABLED=0 $(GO) build $(BUILD_FLAG) GOTEST := CGO_ENABLED=1 $(GO) test -p 3 -OVERALLS := CGO_ENABLED=1 overalls -GOVERALLS := goveralls +OVERALLS := CGO_ENABLED=1 GO111MODULE=on overalls ARCH := "`uname -s`" LINUX := "Linux" @@ -25,8 +24,8 @@ PACKAGES := $$($(PACKAGE_LIST)) PACKAGE_DIRECTORIES := $(PACKAGE_LIST) | sed 's|github.com/pingcap/$(PROJECT)/||' FILES := $$(find $$($(PACKAGE_DIRECTORIES)) -name "*.go") -GOFAIL_ENABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs gofail enable) -GOFAIL_DISABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs gofail disable) +GOFAIL_ENABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs tools/bin/gofail enable) +GOFAIL_DISABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs tools/bin/gofail disable) LDFLAGS += -X "github.com/pingcap/parser/mysql.TiDBReleaseVersion=$(shell git describe --tags --dirty)" LDFLAGS += -X "github.com/pingcap/tidb/util/printer.TiDBBuildTS=$(shell date -u '+%Y-%m-%d %I:%M:%S')" @@ -117,18 +116,19 @@ test: checklist checkdep gotest explaintest explaintest: server @cd cmd/explaintest && ./run-tests.sh -s ../../bin/tidb-server -gotest: - @rm -rf $GOPATH/bin/gofail - $(GO) get github.com/pingcap/gofail - @which gofail - @$(GOFAIL_ENABLE) +upload-coverage: SHELL:=/bin/bash +upload-coverage: +ifeq ("$(TRAVIS_COVERAGE)", "1") + mv overalls.coverprofile coverage.txt + bash <(curl -s https://codecov.io/bash) +endif + +gotest: gofail-enable ifeq ("$(TRAVIS_COVERAGE)", "1") @echo "Running in TRAVIS_COVERAGE mode." @export log_level=error; \ - go get github.com/go-playground/overalls - go get github.com/mattn/goveralls - $(OVERALLS) -project=github.com/pingcap/tidb -covermode=count -ignore='.git,vendor,cmd,docs,LICENSES' || { $(GOFAIL_DISABLE); exit 1; } - $(GOVERALLS) -service=travis-ci -coverprofile=overalls.coverprofile || { $(GOFAIL_DISABLE); exit 1; } + $(GO) get github.com/go-playground/overalls + $(OVERALLS) -project=github.com/pingcap/tidb -covermode=count -ignore='.git,vendor,cmd,docs,LICENSES' -concurrency=1 || { $(GOFAIL_DISABLE); exit 1; } else @echo "Running in native mode." @export log_level=error; \ @@ -136,23 +136,17 @@ else endif @$(GOFAIL_DISABLE) -race: - $(GO) get github.com/pingcap/gofail - @$(GOFAIL_ENABLE) +race: gofail-enable @export log_level=debug; \ $(GOTEST) -timeout 20m -race $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; } @$(GOFAIL_DISABLE) -leak: - $(GO) get github.com/pingcap/gofail - @$(GOFAIL_ENABLE) +leak: gofail-enable @export log_level=debug; \ $(GOTEST) -tags leak $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; } @$(GOFAIL_DISABLE) -tikv_integration_test: - $(GO) get github.com/pingcap/gofail - @$(GOFAIL_ENABLE) +tikv_integration_test: gofail-enable $(GOTEST) ./store/tikv/. -with-tikv=true || { $(GOFAIL_DISABLE); exit 1; } @$(GOFAIL_DISABLE) @@ -196,37 +190,40 @@ importer: checklist: cat checklist.md -gofail-enable: +gofail-enable: tools/bin/gofail # Converting gofail failpoints... @$(GOFAIL_ENABLE) -gofail-disable: +gofail-disable: tools/bin/gofail # Restoring gofail failpoints... @$(GOFAIL_DISABLE) checkdep: $(GO) list -f '{{ join .Imports "\n" }}' github.com/pingcap/tidb/store/tikv | grep ^github.com/pingcap/parser$$ || exit 0; exit 1 -tools/bin/megacheck: +tools/bin/megacheck: tools/check/go.mod cd tools/check; \ - $go build -o ../bin/megacheck honnef.co/go/tools/cmd/megacheck + $(GO) build -o ../bin/megacheck honnef.co/go/tools/cmd/megacheck -tools/bin/revive: +tools/bin/revive: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/revive github.com/mgechev/revive -tools/bin/goword: +tools/bin/goword: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/goword github.com/chzchzchz/goword -tools/bin/gometalinter: +tools/bin/gometalinter: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/gometalinter gopkg.in/alecthomas/gometalinter.v2 -tools/bin/gosec: +tools/bin/gosec: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/gosec github.com/securego/gosec/cmd/gosec -tools/bin/errcheck: +tools/bin/errcheck: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/errcheck github.com/kisielk/errcheck + +tools/bin/gofail: go.mod + $(GO) build -o $@ github.com/pingcap/gofail diff --git a/README.md b/README.md index 88f21ea71e972..0029c901c0f2a 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ [![Build Status](https://travis-ci.org/pingcap/tidb.svg?branch=master)](https://travis-ci.org/pingcap/tidb) [![Go Report Card](https://goreportcard.com/badge/github.com/pingcap/tidb)](https://goreportcard.com/report/github.com/pingcap/tidb) -![GitHub release](https://img.shields.io/github/release/pingcap/tidb.svg) +![GitHub release](https://img.shields.io/github/tag/pingcap/tidb.svg?label=release) [![CircleCI Status](https://circleci.com/gh/pingcap/tidb.svg?style=shield)](https://circleci.com/gh/pingcap/tidb) -[![Coverage Status](https://coveralls.io/repos/github/pingcap/tidb/badge.svg?branch=master)](https://coveralls.io/github/pingcap/tidb?branch=master) +[![Coverage Status](https://codecov.io/gh/pingcap/tidb/branch/master/graph/badge.svg)](https://codecov.io/gh/pingcap/tidb) ## What is TiDB? diff --git a/cmd/explaintest/r/explain_complex.result b/cmd/explaintest/r/explain_complex.result index 2b43ae3a2357a..a9d2b6b249e15 100644 --- a/cmd/explaintest/r/explain_complex.result +++ b/cmd/explaintest/r/explain_complex.result @@ -179,23 +179,27 @@ CREATE TABLE `tbl_008` (`a` int, `b` int); CREATE TABLE `tbl_009` (`a` int, `b` int); explain select sum(a) from (select * from tbl_001 union all select * from tbl_002 union all select * from tbl_003 union all select * from tbl_004 union all select * from tbl_005 union all select * from tbl_006 union all select * from tbl_007 union all select * from tbl_008 union all select * from tbl_009) x group by b; id count task operator info -HashAgg_34 72000.00 root group by:x.b, funcs:sum(x.a) -└─Union_35 90000.00 root - ├─TableReader_38 10000.00 root data:TableScan_37 - │ └─TableScan_37 10000.00 cop table:tbl_001, range:[-inf,+inf], keep order:false, stats:pseudo - ├─TableReader_41 10000.00 root data:TableScan_40 - │ └─TableScan_40 10000.00 cop table:tbl_002, range:[-inf,+inf], keep order:false, stats:pseudo - ├─TableReader_44 10000.00 root data:TableScan_43 - │ └─TableScan_43 10000.00 cop table:tbl_003, range:[-inf,+inf], keep order:false, stats:pseudo - ├─TableReader_47 10000.00 root data:TableScan_46 - │ └─TableScan_46 10000.00 cop table:tbl_004, range:[-inf,+inf], keep order:false, stats:pseudo - ├─TableReader_50 10000.00 root data:TableScan_49 - │ └─TableScan_49 10000.00 cop table:tbl_005, range:[-inf,+inf], keep order:false, stats:pseudo - ├─TableReader_53 10000.00 root data:TableScan_52 - │ └─TableScan_52 10000.00 cop table:tbl_006, range:[-inf,+inf], keep order:false, stats:pseudo - ├─TableReader_56 10000.00 root data:TableScan_55 - │ └─TableScan_55 10000.00 cop table:tbl_007, range:[-inf,+inf], keep order:false, stats:pseudo - ├─TableReader_59 10000.00 root data:TableScan_58 - │ └─TableScan_58 10000.00 cop table:tbl_008, range:[-inf,+inf], keep order:false, stats:pseudo - └─TableReader_62 10000.00 root data:TableScan_61 - └─TableScan_61 10000.00 cop table:tbl_009, range:[-inf,+inf], keep order:false, stats:pseudo +HashAgg_34 72000.00 root group by:col_1, funcs:sum(col_0) +└─Projection_63 90000.00 root cast(x.a), x.b + └─Union_35 90000.00 root + ├─TableReader_38 10000.00 root data:TableScan_37 + │ └─TableScan_37 10000.00 cop table:tbl_001, range:[-inf,+inf], keep order:false, stats:pseudo + ├─TableReader_41 10000.00 root data:TableScan_40 + │ └─TableScan_40 10000.00 cop table:tbl_002, range:[-inf,+inf], keep order:false, stats:pseudo + ├─TableReader_44 10000.00 root data:TableScan_43 + │ └─TableScan_43 10000.00 cop table:tbl_003, range:[-inf,+inf], keep order:false, stats:pseudo + ├─TableReader_47 10000.00 root data:TableScan_46 + │ └─TableScan_46 10000.00 cop table:tbl_004, range:[-inf,+inf], keep order:false, stats:pseudo + ├─TableReader_50 10000.00 root data:TableScan_49 + │ └─TableScan_49 10000.00 cop table:tbl_005, range:[-inf,+inf], keep order:false, stats:pseudo + ├─TableReader_53 10000.00 root data:TableScan_52 + │ └─TableScan_52 10000.00 cop table:tbl_006, range:[-inf,+inf], keep order:false, stats:pseudo + ├─TableReader_56 10000.00 root data:TableScan_55 + │ └─TableScan_55 10000.00 cop table:tbl_007, range:[-inf,+inf], keep order:false, stats:pseudo + ├─TableReader_59 10000.00 root data:TableScan_58 + │ └─TableScan_58 10000.00 cop table:tbl_008, range:[-inf,+inf], keep order:false, stats:pseudo + └─TableReader_62 10000.00 root data:TableScan_61 + └─TableScan_61 10000.00 cop table:tbl_009, range:[-inf,+inf], keep order:false, stats:pseudo + + + diff --git a/cmd/explaintest/r/explain_complex_stats.result b/cmd/explaintest/r/explain_complex_stats.result index 1569a2d3a850c..7b6a89f0e13f8 100644 --- a/cmd/explaintest/r/explain_complex_stats.result +++ b/cmd/explaintest/r/explain_complex_stats.result @@ -205,23 +205,27 @@ CREATE TABLE tbl_009 (a int, b int); load stats 's/explain_complex_stats_tbl_009.json'; explain select sum(a) from (select * from tbl_001 union all select * from tbl_002 union all select * from tbl_003 union all select * from tbl_004 union all select * from tbl_005 union all select * from tbl_006 union all select * from tbl_007 union all select * from tbl_008 union all select * from tbl_009) x group by b; id count task operator info -HashAgg_34 18000.00 root group by:x.b, funcs:sum(x.a) -└─Union_35 18000.00 root - ├─TableReader_38 2000.00 root data:TableScan_37 - │ └─TableScan_37 2000.00 cop table:tbl_001, range:[-inf,+inf], keep order:false - ├─TableReader_41 2000.00 root data:TableScan_40 - │ └─TableScan_40 2000.00 cop table:tbl_002, range:[-inf,+inf], keep order:false - ├─TableReader_44 2000.00 root data:TableScan_43 - │ └─TableScan_43 2000.00 cop table:tbl_003, range:[-inf,+inf], keep order:false - ├─TableReader_47 2000.00 root data:TableScan_46 - │ └─TableScan_46 2000.00 cop table:tbl_004, range:[-inf,+inf], keep order:false - ├─TableReader_50 2000.00 root data:TableScan_49 - │ └─TableScan_49 2000.00 cop table:tbl_005, range:[-inf,+inf], keep order:false - ├─TableReader_53 2000.00 root data:TableScan_52 - │ └─TableScan_52 2000.00 cop table:tbl_006, range:[-inf,+inf], keep order:false - ├─TableReader_56 2000.00 root data:TableScan_55 - │ └─TableScan_55 2000.00 cop table:tbl_007, range:[-inf,+inf], keep order:false - ├─TableReader_59 2000.00 root data:TableScan_58 - │ └─TableScan_58 2000.00 cop table:tbl_008, range:[-inf,+inf], keep order:false - └─TableReader_62 2000.00 root data:TableScan_61 - └─TableScan_61 2000.00 cop table:tbl_009, range:[-inf,+inf], keep order:false +HashAgg_34 18000.00 root group by:col_1, funcs:sum(col_0) +└─Projection_63 18000.00 root cast(x.a), x.b + └─Union_35 18000.00 root + ├─TableReader_38 2000.00 root data:TableScan_37 + │ └─TableScan_37 2000.00 cop table:tbl_001, range:[-inf,+inf], keep order:false + ├─TableReader_41 2000.00 root data:TableScan_40 + │ └─TableScan_40 2000.00 cop table:tbl_002, range:[-inf,+inf], keep order:false + ├─TableReader_44 2000.00 root data:TableScan_43 + │ └─TableScan_43 2000.00 cop table:tbl_003, range:[-inf,+inf], keep order:false + ├─TableReader_47 2000.00 root data:TableScan_46 + │ └─TableScan_46 2000.00 cop table:tbl_004, range:[-inf,+inf], keep order:false + ├─TableReader_50 2000.00 root data:TableScan_49 + │ └─TableScan_49 2000.00 cop table:tbl_005, range:[-inf,+inf], keep order:false + ├─TableReader_53 2000.00 root data:TableScan_52 + │ └─TableScan_52 2000.00 cop table:tbl_006, range:[-inf,+inf], keep order:false + ├─TableReader_56 2000.00 root data:TableScan_55 + │ └─TableScan_55 2000.00 cop table:tbl_007, range:[-inf,+inf], keep order:false + ├─TableReader_59 2000.00 root data:TableScan_58 + │ └─TableScan_58 2000.00 cop table:tbl_008, range:[-inf,+inf], keep order:false + └─TableReader_62 2000.00 root data:TableScan_61 + └─TableScan_61 2000.00 cop table:tbl_009, range:[-inf,+inf], keep order:false + + + diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index dca41dcc8e468..340f36e457af8 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -86,12 +86,13 @@ TableReader_7 0.33 root data:Selection_6 └─TableScan_5 1.00 cop table:t1, range:[1,1], keep order:false, stats:pseudo explain select sum(t1.c1 in (select c1 from t2)) from t1; id count task operator info -StreamAgg_12 1.00 root funcs:sum(5_aux_0) -└─MergeJoin_28 10000.00 root left outer semi join, left key:test.t1.c1, right key:test.t2.c1 - ├─TableReader_19 10000.00 root data:TableScan_18 - │ └─TableScan_18 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo - └─IndexReader_23 10000.00 root index:IndexScan_22 - └─IndexScan_22 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo +StreamAgg_12 1.00 root funcs:sum(col_0) +└─Projection_35 10000.00 root cast(5_aux_0) + └─MergeJoin_28 10000.00 root left outer semi join, left key:test.t1.c1, right key:test.t2.c1 + ├─TableReader_19 10000.00 root data:TableScan_18 + │ └─TableScan_18 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo + └─IndexReader_23 10000.00 root index:IndexScan_22 + └─IndexScan_22 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo explain select c1 from t1 where c1 in (select c2 from t2); id count task operator info Projection_9 10000.00 root test.t1.c1 @@ -190,12 +191,13 @@ HashAgg_18 24000.00 root group by:c1, funcs:firstrow(join_agg_0) set @@session.tidb_opt_insubq_to_join_and_agg=0; explain select sum(t1.c1 in (select c1 from t2)) from t1; id count task operator info -StreamAgg_12 1.00 root funcs:sum(5_aux_0) -└─MergeJoin_28 10000.00 root left outer semi join, left key:test.t1.c1, right key:test.t2.c1 - ├─TableReader_19 10000.00 root data:TableScan_18 - │ └─TableScan_18 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo - └─IndexReader_23 10000.00 root index:IndexScan_22 - └─IndexScan_22 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo +StreamAgg_12 1.00 root funcs:sum(col_0) +└─Projection_35 10000.00 root cast(5_aux_0) + └─MergeJoin_28 10000.00 root left outer semi join, left key:test.t1.c1, right key:test.t2.c1 + ├─TableReader_19 10000.00 root data:TableScan_18 + │ └─TableScan_18 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo + └─IndexReader_23 10000.00 root index:IndexScan_22 + └─IndexScan_22 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo explain select 1 in (select c2 from t2) from t1; id count task operator info Projection_6 10000.00 root 5_aux_0 @@ -207,13 +209,14 @@ Projection_6 10000.00 root 5_aux_0 └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo explain select sum(6 in (select c2 from t2)) from t1; id count task operator info -StreamAgg_12 1.00 root funcs:sum(5_aux_0) -└─HashLeftJoin_19 10000.00 root left outer semi join, inner:TableReader_18 - ├─TableReader_21 10000.00 root data:TableScan_20 - │ └─TableScan_20 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─TableReader_18 10.00 root data:Selection_17 - └─Selection_17 10.00 cop eq(6, test.t2.c2) - └─TableScan_16 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +StreamAgg_12 1.00 root funcs:sum(col_0) +└─Projection_22 10000.00 root cast(5_aux_0) + └─HashLeftJoin_19 10000.00 root left outer semi join, inner:TableReader_18 + ├─TableReader_21 10000.00 root data:TableScan_20 + │ └─TableScan_20 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─TableReader_18 10.00 root data:Selection_17 + └─Selection_17 10.00 cop eq(6, test.t2.c2) + └─TableScan_16 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo explain format="dot" select sum(t1.c1 in (select c1 from t2)) from t1; dot contents @@ -222,7 +225,8 @@ subgraph cluster12{ node [style=filled, color=lightgrey] color=black label = "root" -"StreamAgg_12" -> "MergeJoin_28" +"StreamAgg_12" -> "Projection_35" +"Projection_35" -> "MergeJoin_28" "MergeJoin_28" -> "TableReader_19" "MergeJoin_28" -> "IndexReader_23" } diff --git a/cmd/explaintest/r/tpch.result b/cmd/explaintest/r/tpch.result index 5a9b2b50fd2f8..333c9d9b92c0e 100644 --- a/cmd/explaintest/r/tpch.result +++ b/cmd/explaintest/r/tpch.result @@ -250,19 +250,20 @@ limit 10; id count task operator info Projection_14 10.00 root tpch.lineitem.l_orderkey, 7_col_0, tpch.orders.o_orderdate, tpch.orders.o_shippriority └─TopN_17 10.00 root 7_col_0:desc, tpch.orders.o_orderdate:asc, offset:0, count:10 - └─HashAgg_23 40227041.09 root group by:tpch.lineitem.l_orderkey, tpch.orders.o_orderdate, tpch.orders.o_shippriority, funcs:sum(mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount))), firstrow(tpch.orders.o_orderdate), firstrow(tpch.orders.o_shippriority), firstrow(tpch.lineitem.l_orderkey) - └─IndexJoin_29 91515927.49 root inner join, inner:IndexLookUp_28, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey - ├─HashRightJoin_49 22592975.51 root inner join, inner:TableReader_55, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] - │ ├─TableReader_55 1498236.00 root data:Selection_54 - │ │ └─Selection_54 1498236.00 cop eq(tpch.customer.c_mktsegment, "AUTOMOBILE") - │ │ └─TableScan_53 7500000.00 cop table:customer, range:[-inf,+inf], keep order:false - │ └─TableReader_52 36870000.00 root data:Selection_51 - │ └─Selection_51 36870000.00 cop lt(tpch.orders.o_orderdate, 1995-03-13 00:00:00.000000) - │ └─TableScan_50 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false - └─IndexLookUp_28 162945114.27 root - ├─IndexScan_25 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [tpch.orders.o_orderkey], keep order:false - └─Selection_27 162945114.27 cop gt(tpch.lineitem.l_shipdate, 1995-03-13 00:00:00.000000) - └─TableScan_26 1.00 cop table:lineitem, keep order:false + └─HashAgg_23 40227041.09 root group by:col_4, col_5, col_6, funcs:sum(col_0), firstrow(col_1), firstrow(col_2), firstrow(col_3) + └─Projection_59 91515927.49 root mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), tpch.orders.o_orderdate, tpch.orders.o_shippriority, tpch.lineitem.l_orderkey, tpch.lineitem.l_orderkey, tpch.orders.o_orderdate, tpch.orders.o_shippriority + └─IndexJoin_29 91515927.49 root inner join, inner:IndexLookUp_28, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey + ├─HashRightJoin_49 22592975.51 root inner join, inner:TableReader_55, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] + │ ├─TableReader_55 1498236.00 root data:Selection_54 + │ │ └─Selection_54 1498236.00 cop eq(tpch.customer.c_mktsegment, "AUTOMOBILE") + │ │ └─TableScan_53 7500000.00 cop table:customer, range:[-inf,+inf], keep order:false + │ └─TableReader_52 36870000.00 root data:Selection_51 + │ └─Selection_51 36870000.00 cop lt(tpch.orders.o_orderdate, 1995-03-13 00:00:00.000000) + │ └─TableScan_50 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false + └─IndexLookUp_28 162945114.27 root + ├─IndexScan_25 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [tpch.orders.o_orderkey], keep order:false + └─Selection_27 162945114.27 cop gt(tpch.lineitem.l_shipdate, 1995-03-13 00:00:00.000000) + └─TableScan_26 1.00 cop table:lineitem, keep order:false /* Q4 Order Priority Checking Query This query determines how well the order priority system is working and gives an assessment of customer satisfaction. @@ -343,26 +344,27 @@ revenue desc; id count task operator info Sort_23 5.00 root revenue:desc └─Projection_25 5.00 root tpch.nation.n_name, 13_col_0 - └─HashAgg_28 5.00 root group by:tpch.nation.n_name, funcs:sum(mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount))), firstrow(tpch.nation.n_name) - └─IndexJoin_31 11822812.50 root inner join, inner:TableReader_30, outer key:tpch.orders.o_custkey, inner key:tpch.customer.c_custkey, other cond:eq(tpch.supplier.s_nationkey, tpch.customer.c_nationkey) - ├─IndexJoin_38 11822812.50 root inner join, inner:TableReader_37, outer key:tpch.lineitem.l_orderkey, inner key:tpch.orders.o_orderkey - │ ├─HashRightJoin_42 61163763.01 root inner join, inner:HashRightJoin_44, equal:[eq(tpch.supplier.s_suppkey, tpch.lineitem.l_suppkey)] - │ │ ├─HashRightJoin_44 100000.00 root inner join, inner:HashRightJoin_50, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] - │ │ │ ├─HashRightJoin_50 5.00 root inner join, inner:TableReader_55, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)] - │ │ │ │ ├─TableReader_55 1.00 root data:Selection_54 - │ │ │ │ │ └─Selection_54 1.00 cop eq(tpch.region.r_name, "MIDDLE EAST") - │ │ │ │ │ └─TableScan_53 5.00 cop table:region, range:[-inf,+inf], keep order:false - │ │ │ │ └─TableReader_52 25.00 root data:TableScan_51 - │ │ │ │ └─TableScan_51 25.00 cop table:nation, range:[-inf,+inf], keep order:false - │ │ │ └─TableReader_57 500000.00 root data:TableScan_56 - │ │ │ └─TableScan_56 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false - │ │ └─TableReader_59 300005811.00 root data:TableScan_58 - │ │ └─TableScan_58 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false - │ └─TableReader_37 11822812.50 root data:Selection_36 - │ └─Selection_36 11822812.50 cop ge(tpch.orders.o_orderdate, 1994-01-01 00:00:00.000000), lt(tpch.orders.o_orderdate, 1995-01-01) - │ └─TableScan_35 1.00 cop table:orders, range: decided by [tpch.lineitem.l_orderkey], keep order:false - └─TableReader_30 1.00 root data:TableScan_29 - └─TableScan_29 1.00 cop table:customer, range: decided by [tpch.supplier.s_nationkey tpch.orders.o_custkey], keep order:false + └─HashAgg_28 5.00 root group by:col_2, funcs:sum(col_0), firstrow(col_1) + └─Projection_66 11822812.50 root mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), tpch.nation.n_name, tpch.nation.n_name + └─IndexJoin_31 11822812.50 root inner join, inner:TableReader_30, outer key:tpch.orders.o_custkey, inner key:tpch.customer.c_custkey, other cond:eq(tpch.supplier.s_nationkey, tpch.customer.c_nationkey) + ├─IndexJoin_38 11822812.50 root inner join, inner:TableReader_37, outer key:tpch.lineitem.l_orderkey, inner key:tpch.orders.o_orderkey + │ ├─HashRightJoin_42 61163763.01 root inner join, inner:HashRightJoin_44, equal:[eq(tpch.supplier.s_suppkey, tpch.lineitem.l_suppkey)] + │ │ ├─HashRightJoin_44 100000.00 root inner join, inner:HashRightJoin_50, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] + │ │ │ ├─HashRightJoin_50 5.00 root inner join, inner:TableReader_55, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)] + │ │ │ │ ├─TableReader_55 1.00 root data:Selection_54 + │ │ │ │ │ └─Selection_54 1.00 cop eq(tpch.region.r_name, "MIDDLE EAST") + │ │ │ │ │ └─TableScan_53 5.00 cop table:region, range:[-inf,+inf], keep order:false + │ │ │ │ └─TableReader_52 25.00 root data:TableScan_51 + │ │ │ │ └─TableScan_51 25.00 cop table:nation, range:[-inf,+inf], keep order:false + │ │ │ └─TableReader_57 500000.00 root data:TableScan_56 + │ │ │ └─TableScan_56 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false + │ │ └─TableReader_59 300005811.00 root data:TableScan_58 + │ │ └─TableScan_58 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false + │ └─TableReader_37 11822812.50 root data:Selection_36 + │ └─Selection_36 11822812.50 cop ge(tpch.orders.o_orderdate, 1994-01-01 00:00:00.000000), lt(tpch.orders.o_orderdate, 1995-01-01) + │ └─TableScan_35 1.00 cop table:orders, range: decided by [tpch.lineitem.l_orderkey], keep order:false + └─TableReader_30 1.00 root data:TableScan_29 + └─TableScan_29 1.00 cop table:customer, range: decided by [tpch.supplier.s_nationkey tpch.orders.o_custkey], keep order:false /* Q6 Forecasting Revenue Change Query This query quantifies the amount of revenue increase that would have resulted from eliminating certain companywide @@ -515,35 +517,36 @@ o_year; id count task operator info Sort_29 718.01 root all_nations.o_year:asc └─Projection_31 718.01 root all_nations.o_year, div(18_col_0, 18_col_1) - └─HashAgg_34 718.01 root group by:all_nations.o_year, funcs:sum(case(eq(all_nations.nation, "INDIA"), all_nations.volume, 0)), sum(all_nations.volume), firstrow(all_nations.o_year) - └─Projection_35 562348.12 root extract("YEAR", tpch.orders.o_orderdate), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), n2.n_name - └─HashLeftJoin_39 562348.12 root inner join, inner:TableReader_87, equal:[eq(tpch.supplier.s_nationkey, n2.n_nationkey)] - ├─IndexJoin_43 562348.12 root inner join, inner:TableReader_42, outer key:tpch.lineitem.l_suppkey, inner key:tpch.supplier.s_suppkey - │ ├─HashLeftJoin_50 562348.12 root inner join, inner:TableReader_83, equal:[eq(tpch.lineitem.l_partkey, tpch.part.p_partkey)] - │ │ ├─IndexJoin_56 90661378.61 root inner join, inner:IndexLookUp_55, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey - │ │ │ ├─HashRightJoin_60 22382008.93 root inner join, inner:HashRightJoin_62, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] - │ │ │ │ ├─HashRightJoin_62 1500000.00 root inner join, inner:HashRightJoin_68, equal:[eq(n1.n_nationkey, tpch.customer.c_nationkey)] - │ │ │ │ │ ├─HashRightJoin_68 5.00 root inner join, inner:TableReader_73, equal:[eq(tpch.region.r_regionkey, n1.n_regionkey)] - │ │ │ │ │ │ ├─TableReader_73 1.00 root data:Selection_72 - │ │ │ │ │ │ │ └─Selection_72 1.00 cop eq(tpch.region.r_name, "ASIA") - │ │ │ │ │ │ │ └─TableScan_71 5.00 cop table:region, range:[-inf,+inf], keep order:false - │ │ │ │ │ │ └─TableReader_70 25.00 root data:TableScan_69 - │ │ │ │ │ │ └─TableScan_69 25.00 cop table:n1, range:[-inf,+inf], keep order:false - │ │ │ │ │ └─TableReader_75 7500000.00 root data:TableScan_74 - │ │ │ │ │ └─TableScan_74 7500000.00 cop table:customer, range:[-inf,+inf], keep order:false - │ │ │ │ └─TableReader_78 22382008.93 root data:Selection_77 - │ │ │ │ └─Selection_77 22382008.93 cop ge(tpch.orders.o_orderdate, 1995-01-01 00:00:00.000000), le(tpch.orders.o_orderdate, 1996-12-31 00:00:00.000000) - │ │ │ │ └─TableScan_76 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false - │ │ │ └─IndexLookUp_55 1.00 root - │ │ │ ├─IndexScan_53 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [tpch.orders.o_orderkey], keep order:false - │ │ │ └─TableScan_54 1.00 cop table:lineitem, keep order:false - │ │ └─TableReader_83 61674.00 root data:Selection_82 - │ │ └─Selection_82 61674.00 cop eq(tpch.part.p_type, "SMALL PLATED COPPER") - │ │ └─TableScan_81 10000000.00 cop table:part, range:[-inf,+inf], keep order:false - │ └─TableReader_42 1.00 root data:TableScan_41 - │ └─TableScan_41 1.00 cop table:supplier, range: decided by [tpch.lineitem.l_suppkey], keep order:false - └─TableReader_87 25.00 root data:TableScan_86 - └─TableScan_86 25.00 cop table:n2, range:[-inf,+inf], keep order:false + └─HashAgg_34 718.01 root group by:col_3, funcs:sum(col_0), sum(col_1), firstrow(col_2) + └─Projection_89 562348.12 root case(eq(all_nations.nation, "INDIA"), all_nations.volume, 0), all_nations.volume, all_nations.o_year, all_nations.o_year + └─Projection_35 562348.12 root extract("YEAR", tpch.orders.o_orderdate), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), n2.n_name + └─HashLeftJoin_39 562348.12 root inner join, inner:TableReader_87, equal:[eq(tpch.supplier.s_nationkey, n2.n_nationkey)] + ├─IndexJoin_43 562348.12 root inner join, inner:TableReader_42, outer key:tpch.lineitem.l_suppkey, inner key:tpch.supplier.s_suppkey + │ ├─HashLeftJoin_50 562348.12 root inner join, inner:TableReader_83, equal:[eq(tpch.lineitem.l_partkey, tpch.part.p_partkey)] + │ │ ├─IndexJoin_56 90661378.61 root inner join, inner:IndexLookUp_55, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey + │ │ │ ├─HashRightJoin_60 22382008.93 root inner join, inner:HashRightJoin_62, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] + │ │ │ │ ├─HashRightJoin_62 1500000.00 root inner join, inner:HashRightJoin_68, equal:[eq(n1.n_nationkey, tpch.customer.c_nationkey)] + │ │ │ │ │ ├─HashRightJoin_68 5.00 root inner join, inner:TableReader_73, equal:[eq(tpch.region.r_regionkey, n1.n_regionkey)] + │ │ │ │ │ │ ├─TableReader_73 1.00 root data:Selection_72 + │ │ │ │ │ │ │ └─Selection_72 1.00 cop eq(tpch.region.r_name, "ASIA") + │ │ │ │ │ │ │ └─TableScan_71 5.00 cop table:region, range:[-inf,+inf], keep order:false + │ │ │ │ │ │ └─TableReader_70 25.00 root data:TableScan_69 + │ │ │ │ │ │ └─TableScan_69 25.00 cop table:n1, range:[-inf,+inf], keep order:false + │ │ │ │ │ └─TableReader_75 7500000.00 root data:TableScan_74 + │ │ │ │ │ └─TableScan_74 7500000.00 cop table:customer, range:[-inf,+inf], keep order:false + │ │ │ │ └─TableReader_78 22382008.93 root data:Selection_77 + │ │ │ │ └─Selection_77 22382008.93 cop ge(tpch.orders.o_orderdate, 1995-01-01 00:00:00.000000), le(tpch.orders.o_orderdate, 1996-12-31 00:00:00.000000) + │ │ │ │ └─TableScan_76 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false + │ │ │ └─IndexLookUp_55 1.00 root + │ │ │ ├─IndexScan_53 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [tpch.orders.o_orderkey], keep order:false + │ │ │ └─TableScan_54 1.00 cop table:lineitem, keep order:false + │ │ └─TableReader_83 61674.00 root data:Selection_82 + │ │ └─Selection_82 61674.00 cop eq(tpch.part.p_type, "SMALL PLATED COPPER") + │ │ └─TableScan_81 10000000.00 cop table:part, range:[-inf,+inf], keep order:false + │ └─TableReader_42 1.00 root data:TableScan_41 + │ └─TableScan_41 1.00 cop table:supplier, range: decided by [tpch.lineitem.l_suppkey], keep order:false + └─TableReader_87 25.00 root data:TableScan_86 + └─TableScan_86 25.00 cop table:n2, range:[-inf,+inf], keep order:false /* Q9 Product Type Profit Measure Query This query determines how much profit is made on a given line of parts, broken out by supplier nation and year. @@ -657,21 +660,22 @@ limit 20; id count task operator info Projection_17 20.00 root tpch.customer.c_custkey, tpch.customer.c_name, 9_col_0, tpch.customer.c_acctbal, tpch.nation.n_name, tpch.customer.c_address, tpch.customer.c_phone, tpch.customer.c_comment └─TopN_20 20.00 root 9_col_0:desc, offset:0, count:20 - └─HashAgg_26 3017307.69 root group by:tpch.customer.c_acctbal, tpch.customer.c_address, tpch.customer.c_comment, tpch.customer.c_custkey, tpch.customer.c_name, tpch.customer.c_phone, tpch.nation.n_name, funcs:sum(mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount))), firstrow(tpch.customer.c_custkey), firstrow(tpch.customer.c_name), firstrow(tpch.customer.c_address), firstrow(tpch.customer.c_phone), firstrow(tpch.customer.c_acctbal), firstrow(tpch.customer.c_comment), firstrow(tpch.nation.n_name) - └─IndexJoin_32 12222016.17 root inner join, inner:IndexLookUp_31, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey - ├─HashLeftJoin_35 3017307.69 root inner join, inner:TableReader_48, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] - │ ├─HashRightJoin_41 7500000.00 root inner join, inner:TableReader_45, equal:[eq(tpch.nation.n_nationkey, tpch.customer.c_nationkey)] - │ │ ├─TableReader_45 25.00 root data:TableScan_44 - │ │ │ └─TableScan_44 25.00 cop table:nation, range:[-inf,+inf], keep order:false - │ │ └─TableReader_43 7500000.00 root data:TableScan_42 - │ │ └─TableScan_42 7500000.00 cop table:customer, range:[-inf,+inf], keep order:false - │ └─TableReader_48 3017307.69 root data:Selection_47 - │ └─Selection_47 3017307.69 cop ge(tpch.orders.o_orderdate, 1993-08-01 00:00:00.000000), lt(tpch.orders.o_orderdate, 1993-11-01) - │ └─TableScan_46 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false - └─IndexLookUp_31 73916005.00 root - ├─IndexScan_28 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [tpch.orders.o_orderkey], keep order:false - └─Selection_30 73916005.00 cop eq(tpch.lineitem.l_returnflag, "R") - └─TableScan_29 1.00 cop table:lineitem, keep order:false + └─HashAgg_26 3017307.69 root group by:col_10, col_11, col_12, col_13, col_14, col_8, col_9, funcs:sum(col_0), firstrow(col_1), firstrow(col_2), firstrow(col_3), firstrow(col_4), firstrow(col_5), firstrow(col_6), firstrow(col_7) + └─Projection_52 12222016.17 root mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), tpch.customer.c_custkey, tpch.customer.c_name, tpch.customer.c_address, tpch.customer.c_phone, tpch.customer.c_acctbal, tpch.customer.c_comment, tpch.nation.n_name, tpch.customer.c_custkey, tpch.customer.c_name, tpch.customer.c_acctbal, tpch.customer.c_phone, tpch.nation.n_name, tpch.customer.c_address, tpch.customer.c_comment + └─IndexJoin_32 12222016.17 root inner join, inner:IndexLookUp_31, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey + ├─HashLeftJoin_35 3017307.69 root inner join, inner:TableReader_48, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] + │ ├─HashRightJoin_41 7500000.00 root inner join, inner:TableReader_45, equal:[eq(tpch.nation.n_nationkey, tpch.customer.c_nationkey)] + │ │ ├─TableReader_45 25.00 root data:TableScan_44 + │ │ │ └─TableScan_44 25.00 cop table:nation, range:[-inf,+inf], keep order:false + │ │ └─TableReader_43 7500000.00 root data:TableScan_42 + │ │ └─TableScan_42 7500000.00 cop table:customer, range:[-inf,+inf], keep order:false + │ └─TableReader_48 3017307.69 root data:Selection_47 + │ └─Selection_47 3017307.69 cop ge(tpch.orders.o_orderdate, 1993-08-01 00:00:00.000000), lt(tpch.orders.o_orderdate, 1993-11-01) + │ └─TableScan_46 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false + └─IndexLookUp_31 73916005.00 root + ├─IndexScan_28 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [tpch.orders.o_orderkey], keep order:false + └─Selection_30 73916005.00 cop eq(tpch.lineitem.l_returnflag, "R") + └─TableScan_29 1.00 cop table:lineitem, keep order:false /* Q11 Important Stock Identification Query This query finds the most important subset of suppliers' stock in a given nation. @@ -711,16 +715,17 @@ id count task operator info Projection_63 1304801.67 root tpch.partsupp.ps_partkey, value └─Sort_64 1304801.67 root value:desc └─Selection_66 1304801.67 root gt(sel_agg_4, NULL) - └─HashAgg_69 1631002.09 root group by:tpch.partsupp.ps_partkey, funcs:sum(mul(tpch.partsupp.ps_supplycost, cast(tpch.partsupp.ps_availqty))), firstrow(tpch.partsupp.ps_partkey) - └─HashRightJoin_73 1631002.09 root inner join, inner:HashRightJoin_79, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] - ├─HashRightJoin_79 20000.00 root inner join, inner:TableReader_84, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] - │ ├─TableReader_84 1.00 root data:Selection_83 - │ │ └─Selection_83 1.00 cop eq(tpch.nation.n_name, "MOZAMBIQUE") - │ │ └─TableScan_82 25.00 cop table:nation, range:[-inf,+inf], keep order:false - │ └─TableReader_81 500000.00 root data:TableScan_80 - │ └─TableScan_80 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false - └─TableReader_86 40000000.00 root data:TableScan_85 - └─TableScan_85 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false + └─HashAgg_69 1631002.09 root group by:col_2, funcs:sum(col_0), firstrow(col_1) + └─Projection_88 1631002.09 root mul(tpch.partsupp.ps_supplycost, cast(tpch.partsupp.ps_availqty)), tpch.partsupp.ps_partkey, tpch.partsupp.ps_partkey + └─HashRightJoin_73 1631002.09 root inner join, inner:HashRightJoin_79, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] + ├─HashRightJoin_79 20000.00 root inner join, inner:TableReader_84, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] + │ ├─TableReader_84 1.00 root data:Selection_83 + │ │ └─Selection_83 1.00 cop eq(tpch.nation.n_name, "MOZAMBIQUE") + │ │ └─TableScan_82 25.00 cop table:nation, range:[-inf,+inf], keep order:false + │ └─TableReader_81 500000.00 root data:TableScan_80 + │ └─TableScan_80 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false + └─TableReader_86 40000000.00 root data:TableScan_85 + └─TableScan_85 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false /* Q12 Shipping Modes and Order Priority Query This query determines whether selecting less expensive modes of shipping is negatively affecting the critical-priority @@ -763,13 +768,14 @@ l_shipmode; id count task operator info Sort_9 1.00 root tpch.lineitem.l_shipmode:asc └─Projection_11 1.00 root tpch.lineitem.l_shipmode, 5_col_0, 5_col_1 - └─HashAgg_14 1.00 root group by:tpch.lineitem.l_shipmode, funcs:sum(case(or(eq(tpch.orders.o_orderpriority, "1-URGENT"), eq(tpch.orders.o_orderpriority, "2-HIGH")), 1, 0)), sum(case(and(ne(tpch.orders.o_orderpriority, "1-URGENT"), ne(tpch.orders.o_orderpriority, "2-HIGH")), 1, 0)), firstrow(tpch.lineitem.l_shipmode) - └─IndexJoin_18 10023369.01 root inner join, inner:TableReader_17, outer key:tpch.lineitem.l_orderkey, inner key:tpch.orders.o_orderkey - ├─TableReader_35 10023369.01 root data:Selection_34 - │ └─Selection_34 10023369.01 cop ge(tpch.lineitem.l_receiptdate, 1997-01-01 00:00:00.000000), in(tpch.lineitem.l_shipmode, "RAIL", "FOB"), lt(tpch.lineitem.l_commitdate, tpch.lineitem.l_receiptdate), lt(tpch.lineitem.l_receiptdate, 1998-01-01), lt(tpch.lineitem.l_shipdate, tpch.lineitem.l_commitdate) - │ └─TableScan_33 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false - └─TableReader_17 1.00 root data:TableScan_16 - └─TableScan_16 1.00 cop table:orders, range: decided by [tpch.lineitem.l_orderkey], keep order:false + └─HashAgg_14 1.00 root group by:col_3, funcs:sum(col_0), sum(col_1), firstrow(col_2) + └─Projection_39 10023369.01 root cast(case(or(eq(tpch.orders.o_orderpriority, "1-URGENT"), eq(tpch.orders.o_orderpriority, "2-HIGH")), 1, 0)), cast(case(and(ne(tpch.orders.o_orderpriority, "1-URGENT"), ne(tpch.orders.o_orderpriority, "2-HIGH")), 1, 0)), tpch.lineitem.l_shipmode, tpch.lineitem.l_shipmode + └─IndexJoin_18 10023369.01 root inner join, inner:TableReader_17, outer key:tpch.lineitem.l_orderkey, inner key:tpch.orders.o_orderkey + ├─TableReader_35 10023369.01 root data:Selection_34 + │ └─Selection_34 10023369.01 cop ge(tpch.lineitem.l_receiptdate, 1997-01-01 00:00:00.000000), in(tpch.lineitem.l_shipmode, "RAIL", "FOB"), lt(tpch.lineitem.l_commitdate, tpch.lineitem.l_receiptdate), lt(tpch.lineitem.l_receiptdate, 1998-01-01), lt(tpch.lineitem.l_shipdate, tpch.lineitem.l_commitdate) + │ └─TableScan_33 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false + └─TableReader_17 1.00 root data:TableScan_16 + └─TableScan_16 1.00 cop table:orders, range: decided by [tpch.lineitem.l_orderkey], keep order:false /* Q13 Customer Distribution Query This query seeks relationships between customers and the size of their orders. @@ -833,13 +839,14 @@ and l_shipdate >= '1996-12-01' and l_shipdate < date_add('1996-12-01', interval '1' month); id count task operator info Projection_8 1.00 root div(mul(100.00, 5_col_0), 5_col_1) -└─StreamAgg_13 1.00 root funcs:sum(case(like(tpch.part.p_type, "PROMO%", 92), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), 0)), sum(mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount))) - └─IndexJoin_26 4121984.49 root inner join, inner:TableReader_25, outer key:tpch.lineitem.l_partkey, inner key:tpch.part.p_partkey - ├─TableReader_31 4121984.49 root data:Selection_30 - │ └─Selection_30 4121984.49 cop ge(tpch.lineitem.l_shipdate, 1996-12-01 00:00:00.000000), lt(tpch.lineitem.l_shipdate, 1997-01-01) - │ └─TableScan_29 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false - └─TableReader_25 1.00 root data:TableScan_24 - └─TableScan_24 1.00 cop table:part, range: decided by [tpch.lineitem.l_partkey], keep order:false +└─StreamAgg_13 1.00 root funcs:sum(col_0), sum(col_1) + └─Projection_34 4121984.49 root case(like(tpch.part.p_type, "PROMO%", 92), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), 0), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)) + └─IndexJoin_26 4121984.49 root inner join, inner:TableReader_25, outer key:tpch.lineitem.l_partkey, inner key:tpch.part.p_partkey + ├─TableReader_31 4121984.49 root data:Selection_30 + │ └─Selection_30 4121984.49 cop ge(tpch.lineitem.l_shipdate, 1996-12-01 00:00:00.000000), lt(tpch.lineitem.l_shipdate, 1997-01-01) + │ └─TableScan_29 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false + └─TableReader_25 1.00 root data:TableScan_24 + └─TableScan_24 1.00 cop table:part, range: decided by [tpch.lineitem.l_partkey], keep order:false /* Q15 Top Supplier Query This query determines the top supplier so it can be rewarded, given more business, or identified for special recognition. @@ -1083,14 +1090,15 @@ and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' ); id count task operator info -StreamAgg_13 1.00 root funcs:sum(mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount))) -└─IndexJoin_29 6286493.79 root inner join, inner:TableReader_28, outer key:tpch.lineitem.l_partkey, inner key:tpch.part.p_partkey, other cond:or(and(and(eq(tpch.part.p_brand, "Brand#52"), in(tpch.part.p_container, "SM CASE", "SM BOX", "SM PACK", "SM PKG")), and(ge(tpch.lineitem.l_quantity, 4), and(le(tpch.lineitem.l_quantity, 14), le(tpch.part.p_size, 5)))), or(and(and(eq(tpch.part.p_brand, "Brand#11"), in(tpch.part.p_container, "MED BAG", "MED BOX", "MED PKG", "MED PACK")), and(ge(tpch.lineitem.l_quantity, 18), and(le(tpch.lineitem.l_quantity, 28), le(tpch.part.p_size, 10)))), and(and(eq(tpch.part.p_brand, "Brand#51"), in(tpch.part.p_container, "LG CASE", "LG BOX", "LG PACK", "LG PKG")), and(ge(tpch.lineitem.l_quantity, 29), and(le(tpch.lineitem.l_quantity, 39), le(tpch.part.p_size, 15)))))) - ├─TableReader_34 6286493.79 root data:Selection_33 - │ └─Selection_33 6286493.79 cop eq(tpch.lineitem.l_shipinstruct, "DELIVER IN PERSON"), in(tpch.lineitem.l_shipmode, "AIR", "AIR REG"), or(and(ge(tpch.lineitem.l_quantity, 4), le(tpch.lineitem.l_quantity, 14)), or(and(ge(tpch.lineitem.l_quantity, 18), le(tpch.lineitem.l_quantity, 28)), and(ge(tpch.lineitem.l_quantity, 29), le(tpch.lineitem.l_quantity, 39)))) - │ └─TableScan_32 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false - └─TableReader_28 8000000.00 root data:Selection_27 - └─Selection_27 8000000.00 cop ge(tpch.part.p_size, 1), or(and(eq(tpch.part.p_brand, "Brand#52"), and(in(tpch.part.p_container, "SM CASE", "SM BOX", "SM PACK", "SM PKG"), le(tpch.part.p_size, 5))), or(and(eq(tpch.part.p_brand, "Brand#11"), and(in(tpch.part.p_container, "MED BAG", "MED BOX", "MED PKG", "MED PACK"), le(tpch.part.p_size, 10))), and(eq(tpch.part.p_brand, "Brand#51"), and(in(tpch.part.p_container, "LG CASE", "LG BOX", "LG PACK", "LG PKG"), le(tpch.part.p_size, 15))))) - └─TableScan_26 1.00 cop table:part, range: decided by [tpch.lineitem.l_partkey], keep order:false +StreamAgg_13 1.00 root funcs:sum(col_0) +└─Projection_38 6286493.79 root mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)) + └─IndexJoin_29 6286493.79 root inner join, inner:TableReader_28, outer key:tpch.lineitem.l_partkey, inner key:tpch.part.p_partkey, other cond:or(and(and(eq(tpch.part.p_brand, "Brand#52"), in(tpch.part.p_container, "SM CASE", "SM BOX", "SM PACK", "SM PKG")), and(ge(tpch.lineitem.l_quantity, 4), and(le(tpch.lineitem.l_quantity, 14), le(tpch.part.p_size, 5)))), or(and(and(eq(tpch.part.p_brand, "Brand#11"), in(tpch.part.p_container, "MED BAG", "MED BOX", "MED PKG", "MED PACK")), and(ge(tpch.lineitem.l_quantity, 18), and(le(tpch.lineitem.l_quantity, 28), le(tpch.part.p_size, 10)))), and(and(eq(tpch.part.p_brand, "Brand#51"), in(tpch.part.p_container, "LG CASE", "LG BOX", "LG PACK", "LG PKG")), and(ge(tpch.lineitem.l_quantity, 29), and(le(tpch.lineitem.l_quantity, 39), le(tpch.part.p_size, 15)))))) + ├─TableReader_34 6286493.79 root data:Selection_33 + │ └─Selection_33 6286493.79 cop eq(tpch.lineitem.l_shipinstruct, "DELIVER IN PERSON"), in(tpch.lineitem.l_shipmode, "AIR", "AIR REG"), or(and(ge(tpch.lineitem.l_quantity, 4), le(tpch.lineitem.l_quantity, 14)), or(and(ge(tpch.lineitem.l_quantity, 18), le(tpch.lineitem.l_quantity, 28)), and(ge(tpch.lineitem.l_quantity, 29), le(tpch.lineitem.l_quantity, 39)))) + │ └─TableScan_32 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false + └─TableReader_28 8000000.00 root data:Selection_27 + └─Selection_27 8000000.00 cop ge(tpch.part.p_size, 1), or(and(eq(tpch.part.p_brand, "Brand#52"), and(in(tpch.part.p_container, "SM CASE", "SM BOX", "SM PACK", "SM PKG"), le(tpch.part.p_size, 5))), or(and(eq(tpch.part.p_brand, "Brand#11"), and(in(tpch.part.p_container, "MED BAG", "MED BOX", "MED PKG", "MED PACK"), le(tpch.part.p_size, 10))), and(eq(tpch.part.p_brand, "Brand#51"), and(in(tpch.part.p_container, "LG CASE", "LG BOX", "LG PACK", "LG PKG"), le(tpch.part.p_size, 15))))) + └─TableScan_26 1.00 cop table:part, range: decided by [tpch.lineitem.l_partkey], keep order:false /* Q20 Potential Part Promotion Query The Potential Part Promotion Query identifies suppliers in a particular nation having selected parts that may be candidates diff --git a/ddl/column.go b/ddl/column.go index 6329b616158a0..7854e6b0df739 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -15,7 +15,9 @@ package ddl import ( "fmt" + "strings" "sync/atomic" + "time" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" @@ -25,6 +27,8 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/sqlexec" log "github.com/sirupsen/logrus" ) @@ -121,32 +125,19 @@ func createColumnInfo(tblInfo *model.TableInfo, colInfo *model.ColumnInfo, pos * return colInfo, position, nil } -func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - // Handle the rolling back job. - if job.IsRollingback() { - ver, err = onDropColumn(t, job) - if err != nil { - return ver, errors.Trace(err) - } - return ver, nil - } - +func checkAddColumn(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ColumnInfo, *model.ColumnInfo, *ast.ColumnPosition, int, error) { schemaID := job.SchemaID tblInfo, err := getTableInfo(t, job, schemaID) if err != nil { - return ver, errors.Trace(err) + return nil, nil, nil, nil, 0, errors.Trace(err) } - // gofail: var errorBeforeDecodeArgs bool - // if errorBeforeDecodeArgs { - // return ver, errors.New("occur an error before decode args") - // } col := &model.ColumnInfo{} pos := &ast.ColumnPosition{} offset := 0 err = job.DecodeArgs(col, pos, &offset) if err != nil { job.State = model.JobStateCancelled - return ver, errors.Trace(err) + return nil, nil, nil, nil, 0, errors.Trace(err) } columnInfo := model.FindColumnInfo(tblInfo.Columns, col.Name.L) @@ -154,9 +145,32 @@ func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) if columnInfo.State == model.StatePublic { // We already have a column with the same column name. job.State = model.JobStateCancelled - return ver, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name) + return nil, nil, nil, nil, 0, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name) } - } else { + } + return tblInfo, columnInfo, col, pos, offset, nil +} + +func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + // Handle the rolling back job. + if job.IsRollingback() { + ver, err = onDropColumn(t, job) + if err != nil { + return ver, errors.Trace(err) + } + return ver, nil + } + + // gofail: var errorBeforeDecodeArgs bool + // if errorBeforeDecodeArgs { + // return ver, errors.New("occur an error before decode args") + // } + + tblInfo, columnInfo, col, pos, offset, err := checkAddColumn(t, job) + if err != nil { + return ver, errors.Trace(err) + } + if columnInfo == nil { columnInfo, offset, err = createColumnInfo(tblInfo, col, pos) if err != nil { job.State = model.JobStateCancelled @@ -211,26 +225,8 @@ func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) } func onDropColumn(t *meta.Meta, job *model.Job) (ver int64, _ error) { - schemaID := job.SchemaID - tblInfo, err := getTableInfo(t, job, schemaID) - if err != nil { - return ver, errors.Trace(err) - } - - var colName model.CIStr - err = job.DecodeArgs(&colName) + tblInfo, colInfo, err := checkDropColumn(t, job) if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - colInfo := model.FindColumnInfo(tblInfo.Columns, colName.L) - if colInfo == nil { - job.State = model.JobStateCancelled - return ver, ErrCantDropFieldOrKey.GenWithStack("column %s doesn't exist", colName) - } - if err = isDroppableColumn(tblInfo, colName); err != nil { - job.State = model.JobStateCancelled return ver, errors.Trace(err) } @@ -242,6 +238,15 @@ func onDropColumn(t *meta.Meta, job *model.Job) (ver int64, _ error) { colInfo.State = model.StateWriteOnly // Set this column's offset to the last and reset all following columns' offsets. adjustColumnInfoInDropColumn(tblInfo, colInfo.Offset) + // When the dropping column has not-null flag and it hasn't the default value, we can backfill the column value like "add column". + // NOTE: If the state of StateWriteOnly can be rollbacked, we'd better reconsider the original default value. + // And we need consider the column without not-null flag. + if colInfo.OriginDefaultValue == nil && mysql.HasNotNullFlag(colInfo.Flag) { + colInfo.OriginDefaultValue, err = generateOriginDefaultValue(colInfo) + if err != nil { + return ver, errors.Trace(err) + } + } ver, err = updateVersionAndTableInfo(t, job, tblInfo, originalState != colInfo.State) case model.StateWriteOnly: // write only -> delete only @@ -275,6 +280,32 @@ func onDropColumn(t *meta.Meta, job *model.Job) (ver int64, _ error) { return ver, errors.Trace(err) } +func checkDropColumn(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ColumnInfo, error) { + schemaID := job.SchemaID + tblInfo, err := getTableInfo(t, job, schemaID) + if err != nil { + return nil, nil, errors.Trace(err) + } + + var colName model.CIStr + err = job.DecodeArgs(&colName) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, errors.Trace(err) + } + + colInfo := model.FindColumnInfo(tblInfo.Columns, colName.L) + if colInfo == nil { + job.State = model.JobStateCancelled + return nil, nil, ErrCantDropFieldOrKey.GenWithStack("column %s doesn't exist", colName) + } + if err = isDroppableColumn(tblInfo, colName); err != nil { + job.State = model.JobStateCancelled + return nil, nil, errors.Trace(err) + } + return tblInfo, colInfo, nil +} + func onSetDefaultValue(t *meta.Meta, job *model.Job) (ver int64, _ error) { newCol := &model.ColumnInfo{} err := job.DecodeArgs(newCol) @@ -527,3 +558,21 @@ func modifyColumnFromNull2NotNull(w *worker, t *meta.Meta, dbInfo *model.DBInfo, ver, err = updateVersionAndTableInfo(t, job, tblInfo, true) return ver, errors.Trace(err) } + +func generateOriginDefaultValue(col *model.ColumnInfo) (interface{}, error) { + var err error + odValue := col.GetDefaultValue() + if odValue == nil && mysql.HasNotNullFlag(col.Flag) { + zeroVal := table.GetZeroValue(col) + odValue, err = zeroVal.ToString() + if err != nil { + return nil, errors.Trace(err) + } + } + + if odValue == strings.ToUpper(ast.CurrentTimestamp) && + (col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime) { + odValue = time.Now().Format(types.TimeFormat) + } + return odValue, nil +} diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 30629dde80802..b60915c9d2a9f 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -106,6 +106,8 @@ func (s *testStateChangeSuite) TestShowCreateTable(c *C) { } } d := s.dom.DDL() + originalCallback := d.GetHook() + defer d.(ddl.DDLForTest).SetHook(originalCallback) d.(ddl.DDLForTest).SetHook(callback) tk.MustExec("alter table t add index idx1(id)") c.Assert(checkErr, IsNil) @@ -113,6 +115,61 @@ func (s *testStateChangeSuite) TestShowCreateTable(c *C) { c.Assert(checkErr, IsNil) } +// TestDropNotNullColumn is used to test issue #8654. +func (s *testStateChangeSuite) TestDropNotNullColumn(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t (id int, a int not null default 11)") + tk.MustExec("insert into t values(1, 1)") + tk.MustExec("create table t1 (id int, b varchar(255) not null)") + tk.MustExec("insert into t1 values(2, '')") + tk.MustExec("create table t2 (id int, c time not null)") + tk.MustExec("insert into t2 values(3, '11:22:33')") + tk.MustExec("create table t3 (id int, d json not null)") + tk.MustExec("insert into t3 values(4, d)") + tk1 := testkit.NewTestKit(c, s.store) + tk1.MustExec("use test") + + var checkErr error + d := s.dom.DDL() + originalCallback := d.GetHook() + callback := &ddl.TestDDLCallback{} + sqlNum := 0 + callback.OnJobUpdatedExported = func(job *model.Job) { + if checkErr != nil { + return + } + originalCallback.OnChanged(nil) + if job.SchemaState == model.StateWriteOnly { + switch sqlNum { + case 0: + _, checkErr = tk1.Exec("insert into t set id = 1") + case 1: + _, checkErr = tk1.Exec("insert into t1 set id = 2") + case 2: + _, checkErr = tk1.Exec("insert into t2 set id = 3") + case 3: + _, checkErr = tk1.Exec("insert into t3 set id = 4") + } + } + } + + d.(ddl.DDLForTest).SetHook(callback) + tk.MustExec("alter table t drop column a") + c.Assert(checkErr, IsNil) + sqlNum++ + tk.MustExec("alter table t1 drop column b") + c.Assert(checkErr, IsNil) + sqlNum++ + tk.MustExec("alter table t2 drop column c") + c.Assert(checkErr, IsNil) + sqlNum++ + tk.MustExec("alter table t3 drop column d") + c.Assert(checkErr, IsNil) + d.(ddl.DDLForTest).SetHook(originalCallback) + tk.MustExec("drop table t, t1, t2, t3") +} + func (s *testStateChangeSuite) TestTwoStates(c *C) { cnt := 5 // New the testExecInfo. @@ -212,6 +269,8 @@ func (s *testStateChangeSuite) test(c *C, tableName, alterTableSQL string, testI } } d := s.dom.DDL() + originalCallback := d.GetHook() + defer d.(ddl.DDLForTest).SetHook(originalCallback) d.(ddl.DDLForTest).SetHook(callback) _, err = s.se.Execute(context.Background(), alterTableSQL) c.Assert(err, IsNil) @@ -223,8 +282,6 @@ func (s *testStateChangeSuite) test(c *C, tableName, alterTableSQL string, testI err = testInfo.execSQL(3) c.Assert(err, IsNil) c.Assert(errors.ErrorStack(checkErr), Equals, "") - callback = &ddl.TestDDLCallback{} - d.(ddl.DDLForTest).SetHook(callback) } type stateCase struct { @@ -481,12 +538,12 @@ func (s *testStateChangeSuite) runTestInSchemaState(c *C, state model.SchemaStat } } d := s.dom.DDL() + originalCallback := d.GetHook() d.(ddl.DDLForTest).SetHook(callback) _, err = s.se.Execute(context.Background(), alterTableSQL) c.Assert(err, IsNil) c.Assert(errors.ErrorStack(checkErr), Equals, "") - callback = &ddl.TestDDLCallback{} - d.(ddl.DDLForTest).SetHook(callback) + d.(ddl.DDLForTest).SetHook(originalCallback) if expectQuery != nil { tk := testkit.NewTestKit(c, s.store) @@ -554,6 +611,7 @@ func (s *testStateChangeSuite) TestShowIndex(c *C) { } d := s.dom.DDL() + originalCallback := d.GetHook() d.(ddl.DDLForTest).SetHook(callback) alterTableSQL := `alter table t add index c2(c2)` _, err = s.se.Execute(context.Background(), alterTableSQL) @@ -564,8 +622,7 @@ func (s *testStateChangeSuite) TestShowIndex(c *C) { c.Assert(err, IsNil) err = checkResult(result, testkit.Rows("t 0 PRIMARY 1 c1 A 0 BTREE ", "t 1 c2 1 c2 A 0 YES BTREE ")) c.Assert(err, IsNil) - callback = &ddl.TestDDLCallback{} - d.(ddl.DDLForTest).SetHook(callback) + d.(ddl.DDLForTest).SetHook(originalCallback) c.Assert(err, IsNil) @@ -716,6 +773,8 @@ func (s *testStateChangeSuite) testControlParallelExecSQL(c *C, sql1, sql2 strin times++ } d := s.dom.DDL() + originalCallback := d.GetHook() + defer d.(ddl.DDLForTest).SetHook(originalCallback) d.(ddl.DDLForTest).SetHook(callback) wg := sync.WaitGroup{} @@ -763,9 +822,6 @@ func (s *testStateChangeSuite) testControlParallelExecSQL(c *C, sql1, sql2 strin wg.Wait() f(c, err1, err2) - - callback = &ddl.TestDDLCallback{} - d.(ddl.DDLForTest).SetHook(callback) } func (s *testStateChangeSuite) testParallelExecSQL(c *C, sql string) { @@ -792,6 +848,8 @@ func (s *testStateChangeSuite) testParallelExecSQL(c *C, sql string) { } d := s.dom.DDL() + originalCallback := d.GetHook() + defer d.(ddl.DDLForTest).SetHook(originalCallback) d.(ddl.DDLForTest).SetHook(callback) wg.Add(2) @@ -807,8 +865,6 @@ func (s *testStateChangeSuite) testParallelExecSQL(c *C, sql string) { wg.Wait() c.Assert(err2, IsNil) c.Assert(err3, IsNil) - callback = &ddl.TestDDLCallback{} - d.(ddl.DDLForTest).SetHook(callback) } // TestCreateTableIfNotExists parallel exec create table if not exists xxx. No error returns is expected. diff --git a/ddl/db_test.go b/ddl/db_test.go index 5ceb47775021a..4b20295015fe0 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -473,6 +473,64 @@ func (s *testDBSuite) TestCancelDropIndex(c *C) { s.mustExec(c, "alter table t drop index idx_c2") } +// TestCancelRenameIndex tests cancel ddl job which type is rename index. +func (s *testDBSuite) TestCancelRenameIndex(c *C) { + s.tk = testkit.NewTestKit(c, s.store) + s.mustExec(c, "use test_db") + s.mustExec(c, "create database if not exists test_rename_index") + s.mustExec(c, "drop table if exists t") + s.mustExec(c, "create table t(c1 int, c2 int)") + defer s.mustExec(c, "drop table t;") + for i := 0; i < 100; i++ { + s.mustExec(c, "insert into t values (?, ?)", i, i) + } + s.mustExec(c, "alter table t add index idx_c2(c2)") + var checkErr error + hook := &ddl.TestDDLCallback{} + hook.OnJobRunBeforeExported = func(job *model.Job) { + if job.Type == model.ActionRenameIndex && job.State == model.JobStateNone { + jobIDs := []int64{job.ID} + hookCtx := mock.NewContext() + hookCtx.Store = s.store + err := hookCtx.NewTxn(context.Background()) + if err != nil { + checkErr = errors.Trace(err) + return + } + txn, err := hookCtx.Txn(true) + if err != nil { + checkErr = errors.Trace(err) + return + } + errs, err := admin.CancelJobs(txn, jobIDs) + if err != nil { + checkErr = errors.Trace(err) + return + } + if errs[0] != nil { + checkErr = errors.Trace(errs[0]) + return + } + checkErr = txn.Commit(context.Background()) + } + } + originalHook := s.dom.DDL().GetHook() + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + rs, err := s.tk.Exec("alter table t rename index idx_c2 to idx_c3") + if rs != nil { + rs.Close() + } + c.Assert(checkErr, IsNil) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:12]cancelled DDL job") + s.dom.DDL().(ddl.DDLForTest).SetHook(originalHook) + t := s.testGetTable(c, "t") + for _, idx := range t.Indices() { + c.Assert(strings.EqualFold(idx.Meta().Name.L, "idx_c3"), IsFalse) + } + s.mustExec(c, "alter table t rename index idx_c2 to idx_c3") +} + // TestCancelDropTable tests cancel ddl job which type is drop table. func (s *testDBSuite) TestCancelDropTableAndSchema(c *C) { s.tk = testkit.NewTestKit(c, s.store) @@ -1619,14 +1677,49 @@ func (s *testDBSuite) testRenameTable(c *C, sql string, isAlterTable bool) { // for failure case failSQL := fmt.Sprintf(sql, "test_not_exist.t", "test_not_exist.t") - s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } failSQL = fmt.Sprintf(sql, "test.test_not_exist", "test.test_not_exist") - s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } failSQL = fmt.Sprintf(sql, "test.t_not_exist", "test_not_exist.t") - s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } failSQL = fmt.Sprintf(sql, "test1.t2", "test_not_exist.t") s.testErrorCode(c, failSQL, tmysql.ErrErrorOnRename) + s.tk.MustExec("use test1") + s.tk.MustExec("create table if not exists t_exist (c1 int, c2 int)") + failSQL = fmt.Sprintf(sql, "test1.t2", "test1.t_exist") + s.testErrorCode(c, failSQL, tmysql.ErrTableExists) + failSQL = fmt.Sprintf(sql, "test.t_not_exist", "test1.t_exist") + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrTableExists) + } + failSQL = fmt.Sprintf(sql, "test_not_exist.t", "test1.t_exist") + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrTableExists) + } + failSQL = fmt.Sprintf(sql, "test_not_exist.t", "test1.t_not_exist") + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } + // for the same table name s.tk.MustExec("use test1") s.tk.MustExec("create table if not exists t (c1 int, c2 int)") @@ -2034,3 +2127,174 @@ LOOP: s.dom.DDL().(ddl.DDLForTest).SetHook(originalHook) s.mustExec(c, "drop table t1") } + +func (s *testDBSuite) TestTransactionOnAddDropColumn(c *C) { + s.tk = testkit.NewTestKit(c, s.store) + s.mustExec(c, "use test_db") + s.mustExec(c, "drop table if exists t1") + s.mustExec(c, "create table t1 (a int, b int);") + s.mustExec(c, "create table t2 (a int, b int);") + s.mustExec(c, "insert into t2 values (2,0)") + + transactions := [][]string{ + { + "begin", + "insert into t1 set a=1", + "update t1 set b=1 where a=1", + "commit", + }, + { + "begin", + "insert into t1 select a,b from t2", + "update t1 set b=2 where a=2", + "commit", + }, + } + + originHook := s.dom.DDL().GetHook() + defer s.dom.DDL().(ddl.DDLForTest).SetHook(originHook) + hook := &ddl.TestDDLCallback{} + hook.OnJobRunBeforeExported = func(job *model.Job) { + switch job.SchemaState { + case model.StateWriteOnly, model.StateWriteReorganization, model.StateDeleteOnly, model.StateDeleteReorganization: + default: + return + } + // do transaction. + for _, transaction := range transactions { + for _, sql := range transaction { + s.mustExec(c, sql) + } + } + } + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + done := make(chan error, 1) + // test transaction on add column. + go backgroundExec(s.store, "alter table t1 add column c int not null after a", done) + err := <-done + c.Assert(err, IsNil) + s.tk.MustQuery("select a,b from t1 order by a").Check(testkit.Rows("1 1", "1 1", "1 1", "2 2", "2 2", "2 2")) + s.mustExec(c, "delete from t1") + + // test transaction on drop column. + go backgroundExec(s.store, "alter table t1 drop column c", done) + err = <-done + c.Assert(err, IsNil) + s.tk.MustQuery("select a,b from t1 order by a").Check(testkit.Rows("1 1", "1 1", "1 1", "2 2", "2 2", "2 2")) +} + +func (s *testDBSuite) TestTransactionWithWriteOnlyColumn(c *C) { + s.tk = testkit.NewTestKit(c, s.store) + s.mustExec(c, "use test_db") + s.mustExec(c, "drop table if exists t1") + s.mustExec(c, "create table t1 (a int key);") + + transactions := [][]string{ + { + "begin", + "insert into t1 set a=1", + "update t1 set a=2 where a=1", + "commit", + }, + } + + originHook := s.dom.DDL().GetHook() + defer s.dom.DDL().(ddl.DDLForTest).SetHook(originHook) + hook := &ddl.TestDDLCallback{} + hook.OnJobRunBeforeExported = func(job *model.Job) { + switch job.SchemaState { + case model.StateWriteOnly: + default: + return + } + // do transaction. + for _, transaction := range transactions { + for _, sql := range transaction { + s.mustExec(c, sql) + } + } + } + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + done := make(chan error, 1) + // test transaction on add column. + go backgroundExec(s.store, "alter table t1 add column c int not null", done) + err := <-done + c.Assert(err, IsNil) + s.tk.MustQuery("select a from t1").Check(testkit.Rows("2")) + s.mustExec(c, "delete from t1") + + // test transaction on drop column. + go backgroundExec(s.store, "alter table t1 drop column c", done) + err = <-done + c.Assert(err, IsNil) + s.tk.MustQuery("select a from t1").Check(testkit.Rows("2")) +} + +func (s *testDBSuite) TestAddColumn2(c *C) { + s.tk = testkit.NewTestKit(c, s.store) + s.mustExec(c, "use test_db") + s.mustExec(c, "drop table if exists t1") + s.mustExec(c, "create table t1 (a int key, b int);") + defer s.mustExec(c, "drop table if exists t1, t2") + + originHook := s.dom.DDL().GetHook() + defer s.dom.DDL().(ddl.DDLForTest).SetHook(originHook) + hook := &ddl.TestDDLCallback{} + var writeOnlyTable table.Table + hook.OnJobRunBeforeExported = func(job *model.Job) { + if job.SchemaState == model.StateWriteOnly { + writeOnlyTable, _ = s.dom.InfoSchema().TableByID(job.TableID) + } + } + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + done := make(chan error, 1) + // test transaction on add column. + go backgroundExec(s.store, "alter table t1 add column c int not null", done) + err := <-done + c.Assert(err, IsNil) + + s.mustExec(c, "insert into t1 values (1,1,1)") + s.tk.MustQuery("select a,b,c from t1").Check(testkit.Rows("1 1 1")) + + // mock for outdated tidb update record. + c.Assert(writeOnlyTable, NotNil) + ctx := context.Background() + err = s.tk.Se.NewTxn(ctx) + c.Assert(err, IsNil) + oldRow, err := writeOnlyTable.RowWithCols(s.tk.Se, 1, writeOnlyTable.WritableCols()) + c.Assert(err, IsNil) + c.Assert(len(oldRow), Equals, 3) + err = writeOnlyTable.RemoveRecord(s.tk.Se, 1, oldRow) + c.Assert(err, IsNil) + _, err = writeOnlyTable.AddRecord(s.tk.Se, types.MakeDatums(oldRow[0].GetInt64(), 2, oldRow[2].GetInt64()), + &table.AddRecordOpt{IsUpdate: true}) + c.Assert(err, IsNil) + err = s.tk.Se.StmtCommit() + c.Assert(err, IsNil) + err = s.tk.Se.CommitTxn(ctx) + + s.tk.MustQuery("select a,b,c from t1").Check(testkit.Rows("1 2 1")) + + // Test for _tidb_rowid + var re *testkit.Result + s.mustExec(c, "create table t2 (a int);") + hook.OnJobRunBeforeExported = func(job *model.Job) { + if job.SchemaState != model.StateWriteOnly { + return + } + // allow write _tidb_rowid first + s.mustExec(c, "set @@tidb_opt_write_row_id=1") + s.mustExec(c, "begin") + s.mustExec(c, "insert into t2 (a,_tidb_rowid) values (1,2);") + re = s.tk.MustQuery(" select a,_tidb_rowid from t2;") + s.mustExec(c, "commit") + + } + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + + go backgroundExec(s.store, "alter table t2 add column b int not null default 3", done) + err = <-done + c.Assert(err, IsNil) + re.Check(testkit.Rows("1 2")) + s.tk.MustQuery("select a,b,_tidb_rowid from t2").Check(testkit.Rows("1 3 2")) +} diff --git a/ddl/ddl.go b/ddl/ddl.go index fec3e64e15bdb..24e2db1d93801 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -203,6 +203,8 @@ var ( ErrCoalesceOnlyOnHashPartition = terror.ClassDDL.New(codeCoalesceOnlyOnHashPartition, mysql.MySQLErrName[mysql.ErrCoalesceOnlyOnHashPartition]) // ErrViewWrongList returns create view must include all columns in the select clause ErrViewWrongList = terror.ClassDDL.New(codeViewWrongList, mysql.MySQLErrName[mysql.ErrViewWrongList]) + // ErrTableIsNotView returns for table is not base table. + ErrTableIsNotView = terror.ClassDDL.New(codeErrWrongObject, "'%s.%s' is not VIEW") ) // DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache. @@ -213,6 +215,7 @@ type DDL interface { CreateView(ctx sessionctx.Context, stmt *ast.CreateViewStmt) error CreateTableWithLike(ctx sessionctx.Context, ident, referIdent ast.Ident, ifNotExists bool) error DropTable(ctx sessionctx.Context, tableIdent ast.Ident) (err error) + DropView(ctx sessionctx.Context, tableIdent ast.Ident) (err error) CreateIndex(ctx sessionctx.Context, tableIdent ast.Ident, unique bool, indexName model.CIStr, columnNames []*ast.IndexColName, indexOption *ast.IndexOption) error DropIndex(ctx sessionctx.Context, tableIdent ast.Ident, indexName model.CIStr) error @@ -561,7 +564,7 @@ func (d *ddl) doDDLJob(ctx sessionctx.Context, job *model.Job) error { continue } - // If a job is a history job, the state must be JobSynced or JobCancel. + // If a job is a history job, the state must be JobStateSynced or JobStateRollbackDone or JobStateCancelled. if historyJob.IsSynced() { log.Infof("[ddl] DDL job ID:%d is finished", jobID) return nil @@ -570,7 +573,7 @@ func (d *ddl) doDDLJob(ctx sessionctx.Context, job *model.Job) error { if historyJob.Error != nil { return errors.Trace(historyJob.Error) } - panic("When the state is JobCancel, historyJob.Error should never be nil") + panic("When the state is JobStateRollbackDone or JobStateCancelled, historyJob.Error should never be nil") } } @@ -648,6 +651,7 @@ const ( codeWrongKeyColumn = 1167 codeBlobKeyWithoutLength = 1170 codeInvalidOnUpdate = 1294 + codeErrWrongObject = terror.ErrCode(mysql.ErrWrongObject) codeViewWrongList = 1353 codeUnsupportedOnGeneratedColumn = 3106 codeGeneratedColumnNonPrior = 3107 @@ -727,6 +731,7 @@ func init() { codeWarnDataTruncated: mysql.WarnDataTruncated, codeCoalesceOnlyOnHashPartition: mysql.ErrCoalesceOnlyOnHashPartition, codeUnknownPartition: mysql.ErrUnknownPartition, + codeErrWrongObject: mysql.ErrWrongObject, } terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes } diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 2526eb1b6ea49..ed699dca0dc93 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -22,7 +22,6 @@ import ( "fmt" "strings" "sync/atomic" - "time" "github.com/cznic/mathutil" "github.com/pingcap/errors" @@ -1495,18 +1494,9 @@ func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTab if err != nil { return errors.Trace(err) } - col.OriginDefaultValue = col.GetDefaultValue() - if col.OriginDefaultValue == nil && mysql.HasNotNullFlag(col.Flag) { - zeroVal := table.GetZeroValue(col.ToInfo()) - col.OriginDefaultValue, err = zeroVal.ToString() - if err != nil { - return errors.Trace(err) - } - } - - if col.OriginDefaultValue == strings.ToUpper(ast.CurrentTimestamp) && - (col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime) { - col.OriginDefaultValue = time.Now().Format(types.TimeFormat) + col.OriginDefaultValue, err = generateOriginDefaultValue(col.ToInfo()) + if err != nil { + return errors.Trace(err) } job := &model.Job{ @@ -2218,8 +2208,8 @@ func (d *ddl) DropTable(ctx sessionctx.Context, ti ast.Ident) (err error) { } tb, err := is.TableByName(ti.Schema, ti.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + if err != nil || tb.Meta().IsView() { + return infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name) } job := &model.Job{ @@ -2234,6 +2224,35 @@ func (d *ddl) DropTable(ctx sessionctx.Context, ti ast.Ident) (err error) { return errors.Trace(err) } +// DropView will proceed even if some view in the list does not exists. +func (d *ddl) DropView(ctx sessionctx.Context, ti ast.Ident) (err error) { + is := d.GetInformationSchema(ctx) + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) + } + + tb, err := is.TableByName(ti.Schema, ti.Name) + if err != nil { + return infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name) + } + + if !tb.Meta().IsView() { + return ErrTableIsNotView.GenWithStackByArgs(ti.Schema, ti.Name) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + Type: model.ActionDropView, + BinlogInfo: &model.HistoryInfo{}, + } + + err = d.doDDLJob(ctx, job) + err = d.callHookOnChanged(err) + return errors.Trace(err) +} + func (d *ddl) TruncateTable(ctx sessionctx.Context, ti ast.Ident) error { is := d.GetInformationSchema(ctx) schema, ok := is.SchemaByName(ti.Schema) @@ -2264,10 +2283,22 @@ func (d *ddl) RenameTable(ctx sessionctx.Context, oldIdent, newIdent ast.Ident, is := d.GetInformationSchema(ctx) oldSchema, ok := is.SchemaByName(oldIdent.Schema) if !ok { + if isAlterTable { + return infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if is.TableExists(newIdent.Schema, newIdent.Name) { + return infoschema.ErrTableExists.GenWithStackByArgs(newIdent) + } return errFileNotFound.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) } oldTbl, err := is.TableByName(oldIdent.Schema, oldIdent.Name) if err != nil { + if isAlterTable { + return infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if is.TableExists(newIdent.Schema, newIdent.Name) { + return infoschema.ErrTableExists.GenWithStackByArgs(newIdent) + } return errFileNotFound.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) } if isAlterTable && newIdent.Schema.L == oldIdent.Schema.L && newIdent.Name.L == oldIdent.Name.L { diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index fdc787a24c85b..8723804d3099a 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -459,8 +459,8 @@ func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, ver, err = onCreateTable(d, t, job) case model.ActionCreateView: ver, err = onCreateView(d, t, job) - case model.ActionDropTable: - ver, err = onDropTable(t, job) + case model.ActionDropTable, model.ActionDropView: + ver, err = onDropTableOrView(t, job) case model.ActionDropTablePartition: ver, err = onDropTablePartition(t, job) case model.ActionTruncateTablePartition: diff --git a/ddl/ddl_worker_test.go b/ddl/ddl_worker_test.go index a9648bd05aa67..e375205ca4f76 100644 --- a/ddl/ddl_worker_test.go +++ b/ddl/ddl_worker_test.go @@ -375,9 +375,9 @@ func buildCancelJobTests(firstID int64) []testCancelJob { // Test create database, watch out, database id will alloc a globalID. {act: model.ActionCreateSchema, jobIDs: []int64{firstID + 12}, cancelRetErrs: noErrs, cancelState: model.StateNone, ddlRetErr: err}, - {act: model.ActionDropColumn, jobIDs: []int64{firstID + 13}, cancelRetErrs: []error{admin.ErrCancelFinishedDDLJob.GenWithStackByArgs(firstID + 13)}, cancelState: model.StateDeleteOnly, ddlRetErr: err}, - {act: model.ActionDropColumn, jobIDs: []int64{firstID + 14}, cancelRetErrs: []error{admin.ErrCancelFinishedDDLJob.GenWithStackByArgs(firstID + 14)}, cancelState: model.StateWriteOnly, ddlRetErr: err}, - {act: model.ActionDropColumn, jobIDs: []int64{firstID + 15}, cancelRetErrs: []error{admin.ErrCancelFinishedDDLJob.GenWithStackByArgs(firstID + 15)}, cancelState: model.StateWriteReorganization, ddlRetErr: err}, + {act: model.ActionDropColumn, jobIDs: []int64{firstID + 13}, cancelRetErrs: []error{admin.ErrCannotCancelDDLJob.GenWithStackByArgs(firstID + 13)}, cancelState: model.StateDeleteOnly, ddlRetErr: err}, + {act: model.ActionDropColumn, jobIDs: []int64{firstID + 14}, cancelRetErrs: []error{admin.ErrCannotCancelDDLJob.GenWithStackByArgs(firstID + 14)}, cancelState: model.StateWriteOnly, ddlRetErr: err}, + {act: model.ActionDropColumn, jobIDs: []int64{firstID + 15}, cancelRetErrs: []error{admin.ErrCannotCancelDDLJob.GenWithStackByArgs(firstID + 15)}, cancelState: model.StateWriteReorganization, ddlRetErr: err}, } return tests @@ -427,15 +427,15 @@ func (s *testDDLSuite) TestCancelJob(c *C) { dbInfo := testSchemaInfo(c, d, "test_cancel_job") testCreateSchema(c, testNewContext(d), d, dbInfo) - // create table t (c1 int, c2 int); - tblInfo := testTableInfo(c, d, "t", 2) + // create table t (c1 int, c2 int, c3 int, c4 int, c5 int); + tblInfo := testTableInfo(c, d, "t", 5) ctx := testNewContext(d) err := ctx.NewTxn(context.Background()) c.Assert(err, IsNil) job := testCreateTable(c, ctx, d, dbInfo, tblInfo) - // insert t values (1, 2); + // insert t values (1, 2, 3, 4, 5); originTable := testGetTable(c, d, dbInfo.ID, tblInfo.ID) - row := types.MakeDatums(1, 2) + row := types.MakeDatums(1, 2, 3, 4, 5) _, err = originTable.AddRecord(ctx, row) c.Assert(err, IsNil) txn, err := ctx.Txn(true) @@ -559,21 +559,26 @@ func (s *testDDLSuite) TestCancelJob(c *C) { // for drop column. test = &tests[10] - dropColName := "c2" - dropColumnArgs := []interface{}{model.NewCIStr(dropColName)} - doDDLJobErrWithSchemaState(ctx, d, c, dbInfo.ID, tblInfo.ID, model.ActionDropColumn, dropColumnArgs, &cancelState) - c.Check(errors.ErrorStack(checkErr), Equals, "") + dropColName := "c3" s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, false) + testDropColumn(c, ctx, d, dbInfo, tblInfo, dropColName, false) + c.Check(errors.ErrorStack(checkErr), Equals, "") + s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, true) test = &tests[11] - doDDLJobErrWithSchemaState(ctx, d, c, dbInfo.ID, tblInfo.ID, model.ActionDropColumn, dropColumnArgs, &cancelState) - c.Check(errors.ErrorStack(checkErr), Equals, "") + + dropColName = "c4" s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, false) + testDropColumn(c, ctx, d, dbInfo, tblInfo, dropColName, false) + c.Check(errors.ErrorStack(checkErr), Equals, "") + s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, true) test = &tests[12] - doDDLJobErrWithSchemaState(ctx, d, c, dbInfo.ID, tblInfo.ID, model.ActionDropColumn, dropColumnArgs, &cancelState) - c.Check(errors.ErrorStack(checkErr), Equals, "") + dropColName = "c5" s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, false) + testDropColumn(c, ctx, d, dbInfo, tblInfo, dropColName, false) + c.Check(errors.ErrorStack(checkErr), Equals, "") + s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, true) } func (s *testDDLSuite) TestIgnorableSpec(c *C) { diff --git a/ddl/failtest/fail_db_test.go b/ddl/failtest/fail_db_test.go index 491704bca3eb9..adfac9551c04d 100644 --- a/ddl/failtest/fail_db_test.go +++ b/ddl/failtest/fail_db_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/ddl/testutil" + ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" @@ -342,11 +343,13 @@ func (s *testFailDBSuite) TestAddIndexWorkerNum(c *C) { // Split table to multi region. s.cluster.SplitTable(s.mvccStore, tbl.Meta().ID, splitCount) + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) originDDLAddIndexWorkerCnt := variable.GetDDLReorgWorkerCounter() lastSetWorkerCnt := originDDLAddIndexWorkerCnt atomic.StoreInt32(&ddl.TestCheckWorkerNumber, lastSetWorkerCnt) ddl.TestCheckWorkerNumber = lastSetWorkerCnt - defer variable.SetDDLReorgWorkerCounter(originDDLAddIndexWorkerCnt) + defer tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_worker_cnt=%d", originDDLAddIndexWorkerCnt)) gofail.Enable("github.com/pingcap/tidb/ddl/checkIndexWorkerNum", `return(true)`) defer gofail.Disable("github.com/pingcap/tidb/ddl/checkIndexWorkerNum") @@ -364,7 +367,7 @@ LOOP: c.Assert(err, IsNil, Commentf("err:%v", errors.ErrorStack(err))) case <-ddl.TestCheckWorkerNumCh: lastSetWorkerCnt = int32(rand.Intn(8) + 8) - tk.MustExec(fmt.Sprintf("set @@tidb_ddl_reorg_worker_cnt=%d", lastSetWorkerCnt)) + tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_worker_cnt=%d", lastSetWorkerCnt)) atomic.StoreInt32(&ddl.TestCheckWorkerNumber, lastSetWorkerCnt) checkNum++ } diff --git a/ddl/index.go b/ddl/index.go index 2dafb9b66ef2a..1fa405fb1d61d 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" @@ -200,26 +201,11 @@ func validateRenameIndex(from, to model.CIStr, tbl *model.TableInfo) (ignore boo } func onRenameIndex(t *meta.Meta, job *model.Job) (ver int64, _ error) { - var from, to model.CIStr - if err := job.DecodeArgs(&from, &to); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - tblInfo, err := getTableInfo(t, job, job.SchemaID) + tblInfo, from, to, err := checkRenameIndex(t, job) if err != nil { - job.State = model.JobStateCancelled return ver, errors.Trace(err) } - // Double check. See function `RenameIndex` in ddl_api.go - duplicate, err := validateRenameIndex(from, to, tblInfo) - if duplicate { - return ver, nil - } - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } idx := schemautil.FindIndexByName(from.L, tblInfo.Indices) idx.Name = to if ver, err = updateVersionAndTableInfo(t, job, tblInfo, true); err != nil { @@ -360,24 +346,11 @@ func (w *worker) onCreateIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int } func onDropIndex(t *meta.Meta, job *model.Job) (ver int64, _ error) { - schemaID := job.SchemaID - tblInfo, err := getTableInfo(t, job, schemaID) + tblInfo, indexInfo, err := checkDropIndex(t, job) if err != nil { return ver, errors.Trace(err) } - var indexName model.CIStr - if err = job.DecodeArgs(&indexName); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - indexInfo := schemautil.FindIndexByName(indexName.L, tblInfo.Indices) - if indexInfo == nil { - job.State = model.JobStateCancelled - return ver, ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", indexName) - } - originalState := indexInfo.State switch indexInfo.State { case model.StatePublic: @@ -399,7 +372,7 @@ func onDropIndex(t *meta.Meta, job *model.Job) (ver int64, _ error) { // reorganization -> absent newIndices := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) for _, idx := range tblInfo.Indices { - if idx.Name.L != indexName.L { + if idx.Name.L != indexInfo.Name.L { newIndices = append(newIndices, idx) } } @@ -428,6 +401,52 @@ func onDropIndex(t *meta.Meta, job *model.Job) (ver int64, _ error) { return ver, errors.Trace(err) } +func checkDropIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.IndexInfo, error) { + schemaID := job.SchemaID + tblInfo, err := getTableInfo(t, job, schemaID) + if err != nil { + return nil, nil, errors.Trace(err) + } + + var indexName model.CIStr + if err = job.DecodeArgs(&indexName); err != nil { + job.State = model.JobStateCancelled + return nil, nil, errors.Trace(err) + } + + indexInfo := schemautil.FindIndexByName(indexName.L, tblInfo.Indices) + if indexInfo == nil { + job.State = model.JobStateCancelled + return nil, nil, ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", indexName) + } + return tblInfo, indexInfo, nil +} + +func checkRenameIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, model.CIStr, model.CIStr, error) { + var from, to model.CIStr + schemaID := job.SchemaID + tblInfo, err := getTableInfo(t, job, schemaID) + if err != nil { + return nil, from, to, errors.Trace(err) + } + + if err := job.DecodeArgs(&from, &to); err != nil { + job.State = model.JobStateCancelled + return nil, from, to, errors.Trace(err) + } + + // Double check. See function `RenameIndex` in ddl_api.go + duplicate, err := validateRenameIndex(from, to, tblInfo) + if duplicate { + return nil, from, to, nil + } + if err != nil { + job.State = model.JobStateCancelled + return nil, from, to, errors.Trace(err) + } + return tblInfo, from, to, errors.Trace(err) +} + const ( // DefaultTaskHandleCnt is default batch size of adding indices. DefaultTaskHandleCnt = 128 @@ -730,6 +749,11 @@ func (w *addIndexWorker) batchCheckUniqueKey(txn kv.Transaction, idxRecords []*i // backfillIndexInTxn will add w.batchCnt indices once, default value of w.batchCnt is 128. // TODO: make w.batchCnt can be modified by system variable. func (w *addIndexWorker) backfillIndexInTxn(handleRange reorgIndexTask) (taskCtx addIndexTaskContext, errInTxn error) { + // gofail: var errorMockPanic bool + // if errorMockPanic { + // panic("panic test") + // } + oprStartTime := time.Now() errInTxn = kv.RunInNewTxn(w.sessCtx.GetStore(), true, func(txn kv.Transaction) error { taskCtx.addedCount = 0 @@ -792,13 +816,17 @@ func (w *addIndexWorker) handleBackfillTask(d *ddlCtx, task *reorgIndexTask) *ad startTime := lastLogTime for { - taskCtx, err := w.backfillIndexInTxn(handleRange) - if err == nil { - // Because reorgIndexTask may run a long time, - // we should check whether this ddl job is still runnable. - err = w.ddlWorker.isReorgRunnable(d) + // Give job chance to be canceled, if we not check it here, + // if there is panic in w.backfillIndexInTxn we will never cancel the job. + // Because reorgIndexTask may run a long time, + // we should check whether this ddl job is still runnable. + err := w.ddlWorker.isReorgRunnable(d) + if err != nil { + result.err = err + return result } + taskCtx, err := w.backfillIndexInTxn(handleRange) if err != nil { result.err = err return result @@ -1053,6 +1081,17 @@ var ( TestCheckWorkerNumber = int32(16) ) +func loadDDLReorgVars(w *worker) error { + // Get sessionctx from context resource pool. + var ctx sessionctx.Context + ctx, err := w.sessPool.get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.put(ctx) + return ddlutil.LoadDDLReorgVars(ctx) +} + // addPhysicalTableIndex handles the add index reorganization state for a non-partitioned table or a partition. // For a partitioned table, it should be handled partition by partition. // @@ -1093,6 +1132,9 @@ func (w *worker) addPhysicalTableIndex(t table.PhysicalTable, indexInfo *model.I } // For dynamic adjust add index worker number. + if err := loadDDLReorgVars(w); err != nil { + log.Error(err) + } workerCnt = variable.GetDDLReorgWorkerCounter() // If only have 1 range, we can only start 1 worker. if len(kvRanges) < int(workerCnt) { diff --git a/ddl/rollingback.go b/ddl/rollingback.go index 2451a39d5ab85..3af1b01e37586 100644 --- a/ddl/rollingback.go +++ b/ddl/rollingback.go @@ -17,7 +17,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" - "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/util/schemautil" @@ -78,34 +77,15 @@ func convertNotStartAddIdxJob2RollbackJob(t *meta.Meta, job *model.Job, occuredE func rollingbackAddColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { job.State = model.JobStateRollingback - col := &model.ColumnInfo{} - pos := &ast.ColumnPosition{} - offset := 0 - err = job.DecodeArgs(col, pos, &offset) + tblInfo, columnInfo, col, _, _, err := checkAddColumn(t, job) if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - schemaID := job.SchemaID - tblInfo, err := getTableInfo(t, job, schemaID) - if err != nil { - job.State = model.JobStateCancelled return ver, errors.Trace(err) } - - columnInfo := model.FindColumnInfo(tblInfo.Columns, col.Name.L) if columnInfo == nil { job.State = model.JobStateCancelled return ver, errCancelledDDLJob } - if columnInfo.State == model.StatePublic { - // We already have a column with the same column name. - job.State = model.JobStateCancelled - return ver, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name) - } - originalState := columnInfo.State columnInfo.State = model.StateDeleteOnly job.SchemaState = model.StateDeleteOnly @@ -119,57 +99,29 @@ func rollingbackAddColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { } func rollingbackDropColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, err := getTableInfo(t, job, job.SchemaID) + tblInfo, colInfo, err := checkDropColumn(t, job) if err != nil { return ver, errors.Trace(err) } - var colName model.CIStr - err = job.DecodeArgs(&colName) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - colInfo := model.FindColumnInfo(tblInfo.Columns, colName.L) - if colInfo == nil { - job.State = model.JobStateCancelled - return ver, ErrCantDropFieldOrKey.GenWithStack("column %s doesn't exist", colName) - } - // StatePublic means when the job is not running yet. if colInfo.State == model.StatePublic { job.State = model.JobStateCancelled - } else { - // In the state of drop column `write only -> delete only -> reorganization`, - // We can not rollback now, so just continue to drop column. - job.State = model.JobStateRunning - return ver, errors.Trace(nil) + job.FinishTableJob(model.JobStateRollbackDone, model.StatePublic, ver, tblInfo) + return ver, errCancelledDDLJob } - job.FinishTableJob(model.JobStateRollbackDone, model.StatePublic, ver, tblInfo) - return ver, errors.Trace(errCancelledDDLJob) + // In the state of drop column `write only -> delete only -> reorganization`, + // We can not rollback now, so just continue to drop column. + job.State = model.JobStateRunning + return ver, nil } func rollingbackDropIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { - schemaID := job.SchemaID - tblInfo, err := getTableInfo(t, job, schemaID) + tblInfo, indexInfo, err := checkDropIndex(t, job) if err != nil { return ver, errors.Trace(err) } - var indexName model.CIStr - err = job.DecodeArgs(&indexName) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - indexInfo := schemautil.FindIndexByName(indexName.L, tblInfo.Indices) - if indexInfo == nil { - job.State = model.JobStateCancelled - return ver, ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", indexName) - } - originalState := indexInfo.State switch indexInfo.State { case model.StateDeleteOnly, model.StateDeleteReorganization, model.StateNone: @@ -208,6 +160,51 @@ func rollingbackAddindex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) (ve return } +func rollingbackDropTableOrView(t *meta.Meta, job *model.Job) error { + tblInfo, err := checkTableExist(t, job, job.SchemaID) + if err != nil { + return errors.Trace(err) + } + // To simplify the rollback logic, cannot be canceled after job start to run. + // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. + if tblInfo.State == model.StatePublic { + job.State = model.JobStateCancelled + return errCancelledDDLJob + } + job.State = model.JobStateRunning + return nil +} + +func rollingbackDropSchema(t *meta.Meta, job *model.Job) error { + dbInfo, err := checkDropSchema(t, job) + if err != nil { + return errors.Trace(err) + } + // To simplify the rollback logic, cannot be canceled after job start to run. + // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. + if dbInfo.State == model.StatePublic { + job.State = model.JobStateCancelled + return errCancelledDDLJob + } + job.State = model.JobStateRunning + return nil +} + +func rollingbackRenameIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { + tblInfo, from, _, err := checkRenameIndex(t, job) + if err != nil { + return ver, errors.Trace(err) + } + // Here rename index is done in a transaction, if the job is not completed, it can be canceled. + idx := schemautil.FindIndexByName(from.L, tblInfo.Indices) + if idx.State == model.StatePublic { + job.State = model.JobStateCancelled + return ver, errCancelledDDLJob + } + job.State = model.JobStateRunning + return ver, errors.Trace(err) +} + func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { switch job.Type { case model.ActionAddColumn: @@ -218,8 +215,12 @@ func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) ver, err = rollingbackDropColumn(t, job) case model.ActionDropIndex: ver, err = rollingbackDropIndex(t, job) - case model.ActionDropTable, model.ActionDropSchema: - job.State = model.JobStateRollingback + case model.ActionDropTable, model.ActionDropView: + err = rollingbackDropTableOrView(t, job) + case model.ActionDropSchema: + err = rollingbackDropSchema(t, job) + case model.ActionRenameIndex: + ver, err = rollingbackRenameIndex(t, job) default: job.State = model.JobStateCancelled err = errCancelledDDLJob diff --git a/ddl/schema.go b/ddl/schema.go index 8c506b67e242b..3660c7268c14f 100644 --- a/ddl/schema.go +++ b/ddl/schema.go @@ -71,24 +71,10 @@ func onCreateSchema(t *meta.Meta, job *model.Job) (ver int64, _ error) { } func onDropSchema(t *meta.Meta, job *model.Job) (ver int64, _ error) { - dbInfo, err := t.GetDatabase(job.SchemaID) + dbInfo, err := checkDropSchema(t, job) if err != nil { return ver, errors.Trace(err) } - if dbInfo == nil { - job.State = model.JobStateCancelled - return ver, infoschema.ErrDatabaseDropExists.GenWithStackByArgs("") - } - - if job.IsRollingback() { - // To simplify the rollback logic, cannot be canceled after job start to run. - // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. - if dbInfo.State == model.StatePublic { - job.State = model.JobStateCancelled - return ver, errCancelledDDLJob - } - job.State = model.JobStateRunning - } ver, err = updateSchemaVersion(t, job) if err != nil { @@ -134,6 +120,18 @@ func onDropSchema(t *meta.Meta, job *model.Job) (ver int64, _ error) { return ver, errors.Trace(err) } +func checkDropSchema(t *meta.Meta, job *model.Job) (*model.DBInfo, error) { + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return nil, errors.Trace(err) + } + if dbInfo == nil { + job.State = model.JobStateCancelled + return nil, infoschema.ErrDatabaseDropExists.GenWithStackByArgs("") + } + return dbInfo, nil +} + func getIDs(tables []*model.TableInfo) []int64 { ids := make([]int64, 0, len(tables)) for _, t := range tables { diff --git a/ddl/serial_test.go b/ddl/serial_test.go new file mode 100644 index 0000000000000..d846cdafe7df1 --- /dev/null +++ b/ddl/serial_test.go @@ -0,0 +1,125 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl_test + +import ( + "context" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/errors" + gofail "github.com/pingcap/gofail/runtime" + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util/admin" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" +) + +var _ = SerialSuites(&testSerialSuite{}) + +type testSerialSuite struct { + store kv.Storage + dom *domain.Domain +} + +func (s *testSerialSuite) SetUpSuite(c *C) { + testleak.BeforeTest() + session.SetSchemaLease(200 * time.Millisecond) + session.SetStatsLease(0) + + ddl.WaitTimeWhenErrorOccured = 1 * time.Microsecond + var err error + s.store, err = mockstore.NewMockTikvStore() + c.Assert(err, IsNil) + + s.dom, err = session.BootstrapSession(s.store) + c.Assert(err, IsNil) +} + +func (s *testSerialSuite) TearDownSuite(c *C) { + if s.dom != nil { + s.dom.Close() + } + if s.store != nil { + s.store.Close() + } + testleak.AfterTest(c)() +} + +// TestCancelAddIndex1 tests canceling ddl job when the add index worker is not started. +func (s *testSerialSuite) TestCancelAddIndexPanic(c *C) { + gofail.Enable("github.com/pingcap/tidb/ddl/errorMockPanic", `return(true)`) + defer gofail.Disable("github.com/pingcap/tidb/ddl/errorMockPanic") + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(c1 int, c2 int)") + defer tk.MustExec("drop table t;") + for i := 0; i < 5; i++ { + tk.MustExec("insert into t values (?, ?)", i, i) + } + var checkErr error + oldReorgWaitTimeout := ddl.ReorgWaitTimeout + ddl.ReorgWaitTimeout = 50 * time.Millisecond + defer func() { ddl.ReorgWaitTimeout = oldReorgWaitTimeout }() + hook := &ddl.TestDDLCallback{} + hook.OnJobRunBeforeExported = func(job *model.Job) { + if job.Type == model.ActionAddIndex && job.State == model.JobStateRunning && job.SchemaState == model.StateWriteReorganization && job.SnapshotVer != 0 { + jobIDs := []int64{job.ID} + hookCtx := mock.NewContext() + hookCtx.Store = s.store + err := hookCtx.NewTxn(context.Background()) + if err != nil { + checkErr = errors.Trace(err) + return + } + txn, err := hookCtx.Txn(true) + if err != nil { + checkErr = errors.Trace(err) + return + } + errs, err := admin.CancelJobs(txn, jobIDs) + if err != nil { + checkErr = errors.Trace(err) + return + } + if errs[0] != nil { + checkErr = errors.Trace(errs[0]) + return + } + txn, err = hookCtx.Txn(true) + if err != nil { + checkErr = errors.Trace(err) + return + } + checkErr = txn.Commit(context.Background()) + } + } + origHook := s.dom.DDL().GetHook() + defer s.dom.DDL().(ddl.DDLForTest).SetHook(origHook) + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + rs, err := tk.Exec("alter table t add index idx_c2(c2)") + if rs != nil { + rs.Close() + } + c.Assert(checkErr, IsNil) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:12]cancelled DDL job") +} diff --git a/ddl/table.go b/ddl/table.go index d6ec9370291f1..e53256153c594 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -109,40 +109,12 @@ func onCreateView(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) } } -func onDropTable(t *meta.Meta, job *model.Job) (ver int64, _ error) { - schemaID := job.SchemaID - tableID := job.TableID - // Check this table's database. - tblInfo, err := t.GetTable(schemaID, tableID) +func onDropTableOrView(t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, err := checkTableExist(t, job, job.SchemaID) if err != nil { - if meta.ErrDBNotExists.Equal(err) { - job.State = model.JobStateCancelled - return ver, errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", schemaID), - )) - } return ver, errors.Trace(err) } - // Check the table. - if tblInfo == nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", schemaID), - fmt.Sprintf("(Table ID %d)", tableID), - )) - } - - if job.IsRollingback() { - // To simplify the rollback logic, cannot be canceled after job start to run. - // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. - if tblInfo.State == model.StatePublic { - job.State = model.JobStateCancelled - return ver, errCancelledDDLJob - } - job.State = model.JobStateRunning - } - originalState := job.SchemaState switch tblInfo.State { case model.StatePublic: @@ -161,12 +133,12 @@ func onDropTable(t *meta.Meta, job *model.Job) (ver int64, _ error) { if err != nil { return ver, errors.Trace(err) } - if err = t.DropTable(job.SchemaID, tableID, true); err != nil { + if err = t.DropTable(job.SchemaID, job.TableID, true); err != nil { break } // Finish this job. job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - startKey := tablecodec.EncodeTablePrefix(tableID) + startKey := tablecodec.EncodeTablePrefix(job.TableID) job.Args = append(job.Args, startKey, getPartitionIDs(tblInfo)) default: err = ErrInvalidTableState.GenWithStack("invalid table state %v", tblInfo.State) @@ -198,7 +170,22 @@ func getTable(store kv.Storage, schemaID int64, tblInfo *model.TableInfo) (table } func getTableInfo(t *meta.Meta, job *model.Job, schemaID int64) (*model.TableInfo, error) { + tblInfo, err := checkTableExist(t, job, schemaID) + if err != nil { + return nil, errors.Trace(err) + } + + if tblInfo.State != model.StatePublic { + job.State = model.JobStateCancelled + return nil, ErrInvalidTableState.GenWithStack("table %s is not in public, but %s", tblInfo.Name, tblInfo.State) + } + + return tblInfo, nil +} + +func checkTableExist(t *meta.Meta, job *model.Job, schemaID int64) (*model.TableInfo, error) { tableID := job.TableID + // Check this table's database. tblInfo, err := t.GetTable(schemaID, tableID) if err != nil { if meta.ErrDBNotExists.Equal(err) { @@ -208,19 +195,16 @@ func getTableInfo(t *meta.Meta, job *model.Job, schemaID int64) (*model.TableInf )) } return nil, errors.Trace(err) - } else if tblInfo == nil { + } + + // Check the table. + if tblInfo == nil { job.State = model.JobStateCancelled return nil, errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs( fmt.Sprintf("(Schema ID %d)", schemaID), fmt.Sprintf("(Table ID %d)", tableID), )) } - - if tblInfo.State != model.StatePublic { - job.State = model.JobStateCancelled - return nil, ErrInvalidTableState.GenWithStack("table %s is not in public, but %s", tblInfo.Name, tblInfo.State) - } - return tblInfo, nil } diff --git a/ddl/util/util.go b/ddl/util/util.go index 9635c66380db2..da6a7f38b3b31 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" ) @@ -128,3 +129,23 @@ func UpdateDeleteRange(ctx sessionctx.Context, dr DelRangeTask, newStartKey, old _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) return errors.Trace(err) } + +const loadDDLReorgVarsSQL = "select HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in ('" + + variable.TiDBDDLReorgWorkerCount + "', '" + + variable.TiDBDDLReorgBatchSize + "')" + +// LoadDDLReorgVars loads ddl reorg variable from mysql.global_variables. +func LoadDDLReorgVars(ctx sessionctx.Context) error { + if sctx, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { + rows, _, err := sctx.ExecRestrictedSQL(ctx, loadDDLReorgVarsSQL) + if err != nil { + return errors.Trace(err) + } + for _, row := range rows { + varName := row.GetString(0) + varValue := row.GetString(1) + variable.SetLocalSystemVar(varName, varValue) + } + } + return nil +} diff --git a/domain/domain.go b/domain/domain.go index c122e17ffb61d..c2d91bfe132cd 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -33,11 +33,11 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/infoschema/perfschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/owner" - "github.com/pingcap/tidb/perfschema" "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" diff --git a/executor/adapter.go b/executor/adapter.go index f176181b7f664..4fd004260a43d 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -396,12 +396,12 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool) { execDetail := sessVars.StmtCtx.GetExecDetails() if costTime < threshold { logutil.SlowQueryLogger.Debugf( - "[QUERY] %vcost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", - internal, costTime, execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) + "[QUERY] %vcost_time:%vs %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", + internal, costTime.Seconds(), execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) } else { logutil.SlowQueryLogger.Warnf( - "[SLOW_QUERY] %vcost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", - internal, costTime, execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) + "[SLOW_QUERY] %vcost_time:%vs %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", + internal, costTime.Seconds(), execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) metrics.TotalQueryProcHistogram.Observe(costTime.Seconds()) metrics.TotalCopProcHistogram.Observe(execDetail.ProcessTime.Seconds()) metrics.TotalCopWaitHistogram.Observe(execDetail.WaitTime.Seconds()) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 8db9d8cdfcef2..413d015e0c5c2 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -271,9 +271,9 @@ func (s *testSuite1) TestAggregation(c *C) { tk.MustQuery("select 10 from idx_agg group by b").Check(testkit.Rows("10", "10")) tk.MustQuery("select 11 from idx_agg group by a").Check(testkit.Rows("11", "11")) - tk.MustExec("set @@tidb_max_chunk_size=1;") + tk.MustExec("set @@tidb_init_chunk_size=1;") tk.MustQuery("select group_concat(b) from idx_agg group by b;").Check(testkit.Rows("1", "2,2")) - tk.MustExec("set @@tidb_max_chunk_size=2;") + tk.MustExec("set @@tidb_init_chunk_size=2;") tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int(11), b decimal(15,2))") diff --git a/executor/builder.go b/executor/builder.go index 659432fba12e3..c3e6d07a94f6a 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -16,7 +16,6 @@ package executor import ( "bytes" "context" - "fmt" "math" "sort" "strings" @@ -503,6 +502,7 @@ func (b *executorBuilder) buildShow(v *plannercore.Show) Executor { Table: v.Table, Column: v.Column, User: v.User, + IfNotExists: v.IfNotExists, Flag: v.Flag, Full: v.Full, GlobalScope: v.GlobalScope, @@ -916,6 +916,9 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo v.OtherConditions, lhsTypes, rhsTypes) } metrics.ExecutorCounter.WithLabelValues("HashJoinExec").Inc() + if e.ctx.GetSessionVars().EnableRadixJoin { + return &RadixHashJoinExec{HashJoinExec: e} + } return e } @@ -967,83 +970,12 @@ func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) { } } -// buildProjBelowAgg builds a ProjectionExec below AggregationExec. -// If all the args of `aggFuncs`, and all the item of `groupByItems` -// are columns or constants, we do not need to build the `proj`. -func (b *executorBuilder) buildProjBelowAgg(aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, src Executor) Executor { - hasScalarFunc := false - // If the mode is FinalMode, we do not need to wrap cast upon the args, - // since the types of the args are already the expected. - if len(aggFuncs) > 0 && aggFuncs[0].Mode != aggregation.FinalMode { - b.wrapCastForAggArgs(aggFuncs) - } - for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ { - f := aggFuncs[i] - for _, arg := range f.Args { - _, isScalarFunc := arg.(*expression.ScalarFunction) - hasScalarFunc = hasScalarFunc || isScalarFunc - } - } - for i, isScalarFunc := 0, false; !hasScalarFunc && i < len(groupByItems); i++ { - _, isScalarFunc = groupByItems[i].(*expression.ScalarFunction) - hasScalarFunc = hasScalarFunc || isScalarFunc - } - if !hasScalarFunc { - return src - } - - b.ctx.GetSessionVars().PlanID++ - id := b.ctx.GetSessionVars().PlanID - projFromID := fmt.Sprintf("%s_%d", plannercore.TypeProj, id) - - projSchemaCols := make([]*expression.Column, 0, len(aggFuncs)+len(groupByItems)) - projExprs := make([]expression.Expression, 0, cap(projSchemaCols)) - cursor := 0 - for _, f := range aggFuncs { - for i, arg := range f.Args { - if _, isCnst := arg.(*expression.Constant); isCnst { - continue - } - projExprs = append(projExprs, arg) - newArg := &expression.Column{ - RetType: arg.GetType(), - ColName: model.NewCIStr(fmt.Sprintf("%s_%d", f.Name, i)), - Index: cursor, - } - projSchemaCols = append(projSchemaCols, newArg) - f.Args[i] = newArg - cursor++ - } - } - for i, item := range groupByItems { - if _, isCnst := item.(*expression.Constant); isCnst { - continue - } - projExprs = append(projExprs, item) - newArg := &expression.Column{ - RetType: item.GetType(), - ColName: model.NewCIStr(fmt.Sprintf("group_%d", i)), - Index: cursor, - } - projSchemaCols = append(projSchemaCols, newArg) - groupByItems[i] = newArg - cursor++ - } - - return &ProjectionExec{ - baseExecutor: newBaseExecutor(b.ctx, expression.NewSchema(projSchemaCols...), projFromID, src), - numWorkers: b.ctx.GetSessionVars().ProjectionConcurrency, - evaluatorSuit: expression.NewEvaluatorSuite(projExprs, false), - } -} - func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) Executor { src := b.build(v.Children()[0]) if b.err != nil { b.err = errors.Trace(b.err) return nil } - src = b.buildProjBelowAgg(v.AggFuncs, v.GroupByItems, src) sessionVars := b.ctx.GetSessionVars() e := &HashAggExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src), @@ -1125,7 +1057,6 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) Execu b.err = errors.Trace(b.err) return nil } - src = b.buildProjBelowAgg(v.AggFuncs, v.GroupByItems, src) e := &StreamAggExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src), StmtCtx: b.ctx.GetSessionVars().StmtCtx, diff --git a/executor/ddl.go b/executor/ddl.go index 802b0fc99d527..05c4f18eba002 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -93,7 +93,7 @@ func (e *DDLExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { case *ast.DropDatabaseStmt: err = e.executeDropDatabase(x) case *ast.DropTableStmt: - err = e.executeDropTable(x) + err = e.executeDropTableOrView(x) case *ast.DropIndexStmt: err = e.executeDropIndex(x) case *ast.AlterTableStmt: @@ -222,7 +222,7 @@ func isSystemTable(schema, table string) bool { return false } -func (e *DDLExec) executeDropTable(s *ast.DropTableStmt) error { +func (e *DDLExec) executeDropTableOrView(s *ast.DropTableStmt) error { var notExistTables []string for _, tn := range s.Tables { fullti := ast.Ident{Schema: tn.Schema, Name: tn.Name} @@ -256,7 +256,11 @@ func (e *DDLExec) executeDropTable(s *ast.DropTableStmt) error { } } - err = domain.GetDomain(e.ctx).DDL().DropTable(e.ctx, fullti) + if s.IsView { + err = domain.GetDomain(e.ctx).DDL().DropView(e.ctx, fullti) + } else { + err = domain.GetDomain(e.ctx).DDL().DropTable(e.ctx, fullti) + } if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { notExistTables = append(notExistTables, fullti.String()) } else if err != nil { diff --git a/executor/ddl_test.go b/executor/ddl_test.go index def0f570e27da..9cada7e004648 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -25,7 +25,9 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/ddl" + ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/meta/autoid" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" @@ -142,6 +144,7 @@ func (s *testSuite3) TestCreateView(c *C) { tk.MustExec("CREATE TABLE source_table (id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") //test create a exist view tk.MustExec("CREATE VIEW view_t AS select id , name from source_table") + defer tk.MustExec("DROP VIEW IF EXISTS view_t") _, err := tk.Exec("CREATE VIEW view_t AS select id , name from source_table") c.Assert(err.Error(), Equals, "[schema:1050]Table 'test.view_t' already exists") //create view on nonexistent table @@ -151,31 +154,25 @@ func (s *testSuite3) TestCreateView(c *C) { tk.MustExec("create table t1 (a int ,b int)") tk.MustExec("insert into t1 values (1,2), (1,3), (2,4), (2,5), (3,10)") //view with colList and SelectFieldExpr - _, err = tk.Exec("create view v1 (c) as select b+1 from t1") - c.Assert(err, IsNil) + tk.MustExec("create view v1 (c) as select b+1 from t1") //view with SelectFieldExpr - _, err = tk.Exec("create view v2 as select b+1 from t1") - c.Assert(err, IsNil) + tk.MustExec("create view v2 as select b+1 from t1") //view with SelectFieldExpr and AsName - _, err = tk.Exec("create view v3 as select b+1 as c from t1") - c.Assert(err, IsNil) + tk.MustExec("create view v3 as select b+1 as c from t1") //view with colList , SelectField and AsName - _, err = tk.Exec("create view v4 (c) as select b+1 as d from t1") - c.Assert(err, IsNil) + tk.MustExec("create view v4 (c) as select b+1 as d from t1") //view with select wild card - _, err = tk.Exec("create view v5 as select * from t1") - c.Assert(err, IsNil) - _, err = tk.Exec("create view v6 (c,d) as select * from t1") - c.Assert(err, IsNil) + tk.MustExec("create view v5 as select * from t1") + tk.MustExec("create view v6 (c,d) as select * from t1") _, err = tk.Exec("create view v7 (c,d,e) as select * from t1") c.Assert(err.Error(), Equals, ddl.ErrViewWrongList.Error()) - tk.MustExec("drop table v1,v2,v3,v4,v5,v6") + //drop multiple views in a statement + tk.MustExec("drop view v1,v2,v3,v4,v5,v6") //view with variable - _, err = tk.Exec("create view v1 (c,d) as select a,b+@@global.max_user_connections from t1") - c.Assert(err, IsNil) + tk.MustExec("create view v1 (c,d) as select a,b+@@global.max_user_connections from t1") _, err = tk.Exec("create view v1 (c,d) as select a,b from t1 where a = @@global.max_user_connections") c.Assert(err.Error(), Equals, "[schema:1050]Table 'test.v1' already exists") - tk.MustExec("drop table v1") + tk.MustExec("drop view v1") //view with different col counts _, err = tk.Exec("create view v1 (c,d,e) as select a,b from t1 ") c.Assert(err.Error(), Equals, ddl.ErrViewWrongList.Error()) @@ -211,6 +208,24 @@ func (s *testSuite3) TestCreateDropTable(c *C) { c.Assert(err, NotNil) } +func (s *testSuite3) TestCreateDropView(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create or replace view drop_test as select 1,2") + _, err := tk.Exec("drop view if exists drop_test") + c.Assert(err, IsNil) + + _, err = tk.Exec("drop view mysql.gc_delete_range") + c.Assert(err.Error(), Equals, "Drop tidb system table 'mysql.gc_delete_range' is forbidden") + + _, err = tk.Exec("drop view drop_test") + c.Assert(err.Error(), Equals, "[schema:1051]Unknown table 'test.drop_test'") + + tk.MustExec("create table t_v(a int)") + _, err = tk.Exec("drop view t_v") + c.Assert(err.Error(), Equals, "[ddl:1347]'test.t_v' is not VIEW") +} + func (s *testSuite3) TestCreateDropIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -456,6 +471,24 @@ func (s *testSuite3) TestShardRowIDBits(c *C) { _, err = tk.Exec("alter table auto shard_row_id_bits = 4") c.Assert(err, NotNil) tk.MustExec("alter table auto shard_row_id_bits = 0") + + // Test overflow + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (a int) shard_row_id_bits = 15") + defer tk.MustExec("drop table if exists t1") + + tbl, err = domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) + c.Assert(err, IsNil) + maxID := 1<<(64-15-1) - 1 + err = tbl.RebaseAutoID(tk.Se, int64(maxID)-1, false) + c.Assert(err, IsNil) + tk.MustExec("insert into t1 values(1)") + + // continue inserting will fail. + _, err = tk.Exec("insert into t1 values(2)") + c.Assert(autoid.ErrAutoincReadFailed.Equal(err), IsTrue, Commentf("err:%v", err)) + _, err = tk.Exec("insert into t1 values(3)") + c.Assert(autoid.ErrAutoincReadFailed.Equal(err), IsTrue, Commentf("err:%v", err)) } func (s *testSuite3) TestMaxHandleAddIndex(c *C) { @@ -478,24 +511,32 @@ func (s *testSuite3) TestMaxHandleAddIndex(c *C) { func (s *testSuite3) TestSetDDLReorgWorkerCnt(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + err := ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(variable.DefTiDBDDLReorgWorkerCount)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 1") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 1") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(1)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) - _, err := tk.Exec("set tidb_ddl_reorg_worker_cnt = invalid_val") + _, err = tk.Exec("set @@global.tidb_ddl_reorg_worker_cnt = invalid_val") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) - _, err = tk.Exec("set tidb_ddl_reorg_worker_cnt = -1") + _, err = tk.Exec("set @@global.tidb_ddl_reorg_worker_cnt = -1") c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue, Commentf("err %v", err)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") - res := tk.MustQuery("select @@tidb_ddl_reorg_worker_cnt") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") + res := tk.MustQuery("select @@global.tidb_ddl_reorg_worker_cnt") res.Check(testkit.Rows("100")) res = tk.MustQuery("select @@global.tidb_ddl_reorg_worker_cnt") - res.Check(testkit.Rows(fmt.Sprintf("%v", variable.DefTiDBDDLReorgWorkerCount))) + res.Check(testkit.Rows("100")) tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") res = tk.MustQuery("select @@global.tidb_ddl_reorg_worker_cnt") res.Check(testkit.Rows("100")) @@ -504,28 +545,39 @@ func (s *testSuite3) TestSetDDLReorgWorkerCnt(c *C) { func (s *testSuite3) TestSetDDLReorgBatchSize(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + err := ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(variable.DefTiDBDDLReorgBatchSize)) - tk.MustExec("set tidb_ddl_reorg_batch_size = 1") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 1") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '1'")) + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(variable.MinDDLReorgBatchSize)) - tk.MustExec(fmt.Sprintf("set tidb_ddl_reorg_batch_size = %v", variable.MaxDDLReorgBatchSize+1)) + tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_batch_size = %v", variable.MaxDDLReorgBatchSize+1)) tk.MustQuery("show warnings;").Check(testkit.Rows(fmt.Sprintf("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '%d'", variable.MaxDDLReorgBatchSize+1))) + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(variable.MaxDDLReorgBatchSize)) - _, err := tk.Exec("set tidb_ddl_reorg_batch_size = invalid_val") + _, err = tk.Exec("set @@global.tidb_ddl_reorg_batch_size = invalid_val") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) - tk.MustExec("set tidb_ddl_reorg_batch_size = 100") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 100") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(100)) - tk.MustExec("set tidb_ddl_reorg_batch_size = -1") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = -1") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '-1'")) - tk.MustExec("set tidb_ddl_reorg_batch_size = 100") - res := tk.MustQuery("select @@tidb_ddl_reorg_batch_size") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 100") + res := tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") res.Check(testkit.Rows("100")) res = tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") - res.Check(testkit.Rows(fmt.Sprintf("%v", variable.DefTiDBDDLReorgBatchSize))) + res.Check(testkit.Rows(fmt.Sprintf("%v", 100))) tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 1000") res = tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") res.Check(testkit.Rows("1000")) + + // If do not LoadDDLReorgVars, the local variable will be the last loaded value. + c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(100)) } diff --git a/executor/distsql.go b/executor/distsql.go index 8b2b6a21820e6..beed71598d4fc 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -123,7 +123,7 @@ func statementContextToFlags(sc *stmtctx.StatementContext) uint64 { var flags uint64 if sc.InInsertStmt { flags |= model.FlagInInsertStmt - } else if sc.InUpdateOrDeleteStmt { + } else if sc.InUpdateStmt || sc.InDeleteStmt { flags |= model.FlagInUpdateOrDeleteStmt } else if sc.InSelectStmt { flags |= model.FlagInSelectStmt diff --git a/executor/executor.go b/executor/executor.go index 7e6009ac72166..25a27a36b0ca4 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -134,7 +134,7 @@ func newBaseExecutor(ctx sessionctx.Context, schema *expression.Schema, id strin ctx: ctx, id: id, schema: schema, - initCap: ctx.GetSessionVars().MaxChunkSize, + initCap: ctx.GetSessionVars().InitChunkSize, maxChunkSize: ctx.GetSessionVars().MaxChunkSize, } if ctx.GetSessionVars().StmtCtx.RuntimeStatsColl != nil { @@ -1272,7 +1272,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // pushing them down to TiKV as flags. switch stmt := s.(type) { case *ast.UpdateStmt: - sc.InUpdateOrDeleteStmt = true + sc.InUpdateStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr @@ -1280,7 +1280,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreZeroInDate = !vars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority case *ast.DeleteStmt: - sc.InUpdateOrDeleteStmt = true + sc.InDeleteStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr @@ -1301,6 +1301,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.DupKeyAsWarning = true sc.BadNullAsWarning = true sc.TruncateAsWarning = !vars.StrictSQLMode + sc.InLoadDataStmt = true case *ast.SelectStmt: sc.InSelectStmt = true @@ -1341,7 +1342,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID } sc.PrevAffectedRows = 0 - if vars.StmtCtx.InUpdateOrDeleteStmt || vars.StmtCtx.InInsertStmt { + if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt { sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) } else if vars.StmtCtx.InSelectStmt { sc.PrevAffectedRows = -1 diff --git a/executor/executor_test.go b/executor/executor_test.go index f57b1376810f3..7a90438ae3214 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -65,7 +65,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -// TestLeakCheckCnt is the check count in the pacakge of executor. +// TestLeakCheckCnt is the check count in the package of executor. // In this package CustomParallelSuiteFlag is true, so we need to increase check count. const TestLeakCheckCnt = 1000 @@ -85,6 +85,7 @@ var _ = Suite(&testSuite1{}) var _ = Suite(&testSuite2{}) var _ = Suite(&testSuite3{}) var _ = Suite(&testBypassSuite{}) +var _ = Suite(&testUpdateSuite{}) type testSuite struct { cluster *mocktikv.Cluster @@ -341,6 +342,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, c.Assert(ctx.NewTxn(context.Background()), IsNil) ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true + ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true data, reachLimit, err1 := ld.InsertData(tt.data1, tt.data2) c.Assert(err1, IsNil) c.Assert(reachLimit, IsFalse) @@ -1082,7 +1084,7 @@ func (s *testSuite) TestUnion(c *C) { tk.MustExec(`insert into t1 select * from t1;`) tk.MustExec(`insert into t1 select * from t1;`) tk.MustExec(`insert into t2 values(1, 1);`) - tk.MustExec(`set @@tidb_max_chunk_size=2;`) + tk.MustExec(`set @@tidb_init_chunk_size=2;`) tk.MustQuery(`select count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("128")) tk.MustQuery(`select tmp.a, count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("1 128")) @@ -1095,7 +1097,7 @@ func (s *testSuite) TestUnion(c *C) { tk.MustExec("drop table if exists t1") tk.MustExec("create table t1(a int, b int)") tk.MustExec("insert into t1 value(1,2),(1,1),(2,2),(2,2),(3,2),(3,2)") - tk.MustExec("set @@tidb_max_chunk_size=2;") + tk.MustExec("set @@tidb_init_chunk_size=2;") tk.MustQuery("select count(*) from (select a as c, a as d from t1 union all select a, b from t1) t;").Check(testkit.Rows("12")) // #issue 8189 and #issue 8199 @@ -2974,7 +2976,7 @@ func (s *testSuite) TestLimit(c *C) { "4 4", "5 5", )) - tk.MustExec(`set @@tidb_max_chunk_size=2;`) + tk.MustExec(`set @@tidb_init_chunk_size=2;`) tk.MustQuery(`select * from t order by a limit 2, 1;`).Check(testkit.Rows( "3 3", )) @@ -3232,7 +3234,7 @@ func (s *testSuite3) TestMaxOneRow(c *C) { tk.MustExec(`create table t2(a double, b double);`) tk.MustExec(`insert into t1 values(1, 1), (2, 2), (3, 3);`) tk.MustExec(`insert into t2 values(0, 0);`) - tk.MustExec(`set @@tidb_max_chunk_size=1;`) + tk.MustExec(`set @@tidb_init_chunk_size=1;`) rs, err := tk.Exec(`select (select t1.a from t1 where t1.a > t2.a) as a from t2;`) c.Assert(err, IsNil) @@ -3433,10 +3435,14 @@ func (s *testSuite2) TearDownSuite(c *C) { func (s *testSuite2) TearDownTest(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - r := tk.MustQuery("show tables") + r := tk.MustQuery("show full tables") for _, tb := range r.Rows() { tableName := tb[0] - tk.MustExec(fmt.Sprintf("drop table %v", tableName)) + if tb[1] == "VIEW" { + tk.MustExec(fmt.Sprintf("drop view %v", tableName)) + } else { + tk.MustExec(fmt.Sprintf("drop table %v", tableName)) + } } } @@ -3488,3 +3494,43 @@ func (s *testSuite3) TearDownTest(c *C) { tk.MustExec(fmt.Sprintf("drop table %v", tableName)) } } + +func (s *testSuite) TestStrToDateBuiltin(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustQuery(`select str_to_date('18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('a18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('69/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('70/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("1970-10-22")) + tk.MustQuery(`select str_to_date('8/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2008-10-22")) + tk.MustQuery(`select str_to_date('8/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2008-10-22")) + tk.MustQuery(`select str_to_date('18/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('a18/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('69/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('70/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("1970-10-22")) + tk.MustQuery(`select str_to_date('018/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("0018-10-22")) + tk.MustQuery(`select str_to_date('2018/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('018/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18/10/22','%y0/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18/10/22','%Y0/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18a/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18a/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('20188/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('2018510522','%Y5%m5%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018^10^22','%Y^%m^%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018@10@22','%Y@%m@%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018%10%22','%Y%%m%%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('2018(10(22','%Y(%m(%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018\10\22','%Y\%m\%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('2018=10=22','%Y=%m=%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018+10+22','%Y+%m+%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018_10_22','%Y_%m_%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('69510522','%y5%m5%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('69^10^22','%y^%m^%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('18@10@22','%y@%m@%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18%10%22','%y%%m%%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18(10(22','%y(%m(%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18\10\22','%y\%m\%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18+10+22','%y+%m+%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18=10=22','%y=%m=%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18_10_22','%y_%m_%d') from dual`).Check(testkit.Rows("2018-10-22")) +} diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go index adeb966a1f764..d2f494f84c962 100644 --- a/executor/index_lookup_join_test.go +++ b/executor/index_lookup_join_test.go @@ -96,7 +96,7 @@ func (s *testSuite1) TestBatchIndexJoinUnionScan(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("create table t1(id int primary key, a int)") tk.MustExec("create table t2(id int primary key, a int, key idx_a(a))") - tk.MustExec("set @@session.tidb_max_chunk_size=1") + tk.MustExec("set @@session.tidb_init_chunk_size=1") tk.MustExec("set @@session.tidb_index_join_batch_size=1") tk.MustExec("set @@session.tidb_index_lookup_join_concurrency=4") tk.MustExec("begin") diff --git a/executor/join.go b/executor/join.go index 3f64e7a318cd0..3d0e9093c600d 100644 --- a/executor/join.go +++ b/executor/join.go @@ -15,7 +15,6 @@ package executor import ( "context" - "math" "sync" "sync/atomic" "time" @@ -48,19 +47,22 @@ type HashJoinExec struct { outerKeys []*expression.Column innerKeys []*expression.Column - prepared bool - concurrency uint // concurrency is number of concurrent channels and join workers. - hashTable *mvmap.MVMap + prepared bool + // concurrency is the number of partition, build and join workers. + concurrency uint + globalHashTable *mvmap.MVMap innerFinished chan error hashJoinBuffers []*hashJoinBuffer - workerWaitGroup sync.WaitGroup // workerWaitGroup is for sync multiple join workers. - finished atomic.Value - closeCh chan struct{} // closeCh add a lock for closing executor. - joinType plannercore.JoinType - innerIdx int - - // We build individual joiner for each join worker when use chunk-based execution, - // to avoid the concurrency of joiner.chk and joiner.selected. + // joinWorkerWaitGroup is for sync multiple join workers. + joinWorkerWaitGroup sync.WaitGroup + finished atomic.Value + // closeCh add a lock for closing executor. + closeCh chan struct{} + joinType plannercore.JoinType + innerIdx int + + // We build individual joiner for each join worker when use chunk-based + // execution, to avoid the concurrency of joiner.chk and joiner.selected. joiners []joiner outerKeyColIdx []int @@ -73,12 +75,6 @@ type HashJoinExec struct { hashTableValBufs [][][]byte memTracker *memory.Tracker // track memory usage. - - // radixBits indicates the bit number using for radix partitioning. Inner - // relation will be split to 2^radixBits sub-relations before building the - // hash tables. If the complete inner relation can be hold in L2Cache in - // which case radixBits will be 0, we can skip the partition phase. - radixBits int } // outerChkResource stores the result of the join outer fetch worker, @@ -161,9 +157,14 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.hashJoinBuffers = append(e.hashJoinBuffers, buffer) } + e.innerKeyColIdx = make([]int, len(e.innerKeys)) + for i := range e.innerKeys { + e.innerKeyColIdx[i] = e.innerKeys[i].Index + } + e.closeCh = make(chan struct{}) e.finished.Store(false) - e.workerWaitGroup = sync.WaitGroup{} + e.joinWorkerWaitGroup = sync.WaitGroup{} return nil } @@ -251,77 +252,43 @@ func (e *HashJoinExec) wait4Inner() (finished bool, err error) { return false, errors.Trace(err) } } - if e.hashTable.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { + if e.innerResult.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { return true, nil } return false, nil } -// fetchInnerRows fetches all rows from inner executor, -// and append them to e.innerResult. -func (e *HashJoinExec) fetchInnerRows(ctx context.Context, chkCh chan<- *chunk.Chunk, doneCh chan struct{}) { - defer close(chkCh) +// fetchInnerRows fetches all rows from inner executor, and append them to +// e.innerResult. +func (e *HashJoinExec) fetchInnerRows(ctx context.Context) error { e.innerResult = chunk.NewList(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) e.innerResult.GetMemTracker().AttachTo(e.memTracker) e.innerResult.GetMemTracker().SetLabel("innerResult") var err error for { - select { - case <-doneCh: - return - case <-e.closeCh: - return - default: - if e.finished.Load().(bool) { - return - } - chk := e.children[e.innerIdx].newFirstChunk() - err = e.innerExec.Next(ctx, chk) - if err != nil { - e.innerFinished <- errors.Trace(err) - return - } - if chk.NumRows() == 0 { - e.evalRadixBitNum() - return - } - select { - case chkCh <- chk: - break - case <-e.closeCh: - return - } - e.innerResult.Add(chk) + if e.finished.Load().(bool) { + return nil } + chk := e.children[e.innerIdx].newFirstChunk() + err = e.innerExec.Next(ctx, chk) + if err != nil || chk.NumRows() == 0 { + return err + } + e.innerResult.Add(chk) } } -// evalRadixBitNum evaluates the radix bit numbers. -func (e *HashJoinExec) evalRadixBitNum() { - sv := e.ctx.GetSessionVars() - // Calculate the bit number needed when using radix partition. - if !sv.EnableRadixJoin { - return - } - innerResultSize := float64(e.innerResult.GetMemTracker().BytesConsumed()) - // To ensure that one partition of inner relation, one hash - // table and one partition of outer relation fit into the L2 - // cache when the input data obeys the uniform distribution, - // we suppose every sub-partition of inner relation using - // three quarters of the L2 cache size. - l2CacheSize := float64(sv.L2CacheSize) * 3 / 4 - e.radixBits = int(math.Log2(innerResultSize / l2CacheSize)) -} - func (e *HashJoinExec) initializeForProbe() { - // e.outerResultChs is for transmitting the chunks which store the data of outerExec, - // it'll be written by outer worker goroutine, and read by join workers. + // e.outerResultChs is for transmitting the chunks which store the data of + // outerExec, it'll be written by outer worker goroutine, and read by join + // workers. e.outerResultChs = make([]chan *chunk.Chunk, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.outerResultChs[i] = make(chan *chunk.Chunk, 1) } - // e.outerChkResourceCh is for transmitting the used outerExec chunks from join workers to outerExec worker. + // e.outerChkResourceCh is for transmitting the used outerExec chunks from + // join workers to outerExec worker. e.outerChkResourceCh = make(chan *outerChkResource, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.outerChkResourceCh <- &outerChkResource{ @@ -338,7 +305,8 @@ func (e *HashJoinExec) initializeForProbe() { e.joinChkResourceCh[i] <- e.newFirstChunk() } - // e.joinResultCh is for transmitting the join result chunks to the main thread. + // e.joinResultCh is for transmitting the join result chunks to the main + // thread. e.joinResultCh = make(chan *hashjoinWorkerResult, e.concurrency+1) e.outerKeyColIdx = make([]int, len(e.outerKeys)) @@ -349,37 +317,38 @@ func (e *HashJoinExec) initializeForProbe() { func (e *HashJoinExec) fetchOuterAndProbeHashTable(ctx context.Context) { e.initializeForProbe() - e.workerWaitGroup.Add(1) - go util.WithRecovery(func() { e.fetchOuterChunks(ctx) }, e.finishOuterFetcher) + e.joinWorkerWaitGroup.Add(1) + go util.WithRecovery(func() { e.fetchOuterChunks(ctx) }, e.handleOuterFetcherPanic) - // Start e.concurrency join workers to probe hash table and join inner and outer rows. + // Start e.concurrency join workers to probe hash table and join inner and + // outer rows. for i := uint(0); i < e.concurrency; i++ { - e.workerWaitGroup.Add(1) + e.joinWorkerWaitGroup.Add(1) workID := i - go util.WithRecovery(func() { e.runJoinWorker(workID) }, e.finishJoinWorker) + go util.WithRecovery(func() { e.runJoinWorker(workID) }, e.handleJoinWorkerPanic) } go util.WithRecovery(e.waitJoinWorkersAndCloseResultChan, nil) } -func (e *HashJoinExec) finishOuterFetcher(r interface{}) { +func (e *HashJoinExec) handleOuterFetcherPanic(r interface{}) { for i := range e.outerResultChs { close(e.outerResultChs[i]) } if r != nil { e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} } - e.workerWaitGroup.Done() + e.joinWorkerWaitGroup.Done() } -func (e *HashJoinExec) finishJoinWorker(r interface{}) { +func (e *HashJoinExec) handleJoinWorkerPanic(r interface{}) { if r != nil { e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} } - e.workerWaitGroup.Done() + e.joinWorkerWaitGroup.Done() } func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() { - e.workerWaitGroup.Wait() + e.joinWorkerWaitGroup.Wait() close(e.joinResultCh) } @@ -436,7 +405,7 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) return true, joinResult } - e.hashTableValBufs[workerID] = e.hashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) + e.hashTableValBufs[workerID] = e.globalHashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) innerPtrs := e.hashTableValBufs[workerID] if len(innerPtrs) == 0 { e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) @@ -525,7 +494,7 @@ func (e *HashJoinExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { } if !e.prepared { e.innerFinished = make(chan error, 1) - go util.WithRecovery(func() { e.fetchInnerAndBuildHashTable(ctx) }, e.finishFetchInnerAndBuildHashTable) + go util.WithRecovery(func() { e.fetchInnerAndBuildHashTable(ctx) }, e.handleFetchInnerAndBuildHashTablePanic) e.fetchOuterAndProbeHashTable(ctx) e.prepared = true } @@ -546,7 +515,7 @@ func (e *HashJoinExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { return nil } -func (e *HashJoinExec) finishFetchInnerAndBuildHashTable(r interface{}) { +func (e *HashJoinExec) handleFetchInnerAndBuildHashTablePanic(r interface{}) { if r != nil { e.innerFinished <- errors.Errorf("%v", r) } @@ -554,31 +523,21 @@ func (e *HashJoinExec) finishFetchInnerAndBuildHashTable(r interface{}) { } func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { - // innerResultCh transfer inner result chunk from inner fetch to build hash table. - innerResultCh := make(chan *chunk.Chunk, e.concurrency) - doneCh := make(chan struct{}) - go util.WithRecovery(func() { e.fetchInnerRows(ctx, innerResultCh, doneCh) }, nil) - - // TODO: Parallel build hash table. Currently not support because `mvmap` is not thread-safe. - err := e.buildHashTableForList(innerResultCh) - if err != nil { - e.innerFinished <- errors.Trace(err) - close(doneCh) + if err := e.fetchInnerRows(ctx); err != nil { + e.innerFinished <- err + return } - // wait fetchInnerRows be finished. - for range innerResultCh { + + if err := e.buildGlobalHashTable(); err != nil { + e.innerFinished <- err } } -// buildHashTableForList builds hash table from `list`. +// buildGlobalHashTable builds a global hash table for the inner relation. // key of hash table: hash value of key columns // value of hash table: RowPtr of the corresponded row -func (e *HashJoinExec) buildHashTableForList(innerResultCh chan *chunk.Chunk) error { - e.hashTable = mvmap.NewMVMap() - e.innerKeyColIdx = make([]int, len(e.innerKeys)) - for i := range e.innerKeys { - e.innerKeyColIdx[i] = e.innerKeys[i].Index - } +func (e *HashJoinExec) buildGlobalHashTable() error { + e.globalHashTable = mvmap.NewMVMap() var ( hasNull bool err error @@ -586,25 +545,23 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh chan *chunk.Chunk) er valBuf = make([]byte, 8) ) - chkIdx := uint32(0) - for chk := range innerResultCh { + for chkIdx := 0; chkIdx < e.innerResult.NumChunks(); chkIdx++ { if e.finished.Load().(bool) { return nil } - numRows := chk.NumRows() - for j := 0; j < numRows; j++ { + chk := e.innerResult.GetChunk(chkIdx) + for j, numRows := 0, chk.NumRows(); j < numRows; j++ { hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), keyBuf) if err != nil { - return errors.Trace(err) + return err } if hasNull { continue } - rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)} + rowPtr := chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(j)} *(*chunk.RowPtr)(unsafe.Pointer(&valBuf[0])) = rowPtr - e.hashTable.Put(keyBuf, valBuf) + e.globalHashTable.Put(keyBuf, valBuf) } - chkIdx++ } return nil } diff --git a/executor/join_test.go b/executor/join_test.go index e51fd86df9ab4..be2dc543fd7fc 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -334,10 +334,10 @@ func (s *testSuite2) TestJoinCast(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a varchar(10), index idx(a))") tk.MustExec("insert into t values('1'), ('2'), ('3')") - tk.MustExec("set @@tidb_max_chunk_size=1") + tk.MustExec("set @@tidb_init_chunk_size=1") result = tk.MustQuery("select a from (select /*+ TIDB_INLJ(t1, t2) */ t1.a from t t1 join t t2 on t1.a=t2.a) t group by a") result.Sort().Check(testkit.Rows("1", "2", "3")) - tk.MustExec("set @@tidb_max_chunk_size=1024") + tk.MustExec("set @@tidb_init_chunk_size=32") } func (s *testSuite2) TestUsing(c *C) { @@ -819,7 +819,7 @@ func (s *testSuite2) TestIssue5278(c *C) { func (s *testSuite2) TestIndexLookupJoin(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.MustExec("set @@tidb_max_chunk_size=2") + tk.MustExec("set @@tidb_init_chunk_size=2") tk.MustExec("DROP TABLE IF EXISTS t") tk.MustExec("CREATE TABLE `t` (`a` int, pk integer auto_increment,`b` char (20),primary key (pk))") tk.MustExec("CREATE INDEX idx_t_a ON t(`a`)") @@ -849,7 +849,7 @@ func (s *testSuite2) TestIndexLookupJoin(c *C) { tk.MustExec(`drop table if exists t;`) tk.MustExec(`create table t(a bigint, b bigint, unique key idx1(a, b));`) tk.MustExec(`insert into t values(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6);`) - tk.MustExec(`set @@tidb_max_chunk_size = 2;`) + tk.MustExec(`set @@tidb_init_chunk_size = 2;`) tk.MustQuery(`select /*+ TIDB_INLJ(t2) */ * from t t1 left join t t2 on t1.a = t2.a and t1.b = t2.b + 4;`).Check(testkit.Rows( `1 1 `, `1 2 `, @@ -894,7 +894,7 @@ func (s *testSuite2) TestMergejoinOrder(c *C) { " └─TableScan_12 6666.67 cop table:t2, range:[-inf,3), (3,+inf], keep order:true, stats:pseudo", )) - tk.MustExec("set @@tidb_max_chunk_size=1") + tk.MustExec("set @@tidb_init_chunk_size=1") tk.MustQuery("select /*+ TIDB_SMJ(t2) */ * from t1 left outer join t2 on t1.a=t2.a and t1.a!=3 order by t1.a;").Check(testkit.Rows( "1 100 ", "2 100 ", @@ -942,7 +942,7 @@ func (s *testSuite2) TestHashJoin(c *C) { tk.MustExec("insert into t1 values(1,1),(2,2),(3,3),(4,4),(5,5);") tk.MustQuery("select count(*) from t1").Check(testkit.Rows("5")) tk.MustQuery("select count(*) from t2").Check(testkit.Rows("0")) - tk.MustExec("set @@tidb_max_chunk_size=1;") + tk.MustExec("set @@tidb_init_chunk_size=1;") result := tk.MustQuery("explain analyze select /*+ TIDB_HJ(t1, t2) */ * from t1 where exists (select a from t2 where t1.a = t2.a);") // id count task operator info execution info // HashLeftJoin_9 8000.00 root semi join, inner:TableReader_13, equal:[eq(test.t1.a, test.t2.a)] time:1.036712ms, loops:1, rows:0 @@ -953,7 +953,8 @@ func (s *testSuite2) TestHashJoin(c *C) { row := result.Rows() c.Assert(len(row), Equals, 5) outerExecInfo := row[1][4].(string) - c.Assert(outerExecInfo[len(outerExecInfo)-1:], Equals, "1") + // FIXME: revert this result to 1 after TableReaderExecutor can handle initChunkSize. + c.Assert(outerExecInfo[len(outerExecInfo)-1:], Equals, "5") innerExecInfo := row[3][4].(string) c.Assert(innerExecInfo[len(innerExecInfo)-1:], Equals, "0") @@ -988,3 +989,22 @@ func (s *testSuite2) TestHashJoin(c *C) { innerExecInfo = row[3][4].(string) c.Assert(innerExecInfo[len(innerExecInfo)-1:], LessEqual, "5") } + +func (s *testSuite2) TestJoinDifferentDecimals(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("Use test") + tk.MustExec("Drop table if exists t1") + tk.MustExec("Create table t1 (v int)") + tk.MustExec("Insert into t1 value (1)") + tk.MustExec("Insert into t1 value (2)") + tk.MustExec("Insert into t1 value (3)") + tk.MustExec("Drop table if exists t2") + tk.MustExec("Create table t2 (v decimal(12, 3))") + tk.MustExec("Insert into t2 value (1)") + tk.MustExec("Insert into t2 value (2.0)") + tk.MustExec("Insert into t2 value (000003.000000)") + rst := tk.MustQuery("Select * from t1, t2 where t1.v = t2.v order by t1.v") + row := rst.Rows() + c.Assert(len(row), Equals, 3) + rst.Check(testkit.Rows("1 1.000", "2 2.000", "3 3.000")) +} diff --git a/executor/load_data.go b/executor/load_data.go index df81f8ccd2730..5d690470791e0 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -202,7 +202,7 @@ func (e *LoadDataInfo) getLine(prevData, curData []byte) ([]byte, []byte, bool) } // InsertData inserts data into specified table according to the specified format. -// If it has the rest of data isn't completed the processing, then is returns without completed data. +// If it has the rest of data isn't completed the processing, then it returns without completed data. // If the number of inserted rows reaches the batchRows, then the second return value is true. // If prevData isn't nil and curData is nil, there are no other data to deal with and the isEOF is true. func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error) { diff --git a/executor/pkg_test.go b/executor/pkg_test.go index f15dadfe759dd..f6ea626172ca7 100644 --- a/executor/pkg_test.go +++ b/executor/pkg_test.go @@ -3,15 +3,21 @@ package executor import ( "context" "fmt" + "testing" + "github.com/klauspost/cpuid" . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mock" + "github.com/spaolacci/murmur3" ) var _ = Suite(&pkgTestSuite{}) @@ -111,6 +117,97 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { } } +func prepareOneColChildExec(sctx sessionctx.Context, rowCount int) Executor { + col0 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} + schema := expression.NewSchema(col0) + exec := &MockExec{ + baseExecutor: newBaseExecutor(sctx, schema, ""), + Rows: make([]chunk.MutRow, rowCount)} + for i := 0; i < len(exec.Rows); i++ { + exec.Rows[i] = chunk.MutRowFromDatums(types.MakeDatums(i % 10)) + } + return exec +} + +func buildExec4RadixHashJoin(sctx sessionctx.Context, rowCount int) *RadixHashJoinExec { + childExec0 := prepareOneColChildExec(sctx, rowCount) + childExec1 := prepareOneColChildExec(sctx, rowCount) + + col0 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} + col1 := &expression.Column{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)} + joinSchema := expression.NewSchema(col0, col1) + hashJoinExec := &HashJoinExec{ + baseExecutor: newBaseExecutor(sctx, joinSchema, "HashJoin", childExec0, childExec1), + concurrency: 4, + joinType: 0, // InnerJoin + innerIdx: 0, + innerKeys: []*expression.Column{{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)}}, + innerKeyColIdx: []int{0}, + outerKeys: []*expression.Column{{Index: 0, RetType: types.NewFieldType(mysql.TypeLong)}}, + outerKeyColIdx: []int{0}, + innerExec: childExec0, + outerExec: childExec1, + } + return &RadixHashJoinExec{HashJoinExec: hashJoinExec} +} + +func (s *pkgTestSuite) TestRadixPartition(c *C) { + sctx := mock.NewContext() + hashJoinExec := buildExec4RadixHashJoin(sctx, 200) + sv := sctx.GetSessionVars() + originL2CacheSize, originEnableRadixJoin, originMaxChunkSize := sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize + sv.L2CacheSize = 100 + sv.EnableRadixJoin = true + // FIXME: use initChunkSize when join support initChunkSize. + sv.MaxChunkSize = 100 + defer func() { + sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize = originL2CacheSize, originEnableRadixJoin, originMaxChunkSize + }() + sv.StmtCtx.MemTracker = memory.NewTracker("RootMemTracker", variable.DefTiDBMemQuotaHashJoin) + + ctx := context.Background() + err := hashJoinExec.Open(ctx) + c.Assert(err, IsNil) + + hashJoinExec.fetchInnerRows(ctx) + c.Assert(hashJoinExec.innerResult.GetMemTracker().BytesConsumed(), Equals, int64(14400)) + + hashJoinExec.evalRadixBit() + // ceil(log_2(14400/(100*3/4))) = 8 + c.Assert(len(hashJoinExec.innerParts), Equals, 256) + radixBits := hashJoinExec.radixBits + c.Assert(radixBits, Equals, uint32(0x00ff)) + + err = hashJoinExec.partitionInnerRows() + c.Assert(err, IsNil) + totalRowCnt := 0 + for i, part := range hashJoinExec.innerParts { + if part == nil { + continue + } + totalRowCnt += part.NumRows() + iter := chunk.NewIterator4Chunk(part) + hasNull, keyBuf := false, make([]byte, 0, 64) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + hasNull, keyBuf, err = hashJoinExec.getJoinKeyFromChkRow(false, row, keyBuf) + c.Assert(err, IsNil) + c.Assert(hasNull, IsFalse) + joinHash := murmur3.Sum32(keyBuf) + c.Assert(joinHash&radixBits, Equals, uint32(i)&radixBits) + } + } + c.Assert(totalRowCnt, Equals, 200) + + for _, ch := range hashJoinExec.outerResultChs { + close(ch) + } + for _, ch := range hashJoinExec.joinChkResourceCh { + close(ch) + } + err = hashJoinExec.Close() + c.Assert(err, IsNil) +} + func (s *pkgTestSuite) TestMoveInfoSchemaToFront(c *C) { dbss := [][]string{ {}, @@ -140,3 +237,56 @@ func (s *pkgTestSuite) TestMoveInfoSchemaToFront(c *C) { } } } + +func BenchmarkPartitionInnerRows(b *testing.B) { + sctx := mock.NewContext() + hashJoinExec := buildExec4RadixHashJoin(sctx, 1500000) + sv := sctx.GetSessionVars() + originL2CacheSize, originEnableRadixJoin, originMaxChunkSize := sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize + sv.L2CacheSize = cpuid.CPU.Cache.L2 + sv.EnableRadixJoin = true + sv.MaxChunkSize = 1024 + defer func() { + sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize = originL2CacheSize, originEnableRadixJoin, originMaxChunkSize + }() + sv.StmtCtx.MemTracker = memory.NewTracker("RootMemTracker", variable.DefTiDBMemQuotaHashJoin) + + ctx := context.Background() + hashJoinExec.Open(ctx) + hashJoinExec.fetchInnerRows(ctx) + hashJoinExec.evalRadixBit() + b.ResetTimer() + hashJoinExec.concurrency = 16 + hashJoinExec.maxChunkSize = 1024 + hashJoinExec.initCap = 1024 + for i := 0; i < b.N; i++ { + hashJoinExec.partitionInnerRows() + hashJoinExec.innerRowPrts = hashJoinExec.innerRowPrts[:0] + } +} + +func (s *pkgTestSuite) TestParallelBuildHashTable4RadixJoin(c *C) { + sctx := mock.NewContext() + hashJoinExec := buildExec4RadixHashJoin(sctx, 200) + + sv := sctx.GetSessionVars() + sv.L2CacheSize = 100 + sv.EnableRadixJoin = true + sv.MaxChunkSize = 100 + sv.StmtCtx.MemTracker = memory.NewTracker("RootMemTracker", variable.DefTiDBMemQuotaHashJoin) + + ctx := context.Background() + err := hashJoinExec.Open(ctx) + c.Assert(err, IsNil) + + hashJoinExec.partitionInnerAndBuildHashTables(ctx) + innerParts := hashJoinExec.innerParts + c.Assert(len(hashJoinExec.hashTables), Equals, len(innerParts)) + for i := 0; i < len(innerParts); i++ { + if innerParts[i] == nil { + c.Assert(hashJoinExec.hashTables[i], IsNil) + } else { + c.Assert(hashJoinExec.hashTables[i], NotNil) + } + } +} diff --git a/executor/radix_hash_join.go b/executor/radix_hash_join.go new file mode 100644 index 0000000000000..1820999c592bf --- /dev/null +++ b/executor/radix_hash_join.go @@ -0,0 +1,270 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "math" + "sync" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mvmap" + log "github.com/sirupsen/logrus" + "github.com/spaolacci/murmur3" +) + +var ( + _ Executor = &RadixHashJoinExec{} +) + +// RadixHashJoinExec implements the radix partition-based hash join algorithm. +// It will partition the input relations into small pairs of partitions where +// one of the partitions typically fits into one of the caches. The overall goal +// of this method is to minimize the number of cache misses when building and +// probing hash tables. +type RadixHashJoinExec struct { + *HashJoinExec + + // radixBits indicates the bits using for radix partitioning. Inner relation + // will be split to 2^radixBitsNumber sub-relations before building the hash + // tables. If the complete inner relation can be hold in L2Cache in which + // case radixBits will be 1, we can skip the partition phase. + // Note: We actually check whether `size of sub inner relation < 3/4 * L2 + // cache size` to make sure one inner sub-relation, hash table, one outer + // sub-relation and join result of the sub-relations can be totally loaded + // in L2 cache size. `3/4` is a magic number, we may adjust it after + // benchmark. + radixBits uint32 + innerParts []partition + numNonEmptyPart int + // innerRowPrts indicates the position in corresponding partition of every + // row in innerResult. + innerRowPrts [][]partRowPtr + // hashTables stores the hash tables built from the inner relation, if there + // is no partition phase, a global hash table will be stored in + // hashTables[0]. + hashTables []*mvmap.MVMap +} + +// partition stores the sub-relations of inner relation and outer relation after +// partition phase. Every partition can be fully stored in L2 cache thus can +// reduce the cache miss ratio when building and probing the hash table. +type partition = *chunk.Chunk + +// partRowPtr stores the actual index in `innerParts` or `outerParts`. +type partRowPtr struct { + partitionIdx uint32 + rowIdx uint32 +} + +// partPtr4NullKey indicates a partition pointer which points to a row with null-join-key. +var partPtr4NullKey = partRowPtr{math.MaxUint32, math.MaxUint32} + +// Next implements the Executor Next interface. +// radix hash join constructs the result following these steps: +// step 1. fetch data from inner child +// step 2. parallel partition the inner relation into sub-relations and build an +// individual hash table for every partition +// step 3. fetch data from outer child in a background goroutine and partition +// it into sub-relations +// step 4. probe the corresponded sub-hash-table for every sub-outer-relation in +// multiple join workers +func (e *RadixHashJoinExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { + if e.runtimeStats != nil { + start := time.Now() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + } + if !e.prepared { + e.innerFinished = make(chan error, 1) + go util.WithRecovery(func() { e.partitionInnerAndBuildHashTables(ctx) }, e.handleFetchInnerAndBuildHashTablePanic) + // TODO: parallel fetch outer rows, partition them and do parallel join + e.prepared = true + } + return <-e.innerFinished +} + +// partitionInnerRows re-order e.innerResults into sub-relations. +func (e *RadixHashJoinExec) partitionInnerRows() error { + e.evalRadixBit() + if err := e.preAlloc4InnerParts(); err != nil { + return err + } + + wg := sync.WaitGroup{} + defer wg.Wait() + wg.Add(int(e.concurrency)) + for i := 0; i < int(e.concurrency); i++ { + workerID := i + go util.WithRecovery(func() { + defer wg.Done() + e.doInnerPartition(workerID) + }, e.handlePartitionPanic) + } + return nil +} + +func (e *RadixHashJoinExec) handlePartitionPanic(r interface{}) { + if r != nil { + e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} + } +} + +// doInnerPartition runs concurrently, partitions and copies the inner relation +// to several pre-allocated data partitions. The input inner Chunk idx for each +// partitioner is workerId + x*numPartitioners. +func (e *RadixHashJoinExec) doInnerPartition(workerID int) { + chkIdx, chkNum := workerID, e.innerResult.NumChunks() + for ; chkIdx < chkNum; chkIdx += int(e.concurrency) { + if e.finished.Load().(bool) { + return + } + chk := e.innerResult.GetChunk(chkIdx) + for srcRowIdx, partPtr := range e.innerRowPrts[chkIdx] { + if partPtr == partPtr4NullKey { + continue + } + partIdx, destRowIdx := partPtr.partitionIdx, partPtr.rowIdx + part := e.innerParts[partIdx] + part.Insert(int(destRowIdx), chk.GetRow(srcRowIdx)) + } + } +} + +// preAlloc4InnerParts evaluates partRowPtr and pre-alloc the memory space +// for every inner row to help re-order the inner relation. +// TODO: we need to evaluate the skewness for the partitions size, if the +// skewness exceeds a threshold, we do not use partition phase. +func (e *RadixHashJoinExec) preAlloc4InnerParts() (err error) { + hasNull, keyBuf := false, make([]byte, 0, 64) + for chkIdx, chkNum := 0, e.innerResult.NumChunks(); chkIdx < chkNum; chkIdx++ { + chk := e.innerResult.GetChunk(chkIdx) + partPtrs := make([]partRowPtr, chk.NumRows()) + for rowIdx := 0; rowIdx < chk.NumRows(); rowIdx++ { + row := chk.GetRow(rowIdx) + hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, row, keyBuf) + if err != nil { + return err + } + if hasNull { + partPtrs[rowIdx] = partPtr4NullKey + continue + } + joinHash := murmur3.Sum32(keyBuf) + partIdx := e.radixBits & joinHash + partPtrs[rowIdx].partitionIdx = partIdx + partPtrs[rowIdx].rowIdx = e.getPartition(partIdx).PreAlloc(row) + } + e.innerRowPrts = append(e.innerRowPrts, partPtrs) + } + if e.numNonEmptyPart < len(e.innerParts) { + numTotalPart := len(e.innerParts) + numEmptyPart := numTotalPart - e.numNonEmptyPart + log.Debugf("[EMPTY_PART_IN_RADIX_HASH_JOIN] txn_start_ts:%v, num_empty_parts:%v, "+ + "num_total_parts:%v, empty_ratio:%v", e.ctx.GetSessionVars().TxnCtx.StartTS, + numEmptyPart, numTotalPart, float64(numEmptyPart)/float64(numTotalPart)) + } + return +} + +func (e *RadixHashJoinExec) getPartition(idx uint32) partition { + if e.innerParts[idx] == nil { + e.numNonEmptyPart++ + e.innerParts[idx] = chunk.New(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) + } + return e.innerParts[idx] +} + +// evalRadixBit evaluates the radix bit numbers. +// To ensure that one partition of inner relation, one hash table, one partition +// of outer relation and the join result of these two partitions fit into the L2 +// cache when the input data obeys the uniform distribution, we suppose every +// sub-partition of inner relation using three quarters of the L2 cache size. +func (e *RadixHashJoinExec) evalRadixBit() { + sv := e.ctx.GetSessionVars() + innerResultSize := float64(e.innerResult.GetMemTracker().BytesConsumed()) + l2CacheSize := float64(sv.L2CacheSize) * 3 / 4 + radixBitsNum := math.Ceil(math.Log2(innerResultSize / l2CacheSize)) + if radixBitsNum <= 0 { + radixBitsNum = 1 + } + // Take the rightmost radixBitsNum bits as the bitmask. + e.radixBits = ^(math.MaxUint32 << uint(radixBitsNum)) + e.innerParts = make([]partition, 1< 0 { - fmt.Fprintf(&buf, " /* !40100 DEFAULT CHARACTER SET %s */", s) + fmt.Fprintf(&buf, " /*!40100 DEFAULT CHARACTER SET %s */", s) } e.appendRow([]interface{}{db.Name.O, buf.String()}) diff --git a/executor/show_test.go b/executor/show_test.go index 89dba5b3f6af4..e5175810206a9 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -173,6 +173,9 @@ func (s *testSuite2) TestShow2(c *C) { tk.MustExec(`create table if not exists t (c int) comment '注释'`) tk.MustExec("create or replace view v as select * from t") tk.MustQuery(`show columns from t`).Check(testutil.RowsWithSep(",", "c,int(11),YES,,,")) + tk.MustQuery(`describe t`).Check(testutil.RowsWithSep(",", "c,int(11),YES,,,")) + tk.MustQuery(`show columns from v`).Check(testutil.RowsWithSep(",", "c,int(11),YES,,,")) + tk.MustQuery(`describe v`).Check(testutil.RowsWithSep(",", "c,int(11),YES,,,")) tk.MustQuery("show collation where Charset = 'utf8' and Collation = 'utf8_bin'").Check(testutil.RowsWithSep(",", "utf8_bin,utf8,83,,Yes,1")) tk.MustQuery("show tables").Check(testkit.Rows("t", "v")) diff --git a/executor/simple_test.go b/executor/simple_test.go index 73e0958a5e5c4..fb1f669d50090 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -15,6 +15,7 @@ package executor_test import ( "context" + "github.com/pingcap/tidb/planner/core" . "github.com/pingcap/check" "github.com/pingcap/parser/auth" @@ -302,3 +303,12 @@ func (s *testSuite3) TestFlushTables(c *C) { c.Check(err, NotNil) } + +func (s *testSuite3) TestUseDB(c *C) { + tk := testkit.NewTestKit(c, s.store) + _, err := tk.Exec("USE test") + c.Check(err, IsNil) + + _, err = tk.Exec("USE ``") + c.Assert(terror.ErrorEqual(core.ErrNoDB, err), IsTrue, Commentf("err %v", err)) +} diff --git a/executor/update_test.go b/executor/update_test.go new file mode 100644 index 0000000000000..fda7bc0c35c07 --- /dev/null +++ b/executor/update_test.go @@ -0,0 +1,92 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + "flag" + "fmt" + + . "github.com/pingcap/check" + "github.com/pingcap/parser" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/store/mockstore/mocktikv" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" +) + +type testUpdateSuite struct { + cluster *mocktikv.Cluster + mvccStore mocktikv.MVCCStore + store kv.Storage + domain *domain.Domain + *parser.Parser + ctx *mock.Context +} + +func (s *testUpdateSuite) SetUpSuite(c *C) { + testleak.BeforeTest() + s.Parser = parser.New() + flag.Lookup("mockTikv") + useMockTikv := *mockTikv + if useMockTikv { + s.cluster = mocktikv.NewCluster() + mocktikv.BootstrapWithSingleStore(s.cluster) + s.mvccStore = mocktikv.MustNewMVCCStore() + store, err := mockstore.NewMockTikvStore( + mockstore.WithCluster(s.cluster), + mockstore.WithMVCCStore(s.mvccStore), + ) + c.Assert(err, IsNil) + s.store = store + session.SetSchemaLease(0) + session.SetStatsLease(0) + } + d, err := session.BootstrapSession(s.store) + c.Assert(err, IsNil) + d.SetStatsUpdating(true) + s.domain = d +} + +func (s *testUpdateSuite) TearDownSuite(c *C) { + s.domain.Close() + s.store.Close() + testleak.AfterTest(c, TestLeakCheckCnt)() +} + +func (s *testUpdateSuite) TearDownTest(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + r := tk.MustQuery("show tables") + for _, tb := range r.Rows() { + tableName := tb[0] + tk.MustExec(fmt.Sprintf("drop table %v", tableName)) + } +} + +func (s *testUpdateSuite) TestUpdateGenColInTxn(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`create table t(a bigint, b bigint as (a+1));`) + tk.MustExec(`begin;`) + tk.MustExec(`insert into t(a) values(1);`) + err := tk.ExecToErr(`update t set b=6 where b=2;`) + c.Assert(err.Error(), Equals, "[planner:3105]The value specified for generated column 'b' in table 't' is not allowed.") + tk.MustExec(`commit;`) + tk.MustQuery(`select * from t;`).Check(testkit.Rows( + `1 2`)) +} diff --git a/executor/write.go b/executor/write.go index 4e231355b4c54..58e591a0bb424 100644 --- a/executor/write.go +++ b/executor/write.go @@ -142,7 +142,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu } // the `affectedRows` is increased when adding new record. newHandle, err = t.AddRecord(ctx, newData, - &table.AddRecordOpt{CreateIdxOpt: table.CreateIdxOpt{SkipHandleCheck: sc.DupKeyAsWarning}}) + &table.AddRecordOpt{CreateIdxOpt: table.CreateIdxOpt{SkipHandleCheck: sc.DupKeyAsWarning}, IsUpdate: true}) if err != nil { return false, false, 0, errors.Trace(err) } diff --git a/executor/write_test.go b/executor/write_test.go index e01aaaa4fac48..2096f1eb13a4b 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -237,6 +237,27 @@ func (s *testSuite2) TestInsert(c *C) { tk.MustExec("insert into test values(2, 3)") tk.MustQuery("select * from test use index (id) where id = 2").Check(testkit.Rows("2 2", "2 3")) + // issue 6360 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a bigint unsigned);") + tk.MustExec(" set @orig_sql_mode = @@sql_mode; set @@sql_mode = 'strict_all_tables';") + _, err = tk.Exec("insert into t value (-1);") + c.Assert(types.ErrWarnDataOutOfRange.Equal(err), IsTrue) + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t value (-1);") + // TODO: the following warning messages are not consistent with MySQL, fix them in the future PRs + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select -1;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select cast(-1 as unsigned);") + tk.MustExec("insert into t value (-1.111);") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t value ('-1.111');") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 BIGINT UNSIGNED value is out of range in '-1'")) + r = tk.MustQuery("select * from t;") + r.Check(testkit.Rows("0", "0", "18446744073709551615", "0", "0")) + tk.MustExec("set @@sql_mode = @orig_sql_mode;") + // issue 6424 tk.MustExec("drop table if exists t") tk.MustExec("create table t(a time(6))") @@ -610,7 +631,7 @@ commit;` create table m (id int primary key auto_increment, code int unique); insert tmp (code) values (1); insert tmp (code) values (1); - set tidb_max_chunk_size=1; + set tidb_init_chunk_size=1; insert m (code) select code from tmp on duplicate key update code = values(code);` tk.MustExec(testSQL) testSQL = `select * from m;` @@ -1920,6 +1941,26 @@ func (s *testSuite2) TestLoadDataIgnoreLines(c *C) { checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) } +// related to issue 6360 +func (s *testSuite2) TestLoadDataOverflowBigintUnsigned(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test; drop table if exists load_data_test;") + tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);") + tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + ctx := tk.Se.(sessionctx.Context) + ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo) + c.Assert(ok, IsTrue) + defer ctx.SetValue(executor.LoadDataVarKey, nil) + c.Assert(ld, NotNil) + tests := []testCase{ + {nil, []byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, nil, "Records: 3 Deleted: 0 Skipped: 0 Warnings: 3"}, + {nil, []byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, nil, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, + } + deleteSQL := "delete from load_data_test" + selectSQL := "select * from load_data_test;" + checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) +} + func (s *testSuite2) TestBatchInsertDelete(c *C) { originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit) defer func() { diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go new file mode 100644 index 0000000000000..9555f900bdb73 --- /dev/null +++ b/expression/aggregation/window_func.go @@ -0,0 +1,29 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregation + +import ( + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" +) + +// WindowFuncDesc describes a window function signature, only used in planner. +type WindowFuncDesc struct { + baseFuncDesc +} + +// NewWindowFuncDesc creates a window function signature descriptor. +func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { + return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} +} diff --git a/expression/bench_test.go b/expression/bench_test.go index bc4a58ed81b45..79a481c1e6c8c 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -45,6 +45,7 @@ func (h *benchHelper) init() { h.ctx = mock.NewContext() h.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local + h.ctx.GetSessionVars().InitChunkSize = 32 h.ctx.GetSessionVars().MaxChunkSize = numRows h.inputTypes = make([]*types.FieldType, 0, 10) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index e1550ef7c3012..80163ef3699b1 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -464,7 +464,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b res = 0 } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) res = float64(uVal) } return res, false, err @@ -491,7 +492,8 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe res = &types.MyDecimal{} } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, err } @@ -520,7 +522,8 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul res = strconv.FormatInt(val, 10) } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, err } @@ -750,7 +753,8 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool res = 0 } else { var uintVal uint64 - uintVal, err = types.ConvertFloatToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) + sc := b.ctx.GetSessionVars().StmtCtx + uintVal, err = types.ConvertFloatToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) res = int64(uintVal) } return res, isNull, err @@ -1001,7 +1005,7 @@ func (b *builtinCastDecimalAsDurationSig) Clone() builtinFunc { func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types.Duration, isNull bool, err error) { val, isNull, err := b.args[0].EvalDecimal(b.ctx, row) if isNull || err != nil { - return res, false, err + return res, true, err } res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), b.tp.Decimal) if types.ErrTruncatedWrongVal.Equal(err) { diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index cf4744efb8d0e..85083cc5cfeb3 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -93,6 +93,9 @@ var aesModes = map[string]*aesModeAttr{ "aes-128-cbc": {"cbc", 16, true}, "aes-192-cbc": {"cbc", 24, true}, "aes-256-cbc": {"cbc", 32, true}, + "aes-128-ofb": {"ofb", 16, true}, + "aes-192-ofb": {"ofb", 24, true}, + "aes-256-ofb": {"ofb", 32, true}, "aes-128-cfb": {"cfb", 16, true}, "aes-192-cfb": {"cfb", 24, true}, "aes-256-cfb": {"cfb", 32, true}, @@ -212,6 +215,8 @@ func (b *builtinAesDecryptIVSig) evalString(row chunk.Row) (string, bool, error) switch b.modeName { case "cbc": plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) + case "ofb": + plainText, err = encrypt.AESDecryptWithOFB([]byte(cryptStr), key, []byte(iv)) case "cfb": plainText, err = encrypt.AESDecryptWithCFB([]byte(cryptStr), key, []byte(iv)) default: @@ -337,6 +342,8 @@ func (b *builtinAesEncryptIVSig) evalString(row chunk.Row) (string, bool, error) switch b.modeName { case "cbc": cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) + case "ofb": + cipherText, err = encrypt.AESEncryptWithOFB([]byte(str), key, []byte(iv)) case "cfb": cipherText, err = encrypt.AESEncryptWithCFB([]byte(str), key, []byte(iv)) default: diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index e66b9053dc2cb..4a12f14f1a9af 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -101,6 +101,13 @@ var aesTests = []struct { {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, {"aes-256-cbc", "pingcap", []interface{}{"12345678901234561234567890123456", "1234567890123456"}, "A26BA27CA4BE9D361D545AA84A17002D"}, {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "12345678901234561234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, + // test for ofb + {"aes-128-ofb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "0515A36BBF3DE0"}, + {"aes-128-ofb", "pingcap", []interface{}{"123456789012345678901234", "1234567890123456"}, "C2A93A93818546"}, + {"aes-192-ofb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "FE09DCCF14D458"}, + {"aes-256-ofb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "2E70FCAC0C0834"}, + {"aes-256-ofb", "pingcap", []interface{}{"12345678901234561234567890123456", "1234567890123456"}, "83E2B30A71F011"}, + {"aes-256-ofb", "pingcap", []interface{}{"1234567890123456", "12345678901234561234567890123456"}, "2E70FCAC0C0834"}, // test for cfb {"aes-128-cfb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "0515A36BBF3DE0"}, {"aes-128-cfb", "pingcap", []interface{}{"123456789012345678901234", "1234567890123456"}, "C2A93A93818546"}, diff --git a/expression/builtin_json.go b/expression/builtin_json.go index 4f5486e9b1af3..7fb2aec1e2164 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -403,6 +403,11 @@ func (b *builtinJSONMergeSig) evalJSON(row chunk.Row) (res json.BinaryJSON, isNu values = append(values, value) } res = json.MergeBinary(values) + // function "JSON_MERGE" is deprecated since MySQL 5.7.22. Synonym for function "JSON_MERGE_PRESERVE". + // See https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#function_json-merge + if b.pbCode == tipb.ScalarFuncSig_JsonMergeSig { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errDeprecatedSyntaxNoReplacement.GenWithStackByArgs("JSON_MERGE")) + } return res, false, nil } @@ -720,7 +725,17 @@ type jsonMergePreserveFunctionClass struct { } func (c *jsonMergePreserveFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "JSON_MERGE_PRESERVE") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETJson) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETJson, argTps...) + sig := &builtinJSONMergeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonMergePreserveSig) + return sig, nil } type jsonPrettyFunctionClass struct { diff --git a/expression/builtin_json_test.go b/expression/builtin_json_test.go index 7b78081803a7e..8f6888e8c035c 100644 --- a/expression/builtin_json_test.go +++ b/expression/builtin_json_test.go @@ -187,6 +187,39 @@ func (s *testEvaluatorSuite) TestJSONMerge(c *C) { j2 := d.GetMysqlJSON() cmp := json.CompareBinary(j1, j2) c.Assert(cmp, Equals, 0, Commentf("got %v expect %v", j1.String(), j2.String())) + case nil: + c.Assert(d.IsNull(), IsTrue) + } + } +} + +func (s *testEvaluatorSuite) TestJSONMergePreserve(c *C) { + defer testleak.AfterTest(c)() + fc := funcs[ast.JSONMergePreserve] + tbl := []struct { + Input []interface{} + Expected interface{} + }{ + {[]interface{}{nil, nil}, nil}, + {[]interface{}{`{}`, `[]`}, `[{}]`}, + {[]interface{}{`{}`, `[]`, `3`, `"4"`}, `[{}, 3, "4"]`}, + } + for _, t := range tbl { + args := types.MakeDatums(t.Input...) + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + + switch x := t.Expected.(type) { + case string: + j1, err := json.ParseBinaryFromString(x) + c.Assert(err, IsNil) + j2 := d.GetMysqlJSON() + cmp := json.CompareBinary(j1, j2) + c.Assert(cmp, Equals, 0, Commentf("got %v expect %v", j1.String(), j2.String())) + case nil: + c.Assert(d.IsNull(), IsTrue) } } } diff --git a/expression/constant_test.go b/expression/constant_test.go index f07047a5f9bbe..e0eba43757412 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -32,12 +32,16 @@ var _ = Suite(&testExpressionSuite{}) type testExpressionSuite struct{} func newColumn(id int) *Column { + return newColumnWithType(id, types.NewFieldType(mysql.TypeLonglong)) +} + +func newColumnWithType(id int, t *types.FieldType) *Column { return &Column{ UniqueID: int64(id), ColName: model.NewCIStr(fmt.Sprint(id)), TblName: model.NewCIStr("t"), DBName: model.NewCIStr("test"), - RetType: types.NewFieldType(mysql.TypeLonglong), + RetType: t, } } @@ -48,6 +52,18 @@ func newLonglong(value int64) *Constant { } } +func newDate(year, month, day int) *Constant { + var tmp types.Datum + tmp.SetMysqlTime(types.Time{ + Time: types.FromDate(year, month, day, 0, 0, 0, 0), + Type: mysql.TypeDate, + }) + return &Constant{ + Value: tmp, + RetType: types.NewFieldType(mysql.TypeDate), + } +} + func newFunction(funcName string, args ...Expression) Expression { typeLong := types.NewFieldType(mysql.TypeLonglong) return NewFunctionInternal(mock.NewContext(), funcName, typeLong, args...) @@ -180,6 +196,91 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { } } +func (*testExpressionSuite) TestConstraintPropagation(c *C) { + defer testleak.AfterTest(c)() + col1 := newColumnWithType(1, types.NewFieldType(mysql.TypeDate)) + tests := []struct { + solver constraintSolver + conditions []Expression + result string + }{ + // Don't propagate this any more, because it makes the code more complex but not + // useful for partition pruning. + // { + // solver: newConstraintSolver(ruleColumnGTConst), + // conditions: []Expression{ + // newFunction(ast.GT, newColumn(0), newLonglong(5)), + // newFunction(ast.GT, newColumn(0), newLonglong(7)), + // }, + // result: "gt(test.t.0, 7)", + // }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.GT, newColumn(0), newLonglong(5)), + newFunction(ast.LT, newColumn(0), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.GT, newColumn(0), newLonglong(7)), + newFunction(ast.LT, newColumn(0), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + // col1 > '2018-12-11' and to_days(col1) < 5 => false + conditions: []Expression{ + newFunction(ast.GT, col1, newDate(2018, 12, 11)), + newFunction(ast.LT, newFunction(ast.ToDays, col1), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.LT, newColumn(0), newLonglong(5)), + newFunction(ast.GT, newColumn(0), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.LT, newColumn(0), newLonglong(5)), + newFunction(ast.GT, newColumn(0), newLonglong(7)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + // col1 < '2018-12-11' and to_days(col1) > 737999 => false + conditions: []Expression{ + newFunction(ast.LT, col1, newDate(2018, 12, 11)), + newFunction(ast.GT, newFunction(ast.ToDays, col1), newLonglong(737999)), + }, + result: "0", + }, + } + for _, tt := range tests { + ctx := mock.NewContext() + conds := make([]Expression, 0, len(tt.conditions)) + for _, cd := range tt.conditions { + conds = append(conds, FoldConstant(cd)) + } + newConds := tt.solver.Solve(ctx, conds) + var result []string + for _, v := range newConds { + result = append(result, v.String()) + } + sort.Strings(result) + c.Assert(strings.Join(result, ", "), Equals, tt.result, Commentf("different for expr %s", tt.conditions)) + } +} + func (*testExpressionSuite) TestConstantFolding(c *C) { defer testleak.AfterTest(c)() tests := []struct { diff --git a/expression/constraint_propagation.go b/expression/constraint_propagation.go index 986576aa3261e..aafbf1660d306 100644 --- a/expression/constraint_propagation.go +++ b/expression/constraint_propagation.go @@ -16,6 +16,7 @@ package expression import ( "bytes" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" @@ -78,24 +79,35 @@ func newExprSet(conditions []Expression) *exprSet { return &exprs } +type constraintSolver []constraintPropagateRule + +func newConstraintSolver(rules ...constraintPropagateRule) constraintSolver { + return constraintSolver(rules) +} + type pgSolver2 struct{} -// PropagateConstant propagate constant values of deterministic predicates in a condition. func (s pgSolver2) PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { + solver := newConstraintSolver(ruleConstantFalse, ruleColumnEQConst) + return solver.Solve(ctx, conditions) +} + +// Solve propagate constraint according to the rules in the constraintSolver. +func (s constraintSolver) Solve(ctx sessionctx.Context, conditions []Expression) []Expression { exprs := newExprSet(conditions) s.fixPoint(ctx, exprs) return exprs.Slice() } -// fixPoint is the core of the constant propagation algorithm. +// fixPoint is the core of the constraint propagation algorithm. // It will iterate the expression set over and over again, pick two expressions, // apply one to another. // If new conditions can be inferred, they will be append into the expression set. // Until no more conditions can be inferred from the set, the algorithm finish. -func (s pgSolver2) fixPoint(ctx sessionctx.Context, exprs *exprSet) { +func (s constraintSolver) fixPoint(ctx sessionctx.Context, exprs *exprSet) { for { saveLen := len(exprs.data) - iterOnce(ctx, exprs) + s.iterOnce(ctx, exprs) if saveLen == len(exprs.data) { break } @@ -104,7 +116,7 @@ func (s pgSolver2) fixPoint(ctx sessionctx.Context, exprs *exprSet) { } // iterOnce picks two expressions from the set, try to propagate new conditions from them. -func iterOnce(ctx sessionctx.Context, exprs *exprSet) { +func (s constraintSolver) iterOnce(ctx sessionctx.Context, exprs *exprSet) { for i := 0; i < len(exprs.data); i++ { if exprs.tombstone[i] { continue @@ -116,24 +128,19 @@ func iterOnce(ctx sessionctx.Context, exprs *exprSet) { if i == j { continue } - solve(ctx, i, j, exprs) + s.solve(ctx, i, j, exprs) } } } // solve uses exprs[i] exprs[j] to propagate new conditions. -func solve(ctx sessionctx.Context, i, j int, exprs *exprSet) { - for _, rule := range rules { +func (s constraintSolver) solve(ctx sessionctx.Context, i, j int, exprs *exprSet) { + for _, rule := range s { rule(ctx, i, j, exprs) } } -type constantPropagateRule func(ctx sessionctx.Context, i, j int, exprs *exprSet) - -var rules = []constantPropagateRule{ - ruleConstantFalse, - ruleColumnEQConst, -} +type constraintPropagateRule func(ctx sessionctx.Context, i, j int, exprs *exprSet) // ruleConstantFalse propagates from CNF condition that false plus anything returns false. // false, a = 1, b = c ... => false @@ -164,3 +171,134 @@ func ruleColumnEQConst(ctx sessionctx.Context, i, j int, exprs *exprSet) { } } } + +// ruleColumnOPConst propagates the "column OP const" condition. +func ruleColumnOPConst(ctx sessionctx.Context, i, j int, exprs *exprSet) { + cond := exprs.data[i] + f1, ok := cond.(*ScalarFunction) + if !ok { + return + } + if f1.FuncName.L != ast.GE && f1.FuncName.L != ast.GT && + f1.FuncName.L != ast.LE && f1.FuncName.L != ast.LT { + return + } + OP1 := f1.FuncName.L + + var col1 *Column + var con1 *Constant + col1, ok = f1.GetArgs()[0].(*Column) + if !ok { + return + } + con1, ok = f1.GetArgs()[1].(*Constant) + if !ok { + return + } + + expr := exprs.data[j] + f2, ok := expr.(*ScalarFunction) + if !ok { + return + } + + // The simple case: + // col >= c1, col < c2, c1 >= c2 => false + // col >= c1, col <= c2, c1 > c2 => false + // col >= c1, col OP c2, c1 ^OP c2, where OP in [< , <=] => false + // col OP1 c1 where OP1 in [>= , <], col OP2 c2 where OP1 opsite OP2, c1 ^OP2 c2 => false + // + // The extended case: + // col >= c1, f(col) < c2, f is monotonous, f(c1) >= c2 => false + // + // Proof: + // col > c1, f is monotonous => f(col) > f(c1) + // f(col) > f(c1), f(col) < c2, f(c1) >= c2 => false + OP2 := f2.FuncName.L + if !opsiteOP(OP1, OP2) { + return + } + + con2, ok := f2.GetArgs()[1].(*Constant) + if !ok { + return + } + arg0 := f2.GetArgs()[0] + // The simple case. + var fc1 Expression + col2, ok := arg0.(*Column) + if ok { + fc1 = con1 + } else { + // The extended case. + scalarFunc, ok := arg0.(*ScalarFunction) + if !ok { + return + } + _, ok = monotoneIncFuncs[scalarFunc.FuncName.L] + if !ok { + return + } + col2, ok = scalarFunc.GetArgs()[0].(*Column) + if !ok { + return + } + var err error + fc1, err = NewFunction(ctx, scalarFunc.FuncName.L, scalarFunc.RetType, con1) + if err != nil { + log.Warn(err) + return + } + } + if !col1.Equal(ctx, col2) { + return + } + v, isNull, err := compareConstant(ctx, negOP(OP2), fc1, con2) + if err != nil { + log.Warn(err) + return + } + if !isNull && v > 0 { + exprs.SetConstFalse() + } + return +} + +// opsiteOP the opsite direction of a compare operation, used in ruleColumnOPConst. +func opsiteOP(op1, op2 string) bool { + switch { + case op1 == ast.GE || op1 == ast.GT: + return op2 == ast.LT || op2 == ast.LE + case op1 == ast.LE || op1 == ast.LT: + return op2 == ast.GT || op2 == ast.GE + } + return false +} + +func negOP(cmp string) string { + switch cmp { + case ast.LT: + return ast.GE + case ast.LE: + return ast.GT + case ast.GT: + return ast.LE + case ast.GE: + return ast.LT + } + return "" +} + +// monotoneIncFuncs are those functions that for any x y, if x > y => f(x) > f(y) +var monotoneIncFuncs = map[string]struct{}{ + ast.ToDays: {}, +} + +// compareConstant compares two expressions. c1 and c2 should be constant with the same type. +func compareConstant(ctx sessionctx.Context, fn string, c1, c2 Expression) (int64, bool, error) { + cmp, err := NewFunction(ctx, fn, types.NewFieldType(mysql.TypeTiny), c1, c2) + if err != nil { + return 0, false, err + } + return cmp.EvalInt(ctx, chunk.Row{}) +} diff --git a/expression/errors.go b/expression/errors.go index 8a18acfa2ad5a..1ebc4ffdaba7b 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -72,7 +72,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { return err } sc := ctx.GetSessionVars().StmtCtx - if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateOrDeleteStmt) { + if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { return err } sc.AppendWarning(err) @@ -82,7 +82,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { // handleDivisionByZeroError reports error or warning depend on the context. func handleDivisionByZeroError(ctx sessionctx.Context) error { sc := ctx.GetSessionVars().StmtCtx - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() { return nil } diff --git a/expression/integration_test.go b/expression/integration_test.go index 5e5f73d7770f9..cf6d02b574954 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1006,6 +1006,21 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { result.Check(testkit.Rows("341672829F84CB6B0BE690FEC4C4DAE9 341672829F84CB6B0BE690FEC4C4DAE9 D43734E147A12BB96C6897C4BBABA283 16F2C972411948DCEF3659B726D2CCB04AD1379A1A367FA64242058A50211B67 41E71D0C58967C1F50EEC074523946D1 1117D292E2D39C3EAA3B435371BE56FC 8ACB7ECC0883B672D7BD1CFAA9FA5FAF5B731ADE978244CD581F114D591C2E7E D2B13C30937E3251AEDA73859BA32E4B 2CF4A6051FF248A67598A17AA2C17267")) result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") result.Check(testkit.Rows(`80D5646F07B4654B05A02D9085759770 80D5646F07B4654B05A02D9085759770 B3C14BA15030D2D7E99376DBE011E752 0CD2936EE4FEC7A8CDF6208438B2BC05 `)) + tk.MustExec("SET block_encryption_mode='aes-128-ofb';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("40 40 40C35C 40DD5EBDFCAA397102386E27DDF97A39ECCEC5 43DF55BAE0A0386D 78 47DC5D8AD19A085C32094E16EFC34A08D6FEF459 46D5 06840BE8")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`48E38A 48E38A 9D6C199101C3 `)) + tk.MustExec("SET block_encryption_mode='aes-192-ofb';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("4B 4B 4B573F 4B493D42572E6477233A429BF3E0AD39DB816D 484B36454B24656B 73 4C483E757A1E555A130B62AAC1DA9D08E1B15C47 4D41 0D106817")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`3A76B0 3A76B0 EFF92304268E `)) + tk.MustExec("SET block_encryption_mode='aes-256-ofb';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("16 16 16D103 16CF01CBC95D33E2ED721CBD930262415A69AD 15CD0ACCD55732FE 2E 11CE02FCE46D02CFDD433C8CA138527060599C35 10C7 5096549E")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`E842C5 E842C5 3DCD5646767D `)) // for AES_DECRYPT tk.MustExec("SET block_encryption_mode='aes-128-ecb';") @@ -1018,6 +1033,21 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { result.Check(testkit.Rows("foo")) result = tk.MustQuery("select AES_DECRYPT(UNHEX('80D5646F07B4654B05A02D9085759770'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('B3C14BA15030D2D7E99376DBE011E752'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('0CD2936EE4FEC7A8CDF6208438B2BC05'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456')") result.Check(testkit.Rows(`123 你好 `)) + tk.MustExec("SET block_encryption_mode='aes-128-ofb';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('48E38A'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX(''), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('9D6C199101C3'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), HEX(AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456'))") + result.Check(testkit.Rows(`123 你好 2A9EF431FB2ACB022D7F2E7C71EEC48C7D2B`)) + tk.MustExec("SET block_encryption_mode='aes-192-ofb';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('3A76B0'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX(''), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('EFF92304268E'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), HEX(AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456'))") + result.Check(testkit.Rows(`123 你好 580BCEA4DC67CF33FF2C7C570D36ECC89437`)) + tk.MustExec("SET block_encryption_mode='aes-256-ofb';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('E842C5'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX(''), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('3DCD5646767D'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), HEX(AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456'))") + result.Check(testkit.Rows(`123 你好 8A3FBBE68C9465834584430E3AEEBB04B1F5`)) // for COMPRESS tk.MustExec("DROP TABLE IF EXISTS t1;") @@ -3701,8 +3731,8 @@ func (s *testIntegrationSuite) TestUnknowHintIgnore(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("USE test") tk.MustExec("create table t(a int)") - tk.MustQuery("select /*+ unkown_hint(c1)*/ 1").Check(testkit.Rows("1")) - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 line 1 column 28 near \" 1\" (total length 30)")) + tk.MustQuery("select /*+ unknown_hint(c1)*/ 1").Check(testkit.Rows("1")) + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 line 1 column 29 near \"select /*+ unknown_hint(c1)*/ 1\" (total length 31)")) _, err := tk.Exec("select 1 from /*+ test1() */ t") c.Assert(err, NotNil) } @@ -3797,3 +3827,33 @@ func (s *testIntegrationSuite) TestUserVarMockWindFunc(c *C) { `3 6 3 key3-value6 insert_order6`, )) } + +func (s *testIntegrationSuite) TestCastAsTime(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test;`) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t (col1 bigint, col2 double, col3 decimal, col4 varchar(20), col5 json);`) + tk.MustExec(`insert into t values (1, 1, 1, "1", "1");`) + tk.MustExec(`insert into t values (null, null, null, null, null);`) + tk.MustQuery(`select cast(col1 as time), cast(col2 as time), cast(col3 as time), cast(col4 as time), cast(col5 as time) from t where col1 = 1;`).Check(testkit.Rows( + `00:00:01 00:00:01 00:00:01 00:00:01 00:00:01`, + )) + tk.MustQuery(`select cast(col1 as time), cast(col2 as time), cast(col3 as time), cast(col4 as time), cast(col5 as time) from t where col1 is null;`).Check(testkit.Rows( + ` `, + )) + + err := tk.ExecToErr(`select cast(col1 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col2 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col3 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col4 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col5 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") +} diff --git a/go.mod b/go.mod index b3beb315f66cb..45ba9b40cc01a 100644 --- a/go.mod +++ b/go.mod @@ -43,14 +43,14 @@ require ( github.com/onsi/gomega v1.4.3 // indirect github.com/opentracing/basictracer-go v1.0.0 github.com/opentracing/opentracing-go v1.0.2 - github.com/pingcap/check v0.0.0-20181222140913-41d022e836db + github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8 github.com/pingcap/errors v0.11.0 github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3 github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c - github.com/pingcap/parser v0.0.0-20181225032741-ff56f7f11ed6 + github.com/pingcap/parser v0.0.0-20190108104142-35fab0be7fca github.com/pingcap/pd v2.1.0-rc.4+incompatible - github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible + github.com/pingcap/tidb-tools v2.1.3-0.20190104033906-883b07a04a73+incompatible github.com/pingcap/tipb v0.0.0-20181012112600-11e33c750323 github.com/pkg/errors v0.8.0 // indirect github.com/prometheus/client_golang v0.9.0 diff --git a/go.sum b/go.sum index 60c7b9c477f0d..c7358410e7861 100644 --- a/go.sum +++ b/go.sum @@ -100,8 +100,8 @@ github.com/opentracing/basictracer-go v1.0.0 h1:YyUAhaEfjoWXclZVJ9sGoNct7j4TVk7l github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg= github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/pingcap/check v0.0.0-20181222140913-41d022e836db h1:yg93sLvBszRnzcd+Z5gkCUdYgud2scHYYxnRwljvRAM= -github.com/pingcap/check v0.0.0-20181222140913-41d022e836db/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= +github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8 h1:USx2/E1bX46VG32FIw034Au6seQ2fY9NEILmNh/UlQg= +github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= github.com/pingcap/errors v0.11.0 h1:DCJQB8jrHbQ1VVlMFIrbj2ApScNNotVmkSNplu2yUt4= github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3 h1:04yuCf5NMvLU8rB2m4Qs3rynH7EYpMno3lHkewIOdMo= @@ -110,12 +110,12 @@ github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e h1:P73/4dPCL96rG github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e/go.mod h1:O17XtbryoCJhkKGbT62+L2OlrniwqiGLSqrmdHCMzZw= github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c h1:Qf5St5XGwKgKQLar9lEXoeO0hJMVaFBj3JqvFguWtVg= github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c/go.mod h1:Ja9XPjot9q4/3JyCZodnWDGNXt4pKemhIYCvVJM7P24= -github.com/pingcap/parser v0.0.0-20181225032741-ff56f7f11ed6 h1:ooapyJxH6uSHNvpYjPOggHtd2dzLKwSvYLVzs3OjoM0= -github.com/pingcap/parser v0.0.0-20181225032741-ff56f7f11ed6/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/parser v0.0.0-20190108104142-35fab0be7fca h1:BIk063GpkHOJzR4Bp9I8xfsjp3nSvElvhIUjDgmInoo= +github.com/pingcap/parser v0.0.0-20190108104142-35fab0be7fca/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v2.1.0-rc.4+incompatible h1:/buwGk04aHO5odk/+O8ZOXGs4qkUjYTJ2UpCJXna8NE= github.com/pingcap/pd v2.1.0-rc.4+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E= -github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible h1:Bsd+NHosPVowEGB3BCx+2d8wUQGDTXSSC5ljeNS6cXo= -github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= +github.com/pingcap/tidb-tools v2.1.3-0.20190104033906-883b07a04a73+incompatible h1:Ba48wwPwPq5hd1kkQpgua49dqB5cthC2zXVo7fUUDec= +github.com/pingcap/tidb-tools v2.1.3-0.20190104033906-883b07a04a73+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= github.com/pingcap/tipb v0.0.0-20170310053819-1043caee48da/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= github.com/pingcap/tipb v0.0.0-20181012112600-11e33c750323 h1:mRKKzRjDNaUNPnAkPAHnRqpNmwNWBX1iA+hxlmvQ93I= github.com/pingcap/tipb v0.0.0-20181012112600-11e33c750323/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= diff --git a/infoschema/builder.go b/infoschema/builder.go index 1515939beb0f3..10a1dadf40974 100644 --- a/infoschema/builder.go +++ b/infoschema/builder.go @@ -54,7 +54,7 @@ func (b *Builder) ApplyDiff(m *meta.Meta, diff *model.SchemaDiff) ([]int64, erro case model.ActionCreateTable: newTableID = diff.TableID tblIDs = append(tblIDs, newTableID) - case model.ActionDropTable: + case model.ActionDropTable, model.ActionDropView: oldTableID = diff.TableID tblIDs = append(tblIDs, oldTableID) case model.ActionTruncateTable: diff --git a/infoschema/infoschema_test.go b/infoschema/infoschema_test.go index 3381f19d64aa0..41c2e128d0f88 100644 --- a/infoschema/infoschema_test.go +++ b/infoschema/infoschema_test.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" - "github.com/pingcap/tidb/perfschema" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/testleak" @@ -120,7 +119,7 @@ func (*testSuite) TestT(c *C) { schemaNames := is.AllSchemaNames() c.Assert(schemaNames, HasLen, 3) - c.Assert(testutil.CompareUnorderedStringSlice(schemaNames, []string{infoschema.Name, perfschema.Name, "Test"}), IsTrue) + c.Assert(testutil.CompareUnorderedStringSlice(schemaNames, []string{infoschema.Name, "PERFORMANCE_SCHEMA", "Test"}), IsTrue) schemas := is.AllSchemas() c.Assert(schemas, HasLen, 3) diff --git a/perfschema/const.go b/infoschema/perfschema/const.go similarity index 100% rename from perfschema/const.go rename to infoschema/perfschema/const.go diff --git a/perfschema/init.go b/infoschema/perfschema/init.go similarity index 100% rename from perfschema/init.go rename to infoschema/perfschema/init.go diff --git a/infoschema/perfschema/tables.go b/infoschema/perfschema/tables.go new file mode 100644 index 0000000000000..e0712aad7514f --- /dev/null +++ b/infoschema/perfschema/tables.go @@ -0,0 +1,66 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package perfschema + +import ( + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/meta/autoid" + "github.com/pingcap/tidb/table" +) + +// perfSchemaTable stands for the fake table all its data is in the memory. +type perfSchemaTable struct { + infoschema.VirtualTable + meta *model.TableInfo + cols []*table.Column +} + +func tableFromMeta(alloc autoid.Allocator, meta *model.TableInfo) (table.Table, error) { + return createPerfSchemaTable(meta), nil +} + +// createPerfSchemaTable creates all perfSchemaTables +func createPerfSchemaTable(meta *model.TableInfo) *perfSchemaTable { + columns := make([]*table.Column, 0, len(meta.Columns)) + for _, colInfo := range meta.Columns { + col := table.ToColumn(colInfo) + columns = append(columns, col) + } + t := &perfSchemaTable{ + meta: meta, + cols: columns, + } + return t +} + +// Cols implements table.Table Type interface. +func (vt *perfSchemaTable) Cols() []*table.Column { + return vt.cols +} + +// WritableCols implements table.Table Type interface. +func (vt *perfSchemaTable) WritableCols() []*table.Column { + return vt.cols +} + +// GetID implements table.Table GetID interface. +func (vt *perfSchemaTable) GetPhysicalID() int64 { + return vt.meta.ID +} + +// Meta implements table.Table Type interface. +func (vt *perfSchemaTable) Meta() *model.TableInfo { + return vt.meta +} diff --git a/perfschema/tables_test.go b/infoschema/perfschema/tables_test.go similarity index 100% rename from perfschema/tables_test.go rename to infoschema/perfschema/tables_test.go diff --git a/infoschema/tables.go b/infoschema/tables.go index 02e60854ac291..dc18a9a1d302d 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -144,7 +144,7 @@ var tablesCols = []columnInfo{ {"UPDATE_TIME", mysql.TypeDatetime, 19, 0, nil, nil}, {"CHECK_TIME", mysql.TypeDatetime, 19, 0, nil, nil}, {"TABLE_COLLATION", mysql.TypeVarchar, 32, mysql.NotNullFlag, "utf8_bin", nil}, - {"CHECK_SUM", mysql.TypeLonglong, 21, 0, nil, nil}, + {"CHECKSUM", mysql.TypeLonglong, 21, 0, nil, nil}, {"CREATE_OPTIONS", mysql.TypeVarchar, 255, 0, nil, nil}, {"TABLE_COMMENT", mysql.TypeVarchar, 2048, 0, nil, nil}, } @@ -870,6 +870,43 @@ func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *m return autoIncID, nil } +func dataForViews(ctx sessionctx.Context, schemas []*model.DBInfo) ([][]types.Datum, error) { + checker := privilege.GetPrivilegeManager(ctx) + var rows [][]types.Datum + for _, schema := range schemas { + for _, table := range schema.Tables { + if !table.IsView() { + continue + } + collation := table.Collate + charset := table.Charset + if collation == "" { + collation = mysql.DefaultCollationName + } + if charset == "" { + charset = mysql.DefaultCharset + } + if checker != nil && !checker.RequestVerification(schema.Name.L, table.Name.L, "", mysql.AllPrivMask) { + continue + } + record := types.MakeDatums( + catalogVal, // TABLE_CATALOG + schema.Name.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + table.View.SelectStmt, // VIEW_DEFINITION + table.View.CheckOption.String(), // CHECK_OPTION + "NO", // IS_UPDATABLE + table.View.Definer.String(), // DEFINER + table.View.Security.String(), // SECURITY_TYPE + charset, // CHARACTER_SET_CLIENT + collation, // COLLATION_CONNECTION + ) + rows = append(rows, record) + } + } + return rows, nil +} + func dataForTables(ctx sessionctx.Context, schemas []*model.DBInfo) ([][]types.Datum, error) { tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx) if err != nil { @@ -892,48 +929,75 @@ func dataForTables(ctx sessionctx.Context, schemas []*model.DBInfo) ([][]types.D } createOptions := "" - if table.GetPartitionInfo() != nil { - createOptions = "partitioned" - } if checker != nil && !checker.RequestVerification(schema.Name.L, table.Name.L, "", mysql.AllPrivMask) { continue } - autoIncID, err := getAutoIncrementID(ctx, schema, table) - if err != nil { - return nil, errors.Trace(err) - } - rowCount := tableRowsMap[table.ID] - dataLength, indexLength := getDataAndIndexLength(table, rowCount, colLengthMap) - avgRowLength := uint64(0) - if rowCount != 0 { - avgRowLength = dataLength / rowCount + if !table.IsView() { + if table.GetPartitionInfo() != nil { + createOptions = "partitioned" + } + autoIncID, err := getAutoIncrementID(ctx, schema, table) + if err != nil { + return nil, errors.Trace(err) + } + rowCount := tableRowsMap[table.ID] + dataLength, indexLength := getDataAndIndexLength(table, rowCount, colLengthMap) + avgRowLength := uint64(0) + if rowCount != 0 { + avgRowLength = dataLength / rowCount + } + record := types.MakeDatums( + catalogVal, // TABLE_CATALOG + schema.Name.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + "BASE TABLE", // TABLE_TYPE + "InnoDB", // ENGINE + uint64(10), // VERSION + "Compact", // ROW_FORMAT + rowCount, // TABLE_ROWS + avgRowLength, // AVG_ROW_LENGTH + dataLength, // DATA_LENGTH + uint64(0), // MAX_DATA_LENGTH + indexLength, // INDEX_LENGTH + uint64(0), // DATA_FREE + autoIncID, // AUTO_INCREMENT + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + collation, // TABLE_COLLATION + nil, // CHECKSUM + createOptions, // CREATE_OPTIONS + table.Comment, // TABLE_COMMENT + ) + rows = append(rows, record) + } else { + record := types.MakeDatums( + catalogVal, // TABLE_CATALOG + schema.Name.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + "VIEW", // TABLE_TYPE + nil, // ENGINE + nil, // VERSION + nil, // ROW_FORMAT + nil, // TABLE_ROWS + nil, // AVG_ROW_LENGTH + nil, // DATA_LENGTH + nil, // MAX_DATA_LENGTH + nil, // INDEX_LENGTH + nil, // DATA_FREE + nil, // AUTO_INCREMENT + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + nil, // TABLE_COLLATION + nil, // CHECKSUM + nil, // CREATE_OPTIONS + "VIEW", // TABLE_COMMENT + ) + rows = append(rows, record) } - record := types.MakeDatums( - catalogVal, // TABLE_CATALOG - schema.Name.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - "BASE TABLE", // TABLE_TYPE - "InnoDB", // ENGINE - uint64(10), // VERSION - "Compact", // ROW_FORMAT - rowCount, // TABLE_ROWS - avgRowLength, // AVG_ROW_LENGTH - dataLength, // DATA_LENGTH - uint64(0), // MAX_DATA_LENGTH - indexLength, // INDEX_LENGTH - uint64(0), // DATA_FREE - autoIncID, // AUTO_INCREMENT - createTime, // CREATE_TIME - nil, // UPDATE_TIME - nil, // CHECK_TIME - collation, // TABLE_COLLATION - nil, // CHECKSUM - createOptions, // CREATE_OPTIONS - table.Comment, // TABLE_COMMENT - ) - rows = append(rows, record) } } return rows, nil @@ -1398,6 +1462,7 @@ func (it *infoschemaTable) getRows(ctx sessionctx.Context, cols []*table.Column) case tableEngines: fullRows = dataForEngines() case tableViews: + fullRows, err = dataForViews(ctx, dbs) case tableRoutines: // TODO: Fill the following tables. case tableSchemaPrivileges: @@ -1538,3 +1603,120 @@ func (it *infoschemaTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, e func (it *infoschemaTable) Type() table.Type { return table.VirtualTable } + +// VirtualTable is a dummy table.Table implementation. +type VirtualTable struct{} + +// IterRecords implements table.Table Type interface. +func (vt *VirtualTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, cols []*table.Column, + fn table.RecordIterFunc) error { + if len(startKey) != 0 { + return table.ErrUnsupportedOp + } + return nil +} + +// RowWithCols implements table.Table Type interface. +func (vt *VirtualTable) RowWithCols(ctx sessionctx.Context, h int64, cols []*table.Column) ([]types.Datum, error) { + return nil, table.ErrUnsupportedOp +} + +// Row implements table.Table Type interface. +func (vt *VirtualTable) Row(ctx sessionctx.Context, h int64) ([]types.Datum, error) { + return nil, table.ErrUnsupportedOp +} + +// Cols implements table.Table Type interface. +func (vt *VirtualTable) Cols() []*table.Column { + return nil +} + +// WritableCols implements table.Table Type interface. +func (vt *VirtualTable) WritableCols() []*table.Column { + return nil +} + +// Indices implements table.Table Type interface. +func (vt *VirtualTable) Indices() []table.Index { + return nil +} + +// WritableIndices implements table.Table Type interface. +func (vt *VirtualTable) WritableIndices() []table.Index { + return nil +} + +// DeletableIndices implements table.Table Type interface. +func (vt *VirtualTable) DeletableIndices() []table.Index { + return nil +} + +// RecordPrefix implements table.Table Type interface. +func (vt *VirtualTable) RecordPrefix() kv.Key { + return nil +} + +// IndexPrefix implements table.Table Type interface. +func (vt *VirtualTable) IndexPrefix() kv.Key { + return nil +} + +// FirstKey implements table.Table Type interface. +func (vt *VirtualTable) FirstKey() kv.Key { + return nil +} + +// RecordKey implements table.Table Type interface. +func (vt *VirtualTable) RecordKey(h int64) kv.Key { + return nil +} + +// AddRecord implements table.Table Type interface. +func (vt *VirtualTable) AddRecord(ctx sessionctx.Context, r []types.Datum, opts ...*table.AddRecordOpt) (recordID int64, err error) { + return 0, table.ErrUnsupportedOp +} + +// RemoveRecord implements table.Table Type interface. +func (vt *VirtualTable) RemoveRecord(ctx sessionctx.Context, h int64, r []types.Datum) error { + return table.ErrUnsupportedOp +} + +// UpdateRecord implements table.Table Type interface. +func (vt *VirtualTable) UpdateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, touched []bool) error { + return table.ErrUnsupportedOp +} + +// AllocAutoID implements table.Table Type interface. +func (vt *VirtualTable) AllocAutoID(ctx sessionctx.Context) (int64, error) { + return 0, table.ErrUnsupportedOp +} + +// Allocator implements table.Table Type interface. +func (vt *VirtualTable) Allocator(ctx sessionctx.Context) autoid.Allocator { + return nil +} + +// RebaseAutoID implements table.Table Type interface. +func (vt *VirtualTable) RebaseAutoID(ctx sessionctx.Context, newBase int64, isSetStep bool) error { + return table.ErrUnsupportedOp +} + +// Meta implements table.Table Type interface. +func (vt *VirtualTable) Meta() *model.TableInfo { + return nil +} + +// GetPhysicalID implements table.Table GetPhysicalID interface. +func (vt *VirtualTable) GetPhysicalID() int64 { + return 0 +} + +// Seek implements table.Table Type interface. +func (vt *VirtualTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, error) { + return 0, false, table.ErrUnsupportedOp +} + +// Type implements table.Table Type interface. +func (vt *VirtualTable) Type() table.Type { + return table.VirtualTable +} diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index dfe826ea595c4..1f07f28e5996b 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -191,3 +191,20 @@ func (s *testSuite) TestProfiling(c *C) { tk.MustExec("set @@profiling=1") tk.MustQuery("select * from information_schema.profiling").Check(testkit.Rows("0 0 0 0 0 0 0 0 0 0 0 0 0 0 0")) } + +func (s *testSuite) TestViews(c *C) { + testleak.BeforeTest() + defer testleak.AfterTest(c)() + store, err := mockstore.NewMockTikvStore() + c.Assert(err, IsNil) + defer store.Close() + session.SetStatsLease(0) + do, err := session.BootstrapSession(store) + c.Assert(err, IsNil) + defer do.Close() + + tk := testkit.NewTestKit(c, store) + tk.MustExec("CREATE DEFINER='root'@'localhost' VIEW test.v1 AS SELECT 1") + tk.MustQuery("SELECT * FROM information_schema.views WHERE table_schema='test' AND table_name='v1'").Check(testkit.Rows("def test v1 SELECT 1 CASCADED NO root@localhost DEFINER utf8mb4 utf8mb4_bin")) + tk.MustQuery("SELECT table_catalog, table_schema, table_name, table_type, engine, version, row_format, table_rows, avg_row_length, data_length, max_data_length, index_length, data_free, auto_increment, update_time, check_time, table_collation, checksum, create_options, table_comment FROM information_schema.tables WHERE table_schema='test' AND table_name='v1'").Check(testkit.Rows("def test v1 VIEW VIEW")) +} diff --git a/perfschema/tables.go b/perfschema/tables.go deleted file mode 100644 index 2288311cd9bea..0000000000000 --- a/perfschema/tables.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package perfschema - -import ( - "github.com/pingcap/parser/model" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/meta/autoid" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/types" -) - -// perfSchemaTable stands for the fake table all its data is in the memory. -type perfSchemaTable struct { - meta *model.TableInfo - cols []*table.Column -} - -func tableFromMeta(alloc autoid.Allocator, meta *model.TableInfo) (table.Table, error) { - return createPerfSchemaTable(meta), nil -} - -// createPerfSchemaTable creates all perfSchemaTables -func createPerfSchemaTable(meta *model.TableInfo) *perfSchemaTable { - columns := make([]*table.Column, 0, len(meta.Columns)) - for _, colInfo := range meta.Columns { - col := table.ToColumn(colInfo) - columns = append(columns, col) - } - t := &perfSchemaTable{ - meta: meta, - cols: columns, - } - return t -} - -// IterRecords implements table.Table Type interface. -func (vt *perfSchemaTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, cols []*table.Column, - fn table.RecordIterFunc) error { - if len(startKey) != 0 { - return table.ErrUnsupportedOp - } - return nil -} - -// RowWithCols implements table.Table Type interface. -func (vt *perfSchemaTable) RowWithCols(ctx sessionctx.Context, h int64, cols []*table.Column) ([]types.Datum, error) { - return nil, table.ErrUnsupportedOp -} - -// Row implements table.Table Type interface. -func (vt *perfSchemaTable) Row(ctx sessionctx.Context, h int64) ([]types.Datum, error) { - return nil, table.ErrUnsupportedOp -} - -// Cols implements table.Table Type interface. -func (vt *perfSchemaTable) Cols() []*table.Column { - return vt.cols -} - -// WritableCols implements table.Table Type interface. -func (vt *perfSchemaTable) WritableCols() []*table.Column { - return vt.cols -} - -// Indices implements table.Table Type interface. -func (vt *perfSchemaTable) Indices() []table.Index { - return nil -} - -// WritableIndices implements table.Table Type interface. -func (vt *perfSchemaTable) WritableIndices() []table.Index { - return nil -} - -// DeletableIndices implements table.Table Type interface. -func (vt *perfSchemaTable) DeletableIndices() []table.Index { - return nil -} - -// RecordPrefix implements table.Table Type interface. -func (vt *perfSchemaTable) RecordPrefix() kv.Key { - return nil -} - -// IndexPrefix implements table.Table Type interface. -func (vt *perfSchemaTable) IndexPrefix() kv.Key { - return nil -} - -// FirstKey implements table.Table Type interface. -func (vt *perfSchemaTable) FirstKey() kv.Key { - return nil -} - -// RecordKey implements table.Table Type interface. -func (vt *perfSchemaTable) RecordKey(h int64) kv.Key { - return nil -} - -// AddRecord implements table.Table Type interface. -func (vt *perfSchemaTable) AddRecord(ctx sessionctx.Context, r []types.Datum, opts ...*table.AddRecordOpt) (recordID int64, err error) { - return 0, table.ErrUnsupportedOp -} - -// RemoveRecord implements table.Table Type interface. -func (vt *perfSchemaTable) RemoveRecord(ctx sessionctx.Context, h int64, r []types.Datum) error { - return table.ErrUnsupportedOp -} - -// UpdateRecord implements table.Table Type interface. -func (vt *perfSchemaTable) UpdateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, touched []bool) error { - return table.ErrUnsupportedOp -} - -// AllocAutoID implements table.Table Type interface. -func (vt *perfSchemaTable) AllocAutoID(ctx sessionctx.Context) (int64, error) { - return 0, table.ErrUnsupportedOp -} - -// Allocator implements table.Table Type interface. -func (vt *perfSchemaTable) Allocator(ctx sessionctx.Context) autoid.Allocator { - return nil -} - -// RebaseAutoID implements table.Table Type interface. -func (vt *perfSchemaTable) RebaseAutoID(ctx sessionctx.Context, newBase int64, isSetStep bool) error { - return table.ErrUnsupportedOp -} - -// Meta implements table.Table Type interface. -func (vt *perfSchemaTable) Meta() *model.TableInfo { - return vt.meta -} - -// GetID implements table.Table GetID interface. -func (vt *perfSchemaTable) GetPhysicalID() int64 { - return vt.meta.ID -} - -// Seek implements table.Table Type interface. -func (vt *perfSchemaTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, error) { - return 0, false, table.ErrUnsupportedOp -} - -// Type implements table.Table Type interface. -func (vt *perfSchemaTable) Type() table.Type { - return table.VirtualTable -} diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index 5cf32c9b54a12..8454b042f7b7d 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -77,7 +77,7 @@ func (s *testAnalyzeSuite) TestExplainAnalyze(c *C) { tk.MustExec("insert into t2 values (2, 22), (3, 33), (5, 55)") tk.MustExec("analyze table t1, t2") rs := tk.MustQuery("explain analyze select t1.a, t1.b, sum(t1.c) from t1 join t2 on t1.a = t2.b where t1.a > 1") - c.Assert(len(rs.Rows()), Equals, 10) + c.Assert(len(rs.Rows()), Equals, 11) for _, row := range rs.Rows() { c.Assert(len(row), Equals, 5) taskType := row[2].(string) @@ -367,7 +367,7 @@ func (s *testAnalyzeSuite) TestIndexRead(c *C) { }, { sql: "select sum(a) from t1 use index(idx) where a = 3 and b = 100000 group by a limit 1", - best: "IndexLookUp(Index(t1.idx)[[3,3]], Table(t1)->Sel([eq(test.t1.b, 100000)]))->StreamAgg->Limit", + best: "IndexLookUp(Index(t1.idx)[[3,3]], Table(t1)->Sel([eq(test.t1.b, 100000)]))->Projection->StreamAgg->Limit", }, } for _, tt := range tests { diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 217c4449dafc0..efca01ffd618a 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -305,13 +305,14 @@ type Deallocate struct { type Show struct { baseSchemaProducer - Tp ast.ShowStmtType // Databases/Tables/Columns/.... - DBName string - Table *ast.TableName // Used for showing columns. - Column *ast.ColumnName // Used for `desc table column`. - Flag int // Some flag parsed from sql, such as FULL. - Full bool - User *auth.UserIdentity // Used for show grants. + Tp ast.ShowStmtType // Databases/Tables/Columns/.... + DBName string + Table *ast.TableName // Used for showing columns. + Column *ast.ColumnName // Used for `desc table column`. + Flag int // Some flag parsed from sql, such as FULL. + Full bool + User *auth.UserIdentity // Used for show grants. + IfNotExists bool // Used for `show create database if not exists` Conditions []expression.Expression diff --git a/planner/core/errors.go b/planner/core/errors.go index 6d029c490f533..4f43ca6670f11 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -26,31 +26,41 @@ const ( codeWrongParamCount = 5 codeSchemaChanged = 6 - codeWrongUsage = mysql.ErrWrongUsage - codeAmbiguous = mysql.ErrNonUniq - codeUnknown = mysql.ErrUnknown - codeUnknownColumn = mysql.ErrBadField - codeUnknownTable = mysql.ErrUnknownTable - codeWrongArguments = mysql.ErrWrongArguments - codeBadGeneratedColumn = mysql.ErrBadGeneratedColumn - codeFieldNotInGroupBy = mysql.ErrFieldNotInGroupBy - codeBadTable = mysql.ErrBadTable - codeKeyDoesNotExist = mysql.ErrKeyDoesNotExist - codeOperandColumns = mysql.ErrOperandColumns - codeInvalidWildCard = mysql.ErrParse - codeInvalidGroupFuncUse = mysql.ErrInvalidGroupFuncUse - codeIllegalReference = mysql.ErrIllegalReference - codeNoDB = mysql.ErrNoDB - codeUnknownExplainFormat = mysql.ErrUnknownExplainFormat - codeWrongGroupField = mysql.ErrWrongGroupField - codeDupFieldName = mysql.ErrDupFieldName - codeNonUpdatableTable = mysql.ErrNonUpdatableTable - codeInternal = mysql.ErrInternal - codeMixOfGroupFuncAndFields = mysql.ErrMixOfGroupFuncAndFields - codeNonUniqTable = mysql.ErrNonuniqTable - codeWrongNumberOfColumnsInSelect = mysql.ErrWrongNumberOfColumnsInSelect - codeWrongValueCountOnRow = mysql.ErrWrongValueCountOnRow - codeTablenameNotAllowedHere = mysql.ErrTablenameNotAllowedHere + codeWrongUsage = mysql.ErrWrongUsage + codeAmbiguous = mysql.ErrNonUniq + codeUnknown = mysql.ErrUnknown + codeUnknownColumn = mysql.ErrBadField + codeUnknownTable = mysql.ErrUnknownTable + codeWrongArguments = mysql.ErrWrongArguments + codeBadGeneratedColumn = mysql.ErrBadGeneratedColumn + codeFieldNotInGroupBy = mysql.ErrFieldNotInGroupBy + codeBadTable = mysql.ErrBadTable + codeKeyDoesNotExist = mysql.ErrKeyDoesNotExist + codeOperandColumns = mysql.ErrOperandColumns + codeInvalidWildCard = mysql.ErrParse + codeInvalidGroupFuncUse = mysql.ErrInvalidGroupFuncUse + codeIllegalReference = mysql.ErrIllegalReference + codeNoDB = mysql.ErrNoDB + codeUnknownExplainFormat = mysql.ErrUnknownExplainFormat + codeWrongGroupField = mysql.ErrWrongGroupField + codeDupFieldName = mysql.ErrDupFieldName + codeNonUpdatableTable = mysql.ErrNonUpdatableTable + codeInternal = mysql.ErrInternal + codeMixOfGroupFuncAndFields = mysql.ErrMixOfGroupFuncAndFields + codeNonUniqTable = mysql.ErrNonuniqTable + codeWrongNumberOfColumnsInSelect = mysql.ErrWrongNumberOfColumnsInSelect + codeWrongValueCountOnRow = mysql.ErrWrongValueCountOnRow + codeTablenameNotAllowedHere = mysql.ErrTablenameNotAllowedHere + codePrivilegeCheckFail = mysql.ErrUnknown + codeWindowInvalidWindowFuncUse = mysql.ErrWindowInvalidWindowFuncUse + codeWindowInvalidWindowFuncAliasUse = mysql.ErrWindowInvalidWindowFuncAliasUse + codeWindowNoSuchWindow = mysql.ErrWindowNoSuchWindow + codeWindowCircularityInWindowGraph = mysql.ErrWindowCircularityInWindowGraph + codeWindowNoChildPartitioning = mysql.ErrWindowNoChildPartitioning + codeWindowNoInherentFrame = mysql.ErrWindowNoInherentFrame + codeWindowNoRedefineOrderBy = mysql.ErrWindowNoRedefineOrderBy + codeWindowDuplicateName = mysql.ErrWindowDuplicateName + codeErrTooBigPrecision = mysql.ErrTooBigPrecision ) // error definitions. @@ -63,31 +73,41 @@ var ( ErrSchemaChanged = terror.ClassOptimizer.New(codeSchemaChanged, "Schema has changed") ErrTablenameNotAllowedHere = terror.ClassOptimizer.New(codeTablenameNotAllowedHere, "Table '%s' from one of the %ss cannot be used in %s") - ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage]) - ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq]) - ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown]) - ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField]) - ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable]) - ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) - ErrWrongNumberOfColumnsInSelect = terror.ClassOptimizer.New(codeWrongNumberOfColumnsInSelect, mysql.MySQLErrName[mysql.ErrWrongNumberOfColumnsInSelect]) - ErrBadGeneratedColumn = terror.ClassOptimizer.New(codeBadGeneratedColumn, mysql.MySQLErrName[mysql.ErrBadGeneratedColumn]) - ErrFieldNotInGroupBy = terror.ClassOptimizer.New(codeFieldNotInGroupBy, mysql.MySQLErrName[mysql.ErrFieldNotInGroupBy]) - ErrBadTable = terror.ClassOptimizer.New(codeBadTable, mysql.MySQLErrName[mysql.ErrBadTable]) - ErrKeyDoesNotExist = terror.ClassOptimizer.New(codeKeyDoesNotExist, mysql.MySQLErrName[mysql.ErrKeyDoesNotExist]) - ErrOperandColumns = terror.ClassOptimizer.New(codeOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns]) - ErrInvalidWildCard = terror.ClassOptimizer.New(codeInvalidWildCard, "Wildcard fields without any table name appears in wrong place") - ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(codeInvalidGroupFuncUse, mysql.MySQLErrName[mysql.ErrInvalidGroupFuncUse]) - ErrIllegalReference = terror.ClassOptimizer.New(codeIllegalReference, mysql.MySQLErrName[mysql.ErrIllegalReference]) - ErrNoDB = terror.ClassOptimizer.New(codeNoDB, mysql.MySQLErrName[mysql.ErrNoDB]) - ErrUnknownExplainFormat = terror.ClassOptimizer.New(codeUnknownExplainFormat, mysql.MySQLErrName[mysql.ErrUnknownExplainFormat]) - ErrWrongGroupField = terror.ClassOptimizer.New(codeWrongGroupField, mysql.MySQLErrName[mysql.ErrWrongGroupField]) - ErrDupFieldName = terror.ClassOptimizer.New(codeDupFieldName, mysql.MySQLErrName[mysql.ErrDupFieldName]) - ErrNonUpdatableTable = terror.ClassOptimizer.New(codeNonUpdatableTable, mysql.MySQLErrName[mysql.ErrNonUpdatableTable]) - ErrInternal = terror.ClassOptimizer.New(codeInternal, mysql.MySQLErrName[mysql.ErrInternal]) - ErrMixOfGroupFuncAndFields = terror.ClassOptimizer.New(codeMixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by") - ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable]) - ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow]) - ErrViewInvalid = terror.ClassOptimizer.New(mysql.ErrViewInvalid, mysql.MySQLErrName[mysql.ErrViewInvalid]) + ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage]) + ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq]) + ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown]) + ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField]) + ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable]) + ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) + ErrWrongNumberOfColumnsInSelect = terror.ClassOptimizer.New(codeWrongNumberOfColumnsInSelect, mysql.MySQLErrName[mysql.ErrWrongNumberOfColumnsInSelect]) + ErrBadGeneratedColumn = terror.ClassOptimizer.New(codeBadGeneratedColumn, mysql.MySQLErrName[mysql.ErrBadGeneratedColumn]) + ErrFieldNotInGroupBy = terror.ClassOptimizer.New(codeFieldNotInGroupBy, mysql.MySQLErrName[mysql.ErrFieldNotInGroupBy]) + ErrBadTable = terror.ClassOptimizer.New(codeBadTable, mysql.MySQLErrName[mysql.ErrBadTable]) + ErrKeyDoesNotExist = terror.ClassOptimizer.New(codeKeyDoesNotExist, mysql.MySQLErrName[mysql.ErrKeyDoesNotExist]) + ErrOperandColumns = terror.ClassOptimizer.New(codeOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns]) + ErrInvalidWildCard = terror.ClassOptimizer.New(codeInvalidWildCard, "Wildcard fields without any table name appears in wrong place") + ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(codeInvalidGroupFuncUse, mysql.MySQLErrName[mysql.ErrInvalidGroupFuncUse]) + ErrIllegalReference = terror.ClassOptimizer.New(codeIllegalReference, mysql.MySQLErrName[mysql.ErrIllegalReference]) + ErrNoDB = terror.ClassOptimizer.New(codeNoDB, mysql.MySQLErrName[mysql.ErrNoDB]) + ErrUnknownExplainFormat = terror.ClassOptimizer.New(codeUnknownExplainFormat, mysql.MySQLErrName[mysql.ErrUnknownExplainFormat]) + ErrWrongGroupField = terror.ClassOptimizer.New(codeWrongGroupField, mysql.MySQLErrName[mysql.ErrWrongGroupField]) + ErrDupFieldName = terror.ClassOptimizer.New(codeDupFieldName, mysql.MySQLErrName[mysql.ErrDupFieldName]) + ErrNonUpdatableTable = terror.ClassOptimizer.New(codeNonUpdatableTable, mysql.MySQLErrName[mysql.ErrNonUpdatableTable]) + ErrInternal = terror.ClassOptimizer.New(codeInternal, mysql.MySQLErrName[mysql.ErrInternal]) + ErrMixOfGroupFuncAndFields = terror.ClassOptimizer.New(codeMixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by") + ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable]) + ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow]) + ErrViewInvalid = terror.ClassOptimizer.New(mysql.ErrViewInvalid, mysql.MySQLErrName[mysql.ErrViewInvalid]) + ErrPrivilegeCheckFail = terror.ClassOptimizer.New(codePrivilegeCheckFail, "privilege check fail") + ErrWindowInvalidWindowFuncUse = terror.ClassOptimizer.New(codeWindowInvalidWindowFuncUse, mysql.MySQLErrName[mysql.ErrWindowInvalidWindowFuncUse]) + ErrWindowInvalidWindowFuncAliasUse = terror.ClassOptimizer.New(codeWindowInvalidWindowFuncAliasUse, mysql.MySQLErrName[mysql.ErrWindowInvalidWindowFuncAliasUse]) + ErrWindowNoSuchWindow = terror.ClassOptimizer.New(codeWindowNoSuchWindow, mysql.MySQLErrName[mysql.ErrWindowNoSuchWindow]) + ErrWindowCircularityInWindowGraph = terror.ClassOptimizer.New(codeWindowCircularityInWindowGraph, mysql.MySQLErrName[mysql.ErrWindowCircularityInWindowGraph]) + ErrWindowNoChildPartitioning = terror.ClassOptimizer.New(codeWindowNoChildPartitioning, mysql.MySQLErrName[mysql.ErrWindowNoChildPartitioning]) + ErrWindowNoInherentFrame = terror.ClassOptimizer.New(codeWindowNoInherentFrame, mysql.MySQLErrName[mysql.ErrWindowNoInherentFrame]) + ErrWindowNoRedefineOrderBy = terror.ClassOptimizer.New(codeWindowNoRedefineOrderBy, mysql.MySQLErrName[mysql.ErrWindowNoRedefineOrderBy]) + ErrWindowDuplicateName = terror.ClassOptimizer.New(codeWindowDuplicateName, mysql.MySQLErrName[mysql.ErrWindowDuplicateName]) + errTooBigPrecision = terror.ClassExpression.New(mysql.ErrTooBigPrecision, mysql.MySQLErrName[mysql.ErrTooBigPrecision]) ) func init() { @@ -115,6 +135,16 @@ func init() { codeNonUniqTable: mysql.ErrNonuniqTable, codeWrongNumberOfColumnsInSelect: mysql.ErrWrongNumberOfColumnsInSelect, codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow, + + codeWindowInvalidWindowFuncUse: mysql.ErrWindowInvalidWindowFuncUse, + codeWindowInvalidWindowFuncAliasUse: mysql.ErrWindowInvalidWindowFuncAliasUse, + codeWindowNoSuchWindow: mysql.ErrWindowNoSuchWindow, + codeWindowCircularityInWindowGraph: mysql.ErrWindowCircularityInWindowGraph, + codeWindowNoChildPartitioning: mysql.ErrWindowNoChildPartitioning, + codeWindowNoInherentFrame: mysql.ErrWindowNoInherentFrame, + codeWindowNoRedefineOrderBy: mysql.ErrWindowNoRedefineOrderBy, + codeWindowDuplicateName: mysql.ErrWindowDuplicateName, + codeErrTooBigPrecision: mysql.ErrTooBigPrecision, } terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mysqlErrCodeMap } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 76275eb6a35cf..c8f3e84275d77 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -789,6 +789,18 @@ func (la *LogicalApply) exhaustPhysicalPlans(prop *property.PhysicalProperty) [] return []PhysicalPlan{apply} } +func (p *LogicalWindow) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan { + childProperty := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, Items: p.ByItems, Enforced: true} + if !prop.IsPrefix(childProperty) { + return nil + } + window := PhysicalWindow{ + WindowFuncDesc: p.WindowFuncDesc, + }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProperty) + window.SetSchema(p.Schema()) + return []PhysicalPlan{window} +} + // exhaustPhysicalPlans is only for implementing interface. DataSource and Dual generate task in `findBestTask` directly. func (p *baseLogicalPlan) exhaustPhysicalPlans(_ *property.PhysicalProperty) []PhysicalPlan { panic("baseLogicalPlan.exhaustPhysicalPlans() should never be called.") diff --git a/planner/core/explain.go b/planner/core/explain.go index f4b37efd869d1..44e9ef4c81099 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -297,3 +297,9 @@ func (p *PhysicalTopN) ExplainInfo() string { fmt.Fprintf(buffer, ", offset:%v, count:%v", p.Offset, p.Count) return buffer.String() } + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalWindow) ExplainInfo() string { + // TODO: Add explain info for partition by, order by and frame. + return p.WindowFuncDesc.String() +} diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 2a4af05983ace..5adc459cbca8f 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" @@ -169,6 +170,7 @@ type expressionRewriter struct { insertPlan *Insert } +// constructBinaryOpFunction converts binary operator functions // 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2) // 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to // `IF( a0 NE b0, a0 op b0, @@ -303,12 +305,25 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { } er.ctxStack = append(er.ctxStack, expression.NewValuesFunc(er.ctx, col.Index, col.RetType)) return inNode, true + case *ast.WindowFuncExpr: + return er.handleWindowFunction(v) default: er.asScalar = true } return inNode, false } +func (er *expressionRewriter) handleWindowFunction(v *ast.WindowFuncExpr) (ast.Node, bool) { + windowPlan, err := er.b.buildWindowFunction(er.p, v, er.aggrMap) + if err != nil { + er.err = err + return v, false + } + er.ctxStack = append(er.ctxStack, windowPlan.GetWindowResultColumn()) + er.p = windowPlan + return v, true +} + func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) (ast.Node, bool) { v.L.Accept(er) if er.err != nil { @@ -751,7 +766,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok } switch v := inNode.(type) { case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, - *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr: + *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr: case *driver.ValueExpr: value := &expression.Constant{Value: v.Datum, RetType: &v.Type} er.ctxStack = append(er.ctxStack, value) @@ -782,6 +797,13 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok if er.err != nil { return retNode, false } + + // check the decimal precision of "CAST(AS TIME)". + er.err = er.checkTimePrecision(v.Tp) + if er.err != nil { + return retNode, false + } + er.ctxStack[len(er.ctxStack)-1] = expression.BuildCastFunction(er.ctx, arg, v.Tp) case *ast.PatternLikeExpr: er.likeToScalarFunc(v) @@ -799,6 +821,8 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok er.isNullToExpression(v) case *ast.IsTruthExpr: er.isTrueToScalarFunc(v) + case *ast.DefaultExpr: + er.evalDefaultExpr(v) default: er.err = errors.Errorf("UnknownType: %T", v) return retNode, false @@ -810,6 +834,13 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return originInNode, true } +func (er *expressionRewriter) checkTimePrecision(ft *types.FieldType) error { + if ft.EvalType() == types.ETDuration && ft.Decimal > types.MaxFsp { + return errTooBigPrecision.GenWithStackByArgs(ft.Decimal, "CAST", types.MaxFsp) + } + return nil +} + func (er *expressionRewriter) useCache() bool { return er.ctx.GetSessionVars().StmtCtx.UseCache } @@ -848,7 +879,13 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { return } } - if v.IsGlobal { + sysVar := variable.SysVars[name] + if sysVar == nil { + er.err = variable.UnknownSystemVar.GenWithStackByArgs(name) + return + } + // Variable is @@gobal.variable_name or variable is only global scope variable. + if v.IsGlobal || sysVar.Scope == variable.ScopeGlobal { val, err = variable.GetGlobalSystemVar(sessionVars, name) } else { val, err = variable.GetSessionSystemVar(sessionVars, name) @@ -1292,3 +1329,86 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) { } er.err = ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause]) } + +func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) { + stkLen := len(er.ctxStack) + var colExpr *expression.Column + switch c := er.ctxStack[stkLen-1].(type) { + case *expression.Column: + colExpr = c + case *expression.CorrelatedColumn: + colExpr = &c.Column + default: + colExpr, er.err = er.schema.FindColumn(v.Name) + if er.err != nil { + er.err = errors.Trace(er.err) + return + } + if colExpr == nil { + er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field_list") + return + } + } + dbName := colExpr.DBName + if dbName.O == "" { + // if database name is not specified, use current database name + dbName = model.NewCIStr(er.ctx.GetSessionVars().CurrentDB) + } + if colExpr.OrigTblName.O == "" { + // column is evaluated by some expressions, for example: + // `select default(c) from (select (a+1) as c from t) as t0` + // in such case, a 'no default' error is returned + er.err = table.ErrNoDefaultValue.GenWithStackByArgs(colExpr.ColName) + return + } + var tbl table.Table + tbl, er.err = er.b.is.TableByName(dbName, colExpr.OrigTblName) + if er.err != nil { + return + } + colName := colExpr.OrigColName.O + if colName == "" { + // in some cases, OrigColName is empty, use ColName instead + colName = colExpr.ColName.O + } + col := table.FindCol(tbl.Cols(), colName) + if col == nil { + er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name, "field_list") + return + } + isCurrentTimestamp := hasCurrentDatetimeDefault(col) + var val *expression.Constant + switch { + case isCurrentTimestamp && col.Tp == mysql.TypeDatetime: + // for DATETIME column with current_timestamp, use NULL to be compatible with MySQL 5.7 + val = expression.Null + case isCurrentTimestamp && col.Tp == mysql.TypeTimestamp: + // for TIMESTAMP column with current_timestamp, use 0 to be compatible with MySQL 5.7 + zero := types.Time{ + Time: types.ZeroTime, + Type: mysql.TypeTimestamp, + Fsp: col.Decimal, + } + val = &expression.Constant{ + Value: types.NewDatum(zero), + RetType: types.NewFieldType(mysql.TypeTimestamp), + } + default: + // for other columns, just use what it is + val, er.err = er.b.getDefaultValue(col) + } + if er.err != nil { + return + } + er.ctxStack = er.ctxStack[:stkLen-1] + er.ctxStack = append(er.ctxStack, val) +} + +// hasCurrentDatetimeDefault checks if column has current_timestamp default value +func hasCurrentDatetimeDefault(col *table.Column) bool { + x, ok := col.DefaultValue.(string) + if !ok { + return false + } + return strings.ToLower(x) == ast.CurrentTimestamp +} diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index bedd1328446ca..6089ac43f20ea 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -17,6 +17,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" + "github.com/pingcap/tidb/util/testutil" ) var _ = Suite(&testExpressionRewriterSuite{}) @@ -58,3 +59,73 @@ func (s *testExpressionRewriterSuite) TestBinaryOpFunction(c *C) { tk.MustQuery("SELECT * FROM t WHERE (a,b,c) <= (1,2,3) order by b").Check(testkit.Rows("1 1 ", "1 2 3")) tk.MustQuery("SELECT * FROM t WHERE (a,b,c) > (1,2,3) order by b").Check(testkit.Rows("1 3 ")) } + +func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec(`create table t1( + a varchar(10) default 'def', + b varchar(10), + c int default '10', + d double default '3.14', + e datetime default '20180101', + f datetime default current_timestamp);`) + tk.MustExec("insert into t1(a, b, c, d) values ('1', '1', 1, 1)") + tk.MustQuery(`select + default(a) as defa, + default(b) as defb, + default(c) as defc, + default(d) as defd, + default(e) as defe, + default(f) as deff + from t1`).Check(testutil.RowsWithSep("|", "def||10|3.14|2018-01-01 00:00:00|")) + err = tk.ExecToErr("select default(x) from t1") + c.Assert(err.Error(), Equals, "[planner:1054]Unknown column 'x' in 'field list'") + + tk.MustQuery("select default(a0) from (select a as a0 from t1) as t0").Check(testkit.Rows("def")) + err = tk.ExecToErr("select default(a0) from (select a+1 as a0 from t1) as t0") + c.Assert(err.Error(), Equals, "[table:1364]Field 'a0' doesn't have a default value") + + tk.MustExec("create table t2(a varchar(10), b varchar(10))") + tk.MustExec("insert into t2 values ('1', '1')") + err = tk.ExecToErr("select default(a) from t1, t2") + c.Assert(err.Error(), Equals, "[planner:1052]Column 'a' in field list is ambiguous") + tk.MustQuery("select default(t1.a) from t1, t2").Check(testkit.Rows("def")) + + tk.MustExec(`create table t3( + a datetime default current_timestamp, + b timestamp default current_timestamp, + c timestamp(6) default current_timestamp(6), + d varchar(20) default 'current_timestamp')`) + tk.MustExec("insert into t3 values ()") + tk.MustQuery(`select + default(a) as defa, + default(b) as defb, + default(c) as defc, + default(d) as defd + from t3`).Check(testutil.RowsWithSep("|", "|0000-00-00 00:00:00|0000-00-00 00:00:00.000000|current_timestamp")) + + tk.MustExec(`create table t4(a int default 1, b varchar(5))`) + tk.MustExec(`insert into t4 values (0, 'B'), (1, 'B'), (2, 'B')`) + tk.MustExec(`create table t5(d int default 0, e varchar(5))`) + tk.MustExec(`insert into t5 values (5, 'B')`) + + tk.MustQuery(`select a from t4 where a > (select default(d) from t5 where t4.b = t5.e)`).Check(testkit.Rows("1", "2")) + tk.MustQuery(`select a from t4 where a > (select default(a) from t5 where t4.b = t5.e)`).Check(testkit.Rows("2")) + + tk.MustExec("prepare stmt from 'select default(a) from t1';") + tk.MustQuery("execute stmt").Check(testkit.Rows("def")) + tk.MustExec("alter table t1 modify a varchar(10) default 'DEF'") + tk.MustQuery("execute stmt").Check(testkit.Rows("DEF")) + + tk.MustExec("update t1 set c = c + default(c)") + tk.MustQuery("select c from t1").Check(testkit.Rows("11")) +} diff --git a/planner/core/initialize.go b/planner/core/initialize.go index 9c2b698c927f0..ff2a3dbc7b059 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -81,6 +81,8 @@ const ( TypeTableReader = "TableReader" // TypeIndexReader is the type of IndexReader. TypeIndexReader = "IndexReader" + // TypeWindow is the type of Window. + TypeWindow = "Window" ) // Init initializes LogicalAggregation. @@ -231,6 +233,20 @@ func (p PhysicalMaxOneRow) Init(ctx sessionctx.Context, stats *property.StatsInf return &p } +// Init initializes LogicalWindow. +func (p LogicalWindow) Init(ctx sessionctx.Context) *LogicalWindow { + p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeWindow, &p) + return &p +} + +// Init initializes PhysicalWindow. +func (p PhysicalWindow) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalWindow { + p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeWindow, &p) + p.childrenReqProps = props + p.stats = stats + return &p +} + // Init initializes Update. func (p Update) Init(ctx sessionctx.Context) *Update { p.basePlan = newBasePlan(ctx, TypeUpdate) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 592bc86519e1f..6e4eb5fbccacb 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" @@ -596,13 +597,35 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi } // buildProjection returns a Projection plan and non-aux columns length. -func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int) (LogicalPlan, int, error) { +func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { b.optFlag |= flagEliminateProjection b.curClause = fieldList proj := LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx) schema := expression.NewSchema(make([]*expression.Column, 0, len(fields))...) oldLen := 0 - for _, field := range fields { + for i, field := range fields { + if !field.Auxiliary { + oldLen++ + } + + isWindowFuncField := ast.HasWindowFlag(field.Expr) + // Although window functions occurs in the select fields, but it has to be processed after having clause. + // So when we build the projection for select fields, we need to skip the window function. + // When `considerWindow` is false, we will only build fields for non-window functions, so we add fake placeholders. + // for window functions. These fake placeholders will be erased in column pruning. + // When `considerWindow` is true, all the non-window fields have been built, so we just use the schema columns. + if (considerWindow && !isWindowFuncField) || (!considerWindow && isWindowFuncField) { + var expr expression.Expression + if isWindowFuncField { + expr = expression.Zero + } else { + expr = p.Schema().Columns[i] + } + proj.Exprs = append(proj.Exprs, expr) + col := b.buildProjectionField(proj.id, schema.Len()+1, field, expr) + schema.Append(col) + continue + } newExpr, np, err := b.rewrite(field.Expr, p, mapper, true) if err != nil { return nil, 0, errors.Trace(err) @@ -613,10 +636,6 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, col := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr) schema.Append(col) - - if !field.Auxiliary { - oldLen++ - } } proj.SetSchema(schema) proj.SetChildren(p) @@ -999,10 +1018,11 @@ func resolveFromSelectFields(v *ast.ColumnNameExpr, fields []*ast.SelectField, i return } -// havingAndOrderbyExprResolver visits Expr tree. +// havingWindowAndOrderbyExprResolver visits Expr tree. // It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. -type havingAndOrderbyExprResolver struct { +type havingWindowAndOrderbyExprResolver struct { inAggFunc bool + inWindowFunc bool inExpr bool orderBy bool err error @@ -1016,10 +1036,12 @@ type havingAndOrderbyExprResolver struct { } // Enter implements Visitor interface. -func (a *havingAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { +func (a *havingWindowAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { switch n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = true + case *ast.WindowFuncExpr: + a.inWindowFunc = true case *driver.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: // Enter a new context, skip it. @@ -1031,7 +1053,7 @@ func (a *havingAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChi return n, false } -func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, schema *expression.Schema) (int, error) { +func (a *havingWindowAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, schema *expression.Schema) (int, error) { col, err := schema.FindColumn(v.Name) if err != nil { return -1, errors.Trace(err) @@ -1045,7 +1067,7 @@ func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, Name: col.ColName, } for i, field := range a.selectFields { - if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && colMatch(newColName, c.Name) { + if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && colMatch(c.Name, newColName) { return i, nil } } @@ -1059,7 +1081,7 @@ func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, } // Leave implements Visitor interface. -func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { +func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { switch v := n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = false @@ -1069,9 +1091,15 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool Expr: v, AsName: model.NewCIStr(fmt.Sprintf("sel_agg_%d", len(a.selectFields))), }) + case *ast.WindowFuncExpr: + a.inWindowFunc = false + if a.curClause == havingClause { + a.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) + return node, false + } case *ast.ColumnNameExpr: resolveFieldsFirst := true - if a.inAggFunc || (a.orderBy && a.inExpr) { + if a.inAggFunc || a.inWindowFunc || (a.orderBy && a.inExpr) { resolveFieldsFirst = false } if !a.inAggFunc && !a.orderBy { @@ -1089,6 +1117,10 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool if a.err != nil { return node, false } + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } if index == -1 { if a.orderBy { index, a.err = a.resolveFromSchema(v, a.p.Schema()) @@ -1102,8 +1134,12 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool var err error index, err = a.resolveFromSchema(v, a.p.Schema()) _ = err - if index == -1 { + if index == -1 && a.curClause != windowClause { index, a.err = resolveFromSelectFields(v, a.selectFields, false) + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } } } if a.err != nil { @@ -1137,7 +1173,7 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool // When we rewrite the order by / having expression, we will find column in map at first. func (b *PlanBuilder) resolveHavingAndOrderBy(sel *ast.SelectStmt, p LogicalPlan) ( map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int, error) { - extractor := &havingAndOrderbyExprResolver{ + extractor := &havingWindowAndOrderbyExprResolver{ p: p, selectFields: sel.Fields.Fields, aggMapper: make(map[*ast.AggregateFuncExpr]int), @@ -1190,6 +1226,31 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega return aggList, totalAggMapper } +// resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. +func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) ( + map[*ast.AggregateFuncExpr]int, error) { + extractor := &havingWindowAndOrderbyExprResolver{ + p: p, + selectFields: sel.Fields.Fields, + aggMapper: make(map[*ast.AggregateFuncExpr]int), + colMapper: b.colMapper, + outerSchemas: b.outerSchemas, + } + extractor.curClause = windowClause + for _, field := range sel.Fields.Fields { + if !ast.HasWindowFlag(field.Expr) { + continue + } + n, ok := field.Expr.Accept(extractor) + if !ok { + return nil, extractor.err + } + field.Expr = n.(ast.ExprNode) + } + sel.Fields.Fields = extractor.selectFields + return extractor.aggMapper, nil +} + // gbyResolver resolves group by items from select fields. type gbyResolver struct { ctx sessionctx.Context @@ -1234,6 +1295,8 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to group function") + } else if ast.HasWindowFlag(ret) { + err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to window function") } else { return ret, true } @@ -1270,7 +1333,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { func tblInfoFromCol(from ast.ResultSetNode, col *expression.Column) *model.TableInfo { var tableList []*ast.TableName - tableList = extractTableList(from, tableList) + tableList = extractTableList(from, tableList, true) for _, field := range tableList { if field.Name.L == col.TblName.L { return field.TableInfo @@ -1727,6 +1790,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error var ( aggFuncs []*ast.AggregateFuncExpr havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int + windowMap map[*ast.AggregateFuncExpr]int gbyCols []expression.Expression ) @@ -1759,6 +1823,13 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } } + hasWindowFuncField := b.detectSelectWindow(sel) + if hasWindowFuncField { + windowMap, err = b.resolveWindowFunction(sel, p) + if err != nil { + return nil, err + } + } // We must resolve having and order by clause before build projection, // because when the query is "select a+1 as b from t having sum(b) < 0", we must replace sum(b) to sum(a+1), // which only can be done before building projection and extracting Agg functions. @@ -1792,7 +1863,9 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } var oldLen int - p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap) + // According to https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html, + // we can only process window functions after having clause, so `considerWindow` is false now. + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap, false) if err != nil { return nil, errors.Trace(err) } @@ -1805,6 +1878,19 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } } + b.windowSpecs, err = buildWindowSpecs(sel.WindowSpecs) + if err != nil { + return nil, err + } + + if hasWindowFuncField { + // Now we build the window function fields. + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, windowMap, true) + if err != nil { + return nil, err + } + } + if sel.Distinct { p = b.buildDistinct(p, oldLen) } @@ -1900,10 +1986,10 @@ func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { } tableInfo := tbl.Meta() - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName.L, tableInfo.Name.L, "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName.L, tableInfo.Name.L, "", nil) if tableInfo.IsView() { - return b.buildDataSourceFromView(dbName, tableInfo) + return b.BuildDataSourceFromView(dbName, tableInfo) } if tableInfo.GetPartitionInfo() != nil { @@ -1917,6 +2003,13 @@ func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { var columns []*table.Column if b.inUpdateStmt { + // create table t(a int, b int). + // Imagine that, There are 2 TiDB instances in the cluster, name A, B. We add a column `c` to table t in the TiDB cluster. + // One of the TiDB, A, the column type in its infoschema is changed to public. And in the other TiDB, the column type is + // still StateWriteReorganization. + // TiDB A: insert into t values(1, 2, 3); + // TiDB B: update t set a = 2 where b = 2; + // If we use tbl.Cols() here, the update statement, will ignore the col `c`, and the data `3` will lost. columns = tbl.WritableCols() } else { columns = tbl.Cols() @@ -1998,7 +2091,8 @@ func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { return result, nil } -func (b *PlanBuilder) buildDataSourceFromView(dbName model.CIStr, tableInfo *model.TableInfo) (LogicalPlan, error) { +// BuildDataSourceFromView is used to build LogicalPlan from view +func (b *PlanBuilder) BuildDataSourceFromView(dbName model.CIStr, tableInfo *model.TableInfo) (LogicalPlan, error) { charset, collation := b.ctx.GetSessionVars().GetCharsetInfo() selectNode, err := parser.New().ParseOneStmt(tableInfo.View.SelectStmt, charset, collation) if err != nil { @@ -2210,13 +2304,13 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { } var tableList []*ast.TableName - tableList = extractTableList(sel.From.TableRefs, tableList) + tableList = extractTableList(sel.From.TableRefs, tableList, false) for _, t := range tableList { dbName := t.Schema.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) } if sel.Where != nil { @@ -2330,7 +2424,12 @@ func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A } for _, assign := range newList { col := assign.Col - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, col.DBName.L, col.TblName.L, "") + + dbName := col.DBName.L + if dbName == "" { + dbName = b.ctx.GetSessionVars().CurrentDB + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, col.OrigTblName.L, "", nil) } return newList, p, nil } @@ -2347,15 +2446,27 @@ func extractTableAsNameForUpdate(p LogicalPlan, asNames map[*model.TableInfo][]* asNames[x.tableInfo] = append(asNames[x.tableInfo], alias) } case *LogicalProjection: - if x.calculateGenCols { - ds := x.Children()[0].(*DataSource) - alias := extractTableAlias(x) - if alias != nil { - if _, ok := asNames[ds.tableInfo]; !ok { - asNames[ds.tableInfo] = make([]*model.CIStr, 0, 1) - } - asNames[ds.tableInfo] = append(asNames[ds.tableInfo], alias) + if !x.calculateGenCols { + return + } + + ds, isDS := x.Children()[0].(*DataSource) + if !isDS { + // try to extract the DataSource below a LogicalUnionScan. + if us, isUS := x.Children()[0].(*LogicalUnionScan); isUS { + ds, isDS = us.Children()[0].(*DataSource) + } + } + if !isDS { + return + } + + alias := extractTableAlias(x) + if alias != nil { + if _, ok := asNames[ds.tableInfo]; !ok { + asNames[ds.tableInfo] = make([]*model.CIStr, 0, 1) } + asNames[ds.tableInfo] = append(asNames[ds.tableInfo], alias) } default: for _, child := range p.Children() { @@ -2432,7 +2543,7 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { del.SetSchema(expression.NewSchema()) var tableList []*ast.TableName - tableList = extractTableList(delete.TableRefs.TableRefs, tableList) + tableList = extractTableList(delete.TableRefs.TableRefs, tableList, true) // Collect visitInfo. if delete.Tables != nil { @@ -2468,7 +2579,7 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { // check sql like: `delete b from (select * from t) as a, t` return nil, ErrUnknownTable.GenWithStackByArgs(tn.Name.O, "MULTI DELETE") } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tn.Schema.L, tn.TableInfo.Name.L, "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tn.Schema.L, tn.TableInfo.Name.L, "", nil) } } else { // Delete from a, b, c, d. @@ -2477,22 +2588,190 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, dbName, v.Name.L, "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, dbName, v.Name.L, "", nil) } } return del, nil } +// buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, +// so after the projection, window functions only needs to deal with columns. +func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []expression.Expression, error) { + b.optFlag |= flagEliminateProjection + + var items []*ast.ByItem + spec := expr.Spec + if spec.Ref.L != "" { + ref, ok := b.windowSpecs[spec.Ref.L] + if !ok { + return nil, nil, nil, ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) + } + err := mergeWindowSpec(&spec, &ref) + if err != nil { + return nil, nil, nil, err + } + } + if spec.PartitionBy != nil { + items = append(items, spec.PartitionBy.Items...) + } + if spec.OrderBy != nil { + items = append(items, spec.OrderBy.Items...) + } + projLen := len(p.Schema().Columns) + len(items) + len(expr.Args) + proj := LogicalProjection{Exprs: make([]expression.Expression, 0, projLen)}.Init(b.ctx) + schema := expression.NewSchema(make([]*expression.Column, 0, projLen)...) + for _, col := range p.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + schema.Append(col) + } + + transformer := &itemTransformer{} + propertyItems := make([]property.Item, 0, len(items)) + for _, item := range items { + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) + it, np, err := b.rewrite(item.Expr, p, aggMap, true) + if err != nil { + return nil, nil, nil, err + } + p = np + if col, ok := it.(*expression.Column); ok { + propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) + continue + } + proj.Exprs = append(proj.Exprs, it) + col := &expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: it.GetType(), + } + schema.Append(col) + propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) + } + + newArgList := make([]expression.Expression, 0, len(expr.Args)) + for _, arg := range expr.Args { + newArg, np, err := b.rewrite(arg, p, aggMap, true) + if err != nil { + return nil, nil, nil, err + } + p = np + if col, ok := newArg.(*expression.Column); ok { + newArgList = append(newArgList, col) + continue + } + proj.Exprs = append(proj.Exprs, newArg) + col := &expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: newArg.GetType(), + } + schema.Append(col) + newArgList = append(newArgList, col) + } + + proj.SetSchema(schema) + proj.SetChildren(p) + return proj, propertyItems, newArgList, nil +} + +func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (*LogicalWindow, error) { + p, byItems, args, err := b.buildProjectionForWindow(p, expr, aggMap) + if err != nil { + return nil, err + } + + desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args) + window := LogicalWindow{ + WindowFuncDesc: desc, + ByItems: byItems, + }.Init(b.ctx) + schema := p.Schema().Clone() + schema.Append(&expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_window_%d", window.id, p.Schema().Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + IsReferenced: true, + RetType: desc.RetTp, + }) + window.SetChildren(p) + window.SetSchema(schema) + return window, nil +} + +// resolveWindowSpec resolve window specifications for sql like `select ... from t window w1 as (w2), w2 as (partition by a)`. +// We need to resolve the referenced window to get the definition of current window spec. +func resolveWindowSpec(spec *ast.WindowSpec, specs map[string]ast.WindowSpec, inStack map[string]bool) error { + if inStack[spec.Name.L] { + return errors.Trace(ErrWindowCircularityInWindowGraph) + } + if spec.Ref.L == "" { + return nil + } + ref, ok := specs[spec.Ref.L] + if !ok { + return ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) + } + inStack[spec.Name.L] = true + err := resolveWindowSpec(&ref, specs, inStack) + if err != nil { + return err + } + inStack[spec.Name.L] = false + return mergeWindowSpec(spec, &ref) +} + +func mergeWindowSpec(spec, ref *ast.WindowSpec) error { + if ref.Frame != nil { + return ErrWindowNoInherentFrame.GenWithStackByArgs(ref.Name.O) + } + if ref.OrderBy != nil { + if spec.OrderBy != nil { + name := spec.Name.O + if name == "" { + name = "" + } + return ErrWindowNoRedefineOrderBy.GenWithStackByArgs(name, ref.Name.O) + } + spec.OrderBy = ref.OrderBy + } + if spec.PartitionBy != nil { + return errors.Trace(ErrWindowNoChildPartitioning) + } + spec.PartitionBy = ref.PartitionBy + spec.Ref = model.NewCIStr("") + return nil +} + +func buildWindowSpecs(specs []ast.WindowSpec) (map[string]ast.WindowSpec, error) { + specsMap := make(map[string]ast.WindowSpec, len(specs)) + for _, spec := range specs { + if _, ok := specsMap[spec.Name.L]; ok { + return nil, ErrWindowDuplicateName.GenWithStackByArgs(spec.Name.O) + } + specsMap[spec.Name.L] = spec + } + inStack := make(map[string]bool, len(specs)) + for _, spec := range specs { + err := resolveWindowSpec(&spec, specsMap, inStack) + if err != nil { + return nil, err + } + } + return specsMap, nil +} + // extractTableList extracts all the TableNames from node. -func extractTableList(node ast.ResultSetNode, input []*ast.TableName) []*ast.TableName { +// If asName is true, extract AsName prior to OrigName. +// Privilege check should use OrigName, while expression may use AsName. +func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName bool) []*ast.TableName { switch x := node.(type) { case *ast.Join: - input = extractTableList(x.Left, input) - input = extractTableList(x.Right, input) + input = extractTableList(x.Left, input, asName) + input = extractTableList(x.Right, input, asName) case *ast.TableSource: if s, ok := x.Source.(*ast.TableName); ok { - if x.AsName.L != "" { + if x.AsName.L != "" && asName { newTableName := *s newTableName.Name = x.AsName s.Name = x.AsName @@ -2527,12 +2806,13 @@ func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelec return input } -func appendVisitInfo(vi []visitInfo, priv mysql.PrivilegeType, db, tbl, col string) []visitInfo { +func appendVisitInfo(vi []visitInfo, priv mysql.PrivilegeType, db, tbl, col string, err error) []visitInfo { return append(vi, visitInfo{ privilege: priv, db: db, table: tbl, column: col, + err: err, }) } diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index d0c173e7ae374..b45f738681abc 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1354,130 +1354,137 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { { sql: "insert into t (a) values (1)", ans: []visitInfo{ - {mysql.InsertPriv, "test", "t", ""}, + {mysql.InsertPriv, "test", "t", "", nil}, }, }, { sql: "delete from t where a = 1", ans: []visitInfo{ - {mysql.DeletePriv, "test", "t", ""}, - {mysql.SelectPriv, "test", "t", ""}, + {mysql.DeletePriv, "test", "t", "", nil}, + {mysql.SelectPriv, "test", "t", "", nil}, }, }, { sql: "delete from a1 using t as a1 inner join t as a2 where a1.a = a2.a", ans: []visitInfo{ - {mysql.DeletePriv, "test", "t", ""}, - {mysql.SelectPriv, "test", "t", ""}, + {mysql.DeletePriv, "test", "t", "", nil}, + {mysql.SelectPriv, "test", "t", "", nil}, }, }, { sql: "update t set a = 7 where a = 1", ans: []visitInfo{ - {mysql.UpdatePriv, "test", "t", ""}, - {mysql.SelectPriv, "test", "t", ""}, + {mysql.UpdatePriv, "test", "t", "", nil}, + {mysql.SelectPriv, "test", "t", "", nil}, }, }, { sql: "update t, (select * from t) a1 set t.a = a1.a;", ans: []visitInfo{ - {mysql.UpdatePriv, "test", "t", ""}, - {mysql.SelectPriv, "test", "t", ""}, + {mysql.UpdatePriv, "test", "t", "", nil}, + {mysql.SelectPriv, "test", "t", "", nil}, + }, + }, + { + sql: "update t a1 set a1.a = a1.a + 1", + ans: []visitInfo{ + {mysql.UpdatePriv, "test", "t", "", nil}, + {mysql.SelectPriv, "test", "t", "", nil}, }, }, { sql: "select a, sum(e) from t group by a", ans: []visitInfo{ - {mysql.SelectPriv, "test", "t", ""}, + {mysql.SelectPriv, "test", "t", "", nil}, }, }, { sql: "truncate table t", ans: []visitInfo{ - {mysql.DeletePriv, "test", "t", ""}, + {mysql.DeletePriv, "test", "t", "", nil}, }, }, { sql: "drop table t", ans: []visitInfo{ - {mysql.DropPriv, "test", "t", ""}, + {mysql.DropPriv, "test", "t", "", nil}, }, }, { sql: "create table t (a int)", ans: []visitInfo{ - {mysql.CreatePriv, "test", "t", ""}, + {mysql.CreatePriv, "test", "t", "", nil}, }, }, { sql: "create table t1 like t", ans: []visitInfo{ - {mysql.CreatePriv, "test", "t1", ""}, - {mysql.SelectPriv, "test", "t", ""}, + {mysql.CreatePriv, "test", "t1", "", nil}, + {mysql.SelectPriv, "test", "t", "", nil}, }, }, { sql: "create database test", ans: []visitInfo{ - {mysql.CreatePriv, "test", "", ""}, + {mysql.CreatePriv, "test", "", "", nil}, }, }, { sql: "drop database test", ans: []visitInfo{ - {mysql.DropPriv, "test", "", ""}, + {mysql.DropPriv, "test", "", "", nil}, }, }, { sql: "create index t_1 on t (a)", ans: []visitInfo{ - {mysql.IndexPriv, "test", "t", ""}, + {mysql.IndexPriv, "test", "t", "", nil}, }, }, { sql: "drop index e on t", ans: []visitInfo{ - {mysql.IndexPriv, "test", "t", ""}, + {mysql.IndexPriv, "test", "t", "", nil}, }, }, { sql: `create user 'test'@'%' identified by '123456'`, ans: []visitInfo{ - {mysql.CreateUserPriv, "", "", ""}, + {mysql.CreateUserPriv, "", "", "", nil}, }, }, { sql: `drop user 'test'@'%'`, ans: []visitInfo{ - {mysql.CreateUserPriv, "", "", ""}, + {mysql.CreateUserPriv, "", "", "", nil}, }, }, { sql: `grant all privileges on test.* to 'test'@'%'`, ans: []visitInfo{ - {mysql.SelectPriv, "test", "", ""}, - {mysql.InsertPriv, "test", "", ""}, - {mysql.UpdatePriv, "test", "", ""}, - {mysql.DeletePriv, "test", "", ""}, - {mysql.CreatePriv, "test", "", ""}, - {mysql.DropPriv, "test", "", ""}, - {mysql.GrantPriv, "test", "", ""}, - {mysql.AlterPriv, "test", "", ""}, - {mysql.ExecutePriv, "test", "", ""}, - {mysql.IndexPriv, "test", "", ""}, + {mysql.SelectPriv, "test", "", "", nil}, + {mysql.InsertPriv, "test", "", "", nil}, + {mysql.UpdatePriv, "test", "", "", nil}, + {mysql.DeletePriv, "test", "", "", nil}, + {mysql.CreatePriv, "test", "", "", nil}, + {mysql.DropPriv, "test", "", "", nil}, + {mysql.GrantPriv, "test", "", "", nil}, + {mysql.AlterPriv, "test", "", "", nil}, + {mysql.ExecutePriv, "test", "", "", nil}, + {mysql.IndexPriv, "test", "", "", nil}, }, }, { sql: `grant select on test.ttt to 'test'@'%'`, ans: []visitInfo{ - {mysql.SelectPriv, "test", "ttt", ""}, - {mysql.GrantPriv, "test", "ttt", ""}, + {mysql.SelectPriv, "test", "ttt", "", nil}, + {mysql.GrantPriv, "test", "ttt", "", nil}, }, }, { sql: `revoke all privileges on *.* from 'test'@'%'`, ans: []visitInfo{ - {mysql.SuperPriv, "", "", ""}, + {mysql.SuperPriv, "", "", "", nil}, }, }, { @@ -1487,7 +1494,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { { sql: `show create table test.ttt`, ans: []visitInfo{ - {mysql.AllPrivMask, "test", "ttt", ""}, + {mysql.AllPrivMask, "test", "ttt", "", nil}, }, }, } @@ -1897,3 +1904,132 @@ func (s *testPlanSuite) TestSelectView(c *C) { c.Assert(ToString(p), Equals, tt.best, comment) } } + +func (s *testPlanSuite) TestWindowFunction(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + sql string + result string + }{ + { + sql: "select a, avg(a) over(partition by a) from t", + result: "TableReader(Table(t))->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, avg(a) over(partition by b) from t", + result: "TableReader(Table(t))->Sort->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, avg(a+1) over(partition by (a+1)) from t", + result: "TableReader(Table(t))->Projection->Sort->Window(avg(2_proj_window_3))->Projection", + }, + { + sql: "select a, avg(a) over(order by a asc, b desc) from t order by a asc, b desc", + result: "TableReader(Table(t))->Sort->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, b as a, avg(a) over(partition by a) from t", + result: "TableReader(Table(t))->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, b as z, sum(z) over() from t", + result: "[planner:1054]Unknown column 'z' in 'field list'", + }, + { + sql: "select a, b as z from t order by (sum(z) over())", + result: "TableReader(Table(t))->Window(sum(test.t.z))->Sort->Projection", + }, + { + sql: "select sum(avg(a)) over() from t", + result: "TableReader(Table(t)->StreamAgg)->StreamAgg->Window(sum(sel_agg_2))->Projection", + }, + { + sql: "select b from t order by(sum(a) over())", + result: "TableReader(Table(t))->Window(sum(test.t.a))->Sort->Projection", + }, + { + sql: "select b from t order by(sum(a) over(partition by a))", + result: "TableReader(Table(t))->Window(sum(test.t.a))->Sort->Projection", + }, + { + sql: "select b from t order by(sum(avg(a)) over())", + result: "TableReader(Table(t)->StreamAgg)->StreamAgg->Window(sum(sel_agg_2))->Sort->Projection", + }, + { + sql: "select a from t having (select sum(a) over() as w from t tt where a > t.a)", + result: "Apply{TableReader(Table(t))->TableReader(Table(t)->Sel([gt(tt.a, test.t.a)]))->Window(sum(tt.a))->MaxOneRow->Sel([w])}->Projection", + }, + { + sql: "select avg(a) over() as w from t having w > 1", + result: "[planner:3594]You cannot use the alias 'w' of an expression containing a window function in this context.'", + }, + { + sql: "select sum(a) over() as sum_a from t group by sum_a", + result: "[planner:1247]Reference 'sum_a' not supported (reference to window function)", + }, + { + sql: "select sum(a) over() from t window w1 as (w2)", + result: "[planner:3579]Window name 'w2' is not defined.", + }, + { + sql: "select sum(a) over(w) from t", + result: "[planner:3579]Window name 'w' is not defined.", + }, + { + sql: "select sum(a) over() from t window w1 as (w2), w2 as (w1)", + result: "[planner:3580]There is a circularity in the window dependency graph.", + }, + { + sql: "select sum(a) over(w partition by a) from t window w as ()", + result: "[planner:3581]A window which depends on another cannot define partitioning.", + }, + { + sql: "select sum(a) over(w) from t window w as (rows between 1 preceding AND 1 following)", + result: "[planner:3582]Window 'w' has a frame definition, so cannot be referenced by another window.", + }, + { + sql: "select sum(a) over(w order by b) from t window w as (order by a)", + result: "[planner:3583]Window '' cannot inherit 'w' since both contain an ORDER BY clause.", + }, + { + sql: "select sum(a) over() from t window w1 as (), w1 as ()", + result: "[planner:3591]Window 'w1' is defined twice.", + }, + { + sql: "select sum(a) over(w1), avg(a) over(w2) from t window w1 as (partition by a), w2 as (w1)", + result: "TableReader(Table(t))->Window(sum(test.t.a))->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a from t window w1 as (partition by a) order by (sum(a) over(w1))", + result: "TableReader(Table(t))->Window(sum(test.t.a))->Sort->Projection", + }, + } + + s.Parser.EnableWindowFunc(true) + defer func() { + s.Parser.EnableWindowFunc(false) + }() + for i, tt := range tests { + comment := Commentf("case:%v sql:%s", i, tt.sql) + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + Preprocess(s.ctx, stmt, s.is, false) + builder := &PlanBuilder{ + ctx: MockContext(), + is: s.is, + colMapper: make(map[*ast.ColumnNameExpr]int), + } + p, err := builder.Build(stmt) + if err != nil { + c.Assert(err.Error(), Equals, tt.result, comment) + continue + } + c.Assert(err, IsNil) + p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + c.Assert(err, IsNil) + lp, ok := p.(LogicalPlan) + c.Assert(ok, IsTrue) + p, err = physicalOptimize(lp) + c.Assert(ToString(p), Equals, tt.result, comment) + } +} diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index d7e8d18b53e98..cee54de15f6b9 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" @@ -42,6 +43,7 @@ var ( _ LogicalPlan = &LogicalSort{} _ LogicalPlan = &LogicalLock{} _ LogicalPlan = &LogicalLimit{} + _ LogicalPlan = &LogicalWindow{} ) // JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, FullOuterJoin, SemiJoin. @@ -617,3 +619,17 @@ type LogicalLock struct { Lock ast.SelectLockType } + +// LogicalWindow represents a logical window function plan. +type LogicalWindow struct { + logicalSchemaProducer + + WindowFuncDesc *aggregation.WindowFuncDesc + ByItems []property.Item // ByItems is composed of `PARTITION BY` and `ORDER BY` items. + // TODO: add frame clause +} + +// GetWindowResultColumn returns the column storing the result of the window function. +func (p *LogicalWindow) GetWindowResultColumn() *expression.Column { + return p.schema.Columns[p.schema.Len()-1] +} diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 42aa815c26da1..8446e37309ce4 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -83,13 +83,16 @@ func BuildLogicalPlan(ctx sessionctx.Context, node ast.Node, is infoschema.InfoS } // CheckPrivilege checks the privilege for a user. -func CheckPrivilege(pm privilege.Manager, vs []visitInfo) bool { +func CheckPrivilege(pm privilege.Manager, vs []visitInfo) error { for _, v := range vs { if !pm.RequestVerification(v.db, v.table, v.column, v.privilege) { - return false + if v.err == nil { + return ErrPrivilegeCheckFail + } + return v.err } } - return true + return nil } // DoOptimize optimizes a logical plan to a physical plan. @@ -105,10 +108,18 @@ func DoOptimize(flag uint64, logic LogicalPlan) (PhysicalPlan, error) { if err != nil { return nil, errors.Trace(err) } - finalPlan := eliminatePhysicalProjection(physical) + finalPlan := postOptimize(physical) return finalPlan, nil } +func postOptimize(plan PhysicalPlan) PhysicalPlan { + plan = eliminatePhysicalProjection(plan) + + // build projection below aggregation + plan = buildProjBelowAgg(plan) + return plan +} + func logicalOptimize(flag uint64, logic LogicalPlan) (LogicalPlan, error) { var err error for i, rule := range optRuleList { diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 2b350faf6977f..7e34f1f588177 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -472,7 +472,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderSubquery(c *C) { // Test Nested sub query. { sql: "select * from t where exists (select s.a from t s where s.c in (select c from t as k where k.d = s.d) having sum(s.a) = t.a )", - best: "LeftHashJoin{TableReader(Table(t))->Projection->MergeSemiJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(s.c,k.c)(s.d,k.d)->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection", + best: "LeftHashJoin{TableReader(Table(t))->Projection->MergeSemiJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(s.c,k.c)(s.d,k.d)->Projection->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection", }, // Test Semi Join + Order by. { @@ -821,7 +821,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { }, { sql: "select sum(distinct a), avg(b + c) from t group by d", - best: "TableReader(Table(t))->HashAgg", + best: "TableReader(Table(t))->Projection->HashAgg", }, // Test group by (c + d) { @@ -841,27 +841,27 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { // Test hash agg + index double. { sql: "select sum(e), avg(b + c) from t where c = 1 and e = 1 group by d", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->HashAgg", + best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->Projection->HashAgg", }, // Test stream agg + index double. { sql: "select sum(e), avg(b + c) from t where c = 1 and b = 1 group by c", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t)->Sel([eq(test.t.b, 1)]))->StreamAgg", + best: "IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t)->Sel([eq(test.t.b, 1)]))->Projection->StreamAgg", }, // Test hash agg + order. { sql: "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by d order by k", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->HashAgg->Sort", + best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->Projection->HashAgg->Sort", }, // Test stream agg + order. { sql: "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by c order by k", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->StreamAgg->Sort", + best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->Projection->StreamAgg->Sort", }, // Test agg can't push down. { sql: "select sum(to_base64(e)) from t where c = 1", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->StreamAgg", + best: "IndexReader(Index(t.c_d_e)[[1,1]])->Projection->StreamAgg", }, { sql: "select (select count(1) k from t s where s.a = t.a having k != 0) from t", @@ -870,7 +870,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { // Test stream agg with multi group by columns. { sql: "select sum(to_base64(e)) from t group by e,d,c order by c", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->StreamAgg->Projection", + best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Projection->StreamAgg->Projection", }, { sql: "select sum(e+1) from t group by e,d,c order by c", @@ -878,7 +878,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { }, { sql: "select sum(to_base64(e)) from t group by e,d,c order by c,e", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->StreamAgg->Sort->Projection", + best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Projection->StreamAgg->Sort->Projection", }, { sql: "select sum(e+1) from t group by e,d,c order by c,e", @@ -917,16 +917,16 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { // Test merge join + stream agg { sql: "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g group by a.g", - best: "MergeInnerJoin{IndexReader(Index(t.g)[[NULL,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]])}(a.g,b.g)->StreamAgg", + best: "MergeInnerJoin{IndexReader(Index(t.g)[[NULL,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]])}(a.g,b.g)->Projection->StreamAgg", }, // Test index join + stream agg { sql: "select /*+ tidb_inlj(a,b) */ sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.g > 60 group by a.g order by a.g limit 1", - best: "IndexJoin{IndexReader(Index(t.g)[(60,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(b.g, 60)]))}(a.g,b.g)->StreamAgg->Limit->Projection", + best: "IndexJoin{IndexReader(Index(t.g)[(60,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(b.g, 60)]))}(a.g,b.g)->Projection->StreamAgg->Limit->Projection", }, { sql: "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.a>5 group by a.g order by a.g limit 1", - best: "IndexJoin{IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(a.a, 5)]))->IndexReader(Index(t.g)[[NULL,+inf]])}(a.g,b.g)->StreamAgg->Limit->Projection", + best: "IndexJoin{IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(a.a, 5)]))->IndexReader(Index(t.g)[[NULL,+inf]])}(a.g,b.g)->Projection->StreamAgg->Limit->Projection", }, { sql: "select sum(d) from t", @@ -1217,7 +1217,7 @@ func (s *testPlanSuite) TestAggEliminater(c *C) { // If max/min contains scalar function, we can still do transformation. { sql: "select max(a+1) from t;", - best: "TableReader(Table(t)->Sel([not(isnull(plus(test.t.a, 1)))])->TopN([plus(test.t.a, 1) true],0,1))->TopN([plus(test.t.a, 1) true],0,1)->StreamAgg", + best: "TableReader(Table(t)->Sel([not(isnull(plus(test.t.a, 1)))])->TopN([plus(test.t.a, 1) true],0,1))->TopN([plus(test.t.a, 1) true],0,1)->Projection->StreamAgg", }, // Do nothing to max+min. { diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 4c66330fabdfe..4f404e88d9f2c 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -47,6 +47,7 @@ var ( _ PhysicalPlan = &PhysicalHashJoin{} _ PhysicalPlan = &PhysicalMergeJoin{} _ PhysicalPlan = &PhysicalUnionScan{} + _ PhysicalPlan = &PhysicalWindow{} ) // PhysicalTableReader is the table reader in tidb. @@ -373,3 +374,10 @@ type PhysicalTableDual struct { RowCount int } + +// PhysicalWindow is the physical operator of window function. +type PhysicalWindow struct { + physicalSchemaProducer + + WindowFuncDesc *aggregation.WindowFuncDesc +} diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 87ec96a350385..ccd0f72e2f156 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -41,6 +41,7 @@ type visitInfo struct { db string table string column string + err error } type tableHintInfo struct { @@ -93,6 +94,7 @@ const ( onClause orderByClause whereClause + windowClause groupByClause showStatement globalOrderByClause @@ -108,6 +110,7 @@ var clauseMsg = map[clauseCode]string{ groupByClause: "group statement", showStatement: "show statement", globalOrderByClause: "global ORDER clause", + windowClause: "field list", // For window functions that in field list. } // PlanBuilder builds Plan from an ast.Node. @@ -134,6 +137,8 @@ type PlanBuilder struct { // inStraightJoin represents whether the current "SELECT" statement has // "STRAIGHT_JOIN" option. inStraightJoin bool + + windowSpecs map[string]ast.WindowSpec } // GetVisitInfo gets the visitInfo of the PlanBuilder. @@ -196,7 +201,7 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) { case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt: - return b.buildSimple(node.(ast.StmtNode)), nil + return b.buildSimple(node.(ast.StmtNode)) case ast.DDLNode: return b.buildDDL(x) } @@ -245,6 +250,9 @@ func (b *PlanBuilder) buildDo(v *ast.DoStmt) (Plan, error) { func (b *PlanBuilder) buildSet(v *ast.SetStmt) (Plan, error) { p := &Set{} for _, vars := range v.Variables { + if vars.IsGlobal { + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) + } assign := &expression.VarAssignment{ Name: vars.Name, IsGlobal: vars.IsGlobal, @@ -300,6 +308,15 @@ func (b *PlanBuilder) detectSelectAgg(sel *ast.SelectStmt) bool { return false } +func (b *PlanBuilder) detectSelectWindow(sel *ast.SelectStmt) bool { + for _, f := range sel.Fields.Fields { + if ast.HasWindowFlag(f.Expr) { + return true + } + } + return false +} + func getPathByIndexName(paths []*accessPath, idxName model.CIStr, tblInfo *model.TableInfo) *accessPath { var tablePath *accessPath for _, path := range paths { @@ -538,7 +555,7 @@ func (b *PlanBuilder) buildAdmin(as *ast.AdminStmt) (Plan, error) { } // Admin command can only be executed by administrator. - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) return ret, nil } @@ -750,8 +767,8 @@ const ( func (b *PlanBuilder) buildAnalyze(as *ast.AnalyzeTableStmt) (Plan, error) { for _, tbl := range as.TableNames { - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.InsertPriv, tbl.Schema.O, tbl.Name.O, "") - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, tbl.Schema.O, tbl.Name.O, "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.InsertPriv, tbl.Schema.O, tbl.Name.O, "", nil) + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, tbl.Schema.O, tbl.Name.O, "", nil) } if as.MaxNumBuckets == 0 { as.MaxNumBuckets = defaultMaxNumBuckets @@ -903,6 +920,7 @@ func (b *PlanBuilder) buildShow(show *ast.ShowStmt) (Plan, error) { Flag: show.Flag, Full: show.Full, User: show.User, + IfNotExists: show.IfNotExists, GlobalScope: show.GlobalScope, }.Init(b.ctx) switch showTp := show.Tp; showTp { @@ -921,7 +939,7 @@ func (b *PlanBuilder) buildShow(show *ast.ShowStmt) (Plan, error) { return nil, ErrNoDB } case ast.ShowCreateTable: - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.AllPrivMask, show.Table.Schema.L, show.Table.Name.L, "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.AllPrivMask, show.Table.Schema.L, show.Table.Name.L, "", nil) } p.SetSchema(buildShowSchema(show)) } @@ -954,16 +972,16 @@ func (b *PlanBuilder) buildShow(show *ast.ShowStmt) (Plan, error) { return p, nil } -func (b *PlanBuilder) buildSimple(node ast.StmtNode) Plan { +func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) { p := &Simple{Statement: node} switch raw := node.(type) { case *ast.CreateUserStmt, *ast.DropUserStmt, *ast.AlterUserStmt: - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", nil) case *ast.GrantStmt: b.visitInfo = collectVisitInfoFromGrantStmt(b.visitInfo, raw) case *ast.RevokeStmt: - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) case *ast.KillStmt: // If you have the SUPER privilege, you can kill all threads and statements. // Otherwise, you can kill only your own threads and statements. @@ -973,12 +991,16 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) Plan { if pi, ok := processList[raw.ConnectionID]; ok { loginUser := b.ctx.GetSessionVars().User if pi.User != loginUser.Username { - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) } } } + case *ast.UseStmt: + if raw.DBName == "" { + return nil, ErrNoDB + } } - return p + return p, nil } func collectVisitInfoFromGrantStmt(vi []visitInfo, stmt *ast.GrantStmt) []visitInfo { @@ -986,7 +1008,7 @@ func collectVisitInfoFromGrantStmt(vi []visitInfo, stmt *ast.GrantStmt) []visitI // and you must have the privileges that you are granting. dbName := stmt.Level.DBName tableName := stmt.Level.TableName - vi = appendVisitInfo(vi, mysql.GrantPriv, dbName, tableName, "") + vi = appendVisitInfo(vi, mysql.GrantPriv, dbName, tableName, "", nil) var allPrivs []mysql.PrivilegeType for _, item := range stmt.Privs { @@ -1001,11 +1023,11 @@ func collectVisitInfoFromGrantStmt(vi []visitInfo, stmt *ast.GrantStmt) []visitI } break } - vi = appendVisitInfo(vi, item.Priv, dbName, tableName, "") + vi = appendVisitInfo(vi, item.Priv, dbName, tableName, "", nil) } for _, priv := range allPrivs { - vi = appendVisitInfo(vi, priv, dbName, tableName, "") + vi = appendVisitInfo(vi, priv, dbName, tableName, "", nil) } return vi @@ -1093,6 +1115,7 @@ func (b *PlanBuilder) buildInsert(insert *ast.InsertStmt) (Plan, error) { privilege: mysql.InsertPriv, db: tn.DBInfo.Name.L, table: tableInfo.Name.L, + err: nil, }) mockTablePlan := LogicalTableDual{}.Init(b.ctx) @@ -1403,29 +1426,34 @@ func (b *PlanBuilder) buildDDL(node ast.DDLNode) (Plan, error) { privilege: mysql.AlterPriv, db: v.Table.Schema.L, table: v.Table.Name.L, + err: nil, }) case *ast.CreateDatabaseStmt: b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.CreatePriv, db: v.Name, + err: nil, }) case *ast.CreateIndexStmt: b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.IndexPriv, db: v.Table.Schema.L, table: v.Table.Name.L, + err: nil, }) case *ast.CreateTableStmt: b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.CreatePriv, db: v.Table.Schema.L, table: v.Table.Name.L, + err: nil, }) if v.ReferTable != nil { b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.SelectPriv, db: v.ReferTable.Schema.L, table: v.ReferTable.Name.L, + err: nil, }) } case *ast.CreateViewStmt: @@ -1450,18 +1478,21 @@ func (b *PlanBuilder) buildDDL(node ast.DDLNode) (Plan, error) { privilege: mysql.CreatePriv, db: v.ViewName.Schema.L, table: v.ViewName.Name.L, + err: nil, }) } case *ast.DropDatabaseStmt: b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.DropPriv, db: v.Name, + err: nil, }) case *ast.DropIndexStmt: b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.IndexPriv, db: v.Table.Schema.L, table: v.Table.Name.L, + err: nil, }) case *ast.DropTableStmt: for _, tableVal := range v.Tables { @@ -1469,6 +1500,7 @@ func (b *PlanBuilder) buildDDL(node ast.DDLNode) (Plan, error) { privilege: mysql.DropPriv, db: tableVal.Schema.L, table: tableVal.Name.L, + err: nil, }) } case *ast.TruncateTableStmt: @@ -1476,17 +1508,20 @@ func (b *PlanBuilder) buildDDL(node ast.DDLNode) (Plan, error) { privilege: mysql.DeletePriv, db: v.Table.Schema.L, table: v.Table.Name.L, + err: nil, }) case *ast.RenameTableStmt: b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.AlterPriv, db: v.OldTable.Schema.L, table: v.OldTable.Name.L, + err: nil, }) b.visitInfo = append(b.visitInfo, visitInfo{ privilege: mysql.AlterPriv, db: v.NewTable.Schema.L, table: v.NewTable.Name.L, + err: nil, }) } diff --git a/planner/core/rule_build_proj_below_agg.go b/planner/core/rule_build_proj_below_agg.go new file mode 100644 index 0000000000000..69d1a00cdb79d --- /dev/null +++ b/planner/core/rule_build_proj_below_agg.go @@ -0,0 +1,168 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "fmt" + + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" +) + +// buildProjBelowAgg builds a ProjOperator below AggOperator. +// If all the args of `aggFuncs`, and all the item of `groupByItems` +// are columns or constants, we do not need to build the `proj`. +func buildProjBelowAgg(plan PhysicalPlan) PhysicalPlan { + for _, child := range plan.Children() { + buildProjBelowAgg(child) + } + + var aggFuncs []*aggregation.AggFuncDesc + var groupByItems []expression.Expression + if aggHash, ok := plan.(*PhysicalHashAgg); ok { + aggFuncs = aggHash.AggFuncs + groupByItems = aggHash.GroupByItems + } else if aggStream, ok := plan.(*PhysicalStreamAgg); ok { + aggFuncs = aggStream.AggFuncs + groupByItems = aggStream.GroupByItems + } else { + return plan + } + + return doBuildProjBelowAgg(plan, aggFuncs, groupByItems) +} + +func doBuildProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression) PhysicalPlan { + hasScalarFunc := false + + // If the mode is FinalMode, we do not need to wrap cast upon the args, + // since the types of the args are already the expected. + if len(aggFuncs) > 0 && aggFuncs[0].Mode != aggregation.FinalMode { + wrapCastForAggArgs(aggPlan.context(), aggFuncs) + } + + for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ { + for _, arg := range aggFuncs[i].Args { + _, isScalarFunc := arg.(*expression.ScalarFunction) + hasScalarFunc = hasScalarFunc || isScalarFunc + } + } + for i := 0; !hasScalarFunc && i < len(groupByItems); i++ { + _, isScalarFunc := groupByItems[i].(*expression.ScalarFunction) + hasScalarFunc = hasScalarFunc || isScalarFunc + } + if !hasScalarFunc { + return aggPlan + } + + projSchemaCols := make([]*expression.Column, 0, len(aggFuncs)+len(groupByItems)) + projExprs := make([]expression.Expression, 0, cap(projSchemaCols)) + cursor := 0 + + for _, f := range aggFuncs { + for i, arg := range f.Args { + if _, isCnst := arg.(*expression.Constant); isCnst { + continue + } + projExprs = append(projExprs, arg) + newArg := &expression.Column{ + RetType: arg.GetType(), + ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(projSchemaCols))), + Index: cursor, + } + projSchemaCols = append(projSchemaCols, newArg) + f.Args[i] = newArg + cursor++ + } + } + + for i, item := range groupByItems { + if _, isCnst := item.(*expression.Constant); isCnst { + continue + } + projExprs = append(projExprs, item) + newArg := &expression.Column{ + RetType: item.GetType(), + ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(projSchemaCols))), + Index: cursor, + } + projSchemaCols = append(projSchemaCols, newArg) + groupByItems[i] = newArg + cursor++ + } + + child := aggPlan.Children()[0] + prop := aggPlan.GetChildReqProps(0).Clone() + proj := PhysicalProjection{ + Exprs: projExprs, + AvoidColumnEvaluator: false, + }.Init(aggPlan.context(), child.statsInfo().ScaleByExpectCnt(prop.ExpectedCnt), prop) + proj.SetSchema(expression.NewSchema(projSchemaCols...)) + proj.SetChildren(child) + + aggPlan.SetChildren(proj) + return aggPlan +} + +// wrapCastForAggArgs wraps the args of an aggregate function with a cast function. +func wrapCastForAggArgs(ctx sessionctx.Context, funcs []*aggregation.AggFuncDesc) { + for _, f := range funcs { + // We do not need to wrap cast upon these functions, + // since the EvalXXX method called by the arg is determined by the corresponding arg type. + if f.Name == ast.AggFuncCount || f.Name == ast.AggFuncMin || f.Name == ast.AggFuncMax || f.Name == ast.AggFuncFirstRow { + continue + } + var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression + switch retTp := f.RetTp; retTp.EvalType() { + case types.ETInt: + castFunc = expression.WrapWithCastAsInt + case types.ETReal: + castFunc = expression.WrapWithCastAsReal + case types.ETString: + castFunc = expression.WrapWithCastAsString + case types.ETDecimal: + castFunc = expression.WrapWithCastAsDecimal + default: + panic("should never happen in executorBuilder.wrapCastForAggArgs") + } + for i := range f.Args { + f.Args[i] = castFunc(ctx, f.Args[i]) + if f.Name != ast.AggFuncAvg && f.Name != ast.AggFuncSum { + continue + } + // After wrapping cast on the argument, flen etc. may not the same + // as the type of the aggregation function. The following part set + // the type of the argument exactly as the type of the aggregation + // function. + // Note: If the `Tp` of argument is the same as the `Tp` of the + // aggregation function, it will not wrap cast function on it + // internally. The reason of the special handling for `Column` is + // that the `RetType` of `Column` refers to the `infoschema`, so we + // need to set a new variable for it to avoid modifying the + // definition in `infoschema`. + if col, ok := f.Args[i].(*expression.Column); ok { + col.RetType = types.NewFieldType(col.RetType.Tp) + } + // originTp is used when the the `Tp` of column is TypeFloat32 while + // the type of the aggregation function is TypeFloat64. + originTp := f.Args[i].GetType().Tp + *(f.Args[i].GetType()) = *(f.RetTp) + f.Args[i].GetType().Tp = originTp + } + } +} diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 05b239ee3baf7..bf2d86a916575 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -15,7 +15,6 @@ package core import ( "fmt" - "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" @@ -265,3 +264,30 @@ func (p *LogicalLock) PruneColumns(parentUsedCols []*expression.Column) { p.children[0].PruneColumns(parentUsedCols) } } + +// PruneColumns implements LogicalPlan interface. +func (p *LogicalWindow) PruneColumns(parentUsedCols []*expression.Column) { + windowColumn := p.GetWindowResultColumn() + len := 0 + for _, col := range parentUsedCols { + if !windowColumn.Equal(nil, col) { + parentUsedCols[len] = col + len++ + } + } + parentUsedCols = parentUsedCols[:len] + parentUsedCols = p.extractUsedCols(parentUsedCols) + p.children[0].PruneColumns(parentUsedCols) + p.SetSchema(p.children[0].Schema().Clone()) + p.Schema().Append(windowColumn) +} + +func (p *LogicalWindow) extractUsedCols(parentUsedCols []*expression.Column) []*expression.Column { + for _, arg := range p.WindowFuncDesc.Args { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(arg)...) + } + for _, by := range p.ByItems { + parentUsedCols = append(parentUsedCols, by.Col) + } + return parentUsedCols +} diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 3e70bdb998327..597ce50a9e8d6 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -117,6 +117,8 @@ func (pe *projectionEliminater) eliminate(p LogicalPlan, replace map[string]*exp childFlag = false } else if _, isAgg := p.(*LogicalAggregation); isAgg || isProj { childFlag = true + } else if _, isWindow := p.(*LogicalWindow); isWindow { + childFlag = true } for i, child := range p.Children() { p.Children()[i] = pe.eliminate(child, replace, childFlag) @@ -204,3 +206,12 @@ func (lt *LogicalTopN) replaceExprColumns(replace map[string]*expression.Column) resolveExprAndReplace(byItem.Expr, replace) } } + +func (p *LogicalWindow) replaceExprColumns(replace map[string]*expression.Column) { + for _, arg := range p.WindowFuncDesc.Args { + resolveExprAndReplace(arg, replace) + } + for _, item := range p.ByItems { + resolveColumnAndReplace(item.Col, replace) + } +} diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index bf352289da908..1523ebeaf136f 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -465,3 +465,10 @@ func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []e p.attachOnConds(joinConds) return predicates } + +// PredicatePushDown implements LogicalPlan PredicatePushDown interface. +func (p *LogicalWindow) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { + // Window function forbids any condition to push down. + p.baseLogicalPlan.PredicatePushDown(nil) + return predicates, p +} diff --git a/planner/core/stats.go b/planner/core/stats.go index c9dac5b58c045..eadafd4961541 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -337,3 +337,16 @@ func (p *LogicalMaxOneRow) DeriveStats(childStats []*property.StatsInfo) (*prope p.stats = getSingletonStats(p.Schema().Len()) return p.stats, nil } + +// DeriveStats implement LogicalPlan DeriveStats interface. +func (p *LogicalWindow) DeriveStats(childStats []*property.StatsInfo) (*property.StatsInfo, error) { + childProfile := childStats[0] + childLen := len(childProfile.Cardinality) + p.stats = &property.StatsInfo{ + RowCount: childProfile.RowCount, + Cardinality: make([]float64, childLen+1), + } + copy(p.stats.Cardinality, childProfile.Cardinality) + p.stats.Cardinality[childLen] = childProfile.RowCount + return p.stats, nil +} diff --git a/planner/core/stringer.go b/planner/core/stringer.go index 6cd078354469a..62459b66d4524 100644 --- a/planner/core/stringer.go +++ b/planner/core/stringer.go @@ -220,6 +220,10 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { if x.SelectPlan != nil { str = fmt.Sprintf("%s->Insert", ToString(x.SelectPlan)) } + case *LogicalWindow: + str = fmt.Sprintf("Window(%s)", x.WindowFuncDesc.String()) + case *PhysicalWindow: + str = fmt.Sprintf("Window(%s)", x.WindowFuncDesc.String()) default: str = fmt.Sprintf("%T", in) } diff --git a/planner/optimize.go b/planner/optimize.go index 88cb5e557a8d5..0f2cfe3a5ba67 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -44,8 +44,8 @@ func Optimize(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) ( // we need the table information to check privilege, which is collected // into the visitInfo in the logical plan builder. if pm := privilege.GetPrivilegeManager(ctx); pm != nil { - if !plannercore.CheckPrivilege(pm, builder.GetVisitInfo()) { - return nil, errors.New("privilege check fail") + if err := plannercore.CheckPrivilege(pm, builder.GetVisitInfo()); err != nil { + return nil, err } } diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index a827e041241aa..9f412715ed20a 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -326,6 +326,21 @@ func (s *testPrivilegeSuite) TestUseDb(c *C) { } +func (s *testPrivilegeSuite) TestSetGlobal(c *C) { + se := newSession(c, s.store, s.dbName) + mustExec(c, se, `CREATE USER setglobal_a@localhost`) + mustExec(c, se, `CREATE USER setglobal_b@localhost`) + mustExec(c, se, `GRANT SUPER ON *.* to setglobal_a@localhost`) + mustExec(c, se, `FLUSH PRIVILEGES`) + + c.Assert(se.Auth(&auth.UserIdentity{Username: "setglobal_a", Hostname: "localhost"}, nil, nil), IsTrue) + mustExec(c, se, `set global innodb_commit_concurrency=16`) + + c.Assert(se.Auth(&auth.UserIdentity{Username: "setglobal_b", Hostname: "localhost"}, nil, nil), IsTrue) + _, err := se.Execute(context.Background(), `set global innodb_commit_concurrency=16`) + c.Assert(strings.Contains(err.Error(), "privilege check fail"), IsTrue) +} + func (s *testPrivilegeSuite) TestAnalyzeTable(c *C) { se := newSession(c, s.store, s.dbName) diff --git a/server/server.go b/server/server.go index 6f52d0ddb2d19..608f387004635 100644 --- a/server/server.go +++ b/server/server.go @@ -32,6 +32,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "io/ioutil" "math/rand" "net" @@ -80,6 +81,7 @@ type Server struct { tlsConfig *tls.Config driver IDriver listener net.Listener + socket net.Listener rwlock *sync.RWMutex concurrentLimiter *TokenLimiter clients map[uint32]*clientConn @@ -133,6 +135,39 @@ func (s *Server) isUnixSocket() bool { return s.cfg.Socket != "" } +func (s *Server) forwardUnixSocketToTCP() { + addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) + for { + if s.listener == nil { + return // server shutdown has started + } + if uconn, err := s.socket.Accept(); err == nil { + log.Infof("server socket forwarding from [%s] to [%s]", s.cfg.Socket, addr) + go s.handleForwardedConnection(uconn, addr) + } else { + if s.listener != nil { + log.Errorf("server failed to forward from [%s] to [%s], err: %s", s.cfg.Socket, addr, err) + } + } + } +} + +func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) { + defer terror.Call(uconn.Close) + if tconn, err := net.Dial("tcp", addr); err == nil { + go func() { + if _, err := io.Copy(uconn, tconn); err != nil { + log.Warningf("copy server to socket failed: %s", err) + } + }() + if _, err := io.Copy(tconn, uconn); err != nil { + log.Warningf("socket forward copy failed: %s", err) + } + } else { + log.Warningf("socket forward failed: could not connect to [%s], err: %s", addr, err) + } +} + // NewServer creates a new Server. func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s := &Server{ @@ -151,15 +186,24 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { } var err error - if cfg.Socket != "" { - if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { - log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket) - } - } else { + + if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { log.Infof("Server is running MySQL Protocol at [%s]", addr) + if cfg.Socket != "" { + if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil { + log.Infof("Server redirecting [%s] to [%s]", s.cfg.Socket, addr) + go s.forwardUnixSocketToTCP() + } + } } + } else if cfg.Socket != "" { + if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { + log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket) + } + } else { + err = errors.New("Server not configured to listen on either -socket or -host and -port") } if cfg.ProxyProtocol.Networks != "" { @@ -292,6 +336,11 @@ func (s *Server) Close() { terror.Log(errors.Trace(err)) s.listener = nil } + if s.socket != nil { + err := s.socket.Close() + terror.Log(errors.Trace(err)) + s.socket = nil + } if s.statusServer != nil { err := s.statusServer.Close() terror.Log(errors.Trace(err)) @@ -419,7 +468,7 @@ func (s *Server) kickIdleConnection() { for _, cc := range conns { err := cc.Close() if err != nil { - log.Error("close connection error:", err) + log.Errorf("close connection error: %s", err) } } } diff --git a/server/tidb_test.go b/server/tidb_test.go index 2a7a18264f15e..200263caf5931 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -160,9 +160,34 @@ func (ts *TidbTestSuite) TestMultiStatements(c *C) { runTestMultiStatements(c) } +func (ts *TidbTestSuite) TestSocketForwarding(c *C) { + cfg := config.NewConfig() + cfg.Socket = "/tmp/tidbtest.sock" + cfg.Port = 3999 + os.Remove(cfg.Socket) + cfg.Status.ReportStatus = false + + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + defer server.Close() + + runTestRegression(c, func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = "/tmp/tidbtest.sock" + config.DBName = "test" + config.Strict = true + }, "SocketRegression") +} + func (ts *TidbTestSuite) TestSocket(c *C) { cfg := config.NewConfig() cfg.Socket = "/tmp/tidbtest.sock" + cfg.Port = 0 + os.Remove(cfg.Socket) + cfg.Host = "" cfg.Status.ReportStatus = false server, err := NewServer(cfg, ts.tidbdrv) @@ -178,6 +203,7 @@ func (ts *TidbTestSuite) TestSocket(c *C) { config.DBName = "test" config.Strict = true }, "SocketRegression") + } // generateCert generates a private key and a certificate in PEM format based on parameters. diff --git a/session/bootstrap.go b/session/bootstrap.go index bf73987d9cbcd..a0ed79cc58105 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -262,6 +262,7 @@ const ( version22 = 22 version23 = 23 version24 = 24 + version25 = 25 ) func checkBootstrapped(s Session) (bool, error) { @@ -416,6 +417,10 @@ func upgrade(s Session) { upgradeToVer24(s) } + if ver < version25 { + upgradeToVer25(s) + } + updateBootstrapVer(s) _, err = s.Execute(context.Background(), "COMMIT") @@ -670,6 +675,13 @@ func upgradeToVer24(s Session) { writeSystemTZ(s) } +// upgradeToVer25 updates tidb_max_chunk_size to new low bound value 32 if previous value is small than 32. +func upgradeToVer25(s Session) { + sql := fmt.Sprintf("UPDATE HIGH_PRIORITY %[1]s.%[2]s SET VARIABLE_VALUE = '%[4]d' WHERE VARIABLE_NAME = '%[3]s' AND VARIABLE_VALUE < %[4]d", + mysql.SystemDB, mysql.GlobalVariablesTable, variable.TiDBMaxChunkSize, variable.DefInitChunkSize) + mustExecute(s, sql) +} + // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. diff --git a/session/session.go b/session/session.go index f8a0a4e999412..f2e8a51baef32 100644 --- a/session/session.go +++ b/session/session.go @@ -514,23 +514,42 @@ func (s *session) isRetryableError(err error) bool { return kv.IsRetryableError(err) || domain.ErrInfoSchemaChanged.Equal(err) } -func (s *session) retry(ctx context.Context, maxCnt uint) error { - connID := s.sessionVars.ConnectionID - if s.sessionVars.TxnCtx.ForUpdate { - return errForUpdateCantRetry.GenWithStackByArgs(connID) +func (s *session) checkTxnAborted(stmt sqlexec.Statement) error { + if s.txn.doNotCommit == nil { + return nil } - s.sessionVars.RetryInfo.Retrying = true + // If the transaction is aborted, the following statements do not need to execute, except `commit` and `rollback`, + // because they are used to finish the aborted transaction. + if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { + return nil + } + if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.RollbackStmt); ok { + return nil + } + return errors.New("current transaction is aborted, commands ignored until end of transaction block") +} + +func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { var retryCnt uint defer func() { s.sessionVars.RetryInfo.Retrying = false - s.txn.changeToInvalid() // retryCnt only increments on retryable error, so +1 here. metrics.SessionRetry.Observe(float64(retryCnt + 1)) s.sessionVars.SetStatusFlag(mysql.ServerStatusInTrans, false) + if err != nil { + s.rollbackOnError(ctx) + } + s.txn.changeToInvalid() }() + connID := s.sessionVars.ConnectionID + s.sessionVars.RetryInfo.Retrying = true + if s.sessionVars.TxnCtx.ForUpdate { + err = errForUpdateCantRetry.GenWithStackByArgs(connID) + return err + } + nh := GetHistory(s) - var err error var schemaVersion int64 sessVars := s.GetSessionVars() orgStartTS := sessVars.TxnCtx.StartTS @@ -825,8 +844,9 @@ func (s *session) SetGlobalSysVar(name, value string) error { if err != nil { return errors.Trace(err) } + name = strings.ToLower(name) sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, strings.ToLower(name), sVal) + mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) _, _, err = s.ExecRestrictedSQL(s, sql) return errors.Trace(err) } @@ -1202,7 +1222,8 @@ func CreateSession4Test(store kv.Storage) (Session, error) { s, err := CreateSession(store) if err == nil { // initialize session variables for test. - s.GetSessionVars().MaxChunkSize = 2 + s.GetSessionVars().InitChunkSize = 2 + s.GetSessionVars().MaxChunkSize = 32 } return s, errors.Trace(err) } @@ -1376,7 +1397,7 @@ func createSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er const ( notBootstrapped = 0 - currentBootstrapVersion = 24 + currentBootstrapVersion = 25 ) func getStoreBootstrapVersion(store kv.Storage) int64 { @@ -1447,10 +1468,9 @@ const loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variab variable.TiDBHashAggFinalConcurrency + quoteCommaQuote + variable.TiDBBackoffLockFast + quoteCommaQuote + variable.TiDBConstraintCheckInPlace + quoteCommaQuote + - variable.TiDBDDLReorgWorkerCount + quoteCommaQuote + - variable.TiDBDDLReorgBatchSize + quoteCommaQuote + variable.TiDBOptInSubqToJoinAndAgg + quoteCommaQuote + variable.TiDBDistSQLScanConcurrency + quoteCommaQuote + + variable.TiDBInitChunkSize + quoteCommaQuote + variable.TiDBMaxChunkSize + quoteCommaQuote + variable.TiDBEnableCascadesPlanner + quoteCommaQuote + variable.TiDBRetryLimit + quoteCommaQuote + diff --git a/session/session_fail_test.go b/session/session_fail_test.go index f7524c86ea45b..9f160a8fda1ec 100644 --- a/session/session_fail_test.go +++ b/session/session_fail_test.go @@ -22,17 +22,60 @@ import ( ) func (s *testSessionSuite) TestFailStatementCommit(c *C) { - defer gofail.Disable("github.com/pingcap/tidb/session/mockStmtCommitError") tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("create table t (id int)") tk.MustExec("begin") tk.MustExec("insert into t values (1)") + gofail.Enable("github.com/pingcap/tidb/session/mockStmtCommitError", `return(true)`) - _, err := tk.Exec("insert into t values (2)") + _, err := tk.Exec("insert into t values (2),(3),(4),(5)") + c.Assert(err, NotNil) + + gofail.Disable("github.com/pingcap/tidb/session/mockStmtCommitError") + + _, err = tk.Exec("select * from t") + c.Assert(err, NotNil) + _, err = tk.Exec("insert into t values (3)") + c.Assert(err, NotNil) + _, err = tk.Exec("insert into t values (4)") + c.Assert(err, NotNil) + _, err = tk.Exec("commit") c.Assert(err, NotNil) + + tk.MustQuery(`select * from t`).Check(testkit.Rows()) + + tk.MustExec("insert into t values (1)") + + tk.MustExec("begin") + tk.MustExec("insert into t values (2)") tk.MustExec("commit") - tk.MustQuery(`select * from t`).Check(testkit.Rows("1")) + + tk.MustExec("begin") + tk.MustExec("insert into t values (3)") + tk.MustExec("rollback") + + tk.MustQuery(`select * from t`).Check(testkit.Rows("1", "2")) +} + +func (s *testSessionSuite) TestFailStatementCommitInRetry(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t (id int)") + + tk.MustExec("begin") + tk.MustExec("insert into t values (1)") + tk.MustExec("insert into t values (2),(3),(4),(5)") + tk.MustExec("insert into t values (6)") + + gofail.Enable("github.com/pingcap/tidb/session/mockCommitError8942", `return(true)`) + gofail.Enable("github.com/pingcap/tidb/session/mockStmtCommitError", `return(true)`) + _, err := tk.Exec("commit") + c.Assert(err, NotNil) + gofail.Disable("github.com/pingcap/tidb/session/mockCommitError8942") + gofail.Disable("github.com/pingcap/tidb/session/mockStmtCommitError") + + tk.MustExec("insert into t values (6)") + tk.MustQuery(`select * from t`).Check(testkit.Rows("6")) } func (s *testSessionSuite) TestGetTSFailDirtyState(c *C) { diff --git a/session/session_test.go b/session/session_test.go index f95bd913cc15d..b7185580bdc28 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1871,9 +1871,9 @@ func (s *testSchemaSuite) TestTableReaderChunk(c *C) { s.cluster.SplitTable(s.mvccStore, tbl.Meta().ID, 10) tk.Se.GetSessionVars().DistSQLScanConcurrency = 1 - tk.MustExec("set tidb_max_chunk_size = 2") + tk.MustExec("set tidb_init_chunk_size = 2") defer func() { - tk.MustExec(fmt.Sprintf("set tidb_max_chunk_size = %d", variable.DefMaxChunkSize)) + tk.MustExec(fmt.Sprintf("set tidb_init_chunk_size = %d", variable.DefInitChunkSize)) }() rs, err := tk.Exec("select * from chk") c.Assert(err, IsNil) @@ -1894,7 +1894,8 @@ func (s *testSchemaSuite) TestTableReaderChunk(c *C) { numChunks++ } c.Assert(count, Equals, 100) - c.Assert(numChunks, Equals, 50) + // FIXME: revert this result to new group value after distsql can handle initChunkSize. + c.Assert(numChunks, Equals, 1) rs.Close() } @@ -2456,6 +2457,17 @@ func (s *testSessionSuite) TestUpdatePrivilege(c *C) { // In fact, the privlege check for t1 should be update, and for t2 should be select. _, err = tk1.Exec("update t1,t2 set t1.id = t2.id;") c.Assert(err, IsNil) + + // Fix issue 8911 + tk.MustExec("create database weperk") + tk.MustExec("use weperk") + tk.MustExec("create table tb_wehub_server (id int, active_count int, used_count int)") + tk.MustExec("create user 'weperk'") + tk.MustExec("grant all privileges on weperk.* to 'weperk'@'%'") + c.Assert(tk1.Se.Auth(&auth.UserIdentity{Username: "weperk", Hostname: "%"}, + []byte(""), []byte("")), IsTrue) + tk1.MustExec("use weperk") + tk1.MustExec("update tb_wehub_server a set a.active_count=a.active_count+1,a.used_count=a.used_count+1 where id=1") } func (s *testSessionSuite) TestTxnGoString(c *C) { diff --git a/session/tidb.go b/session/tidb.go index 3df30dc955367..41925440b24e4 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -192,6 +192,10 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) var err error var rs sqlexec.RecordSet se := sctx.(*session) + err = se.checkTxnAborted(s) + if err != nil { + return nil, err + } rs, err = s.Exec(ctx) sessVars := se.GetSessionVars() // All the history should be added here. diff --git a/session/txn.go b/session/txn.go index 4aee11490bff6..a4f28f94732d6 100644 --- a/session/txn.go +++ b/session/txn.go @@ -47,6 +47,10 @@ type TxnState struct { buf kv.MemBuffer mutations map[int64]*binlog.TableMutation dirtyTableOP []dirtyTableOperation + + // If doNotCommit is not nil, Commit() will not commit the transaction. + // doNotCommit flag may be set when StmtCommit fail. + doNotCommit error } func (st *TxnState) init() { @@ -144,6 +148,7 @@ type dirtyTableOperation struct { // Commit overrides the Transaction interface. func (st *TxnState) Commit(ctx context.Context) error { + defer st.reset() if len(st.mutations) != 0 || len(st.dirtyTableOP) != 0 || st.buf.Len() != 0 { log.Errorf("The code should never run here, TxnState=%#v, mutations=%#v, dirtyTableOP=%#v, buf=%#v something must be wrong: %s", st, @@ -151,12 +156,36 @@ func (st *TxnState) Commit(ctx context.Context) error { st.dirtyTableOP, st.buf, debug.Stack()) - st.cleanup() return errors.New("invalid transaction") } + if st.doNotCommit != nil { + if err1 := st.Transaction.Rollback(); err1 != nil { + log.Error(err1) + } + return errors.Trace(st.doNotCommit) + } + + // mockCommitError8942 is used for PR #8942. + // gofail: var mockCommitError8942 bool + // if mockCommitError8942 { + // return kv.ErrRetryable + // } + return errors.Trace(st.Transaction.Commit(ctx)) } +// Rollback overrides the Transaction interface. +func (st *TxnState) Rollback() error { + defer st.reset() + return errors.Trace(st.Transaction.Rollback()) +} + +func (st *TxnState) reset() { + st.doNotCommit = nil + st.cleanup() + st.changeToInvalid() +} + // Get overrides the Transaction interface. func (st *TxnState) Get(k kv.Key) ([]byte, error) { val, err := st.buf.Get(k) @@ -311,17 +340,24 @@ func (s *session) getTxnFuture(ctx context.Context) *txnFuture { func (s *session) StmtCommit() error { defer s.txn.cleanup() st := &s.txn + var count int err := kv.WalkMemBuffer(st.buf, func(k kv.Key, v []byte) error { + // gofail: var mockStmtCommitError bool // if mockStmtCommitError { - // return errors.New("mock stmt commit error") + // count++ // } + if count > 3 { + return errors.New("mock stmt commit error") + } + if len(v) == 0 { return errors.Trace(st.Transaction.Delete(k)) } return errors.Trace(st.Transaction.Set(k, v)) }) if err != nil { + st.doNotCommit = err return errors.Trace(err) } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 1830f096e46dd..87d856d089c41 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -47,8 +47,10 @@ type StatementContext struct { // If IsDDLJobInQueue is true, it means the DDL job is in the queue of storage, and it can be handled by the DDL worker. IsDDLJobInQueue bool InInsertStmt bool - InUpdateOrDeleteStmt bool + InUpdateStmt bool + InDeleteStmt bool InSelectStmt bool + InLoadDataStmt bool IgnoreTruncate bool IgnoreZeroInDate bool DupKeyAsWarning bool @@ -379,3 +381,21 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails { sc.mu.Unlock() return details } + +// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types. +// This is the case for `insert`, `update`, `alter table` and `load data infile` statements, when not in strict SQL mode. +// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html +func (sc *StatementContext) ShouldClipToZero() bool { + // TODO: Currently altering column of integer to unsigned integer is not supported. + // If it is supported one day, that case should be added here. + return sc.InInsertStmt || sc.InLoadDataStmt +} + +// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows, +// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types. +func (sc *StatementContext) ShouldIgnoreOverflowError() bool { + if (sc.InInsertStmt && sc.TruncateAsWarning) || sc.InLoadDataStmt { + return true + } + return false +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 86fd36b5837d2..66ba563f0d422 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -381,6 +381,7 @@ func NewSessionVars() *SessionVars { vars.BatchSize = BatchSize{ IndexJoinBatchSize: DefIndexJoinBatchSize, IndexLookupSize: DefIndexLookupSize, + InitChunkSize: DefInitChunkSize, MaxChunkSize: DefMaxChunkSize, DMLBatchSize: DefDMLBatchSize, } @@ -644,6 +645,8 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { return ErrReadOnly case TiDBMaxChunkSize: s.MaxChunkSize = tidbOptPositiveInt32(val, DefMaxChunkSize) + case TiDBInitChunkSize: + s.InitChunkSize = tidbOptPositiveInt32(val, DefInitChunkSize) case TIDBMemQuotaQuery: s.MemQuotaQuery = tidbOptInt64(val, config.GetGlobalConfig().MemQuotaQuery) case TIDBMemQuotaHashJoin: @@ -678,10 +681,6 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { s.OptimizerSelectivityLevel = tidbOptPositiveInt32(val, DefTiDBOptimizerSelectivityLevel) case TiDBEnableTablePartition: s.EnableTablePartition = val - case TiDBDDLReorgWorkerCount: - SetDDLReorgWorkerCounter(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgWorkerCount))) - case TiDBDDLReorgBatchSize: - SetDDLReorgBatchSize(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgBatchSize))) case TiDBDDLReorgPriority: s.setDDLReorgPriority(val) case TiDBForcePriority: @@ -695,6 +694,16 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { return nil } +// SetLocalSystemVar sets values of the local variables which in "server" scope. +func SetLocalSystemVar(name string, val string) { + switch name { + case TiDBDDLReorgWorkerCount: + SetDDLReorgWorkerCounter(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgWorkerCount))) + case TiDBDDLReorgBatchSize: + SetDDLReorgBatchSize(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgBatchSize))) + } +} + // special session variables. const ( SQLModeVar = "sql_mode" @@ -784,6 +793,9 @@ type BatchSize struct { // IndexLookupSize is the number of handles for an index lookup task in index double read executor. IndexLookupSize int + // InitChunkSize defines init row count of a Chunk during query execution. + InitChunkSize int + // MaxChunkSize defines max row count of a Chunk during query execution. MaxChunkSize int } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 1e7437931ae78..1814985a7b084 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -644,6 +644,7 @@ var defaultSysVars = []*SysVar{ {ScopeSession, TiDBDMLBatchSize, strconv.Itoa(DefDMLBatchSize)}, {ScopeSession, TiDBCurrentTS, strconv.Itoa(DefCurretTS)}, {ScopeGlobal | ScopeSession, TiDBMaxChunkSize, strconv.Itoa(DefMaxChunkSize)}, + {ScopeGlobal | ScopeSession, TiDBInitChunkSize, strconv.Itoa(DefInitChunkSize)}, {ScopeGlobal | ScopeSession, TiDBEnableCascadesPlanner, "0"}, {ScopeSession, TIDBMemQuotaQuery, strconv.FormatInt(config.GetGlobalConfig().MemQuotaQuery, 10)}, {ScopeSession, TIDBMemQuotaHashJoin, strconv.FormatInt(DefTiDBMemQuotaHashJoin, 10)}, @@ -671,8 +672,8 @@ var defaultSysVars = []*SysVar{ {ScopeSession, TiDBSlowLogThreshold, strconv.Itoa(logutil.DefaultSlowThreshold)}, {ScopeSession, TiDBQueryLogMaxLen, strconv.Itoa(logutil.DefaultQueryLogMaxLen)}, {ScopeSession, TiDBConfig, ""}, - {ScopeGlobal | ScopeSession, TiDBDDLReorgWorkerCount, strconv.Itoa(DefTiDBDDLReorgWorkerCount)}, - {ScopeGlobal | ScopeSession, TiDBDDLReorgBatchSize, strconv.Itoa(DefTiDBDDLReorgBatchSize)}, + {ScopeGlobal, TiDBDDLReorgWorkerCount, strconv.Itoa(DefTiDBDDLReorgWorkerCount)}, + {ScopeGlobal, TiDBDDLReorgBatchSize, strconv.Itoa(DefTiDBDDLReorgBatchSize)}, {ScopeSession, TiDBDDLReorgPriority, "PRIORITY_LOW"}, {ScopeSession, TiDBForcePriority, mysql.Priority2Str[DefTiDBForcePriority]}, {ScopeSession, TiDBEnableRadixJoin, boolToIntStr(DefTiDBUseRadixJoin)}, diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 4b04b48e1707d..684e2883741f4 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -172,9 +172,12 @@ const ( // when we need to keep the data output order the same as the order of index data. TiDBIndexSerialScanConcurrency = "tidb_index_serial_scan_concurrency" - // tidb_max_chunk_capacity is used to control the max chunk size during query execution. + // TiDBMaxChunkSize is used to control the max chunk size during query execution. TiDBMaxChunkSize = "tidb_max_chunk_size" + // TiDBInitChunkSize is used to control the init chunk size during query execution. + TiDBInitChunkSize = "tidb_init_chunk_size" + // tidb_enable_cascades_planner is used to control whether to enable the cascades planner. TiDBEnableCascadesPlanner = "tidb_enable_cascades_planner" @@ -249,7 +252,8 @@ const ( DefBatchDelete = false DefBatchCommit = false DefCurretTS = 0 - DefMaxChunkSize = 32 + DefInitChunkSize = 32 + DefMaxChunkSize = 1024 DefDMLBatchSize = 20000 DefMaxPreparedStmtCount = -1 DefWaitTimeout = 28800 diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 17a687f612cc1..9c20d2e4d261a 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -229,6 +229,13 @@ func checkInt64SystemVar(name, value string, min, max int64, vars *SessionVars) return value, nil } +const ( + // initChunkSizeUpperBound indicates upper bound value of tidb_init_chunk_size. + initChunkSizeUpperBound = 32 + // maxChunkSizeLowerBound indicates lower bound value of tidb_max_chunk_size. + maxChunkSizeLowerBound = 32 +) + // ValidateSetSystemVar checks if system variable satisfies specific restriction. func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, error) { if strings.EqualFold(value, "DEFAULT") { @@ -356,7 +363,7 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, TiDBHashAggFinalConcurrency, TiDBDistSQLScanConcurrency, TiDBIndexSerialScanConcurrency, TiDBDDLReorgWorkerCount, - TiDBBackoffLockFast, TiDBMaxChunkSize, + TiDBBackoffLockFast, TiDBDMLBatchSize, TiDBOptimizerSelectivityLevel: v, err := strconv.Atoi(value) if err != nil { @@ -400,6 +407,27 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return "", ErrUnsupportedValueForVar.GenWithStackByArgs(name, value) } return upVal, nil + case TiDBInitChunkSize: + v, err := strconv.Atoi(value) + if err != nil { + return value, ErrWrongTypeForVar.GenWithStackByArgs(name) + } + if v <= 0 { + return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) + } + if v > initChunkSizeUpperBound { + return value, errors.Errorf("tidb_init_chunk_size(%d) cannot be bigger than %d", v, initChunkSizeUpperBound) + } + return value, nil + case TiDBMaxChunkSize: + v, err := strconv.Atoi(value) + if err != nil { + return value, ErrWrongTypeForVar.GenWithStackByArgs(name) + } + if v < maxChunkSizeLowerBound { + return value, errors.Errorf("tidb_max_chunk_size(%d) cannot be smaller than %d", v, maxChunkSizeLowerBound) + } + return value, nil } return value, nil } diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 8fcf3c4d15f92..b54dd8d392b96 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -185,9 +185,12 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { SetSessionSystemVar(v, TiDBBatchInsert, types.NewStringDatum("1")) c.Assert(v.BatchInsert, IsTrue) - c.Assert(v.MaxChunkSize, Equals, 32) - SetSessionSystemVar(v, TiDBMaxChunkSize, types.NewStringDatum("2")) - c.Assert(v.MaxChunkSize, Equals, 2) + c.Assert(v.InitChunkSize, Equals, 32) + c.Assert(v.MaxChunkSize, Equals, 1024) + err = SetSessionSystemVar(v, TiDBMaxChunkSize, types.NewStringDatum("2")) + c.Assert(err, NotNil) + err = SetSessionSystemVar(v, TiDBInitChunkSize, types.NewStringDatum("1024")) + c.Assert(err, NotNil) // Test case for TiDBConfig session variable. err = SetSessionSystemVar(v, TiDBConfig, types.NewStringDatum("abc")) @@ -213,10 +216,6 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { SetSessionSystemVar(v, TiDBOptimizerSelectivityLevel, types.NewIntDatum(1)) c.Assert(v.OptimizerSelectivityLevel, Equals, 1) - c.Assert(GetDDLReorgWorkerCounter(), Equals, int32(DefTiDBDDLReorgWorkerCount)) - SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(1)) - c.Assert(GetDDLReorgWorkerCounter(), Equals, int32(1)) - err = SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(-1)) c.Assert(terror.ErrorEqual(err, ErrWrongValueForVar), IsTrue) diff --git a/statistics/feedback.go b/statistics/feedback.go index e21dedbe23b03..d960c8ac6f9d0 100644 --- a/statistics/feedback.go +++ b/statistics/feedback.go @@ -294,8 +294,7 @@ func buildBucketFeedback(h *Histogram, feedback *QueryFeedback) (map[int]*Bucket } total := 0 sc := &stmtctx.StatementContext{TimeZone: time.UTC} - kind := feedback.feedback[0].lower.Kind() - min, max := getMinValue(kind, h.Tp), getMaxValue(kind, h.Tp) + min, max := getMinValue(h.tp), getMaxValue(h.tp) for _, fb := range feedback.feedback { skip, err := fb.adjustFeedbackBoundaries(sc, &min, &max) if err != nil { @@ -723,11 +722,18 @@ func decodeFeedbackForIndex(q *QueryFeedback, pb *queryFeedback, c *CMSketch) { } } -func decodeFeedbackForPK(q *QueryFeedback, pb *queryFeedback) { +func decodeFeedbackForPK(q *QueryFeedback, pb *queryFeedback, isUnsigned bool) { q.tp = pkType // decode feedback for primary key for i := 0; i < len(pb.IntRanges); i += 2 { - lower, upper := types.NewIntDatum(pb.IntRanges[i]), types.NewIntDatum(pb.IntRanges[i+1]) + var lower, upper types.Datum + if isUnsigned { + lower.SetUint64(uint64(pb.IntRanges[i])) + upper.SetUint64(uint64(pb.IntRanges[i+1])) + } else { + lower.SetInt64(pb.IntRanges[i]) + upper.SetInt64(pb.IntRanges[i+1]) + } q.feedback = append(q.feedback, feedback{&lower, &upper, pb.Counts[i/2], 0}) } } @@ -748,7 +754,7 @@ func decodeFeedbackForColumn(q *QueryFeedback, pb *queryFeedback) error { return nil } -func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch) error { +func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch, isUnsigned bool) error { buf := bytes.NewBuffer(val) dec := gob.NewDecoder(buf) pb := &queryFeedback{} @@ -759,7 +765,7 @@ func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch) error { if len(pb.IndexRanges) > 0 || len(pb.HashValues) > 0 { decodeFeedbackForIndex(q, pb, c) } else if len(pb.IntRanges) > 0 { - decodeFeedbackForPK(q, pb) + decodeFeedbackForPK(q, pb, isUnsigned) } else { err := decodeFeedbackForColumn(q, pb) if err != nil { @@ -1072,15 +1078,14 @@ func (q *QueryFeedback) dumpRangeFeedback(sc *stmtctx.StatementContext, h *Handl ran.LowVal[0].SetBytes(lower) ran.HighVal[0].SetBytes(upper) } else { - k := q.hist.GetLower(0).Kind() - if !supportColumnType(k) { + if !supportColumnType(q.hist.tp) { return nil } if ran.LowVal[0].Kind() == types.KindMinNotNull { - ran.LowVal[0] = getMinValue(k, q.hist.Tp) + ran.LowVal[0] = getMinValue(q.hist.tp) } if ran.HighVal[0].Kind() == types.KindMaxValue { - ran.HighVal[0] = getMaxValue(k, q.hist.Tp) + ran.HighVal[0] = getMaxValue(q.hist.tp) } } ranges := q.hist.SplitRange(sc, []*ranger.Range{ran}, q.tp == indexType) @@ -1127,27 +1132,30 @@ func setNextValue(d *types.Datum) { } // supportColumnType checks if the type of the column can be updated by feedback. -func supportColumnType(k byte) bool { - switch k { - case types.KindInt64, types.KindUint64, types.KindFloat32, types.KindFloat64, types.KindString, types.KindBytes, - types.KindMysqlDecimal, types.KindMysqlDuration, types.KindMysqlTime: +func supportColumnType(ft *types.FieldType) bool { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeFloat, + mysql.TypeDouble, mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, + mysql.TypeNewDecimal, mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: return true default: return false } } -func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { - switch k { - case types.KindInt64: - max.SetInt64(types.SignedUpperBound[ft.Tp]) - case types.KindUint64: - max.SetUint64(types.UnsignedUpperBound[ft.Tp]) - case types.KindFloat32: +func getMaxValue(ft *types.FieldType) (max types.Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + max.SetUint64(types.UnsignedUpperBound[ft.Tp]) + } else { + max.SetInt64(types.SignedUpperBound[ft.Tp]) + } + case mysql.TypeFloat: max.SetFloat32(float32(types.GetMaxFloat(ft.Flen, ft.Decimal))) - case types.KindFloat64: + case mysql.TypeDouble: max.SetFloat64(types.GetMaxFloat(ft.Flen, ft.Decimal)) - case types.KindString, types.KindBytes: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: val := types.MaxValueDatum() bytes, err := codec.EncodeKey(nil, nil, val) // should not happen @@ -1155,11 +1163,11 @@ func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { log.Error(err) } max.SetBytes(bytes) - case types.KindMysqlDecimal: + case mysql.TypeNewDecimal: max.SetMysqlDecimal(types.NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) - case types.KindMysqlDuration: + case mysql.TypeDuration: max.SetMysqlDuration(types.Duration{Duration: math.MaxInt64}) - case types.KindMysqlTime: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { max.SetMysqlTime(types.Time{Time: types.MaxDatetime, Type: ft.Tp}) } else { @@ -1169,17 +1177,19 @@ func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { return } -func getMinValue(k byte, ft *types.FieldType) (min types.Datum) { - switch k { - case types.KindInt64: - min.SetInt64(types.SignedLowerBound[ft.Tp]) - case types.KindUint64: - min.SetUint64(0) - case types.KindFloat32: +func getMinValue(ft *types.FieldType) (min types.Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + min.SetUint64(0) + } else { + min.SetInt64(types.SignedLowerBound[ft.Tp]) + } + case mysql.TypeFloat: min.SetFloat32(float32(-types.GetMaxFloat(ft.Flen, ft.Decimal))) - case types.KindFloat64: + case mysql.TypeDouble: min.SetFloat64(-types.GetMaxFloat(ft.Flen, ft.Decimal)) - case types.KindString, types.KindBytes: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: val := types.MinNotNullDatum() bytes, err := codec.EncodeKey(nil, nil, val) // should not happen @@ -1187,11 +1197,11 @@ func getMinValue(k byte, ft *types.FieldType) (min types.Datum) { log.Error(err) } min.SetBytes(bytes) - case types.KindMysqlDecimal: + case mysql.TypeNewDecimal: min.SetMysqlDecimal(types.NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) - case types.KindMysqlDuration: + case mysql.TypeDuration: min.SetMysqlDuration(types.Duration{Duration: math.MinInt64}) - case types.KindMysqlTime: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { min.SetMysqlTime(types.Time{Time: types.MinDatetime, Type: ft.Tp}) } else { diff --git a/statistics/feedback_test.go b/statistics/feedback_test.go index 7ea226febf840..5c71aa0013da2 100644 --- a/statistics/feedback_test.go +++ b/statistics/feedback_test.go @@ -221,7 +221,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { val, err := encodeFeedback(q) c.Assert(err, IsNil) rq := &QueryFeedback{} - c.Assert(decodeFeedback(val, rq, nil), IsNil) + c.Assert(decodeFeedback(val, rq, nil, false), IsNil) for _, fb := range rq.feedback { fb.lower.SetBytes(codec.EncodeInt(nil, fb.lower.GetInt64())) fb.upper.SetBytes(codec.EncodeInt(nil, fb.upper.GetInt64())) @@ -236,7 +236,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { c.Assert(err, IsNil) rq = &QueryFeedback{} cms := NewCMSketch(4, 4) - c.Assert(decodeFeedback(val, rq, cms), IsNil) + c.Assert(decodeFeedback(val, rq, cms, false), IsNil) c.Assert(cms.QueryBytes(codec.EncodeInt(nil, 0)), Equals, uint32(1)) q.feedback = q.feedback[:1] c.Assert(q.Equal(rq), IsTrue) diff --git a/statistics/handle.go b/statistics/handle.go index 8484f448ea2ed..2a65288e9991d 100644 --- a/statistics/handle.go +++ b/statistics/handle.go @@ -72,7 +72,8 @@ func (h *Handle) Clear() { <-h.ddlEventCh } h.feedback = h.feedback[:0] - h.mu.ctx.GetSessionVars().MaxChunkSize = 1 + h.mu.ctx.GetSessionVars().InitChunkSize = 1 + h.mu.ctx.GetSessionVars().MaxChunkSize = 32 h.listHead = &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)} h.globalMap = make(tableDeltaMap) h.mu.rateMap = make(errorRateDeltaMap) diff --git a/statistics/update.go b/statistics/update.go index 143a88a546d77..b65c4cdbb7734 100644 --- a/statistics/update.go +++ b/statistics/update.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx/variable" @@ -561,7 +562,7 @@ func (h *Handle) handleSingleHistogramUpdate(is infoschema.InfoSchema, rows []ch } q := &QueryFeedback{} for _, row := range rows { - err1 := decodeFeedback(row.GetBytes(3), q, cms) + err1 := decodeFeedback(row.GetBytes(3), q, cms, mysql.HasUnsignedFlag(hist.tp.Flag)) if err1 != nil { log.Debugf("decode feedback failed, err: %v", errors.ErrorStack(err)) } diff --git a/statistics/update_test.go b/statistics/update_test.go index 0f152f74fe029..0684d341e4a40 100644 --- a/statistics/update_test.go +++ b/statistics/update_test.go @@ -1362,3 +1362,61 @@ func (s *testStatsSuite) TestFeedbackRanges(c *C) { c.Assert(tbl.Columns[t.colID].ToString(0), Equals, tests[i].hist) } } + +func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + h := s.do.StatsHandle() + oriProbability := statistics.FeedbackProbability + oriNumber := statistics.MaxNumberOfRanges + defer func() { + statistics.FeedbackProbability = oriProbability + statistics.MaxNumberOfRanges = oriNumber + }() + statistics.FeedbackProbability = 1 + + testKit.MustExec("use test") + testKit.MustExec("create table t (a tinyint unsigned, primary key(a))") + for i := 0; i < 20; i++ { + testKit.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + h.HandleDDLEvent(<-h.DDLEventCh()) + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + testKit.MustExec("analyze table t with 3 buckets") + for i := 30; i < 40; i++ { + testKit.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + tests := []struct { + sql string + hist string + }{ + { + sql: "select * from t where a <= 50", + hist: "column:1 ndv:30 totColSize:0\n" + + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", + }, + { + sql: "select count(*) from t", + hist: "column:1 ndv:30 totColSize:0\n" + + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 14 lower_bound: 16 upper_bound: 255 repeats: 0", + }, + } + is := s.do.InfoSchema() + table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + for i, t := range tests { + testKit.MustQuery(t.sql) + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + c.Assert(h.DumpStatsFeedbackToKV(), IsNil) + c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) + c.Assert(err, IsNil) + h.Update(is) + tblInfo := table.Meta() + tbl := h.GetTableStats(tblInfo) + c.Assert(tbl.Columns[1].ToString(0), Equals, tests[i].hist) + } +} diff --git a/store/tikv/backoff.go b/store/tikv/backoff.go index a24dc39aabe33..ca99dfa7785e9 100644 --- a/store/tikv/backoff.go +++ b/store/tikv/backoff.go @@ -87,7 +87,7 @@ const ( boTiKVRPC backoffType = iota BoTxnLock boTxnLockFast - boPDRPC + BoPDRPC BoRegionMiss BoUpdateLeader boServerBusy @@ -104,7 +104,7 @@ func (t backoffType) createFn(vars *kv.Variables) func(context.Context) int { return NewBackoffFn(200, 3000, EqualJitter) case boTxnLockFast: return NewBackoffFn(vars.BackoffLockFast, 3000, EqualJitter) - case boPDRPC: + case BoPDRPC: return NewBackoffFn(500, 3000, EqualJitter) case BoRegionMiss: // change base time to 2ms, because it may recover soon. @@ -125,7 +125,7 @@ func (t backoffType) String() string { return "txnLock" case boTxnLockFast: return "txnLockFast" - case boPDRPC: + case BoPDRPC: return "pdRPC" case BoRegionMiss: return "regionMiss" @@ -143,7 +143,7 @@ func (t backoffType) TError() error { return ErrTiKVServerTimeout case BoTxnLock, boTxnLockFast: return ErrResolveLockTimeout - case boPDRPC: + case BoPDRPC: return ErrPDServerTimeout.GenWithStackByArgs(txnRetryableMark) case BoRegionMiss, BoUpdateLeader: return ErrRegionUnavailable diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index 4561b586f176f..bc31122fb1984 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -122,6 +122,11 @@ const ( gcEnableValue = "true" gcDisableValue = "false" gcDefaultEnableValue = true + + gcModeKey = "tikv_gc_mode" + gcModeCentral = "central" + gcModeDistributed = "distributed" + gcModeDefault = gcModeDistributed ) var gcSafePointCacheInterval = tikv.GcSafePointCacheInterval @@ -136,6 +141,7 @@ var gcVariableComments = map[string]string{ gcSafePointKey: "All versions after safe point can be accessed. (DO NOT EDIT)", gcConcurrencyKey: "How many go routines used to do GC parallel, [1, 128], default 2", gcEnableKey: "Current GC enable status", + gcModeKey: "Mode of GC, \"central\" or \"distributed\"", } func (w *GCWorker) start(ctx context.Context, wg *sync.WaitGroup) { @@ -400,14 +406,34 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64) { w.done <- errors.Trace(err) return } - err = w.doGC(ctx, safePoint) + + useDistributedGC, err := w.checkUseDistributedGC() if err != nil { - log.Errorf("[gc worker] %s do GC returns an error %v", w.uuid, errors.ErrorStack(err)) - w.gcIsRunning = false - metrics.GCJobFailureCounter.WithLabelValues("gc").Inc() - w.done <- errors.Trace(err) - return + log.Errorf("[gc worker] %s failed to load gc mode, fall back to central mode. err: %v", w.uuid, errors.ErrorStack(err)) + metrics.GCJobFailureCounter.WithLabelValues("check_gc_mode").Inc() + useDistributedGC = false + } + + if useDistributedGC { + err = w.uploadSafePointToPD(ctx, safePoint) + if err != nil { + log.Errorf("[gc worker] %s failed to upload safe point to PD: %v", w.uuid, errors.ErrorStack(err)) + w.gcIsRunning = false + metrics.GCJobFailureCounter.WithLabelValues("upload_safe_point").Inc() + w.done <- errors.Trace(err) + return + } + } else { + err = w.doGC(ctx, safePoint) + if err != nil { + log.Errorf("[gc worker] %s do GC returns an error %v", w.uuid, errors.ErrorStack(err)) + w.gcIsRunning = false + metrics.GCJobFailureCounter.WithLabelValues("gc").Inc() + w.done <- errors.Trace(err) + return + } } + w.done <- nil } @@ -484,7 +510,7 @@ func (w *GCWorker) sendUnsafeDestroyRangeRequest(ctx context.Context, startKey [ // Get all stores every time deleting a region. So the store list is less probably to be stale. stores, err := w.pdClient.GetAllStores(ctx) if err != nil { - log.Errorf("[gc worker] %s delete ranges: got an error while trying to get store list from pd: %v", w.uuid, errors.ErrorStack(err)) + log.Errorf("[gc worker] %s delete ranges: got an error while trying to get store list from PD: %v", w.uuid, errors.ErrorStack(err)) return errors.Trace(err) } @@ -550,6 +576,28 @@ func (w *GCWorker) loadGCConcurrencyWithDefault() (int, error) { return jobConcurrency, nil } +func (w *GCWorker) checkUseDistributedGC() (bool, error) { + str, err := w.loadValueFromSysTable(gcModeKey) + if err != nil { + return false, errors.Trace(err) + } + if str == "" { + err = w.saveValueToSysTable(gcModeKey, gcModeDefault) + if err != nil { + return false, errors.Trace(err) + } + str = gcModeDefault + } + if strings.EqualFold(str, gcModeDistributed) { + return true, nil + } + if strings.EqualFold(str, gcModeCentral) { + return false, nil + } + log.Warnf("[gc worker] \"%v\" is not a valid gc mode. distributed mode will be used.", str) + return true, nil +} + func (w *GCWorker) resolveLocks(ctx context.Context, safePoint uint64) error { metrics.GCWorkerCounter.WithLabelValues("resolve_locks").Inc() @@ -642,6 +690,34 @@ func (w *GCWorker) resolveLocks(ctx context.Context, safePoint uint64) error { return nil } +func (w *GCWorker) uploadSafePointToPD(ctx context.Context, safePoint uint64) error { + var newSafePoint uint64 + var err error + + bo := tikv.NewBackoffer(ctx, tikv.GcOneRegionMaxBackoff) + for { + newSafePoint, err = w.pdClient.UpdateGCSafePoint(ctx, safePoint) + if err != nil { + if errors.Cause(err) == context.Canceled { + return errors.Trace(err) + } + err = bo.Backoff(tikv.BoPDRPC, errors.Errorf("failed to upload safe point to PD, err: %v", err)) + if err != nil { + return errors.Trace(err) + } + continue + } + break + } + + if newSafePoint != safePoint { + log.Warnf("[gc worker] %s, PD rejected our safe point %v but is using another safe point %v", w.uuid, safePoint, newSafePoint) + return errors.Errorf("PD rejected our safe point %v but is using another safe point %v", safePoint, newSafePoint) + } + log.Infof("[gc worker] %s sent safe point %v to PD", w.uuid, safePoint) + return nil +} + type gcTask struct { startKey []byte endKey []byte diff --git a/store/tikv/gcworker/gc_worker_test.go b/store/tikv/gcworker/gc_worker_test.go index 1bf14628ba45b..579a8b9b456f5 100644 --- a/store/tikv/gcworker/gc_worker_test.go +++ b/store/tikv/gcworker/gc_worker_test.go @@ -214,3 +214,28 @@ func (s *testGCWorkerSuite) TestDoGC(c *C) { err = s.gcWorker.doGC(ctx, 20) c.Assert(err, IsNil) } + +func (s *testGCWorkerSuite) TestCheckGCMode(c *C) { + useDistributedGC, err := s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, true) + // Now the row must be set to the default value. + str, err := s.gcWorker.loadValueFromSysTable(gcModeKey) + c.Assert(err, IsNil) + c.Assert(str, Equals, gcModeDistributed) + + s.gcWorker.saveValueToSysTable(gcModeKey, gcModeCentral) + useDistributedGC, err = s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, false) + + s.gcWorker.saveValueToSysTable(gcModeKey, gcModeDistributed) + useDistributedGC, err = s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, true) + + s.gcWorker.saveValueToSysTable(gcModeKey, "invalid_mode") + useDistributedGC, err = s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, true) +} diff --git a/store/tikv/kv.go b/store/tikv/kv.go index beb7451718448..a816826f696d4 100644 --- a/store/tikv/kv.go +++ b/store/tikv/kv.go @@ -317,7 +317,7 @@ func (s *tikvStore) getTimestampWithRetry(bo *Backoffer) (uint64, error) { if err == nil { return startTS, nil } - err = bo.Backoff(boPDRPC, errors.Errorf("get timestamp failed: %v", err)) + err = bo.Backoff(BoPDRPC, errors.Errorf("get timestamp failed: %v", err)) if err != nil { return 0, errors.Trace(err) } diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index 919f3aa6d3472..18dfcf3589d80 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -330,7 +330,7 @@ func (c *RegionCache) loadRegion(bo *Backoffer, key []byte) (*Region, error) { var backoffErr error for { if backoffErr != nil { - err := bo.Backoff(boPDRPC, backoffErr) + err := bo.Backoff(BoPDRPC, backoffErr) if err != nil { return nil, errors.Trace(err) } @@ -364,7 +364,7 @@ func (c *RegionCache) loadRegionByID(bo *Backoffer, regionID uint64) (*Region, e var backoffErr error for { if backoffErr != nil { - err := bo.Backoff(boPDRPC, backoffErr) + err := bo.Backoff(BoPDRPC, backoffErr) if err != nil { return nil, errors.Trace(err) } @@ -437,7 +437,7 @@ func (c *RegionCache) loadStoreAddr(bo *Backoffer, id uint64) (string, error) { return "", errors.Trace(err) } err = errors.Errorf("loadStore from PD failed, id: %d, err: %v", id, err) - if err = bo.Backoff(boPDRPC, err); err != nil { + if err = bo.Backoff(BoPDRPC, err); err != nil { return "", errors.Trace(err) } continue diff --git a/table/column.go b/table/column.go index d6e2d6affd832..2c578417f3c34 100644 --- a/table/column.go +++ b/table/column.go @@ -390,7 +390,7 @@ func getColDefaultValueFromNil(ctx sessionctx.Context, col *model.ColumnInfo) (t sc.AppendWarning(ErrColumnCantNull.GenWithStackByArgs(col.Name)) return GetZeroValue(col), nil } - return types.Datum{}, ErrNoDefaultValue.GenWithStack("Field '%s' doesn't have a default value", col.Name) + return types.Datum{}, ErrNoDefaultValue.GenWithStackByArgs(col.Name) } // GetZeroValue gets zero value for given column type. diff --git a/table/table.go b/table/table.go index b23a77dce57c8..1ca0325c5b172 100644 --- a/table/table.go +++ b/table/table.go @@ -58,7 +58,7 @@ var ( // ErrNoDefaultValue is used when insert a row, the column value is not given, and the column has not null flag // and it doesn't have a default value. - ErrNoDefaultValue = terror.ClassTable.New(codeNoDefaultValue, "field doesn't have a default value") + ErrNoDefaultValue = terror.ClassTable.New(codeNoDefaultValue, mysql.MySQLErrName[mysql.ErrNoDefaultForField]) // ErrIndexOutBound returns for index column offset out of bound. ErrIndexOutBound = terror.ClassTable.New(codeIndexOutBound, "index column offset out of bound") // ErrUnsupportedOp returns for unsupported operation. @@ -89,6 +89,7 @@ type RecordIterFunc func(h int64, rec []types.Datum, cols []*Column) (more bool, // AddRecordOpt contains the options will be used when adding a record. type AddRecordOpt struct { CreateIdxOpt + IsUpdate bool } // Table is used to retrieve and modify rows in table. diff --git a/table/tables/tables.go b/table/tables/tables.go index e46ee22091b7e..82d7e4aad2949 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -425,7 +425,10 @@ func (t *tableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. } var hasRecordID bool cols := t.Cols() - if len(r) > len(cols) { + // opt.IsUpdate is a flag for update. + // If handle ID is changed when update, update will remove the old record first, and then call `AddRecord` to add a new record. + // Currently, only insert can set _tidb_rowid, update can not update _tidb_rowid. + if len(r) > len(cols) && !opt.IsUpdate { // The last value is _tidb_rowid. recordID = r[len(r)-1].GetInt64() hasRecordID = true @@ -466,12 +469,21 @@ func (t *tableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. for _, col := range t.WritableCols() { var value types.Datum - if col.State != model.StatePublic { + // Update call `AddRecord` will already handle the write only column default value. + // Only insert should add default value for write only column. + if col.State != model.StatePublic && !opt.IsUpdate { // If col is in write only or write reorganization state, we must add it with its default value. value, err = table.GetColOriginDefaultValue(ctx, col.ToInfo()) if err != nil { return 0, errors.Trace(err) } + // add value to `r` for dirty db in transaction. + // Otherwise when update will panic cause by get value of column in write only state from dirty db. + if col.Offset < len(r) { + r[col.Offset] = value + } else { + r = append(r, value) + } } else { value = r[col.Offset] } @@ -907,6 +919,16 @@ func (t *tableCommon) AllocAutoID(ctx sessionctx.Context) (int64, error) { return 0, errors.Trace(err) } if t.meta.ShardRowIDBits > 0 { + if t.overflowShardBits(rowID) { + // If overflow, the rowID may be duplicated. For examples, + // t.meta.ShardRowIDBits = 4 + // rowID = 0010111111111111111111111111111111111111111111111111111111111111 + // shard = 01000000000000000000000000000000000000000000000000000000000000000 + // will be duplicated with: + // rowID = 0100111111111111111111111111111111111111111111111111111111111111 + // shard = 0010000000000000000000000000000000000000000000000000000000000000 + return 0, autoid.ErrAutoincReadFailed + } txnCtx := ctx.GetSessionVars().TxnCtx if txnCtx.Shard == nil { shard := t.calcShard(txnCtx.StartTS) @@ -917,6 +939,12 @@ func (t *tableCommon) AllocAutoID(ctx sessionctx.Context) (int64, error) { return rowID, nil } +// overflowShardBits check whether the rowID overflow `1<<(64-t.meta.ShardRowIDBits-1) -1`. +func (t *tableCommon) overflowShardBits(rowID int64) bool { + mask := (1< 0 +} + func (t *tableCommon) calcShard(startTS uint64) int64 { var buf [8]byte binary.LittleEndian.PutUint64(buf[:], startTS) @@ -1070,7 +1098,8 @@ func newCtxForPartitionExpr() sessionctx.Context { sctx := &ctxForPartitionExpr{ sessionVars: variable.NewSessionVars(), } - sctx.sessionVars.MaxChunkSize = 2 + sctx.sessionVars.InitChunkSize = 2 + sctx.sessionVars.MaxChunkSize = 32 sctx.sessionVars.StmtCtx.TimeZone = time.UTC return sctx } diff --git a/types/convert.go b/types/convert.go index 27df51b771997..87e3cd82e24a8 100644 --- a/types/convert.go +++ b/types/convert.go @@ -106,7 +106,11 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) { } // ConvertIntToUint converts an int value to an uint value. -func ConvertIntToUint(val int64, upperBound uint64, tp byte) (uint64, error) { +func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) { + if sc.ShouldClipToZero() && val < 0 { + return 0, overflow(val, tp) + } + if uint64(val) > upperBound { return upperBound, overflow(val, tp) } @@ -124,9 +128,12 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) { } // ConvertFloatToUint converts a float value to an uint value. -func ConvertFloatToUint(fval float64, upperBound uint64, tp byte) (uint64, error) { +func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) { val := RoundFloat(fval) if val < 0 { + if sc.ShouldClipToZero() { + return 0, overflow(val, tp) + } return uint64(int64(val)), overflow(val, tp) } @@ -400,7 +407,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble) } bound := UnsignedUpperBound[mysql.TypeLonglong] - u, err := ConvertFloatToUint(f, bound, mysql.TypeDouble) + u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble) return int64(u), errors.Trace(err) case json.TypeCodeString: return StrToInt(sc, hack.String(j.GetString())) @@ -423,7 +430,7 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6 case json.TypeCodeInt64: return float64(j.GetInt64()), nil case json.TypeCodeUint64: - u, err := ConvertIntToUint(j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) return float64(u), errors.Trace(err) case json.TypeCodeFloat64: return j.GetFloat64(), nil diff --git a/types/datum.go b/types/datum.go index a476d4d37e17e..b8bea0dc25452 100644 --- a/types/datum.go +++ b/types/datum.go @@ -872,21 +872,21 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( ) switch d.k { case KindInt64: - val, err = ConvertIntToUint(d.GetInt64(), upperBound, tp) + val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp) case KindUint64: val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp) case KindFloat32, KindFloat64: - val, err = ConvertFloatToUint(d.GetFloat64(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp) case KindString, KindBytes: - val, err = StrToUint(sc, d.GetString()) - if err != nil { - return ret, errors.Trace(err) + uval, err1 := StrToUint(sc, d.GetString()) + if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() { + return ret, errors.Trace(err1) } - val, err = ConvertUintToUint(val, upperBound, tp) + val, err = ConvertUintToUint(uval, upperBound, tp) if err != nil { return ret, errors.Trace(err) } - ret.SetUint64(val) + err = err1 case KindMysqlTime: dec := d.GetMysqlTime().ToNumber() err = dec.Round(dec, 0, ModeHalfEven) @@ -894,7 +894,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( if err == nil { err = err1 } - val, err1 = ConvertIntToUint(ival, upperBound, tp) + val, err1 = ConvertIntToUint(sc, ival, upperBound, tp) if err == nil { err = err1 } @@ -903,18 +903,18 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( err = dec.Round(dec, 0, ModeHalfEven) ival, err1 := dec.ToInt() if err1 == nil { - val, err = ConvertIntToUint(ival, upperBound, tp) + val, err = ConvertIntToUint(sc, ival, upperBound, tp) } case KindMysqlDecimal: fval, err1 := d.GetMysqlDecimal().ToFloat64() - val, err = ConvertFloatToUint(fval, upperBound, tp) + val, err = ConvertFloatToUint(sc, fval, upperBound, tp) if err == nil { err = err1 } case KindMysqlEnum: - val, err = ConvertFloatToUint(d.GetMysqlEnum().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp) case KindMysqlSet: - val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: val, err = d.GetBinaryLiteral().ToInt(sc) case KindMysqlJSON: @@ -1144,7 +1144,7 @@ func ProduceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *stmtctx.Statem return nil, errors.Trace(err) } if !dec.IsZero() && frac > decimal && dec.Compare(&old) != 0 { - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { // fix https://github.com/pingcap/tidb/issues/3895 // fix https://github.com/pingcap/tidb/issues/5532 sc.AppendWarning(ErrTruncated) diff --git a/types/format_test.go b/types/format_test.go index 16fbf6a963971..493ac31e75ccf 100644 --- a/types/format_test.go +++ b/types/format_test.go @@ -103,6 +103,17 @@ func (s *testTimeSuite) TestStrToDate(c *C) { {`10:13 PM`, `%l:%i %p`, types.FromDate(0, 0, 0, 22, 13, 0, 0)}, {`12:00:00 AM`, `%h:%i:%s %p`, types.FromDate(0, 0, 0, 0, 0, 0, 0)}, {`12:00:00 PM`, `%h:%i:%s %p`, types.FromDate(0, 0, 0, 12, 0, 0, 0)}, + {`18/10/22`, `%y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`8/10/22`, `%y/%m/%d`, types.FromDate(2008, 10, 22, 0, 0, 0, 0)}, + {`69/10/22`, `%y/%m/%d`, types.FromDate(2069, 10, 22, 0, 0, 0, 0)}, + {`70/10/22`, `%y/%m/%d`, types.FromDate(1970, 10, 22, 0, 0, 0, 0)}, + {`18/10/22`, `%Y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`2018/10/22`, `%Y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`8/10/22`, `%Y/%m/%d`, types.FromDate(2008, 10, 22, 0, 0, 0, 0)}, + {`69/10/22`, `%Y/%m/%d`, types.FromDate(2069, 10, 22, 0, 0, 0, 0)}, + {`70/10/22`, `%Y/%m/%d`, types.FromDate(1970, 10, 22, 0, 0, 0, 0)}, + {`18/10/22`, `%Y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`100/10/22`, `%Y/%m/%d`, types.FromDate(100, 10, 22, 0, 0, 0, 0)}, } for i, tt := range tests { var t types.Time @@ -121,6 +132,7 @@ func (s *testTimeSuite) TestStrToDate(c *C) { {`23:60:12`, `%T`}, // invalid minute {`18`, `%l`}, {`00:21:22 AM`, `%h:%i:%s %p`}, + {`100/10/22`, `%y/%m/%d`}, } for _, tt := range errTests { var t types.Time diff --git a/types/mydecimal.go b/types/mydecimal.go index afb34eca3523f..30c0d736bfcd8 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -234,6 +234,23 @@ func (d *MyDecimal) removeLeadingZeros() (wordIdx int, digitsInt int) { return } +func (d *MyDecimal) removeTrailingZeros() (lastWordIdx int, digitsFrac int) { + digitsFrac = int(d.digitsFrac) + i := ((digitsFrac - 1) % digitsPerWord) + 1 + lastWordIdx = digitsToWords(int(d.digitsInt)) + digitsToWords(int(d.digitsFrac)) + for digitsFrac > 0 && d.wordBuf[lastWordIdx-1] == 0 { + digitsFrac -= i + i = digitsPerWord + lastWordIdx-- + } + if digitsFrac > 0 { + digitsFrac -= countTrailingZeroes(9-((digitsFrac-1)%digitsPerWord), d.wordBuf[lastWordIdx-1]) + } else { + digitsFrac = 0 + } + return +} + // ToString converts decimal to its printable string representation without rounding. // // RETURN VALUE @@ -1212,6 +1229,26 @@ func (d *MyDecimal) ToBin(precision, frac int) ([]byte, error) { return bin, err } +// ToHashKey removes the leading and trailing zeros and generates a hash key. +// Two Decimals dec0 and dec1 with different fraction will generate the same hash keys if dec0.Compare(dec1) == 0. +func (d *MyDecimal) ToHashKey() ([]byte, error) { + _, digitsInt := d.removeLeadingZeros() + _, digitsFrac := d.removeTrailingZeros() + prec := digitsInt + digitsFrac + if prec == 0 { // zeroDecimal + prec = 1 + } + buf, err := d.ToBin(prec, digitsFrac) + if err == ErrTruncated { + // This err is caused by shorter digitsFrac; + // After removing the trailing zeros from a Decimal, + // so digitsFrac may be less than the real digitsFrac of the Decimal, + // thus ErrTruncated may be raised, we can ignore it here. + err = nil + } + return buf, err +} + // PrecisionAndFrac returns the internal precision and frac number. func (d *MyDecimal) PrecisionAndFrac() (precision, frac int) { frac = int(d.digitsFrac) diff --git a/types/mydecimal_test.go b/types/mydecimal_test.go index 1770554629b8c..e799692231c6a 100644 --- a/types/mydecimal_test.go +++ b/types/mydecimal_test.go @@ -15,6 +15,7 @@ package types import ( "strings" + "testing" . "github.com/pingcap/check" ) @@ -145,6 +146,114 @@ func (s *testMyDecimalSuite) TestToFloat(c *C) { } } +func (s *testMyDecimalSuite) TestToHashKey(c *C) { + tests := []struct { + numbers []string + }{ + {[]string{"1.1", "1.1000", "1.1000000", "1.10000000000", "01.1", "0001.1", "001.1000000"}}, + {[]string{"-1.1", "-1.1000", "-1.1000000", "-1.10000000000", "-01.1", "-0001.1", "-001.1000000"}}, + {[]string{".1", "0.1", "000000.1", ".10000", "0000.10000", "000000000000000000.1"}}, + {[]string{"0", "0000", ".0", ".00000", "00000.00000", "-0", "-0000", "-.0", "-.00000", "-00000.00000"}}, + {[]string{".123456789123456789", ".1234567891234567890", ".12345678912345678900", ".123456789123456789000", ".1234567891234567890000", "0.123456789123456789", + ".1234567891234567890000000000", "0000000.123456789123456789000"}}, + {[]string{"12345", "012345", "0012345", "0000012345", "0000000012345", "00000000000012345", "12345.", "12345.00", "12345.000000000", "000012345.0000"}}, + {[]string{"123E5", "12300000", "00123E5", "000000123E5", "12300000.00000000"}}, + {[]string{"123E-2", "1.23", "00000001.23", "1.2300000000000000", "000000001.23000000000000"}}, + } + for _, ca := range tests { + keys := make([]string, 0, len(ca.numbers)) + for _, num := range ca.numbers { + var dec MyDecimal + c.Check(dec.FromString([]byte(num)), IsNil) + key, err := dec.ToHashKey() + c.Check(err, IsNil) + keys = append(keys, string(key)) + } + + for i := 1; i < len(keys); i++ { + c.Check(keys[0], Equals, keys[i]) + } + } + + binTests := []struct { + hashNumbers []string + binNumbers []string + }{ + {[]string{"1.1", "1.1000", "1.1000000", "1.10000000000", "01.1", "0001.1", "001.1000000"}, + []string{"1.1", "0001.1", "01.1"}}, + {[]string{"-1.1", "-1.1000", "-1.1000000", "-1.10000000000", "-01.1", "-0001.1", "-001.1000000"}, + []string{"-1.1", "-0001.1", "-01.1"}}, + {[]string{".1", "0.1", "000000.1", ".10000", "0000.10000", "000000000000000000.1"}, + []string{".1", "0.1", "000000.1", "00.1"}}, + {[]string{"0", "0000", ".0", ".00000", "00000.00000", "-0", "-0000", "-.0", "-.00000", "-00000.00000"}, + []string{"0", "0000", "00", "-0", "-00", "-000000"}}, + {[]string{".123456789123456789", ".1234567891234567890", ".12345678912345678900", ".123456789123456789000", ".1234567891234567890000", "0.123456789123456789", + ".1234567891234567890000000000", "0000000.123456789123456789000"}, + []string{".123456789123456789", "0.123456789123456789", "0000.123456789123456789", "0000000.123456789123456789"}}, + {[]string{"12345", "012345", "0012345", "0000012345", "0000000012345", "00000000000012345", "12345.", "12345.00", "12345.000000000", "000012345.0000"}, + []string{"12345", "012345", "000012345", "000000000000012345"}}, + {[]string{"123E5", "12300000", "00123E5", "000000123E5", "12300000.00000000"}, + []string{"12300000", "123E5", "00123E5", "0000000000123E5"}}, + {[]string{"123E-2", "1.23", "00000001.23", "1.2300000000000000", "000000001.23000000000000"}, + []string{"123E-2", "1.23", "000001.23", "0000000000001.23"}}, + } + for _, ca := range binTests { + keys := make([]string, 0, len(ca.hashNumbers)+len(ca.binNumbers)) + for _, num := range ca.hashNumbers { + var dec MyDecimal + c.Check(dec.FromString([]byte(num)), IsNil) + key, err := dec.ToHashKey() + c.Check(err, IsNil) + keys = append(keys, string(key)) + } + for _, num := range ca.binNumbers { + var dec MyDecimal + c.Check(dec.FromString([]byte(num)), IsNil) + prec, frac := dec.PrecisionAndFrac() // remove leading zeros but trailing zeros remain + key, err := dec.ToBin(prec, frac) + c.Check(err, IsNil) + keys = append(keys, string(key)) + } + + for i := 1; i < len(keys); i++ { + c.Check(keys[0], Equals, keys[i]) + } + } +} + +func (s *testMyDecimalSuite) TestRemoveTrailingZeros(c *C) { + tests := []string{ + "0", "0.0", ".0", ".00000000", "0.0000", "0000", "0000.0", "0000.000", + "-0", "-0.0", "-.0", "-.00000000", "-0.0000", "-0000", "-0000.0", "-0000.000", + "123123123", "213123.", "21312.000", "21321.123", "213.1230000", "213123.000123000", + "-123123123", "-213123.", "-21312.000", "-21321.123", "-213.1230000", "-213123.000123000", + "123E5", "12300E-5", "0.00100E1", "0.001230E-3", + "123987654321.123456789000", "000000000123", "123456789.987654321", "999.999000", + } + for _, ca := range tests { + var dec MyDecimal + c.Check(dec.FromString([]byte(ca)), IsNil) + + // calculate the number of digits after point but trailing zero + digitsFracExp := 0 + str := string(dec.ToString()) + point := strings.Index(str, ".") + if point != -1 { + pos := len(str) - 1 + for pos > point { + if str[pos] != '0' { + break + } + pos-- + } + digitsFracExp = pos - point + } + + _, digitsFrac := dec.removeTrailingZeros() + c.Check(digitsFrac, Equals, digitsFracExp) + } +} + func (s *testMyDecimalSuite) TestShift(c *C) { type tcase struct { input string @@ -778,3 +887,57 @@ func (s *testMyDecimalSuite) TestMaxOrMin(c *C) { c.Assert(dec.String(), Equals, tt.result) } } + +func benchmarkMyDecimalToBinOrHashCases() []string { + return []string{ + "1.000000000000", "3", "12.000000000", "120", + "120000", "100000000000.00000", "0.000000001200000000", + "98765.4321", "-123.456000000000000000", + "0", "0000000000", "0.00000000000", + } +} + +func BenchmarkMyDecimalToBin(b *testing.B) { + cases := benchmarkMyDecimalToBinOrHashCases() + decs := make([]*MyDecimal, 0, len(cases)) + for _, ca := range cases { + var dec MyDecimal + if err := dec.FromString([]byte(ca)); err != nil { + b.Fatal(err) + } + decs = append(decs, &dec) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, dec := range decs { + prec, frac := dec.PrecisionAndFrac() + _, err := dec.ToBin(prec, frac) + if err != nil { + b.Fatal(err) + } + } + } +} + +func BenchmarkMyDecimalToHashKey(b *testing.B) { + cases := benchmarkMyDecimalToBinOrHashCases() + decs := make([]*MyDecimal, 0, len(cases)) + for _, ca := range cases { + var dec MyDecimal + if err := dec.FromString([]byte(ca)); err != nil { + b.Fatal(err) + } + decs = append(decs, &dec) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, dec := range decs { + _, err := dec.ToHashKey() + if err != nil { + b.Fatal(err) + } + } + } +} diff --git a/types/parser_driver/value_expr.go b/types/parser_driver/value_expr.go index 08b6872ce2f21..9ab8e66e8f650 100644 --- a/types/parser_driver/value_expr.go +++ b/types/parser_driver/value_expr.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/hack" @@ -69,7 +70,7 @@ type ValueExpr struct { } // Restore implements Node interface. -func (n *ValueExpr) Restore(ctx *ast.RestoreCtx) error { +func (n *ValueExpr) Restore(ctx *format.RestoreCtx) error { switch n.Kind() { case types.KindNull: ctx.WriteKeyWord("NULL") @@ -195,7 +196,7 @@ type ParamMarkerExpr struct { } // Restore implements Node interface. -func (n *ParamMarkerExpr) Restore(ctx *ast.RestoreCtx) error { +func (n *ParamMarkerExpr) Restore(ctx *format.RestoreCtx) error { ctx.WritePlain("?") return nil } diff --git a/types/parser_driver/value_exprl_test.go b/types/parser_driver/value_expr_test.go similarity index 93% rename from types/parser_driver/value_exprl_test.go rename to types/parser_driver/value_expr_test.go index 07746f77dc42d..7cbc6643246a6 100644 --- a/types/parser_driver/value_exprl_test.go +++ b/types/parser_driver/value_expr_test.go @@ -18,7 +18,7 @@ import ( "testing" . "github.com/pingcap/check" - "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" "github.com/pingcap/tidb/types" ) @@ -52,7 +52,7 @@ func (s *testValueExprRestoreSuite) TestValueExprRestore(c *C) { for _, testCase := range testCases { sb.Reset() expr := &ValueExpr{Datum: testCase.datum} - err := expr.Restore(ast.NewRestoreCtx(ast.DefaultRestoreFlags, &sb)) + err := expr.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) c.Assert(err, IsNil) c.Assert(sb.String(), Equals, testCase.expect, Commentf("Datum: %#v", testCase.datum)) } diff --git a/types/time.go b/types/time.go index fec734025aea1..41d32a7422245 100644 --- a/types/time.go +++ b/types/time.go @@ -2165,6 +2165,8 @@ var dateFormatParserTable = map[string]dateFormatParser{ "%S": secondsNumeric, // Seconds (00..59) "%T": time24Hour, // Time, 24-hour (hh:mm:ss) "%Y": yearNumericFourDigits, // Year, numeric, four digits + // Deprecated since MySQL 5.7.5 + "%y": yearNumericTwoDigits, // Year, numeric (two digits) // TODO: Add the following... // "%a": abbreviatedWeekday, // Abbreviated weekday name (Sun..Sat) // "%D": dayOfMonthWithSuffix, // Day of the month with English suffix (0th, 1st, 2nd, 3rd) @@ -2176,8 +2178,6 @@ var dateFormatParserTable = map[string]dateFormatParser{ // "%w": dayOfWeek, // Day of the week (0=Sunday..6=Saturday) // "%X": yearOfWeek, // Year for the week where Sunday is the first day of the week, numeric, four digits; used with %V // "%x": yearOfWeek, // Year for the week, where Monday is the first day of the week, numeric, four digits; used with %v - // Deprecated since MySQL 5.7.5 - // "%y": yearTwoDigits, // Year, numeric (two digits) } // GetFormatType checks the type(Duration, Date or Datetime) of a format string. @@ -2235,7 +2235,7 @@ func matchDateWithToken(t *MysqlTime, date string, token string, ctx map[string] } func parseDigits(input string, count int) (int, bool) { - if len(input) < count { + if count <= 0 || len(input) < count { return 0, false } @@ -2432,12 +2432,31 @@ func microSeconds(t *MysqlTime, input string, ctx map[string]int) (string, bool) } func yearNumericFourDigits(t *MysqlTime, input string, ctx map[string]int) (string, bool) { - v, succ := parseDigits(input, 4) - if !succ { + return yearNumericNDigits(t, input, ctx, 4) +} + +func yearNumericTwoDigits(t *MysqlTime, input string, ctx map[string]int) (string, bool) { + return yearNumericNDigits(t, input, ctx, 2) +} + +func yearNumericNDigits(t *MysqlTime, input string, ctx map[string]int, n int) (string, bool) { + effectiveCount, effectiveValue := 0, 0 + for effectiveCount+1 <= n { + value, succeed := parseDigits(input, effectiveCount+1) + if !succeed { + break + } + effectiveCount++ + effectiveValue = value + } + if effectiveCount == 0 { return input, false } - t.year = uint16(v) - return input[4:], true + if effectiveCount <= 2 { + effectiveValue = adjustYear(effectiveValue) + } + t.year = uint16(effectiveValue) + return input[effectiveCount:], true } func dayOfYearThreeDigits(t *MysqlTime, input string, ctx map[string]int) (string, bool) { diff --git a/util/chunk/chunk.go b/util/chunk/chunk.go index 3b91cf3bb693d..52ad133c4fd2b 100644 --- a/util/chunk/chunk.go +++ b/util/chunk/chunk.go @@ -288,7 +288,8 @@ func (c *Chunk) AppendPartialRow(colIdx int, row Row) { // when the Insert() function is called parallelly, the data race on a byte // can not be avoided although the manipulated bits are different inside a // byte. -func (c *Chunk) PreAlloc(row Row) { +func (c *Chunk) PreAlloc(row Row) (rowIdx uint32) { + rowIdx = uint32(c.NumRows()) for i, srcCol := range row.c.columns { dstCol := c.columns[i] dstCol.appendNullBitmap(!srcCol.isNull(row.idx)) @@ -304,16 +305,28 @@ func (c *Chunk) PreAlloc(row Row) { continue } // Grow the capacity according to golang.growslice. + // Implementation differences with golang: + // 1. We double the capacity when `dstCol.data < 1024*elemLen bytes` but + // not `1024 bytes`. + // 2. We expand the capacity to 1.5*originCap rather than 1.25*originCap + // during the slow-increasing phase. newCap := cap(dstCol.data) doubleCap := newCap << 1 if needCap > doubleCap { newCap = needCap } else { - if len(dstCol.data) < 1024 { + avgElemLen := elemLen + if !srcCol.isFixed() { + avgElemLen = len(dstCol.data) / len(dstCol.offsets) + } + // slowIncThreshold indicates the threshold exceeding which the + // dstCol.data capacity increase fold decreases from 2 to 1.5. + slowIncThreshold := 1024 * avgElemLen + if len(dstCol.data) < slowIncThreshold { newCap = doubleCap } else { for 0 < newCap && newCap < needCap { - newCap += newCap / 4 + newCap += newCap / 2 } if newCap <= 0 { newCap = needCap @@ -322,6 +335,7 @@ func (c *Chunk) PreAlloc(row Row) { } dstCol.data = make([]byte, len(dstCol.data)+elemLen, newCap) } + return } // Insert inserts `row` on the position specified by `rowIdx`. diff --git a/util/chunk/list.go b/util/chunk/list.go index e7a939c69c70c..f3dd06f613767 100644 --- a/util/chunk/list.go +++ b/util/chunk/list.go @@ -158,8 +158,7 @@ func (l *List) PreAlloc4Row(row Row) (ptr RowPtr) { chkIdx++ } chk := l.chunks[chkIdx] - rowIdx := chk.NumRows() - chk.PreAlloc(row) + rowIdx := chk.PreAlloc(row) l.length++ return RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)} } diff --git a/util/chunk/list_test.go b/util/chunk/list_test.go index 9669e784a9e2b..014470a58bb3b 100644 --- a/util/chunk/list_test.go +++ b/util/chunk/list_test.go @@ -183,3 +183,38 @@ func BenchmarkListMemoryUsage(b *testing.B) { list.GetMemTracker().BytesConsumed() } } + +func BenchmarkPreAllocList(b *testing.B) { + fieldTypes := make([]*types.FieldType, 0, 1) + fieldTypes = append(fieldTypes, &types.FieldType{Tp: mysql.TypeLonglong}) + chk := NewChunkWithCapacity(fieldTypes, 1) + chk.AppendInt64(0, 1) + row := chk.GetRow(0) + + b.ResetTimer() + list := NewList(fieldTypes, 1024, 1024) + for i := 0; i < b.N; i++ { + list.Reset() + // 32768 indicates the number of int64 rows to fill 256KB L2 cache. + for j := 0; j < 32768; j++ { + list.PreAlloc4Row(row) + } + } +} + +func BenchmarkPreAllocChunk(b *testing.B) { + fieldTypes := make([]*types.FieldType, 0, 1) + fieldTypes = append(fieldTypes, &types.FieldType{Tp: mysql.TypeLonglong}) + chk := NewChunkWithCapacity(fieldTypes, 1) + chk.AppendInt64(0, 1) + row := chk.GetRow(0) + + b.ResetTimer() + finalChk := New(fieldTypes, 33000, 1024) + for i := 0; i < b.N; i++ { + finalChk.Reset() + for j := 0; j < 32768; j++ { + finalChk.PreAlloc(row) + } + } +} diff --git a/util/codec/codec.go b/util/codec/codec.go index fae2935538246..0af81e36ae5f7 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -91,9 +91,8 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab if hash { // If hash is true, we only consider the original value of this decimal and ignore it's precision. dec := vals[i].GetMysqlDecimal() - precision, frac := dec.PrecisionAndFrac() var bin []byte - bin, err = dec.ToBin(precision, frac) + bin, err = dec.ToHashKey() if err != nil { return nil, errors.Trace(err) } @@ -238,9 +237,7 @@ func encodeChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTy if hash { // If hash is true, we only consider the original value of this decimal and ignore it's precision. dec := row.GetMyDecimal(i) - precision, frac := dec.PrecisionAndFrac() - var bin []byte - bin, err = dec.ToBin(precision, frac) + bin, err := dec.ToHashKey() if err != nil { return nil, errors.Trace(err) } diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index 435f3bdc7c6ee..7b1d39644b059 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -173,6 +173,30 @@ func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { return aesDecrypt(cryptStr, mode) } +// AESEncryptWithOFB encrypts data using AES with OFB mode. +func AESEncryptWithOFB(plainStr []byte, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewOFB(cb, iv) + crypted := make([]byte, len(plainStr)) + mode.XORKeyStream(crypted, plainStr) + return crypted, nil +} + +// AESDecryptWithOFB decrypts data using AES with OFB mode. +func AESDecryptWithOFB(cipherStr []byte, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewOFB(cb, iv) + plainStr := make([]byte, len(cipherStr)) + mode.XORKeyStream(plainStr, cipherStr) + return plainStr, nil +} + // AESEncryptWithCFB decrypts data using AES with CFB mode. func AESEncryptWithCFB(cryptStr, key []byte, iv []byte) ([]byte, error) { cb, err := aes.NewCipher(key) diff --git a/util/encrypt/aes_test.go b/util/encrypt/aes_test.go index 14cd1e4b77fa2..0ea417c82090e 100644 --- a/util/encrypt/aes_test.go +++ b/util/encrypt/aes_test.go @@ -304,6 +304,75 @@ func (s *testEncryptSuite) TestAESEncryptWithCBC(c *C) { } } +func (s *testEncryptSuite) TestAESEncryptWithOFB(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + key string + iv string + expect string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "0515A36BBF3DE0", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "0515A36BBF3DE0DBE9DD", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "45A57592449893", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + } + + for _, t := range tests { + str := []byte(t.str) + key := []byte(t.key) + iv := []byte(t.iv) + + crypted, err := AESEncryptWithOFB(str, key, iv) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + result := toHex(crypted) + c.Assert(result, Equals, t.expect, Commentf("%v", t)) + } +} + +func (s *testEncryptSuite) TestAESDecryptWithOFB(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + key string + iv string + expect string + isError bool + }{ + // 128 bits key + {"0515A36BBF3DE0", "1234567890123456", "1234567890123456", "pingcap", false}, + {"0515A36BBF3DE0DBE9DD", "1234567890123456", "1234567890123456", "pingcap123", false}, + // 192 bits key + {"45A57592449893", "123456789012345678901234", "1234567890123456", "pingcap", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + } + + for _, t := range tests { + str, _ := hex.DecodeString(t.str) + key := []byte(t.key) + iv := []byte(t.iv) + + plainText, err := AESDecryptWithOFB(str, key, iv) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + c.Assert(string(plainText), Equals, t.expect, Commentf("%v", t)) + } +} + func (s *testEncryptSuite) TestAESDecryptWithCBC(c *C) { defer testleak.AfterTest(c)() tests := []struct { @@ -342,7 +411,7 @@ func (s *testEncryptSuite) TestAESDecryptWithCBC(c *C) { } } -func (s *testEncryptSuite) TestAESEncryptWithOFB(c *C) { +func (s *testEncryptSuite) TestAESEncryptWithCFB(c *C) { defer testleak.AfterTest(c)() tests := []struct { str string diff --git a/util/execdetails/execdetails.go b/util/execdetails/execdetails.go index 184f5d4fde6e4..aaebbda1a3cc8 100644 --- a/util/execdetails/execdetails.go +++ b/util/execdetails/execdetails.go @@ -53,13 +53,13 @@ type CommitDetails struct { func (d ExecDetails) String() string { parts := make([]string, 0, 6) if d.ProcessTime > 0 { - parts = append(parts, fmt.Sprintf("process_time:%v", d.ProcessTime)) + parts = append(parts, fmt.Sprintf("process_time:%vs", d.ProcessTime.Seconds())) } if d.WaitTime > 0 { - parts = append(parts, fmt.Sprintf("wait_time:%v", d.WaitTime)) + parts = append(parts, fmt.Sprintf("wait_time:%vs", d.WaitTime.Seconds())) } if d.BackoffTime > 0 { - parts = append(parts, fmt.Sprintf("backoff_time:%v", d.BackoffTime)) + parts = append(parts, fmt.Sprintf("backoff_time:%vs", d.BackoffTime.Seconds())) } if d.RequestCount > 0 { parts = append(parts, fmt.Sprintf("request_count:%d", d.RequestCount)) @@ -73,23 +73,23 @@ func (d ExecDetails) String() string { commitDetails := d.CommitDetail if commitDetails != nil { if commitDetails.PrewriteTime > 0 { - parts = append(parts, fmt.Sprintf("prewrite_time:%v", commitDetails.PrewriteTime)) + parts = append(parts, fmt.Sprintf("prewrite_time:%vs", commitDetails.PrewriteTime.Seconds())) } if commitDetails.CommitTime > 0 { - parts = append(parts, fmt.Sprintf("commit_time:%v", commitDetails.CommitTime)) + parts = append(parts, fmt.Sprintf("commit_time:%vs", commitDetails.CommitTime.Seconds())) } if commitDetails.GetCommitTsTime > 0 { - parts = append(parts, fmt.Sprintf("get_commit_ts_time:%v", commitDetails.GetCommitTsTime)) + parts = append(parts, fmt.Sprintf("get_commit_ts_time:%vs", commitDetails.GetCommitTsTime.Seconds())) } if commitDetails.TotalBackoffTime > 0 { - parts = append(parts, fmt.Sprintf("total_backoff_time:%v", commitDetails.TotalBackoffTime)) + parts = append(parts, fmt.Sprintf("total_backoff_time:%vs", commitDetails.TotalBackoffTime.Seconds())) } resolveLockTime := atomic.LoadInt64(&commitDetails.ResolveLockTime) if resolveLockTime > 0 { - parts = append(parts, fmt.Sprintf("resolve_lock_time:%d", time.Duration(resolveLockTime))) + parts = append(parts, fmt.Sprintf("resolve_lock_time:%vs", time.Duration(resolveLockTime).Seconds())) } if commitDetails.LocalLatchTime > 0 { - parts = append(parts, fmt.Sprintf("local_latch_wait_time:%v", commitDetails.LocalLatchTime)) + parts = append(parts, fmt.Sprintf("local_latch_wait_time:%vs", commitDetails.LocalLatchTime.Seconds())) } if commitDetails.WriteKeys > 0 { parts = append(parts, fmt.Sprintf("write_keys:%d", commitDetails.WriteKeys)) diff --git a/util/execdetails/execdetails_test.go b/util/execdetails/execdetails_test.go index b69f2229c7668..cd69856070bf8 100644 --- a/util/execdetails/execdetails_test.go +++ b/util/execdetails/execdetails_test.go @@ -20,14 +20,26 @@ import ( func TestString(t *testing.T) { detail := &ExecDetails{ - ProcessTime: time.Second, + ProcessTime: 2*time.Second + 5*time.Millisecond, WaitTime: time.Second, BackoffTime: time.Second, RequestCount: 1, TotalKeys: 100, ProcessedKeys: 10, + CommitDetail: &CommitDetails{ + GetCommitTsTime: time.Second, + PrewriteTime: time.Second, + CommitTime: time.Second, + LocalLatchTime: time.Second, + TotalBackoffTime: time.Second, + ResolveLockTime: 1000000000, // 10^9 ns = 1s + WriteKeys: 1, + WriteSize: 1, + PrewriteRegionNum: 1, + TxnRetry: 1, + }, } - expected := "process_time:1s wait_time:1s backoff_time:1s request_count:1 total_keys:100 processed_keys:10" + expected := "process_time:2.005s wait_time:1s backoff_time:1s request_count:1 total_keys:100 processed_keys:10 prewrite_time:1s commit_time:1s get_commit_ts_time:1s total_backoff_time:1s resolve_lock_time:1s local_latch_wait_time:1s write_keys:1 write_size:1 prewrite_region:1 txn_retry:1" if str := detail.String(); str != expected { t.Errorf("got:\n%s\nexpected:\n%s", str, expected) } diff --git a/util/mock/context.go b/util/mock/context.go index 96bdd9c5c8e52..c3419792ac857 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -224,7 +224,8 @@ func NewContext() *Context { ctx: ctx, cancel: cancel, } - sctx.sessionVars.MaxChunkSize = 2 + sctx.sessionVars.InitChunkSize = 2 + sctx.sessionVars.MaxChunkSize = 32 sctx.sessionVars.StmtCtx.TimeZone = time.UTC sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor() return sctx