diff --git a/pkg/server/internal/testserverclient/BUILD.bazel b/pkg/server/internal/testserverclient/BUILD.bazel new file mode 100644 index 0000000000000..b29f420f7e1f5 --- /dev/null +++ b/pkg/server/internal/testserverclient/BUILD.bazel @@ -0,0 +1,24 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "testserverclient", + srcs = ["server_client.go"], + importpath = "github.com/pingcap/tidb/pkg/server/internal/testserverclient", + visibility = ["//pkg/server:__subpackages__"], + deps = [ + "//pkg/config", + "//pkg/errno", + "//pkg/kv", + "//pkg/parser/mysql", + "//pkg/server", + "//pkg/testkit", + "//pkg/testkit/testenv", + "//pkg/util/versioninfo", + "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_log//:log", + "@com_github_stretchr_testify//require", + "@org_uber_go_zap//:zap", + ], +) diff --git a/pkg/server/tests/tidb_test.go b/pkg/server/tests/tidb_test.go new file mode 100644 index 0000000000000..365fad6d60ecc --- /dev/null +++ b/pkg/server/tests/tidb_test.go @@ -0,0 +1,3129 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "database/sql" + "encoding/binary" + "encoding/pem" + "fmt" + "io" + "math/big" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/extension" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + tmysql "github.com/pingcap/tidb/pkg/parser/mysql" + server2 "github.com/pingcap/tidb/pkg/server" + "github.com/pingcap/tidb/pkg/server/internal/column" + "github.com/pingcap/tidb/pkg/server/internal/resultset" + "github.com/pingcap/tidb/pkg/server/internal/testserverclient" + "github.com/pingcap/tidb/pkg/server/internal/testutil" + util2 "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/store/mockstore" + "github.com/pingcap/tidb/pkg/store/mockstore/unistore" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/cpuprofile" + "github.com/pingcap/tidb/pkg/util/plancodec" + "github.com/pingcap/tidb/pkg/util/resourcegrouptag" + "github.com/pingcap/tidb/pkg/util/topsql" + "github.com/pingcap/tidb/pkg/util/topsql/collector" + mockTopSQLTraceCPU "github.com/pingcap/tidb/pkg/util/topsql/collector/mock" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/tikvrpc" + "go.opencensus.io/stats/view" +) + +type tidbTestSuite struct { + *testserverclient.TestServerClient + tidbdrv *server2.TiDBDriver + server *server2.Server + domain *domain.Domain + store kv.Storage +} + +func createTidbTestSuite(t *testing.T) *tidbTestSuite { + cfg := util2.NewTestConfig() + cfg.Port = 0 + cfg.Status.ReportStatus = true + cfg.Status.StatusPort = 0 + cfg.Status.RecordDBLabel = true + cfg.Performance.TCPKeepAlive = true + return createTidbTestSuiteWithCfg(t, cfg) +} + +func createTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *tidbTestSuite { + ts := &tidbTestSuite{TestServerClient: testserverclient.NewTestServerClient()} + + // setup tidbTestSuite + var err error + ts.store, err = mockstore.NewMockStore() + session.DisableStats4Test() + require.NoError(t, err) + ts.domain, err = session.BootstrapSession(ts.store) + require.NoError(t, err) + ts.tidbdrv = server2.NewTiDBDriver(ts.store) + + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + ts.server = server + ts.server.SetDomain(ts.domain) + ts.domain.InfoSyncer().SetSessionManager(ts.server) + go func() { + err := ts.server.Run() + require.NoError(t, err) + }() + ts.WaitUntilServerOnline() + + t.Cleanup(func() { + if ts.domain != nil { + ts.domain.Close() + } + if ts.server != nil { + ts.server.Close() + } + if ts.store != nil { + require.NoError(t, ts.store.Close()) + } + view.Stop() + }) + return ts +} + +type tidbTestTopSQLSuite struct { + *tidbTestSuite +} + +func createTidbTestTopSQLSuite(t *testing.T) *tidbTestTopSQLSuite { + base := createTidbTestSuite(t) + + ts := &tidbTestTopSQLSuite{base} + + // Initialize global variable for top-sql test. + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + dbt := testkit.NewDBTestKit(t, db) + topsqlstate.GlobalState.PrecisionSeconds.Store(1) + topsqlstate.GlobalState.ReportIntervalSeconds.Store(2) + dbt.MustExec("set @@global.tidb_top_sql_max_time_series_count=5;") + + require.NoError(t, cpuprofile.StartCPUProfiler()) + t.Cleanup(func() { + cpuprofile.StopCPUProfiler() + topsqlstate.GlobalState.PrecisionSeconds.Store(topsqlstate.DefTiDBTopSQLPrecisionSeconds) + topsqlstate.GlobalState.ReportIntervalSeconds.Store(topsqlstate.DefTiDBTopSQLReportIntervalSeconds) + view.Stop() + }) + return ts +} + +func TestRegression(t *testing.T) { + ts := createTidbTestSuite(t) + if testserverclient.Regression { + ts.RunTestRegression(t, nil, "Regression") + } +} + +func TestUint64(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestPrepareResultFieldType(t) +} + +func TestSpecialType(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestSpecialType(t) +} + +func TestPreparedString(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestPreparedString(t) +} + +func TestPreparedTimestamp(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestPreparedTimestamp(t) +} + +func TestConcurrentUpdate(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestConcurrentUpdate(t) +} + +func TestErrorCode(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestErrorCode(t) +} + +func TestAuth(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestAuth(t) + ts.RunTestIssue3682(t) + ts.RunTestAccountLock(t) +} + +func TestIssues(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestIssue3662(t) + ts.RunTestIssue3680(t) + ts.RunTestIssue22646(t) +} + +func TestDBNameEscape(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestDBNameEscape(t) +} + +func TestResultFieldTableIsNull(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestResultFieldTableIsNull(t) +} + +func TestStatusAPI(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestStatusAPI(t) +} + +func TestStatusPort(t *testing.T) { + ts := createTidbTestSuite(t) + + cfg := util2.NewTestConfig() + cfg.Port = 0 + cfg.Status.ReportStatus = true + cfg.Status.StatusPort = ts.StatusPort + cfg.Performance.TCPKeepAlive = true + + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.Error(t, err) + require.Nil(t, server) +} + +func TestStatusAPIWithTLS(t *testing.T) { + ts := createTidbTestSuite(t) + + dir := t.TempDir() + + fileName := func(file string) string { + return filepath.Join(dir, file) + } + + caCert, caKey, err := generateCert(0, "TiDB CA 2", nil, nil, fileName("ca-key-2.pem"), fileName("ca-cert-2.pem")) + require.NoError(t, err) + _, _, err = generateCert(1, "tidb-server-2", caCert, caKey, fileName("server-key-2.pem"), fileName("server-cert-2.pem")) + require.NoError(t, err) + + cli := testserverclient.NewTestServerClient() + cli.StatusScheme = "https" + cfg := util2.NewTestConfig() + cfg.Port = cli.Port + cfg.Status.StatusPort = cli.StatusPort + cfg.Security.ClusterSSLCA = fileName("ca-cert-2.pem") + cfg.Security.ClusterSSLCert = fileName("server-cert-2.pem") + cfg.Security.ClusterSSLKey = fileName("server-key-2.pem") + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + + // https connection should work. + ts.RunTestStatusAPI(t) + + // but plain http connection should fail. + cli.StatusScheme = "http" + //nolint:bodyclose + _, err = cli.FetchStatus("/status") + require.Error(t, err) + + server.Close() +} + +func TestStatusAPIWithTLSCNCheck(t *testing.T) { + ts := createTidbTestSuite(t) + + dir := t.TempDir() + + caPath := filepath.Join(dir, "ca-cert-cn.pem") + serverKeyPath := filepath.Join(dir, "server-key-cn.pem") + serverCertPath := filepath.Join(dir, "server-cert-cn.pem") + client1KeyPath := filepath.Join(dir, "client-key-cn-check-a.pem") + client1CertPath := filepath.Join(dir, "client-cert-cn-check-a.pem") + client2KeyPath := filepath.Join(dir, "client-key-cn-check-b.pem") + client2CertPath := filepath.Join(dir, "client-cert-cn-check-b.pem") + + caCert, caKey, err := generateCert(0, "TiDB CA CN CHECK", nil, nil, filepath.Join(dir, "ca-key-cn.pem"), caPath) + require.NoError(t, err) + _, _, err = generateCert(1, "tidb-server-cn-check", caCert, caKey, serverKeyPath, serverCertPath) + require.NoError(t, err) + _, _, err = generateCert(2, "tidb-client-cn-check-a", caCert, caKey, client1KeyPath, client1CertPath, func(c *x509.Certificate) { + c.Subject.CommonName = "tidb-client-1" + }) + require.NoError(t, err) + _, _, err = generateCert(3, "tidb-client-cn-check-b", caCert, caKey, client2KeyPath, client2CertPath, func(c *x509.Certificate) { + c.Subject.CommonName = "tidb-client-2" + }) + require.NoError(t, err) + + cli := testserverclient.NewTestServerClient() + cli.StatusScheme = "https" + cfg := util2.NewTestConfig() + cfg.Port = cli.Port + cfg.Status.StatusPort = cli.StatusPort + cfg.Security.ClusterSSLCA = caPath + cfg.Security.ClusterSSLCert = serverCertPath + cfg.Security.ClusterSSLKey = serverKeyPath + cfg.Security.ClusterVerifyCN = []string{"tidb-client-2"} + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + defer server.Close() + time.Sleep(time.Millisecond * 100) + + hc := newTLSHttpClient(t, caPath, + client1CertPath, + client1KeyPath, + ) + //nolint:bodyclose + _, err = hc.Get(cli.StatusURL("/status")) + require.Error(t, err) + + hc = newTLSHttpClient(t, caPath, + client2CertPath, + client2KeyPath, + ) + resp, err := hc.Get(cli.StatusURL("/status")) + require.NoError(t, err) + require.Nil(t, resp.Body.Close()) +} + +func newTLSHttpClient(t *testing.T, caFile, certFile, keyFile string) *http.Client { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) + caCert, err := os.ReadFile(caFile) + require.NoError(t, err) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + InsecureSkipVerify: true, + } + tlsConfig.BuildNameToCertificate() + return &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} +} + +func TestMultiStatements(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunFailedTestMultiStatements(t) + ts.RunTestMultiStatements(t) +} + +func TestSocketForwarding(t *testing.T) { + tempDir := t.TempDir() + socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK + + ts := createTidbTestSuite(t) + + cli := testserverclient.NewTestServerClient() + cfg := util2.NewTestConfig() + cfg.Socket = socketFile + cfg.Port = cli.Port + os.Remove(cfg.Socket) + cfg.Status.ReportStatus = false + + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + server.SetDomain(ts.domain) + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + defer server.Close() + + cli.RunTestRegression(t, func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = socketFile + config.DBName = "test" + config.Params = map[string]string{"sql_mode": "'STRICT_ALL_TABLES'"} + }, "SocketRegression") +} + +func TestSocket(t *testing.T) { + tempDir := t.TempDir() + socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK + + cfg := util2.NewTestConfig() + cfg.Socket = socketFile + cfg.Port = 0 + os.Remove(cfg.Socket) + cfg.Host = "" + cfg.Status.ReportStatus = false + + ts := createTidbTestSuite(t) + + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + server.SetDomain(ts.domain) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + defer server.Close() + + confFunc := func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = socketFile + config.DBName = "test" + config.Params = map[string]string{"sql_mode": "STRICT_ALL_TABLES"} + } + // a fake server client, config is override, just used to run tests + cli := testserverclient.NewTestServerClient() + cli.WaitUntilCustomServerCanConnect(confFunc) + cli.RunTestRegression(t, confFunc, "SocketRegression") +} + +func TestSocketAndIp(t *testing.T) { + tempDir := t.TempDir() + socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK + + cli := testserverclient.NewTestServerClient() + cfg := util2.NewTestConfig() + cfg.Socket = socketFile + cfg.Port = cli.Port + cfg.Status.ReportStatus = false + + ts := createTidbTestSuite(t) + + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + server.SetDomain(ts.domain) + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + cli.WaitUntilServerCanConnect() + defer server.Close() + + // Test with Socket connection + Setup user1@% for all host access + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + defer func() { + cli.RunTests(t, func(config *mysql.Config) { + config.User = "root" + }, + func(dbt *testkit.DBTestKit) { + dbt.MustExec("DROP USER IF EXISTS 'user1'@'%'") + dbt.MustExec("DROP USER IF EXISTS 'user1'@'localhost'") + dbt.MustExec("DROP USER IF EXISTS 'user1'@'127.0.0.1'") + }) + }() + cli.RunTests(t, func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = socketFile + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "root@localhost") + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@'%'") + dbt.MustQuery("GRANT SELECT ON test.* TO user1@'%'") + }) + // Test with Network interface connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) + cli.CheckRows(t, rows, "user1@127.0.0.1") + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON `test`.* TO 'user1'@'%'") + rows = dbt.MustQuery("select host from information_schema.processlist where user = 'user1'") + records := cli.Rows(t, rows) + require.Contains(t, records[0], ":", "Missing : in is.processlist") + }) + // Test with unix domain socket file connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON `test`.* TO 'user1'@'%'") + }) + + // Setup user1@127.0.0.1 for loop back network interface access + cli.RunTests(t, func(config *mysql.Config) { + config.User = "root" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) + cli.CheckRows(t, rows, "root@127.0.0.1") + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustQuery("CREATE USER user1@127.0.0.1") + dbt.MustQuery("GRANT SELECT,INSERT ON test.* TO user1@'127.0.0.1'") + }) + // Test with Network interface connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) + cli.CheckRows(t, rows, "user1@127.0.0.1") + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'127.0.0.1'\nGRANT SELECT,INSERT ON `test`.* TO 'user1'@'127.0.0.1'") + }) + // Test with unix domain socket file connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "user1@localhost") + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON `test`.* TO 'user1'@'%'") + }) + + // Setup user1@localhost for socket (and if MySQL compatible; loop back network interface access) + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "root" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "root@localhost") + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + dbt.MustExec("CREATE USER user1@localhost") + dbt.MustExec("GRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO user1@localhost") + }) + // Test with Network interface connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) + cli.CheckRows(t, rows, "user1@127.0.0.1") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'127.0.0.1'\nGRANT SELECT,INSERT ON `test`.* TO 'user1'@'127.0.0.1'") + require.NoError(t, rows.Close()) + }) + // Test with unix domain socket file connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "user1@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'localhost'\nGRANT SELECT,INSERT,UPDATE,DELETE ON `test`.* TO 'user1'@'localhost'") + require.NoError(t, rows.Close()) + }) +} + +// TestOnlySocket for server configuration without network interface for mysql clients +func TestOnlySocket(t *testing.T) { + tempDir := t.TempDir() + socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK + + cli := testserverclient.NewTestServerClient() + cfg := util2.NewTestConfig() + cfg.Socket = socketFile + cfg.Host = "" // No network interface listening for mysql traffic + cfg.Status.ReportStatus = false + + ts := createTidbTestSuite(t) + + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + server.SetDomain(ts.domain) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + defer server.Close() + require.Nil(t, server.Listener()) + require.NotNil(t, server.Socket()) + + // Test with Socket connection + Setup user1@% for all host access + defer func() { + cli.RunTests(t, func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = socketFile + }, + func(dbt *testkit.DBTestKit) { + dbt.MustExec("DROP USER IF EXISTS 'user1'@'%'") + dbt.MustExec("DROP USER IF EXISTS 'user1'@'localhost'") + dbt.MustExec("DROP USER IF EXISTS 'user1'@'127.0.0.1'") + }) + }() + cli.RunTests(t, func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = socketFile + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "root@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + require.NoError(t, rows.Close()) + dbt.MustExec("CREATE USER user1@'%'") + dbt.MustExec("GRANT SELECT ON test.* TO user1@'%'") + }) + // Test with Network interface connection with all hosts, should fail since server not configured + db, err := sql.Open("mysql", cli.GetDSN(func(config *mysql.Config) { + config.User = "root" + config.DBName = "test" + })) + require.NoErrorf(t, err, "Open failed") + err = db.Ping() + require.Errorf(t, err, "Connect succeeded when not configured!?!") + db.Close() + db, err = sql.Open("mysql", cli.GetDSN(func(config *mysql.Config) { + config.User = "user1" + config.DBName = "test" + })) + require.NoErrorf(t, err, "Open failed") + err = db.Ping() + require.Errorf(t, err, "Connect succeeded when not configured!?!") + db.Close() + // Test with unix domain socket file connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "user1@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON `test`.* TO 'user1'@'%'") + require.NoError(t, rows.Close()) + }) + + // Setup user1@127.0.0.1 for loop back network interface access + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "root" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + // NOTICE: this is not compatible with MySQL! (MySQL would report user1@localhost also for 127.0.0.1) + cli.CheckRows(t, rows, "root@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + require.NoError(t, rows.Close()) + dbt.MustExec("CREATE USER user1@127.0.0.1") + dbt.MustExec("GRANT SELECT,INSERT ON test.* TO user1@'127.0.0.1'") + }) + // Test with unix domain socket file connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "user1@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'%'\nGRANT SELECT ON `test`.* TO 'user1'@'%'") + require.NoError(t, rows.Close()) + }) + + // Setup user1@localhost for socket (and if MySQL compatible; loop back network interface access) + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "root" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "root@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + require.NoError(t, rows.Close()) + dbt.MustExec("CREATE USER user1@localhost") + dbt.MustExec("GRANT SELECT,INSERT,UPDATE,DELETE ON test.* TO user1@localhost") + }) + // Test with unix domain socket file connection with all hosts + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "user1" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "user1@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'user1'@'localhost'\nGRANT SELECT,INSERT,UPDATE,DELETE ON `test`.* TO 'user1'@'localhost'") + require.NoError(t, rows.Close()) + }) +} + +// generateCert generates a private key and a certificate in PEM format based on parameters. +// If parentCert and parentCertKey is specified, the new certificate will be signed by the parentCert. +// Otherwise, the new certificate will be self-signed and is a CA. +func generateCert(sn int, commonName string, parentCert *x509.Certificate, parentCertKey *rsa.PrivateKey, outKeyFile string, outCertFile string, opts ...func(c *x509.Certificate)) (*x509.Certificate, *rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 528) + if err != nil { + return nil, nil, errors.Trace(err) + } + notBefore := time.Now().Add(-10 * time.Minute).UTC() + notAfter := notBefore.Add(1 * time.Hour).UTC() + + template := x509.Certificate{ + SerialNumber: big.NewInt(int64(sn)), + Subject: pkix.Name{CommonName: commonName, Names: []pkix.AttributeTypeAndValue{util.MockPkixAttribute(util.CommonName, commonName)}}, + DNSNames: []string{commonName}, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + for _, opt := range opts { + opt(&template) + } + + var parent *x509.Certificate + var priv *rsa.PrivateKey + + if parentCert == nil || parentCertKey == nil { + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + parent = &template + priv = privateKey + } else { + parent = parentCert + priv = parentCertKey + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, parent, &privateKey.PublicKey, priv) + if err != nil { + return nil, nil, errors.Trace(err) + } + + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + return nil, nil, errors.Trace(err) + } + + certOut, err := os.Create(outCertFile) + if err != nil { + return nil, nil, errors.Trace(err) + } + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + if err != nil { + return nil, nil, errors.Trace(err) + } + err = certOut.Close() + if err != nil { + return nil, nil, errors.Trace(err) + } + + keyOut, err := os.OpenFile(outKeyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return nil, nil, errors.Trace(err) + } + err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + if err != nil { + return nil, nil, errors.Trace(err) + } + err = keyOut.Close() + if err != nil { + return nil, nil, errors.Trace(err) + } + + return cert, privateKey, nil +} + +// registerTLSConfig registers a mysql client TLS config. +// See https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig for details. +func registerTLSConfig(configName string, caCertPath string, clientCertPath string, clientKeyPath string, serverName string, verifyServer bool) error { + rootCertPool := x509.NewCertPool() + data, err := os.ReadFile(caCertPath) + if err != nil { + return err + } + if ok := rootCertPool.AppendCertsFromPEM(data); !ok { + return errors.New("Failed to append PEM") + } + clientCert := make([]tls.Certificate, 0, 1) + certs, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + if err != nil { + return err + } + clientCert = append(clientCert, certs) + tlsConfig := &tls.Config{ + RootCAs: rootCertPool, + Certificates: clientCert, + ServerName: serverName, + InsecureSkipVerify: !verifyServer, + } + return mysql.RegisterTLSConfig(configName, tlsConfig) +} + +func TestSystemTimeZone(t *testing.T) { + ts := createTidbTestSuite(t) + + tk := testkit.NewTestKit(t, ts.store) + cfg := util2.NewTestConfig() + cfg.Port, cfg.Status.StatusPort = 0, 0 + cfg.Status.ReportStatus = false + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + defer server.Close() + + tz1 := tk.MustQuery("select variable_value from mysql.tidb where variable_name = 'system_tz'").Rows() + tk.MustQuery("select @@system_time_zone").Check(tz1) +} + +func TestInternalSessionTxnStartTS(t *testing.T) { + ts := createTidbTestSuite(t) + + se, err := session.CreateSession4Test(ts.store) + require.NoError(t, err) + + _, err = se.Execute(context.Background(), "set global tidb_enable_metadata_lock=0") + require.NoError(t, err) + + count := 10 + stmts := make([]ast.StmtNode, count) + for i := 0; i < count; i++ { + stmt, err := session.ParseWithParams4Test(context.Background(), se, "select * from mysql.user limit 1") + require.NoError(t, err) + stmts[i] = stmt + } + // Test an issue that sysSessionPool doesn't call session's Close, cause + // asyncGetTSWorker goroutine leak. + var wg util.WaitGroupWrapper + for i := 0; i < count; i++ { + s := stmts[i] + wg.Run(func() { + _, _, err := session.ExecRestrictedStmt4Test(context.Background(), se, s) + require.NoError(t, err) + }) + } + + wg.Wait() +} + +func TestClientWithCollation(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestClientWithCollation(t) +} + +func TestCreateTableFlen(t *testing.T) { + ts := createTidbTestSuite(t) + + // issue #4540 + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + _, err = Execute(context.Background(), qctx, "use test;") + require.NoError(t, err) + + ctx := context.Background() + testSQL := "CREATE TABLE `t1` (" + + "`a` char(36) NOT NULL," + + "`b` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP," + + "`c` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP," + + "`d` varchar(50) DEFAULT ''," + + "`e` char(36) NOT NULL DEFAULT ''," + + "`f` char(36) NOT NULL DEFAULT ''," + + "`g` char(1) NOT NULL DEFAULT 'N'," + + "`h` varchar(100) NOT NULL," + + "`i` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP," + + "`j` varchar(10) DEFAULT ''," + + "`k` varchar(10) DEFAULT ''," + + "`l` varchar(20) DEFAULT ''," + + "`m` varchar(20) DEFAULT ''," + + "`n` varchar(30) DEFAULT ''," + + "`o` varchar(100) DEFAULT ''," + + "`p` varchar(50) DEFAULT ''," + + "`q` varchar(50) DEFAULT ''," + + "`r` varchar(100) DEFAULT ''," + + "`s` varchar(20) DEFAULT ''," + + "`t` varchar(50) DEFAULT ''," + + "`u` varchar(100) DEFAULT ''," + + "`v` varchar(50) DEFAULT ''," + + "`w` varchar(300) NOT NULL," + + "`x` varchar(250) DEFAULT ''," + + "`y` decimal(20)," + + "`z` decimal(20, 4)," + + "PRIMARY KEY (`a`)" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin" + _, err = Execute(ctx, qctx, testSQL) + require.NoError(t, err) + rs, err := Execute(ctx, qctx, "show create table t1") + require.NoError(t, err) + req := rs.NewChunk(nil) + err = rs.Next(ctx, req) + require.NoError(t, err) + cols := rs.Columns() + require.NoError(t, err) + require.Len(t, cols, 2) + require.Equal(t, 5*tmysql.MaxBytesOfCharacter, int(cols[0].ColumnLength)) + require.Equal(t, len(req.GetRow(0).GetString(1))*tmysql.MaxBytesOfCharacter, int(cols[1].ColumnLength)) + + // for issue#5246 + rs, err = Execute(ctx, qctx, "select y, z from t1") + require.NoError(t, err) + cols = rs.Columns() + require.Len(t, cols, 2) + require.Equal(t, 21, int(cols[0].ColumnLength)) + require.Equal(t, 22, int(cols[1].ColumnLength)) + rs.Close() +} + +func Execute(ctx context.Context, qc *server2.TiDBContext, sql string) (resultset.ResultSet, error) { + stmts, err := qc.Parse(ctx, sql) + if err != nil { + return nil, err + } + if len(stmts) != 1 { + panic("wrong input for Execute: " + sql) + } + return qc.ExecuteStmt(ctx, stmts[0]) +} + +func TestShowTablesFlen(t *testing.T) { + ts := createTidbTestSuite(t) + + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + ctx := context.Background() + _, err = Execute(ctx, qctx, "use test;") + require.NoError(t, err) + + testSQL := "create table abcdefghijklmnopqrstuvwxyz (i int)" + _, err = Execute(ctx, qctx, testSQL) + require.NoError(t, err) + rs, err := Execute(ctx, qctx, "show tables") + require.NoError(t, err) + req := rs.NewChunk(nil) + err = rs.Next(ctx, req) + require.NoError(t, err) + cols := rs.Columns() + require.NoError(t, err) + require.Len(t, cols, 1) + require.Equal(t, 26*tmysql.MaxBytesOfCharacter, int(cols[0].ColumnLength)) +} + +func checkColNames(t *testing.T, columns []*column.Info, names ...string) { + for i, name := range names { + require.Equal(t, name, columns[i].Name) + require.Equal(t, name, columns[i].OrgName) + } +} + +func TestFieldList(t *testing.T) { + ts := createTidbTestSuite(t) + + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + _, err = Execute(context.Background(), qctx, "use test;") + require.NoError(t, err) + + ctx := context.Background() + testSQL := `create table t ( + c_bit bit(10), + c_int_d int, + c_bigint_d bigint, + c_float_d float, + c_double_d double, + c_decimal decimal(6, 3), + c_datetime datetime(2), + c_time time(3), + c_date date, + c_timestamp timestamp(4) DEFAULT CURRENT_TIMESTAMP(4), + c_char char(20), + c_varchar varchar(20), + c_text_d text, + c_binary binary(20), + c_blob_d blob, + c_set set('a', 'b', 'c'), + c_enum enum('a', 'b', 'c'), + c_json JSON, + c_year year + )` + _, err = Execute(ctx, qctx, testSQL) + require.NoError(t, err) + colInfos, err := qctx.FieldList("t") + require.NoError(t, err) + require.Len(t, colInfos, 19) + + checkColNames(t, colInfos, "c_bit", "c_int_d", "c_bigint_d", "c_float_d", + "c_double_d", "c_decimal", "c_datetime", "c_time", "c_date", "c_timestamp", + "c_char", "c_varchar", "c_text_d", "c_binary", "c_blob_d", "c_set", "c_enum", + "c_json", "c_year") + + for _, cols := range colInfos { + require.Equal(t, "test", cols.Schema) + } + + for _, cols := range colInfos { + require.Equal(t, "t", cols.Table) + } + + for i, col := range colInfos { + switch i { + case 10, 11, 12, 15, 16: + // c_char char(20), c_varchar varchar(20), c_text_d text, + // c_set set('a', 'b', 'c'), c_enum enum('a', 'b', 'c') + require.Equalf(t, uint16(tmysql.CharsetNameToID(tmysql.DefaultCharset)), col.Charset, "index %d", i) + continue + } + + require.Equalf(t, uint16(tmysql.CharsetNameToID("binary")), col.Charset, "index %d", i) + } + + // c_decimal decimal(6, 3) + require.Equal(t, uint8(3), colInfos[5].Decimal) + + // for issue#10513 + tooLongColumnAsName := "COALESCE(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)" + columnAsName := tooLongColumnAsName[:tmysql.MaxAliasIdentifierLen] + + rs, err := Execute(ctx, qctx, "select "+tooLongColumnAsName) + require.NoError(t, err) + cols := rs.Columns() + require.Equal(t, "", cols[0].OrgName) + require.Equal(t, columnAsName, cols[0].Name) + rs.Close() + + rs, err = Execute(ctx, qctx, "select c_bit as '"+tooLongColumnAsName+"' from t") + require.NoError(t, err) + cols = rs.Columns() + require.Equal(t, "c_bit", cols[0].OrgName) + require.Equal(t, columnAsName, cols[0].Name) + rs.Close() +} + +func TestClientErrors(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestInfoschemaClientErrors(t) +} + +func TestInitConnect(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestInitConnect(t) +} + +func TestSumAvg(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestSumAvg(t) +} + +func TestStmtCountLimit(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestStmtCountLimit(t) +} + +func TestNullFlag(t *testing.T) { + ts := createTidbTestSuite(t) + + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + + ctx := context.Background() + { + // issue #9689 + rs, err := Execute(ctx, qctx, "select 1") + require.NoError(t, err) + cols := rs.Columns() + require.Len(t, cols, 1) + expectFlag := uint16(tmysql.NotNullFlag | tmysql.BinaryFlag) + require.Equal(t, expectFlag, column.DumpFlag(cols[0].Type, cols[0].Flag)) + rs.Close() + } + + { + // issue #19025 + rs, err := Execute(ctx, qctx, "select convert('{}', JSON)") + require.NoError(t, err) + cols := rs.Columns() + require.Len(t, cols, 1) + expectFlag := uint16(tmysql.BinaryFlag) + require.Equal(t, expectFlag, column.DumpFlag(cols[0].Type, cols[0].Flag)) + rs.Close() + } + + { + // issue #18488 + _, err := Execute(ctx, qctx, "use test") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "CREATE TABLE `test` (`iD` bigint(20) NOT NULL, `INT_TEST` int(11) DEFAULT NULL);") + require.NoError(t, err) + rs, err := Execute(ctx, qctx, `SELECT id + int_test as res FROM test GROUP BY res ORDER BY res;`) + require.NoError(t, err) + cols := rs.Columns() + require.Len(t, cols, 1) + expectFlag := uint16(tmysql.BinaryFlag) + require.Equal(t, expectFlag, column.DumpFlag(cols[0].Type, cols[0].Flag)) + rs.Close() + } + + { + rs, err := Execute(ctx, qctx, "select if(1, null, 1) ;") + require.NoError(t, err) + cols := rs.Columns() + require.Len(t, cols, 1) + expectFlag := uint16(tmysql.BinaryFlag) + require.Equal(t, expectFlag, column.DumpFlag(cols[0].Type, cols[0].Flag)) + rs.Close() + } + { + rs, err := Execute(ctx, qctx, "select CASE 1 WHEN 2 THEN 1 END ;") + require.NoError(t, err) + cols := rs.Columns() + require.Len(t, cols, 1) + expectFlag := uint16(tmysql.BinaryFlag) + require.Equal(t, expectFlag, column.DumpFlag(cols[0].Type, cols[0].Flag)) + rs.Close() + } + { + rs, err := Execute(ctx, qctx, "select NULL;") + require.NoError(t, err) + cols := rs.Columns() + require.Len(t, cols, 1) + expectFlag := uint16(tmysql.BinaryFlag) + require.Equal(t, expectFlag, column.DumpFlag(cols[0].Type, cols[0].Flag)) + rs.Close() + } +} + +func TestNO_DEFAULT_VALUEFlag(t *testing.T) { + ts := createTidbTestSuite(t) + + // issue #21465 + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + + ctx := context.Background() + _, err = Execute(ctx, qctx, "use test") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "drop table if exists t") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "create table t(c1 int key, c2 int);") + require.NoError(t, err) + rs, err := Execute(ctx, qctx, "select c1 from t;") + require.NoError(t, err) + defer rs.Close() + cols := rs.Columns() + require.Len(t, cols, 1) + expectFlag := uint16(tmysql.NotNullFlag | tmysql.PriKeyFlag | tmysql.NoDefaultValueFlag) + require.Equal(t, expectFlag, column.DumpFlag(cols[0].Type, cols[0].Flag)) +} + +func TestGracefulShutdown(t *testing.T) { + ts := createTidbTestSuite(t) + + cli := testserverclient.NewTestServerClient() + cfg := util2.NewTestConfig() + cfg.GracefulWaitBeforeShutdown = 2 // wait before shutdown + cfg.Port = 0 + cfg.Status.StatusPort = 0 + cfg.Status.ReportStatus = true + cfg.Performance.TCPKeepAlive = true + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + require.NotNil(t, server) + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + cli.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + + resp, err := cli.FetchStatus("/status") // server is up + require.NoError(t, err) + require.Nil(t, resp.Body.Close()) + + go server.Close() + time.Sleep(time.Millisecond * 500) + + resp, _ = cli.FetchStatus("/status") // should return 5xx code + require.Equal(t, 500, resp.StatusCode) + require.Nil(t, resp.Body.Close()) + + time.Sleep(time.Second * 2) + + //nolint:bodyclose + _, err = cli.FetchStatus("/status") // Status is gone + require.Error(t, err) + require.Regexp(t, "connect: connection refused$", err.Error()) +} + +func TestPessimisticInsertSelectForUpdate(t *testing.T) { + ts := createTidbTestSuite(t) + + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + defer qctx.Close() + ctx := context.Background() + _, err = Execute(ctx, qctx, "use test;") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "drop table if exists t1, t2") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "create table t1 (id int)") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "create table t2 (id int)") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "insert into t1 select 1") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "begin pessimistic") + require.NoError(t, err) + rs, err := Execute(ctx, qctx, "INSERT INTO t2 (id) select id from t1 where id = 1 for update") + require.NoError(t, err) + require.Nil(t, rs) // should be no delay +} + +func TestTopSQLCatchRunningSQL(t *testing.T) { + ts := createTidbTestTopSQLSuite(t) + + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("drop database if exists topsql") + dbt.MustExec("create database topsql") + dbt.MustExec("use topsql;") + dbt.MustExec("create table t (a int, b int);") + + for i := 0; i < 5000; i++ { + dbt.MustExec(fmt.Sprintf("insert into t values (%v, %v)", i, i)) + } + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/topsql/mockHighLoadForEachPlan", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/skipLoadSysVarCacheLoop", `return(true)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/topsql/mockHighLoadForEachPlan")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/skipLoadSysVarCacheLoop")) + }() + + mc := mockTopSQLTraceCPU.NewTopSQLCollector() + topsql.SetupTopSQLForTest(mc) + sqlCPUCollector := collector.NewSQLCPUCollector(mc) + sqlCPUCollector.Start() + defer sqlCPUCollector.Stop() + + query := "select count(*) from t as t0 join t as t1 on t0.a != t1.a;" + needEnableTopSQL := int64(0) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + default: + } + if atomic.LoadInt64(&needEnableTopSQL) == 1 { + time.Sleep(2 * time.Millisecond) + topsqlstate.EnableTopSQL() + atomic.StoreInt64(&needEnableTopSQL, 0) + } + time.Sleep(time.Millisecond) + } + }() + execFn := func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + atomic.StoreInt64(&needEnableTopSQL, 1) + mustQuery(t, dbt, query) + topsqlstate.DisableTopSQL() + } + check := func() { + require.NoError(t, ctx.Err()) + stats := mc.GetSQLStatsBySQLWithRetry(query, true) + require.Greaterf(t, len(stats), 0, query) + } + ts.testCase(t, mc, execFn, check) + cancel() + wg.Wait() +} + +func TestTopSQLCPUProfile(t *testing.T) { + ts := createTidbTestTopSQLSuite(t) + + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/topsql/mockHighLoadForEachSQL", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/topsql/mockHighLoadForEachPlan", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/skipLoadSysVarCacheLoop", `return(true)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/topsql/mockHighLoadForEachSQL")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/topsql/mockHighLoadForEachPlan")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/skipLoadSysVarCacheLoop")) + }() + + topsqlstate.EnableTopSQL() + defer topsqlstate.DisableTopSQL() + + mc := mockTopSQLTraceCPU.NewTopSQLCollector() + topsql.SetupTopSQLForTest(mc) + sqlCPUCollector := collector.NewSQLCPUCollector(mc) + sqlCPUCollector.Start() + defer sqlCPUCollector.Stop() + + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("drop database if exists topsql") + dbt.MustExec("create database topsql") + dbt.MustExec("use topsql;") + dbt.MustExec("create table t (a int auto_increment, b int, unique index idx(a));") + dbt.MustExec("create table t1 (a int auto_increment, b int, unique index idx(a));") + dbt.MustExec("create table t2 (a int auto_increment, b int, unique index idx(a));") + dbt.MustExec("set @@global.tidb_txn_mode = 'pessimistic'") + + checkFn := func(sql, planRegexp string) { + stats := mc.GetSQLStatsBySQLWithRetry(sql, len(planRegexp) > 0) + // since 1 sql may has many plan, check `len(stats) > 0` instead of `len(stats) == 1`. + require.Greaterf(t, len(stats), 0, "sql: "+sql) + + for _, s := range stats { + sqlStr := mc.GetSQL(s.SQLDigest) + encodedPlan := mc.GetPlan(s.PlanDigest) + // Normalize the user SQL before check. + normalizedSQL := parser.Normalize(sql) + require.Equalf(t, normalizedSQL, sqlStr, "sql: %v", sql) + // decode plan before check. + normalizedPlan, err := plancodec.DecodeNormalizedPlan(encodedPlan) + require.NoError(t, err) + // remove '\n' '\t' before do regexp match. + normalizedPlan = strings.Replace(normalizedPlan, "\n", " ", -1) + normalizedPlan = strings.Replace(normalizedPlan, "\t", " ", -1) + require.Regexpf(t, planRegexp, normalizedPlan, "sql: %v", sql) + } + } + + // Test case 1: DML query: insert/update/replace/delete/select + cases1 := []struct { + sql string + planRegexp string + }{ + {sql: "insert into t () values (),(),(),(),(),(),();", planRegexp: ""}, + {sql: "insert into t (b) values (1),(1),(1),(1),(1),(1),(1),(1);", planRegexp: ""}, + {sql: "update t set b=a where b is null limit 1;", planRegexp: ".*Limit.*TableReader.*"}, + {sql: "delete from t where b = a limit 2;", planRegexp: ".*Limit.*TableReader.*"}, + {sql: "replace into t (b) values (1),(1),(1),(1),(1),(1),(1),(1);", planRegexp: ""}, + {sql: "select * from t use index(idx) where a<10;", planRegexp: ".*IndexLookUp.*"}, + {sql: "select * from t ignore index(idx) where a>1000000000;", planRegexp: ".*TableReader.*"}, + {sql: "select /*+ HASH_JOIN(t1, t2) */ * from t t1 join t t2 on t1.a=t2.a where t1.b is not null;", planRegexp: ".*HashJoin.*"}, + {sql: "select /*+ INL_HASH_JOIN(t1, t2) */ * from t t1 join t t2 on t2.a=t1.a where t1.b is not null;", planRegexp: ".*IndexHashJoin.*"}, + {sql: "select * from t where a=1;", planRegexp: ".*Point_Get.*"}, + {sql: "select * from t where a in (1,2,3,4)", planRegexp: ".*Batch_Point_Get.*"}, + } + execFn := func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + for _, ca := range cases1 { + sqlStr := ca.sql + if strings.HasPrefix(sqlStr, "select") { + mustQuery(t, dbt, sqlStr) + } else { + dbt.MustExec(sqlStr) + } + } + } + check := func() { + for _, ca := range cases1 { + checkFn(ca.sql, ca.planRegexp) + } + } + ts.testCase(t, mc, execFn, check) + + // Test case 2: prepare/execute sql + cases2 := []struct { + prepare string + args []interface{} + planRegexp string + }{ + {prepare: "insert into t1 (b) values (?);", args: []interface{}{1}, planRegexp: ""}, + {prepare: "replace into t1 (b) values (?);", args: []interface{}{1}, planRegexp: ""}, + {prepare: "update t1 set b=a where b is null limit ?;", args: []interface{}{1}, planRegexp: ".*Limit.*TableReader.*"}, + {prepare: "delete from t1 where b = a limit ?;", args: []interface{}{1}, planRegexp: ".*Limit.*TableReader.*"}, + {prepare: "replace into t1 (b) values (?);", args: []interface{}{1}, planRegexp: ""}, + {prepare: "select * from t1 use index(idx) where a?;", args: []interface{}{1000000000}, planRegexp: ".*TableReader.*"}, + {prepare: "select /*+ HASH_JOIN(t1, t2) */ * from t1 t1 join t1 t2 on t1.a=t2.a where t1.b is not null;", args: nil, planRegexp: ".*HashJoin.*"}, + {prepare: "select /*+ INL_HASH_JOIN(t1, t2) */ * from t1 t1 join t1 t2 on t2.a=t1.a where t1.b is not null;", args: nil, planRegexp: ".*IndexHashJoin.*"}, + {prepare: "select * from t1 where a=?;", args: []interface{}{1}, planRegexp: ".*Point_Get.*"}, + {prepare: "select * from t1 where a in (?,?,?,?)", args: []interface{}{1, 2, 3, 4}, planRegexp: ".*Batch_Point_Get.*"}, + } + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + for _, ca := range cases2 { + prepare, args := ca.prepare, ca.args + stmt := dbt.MustPrepare(prepare) + if strings.HasPrefix(prepare, "select") { + rows, err := stmt.Query(args...) + require.NoError(t, err) + for rows.Next() { + } + require.NoError(t, rows.Close()) + } else { + _, err = stmt.Exec(args...) + require.NoError(t, err) + } + } + } + check = func() { + for _, ca := range cases2 { + checkFn(ca.prepare, ca.planRegexp) + } + } + ts.testCase(t, mc, execFn, check) + + // Test case 3: prepare, execute stmt using @val... + cases3 := []struct { + prepare string + args []interface{} + planRegexp string + }{ + {prepare: "insert into t2 (b) values (?);", args: []interface{}{1}, planRegexp: ""}, + {prepare: "update t2 set b=a where b is null limit ?;", args: []interface{}{1}, planRegexp: ".*Limit.*TableReader.*"}, + {prepare: "delete from t2 where b = a limit ?;", args: []interface{}{1}, planRegexp: ".*Limit.*TableReader.*"}, + {prepare: "replace into t2 (b) values (?);", args: []interface{}{1}, planRegexp: ""}, + {prepare: "select * from t2 use index(idx) where a?;", args: []interface{}{1000000000}, planRegexp: ".*TableReader.*"}, + {prepare: "select /*+ HASH_JOIN(t1, t2) */ * from t2 t1 join t2 t2 on t1.a=t2.a where t1.b is not null;", args: nil, planRegexp: ".*HashJoin.*"}, + {prepare: "select /*+ INL_HASH_JOIN(t1, t2) */ * from t2 t1 join t2 t2 on t2.a=t1.a where t1.b is not null;", args: nil, planRegexp: ".*IndexHashJoin.*"}, + {prepare: "select * from t2 where a=?;", args: []interface{}{1}, planRegexp: ".*Point_Get.*"}, + {prepare: "select * from t2 where a in (?,?,?,?)", args: []interface{}{1, 2, 3, 4}, planRegexp: ".*Batch_Point_Get.*"}, + } + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + for _, ca := range cases3 { + prepare, args := ca.prepare, ca.args + dbt.MustExec(fmt.Sprintf("prepare stmt from '%v'", prepare)) + + var params []string + for i := range args { + param := 'a' + i + dbt.MustExec(fmt.Sprintf("set @%c=%v", param, args[i])) + params = append(params, fmt.Sprintf("@%c", param)) + } + + sqlStr := "execute stmt" + if len(params) > 0 { + sqlStr += " using " + sqlStr += strings.Join(params, ",") + } + if strings.HasPrefix(prepare, "select") { + mustQuery(t, dbt, sqlStr) + } else { + dbt.MustExec(sqlStr) + } + } + } + check = func() { + for _, ca := range cases3 { + checkFn(ca.prepare, ca.planRegexp) + } + } + ts.testCase(t, mc, execFn, check) + + // Test case for other statements + cases4 := []struct { + sql string + plan string + isQuery bool + }{ + {"begin", "", false}, + {"insert into t () values (),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),()", "", false}, + {"commit", "", false}, + {"analyze table t", "", false}, + {"explain analyze select sum(a+b) from t", ".*TableReader.*", true}, + {"trace select sum(b*a), sum(a+b) from t", "", true}, + {"set global tidb_stmt_summary_history_size=5;", "", false}, + } + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + for _, ca := range cases4 { + if ca.isQuery { + mustQuery(t, dbt, ca.sql) + } else { + dbt.MustExec(ca.sql) + } + } + } + check = func() { + for _, ca := range cases4 { + checkFn(ca.sql, ca.plan) + } + // check for internal SQL. + checkFn("replace into mysql.global_variables (variable_name,variable_value) values ('tidb_stmt_summary_history_size', '5')", "") + } + ts.testCase(t, mc, execFn, check) + + // Test case for multi-statement. + cases5 := []string{ + "delete from t limit 1;", + "update t set b=1 where b is null limit 1;", + "select sum(a+b*2) from t;", + } + multiStatement5 := strings.Join(cases5, "") + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("SET tidb_multi_statement_mode='ON'") + dbt.MustExec(multiStatement5) + } + check = func() { + for _, sqlStr := range cases5 { + checkFn(sqlStr, ".*TableReader.*") + } + } + ts.testCase(t, mc, execFn, check) + + // Test case for multi-statement, but first statements execute failed + cases6 := []string{ + "delete from t_not_exist;", + "update t set a=1 where a is null limit 1;", + } + multiStatement6 := strings.Join(cases6, "") + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("SET tidb_multi_statement_mode='ON'") + _, err := db.Exec(multiStatement6) + require.NotNil(t, err) + require.Equal(t, "Error 1146: Table 'topsql.t_not_exist' doesn't exist", err.Error()) + } + check = func() { + for i := 1; i < len(cases6); i++ { + sqlStr := cases6[i] + stats := mc.GetSQLStatsBySQL(sqlStr, false) + require.Equal(t, 0, len(stats), sqlStr) + } + } + ts.testCase(t, mc, execFn, check) + + // Test case for multi-statement, the first statements execute success but the second statement execute failed. + cases7 := []string{ + "update t set a=1 where a <0 limit 1;", + "delete from t_not_exist;", + } + multiStatement7 := strings.Join(cases7, "") + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("SET tidb_multi_statement_mode='ON'") + _, err = db.Exec(multiStatement7) + require.NotNil(t, err) + require.Equal(t, "Error 1146 (42S02): Table 'topsql.t_not_exist' doesn't exist", err.Error()) + } + check = func() { + checkFn(cases7[0], "") // the first statement execute success, should have topsql data. + } + ts.testCase(t, mc, execFn, check) + + // Test case for statement with wrong syntax. + wrongSyntaxSQL := "select * froms t" + execFn = func(db *sql.DB) { + _, err = db.Exec(wrongSyntaxSQL) + require.NotNil(t, err) + require.Regexp(t, "Error 1064: You have an error in your SQL syntax...", err.Error()) + } + check = func() { + stats := mc.GetSQLStatsBySQL(wrongSyntaxSQL, false) + require.Equal(t, 0, len(stats), wrongSyntaxSQL) + } + ts.testCase(t, mc, execFn, check) + + // Test case for high cost of plan optimize. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/planner/mockHighLoadForOptimize", "return")) + selectSQL := "select sum(a+b), count(distinct b) from t where a+b >0" + updateSQL := "update t set a=a+100 where a > 10000000" + selectInPlanSQL := "select * from t where exists (select 1 from t1 where t1.a = 1);" + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + mustQuery(t, dbt, selectSQL) + dbt.MustExec(updateSQL) + mustQuery(t, dbt, selectInPlanSQL) + } + check = func() { + checkFn(selectSQL, "") + checkFn(updateSQL, "") + selectCPUTime := mc.GetSQLCPUTimeBySQL(selectSQL) + updateCPUTime := mc.GetSQLCPUTimeBySQL(updateSQL) + require.Less(t, updateCPUTime, selectCPUTime) + checkFn(selectInPlanSQL, "") + } + ts.testCase(t, mc, execFn, check) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/planner/mockHighLoadForOptimize")) + + // Test case for DDL execute failed but should still have CPU data. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/mockHighLoadForAddIndex", "return")) + dbt.MustExec(fmt.Sprintf("insert into t values (%v,%v), (%v, %v);", 2000, 1, 2001, 1)) + addIndexStr := "alter table t add unique index idx_b (b)" + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("alter table t drop index if exists idx_b") + _, err := db.Exec(addIndexStr) + require.NotNil(t, err) + require.Equal(t, "Error 1062 (23000): Duplicate entry '1' for key 't.idx_b'", err.Error()) + } + check = func() { + checkFn(addIndexStr, "") + } + ts.testCase(t, mc, execFn, check) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/mockHighLoadForAddIndex")) + + // Test case for execute failed cause by storage error. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/store/copr/handleTaskOnceError", `return(true)`)) + execFailedQuery := "select * from t where a*b < 1000" + execFn = func(db *sql.DB) { + _, err = db.Query(execFailedQuery) + require.NotNil(t, err) + require.Equal(t, "Error 1105 (HY000): mock handleTaskOnce error", err.Error()) + } + check = func() { + checkFn(execFailedQuery, "") + } + ts.testCase(t, mc, execFn, check) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/store/copr/handleTaskOnceError")) +} + +func (ts *tidbTestTopSQLSuite) testCase(t *testing.T, mc *mockTopSQLTraceCPU.TopSQLCollector, execFn func(db *sql.DB), checkFn func()) { + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + wg.Add(1) + go func() { + defer wg.Done() + ts.loopExec(ctx, t, execFn) + }() + + checkFn() + cancel() + wg.Wait() + mc.Reset() +} + +func mustQuery(t *testing.T, dbt *testkit.DBTestKit, query string) { + rows := dbt.MustQuery(query) + for rows.Next() { + } + err := rows.Close() + require.NoError(t, err) +} + +type mockCollector struct { + f func(data stmtstats.StatementStatsMap) +} + +func newMockCollector(f func(data stmtstats.StatementStatsMap)) stmtstats.Collector { + return &mockCollector{f: f} +} + +func (c *mockCollector) CollectStmtStatsMap(data stmtstats.StatementStatsMap) { + c.f(data) +} + +func waitCollected(ch chan struct{}) { + select { + case <-ch: + case <-time.After(time.Second * 3): + } +} + +func TestTopSQLStatementStats(t *testing.T) { + ts, total, tagChecker, collectedNotifyCh := setupForTestTopSQLStatementStats(t) + + const ExecCountPerSQL = 2 + // Test for CRUD. + cases1 := []string{ + "insert into t values (%d, sleep(0.1))", + "update t set a = %[1]d + 1000 where a = %[1]d and sleep(0.1);", + "select a from t where b = %d and sleep(0.1);", + "select a from t where a = %d and sleep(0.1);", // test for point-get + "delete from t where a = %d and sleep(0.1);", + "insert into t values (%d, sleep(0.1)) on duplicate key update b = b+1", + } + var wg sync.WaitGroup + sqlDigests := map[stmtstats.BinaryDigest]string{} + for i, ca := range cases1 { + sqlStr := fmt.Sprintf(ca, i) + _, digest := parser.NormalizeDigest(sqlStr) + sqlDigests[stmtstats.BinaryDigest(digest.Bytes())] = sqlStr + } + wg.Add(1) + go func() { + defer wg.Done() + for _, ca := range cases1 { + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use stmtstats;") + for n := 0; n < ExecCountPerSQL; n++ { + sqlStr := fmt.Sprintf(ca, n) + if strings.HasPrefix(strings.ToLower(sqlStr), "select") { + mustQuery(t, dbt, sqlStr) + } else { + dbt.MustExec(sqlStr) + } + } + err = db.Close() + require.NoError(t, err) + } + }() + + // Test for prepare stmt/execute stmt + cases2 := []struct { + prepare string + execStmt string + setSQLsGen func(idx int) []string + execSQL string + }{ + { + prepare: "prepare stmt from 'insert into t2 values (?, sleep(?))';", + execStmt: "insert into t2 values (1, sleep(0.1))", + setSQLsGen: func(idx int) []string { + return []string{fmt.Sprintf("set @a=%v", idx), "set @b=0.1"} + }, + execSQL: "execute stmt using @a, @b;", + }, + { + prepare: "prepare stmt from 'update t2 set a = a + 1000 where a = ? and sleep(?);';", + execStmt: "update t2 set a = a + 1000 where a = 1 and sleep(0.1);", + setSQLsGen: func(idx int) []string { + return []string{fmt.Sprintf("set @a=%v", idx), "set @b=0.1"} + }, + execSQL: "execute stmt using @a, @b;", + }, + { + // test for point-get + prepare: "prepare stmt from 'select a, sleep(?) from t2 where a = ?';", + execStmt: "select a, sleep(?) from t2 where a = ?", + setSQLsGen: func(idx int) []string { + return []string{"set @a=0.1", fmt.Sprintf("set @b=%v", idx)} + }, + execSQL: "execute stmt using @a, @b;", + }, + { + prepare: "prepare stmt from 'select a, sleep(?) from t2 where b = ?';", + execStmt: "select a, sleep(?) from t2 where b = ?", + setSQLsGen: func(idx int) []string { + return []string{"set @a=0.1", fmt.Sprintf("set @b=%v", idx)} + }, + execSQL: "execute stmt using @a, @b;", + }, + { + prepare: "prepare stmt from 'delete from t2 where sleep(?) and a = ?';", + execStmt: "delete from t2 where sleep(0.1) and a = 1", + setSQLsGen: func(idx int) []string { + return []string{"set @a=0.1", fmt.Sprintf("set @b=%v", idx)} + }, + execSQL: "execute stmt using @a, @b;", + }, + { + prepare: "prepare stmt from 'insert into t2 values (?, sleep(?)) on duplicate key update b = b+1';", + execStmt: "insert into t2 values (1, sleep(0.1)) on duplicate key update b = b+1", + setSQLsGen: func(idx int) []string { + return []string{fmt.Sprintf("set @a=%v", idx), "set @b=0.1"} + }, + execSQL: "execute stmt using @a, @b;", + }, + { + prepare: "prepare stmt from 'set global tidb_enable_top_sql = (? = sleep(?))';", + execStmt: "set global tidb_enable_top_sql = (0 = sleep(0.1))", + setSQLsGen: func(idx int) []string { + return []string{"set @a=0", "set @b=0.1"} + }, + execSQL: "execute stmt using @a, @b;", + }, + } + for _, ca := range cases2 { + _, digest := parser.NormalizeDigest(ca.execStmt) + sqlDigests[stmtstats.BinaryDigest(digest.Bytes())] = ca.execStmt + } + wg.Add(1) + go func() { + defer wg.Done() + for _, ca := range cases2 { + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use stmtstats;") + // prepare stmt + dbt.MustExec(ca.prepare) + for n := 0; n < ExecCountPerSQL; n++ { + setSQLs := ca.setSQLsGen(n) + for _, setSQL := range setSQLs { + dbt.MustExec(setSQL) + } + if strings.HasPrefix(strings.ToLower(ca.execStmt), "select") { + mustQuery(t, dbt, ca.execSQL) + } else { + dbt.MustExec(ca.execSQL) + } + } + err = db.Close() + require.NoError(t, err) + } + }() + + // Test for prepare by db client prepare/exec interface. + cases3 := []struct { + prepare string + execStmt string + argsGen func(idx int) []interface{} + }{ + { + prepare: "insert into t3 values (?, sleep(?))", + argsGen: func(idx int) []interface{} { + return []interface{}{idx, 0.1} + }, + }, + { + prepare: "update t3 set a = a + 1000 where a = ? and sleep(?)", + argsGen: func(idx int) []interface{} { + return []interface{}{idx, 0.1} + }, + }, + { + // test for point-get + prepare: "select a, sleep(?) from t3 where a = ?", + argsGen: func(idx int) []interface{} { + return []interface{}{0.1, idx} + }, + }, + { + prepare: "select a, sleep(?) from t3 where b = ?", + argsGen: func(idx int) []interface{} { + return []interface{}{0.1, idx} + }, + }, + { + prepare: "delete from t3 where sleep(?) and a = ?", + argsGen: func(idx int) []interface{} { + return []interface{}{0.1, idx} + }, + }, + { + prepare: "insert into t3 values (?, sleep(?)) on duplicate key update b = b+1", + argsGen: func(idx int) []interface{} { + return []interface{}{idx, 0.1} + }, + }, + { + prepare: "set global tidb_enable_1pc = (? = sleep(?))", + argsGen: func(idx int) []interface{} { + return []interface{}{0, 0.1} + }, + }, + } + for _, ca := range cases3 { + _, digest := parser.NormalizeDigest(ca.prepare) + sqlDigests[stmtstats.BinaryDigest(digest.Bytes())] = ca.prepare + } + wg.Add(1) + go func() { + defer wg.Done() + for _, ca := range cases3 { + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use stmtstats;") + // prepare stmt + stmt, err := db.Prepare(ca.prepare) + require.NoError(t, err) + for n := 0; n < ExecCountPerSQL; n++ { + args := ca.argsGen(n) + if strings.HasPrefix(strings.ToLower(ca.prepare), "select") { + row, err := stmt.Query(args...) + require.NoError(t, err) + err = row.Close() + require.NoError(t, err) + } else { + _, err := stmt.Exec(args...) + require.NoError(t, err) + } + } + err = db.Close() + require.NoError(t, err) + } + }() + + wg.Wait() + // Wait for collect. + waitCollected(collectedNotifyCh) + + found := 0 + for digest, item := range total { + if sqlStr, ok := sqlDigests[digest.SQLDigest]; ok { + found++ + require.Equal(t, uint64(ExecCountPerSQL), item.ExecCount, sqlStr) + require.Equal(t, uint64(ExecCountPerSQL), item.DurationCount, sqlStr) + require.True(t, item.SumDurationNs > uint64(time.Millisecond*100*ExecCountPerSQL), sqlStr) + require.True(t, item.SumDurationNs < uint64(time.Millisecond*300*ExecCountPerSQL), sqlStr) + if strings.HasPrefix(sqlStr, "set global") { + // set global statement use internal SQL to change global variable, so itself doesn't have KV request. + continue + } + var kvSum uint64 + for _, kvCount := range item.KvStatsItem.KvExecCount { + kvSum += kvCount + } + require.Equal(t, uint64(ExecCountPerSQL), kvSum) + tagChecker.checkExist(t, digest.SQLDigest, sqlStr) + } + } + require.Equal(t, len(sqlDigests), found) + require.Equal(t, 20, found) +} + +type resourceTagChecker struct { + sync.Mutex + sqlDigest2Reqs map[stmtstats.BinaryDigest]map[tikvrpc.CmdType]struct{} +} + +func (c *resourceTagChecker) checkExist(t *testing.T, digest stmtstats.BinaryDigest, sqlStr string) { + if strings.HasPrefix(sqlStr, "set global") { + // `set global` statement will use another internal sql to execute, so `set global` statement won't + // send RPC request. + return + } + if strings.HasPrefix(sqlStr, "trace") { + // `trace` statement will use another internal sql to execute, so remove the `trace` prefix before check. + _, sqlDigest := parser.NormalizeDigest(strings.TrimPrefix(sqlStr, "trace")) + digest = stmtstats.BinaryDigest(sqlDigest.Bytes()) + } + + c.Lock() + defer c.Unlock() + _, ok := c.sqlDigest2Reqs[digest] + require.True(t, ok, sqlStr) +} + +func (c *resourceTagChecker) checkReqExist(t *testing.T, digest stmtstats.BinaryDigest, sqlStr string, reqs ...tikvrpc.CmdType) { + if len(reqs) == 0 { + return + } + c.Lock() + defer c.Unlock() + reqMap, ok := c.sqlDigest2Reqs[digest] + require.True(t, ok, sqlStr) + for _, req := range reqs { + _, ok := reqMap[req] + require.True(t, ok, fmt.Sprintf("sql: %v, expect: %v, got: %v", sqlStr, reqs, reqMap)) + } +} + +func setupForTestTopSQLStatementStats(t *testing.T) (*tidbTestSuite, stmtstats.StatementStatsMap, *resourceTagChecker, chan struct{}) { + // Prepare stmt stats. + stmtstats.SetupAggregator() + + // Register stmt stats collector. + var mu sync.Mutex + collectedNotifyCh := make(chan struct{}) + total := stmtstats.StatementStatsMap{} + mockCollector := newMockCollector(func(data stmtstats.StatementStatsMap) { + mu.Lock() + defer mu.Unlock() + total.Merge(data) + select { + case collectedNotifyCh <- struct{}{}: + default: + } + }) + stmtstats.RegisterCollector(mockCollector) + + ts := createTidbTestSuite(t) + + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/skipLoadSysVarCacheLoop", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/store/mockstore/unistore/unistoreRPCClientSendHook", `return(true)`)) + + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("drop database if exists stmtstats") + dbt.MustExec("create database stmtstats") + dbt.MustExec("use stmtstats;") + dbt.MustExec("create table t (a int, b int, unique index idx(a));") + dbt.MustExec("create table t2 (a int, b int, unique index idx(a));") + dbt.MustExec("create table t3 (a int, b int, unique index idx(a));") + + // Enable TopSQL + topsqlstate.EnableTopSQL() + config.UpdateGlobal(func(conf *config.Config) { + conf.TopSQL.ReceiverAddress = "mock-agent" + }) + + tagChecker := &resourceTagChecker{ + sqlDigest2Reqs: make(map[stmtstats.BinaryDigest]map[tikvrpc.CmdType]struct{}), + } + unistoreRPCClientSendHook := func(req *tikvrpc.Request) { + tag := req.GetResourceGroupTag() + if len(tag) == 0 || ddlutil.IsInternalResourceGroupTaggerForTopSQL(tag) { + // Ignore for internal background request. + return + } + sqlDigest, err := resourcegrouptag.DecodeResourceGroupTag(tag) + require.NoError(t, err) + tagChecker.Lock() + defer tagChecker.Unlock() + + reqMap, ok := tagChecker.sqlDigest2Reqs[stmtstats.BinaryDigest(sqlDigest)] + if !ok { + reqMap = make(map[tikvrpc.CmdType]struct{}) + } + reqMap[req.Type] = struct{}{} + tagChecker.sqlDigest2Reqs[stmtstats.BinaryDigest(sqlDigest)] = reqMap + } + unistore.UnistoreRPCClientSendHook.Store(&unistoreRPCClientSendHook) + + t.Cleanup(func() { + stmtstats.UnregisterCollector(mockCollector) + err = failpoint.Disable("github.com/pingcap/tidb/pkg/domain/skipLoadSysVarCacheLoop") + require.NoError(t, err) + err = failpoint.Disable("github.com/pingcap/tidb/pkg/store/mockstore/unistore/unistoreRPCClientSendHook") + require.NoError(t, err) + stmtstats.CloseAggregator() + view.Stop() + }) + + return ts, total, tagChecker, collectedNotifyCh +} + +func TestTopSQLStatementStats2(t *testing.T) { + ts, total, tagChecker, collectedNotifyCh := setupForTestTopSQLStatementStats(t) + + const ExecCountPerSQL = 3 + sqlDigests := map[stmtstats.BinaryDigest]string{} + + // Test case for other statements + cases4 := []struct { + sql string + plan string + isQuery bool + }{ + {"insert into t () values (),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),(),()", "", false}, + {"analyze table t", "", false}, + {"explain analyze select sum(a+b) from t", ".*TableReader.*", true}, + {"trace select sum(b*a), sum(a+b) from t", "", true}, + {"set global tidb_stmt_summary_history_size=5;", "", false}, + {"select * from stmtstats.t where exists (select 1 from stmtstats.t2 where t2.a = 1);", ".*TableReader.*", true}, + } + executeCaseFn := func(execFn func(db *sql.DB)) { + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use stmtstats;") + require.NoError(t, err) + + for n := 0; n < ExecCountPerSQL; n++ { + execFn(db) + } + err = db.Close() + require.NoError(t, err) + } + execFn := func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + for _, ca := range cases4 { + if ca.isQuery { + mustQuery(t, dbt, ca.sql) + } else { + dbt.MustExec(ca.sql) + } + } + } + for _, ca := range cases4 { + _, digest := parser.NormalizeDigest(ca.sql) + sqlDigests[stmtstats.BinaryDigest(digest.Bytes())] = ca.sql + } + executeCaseFn(execFn) + + // Test case for multi-statement. + cases5 := []string{ + "delete from t limit 1;", + "update t set b=1 where b is null limit 1;", + "select sum(a+b*2) from t;", + } + multiStatement5 := strings.Join(cases5, "") + // Test case for multi-statement, but first statements execute failed + cases6 := []string{ + "delete from t6_not_exist;", + "update t set a=1 where a is null limit 1;", + } + multiStatement6 := strings.Join(cases6, "") + // Test case for multi-statement, the first statements execute success but the second statement execute failed. + cases7 := []string{ + "update t set a=1 where a <0 limit 1;", + "delete from t7_not_exist;", + } + // Test case for DDL. + cases8 := []string{ + "create table if not exists t10 (a int, b int)", + "alter table t drop index if exists idx_b", + "alter table t add index idx_b (b)", + } + multiStatement7 := strings.Join(cases7, "") + execFn = func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("SET tidb_multi_statement_mode='ON'") + dbt.MustExec(multiStatement5) + + _, err := db.Exec(multiStatement6) + require.NotNil(t, err) + require.Equal(t, "Error 1146 (42S02): Table 'stmtstats.t6_not_exist' doesn't exist", err.Error()) + + _, err = db.Exec(multiStatement7) + require.NotNil(t, err) + require.Equal(t, "Error 1146 (42S02): Table 'stmtstats.t7_not_exist' doesn't exist", err.Error()) + + for _, ca := range cases8 { + dbt.MustExec(ca) + } + } + executeCaseFn(execFn) + sqlStrs := append([]string{}, cases5...) + sqlStrs = append(sqlStrs, cases7[0]) + sqlStrs = append(sqlStrs, cases8...) + for _, sqlStr := range sqlStrs { + _, digest := parser.NormalizeDigest(sqlStr) + sqlDigests[stmtstats.BinaryDigest(digest.Bytes())] = sqlStr + } + + // Wait for collect. + waitCollected(collectedNotifyCh) + + foundMap := map[stmtstats.BinaryDigest]string{} + for digest, item := range total { + if sqlStr, ok := sqlDigests[digest.SQLDigest]; ok { + require.Equal(t, uint64(ExecCountPerSQL), item.ExecCount, sqlStr) + require.True(t, item.SumDurationNs > 1, sqlStr) + foundMap[digest.SQLDigest] = sqlStr + tagChecker.checkExist(t, digest.SQLDigest, sqlStr) + // The special check uses to test the issue #33202. + if strings.Contains(strings.ToLower(sqlStr), "add index") { + tagChecker.checkReqExist(t, digest.SQLDigest, sqlStr, tikvrpc.CmdScan) + } + } + } + require.Equal(t, len(sqlDigests), len(foundMap), fmt.Sprintf("%v !=\n %v", sqlDigests, foundMap)) +} + +func TestTopSQLStatementStats3(t *testing.T) { + ts, total, tagChecker, collectedNotifyCh := setupForTestTopSQLStatementStats(t) + + err := failpoint.Enable("github.com/pingcap/tidb/pkg/executor/mockSleepInTableReaderNext", "return(2000)") + require.NoError(t, err) + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/pkg/executor/mockSleepInTableReaderNext") + }() + + cases := []string{ + "select count(a+b) from stmtstats.t", + "select * from stmtstats.t where b is null", + "update stmtstats.t set b = 1 limit 10", + "delete from stmtstats.t limit 1", + } + var wg sync.WaitGroup + sqlDigests := map[stmtstats.BinaryDigest]string{} + for _, ca := range cases { + wg.Add(1) + go func(sqlStr string) { + defer wg.Done() + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + dbt := testkit.NewDBTestKit(t, db) + require.NoError(t, err) + if strings.HasPrefix(sqlStr, "select") { + mustQuery(t, dbt, sqlStr) + } else { + dbt.MustExec(sqlStr) + } + err = db.Close() + require.NoError(t, err) + }(ca) + _, digest := parser.NormalizeDigest(ca) + sqlDigests[stmtstats.BinaryDigest(digest.Bytes())] = ca + } + // Wait for collect. + waitCollected(collectedNotifyCh) + + foundMap := map[stmtstats.BinaryDigest]string{} + for digest, item := range total { + if sqlStr, ok := sqlDigests[digest.SQLDigest]; ok { + // since the SQL doesn't execute finish, the ExecCount should be recorded, + // but the DurationCount and SumDurationNs should be 0. + require.Equal(t, uint64(1), item.ExecCount, sqlStr) + require.Equal(t, uint64(0), item.DurationCount, sqlStr) + require.Equal(t, uint64(0), item.SumDurationNs, sqlStr) + foundMap[digest.SQLDigest] = sqlStr + } + } + + // wait sql execute finish. + wg.Wait() + // Wait for collect. + waitCollected(collectedNotifyCh) + + for digest, item := range total { + if sqlStr, ok := sqlDigests[digest.SQLDigest]; ok { + require.Equal(t, uint64(1), item.ExecCount, sqlStr) + require.Equal(t, uint64(1), item.DurationCount, sqlStr) + require.Less(t, uint64(0), item.SumDurationNs, sqlStr) + foundMap[digest.SQLDigest] = sqlStr + tagChecker.checkExist(t, digest.SQLDigest, sqlStr) + } + } +} + +func TestTopSQLStatementStats4(t *testing.T) { + ts, total, tagChecker, collectedNotifyCh := setupForTestTopSQLStatementStats(t) + + err := failpoint.Enable("github.com/pingcap/tidb/pkg/executor/mockSleepInTableReaderNext", "return(2000)") + require.NoError(t, err) + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/pkg/executor/mockSleepInTableReaderNext") + }() + + cases := []struct { + prepare string + sql string + args []interface{} + }{ + {prepare: "select count(a+b) from stmtstats.t", sql: "select count(a+b) from stmtstats.t"}, + {prepare: "select * from stmtstats.t where b is null", sql: "select * from stmtstats.t where b is null"}, + {prepare: "update stmtstats.t set b = ? limit ?", sql: "update stmtstats.t set b = 1 limit 10", args: []interface{}{1, 10}}, + {prepare: "delete from stmtstats.t limit ?", sql: "delete from stmtstats.t limit 1", args: []interface{}{1}}, + } + var wg sync.WaitGroup + sqlDigests := map[stmtstats.BinaryDigest]string{} + for _, ca := range cases { + wg.Add(1) + go func(prepare string, args []interface{}) { + defer wg.Done() + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + stmt, err := db.Prepare(prepare) + require.NoError(t, err) + if strings.HasPrefix(prepare, "select") { + rows, err := stmt.Query(args...) + require.NoError(t, err) + for rows.Next() { + } + err = rows.Close() + require.NoError(t, err) + } else { + _, err := stmt.Exec(args...) + require.NoError(t, err) + } + err = db.Close() + require.NoError(t, err) + }(ca.prepare, ca.args) + _, digest := parser.NormalizeDigest(ca.sql) + sqlDigests[stmtstats.BinaryDigest(digest.Bytes())] = ca.sql + } + // Wait for collect. + waitCollected(collectedNotifyCh) + + foundMap := map[stmtstats.BinaryDigest]string{} + for digest, item := range total { + if sqlStr, ok := sqlDigests[digest.SQLDigest]; ok { + // since the SQL doesn't execute finish, the ExecCount should be recorded, + // but the DurationCount and SumDurationNs should be 0. + require.Equal(t, uint64(1), item.ExecCount, sqlStr) + require.Equal(t, uint64(0), item.DurationCount, sqlStr) + require.Equal(t, uint64(0), item.SumDurationNs, sqlStr) + foundMap[digest.SQLDigest] = sqlStr + } + } + + // wait sql execute finish. + wg.Wait() + // Wait for collect. + waitCollected(collectedNotifyCh) + + for digest, item := range total { + if sqlStr, ok := sqlDigests[digest.SQLDigest]; ok { + require.Equal(t, uint64(1), item.ExecCount, sqlStr) + require.Equal(t, uint64(1), item.DurationCount, sqlStr) + require.Less(t, uint64(0), item.SumDurationNs, sqlStr) + foundMap[digest.SQLDigest] = sqlStr + tagChecker.checkExist(t, digest.SQLDigest, sqlStr) + } + } +} + +func TestTopSQLResourceTag(t *testing.T) { + ts, _, tagChecker, _ := setupForTestTopSQLStatementStats(t) + defer func() { + topsqlstate.DisableTopSQL() + }() + + loadDataFile, err := os.CreateTemp("", "load_data_test0.csv") + require.NoError(t, err) + defer func() { + path := loadDataFile.Name() + err = loadDataFile.Close() + require.NoError(t, err) + err = os.Remove(path) + require.NoError(t, err) + }() + _, err = loadDataFile.WriteString( + "31 31\n" + + "32 32\n" + + "33 33\n") + require.NoError(t, err) + + // Test case for other statements + cases := []struct { + sql string + isQuery bool + reqs []tikvrpc.CmdType + }{ + // Test for curd. + {"insert into t values (1,1), (3,3)", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + {"insert into t values (1,2) on duplicate key update a = 2", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdBatchGet}}, + {"update t set b=b+1 where a=3", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdGet}}, + {"update t set b=b+1 where a>1", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdCop}}, + {"delete from t where a=3", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdGet}}, + {"delete from t where a>1", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdCop}}, + {"insert ignore into t values (2,2), (3,3)", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdBatchGet}}, + {"select * from t where a in (1,2,3,4)", true, []tikvrpc.CmdType{tikvrpc.CmdBatchGet}}, + {"select * from t where a = 1", true, []tikvrpc.CmdType{tikvrpc.CmdGet}}, + {"select * from t where b > 0", true, []tikvrpc.CmdType{tikvrpc.CmdCop}}, + {"replace into t values (2,2), (4,4)", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdBatchGet}}, + + // Test for DDL + {"create database test_db0", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + {"create table test_db0.test_t0 (a int, b int, index idx(a))", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + {"create table test_db0.test_t1 (a int, b int, index idx(a))", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + {"alter table test_db0.test_t0 add column c int", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + {"drop table test_db0.test_t0", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + {"drop database test_db0", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + {"alter table t modify column b double", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdScan, tikvrpc.CmdCop}}, + {"alter table t add index idx2 (b,a)", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdScan, tikvrpc.CmdCop}}, + {"alter table t drop index idx2", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + + // Test for transaction + {"begin", false, nil}, + {"insert into t2 values (10,10), (11,11)", false, nil}, + {"insert ignore into t2 values (20,20), (21,21)", false, []tikvrpc.CmdType{tikvrpc.CmdBatchGet}}, + {"commit", false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit}}, + + // Test for other statements. + {"set @@global.tidb_enable_1pc = 1", false, nil}, + {fmt.Sprintf("load data local infile %q into table t2", loadDataFile.Name()), false, []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdBatchGet}}, + {"admin check table t", false, nil}, + {"admin check index t idx", false, nil}, + {"admin recover index t idx", false, []tikvrpc.CmdType{tikvrpc.CmdBatchGet}}, + {"admin cleanup index t idx", false, []tikvrpc.CmdType{tikvrpc.CmdBatchGet}}, + } + + internalCases := []struct { + sql string + reqs []tikvrpc.CmdType + }{ + {"replace into mysql.global_variables (variable_name,variable_value) values ('tidb_enable_1pc', '1')", []tikvrpc.CmdType{tikvrpc.CmdPrewrite, tikvrpc.CmdCommit, tikvrpc.CmdBatchGet}}, + {"select /*+ read_from_storage(tikv[`stmtstats`.`t`]) */ bit_xor(crc32(md5(concat_ws(0x2, `_tidb_rowid`, `a`)))), ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024), count(*) from `stmtstats`.`t` use index() where 0 = 0 group by ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024)", []tikvrpc.CmdType{tikvrpc.CmdCop}}, + {"select bit_xor(crc32(md5(concat_ws(0x2, `_tidb_rowid`, `a`)))), ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024), count(*) from `stmtstats`.`t` use index(`idx`) where 0 = 0 group by ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024)", []tikvrpc.CmdType{tikvrpc.CmdCop}}, + {"select /*+ read_from_storage(tikv[`stmtstats`.`t`]) */ bit_xor(crc32(md5(concat_ws(0x2, `_tidb_rowid`, `a`)))), ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024), count(*) from `stmtstats`.`t` use index() where 0 = 0 group by ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024)", []tikvrpc.CmdType{tikvrpc.CmdCop}}, + {"select bit_xor(crc32(md5(concat_ws(0x2, `_tidb_rowid`, `a`)))), ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024), count(*) from `stmtstats`.`t` use index(`idx`) where 0 = 0 group by ((cast(crc32(md5(concat_ws(0x2, `_tidb_rowid`))) as signed) - 0) div 1 % 1024)", []tikvrpc.CmdType{tikvrpc.CmdCop}}, + } + executeCaseFn := func(execFn func(db *sql.DB)) { + dsn := ts.GetDSN(func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }) + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use stmtstats;") + require.NoError(t, err) + + execFn(db) + err = db.Close() + require.NoError(t, err) + } + execFn := func(db *sql.DB) { + dbt := testkit.NewDBTestKit(t, db) + for _, ca := range cases { + if ca.isQuery { + mustQuery(t, dbt, ca.sql) + } else { + dbt.MustExec(ca.sql) + } + } + } + executeCaseFn(execFn) + + for _, ca := range cases { + _, digest := parser.NormalizeDigest(ca.sql) + tagChecker.checkReqExist(t, stmtstats.BinaryDigest(digest.Bytes()), ca.sql, ca.reqs...) + } + for _, ca := range internalCases { + _, digest := parser.NormalizeDigest(ca.sql) + tagChecker.checkReqExist(t, stmtstats.BinaryDigest(digest.Bytes()), ca.sql, ca.reqs...) + } +} + +func (ts *tidbTestTopSQLSuite) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use topsql;") + for { + select { + case <-ctx.Done(): + return + default: + } + fn(db) + } +} + +func TestLocalhostClientMapping(t *testing.T) { + tempDir := t.TempDir() + socketFile := tempDir + "/tidbtest.sock" // Unix Socket does not work on Windows, so '/' should be OK + + cli := testserverclient.NewTestServerClient() + cfg := util2.NewTestConfig() + cfg.Socket = socketFile + cfg.Port = cli.Port + cfg.Status.ReportStatus = false + + ts := createTidbTestSuite(t) + + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + server.SetDomain(ts.domain) + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + go func() { + err := server.Run() + require.NoError(t, err) + }() + defer server.Close() + cli.WaitUntilServerCanConnect() + + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + // Create a db connection for root + db, err := sql.Open("mysql", cli.GetDSN(func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.DBName = "test" + config.Addr = socketFile + })) + require.NoErrorf(t, err, "Open failed") + err = db.Ping() + require.NoErrorf(t, err, "Ping failed") + defer db.Close() + dbt := testkit.NewDBTestKit(t, db) + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "root@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT ALL PRIVILEGES ON *.* TO 'root'@'%' WITH GRANT OPTION") + require.NoError(t, rows.Close()) + + dbt.MustExec("CREATE USER 'localhostuser'@'localhost'") + dbt.MustExec("CREATE USER 'localhostuser'@'%'") + defer func() { + dbt.MustExec("DROP USER IF EXISTS 'localhostuser'@'%'") + dbt.MustExec("DROP USER IF EXISTS 'localhostuser'@'localhost'") + dbt.MustExec("DROP USER IF EXISTS 'localhostuser'@'127.0.0.1'") + }() + + dbt.MustExec("GRANT SELECT ON test.* TO 'localhostuser'@'%'") + dbt.MustExec("GRANT SELECT,UPDATE ON test.* TO 'localhostuser'@'localhost'") + + // Test with loopback interface - Should get access to localhostuser@localhost! + cli.RunTests(t, func(config *mysql.Config) { + config.User = "localhostuser" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + // NOTICE: this is not compatible with MySQL! (MySQL would report localhostuser@localhost also for 127.0.0.1) + cli.CheckRows(t, rows, "localhostuser@127.0.0.1") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'localhostuser'@'localhost'\nGRANT SELECT,UPDATE ON `test`.* TO 'localhostuser'@'localhost'") + require.NoError(t, rows.Close()) + }) + + dbt.MustExec("DROP USER IF EXISTS 'localhostuser'@'localhost'") + dbt.MustExec("CREATE USER 'localhostuser'@'127.0.0.1'") + dbt.MustExec("GRANT SELECT,UPDATE ON test.* TO 'localhostuser'@'127.0.0.1'") + // Test with unix domain socket file connection - Should get access to '%' + cli.RunTests(t, func(config *mysql.Config) { + config.Net = "unix" + config.Addr = socketFile + config.User = "localhostuser" + config.DBName = "test" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select user()") + cli.CheckRows(t, rows, "localhostuser@localhost") + require.NoError(t, rows.Close()) + rows = dbt.MustQuery("show grants") + cli.CheckRows(t, rows, "GRANT USAGE ON *.* TO 'localhostuser'@'%'\nGRANT SELECT ON `test`.* TO 'localhostuser'@'%'") + require.NoError(t, rows.Close()) + }) + + // Test if only localhost exists + dbt.MustExec("DROP USER 'localhostuser'@'%'") + dbSocket, err := sql.Open("mysql", cli.GetDSN(func(config *mysql.Config) { + config.User = "localhostuser" + config.Net = "unix" + config.DBName = "test" + config.Addr = socketFile + })) + require.NoErrorf(t, err, "Open failed") + defer dbSocket.Close() + err = dbSocket.Ping() + require.Errorf(t, err, "Connection successful without matching host for unix domain socket!") +} + +func TestRcReadCheckTS(t *testing.T) { + ts := createTidbTestSuite(t) + + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + db2, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + err := db2.Close() + require.NoError(t, err) + }() + tk2 := testkit.NewDBTestKit(t, db2) + tk2.MustExec("set @@tidb_enable_async_commit = 0") + tk2.MustExec("set @@tidb_enable_1pc = 0") + + cli := testserverclient.NewTestServerClient() + + tk := testkit.NewDBTestKit(t, db) + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int key, c2 int)") + tk.MustExec("insert into t1 values(1, 10), (2, 20), (3, 30)") + + tk.MustExec(`set tidb_rc_read_check_ts = 'on';`) + tk.MustExec(`set tx_isolation = 'READ-COMMITTED';`) + tk.MustExec("begin pessimistic") + // Test point get retry. + rows := tk.MustQuery("select * from t1 where c1 = 1") + cli.CheckRows(t, rows, "1 10") + tk2.MustExec("update t1 set c2 = c2 + 1") + rows = tk.MustQuery("select * from t1 where c1 = 1") + cli.CheckRows(t, rows, "1 11") + // Test batch point get retry. + rows = tk.MustQuery("select * from t1 where c1 in (1, 3)") + cli.CheckRows(t, rows, "1 11", "3 31") + tk2.MustExec("update t1 set c2 = c2 + 1") + rows = tk.MustQuery("select * from t1 where c1 in (1, 3)") + cli.CheckRows(t, rows, "1 12", "3 32") + // Test scan retry. + rows = tk.MustQuery("select * from t1") + cli.CheckRows(t, rows, "1 12", "2 22", "3 32") + tk2.MustExec("update t1 set c2 = c2 + 1") + rows = tk.MustQuery("select * from t1") + cli.CheckRows(t, rows, "1 13", "2 23", "3 33") + // Test reverse scan retry. + rows = tk.MustQuery("select * from t1 order by c1 desc") + cli.CheckRows(t, rows, "3 33", "2 23", "1 13") + tk2.MustExec("update t1 set c2 = c2 + 1") + rows = tk.MustQuery("select * from t1 order by c1 desc") + cli.CheckRows(t, rows, "3 34", "2 24", "1 14") + + // Test retry caused by ongoing prewrite lock. + // As the `defaultLockTTL` is 3s and it's difficult to change it here, the lock + // test is implemented in the uft test cases. +} + +type connEventLogs struct { + sync.Mutex + types []extension.ConnEventTp + infos []extension.ConnEventInfo +} + +func (l *connEventLogs) add(tp extension.ConnEventTp, info *extension.ConnEventInfo) { + l.Lock() + defer l.Unlock() + l.types = append(l.types, tp) + l.infos = append(l.infos, *info) +} + +func (l *connEventLogs) reset() { + l.Lock() + defer l.Unlock() + l.types = l.types[:0] + l.infos = l.infos[:0] +} + +func (l *connEventLogs) check(fn func()) { + l.Lock() + defer l.Unlock() + fn() +} + +func (l *connEventLogs) waitEvent(tp extension.ConnEventTp) error { + totalSleep := 0 + for { + l.Lock() + if l.types[len(l.types)-1] == tp { + l.Unlock() + return nil + } + l.Unlock() + if totalSleep >= 10000 { + break + } + time.Sleep(time.Millisecond * 100) + totalSleep += 100 + } + return errors.New("timeout") +} + +func TestExtensionConnEvent(t *testing.T) { + defer extension.Reset() + extension.Reset() + + logs := &connEventLogs{} + require.NoError(t, extension.Register("test", extension.WithSessionHandlerFactory(func() *extension.SessionHandler { + return &extension.SessionHandler{ + OnConnectionEvent: logs.add, + } + }))) + require.NoError(t, extension.Setup()) + + ts := createTidbTestSuite(t) + // createTidbTestSuite create an inner connection, so wait the previous connection closed + require.NoError(t, logs.waitEvent(extension.ConnDisconnected)) + + // test for login success + logs.reset() + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + var expectedConn2 variable.ConnectionInfo + require.NoError(t, logs.waitEvent(extension.ConnHandshakeAccepted)) + logs.check(func() { + require.Equal(t, []extension.ConnEventTp{ + extension.ConnConnected, + extension.ConnHandshakeAccepted, + }, logs.types) + conn1 := logs.infos[0] + require.Equal(t, "127.0.0.1", conn1.ClientIP) + require.Equal(t, "127.0.0.1", conn1.ServerIP) + require.Empty(t, conn1.User) + require.Empty(t, conn1.DB) + require.Equal(t, int(ts.Port), conn1.ServerPort) + require.NotEqual(t, conn1.ServerPort, conn1.ClientPort) + require.NotEmpty(t, conn1.ConnectionID) + require.Nil(t, conn1.ActiveRoles) + require.NoError(t, conn1.Error) + require.Empty(t, conn1.SessionAlias) + + expectedConn2 = *(conn1.ConnectionInfo) + expectedConn2.User = "root" + expectedConn2.DB = "test" + require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles) + require.Nil(t, logs.infos[1].Error) + require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo)) + require.Empty(t, logs.infos[1].SessionAlias) + }) + + _, err = conn.ExecContext(context.TODO(), "create role r1@'%'") + require.NoError(t, err) + _, err = conn.ExecContext(context.TODO(), "grant r1 TO root") + require.NoError(t, err) + _, err = conn.ExecContext(context.TODO(), "set role all") + require.NoError(t, err) + _, err = conn.ExecContext(context.TODO(), "set @@tidb_session_alias='alias123'") + require.NoError(t, err) + + require.NoError(t, conn.Close()) + require.NoError(t, db.Close()) + require.NoError(t, logs.waitEvent(extension.ConnDisconnected)) + logs.check(func() { + require.Equal(t, 3, len(logs.infos)) + require.Equal(t, 1, len(logs.infos[2].ActiveRoles)) + require.Equal(t, auth.RoleIdentity{ + Username: "r1", + Hostname: "%", + }, *logs.infos[2].ActiveRoles[0]) + require.Nil(t, logs.infos[2].Error) + require.Equal(t, expectedConn2, *(logs.infos[2].ConnectionInfo)) + require.Equal(t, "alias123", logs.infos[2].SessionAlias) + }) + + // test for login failed + logs.reset() + cfg := mysql.NewConfig() + cfg.User = "noexist" + cfg.Net = "tcp" + cfg.Addr = fmt.Sprintf("127.0.0.1:%d", ts.Port) + cfg.DBName = "test" + + db, err = sql.Open("mysql", cfg.FormatDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + _, err = db.Conn(context.Background()) + require.Error(t, err) + require.NoError(t, logs.waitEvent(extension.ConnDisconnected)) + logs.check(func() { + require.Equal(t, []extension.ConnEventTp{ + extension.ConnConnected, + extension.ConnHandshakeRejected, + extension.ConnDisconnected, + }, logs.types) + conn1 := logs.infos[0] + require.Equal(t, "127.0.0.1", conn1.ClientIP) + require.Equal(t, "127.0.0.1", conn1.ServerIP) + require.Empty(t, conn1.User) + require.Empty(t, conn1.DB) + require.Equal(t, int(ts.Port), conn1.ServerPort) + require.NotEqual(t, conn1.ServerPort, conn1.ClientPort) + require.NotEmpty(t, conn1.ConnectionID) + require.Nil(t, conn1.ActiveRoles) + require.NoError(t, conn1.Error) + require.Empty(t, conn1.SessionAlias) + + expectedConn2 = *(conn1.ConnectionInfo) + expectedConn2.User = "noexist" + expectedConn2.DB = "test" + require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles) + require.EqualError(t, logs.infos[1].Error, "[server:1045]Access denied for user 'noexist'@'127.0.0.1' (using password: NO)") + require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo)) + require.Empty(t, logs.infos[2].SessionAlias) + }) +} + +func TestSandBoxMode(t *testing.T) { + ts := createTidbTestSuite(t) + qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + _, err = Execute(context.Background(), qctx, "create user testuser;") + require.NoError(t, err) + qctx.Session.GetSessionVars().User = &auth.UserIdentity{Username: "testuser", AuthUsername: "testuser", AuthHostname: "%"} + + alterPwdStmts := []string{ + "set password = '1234';", + "alter user testuser identified by '1234';", + "alter user current_user() identified by '1234';", + } + + for _, alterPwdStmt := range alterPwdStmts { + require.False(t, qctx.Session.InSandBoxMode()) + _, err = Execute(context.Background(), qctx, "select 1;") + require.NoError(t, err) + + qctx.Session.EnableSandBoxMode() + require.True(t, qctx.Session.InSandBoxMode()) + _, err = Execute(context.Background(), qctx, "select 1;") + require.Error(t, err) + _, err = Execute(context.Background(), qctx, "alter user testuser identified with 'mysql_native_password';") + require.Error(t, err) + _, err = Execute(context.Background(), qctx, alterPwdStmt) + require.NoError(t, err) + _, err = Execute(context.Background(), qctx, "select 1;") + require.NoError(t, err) + } +} + +// See: https://github.com/pingcap/tidb/issues/40979 +// Reusing memory of `chunk.Chunk` may cause some systems variable's memory value to be modified unexpectedly. +func TestChunkReuseCorruptSysVarString(t *testing.T) { + ts := createTidbTestSuite(t) + + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close()) + }() + + rs, err := conn.QueryContext(context.Background(), "show tables in test") + ts.Rows(t, rs) + require.NoError(t, err) + + _, err = conn.ExecContext(context.Background(), "set @@time_zone=(select 'Asia/Shanghai')") + require.NoError(t, err) + + rs, err = conn.QueryContext(context.Background(), "select TIDB_TABLE_ID from information_schema.tables where TABLE_SCHEMA='aaaa'") + ts.Rows(t, rs) + require.NoError(t, err) + + rs, err = conn.QueryContext(context.Background(), "select @@time_zone") + require.NoError(t, err) + defer func() { + require.NoError(t, rs.Close()) + }() + + rows := ts.Rows(t, rs) + require.Equal(t, 1, len(rows)) + require.Equal(t, "Asia/Shanghai", rows[0]) +} + +type mockProxyProtocolProxy struct { + frontend string + backend string + clientAddr string + backendIsSock bool + ln net.Listener + run atomic.Bool +} + +func newMockProxyProtocolProxy(frontend, backend, clientAddr string, backendIsSock bool) *mockProxyProtocolProxy { + return &mockProxyProtocolProxy{ + frontend: frontend, + backend: backend, + clientAddr: clientAddr, + backendIsSock: backendIsSock, + ln: nil, + } +} + +func (p *mockProxyProtocolProxy) ListenAddr() net.Addr { + return p.ln.Addr() +} + +func (p *mockProxyProtocolProxy) Run() (err error) { + p.run.Store(true) + p.ln, err = net.Listen("tcp", p.frontend) + if err != nil { + return err + } + for p.run.Load() { + conn, err := p.ln.Accept() + if err != nil { + break + } + go p.onConn(conn) + } + return nil +} + +func (p *mockProxyProtocolProxy) Close() error { + p.run.Store(false) + if p.ln != nil { + return p.ln.Close() + } + return nil +} + +func (p *mockProxyProtocolProxy) connectToBackend() (net.Conn, error) { + if p.backendIsSock { + return net.Dial("unix", p.backend) + } + return net.Dial("tcp", p.backend) +} + +func (p *mockProxyProtocolProxy) onConn(conn net.Conn) { + bconn, err := p.connectToBackend() + if err != nil { + conn.Close() + fmt.Println(err) + } + defer bconn.Close() + ppHeader := p.generateProxyProtocolHeaderV2("tcp4", p.clientAddr, p.frontend) + bconn.Write(ppHeader) + p.proxyPipe(conn, bconn) +} + +func (p *mockProxyProtocolProxy) proxyPipe(p1, p2 io.ReadWriteCloser) { + defer p1.Close() + defer p2.Close() + + // start proxy + p1die := make(chan struct{}) + go func() { io.Copy(p1, p2); close(p1die) }() + + p2die := make(chan struct{}) + go func() { io.Copy(p2, p1); close(p2die) }() + + // wait for proxy termination + select { + case <-p1die: + case <-p2die: + } +} + +func (p *mockProxyProtocolProxy) generateProxyProtocolHeaderV2(network, srcAddr, dstAddr string) []byte { + var ( + proxyProtocolV2Sig = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + v2CmdPos = 12 + v2FamlyPos = 13 + ) + saddr, _ := net.ResolveTCPAddr(network, srcAddr) + daddr, _ := net.ResolveTCPAddr(network, dstAddr) + buffer := make([]byte, 1024) + copy(buffer, proxyProtocolV2Sig) + // Command + buffer[v2CmdPos] = 0x21 + // Famly + if network == "tcp4" { + buffer[v2FamlyPos] = 0x11 + binary.BigEndian.PutUint16(buffer[14:14+2], 12) + copy(buffer[16:16+4], []byte(saddr.IP.To4())) + copy(buffer[20:20+4], []byte(daddr.IP.To4())) + binary.BigEndian.PutUint16(buffer[24:24+2], uint16(saddr.Port)) + binary.BigEndian.PutUint16(buffer[26:26+2], uint16(saddr.Port)) + return buffer[0:28] + } else if network == "tcp6" { + buffer[v2FamlyPos] = 0x21 + binary.BigEndian.PutUint16(buffer[14:14+2], 36) + copy(buffer[16:16+16], []byte(saddr.IP.To16())) + copy(buffer[32:32+16], []byte(daddr.IP.To16())) + binary.BigEndian.PutUint16(buffer[48:48+2], uint16(saddr.Port)) + binary.BigEndian.PutUint16(buffer[50:50+2], uint16(saddr.Port)) + return buffer[0:52] + } + return buffer +} + +func TestProxyProtocolWithIpFallbackable(t *testing.T) { + cfg := util2.NewTestConfig() + cfg.Port = 4999 + cfg.Status.ReportStatus = false + // Setup proxy protocol config + cfg.ProxyProtocol.Networks = "*" + cfg.ProxyProtocol.Fallbackable = true + + ts := createTidbTestSuite(t) + + // Prepare Server + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + server.SetDomain(ts.domain) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 100) + defer func() { + server.Close() + }() + + require.NotNil(t, server.Listener()) + require.Nil(t, server.Socket()) + + // Prepare Proxy + ppProxy := newMockProxyProtocolProxy("127.0.0.1:5000", "127.0.0.1:4999", "192.168.1.2:60055", false) + go func() { + ppProxy.Run() + }() + time.Sleep(time.Millisecond * 100) + defer func() { + ppProxy.Close() + }() + + cli := testserverclient.NewTestServerClient() + cli.Port = testutil.GetPortFromTCPAddr(ppProxy.ListenAddr()) + cli.WaitUntilServerCanConnect() + + cli.RunTests(t, + func(config *mysql.Config) { + config.User = "root" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("SHOW PROCESSLIST;") + records := cli.Rows(t, rows) + require.Contains(t, records[0], "192.168.1.2:60055") + }, + ) + + cli2 := testserverclient.NewTestServerClient() + cli2.Port = 4999 + cli2.RunTests(t, + func(config *mysql.Config) { + config.User = "root" + }, + func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("SHOW PROCESSLIST;") + records := cli.Rows(t, rows) + require.Contains(t, records[0], "127.0.0.1:") + }, + ) +} + +func TestProxyProtocolWithIpNoFallbackable(t *testing.T) { + cfg := util2.NewTestConfig() + cfg.Port = 0 + cfg.Status.ReportStatus = false + // Setup proxy protocol config + cfg.ProxyProtocol.Networks = "*" + cfg.ProxyProtocol.Fallbackable = false + + ts := createTidbTestSuite(t) + + // Prepare Server + server, err := server2.NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + server.SetDomain(ts.domain) + go func() { + err := server.Run() + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 1000) + defer func() { + server.Close() + }() + + require.NotNil(t, server.Listener()) + require.Nil(t, server.Socket()) + + cli := testserverclient.NewTestServerClient() + cli.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + dsn := cli.GetDSN(func(config *mysql.Config) { + config.User = "root" + config.DBName = "test" + }) + db, err := sql.Open("mysql", dsn) + require.Nil(t, err) + err = db.Ping() + require.NotNil(t, err) + db.Close() +} diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 0000000000000..74e4877d8bae1 --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,4441 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package session + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/hex" + "encoding/json" + stderrs "errors" + "fmt" + "math" + "math/rand" + "runtime/pprof" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/executor" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/extension" + "github.com/pingcap/tidb/pkg/extension/extensionimpl" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/plugin" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/privilege/conn" + "github.com/pingcap/tidb/pkg/privilege/privileges" + session_metrics "github.com/pingcap/tidb/pkg/session/metrics" + "github.com/pingcap/tidb/pkg/session/txninfo" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/statistics/handle/usage" + storeerr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/store/driver/txn" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/temptable" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/telemetry" + "github.com/pingcap/tidb/pkg/ttl/ttlworker" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/kvcache" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/logutil/consistency" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sem" + "github.com/pingcap/tidb/pkg/util/sli" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/pingcap/tidb/pkg/util/tableutil" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/pingcap/tipb/go-binlog" + tikverr "github.com/tikv/client-go/v2/error" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" + tikvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// Session context, it is consistent with the lifecycle of a client connection. +type Session interface { + sessionctx.Context + Status() uint16 // Flag of current status, such as autocommit. + LastInsertID() uint64 // LastInsertID is the last inserted auto_increment ID. + LastMessage() string // LastMessage is the info message that may be generated by last command + AffectedRows() uint64 // Affected rows by latest executed stmt. + // Execute is deprecated, and only used by plugins. Use ExecuteStmt() instead. + Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. + // ExecuteStmt executes a parsed statement. + ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) + // Parse is deprecated, use ParseWithParams() instead. + Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) + // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. + ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error) + String() string // String is used to debug. + CommitTxn(context.Context) error + RollbackTxn(context.Context) + // PrepareStmt executes prepare statement in binary protocol. + PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) + // ExecutePreparedStmt executes a prepared statement. + // Deprecated: please use ExecuteStmt, this function is left for testing only. + // TODO: remove ExecutePreparedStmt. + ExecutePreparedStmt(ctx context.Context, stmtID uint32, param []expression.Expression) (sqlexec.RecordSet, error) + DropPreparedStmt(stmtID uint32) error + // SetSessionStatesHandler sets SessionStatesHandler for type stateType. + SetSessionStatesHandler(stateType sessionstates.SessionStateType, handler sessionctx.SessionStatesHandler) + SetClientCapability(uint32) // Set client capability flags. + SetConnectionID(uint64) + SetCommandValue(byte) + SetProcessInfo(string, time.Time, byte, uint64) + SetTLSState(*tls.ConnectionState) + SetCollation(coID int) error + SetSessionManager(util.SessionManager) + Close() + Auth(user *auth.UserIdentity, auth, salt []byte, authConn conn.AuthConn) error + AuthWithoutVerification(user *auth.UserIdentity) bool + AuthPluginForUser(user *auth.UserIdentity) (string, error) + MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) + // Return the information of the txn current running + TxnInfo() *txninfo.TxnInfo + // PrepareTxnCtx is exported for test. + PrepareTxnCtx(context.Context) error + // FieldList returns fields list of a table. + FieldList(tableName string) (fields []*ast.ResultField, err error) + SetPort(port string) + + // set cur session operations allowed when tikv disk full happens. + SetDiskFullOpt(level kvrpcpb.DiskFullOpt) + GetDiskFullOpt() kvrpcpb.DiskFullOpt + ClearDiskFullOpt() + + // SetExtensions sets the `*extension.SessionExtensions` object + SetExtensions(extensions *extension.SessionExtensions) +} + +func init() { + executor.CreateSession = func(ctx sessionctx.Context) (sessionctx.Context, error) { + return CreateSession(ctx.GetStore()) + } + executor.CloseSession = func(ctx sessionctx.Context) { + if se, ok := ctx.(Session); ok { + se.Close() + } + } +} + +var _ Session = (*session)(nil) + +type stmtRecord struct { + st sqlexec.Statement + stmtCtx *stmtctx.StatementContext +} + +// StmtHistory holds all histories of statements in a txn. +type StmtHistory struct { + history []*stmtRecord +} + +// Add appends a stmt to history list. +func (h *StmtHistory) Add(st sqlexec.Statement, stmtCtx *stmtctx.StatementContext) { + s := &stmtRecord{ + st: st, + stmtCtx: stmtCtx, + } + h.history = append(h.history, s) +} + +// Count returns the count of the history. +func (h *StmtHistory) Count() int { + return len(h.history) +} + +type session struct { + // processInfo is used by ShowProcess(), and should be modified atomically. + processInfo atomic.Value + txn LazyTxn + + mu struct { + sync.RWMutex + values map[fmt.Stringer]interface{} + } + + currentCtx context.Context // only use for runtime.trace, Please NEVER use it. + currentPlan plannercore.Plan + + store kv.Storage + + sessionPlanCache sessionctx.PlanCache + + sessionVars *variable.SessionVars + sessionManager util.SessionManager + + statsCollector *usage.SessionStatsItem + // ddlOwnerManager is used in `select tidb_is_ddl_owner()` statement; + ddlOwnerManager owner.Manager + // lockedTables use to record the table locks hold by the session. + lockedTables map[int64]model.TableLockTpInfo + + // client shared coprocessor client per session + client kv.Client + + mppClient kv.MPPClient + + // indexUsageCollector collects index usage information. + idxUsageCollector *usage.SessionIndexUsageCollector + + functionUsageMu struct { + syncutil.RWMutex + builtinFunctionUsage telemetry.BuiltinFunctionsUsage + } + // allowed when tikv disk full happened. + diskFullOpt kvrpcpb.DiskFullOpt + + // StmtStats is used to count various indicators of each SQL in this session + // at each point in time. These data will be periodically taken away by the + // background goroutine. The background goroutine will continue to aggregate + // all the local data in each session, and finally report them to the remote + // regularly. + stmtStats *stmtstats.StatementStats + + // Used to encode and decode each type of session states. + sessionStatesHandlers map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler + + // Contains a list of sessions used to collect advisory locks. + advisoryLocks map[string]*advisoryLock + + extensions *extension.SessionExtensions + + sandBoxMode bool +} + +var parserPool = &sync.Pool{New: func() interface{} { return parser.New() }} + +// AddTableLock adds table lock to the session lock map. +func (s *session) AddTableLock(locks []model.TableLockTpInfo) { + for _, l := range locks { + // read only lock is session unrelated, skip it when adding lock to session. + if l.Tp != model.TableLockReadOnly { + s.lockedTables[l.TableID] = l + } + } +} + +// ReleaseTableLocks releases table lock in the session lock map. +func (s *session) ReleaseTableLocks(locks []model.TableLockTpInfo) { + for _, l := range locks { + delete(s.lockedTables, l.TableID) + } +} + +// ReleaseTableLockByTableIDs releases table lock in the session lock map by table ID. +func (s *session) ReleaseTableLockByTableIDs(tableIDs []int64) { + for _, tblID := range tableIDs { + delete(s.lockedTables, tblID) + } +} + +// CheckTableLocked checks the table lock. +func (s *session) CheckTableLocked(tblID int64) (bool, model.TableLockType) { + lt, ok := s.lockedTables[tblID] + if !ok { + return false, model.TableLockNone + } + return true, lt.Tp +} + +// GetAllTableLocks gets all table locks table id and db id hold by the session. +func (s *session) GetAllTableLocks() []model.TableLockTpInfo { + lockTpInfo := make([]model.TableLockTpInfo, 0, len(s.lockedTables)) + for _, tl := range s.lockedTables { + lockTpInfo = append(lockTpInfo, tl) + } + return lockTpInfo +} + +// HasLockedTables uses to check whether this session locked any tables. +// If so, the session can only visit the table which locked by self. +func (s *session) HasLockedTables() bool { + b := len(s.lockedTables) > 0 + return b +} + +// ReleaseAllTableLocks releases all table locks hold by the session. +func (s *session) ReleaseAllTableLocks() { + s.lockedTables = make(map[int64]model.TableLockTpInfo) +} + +// IsDDLOwner checks whether this session is DDL owner. +func (s *session) IsDDLOwner() bool { + return s.ddlOwnerManager.IsOwner() +} + +func (s *session) cleanRetryInfo() { + if s.sessionVars.RetryInfo.Retrying { + return + } + + retryInfo := s.sessionVars.RetryInfo + defer retryInfo.Clean() + if len(retryInfo.DroppedPreparedStmtIDs) == 0 { + return + } + + planCacheEnabled := s.GetSessionVars().EnablePreparedPlanCache + var cacheKey kvcache.Key + var err error + var preparedAst *ast.Prepared + var stmtText, stmtDB string + if planCacheEnabled { + firstStmtID := retryInfo.DroppedPreparedStmtIDs[0] + if preparedPointer, ok := s.sessionVars.PreparedStmts[firstStmtID]; ok { + preparedObj, ok := preparedPointer.(*plannercore.PlanCacheStmt) + if ok { + preparedAst = preparedObj.PreparedAst + stmtText, stmtDB = preparedObj.StmtText, preparedObj.StmtDB + bindSQL, _ := plannercore.GetBindSQL4PlanCache(s, preparedObj) + cacheKey, err = plannercore.NewPlanCacheKey(s.sessionVars, stmtText, stmtDB, preparedAst.SchemaVersion, + 0, bindSQL, expression.ExprPushDownBlackListReloadTimeStamp.Load()) + if err != nil { + logutil.Logger(s.currentCtx).Warn("clean cached plan failed", zap.Error(err)) + return + } + } + } + } + for i, stmtID := range retryInfo.DroppedPreparedStmtIDs { + if planCacheEnabled { + if i > 0 && preparedAst != nil { + plannercore.SetPstmtIDSchemaVersion(cacheKey, stmtText, preparedAst.SchemaVersion, s.sessionVars.IsolationReadEngines) + } + if !s.sessionVars.IgnorePreparedCacheCloseStmt { // keep the plan in cache + s.GetSessionPlanCache().Delete(cacheKey) + } + } + s.sessionVars.RemovePreparedStmt(stmtID) + } +} + +func (s *session) Status() uint16 { + return s.sessionVars.Status +} + +func (s *session) LastInsertID() uint64 { + if s.sessionVars.StmtCtx.LastInsertID > 0 { + return s.sessionVars.StmtCtx.LastInsertID + } + return s.sessionVars.StmtCtx.InsertID +} + +func (s *session) LastMessage() string { + return s.sessionVars.StmtCtx.GetMessage() +} + +func (s *session) AffectedRows() uint64 { + return s.sessionVars.StmtCtx.AffectedRows() +} + +func (s *session) SetClientCapability(capability uint32) { + s.sessionVars.ClientCapability = capability +} + +func (s *session) SetConnectionID(connectionID uint64) { + s.sessionVars.ConnectionID = connectionID +} + +func (s *session) SetTLSState(tlsState *tls.ConnectionState) { + // If user is not connected via TLS, then tlsState == nil. + if tlsState != nil { + s.sessionVars.TLSConnectionState = tlsState + } +} + +func (s *session) SetCommandValue(command byte) { + atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command)) +} + +func (s *session) SetCollation(coID int) error { + cs, co, err := charset.GetCharsetInfoByID(coID) + if err != nil { + return err + } + // If new collations are enabled, switch to the default + // collation if this one is not supported. + co = collate.SubstituteMissingCollationToDefault(co) + for _, v := range variable.SetNamesVariables { + terror.Log(s.sessionVars.SetSystemVarWithoutValidation(v, cs)) + } + return s.sessionVars.SetSystemVarWithoutValidation(variable.CollationConnection, co) +} + +func (s *session) GetSessionPlanCache() sessionctx.PlanCache { + // use the prepared plan cache + if !s.GetSessionVars().EnablePreparedPlanCache && !s.GetSessionVars().EnableNonPreparedPlanCache { + return nil + } + if s.sessionPlanCache == nil { // lazy construction + s.sessionPlanCache = plannercore.NewLRUPlanCache(uint(s.GetSessionVars().SessionPlanCacheSize), + variable.PreparedPlanCacheMemoryGuardRatio.Load(), plannercore.PreparedPlanCacheMaxMemory.Load(), s, false) + } + return s.sessionPlanCache +} + +func (s *session) SetSessionManager(sm util.SessionManager) { + s.sessionManager = sm +} + +func (s *session) GetSessionManager() util.SessionManager { + return s.sessionManager +} + +func (s *session) UpdateColStatsUsage(predicateColumns []model.TableItemID) { + if s.statsCollector == nil { + return + } + t := time.Now() + colMap := make(map[model.TableItemID]time.Time, len(predicateColumns)) + for _, col := range predicateColumns { + if col.IsIndex { + continue + } + colMap[col] = t + } + s.statsCollector.UpdateColStatsUsage(colMap) +} + +// StoreIndexUsage stores index usage information in idxUsageCollector. +func (s *session) StoreIndexUsage(tblID int64, idxID int64, rowsSelected int64) { + if s.idxUsageCollector == nil { + return + } + s.idxUsageCollector.Update(tblID, idxID, &usage.IndexUsageInformation{QueryCount: 1, RowsSelected: rowsSelected}) +} + +// FieldList returns fields list of a table. +func (s *session) FieldList(tableName string) ([]*ast.ResultField, error) { + is := s.GetInfoSchema().(infoschema.InfoSchema) + dbName := model.NewCIStr(s.GetSessionVars().CurrentDB) + tName := model.NewCIStr(tableName) + pm := privilege.GetPrivilegeManager(s) + if pm != nil && s.sessionVars.User != nil { + if !pm.RequestVerification(s.sessionVars.ActiveRoles, dbName.O, tName.O, "", mysql.AllPrivMask) { + user := s.sessionVars.User + u := user.Username + h := user.Hostname + if len(user.AuthUsername) > 0 && len(user.AuthHostname) > 0 { + u = user.AuthUsername + h = user.AuthHostname + } + return nil, plannercore.ErrTableaccessDenied.GenWithStackByArgs("SELECT", u, h, tableName) + } + } + table, err := is.TableByName(dbName, tName) + if err != nil { + return nil, err + } + + cols := table.Cols() + fields := make([]*ast.ResultField, 0, len(cols)) + for _, col := range table.Cols() { + rf := &ast.ResultField{ + ColumnAsName: col.Name, + TableAsName: tName, + DBName: dbName, + Table: table.Meta(), + Column: col.ColumnInfo, + } + fields = append(fields, rf) + } + return fields, nil +} + +// TxnInfo returns a pointer to a *copy* of the internal TxnInfo, thus is *read only* +func (s *session) TxnInfo() *txninfo.TxnInfo { + s.txn.mu.RLock() + // Copy on read to get a snapshot, this API shouldn't be frequently called. + txnInfo := s.txn.mu.TxnInfo + s.txn.mu.RUnlock() + + if txnInfo.StartTS == 0 { + return nil + } + + processInfo := s.ShowProcess() + if processInfo == nil { + return nil + } + txnInfo.ConnectionID = processInfo.ID + txnInfo.Username = processInfo.User + txnInfo.CurrentDB = processInfo.DB + txnInfo.RelatedTableIDs = make(map[int64]struct{}) + s.GetSessionVars().GetRelatedTableForMDL().Range(func(key, value interface{}) bool { + txnInfo.RelatedTableIDs[key.(int64)] = struct{}{} + return true + }) + + return &txnInfo +} + +func (s *session) doCommit(ctx context.Context) error { + if !s.txn.Valid() { + return nil + } + + // to avoid session set overlap the txn set. + if s.GetDiskFullOpt() != kvrpcpb.DiskFullOpt_NotAllowedOnFull { + s.txn.SetDiskFullOpt(s.GetDiskFullOpt()) + } + + defer func() { + s.txn.changeToInvalid() + s.sessionVars.SetInTxn(false) + s.ClearDiskFullOpt() + }() + // check if the transaction is read-only + if s.txn.IsReadOnly() { + return nil + } + // check if the cluster is read-only + if !s.sessionVars.InRestrictedSQL && variable.RestrictedReadOnly.Load() || variable.VarTiDBSuperReadOnly.Load() { + // It is not internal SQL, and the cluster has one of RestrictedReadOnly or SuperReadOnly + // We need to privilege check again: a privilege check occurred during planning, but we need + // to prevent the case that a long running auto-commit statement is now trying to commit. + pm := privilege.GetPrivilegeManager(s) + roles := s.sessionVars.ActiveRoles + if pm != nil && !pm.HasExplicitlyGrantedDynamicPrivilege(roles, "RESTRICTED_REPLICA_WRITER_ADMIN", false) { + s.RollbackTxn(ctx) + return plannercore.ErrSQLInReadOnlyMode + } + } + err := s.checkPlacementPolicyBeforeCommit() + if err != nil { + return err + } + // mockCommitError and mockGetTSErrorInRetry use to test PR #8743. + failpoint.Inject("mockCommitError", func(val failpoint.Value) { + if val.(bool) { + if _, err := failpoint.Eval("tikvclient/mockCommitErrorOpt"); err == nil { + failpoint.Return(kv.ErrTxnRetryable) + } + } + }) + + if s.sessionVars.BinlogClient != nil { + prewriteValue := binloginfo.GetPrewriteValue(s, false) + if prewriteValue != nil { + prewriteData, err := prewriteValue.Marshal() + if err != nil { + return errors.Trace(err) + } + info := &binloginfo.BinlogInfo{ + Data: &binlog.Binlog{ + Tp: binlog.BinlogType_Prewrite, + PrewriteValue: prewriteData, + }, + Client: s.sessionVars.BinlogClient, + } + s.txn.SetOption(kv.BinlogInfo, info) + } + } + + sessVars := s.GetSessionVars() + // Get the related table or partition IDs. + relatedPhysicalTables := sessVars.TxnCtx.TableDeltaMap + // Get accessed temporary tables in the transaction. + temporaryTables := sessVars.TxnCtx.TemporaryTables + physicalTableIDs := make([]int64, 0, len(relatedPhysicalTables)) + for id := range relatedPhysicalTables { + // Schema change on global temporary tables doesn't affect transactions. + if _, ok := temporaryTables[id]; ok { + continue + } + physicalTableIDs = append(physicalTableIDs, id) + } + needCheckSchema := true + // Set this option for 2 phase commit to validate schema lease. + if s.GetSessionVars().TxnCtx != nil { + needCheckSchema = !s.GetSessionVars().TxnCtx.EnableMDL + } + s.txn.SetOption(kv.SchemaChecker, domain.NewSchemaChecker(domain.GetDomain(s), s.GetInfoSchema().SchemaMetaVersion(), physicalTableIDs, needCheckSchema)) + s.txn.SetOption(kv.InfoSchema, s.sessionVars.TxnCtx.InfoSchema) + s.txn.SetOption(kv.CommitHook, func(info string, _ error) { s.sessionVars.LastTxnInfo = info }) + s.txn.SetOption(kv.EnableAsyncCommit, sessVars.EnableAsyncCommit) + s.txn.SetOption(kv.Enable1PC, sessVars.Enable1PC) + s.txn.SetOption(kv.ResourceGroupTagger, sessVars.StmtCtx.GetResourceGroupTagger()) + s.txn.SetOption(kv.ExplicitRequestSourceType, sessVars.ExplicitRequestSourceType) + if sessVars.StmtCtx.KvExecCounter != nil { + // Bind an interceptor for client-go to count the number of SQL executions of each TiKV. + s.txn.SetOption(kv.RPCInterceptor, sessVars.StmtCtx.KvExecCounter.RPCInterceptor()) + } + // priority of the sysvar is lower than `start transaction with causal consistency only` + if val := s.txn.GetOption(kv.GuaranteeLinearizability); val == nil || val.(bool) { + // We needn't ask the TiKV client to guarantee linearizability for auto-commit transactions + // because the property is naturally holds: + // We guarantee the commitTS of any transaction must not exceed the next timestamp from the TSO. + // An auto-commit transaction fetches its startTS from the TSO so its commitTS > its startTS > the commitTS + // of any previously committed transactions. + s.txn.SetOption(kv.GuaranteeLinearizability, + sessVars.TxnCtx.IsExplicit && sessVars.GuaranteeLinearizability) + } + if tables := sessVars.TxnCtx.TemporaryTables; len(tables) > 0 { + s.txn.SetOption(kv.KVFilter, temporaryTableKVFilter(tables)) + } + + var txnSource uint64 + if val := s.txn.GetOption(kv.TxnSource); val != nil { + txnSource, _ = val.(uint64) + } + // If the transaction is started by CDC, we need to set the CDCWriteSource option. + if sessVars.CDCWriteSource != 0 { + err := kv.SetCDCWriteSource(&txnSource, sessVars.CDCWriteSource) + if err != nil { + return errors.Trace(err) + } + + s.txn.SetOption(kv.TxnSource, txnSource) + } + + if tables := sessVars.TxnCtx.CachedTables; len(tables) > 0 { + c := cachedTableRenewLease{tables: tables} + now := time.Now() + err := c.start(ctx) + defer c.stop(ctx) + sessVars.StmtCtx.WaitLockLeaseTime += time.Since(now) + if err != nil { + return errors.Trace(err) + } + s.txn.SetOption(kv.CommitTSUpperBoundCheck, c.commitTSCheck) + } + + err = s.commitTxnWithTemporaryData(tikvutil.SetSessionID(ctx, sessVars.ConnectionID), &s.txn) + if err != nil { + err = s.handleAssertionFailure(ctx, err) + } + return err +} + +type cachedTableRenewLease struct { + tables map[int64]interface{} + lease []uint64 // Lease for each visited cached tables. + exit chan struct{} +} + +func (c *cachedTableRenewLease) start(ctx context.Context) error { + c.exit = make(chan struct{}) + c.lease = make([]uint64, len(c.tables)) + wg := make(chan error, len(c.tables)) + ith := 0 + for _, raw := range c.tables { + tbl := raw.(table.CachedTable) + go tbl.WriteLockAndKeepAlive(ctx, c.exit, &c.lease[ith], wg) + ith++ + } + + // Wait for all LockForWrite() return, this function can return. + var err error + for ; ith > 0; ith-- { + tmp := <-wg + if tmp != nil { + err = tmp + } + } + return err +} + +func (c *cachedTableRenewLease) stop(_ context.Context) { + close(c.exit) +} + +func (c *cachedTableRenewLease) commitTSCheck(commitTS uint64) bool { + for i := 0; i < len(c.lease); i++ { + lease := atomic.LoadUint64(&c.lease[i]) + if commitTS >= lease { + // Txn fails to commit because the write lease is expired. + return false + } + } + return true +} + +// handleAssertionFailure extracts the possible underlying assertionFailed error, +// gets the corresponding MVCC history and logs it. +// If it's not an assertion failure, returns the original error. +func (s *session) handleAssertionFailure(ctx context.Context, err error) error { + var assertionFailure *tikverr.ErrAssertionFailed + if !stderrs.As(err, &assertionFailure) { + return err + } + key := assertionFailure.Key + newErr := kv.ErrAssertionFailed.GenWithStackByArgs( + hex.EncodeToString(key), assertionFailure.Assertion.String(), assertionFailure.StartTs, + assertionFailure.ExistingStartTs, assertionFailure.ExistingCommitTs, + ) + + if s.GetSessionVars().EnableRedactLog { + return newErr + } + + var decodeFunc func(kv.Key, *kvrpcpb.MvccGetByKeyResponse, map[string]interface{}) + // if it's a record key or an index key, decode it + if infoSchema, ok := s.sessionVars.TxnCtx.InfoSchema.(infoschema.InfoSchema); ok && + infoSchema != nil && (tablecodec.IsRecordKey(key) || tablecodec.IsIndexKey(key)) { + tableOrPartitionID := tablecodec.DecodeTableID(key) + tbl, ok := infoSchema.TableByID(tableOrPartitionID) + if !ok { + tbl, _, _ = infoSchema.FindTableByPartitionID(tableOrPartitionID) + } + if tbl == nil { + logutil.Logger(ctx).Warn("cannot find table by id", zap.Int64("tableID", tableOrPartitionID), zap.String("key", hex.EncodeToString(key))) + return newErr + } + + if tablecodec.IsRecordKey(key) { + decodeFunc = consistency.DecodeRowMvccData(tbl.Meta()) + } else { + tableInfo := tbl.Meta() + _, indexID, _, e := tablecodec.DecodeIndexKey(key) + if e != nil { + logutil.Logger(ctx).Error("assertion failed but cannot decode index key", zap.Error(e)) + return newErr + } + var indexInfo *model.IndexInfo + for _, idx := range tableInfo.Indices { + if idx.ID == indexID { + indexInfo = idx + break + } + } + if indexInfo == nil { + return newErr + } + decodeFunc = consistency.DecodeIndexMvccData(indexInfo) + } + } + if store, ok := s.store.(helper.Storage); ok { + content := consistency.GetMvccByKey(store, key, decodeFunc) + logutil.Logger(ctx).Error("assertion failed", zap.String("message", newErr.Error()), zap.String("mvcc history", content)) + } + return newErr +} + +func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transaction) error { + sessVars := s.sessionVars + txnTempTables := sessVars.TxnCtx.TemporaryTables + if len(txnTempTables) == 0 { + failpoint.Inject("mockSleepBeforeTxnCommit", func(v failpoint.Value) { + ms := v.(int) + time.Sleep(time.Millisecond * time.Duration(ms)) + }) + return txn.Commit(ctx) + } + + sessionData := sessVars.TemporaryTableData + var ( + stage kv.StagingHandle + localTempTables *infoschema.SessionTables + ) + + if sessVars.LocalTemporaryTables != nil { + localTempTables = sessVars.LocalTemporaryTables.(*infoschema.SessionTables) + } else { + localTempTables = new(infoschema.SessionTables) + } + + defer func() { + // stage != kv.InvalidStagingHandle means error occurs, we need to cleanup sessionData + if stage != kv.InvalidStagingHandle { + sessionData.Cleanup(stage) + } + }() + + for tblID, tbl := range txnTempTables { + if !tbl.GetModified() { + continue + } + + if tbl.GetMeta().TempTableType != model.TempTableLocal { + continue + } + if _, ok := localTempTables.TableByID(tblID); !ok { + continue + } + + if stage == kv.InvalidStagingHandle { + stage = sessionData.Staging() + } + + tblPrefix := tablecodec.EncodeTablePrefix(tblID) + endKey := tablecodec.EncodeTablePrefix(tblID + 1) + + txnMemBuffer := s.txn.GetMemBuffer() + iter, err := txnMemBuffer.Iter(tblPrefix, endKey) + if err != nil { + return err + } + + for iter.Valid() { + key := iter.Key() + if !bytes.HasPrefix(key, tblPrefix) { + break + } + + value := iter.Value() + if len(value) == 0 { + err = sessionData.DeleteTableKey(tblID, key) + } else { + err = sessionData.SetTableKey(tblID, key, iter.Value()) + } + + if err != nil { + return err + } + + err = iter.Next() + if err != nil { + return err + } + } + } + + err := txn.Commit(ctx) + if err != nil { + return err + } + + if stage != kv.InvalidStagingHandle { + sessionData.Release(stage) + stage = kv.InvalidStagingHandle + } + + return nil +} + +type temporaryTableKVFilter map[int64]tableutil.TempTable + +func (m temporaryTableKVFilter) IsUnnecessaryKeyValue(key, value []byte, flags tikvstore.KeyFlags) (bool, error) { + tid := tablecodec.DecodeTableID(key) + if _, ok := m[tid]; ok { + return true, nil + } + + // This is the default filter for all tables. + defaultFilter := txn.TiDBKVFilter{} + return defaultFilter.IsUnnecessaryKeyValue(key, value, flags) +} + +// errIsNoisy is used to filter DUPLCATE KEY errors. +// These can observed by users in INFORMATION_SCHEMA.CLIENT_ERRORS_SUMMARY_GLOBAL instead. +// +// The rationale for filtering these errors is because they are "client generated errors". i.e. +// of the errors defined in kv/error.go, these look to be clearly related to a client-inflicted issue, +// and the server is only responsible for handling the error correctly. It does not need to log. +func errIsNoisy(err error) bool { + if kv.ErrKeyExists.Equal(err) { + return true + } + if storeerr.ErrLockAcquireFailAndNoWaitSet.Equal(err) { + return true + } + return false +} + +func (s *session) doCommitWithRetry(ctx context.Context) error { + defer func() { + s.GetSessionVars().SetTxnIsolationLevelOneShotStateForNextTxn() + s.txn.changeToInvalid() + s.cleanRetryInfo() + sessiontxn.GetTxnManager(s).OnTxnEnd() + }() + if !s.txn.Valid() { + // If the transaction is invalid, maybe it has already been rolled back by the client. + return nil + } + isInternalTxn := false + if internal := s.txn.GetOption(kv.RequestSourceInternal); internal != nil && internal.(bool) { + isInternalTxn = true + } + var err error + txnSize := s.txn.Size() + isPessimistic := s.txn.IsPessimistic() + r, ctx := tracing.StartRegionEx(ctx, "session.doCommitWithRetry") + defer r.End() + + err = s.doCommit(ctx) + if err != nil { + // polish the Write Conflict error message + newErr := s.tryReplaceWriteConflictError(err) + if newErr != nil { + err = newErr + } + + commitRetryLimit := s.sessionVars.RetryLimit + if !s.sessionVars.TxnCtx.CouldRetry { + commitRetryLimit = 0 + } + // Don't retry in BatchInsert mode. As a counter-example, insert into t1 select * from t2, + // BatchInsert already commit the first batch 1000 rows, then it commit 1000-2000 and retry the statement, + // Finally t1 will have more data than t2, with no errors return to user! + if s.isTxnRetryableError(err) && !s.sessionVars.BatchInsert && commitRetryLimit > 0 && !isPessimistic { + logutil.Logger(ctx).Warn("sql", + zap.String("label", s.GetSQLLabel()), + zap.Error(err), + zap.String("txn", s.txn.GoString())) + // Transactions will retry 2 ~ commitRetryLimit times. + // We make larger transactions retry less times to prevent cluster resource outage. + txnSizeRate := float64(txnSize) / float64(kv.TxnTotalSizeLimit.Load()) + maxRetryCount := commitRetryLimit - int64(float64(commitRetryLimit-1)*txnSizeRate) + err = s.retry(ctx, uint(maxRetryCount)) + } else if !errIsNoisy(err) { + logutil.Logger(ctx).Warn("can not retry txn", + zap.String("label", s.GetSQLLabel()), + zap.Error(err), + zap.Bool("IsBatchInsert", s.sessionVars.BatchInsert), + zap.Bool("IsPessimistic", isPessimistic), + zap.Bool("InRestrictedSQL", s.sessionVars.InRestrictedSQL), + zap.Int64("tidb_retry_limit", s.sessionVars.RetryLimit), + zap.Bool("tidb_disable_txn_auto_retry", s.sessionVars.DisableTxnAutoRetry)) + } + } + counter := s.sessionVars.TxnCtx.StatementCount + duration := time.Since(s.GetSessionVars().TxnCtx.CreateTime).Seconds() + s.recordOnTransactionExecution(err, counter, duration, isInternalTxn) + + if err != nil { + if !errIsNoisy(err) { + logutil.Logger(ctx).Warn("commit failed", + zap.String("finished txn", s.txn.GoString()), + zap.Error(err)) + } + return err + } + s.updateStatsDeltaToCollector() + return nil +} + +// adds more information about the table in the error message +// precondition: oldErr is a 9007:WriteConflict Error +func (s *session) tryReplaceWriteConflictError(oldErr error) (newErr error) { + if !kv.ErrWriteConflict.Equal(oldErr) { + return nil + } + if errors.RedactLogEnabled.Load() { + return nil + } + originErr := errors.Cause(oldErr) + inErr, _ := originErr.(*errors.Error) + args := inErr.Args() + is := sessiontxn.GetTxnManager(s).GetTxnInfoSchema() + if is == nil { + return nil + } + newKeyTableField, ok := addTableNameInTableIDField(args[3], is) + if ok { + args[3] = newKeyTableField + } + newPrimaryKeyTableField, ok := addTableNameInTableIDField(args[5], is) + if ok { + args[5] = newPrimaryKeyTableField + } + return kv.ErrWriteConflict.FastGenByArgs(args...) +} + +// precondition: is != nil +func addTableNameInTableIDField(tableIDField interface{}, is infoschema.InfoSchema) (enhancedMsg string, done bool) { + keyTableID, ok := tableIDField.(string) + if !ok { + return "", false + } + stringsInTableIDField := strings.Split(keyTableID, "=") + if len(stringsInTableIDField) == 0 { + return "", false + } + tableIDStr := stringsInTableIDField[len(stringsInTableIDField)-1] + tableID, err := strconv.ParseInt(tableIDStr, 10, 64) + if err != nil { + return "", false + } + var tableName string + tbl, ok := is.TableByID(tableID) + if !ok { + tableName = "unknown" + } else { + dbInfo, ok := is.SchemaByTable(tbl.Meta()) + if !ok { + tableName = "unknown." + tbl.Meta().Name.String() + } else { + tableName = dbInfo.Name.String() + "." + tbl.Meta().Name.String() + } + } + enhancedMsg = keyTableID + ", tableName=" + tableName + return enhancedMsg, true +} + +func (s *session) updateStatsDeltaToCollector() { + mapper := s.GetSessionVars().TxnCtx.TableDeltaMap + if s.statsCollector != nil && mapper != nil { + for _, item := range mapper { + if item.TableID > 0 { + s.statsCollector.Update(item.TableID, item.Delta, item.Count, &item.ColSize) + } + } + } +} + +func (s *session) CommitTxn(ctx context.Context) error { + r, ctx := tracing.StartRegionEx(ctx, "session.CommitTxn") + defer r.End() + + var commitDetail *tikvutil.CommitDetails + ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) + err := s.doCommitWithRetry(ctx) + if commitDetail != nil { + s.sessionVars.StmtCtx.MergeExecDetails(nil, commitDetail) + } + + // record the TTLInsertRows in the metric + metrics.TTLInsertRowsCount.Add(float64(s.sessionVars.TxnCtx.InsertTTLRowsCount)) + + failpoint.Inject("keepHistory", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(err) + } + }) + s.sessionVars.TxnCtx.Cleanup() + s.sessionVars.CleanupTxnReadTSIfUsed() + return err +} + +func (s *session) RollbackTxn(ctx context.Context) { + r, ctx := tracing.StartRegionEx(ctx, "session.RollbackTxn") + defer r.End() + + if s.txn.Valid() { + terror.Log(s.txn.Rollback()) + } + if ctx.Value(inCloseSession{}) == nil { + s.cleanRetryInfo() + } + s.txn.changeToInvalid() + s.sessionVars.TxnCtx.Cleanup() + s.sessionVars.CleanupTxnReadTSIfUsed() + s.sessionVars.SetInTxn(false) + sessiontxn.GetTxnManager(s).OnTxnEnd() +} + +func (s *session) GetClient() kv.Client { + return s.client +} + +func (s *session) GetMPPClient() kv.MPPClient { + return s.mppClient +} + +func (s *session) String() string { + // TODO: how to print binded context in values appropriately? + sessVars := s.sessionVars + data := map[string]interface{}{ + "id": sessVars.ConnectionID, + "user": sessVars.User, + "currDBName": sessVars.CurrentDB, + "status": sessVars.Status, + "strictMode": sessVars.StrictSQLMode, + } + if s.txn.Valid() { + // if txn is committed or rolled back, txn is nil. + data["txn"] = s.txn.String() + } + if sessVars.SnapshotTS != 0 { + data["snapshotTS"] = sessVars.SnapshotTS + } + if sessVars.StmtCtx.LastInsertID > 0 { + data["lastInsertID"] = sessVars.StmtCtx.LastInsertID + } + if len(sessVars.PreparedStmts) > 0 { + data["preparedStmtCount"] = len(sessVars.PreparedStmts) + } + b, err := json.MarshalIndent(data, "", " ") + terror.Log(errors.Trace(err)) + return string(b) +} + +const sqlLogMaxLen = 1024 + +// SchemaChangedWithoutRetry is used for testing. +var SchemaChangedWithoutRetry uint32 + +func (s *session) GetSQLLabel() string { + if s.sessionVars.InRestrictedSQL { + return metrics.LblInternal + } + return metrics.LblGeneral +} + +func (s *session) isInternal() bool { + return s.sessionVars.InRestrictedSQL +} + +func (*session) isTxnRetryableError(err error) bool { + if atomic.LoadUint32(&SchemaChangedWithoutRetry) == 1 { + return kv.IsTxnRetryableError(err) + } + return kv.IsTxnRetryableError(err) || domain.ErrInfoSchemaChanged.Equal(err) +} + +func (s *session) checkTxnAborted(stmt sqlexec.Statement) error { + var err error + if atomic.LoadUint32(&s.GetSessionVars().TxnCtx.LockExpire) == 0 { + return nil + } + err = kv.ErrLockExpire + // If the transaction is aborted, the following statements do not need to execute, except `commit` and `rollback`, + // because they are used to finish the aborted transaction. + if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { + return nil + } + if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.RollbackStmt); ok { + return nil + } + return err +} + +func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { + var retryCnt uint + defer func() { + s.sessionVars.RetryInfo.Retrying = false + // retryCnt only increments on retryable error, so +1 here. + if s.sessionVars.InRestrictedSQL { + session_metrics.TransactionRetryInternal.Observe(float64(retryCnt + 1)) + } else { + session_metrics.TransactionRetryGeneral.Observe(float64(retryCnt + 1)) + } + s.sessionVars.SetInTxn(false) + if err != nil { + s.RollbackTxn(ctx) + } + s.txn.changeToInvalid() + }() + + connID := s.sessionVars.ConnectionID + s.sessionVars.RetryInfo.Retrying = true + if atomic.LoadUint32(&s.sessionVars.TxnCtx.ForUpdate) == 1 { + err = ErrForUpdateCantRetry.GenWithStackByArgs(connID) + return err + } + + nh := GetHistory(s) + var schemaVersion int64 + sessVars := s.GetSessionVars() + orgStartTS := sessVars.TxnCtx.StartTS + label := s.GetSQLLabel() + for { + if err = s.PrepareTxnCtx(ctx); err != nil { + return err + } + s.sessionVars.RetryInfo.ResetOffset() + for i, sr := range nh.history { + st := sr.st + s.sessionVars.StmtCtx = sr.stmtCtx + s.sessionVars.StmtCtx.ResetForRetry() + s.sessionVars.PlanCacheParams.Reset() + schemaVersion, err = st.RebuildPlan(ctx) + if err != nil { + return err + } + + if retryCnt == 0 { + // We do not have to log the query every time. + // We print the queries at the first try only. + sql := sqlForLog(st.GetTextToLog(false)) + if !sessVars.EnableRedactLog { + sql += sessVars.PlanCacheParams.String() + } + logutil.Logger(ctx).Warn("retrying", + zap.Int64("schemaVersion", schemaVersion), + zap.Uint("retryCnt", retryCnt), + zap.Int("queryNum", i), + zap.String("sql", sql)) + } else { + logutil.Logger(ctx).Warn("retrying", + zap.Int64("schemaVersion", schemaVersion), + zap.Uint("retryCnt", retryCnt), + zap.Int("queryNum", i)) + } + _, digest := s.sessionVars.StmtCtx.SQLDigest() + s.txn.onStmtStart(digest.String()) + if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx, st.GetStmtNode()); err == nil { + _, err = st.Exec(ctx) + } + s.txn.onStmtEnd() + if err != nil { + s.StmtRollback(ctx, false) + break + } + s.StmtCommit(ctx) + } + logutil.Logger(ctx).Warn("transaction association", + zap.Uint64("retrying txnStartTS", s.GetSessionVars().TxnCtx.StartTS), + zap.Uint64("original txnStartTS", orgStartTS)) + failpoint.Inject("preCommitHook", func() { + hook, ok := ctx.Value("__preCommitHook").(func()) + if ok { + hook() + } + }) + if err == nil { + err = s.doCommit(ctx) + if err == nil { + break + } + } + if !s.isTxnRetryableError(err) { + logutil.Logger(ctx).Warn("sql", + zap.String("label", label), + zap.Stringer("session", s), + zap.Error(err)) + metrics.SessionRetryErrorCounter.WithLabelValues(label, metrics.LblUnretryable).Inc() + return err + } + retryCnt++ + if retryCnt >= maxCnt { + logutil.Logger(ctx).Warn("sql", + zap.String("label", label), + zap.Uint("retry reached max count", retryCnt)) + metrics.SessionRetryErrorCounter.WithLabelValues(label, metrics.LblReachMax).Inc() + return err + } + logutil.Logger(ctx).Warn("sql", + zap.String("label", label), + zap.Error(err), + zap.String("txn", s.txn.GoString())) + kv.BackOff(retryCnt) + s.txn.changeToInvalid() + s.sessionVars.SetInTxn(false) + } + return err +} + +func sqlForLog(sql string) string { + if len(sql) > sqlLogMaxLen { + sql = sql[:sqlLogMaxLen] + fmt.Sprintf("(len:%d)", len(sql)) + } + return executor.QueryReplacer.Replace(sql) +} + +type sessionPool interface { + Get() (pools.Resource, error) + Put(pools.Resource) +} + +func (s *session) sysSessionPool() sessionPool { + return domain.GetDomain(s).SysSessionPool() +} + +func createSessionFunc(store kv.Storage) pools.Factory { + return func() (pools.Resource, error) { + se, err := createSession(store) + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.AutoCommit, "1") + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.MaxExecutionTime, "0") + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.TiDBEnableWindowFunction, variable.BoolToOnOff(variable.DefEnableWindowFunction)) + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.TiDBConstraintCheckInPlacePessimistic, variable.On) + if err != nil { + return nil, errors.Trace(err) + } + se.sessionVars.CommonGlobalLoaded = true + se.sessionVars.InRestrictedSQL = true + // Internal session uses default format to prevent memory leak problem. + se.sessionVars.EnableChunkRPC = false + return se, nil + } +} + +func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.Resource, error) { + return func(dom *domain.Domain) (pools.Resource, error) { + se, err := CreateSessionWithDomain(store, dom) + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.AutoCommit, "1") + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.MaxExecutionTime, "0") + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.TiDBConstraintCheckInPlacePessimistic, variable.On) + if err != nil { + return nil, errors.Trace(err) + } + se.sessionVars.CommonGlobalLoaded = true + se.sessionVars.InRestrictedSQL = true + // Internal session uses default format to prevent memory leak problem. + se.sessionVars.EnableChunkRPC = false + return se, nil + } +} + +func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet, alloc chunk.Allocator) ([]chunk.Row, error) { + var rows []chunk.Row + var req *chunk.Chunk + req = rs.NewChunk(alloc) + for { + err := rs.Next(ctx, req) + if err != nil || req.NumRows() == 0 { + return rows, err + } + iter := chunk.NewIterator4Chunk(req) + for r := iter.Begin(); r != iter.End(); r = iter.Next() { + rows = append(rows, r) + } + req = chunk.Renew(req, se.sessionVars.MaxChunkSize) + } +} + +// getTableValue executes restricted sql and the result is one column. +// It returns a string value. +func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { + if ctx.Value(kv.RequestSourceKey) == nil { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnSysVar) + } + rows, fields, err := s.ExecRestrictedSQL(ctx, nil, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) + if err != nil { + return "", err + } + if len(rows) == 0 { + return "", errResultIsEmpty + } + d := rows[0].GetDatum(0, &fields[0].Column.FieldType) + value, err := d.ToString() + if err != nil { + return "", err + } + return value, nil +} + +// replaceGlobalVariablesTableValue executes restricted sql updates the variable value +// It will then notify the etcd channel that the value has changed. +func (s *session) replaceGlobalVariablesTableValue(ctx context.Context, varName, val string, updateLocal bool) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnSysVar) + _, _, err := s.ExecRestrictedSQL(ctx, nil, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) + if err != nil { + return err + } + domain.GetDomain(s).NotifyUpdateSysVarCache(updateLocal) + return err +} + +// GetGlobalSysVar implements GlobalVarAccessor.GetGlobalSysVar interface. +func (s *session) GetGlobalSysVar(name string) (string, error) { + if s.Value(sessionctx.Initing) != nil { + // When running bootstrap or upgrade, we should not access global storage. + return "", nil + } + + sv := variable.GetSysVar(name) + if sv == nil { + // It might be a recently unregistered sysvar. We should return unknown + // since GetSysVar is the canonical version, but we can update the cache + // so the next request doesn't attempt to load this. + logutil.BgLogger().Info("sysvar does not exist. sysvar cache may be stale", zap.String("name", name)) + return "", variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + } + + sysVar, err := domain.GetDomain(s).GetGlobalVar(name) + if err != nil { + // The sysvar exists, but there is no cache entry yet. + // This might be because the sysvar was only recently registered. + // In which case it is safe to return the default, but we can also + // update the cache for the future. + logutil.BgLogger().Info("sysvar not in cache yet. sysvar cache may be stale", zap.String("name", name)) + sysVar, err = s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) + if err != nil { + return sv.Value, nil + } + } + // It might have been written from an earlier TiDB version, so we should do type validation + // See https://github.com/pingcap/tidb/issues/30255 for why we don't do full validation. + // If validation fails, we should return the default value: + // See: https://github.com/pingcap/tidb/pull/31566 + sysVar, err = sv.ValidateFromType(s.GetSessionVars(), sysVar, variable.ScopeGlobal) + if err != nil { + return sv.Value, nil + } + return sysVar, nil +} + +// SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface. +// it is called (but skipped) when setting instance scope +func (s *session) SetGlobalSysVar(ctx context.Context, name string, value string) (err error) { + sv := variable.GetSysVar(name) + if sv == nil { + return variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + } + if value, err = sv.Validate(s.sessionVars, value, variable.ScopeGlobal); err != nil { + return err + } + if err = sv.SetGlobalFromHook(ctx, s.sessionVars, value, false); err != nil { + return err + } + if sv.HasInstanceScope() { // skip for INSTANCE scope + return nil + } + if sv.GlobalConfigName != "" { + domain.GetDomain(s).NotifyGlobalConfigChange(sv.GlobalConfigName, variable.OnOffToTrueFalse(value)) + } + return s.replaceGlobalVariablesTableValue(context.TODO(), sv.Name, value, true) +} + +// SetGlobalSysVarOnly updates the sysvar, but does not call the validation function or update aliases. +// This is helpful to prevent duplicate warnings being appended from aliases, or recursion. +// updateLocal indicates whether to rebuild the local SysVar Cache. This is helpful to prevent recursion. +func (s *session) SetGlobalSysVarOnly(ctx context.Context, name string, value string, updateLocal bool) (err error) { + sv := variable.GetSysVar(name) + if sv == nil { + return variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + } + if err = sv.SetGlobalFromHook(ctx, s.sessionVars, value, true); err != nil { + return err + } + if sv.HasInstanceScope() { // skip for INSTANCE scope + return nil + } + return s.replaceGlobalVariablesTableValue(ctx, sv.Name, value, updateLocal) +} + +// SetTiDBTableValue implements GlobalVarAccessor.SetTiDBTableValue interface. +func (s *session) SetTiDBTableValue(name, value, comment string) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnSysVar) + _, _, err := s.ExecRestrictedSQL(ctx, nil, `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) + return err +} + +// GetTiDBTableValue implements GlobalVarAccessor.GetTiDBTableValue interface. +func (s *session) GetTiDBTableValue(name string) (string, error) { + return s.getTableValue(context.TODO(), mysql.TiDBTable, name) +} + +var _ sqlexec.SQLParser = &session{} + +func (s *session) ParseSQL(ctx context.Context, sql string, params ...parser.ParseParam) ([]ast.StmtNode, []error, error) { + defer tracing.StartRegion(ctx, "ParseSQL").End() + + p := parserPool.Get().(*parser.Parser) + defer parserPool.Put(p) + + sqlMode := s.sessionVars.SQLMode + if s.isInternal() { + sqlMode = mysql.DelSQLMode(sqlMode, mysql.ModeNoBackslashEscapes) + } + p.SetSQLMode(sqlMode) + p.SetParserConfig(s.sessionVars.BuildParserConfig()) + tmp, warn, err := p.ParseSQL(sql, params...) + // The []ast.StmtNode is referenced by the parser, to reuse the parser, make a copy of the result. + res := make([]ast.StmtNode, len(tmp)) + copy(res, tmp) + return res, warn, err +} + +func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) { + // If command == mysql.ComSleep, it means the SQL execution is finished. The processinfo is reset to SLEEP. + // If the SQL finished and the session is not in transaction, the current start timestamp need to reset to 0. + // Otherwise, it should be set to the transaction start timestamp. + // Why not reset the transaction start timestamp to 0 when transaction committed? + // Because the select statement and other statements need this timestamp to read data, + // after the transaction is committed. e.g. SHOW MASTER STATUS; + var curTxnStartTS uint64 + var curTxnCreateTime time.Time + if command != mysql.ComSleep || s.GetSessionVars().InTxn() { + curTxnStartTS = s.sessionVars.TxnCtx.StartTS + curTxnCreateTime = s.sessionVars.TxnCtx.CreateTime + } + // Set curTxnStartTS to SnapshotTS directly when the session is trying to historic read. + // It will avoid the session meet GC lifetime too short error. + if s.GetSessionVars().SnapshotTS != 0 { + curTxnStartTS = s.GetSessionVars().SnapshotTS + } + p := s.currentPlan + if explain, ok := p.(*plannercore.Explain); ok && explain.Analyze && explain.TargetPlan != nil { + p = explain.TargetPlan + } + + pi := util.ProcessInfo{ + ID: s.sessionVars.ConnectionID, + Port: s.sessionVars.Port, + DB: s.sessionVars.CurrentDB, + Command: command, + Plan: p, + PlanExplainRows: plannercore.GetExplainRowsForPlan(p), + RuntimeStatsColl: s.sessionVars.StmtCtx.RuntimeStatsColl, + Time: t, + State: s.Status(), + Info: sql, + CurTxnStartTS: curTxnStartTS, + CurTxnCreateTime: curTxnCreateTime, + StmtCtx: s.sessionVars.StmtCtx, + RefCountOfStmtCtx: &s.sessionVars.RefCountOfStmtCtx, + MemTracker: s.sessionVars.MemTracker, + DiskTracker: s.sessionVars.DiskTracker, + StatsInfo: plannercore.GetStatsInfo, + OOMAlarmVariablesInfo: s.getOomAlarmVariablesInfo(), + TableIDs: s.sessionVars.StmtCtx.TableIDs, + IndexNames: s.sessionVars.StmtCtx.IndexNames, + MaxExecutionTime: maxExecutionTime, + RedactSQL: s.sessionVars.EnableRedactLog, + ResourceGroupName: s.sessionVars.ResourceGroupName, + SessionAlias: s.sessionVars.SessionAlias, + } + oldPi := s.ShowProcess() + if p == nil { + // Store the last valid plan when the current plan is nil. + // This is for `explain for connection` statement has the ability to query the last valid plan. + if oldPi != nil && oldPi.Plan != nil && len(oldPi.PlanExplainRows) > 0 { + pi.Plan = oldPi.Plan + pi.PlanExplainRows = oldPi.PlanExplainRows + pi.RuntimeStatsColl = oldPi.RuntimeStatsColl + } + } + // We set process info before building plan, so we extended execution time. + if oldPi != nil && oldPi.Info == pi.Info && oldPi.Command == pi.Command { + pi.Time = oldPi.Time + } + if oldPi != nil && oldPi.CurTxnStartTS != 0 && oldPi.CurTxnStartTS == pi.CurTxnStartTS { + // Keep the last expensive txn log time, avoid print too many expensive txn logs. + pi.ExpensiveTxnLogTime = oldPi.ExpensiveTxnLogTime + } + _, digest := s.sessionVars.StmtCtx.SQLDigest() + pi.Digest = digest.String() + // DO NOT reset the currentPlan to nil until this query finishes execution, otherwise reentrant calls + // of SetProcessInfo would override Plan and PlanExplainRows to nil. + if command == mysql.ComSleep { + s.currentPlan = nil + } + if s.sessionVars.User != nil { + pi.User = s.sessionVars.User.Username + pi.Host = s.sessionVars.User.Hostname + } + s.processInfo.Store(&pi) +} + +// UpdateProcessInfo updates the session's process info for the running statement. +func (s *session) UpdateProcessInfo() { + pi := s.ShowProcess() + if pi == nil || pi.CurTxnStartTS != 0 { + return + } + // Update the current transaction start timestamp. + pi.CurTxnStartTS = s.sessionVars.TxnCtx.StartTS + pi.CurTxnCreateTime = s.sessionVars.TxnCtx.CreateTime +} + +func (s *session) getOomAlarmVariablesInfo() util.OOMAlarmVariablesInfo { + return util.OOMAlarmVariablesInfo{ + SessionAnalyzeVersion: s.sessionVars.AnalyzeVersion, + SessionEnabledRateLimitAction: s.sessionVars.EnabledRateLimitAction, + SessionMemQuotaQuery: s.sessionVars.MemQuotaQuery, + } +} + +func (s *session) SetDiskFullOpt(level kvrpcpb.DiskFullOpt) { + s.diskFullOpt = level +} + +func (s *session) GetDiskFullOpt() kvrpcpb.DiskFullOpt { + return s.diskFullOpt +} + +func (s *session) ClearDiskFullOpt() { + s.diskFullOpt = kvrpcpb.DiskFullOpt_NotAllowedOnFull +} + +func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) { + origin := s.sessionVars.InRestrictedSQL + s.sessionVars.InRestrictedSQL = true + defer func() { + s.sessionVars.InRestrictedSQL = origin + // Restore the goroutine label by using the original ctx after execution is finished. + pprof.SetGoroutineLabels(ctx) + }() + + r, ctx := tracing.StartRegionEx(ctx, "session.ExecuteInternal") + defer r.End() + logutil.Eventf(ctx, "execute: %s", sql) + + stmtNode, err := s.ParseWithParams(ctx, sql, args...) + if err != nil { + return nil, err + } + + rs, err = s.ExecuteStmt(ctx, stmtNode) + if err != nil { + s.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, err + } + + return rs, err +} + +// Execute is deprecated, we can remove it as soon as plugins are migrated. +func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { + r, ctx := tracing.StartRegionEx(ctx, "session.Execute") + defer r.End() + logutil.Eventf(ctx, "execute: %s", sql) + + stmtNodes, err := s.Parse(ctx, sql) + if err != nil { + return nil, err + } + if len(stmtNodes) != 1 { + return nil, errors.New("Execute() API doesn't support multiple statements any more") + } + + rs, err := s.ExecuteStmt(ctx, stmtNodes[0]) + if err != nil { + s.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, err + } + return []sqlexec.RecordSet{rs}, err +} + +// Parse parses a query string to raw ast.StmtNode. +func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) { + logutil.Logger(ctx).Debug("parse", zap.String("sql", sql)) + parseStartTime := time.Now() + stmts, warns, err := s.ParseSQL(ctx, sql, s.sessionVars.GetParseParams()...) + if err != nil { + s.rollbackOnError(ctx) + err = util.SyntaxError(err) + + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + if s.sessionVars.EnableRedactLog { + logutil.Logger(ctx).Debug("parse SQL failed", zap.Error(err), zap.String("SQL", sql)) + } else { + logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", sql)) + } + s.sessionVars.StmtCtx.AppendError(err) + } + return nil, err + } + + durParse := time.Since(parseStartTime) + s.GetSessionVars().DurationParse = durParse + isInternal := s.isInternal() + if isInternal { + session_metrics.SessionExecuteParseDurationInternal.Observe(durParse.Seconds()) + } else { + session_metrics.SessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) + } + for _, warn := range warns { + s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + return stmts, nil +} + +// ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. +// Note that it will not do escaping if no variable arguments are passed. +func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) { + var err error + if len(args) > 0 { + sql, err = sqlescape.EscapeSQL(sql, args...) + if err != nil { + return nil, err + } + } + + internal := s.isInternal() + + var stmts []ast.StmtNode + var warns []error + parseStartTime := time.Now() + if internal { + // Do no respect the settings from clients, if it is for internal usage. + // Charsets from clients may give chance injections. + // Refer to https://stackoverflow.com/questions/5741187/sql-injection-that-gets-around-mysql-real-escape-string/12118602. + stmts, warns, err = s.ParseSQL(ctx, sql) + } else { + stmts, warns, err = s.ParseSQL(ctx, sql, s.sessionVars.GetParseParams()...) + } + if len(stmts) != 1 && err == nil { + err = errors.New("run multiple statements internally is not supported") + } + if err != nil { + s.rollbackOnError(ctx) + logSQL := sql[:min(500, len(sql))] + if s.sessionVars.EnableRedactLog { + logutil.Logger(ctx).Debug("parse SQL failed", zap.Error(err), zap.String("SQL", logSQL)) + } else { + logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", logSQL)) + } + return nil, util.SyntaxError(err) + } + durParse := time.Since(parseStartTime) + if internal { + session_metrics.SessionExecuteParseDurationInternal.Observe(durParse.Seconds()) + } else { + session_metrics.SessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) + } + for _, warn := range warns { + s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + if topsqlstate.TopSQLEnabled() { + normalized, digest := parser.NormalizeDigest(sql) + if digest != nil { + // Reset the goroutine label when internal sql execute finish. + // Specifically reset in ExecRestrictedStmt function. + s.sessionVars.StmtCtx.IsSQLRegistered.Store(true) + topsql.AttachAndRegisterSQLInfo(ctx, normalized, digest, s.sessionVars.InRestrictedSQL) + } + } + return stmts[0], nil +} + +// GetAdvisoryLock acquires an advisory lock of lockName. +// Note that a lock can be acquired multiple times by the same session, +// in which case we increment a reference count. +// Each lock needs to be held in a unique session because +// we need to be able to ROLLBACK in any arbitrary order +// in order to release the locks. +func (s *session) GetAdvisoryLock(lockName string, timeout int64) error { + if lock, ok := s.advisoryLocks[lockName]; ok { + lock.IncrReferences() + return nil + } + sess, err := createSession(s.store) + if err != nil { + return err + } + infosync.StoreInternalSession(sess) + lock := &advisoryLock{session: sess, ctx: context.TODO(), owner: s.ShowProcess().ID} + err = lock.GetLock(lockName, timeout) + if err != nil { + return err + } + s.advisoryLocks[lockName] = lock + return nil +} + +// IsUsedAdvisoryLock checks if a lockName is already in use +func (s *session) IsUsedAdvisoryLock(lockName string) uint64 { + // Same session + if lock, ok := s.advisoryLocks[lockName]; ok { + return lock.owner + } + + // Check for transaction on advisory_locks table + sess, err := createSession(s.store) + if err != nil { + return 0 + } + lock := &advisoryLock{session: sess, ctx: context.TODO(), owner: s.ShowProcess().ID} + err = lock.IsUsedLock(lockName) + if err != nil { + // TODO: Return actual owner pid + // TODO: Check for mysql.ErrLockWaitTimeout and DeadLock + return 1 + } + return 0 +} + +// ReleaseAdvisoryLock releases an advisory locks held by the session. +// It returns FALSE if no lock by this name was held (by this session), +// and TRUE if a lock was held and "released". +// Note that the lock is not actually released if there are multiple +// references to the same lockName by the session, instead the reference +// count is decremented. +func (s *session) ReleaseAdvisoryLock(lockName string) (released bool) { + if lock, ok := s.advisoryLocks[lockName]; ok { + lock.DecrReferences() + if lock.ReferenceCount() <= 0 { + lock.Close() + delete(s.advisoryLocks, lockName) + infosync.DeleteInternalSession(lock.session) + } + return true + } + return false +} + +// ReleaseAllAdvisoryLocks releases all advisory locks held by the session +// and returns a count of the locks that were released. +// The count is based on unique locks held, so multiple references +// to the same lock do not need to be accounted for. +func (s *session) ReleaseAllAdvisoryLocks() int { + var count int + for lockName, lock := range s.advisoryLocks { + lock.Close() + count += lock.ReferenceCount() + delete(s.advisoryLocks, lockName) + infosync.DeleteInternalSession(lock.session) + } + return count +} + +// GetExtensions returns the `*extension.SessionExtensions` object +func (s *session) GetExtensions() *extension.SessionExtensions { + return s.extensions +} + +// SetExtensions sets the `*extension.SessionExtensions` object +func (s *session) SetExtensions(extensions *extension.SessionExtensions) { + s.extensions = extensions +} + +// InSandBoxMode indicates that this session is in sandbox mode +func (s *session) InSandBoxMode() bool { + return s.sandBoxMode +} + +// EnableSandBoxMode enable the sandbox mode. +func (s *session) EnableSandBoxMode() { + s.sandBoxMode = true +} + +// DisableSandBoxMode enable the sandbox mode. +func (s *session) DisableSandBoxMode() { + s.sandBoxMode = false +} + +// ParseWithParams4Test wrapper (s *session) ParseWithParams for test +func ParseWithParams4Test(ctx context.Context, s Session, + sql string, args ...interface{}) (ast.StmtNode, error) { + return s.(*session).ParseWithParams(ctx, sql, args) +} + +var _ sqlexec.RestrictedSQLExecutor = &session{} +var _ sqlexec.SQLExecutor = &session{} + +// ExecRestrictedStmt implements RestrictedSQLExecutor interface. +func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( + []chunk.Row, []*ast.ResultField, error) { + defer pprof.SetGoroutineLabels(ctx) + execOption := sqlexec.GetExecOption(opts) + var se *session + var clean func() + var err error + if execOption.UseCurSession { + se, clean, err = s.useCurrentSession(execOption) + } else { + se, clean, err = s.getInternalSession(execOption) + } + if err != nil { + return nil, nil, err + } + defer clean() + + startTime := time.Now() + metrics.SessionRestrictedSQLCounter.Inc() + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) + rs, err := se.ExecuteStmt(ctx, stmtNode) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs, nil) + if err != nil { + return nil, nil, err + } + + vars := se.GetSessionVars() + for _, dbName := range GetDBNames(vars) { + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal, dbName, vars.ResourceGroupName).Observe(time.Since(startTime).Seconds()) + } + return rows, rs.Fields(), err +} + +// ExecRestrictedStmt4Test wrapper `(s *session) ExecRestrictedStmt` for test. +func ExecRestrictedStmt4Test(ctx context.Context, s Session, + stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( + []chunk.Row, []*ast.ResultField, error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + return s.(*session).ExecRestrictedStmt(ctx, stmtNode, opts...) +} + +// only set and clean session with execOption +func (s *session) useCurrentSession(execOption sqlexec.ExecOption) (*session, func(), error) { + var err error + orgSnapshotInfoSchema, orgSnapshotTS := s.sessionVars.SnapshotInfoschema, s.sessionVars.SnapshotTS + if execOption.SnapshotTS != 0 { + if err = s.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { + return nil, nil, err + } + s.sessionVars.SnapshotInfoschema, err = getSnapshotInfoSchema(s, execOption.SnapshotTS) + if err != nil { + return nil, nil, err + } + } + prevStatsVer := s.sessionVars.AnalyzeVersion + if execOption.AnalyzeVer != 0 { + s.sessionVars.AnalyzeVersion = execOption.AnalyzeVer + } + prevAnalyzeSnapshot := s.sessionVars.EnableAnalyzeSnapshot + if execOption.AnalyzeSnapshot != nil { + s.sessionVars.EnableAnalyzeSnapshot = *execOption.AnalyzeSnapshot + } + prePruneMode := s.sessionVars.PartitionPruneMode.Load() + if len(execOption.PartitionPruneMode) > 0 { + s.sessionVars.PartitionPruneMode.Store(execOption.PartitionPruneMode) + } + prevSQL := s.sessionVars.StmtCtx.OriginalSQL + prevStmtType := s.sessionVars.StmtCtx.StmtType + prevTables := s.sessionVars.StmtCtx.Tables + return s, func() { + s.sessionVars.AnalyzeVersion = prevStatsVer + s.sessionVars.EnableAnalyzeSnapshot = prevAnalyzeSnapshot + if err := s.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) + } + s.sessionVars.SnapshotInfoschema = orgSnapshotInfoSchema + s.sessionVars.SnapshotTS = orgSnapshotTS + s.sessionVars.PartitionPruneMode.Store(prePruneMode) + s.sessionVars.StmtCtx.OriginalSQL = prevSQL + s.sessionVars.StmtCtx.StmtType = prevStmtType + s.sessionVars.StmtCtx.Tables = prevTables + s.sessionVars.MemTracker.Detach() + }, nil +} + +func (s *session) getInternalSession(execOption sqlexec.ExecOption) (*session, func(), error) { + tmp, err := s.sysSessionPool().Get() + if err != nil { + return nil, nil, errors.Trace(err) + } + se := tmp.(*session) + + // The special session will share the `InspectionTableCache` with current session + // if the current session in inspection mode. + if cache := s.sessionVars.InspectionTableCache; cache != nil { + se.sessionVars.InspectionTableCache = cache + } + if ok := s.sessionVars.OptimizerUseInvisibleIndexes; ok { + se.sessionVars.OptimizerUseInvisibleIndexes = true + } + + if execOption.SnapshotTS != 0 { + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { + return nil, nil, err + } + se.sessionVars.SnapshotInfoschema, err = getSnapshotInfoSchema(s, execOption.SnapshotTS) + if err != nil { + return nil, nil, err + } + } + + prevStatsVer := se.sessionVars.AnalyzeVersion + if execOption.AnalyzeVer != 0 { + se.sessionVars.AnalyzeVersion = execOption.AnalyzeVer + } + + prevAnalyzeSnapshot := se.sessionVars.EnableAnalyzeSnapshot + if execOption.AnalyzeSnapshot != nil { + se.sessionVars.EnableAnalyzeSnapshot = *execOption.AnalyzeSnapshot + } + + prePruneMode := se.sessionVars.PartitionPruneMode.Load() + if len(execOption.PartitionPruneMode) > 0 { + se.sessionVars.PartitionPruneMode.Store(execOption.PartitionPruneMode) + } + + return se, func() { + se.sessionVars.AnalyzeVersion = prevStatsVer + se.sessionVars.EnableAnalyzeSnapshot = prevAnalyzeSnapshot + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) + } + se.sessionVars.SnapshotInfoschema = nil + se.sessionVars.SnapshotTS = 0 + if !execOption.IgnoreWarning { + if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + s.GetSessionVars().StmtCtx.AppendWarnings(warnings) + } + } + se.sessionVars.PartitionPruneMode.Store(prePruneMode) + se.sessionVars.OptimizerUseInvisibleIndexes = false + se.sessionVars.InspectionTableCache = nil + se.sessionVars.MemTracker.Detach() + s.sysSessionPool().Put(tmp) + }, nil +} + +func (s *session) withRestrictedSQLExecutor(ctx context.Context, opts []sqlexec.OptionFuncAlias, fn func(context.Context, *session) ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { + execOption := sqlexec.GetExecOption(opts) + var se *session + var clean func() + var err error + if execOption.UseCurSession { + se, clean, err = s.useCurrentSession(execOption) + } else { + se, clean, err = s.getInternalSession(execOption) + } + if err != nil { + return nil, nil, errors.Trace(err) + } + defer clean() + if execOption.TrackSysProcID > 0 { + err = execOption.TrackSysProc(execOption.TrackSysProcID, se) + if err != nil { + return nil, nil, errors.Trace(err) + } + // unTrack should be called before clean (return sys session) + defer execOption.UnTrackSysProc(execOption.TrackSysProcID) + } + return fn(ctx, se) +} + +func (s *session) ExecRestrictedSQL(ctx context.Context, opts []sqlexec.OptionFuncAlias, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { + return s.withRestrictedSQLExecutor(ctx, opts, func(ctx context.Context, se *session) ([]chunk.Row, []*ast.ResultField, error) { + stmt, err := se.ParseWithParams(ctx, sql, params...) + if err != nil { + return nil, nil, errors.Trace(err) + } + defer pprof.SetGoroutineLabels(ctx) + startTime := time.Now() + metrics.SessionRestrictedSQLCounter.Inc() + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) + rs, err := se.ExecuteInternalStmt(ctx, stmt) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs, nil) + if err != nil { + return nil, nil, err + } + + vars := se.GetSessionVars() + for _, dbName := range GetDBNames(vars) { + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal, dbName, vars.ResourceGroupName).Observe(time.Since(startTime).Seconds()) + } + return rows, rs.Fields(), err + }) +} + +// ExecuteInternalStmt execute internal stmt +func (s *session) ExecuteInternalStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + origin := s.sessionVars.InRestrictedSQL + s.sessionVars.InRestrictedSQL = true + defer func() { + s.sessionVars.InRestrictedSQL = origin + // Restore the goroutine label by using the original ctx after execution is finished. + pprof.SetGoroutineLabels(ctx) + }() + return s.ExecuteStmt(ctx, stmtNode) +} + +func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + r, ctx := tracing.StartRegionEx(ctx, "session.ExecuteStmt") + defer r.End() + + if err := s.PrepareTxnCtx(ctx); err != nil { + return nil, err + } + if err := s.loadCommonGlobalVariablesIfNeeded(); err != nil { + return nil, err + } + + sessVars := s.sessionVars + sessVars.StartTime = time.Now() + + // Some executions are done in compile stage, so we reset them before compile. + if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { + return nil, err + } + normalizedSQL, digest := s.sessionVars.StmtCtx.SQLDigest() + cmdByte := byte(atomic.LoadUint32(&s.GetSessionVars().CommandValue)) + if topsqlstate.TopSQLEnabled() { + s.sessionVars.StmtCtx.IsSQLRegistered.Store(true) + ctx = topsql.AttachAndRegisterSQLInfo(ctx, normalizedSQL, digest, s.sessionVars.InRestrictedSQL) + } + if sessVars.InPlanReplayer { + sessVars.StmtCtx.EnableOptimizerDebugTrace = true + } else if dom := domain.GetDomain(s); dom != nil && !sessVars.InRestrictedSQL { + // This is the earliest place we can get the SQL digest for this execution. + // If we find this digest is registered for PLAN REPLAYER CAPTURE, we need to enable optimizer debug trace no matter + // the plan digest will be matched or not. + if planReplayerHandle := dom.GetPlanReplayerHandle(); planReplayerHandle != nil { + tasks := planReplayerHandle.GetTasks() + for _, task := range tasks { + if task.SQLDigest == digest.String() { + sessVars.StmtCtx.EnableOptimizerDebugTrace = true + } + } + } + } + if sessVars.StmtCtx.EnableOptimizerDebugTrace { + plannercore.DebugTraceReceivedCommand(s, cmdByte, stmtNode) + } + + if err := s.validateStatementInTxn(stmtNode); err != nil { + return nil, err + } + + if err := s.validateStatementReadOnlyInStaleness(stmtNode); err != nil { + return nil, err + } + + // Uncorrelated subqueries will execute once when building plan, so we reset process info before building plan. + s.currentPlan = nil // reset current plan + s.SetProcessInfo(stmtNode.Text(), time.Now(), cmdByte, 0) + s.txn.onStmtStart(digest.String()) + defer sessiontxn.GetTxnManager(s).OnStmtEnd() + defer s.txn.onStmtEnd() + + if err := s.onTxnManagerStmtStartOrRetry(ctx, stmtNode); err != nil { + return nil, err + } + + failpoint.Inject("mockStmtSlow", func(val failpoint.Value) { + if strings.Contains(stmtNode.Text(), "/* sleep */") { + v, _ := val.(int) + time.Sleep(time.Duration(v) * time.Millisecond) + } + }) + + var stmtLabel string + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + prepareStmt, err := plannercore.GetPreparedStmt(execStmt, s.sessionVars) + if err == nil && prepareStmt.PreparedAst != nil { + stmtLabel = ast.GetStmtLabel(prepareStmt.PreparedAst.Stmt) + } + } + if stmtLabel == "" { + stmtLabel = ast.GetStmtLabel(stmtNode) + } + s.setRequestSource(ctx, stmtLabel, stmtNode) + + // Backup the original resource group name since sql hint might change it during optimization + originalResourceGroup := s.GetSessionVars().ResourceGroupName + + // Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). + compiler := executor.Compiler{Ctx: s} + stmt, err := compiler.Compile(ctx, stmtNode) + // session resource-group might be changed by query hint, ensure restore it back when + // the execution finished. + if sessVars.ResourceGroupName != originalResourceGroup { + // if target resource group doesn't exist, fallback to the origin resource group. + if _, ok := domain.GetDomain(s).InfoSchema().ResourceGroupByName(model.NewCIStr(sessVars.ResourceGroupName)); !ok { + logutil.Logger(ctx).Warn("Unknown resource group from hint", zap.String("name", sessVars.ResourceGroupName)) + sessVars.ResourceGroupName = originalResourceGroup + // if we are in a txn, should also reset the txn resource group. + if txn, err := s.Txn(false); err == nil && txn != nil && txn.Valid() { + kv.SetTxnResourceGroup(txn, originalResourceGroup) + } + } else { + defer func() { + // Restore the resource group for the session + sessVars.ResourceGroupName = originalResourceGroup + }() + } + } + if err != nil { + s.rollbackOnError(ctx) + + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + if !variable.ErrUnknownSystemVar.Equal(err) { + sql := stmtNode.Text() + if s.sessionVars.EnableRedactLog { + sql = parser.Normalize(sql) + } + logutil.Logger(ctx).Warn("compile SQL failed", zap.Error(err), + zap.String("SQL", sql)) + } + } + return nil, err + } + + durCompile := time.Since(s.sessionVars.StartTime) + s.GetSessionVars().DurationCompile = durCompile + if s.isInternal() { + session_metrics.SessionExecuteCompileDurationInternal.Observe(durCompile.Seconds()) + } else { + session_metrics.SessionExecuteCompileDurationGeneral.Observe(durCompile.Seconds()) + } + s.currentPlan = stmt.Plan + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + if execStmt.Name == "" { + // for exec-stmt on bin-protocol, ignore the plan detail in `show process` to gain performance benefits. + s.currentPlan = nil + } + } + + // Execute the physical plan. + logStmt(stmt, s) + + var recordSet sqlexec.RecordSet + if stmt.PsStmt != nil { // point plan short path + recordSet, err = stmt.PointGet(ctx) + s.txn.changeToInvalid() + } else { + recordSet, err = runStmt(ctx, s, stmt) + } + + // Observe the resource group query total counter if the resource control is enabled and the + // current session is attached with a resource group. + resourceGroupName := s.GetSessionVars().ResourceGroupName + if len(resourceGroupName) > 0 { + metrics.ResourceGroupQueryTotalCounter.WithLabelValues(resourceGroupName).Inc() + } + + if err != nil { + if !errIsNoisy(err) { + logutil.Logger(ctx).Warn("run statement failed", + zap.Int64("schemaVersion", s.GetInfoSchema().SchemaMetaVersion()), + zap.Error(err), + zap.String("session", s.String())) + } + return recordSet, err + } + if !s.isInternal() && config.GetGlobalConfig().EnableTelemetry { + telemetry.CurrentExecuteCount.Inc() + tiFlashPushDown, tiFlashExchangePushDown := plannercore.IsTiFlashContained(stmt.Plan) + if tiFlashPushDown { + telemetry.CurrentTiFlashPushDownCount.Inc() + } + if tiFlashExchangePushDown { + telemetry.CurrentTiFlashExchangePushDownCount.Inc() + } + } + return recordSet, nil +} + +func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context, node ast.StmtNode) error { + if s.sessionVars.RetryInfo.Retrying { + return sessiontxn.GetTxnManager(s).OnStmtRetry(ctx) + } + return sessiontxn.GetTxnManager(s).OnStmtStart(ctx, node) +} + +func (s *session) validateStatementInTxn(stmtNode ast.StmtNode) error { + vars := s.GetSessionVars() + if _, ok := stmtNode.(*ast.ImportIntoStmt); ok && vars.InTxn() { + return errors.New("cannot run IMPORT INTO in explicit transaction") + } + return nil +} + +func (s *session) validateStatementReadOnlyInStaleness(stmtNode ast.StmtNode) error { + vars := s.GetSessionVars() + if !vars.TxnCtx.IsStaleness && vars.TxnReadTS.PeakTxnReadTS() == 0 && !vars.EnableExternalTSRead || vars.InRestrictedSQL { + return nil + } + errMsg := "only support read-only statement during read-only staleness transactions" + node := stmtNode.(ast.Node) + switch v := node.(type) { + case *ast.SplitRegionStmt: + return nil + case *ast.SelectStmt: + // select lock statement needs start a transaction which will be conflict to stale read, + // we forbid select lock statement in stale read for now. + if v.LockInfo != nil { + return errors.New("select lock hasn't been supported in stale read yet") + } + if !planner.IsReadOnly(stmtNode, vars) { + return errors.New(errMsg) + } + return nil + case *ast.ExplainStmt, *ast.DoStmt, *ast.ShowStmt, *ast.SetOprStmt, *ast.ExecuteStmt, *ast.SetOprSelectList: + if !planner.IsReadOnly(stmtNode, vars) { + return errors.New(errMsg) + } + return nil + default: + } + // covered DeleteStmt/InsertStmt/UpdateStmt/CallStmt/LoadDataStmt + if _, ok := stmtNode.(ast.DMLNode); ok { + return errors.New(errMsg) + } + return nil +} + +// fileTransInConnKeys contains the keys of queries that will be handled by handleFileTransInConn. +var fileTransInConnKeys = []fmt.Stringer{ + executor.LoadDataVarKey, + executor.LoadStatsVarKey, + executor.IndexAdviseVarKey, + executor.PlanReplayerLoadVarKey, +} + +func (s *session) hasFileTransInConn() bool { + s.mu.RLock() + defer s.mu.RUnlock() + + for _, k := range fileTransInConnKeys { + v := s.mu.values[k] + if v != nil { + return true + } + } + return false +} + +// runStmt executes the sqlexec.Statement and commit or rollback the current transaction. +func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) { + failpoint.Inject("assertTxnManagerInRunStmt", func() { + sessiontxn.RecordAssert(se, "assertTxnManagerInRunStmt", true) + if stmt, ok := s.(*executor.ExecStmt); ok { + sessiontxn.AssertTxnManagerInfoSchema(se, stmt.InfoSchema) + } + }) + + r, ctx := tracing.StartRegionEx(ctx, "session.runStmt") + defer r.End() + if r.Span != nil { + r.Span.LogKV("sql", s.OriginText()) + } + + se.SetValue(sessionctx.QueryString, s.OriginText()) + if _, ok := s.(*executor.ExecStmt).StmtNode.(ast.DDLNode); ok { + se.SetValue(sessionctx.LastExecuteDDL, true) + } else { + se.ClearValue(sessionctx.LastExecuteDDL) + } + + sessVars := se.sessionVars + + // Record diagnostic information for DML statements + if stmt, ok := s.(*executor.ExecStmt).StmtNode.(ast.DMLNode); ok { + // Keep the previous queryInfo for `show session_states` because the statement needs to encode it. + if showStmt, ok := stmt.(*ast.ShowStmt); !ok || showStmt.Tp != ast.ShowSessionStates { + defer func() { + sessVars.LastQueryInfo = sessionstates.QueryInfo{ + TxnScope: sessVars.CheckAndGetTxnScope(), + StartTS: sessVars.TxnCtx.StartTS, + ForUpdateTS: sessVars.TxnCtx.GetForUpdateTS(), + } + if err != nil { + sessVars.LastQueryInfo.ErrMsg = err.Error() + } + }() + } + } + + // Save origTxnCtx here to avoid it reset in the transaction retry. + origTxnCtx := sessVars.TxnCtx + err = se.checkTxnAborted(s) + if err != nil { + return nil, err + } + if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) { + // Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check, + // otherwise, the stmt won't be add into stmt history, and also don't need check. + // About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit + if err := checkStmtLimit(ctx, se, false); err != nil { + return nil, err + } + } + + rs, err = s.Exec(ctx) + se.updateTelemetryMetric(s.(*executor.ExecStmt)) + sessVars.TxnCtx.StatementCount++ + if rs != nil { + if se.GetSessionVars().StmtCtx.IsExplainAnalyzeDML { + if !sessVars.InTxn() { + se.StmtCommit(ctx) + if err := se.CommitTxn(ctx); err != nil { + return nil, err + } + } + } + return &execStmtResult{ + RecordSet: rs, + sql: s, + se: se, + }, err + } + + err = finishStmt(ctx, se, err, s) + if se.hasFileTransInConn() { + // The query will be handled later in handleFileTransInConn, + // then should call the ExecStmt.FinishExecuteStmt to finish this statement. + se.SetValue(ExecStmtVarKey, s.(*executor.ExecStmt)) + } else { + // If it is not a select statement or special query, we record its slow log here, + // then it could include the transaction commit time. + s.(*executor.ExecStmt).FinishExecuteStmt(origTxnCtx.StartTS, err, false) + } + return nil, err +} + +// ExecStmtVarKeyType is a dummy type to avoid naming collision in context. +type ExecStmtVarKeyType int + +// String defines a Stringer function for debugging and pretty printing. +func (ExecStmtVarKeyType) String() string { + return "exec_stmt_var_key" +} + +// ExecStmtVarKey is a variable key for ExecStmt. +const ExecStmtVarKey ExecStmtVarKeyType = 0 + +// execStmtResult is the return value of ExecuteStmt and it implements the sqlexec.RecordSet interface. +// Why we need a struct to wrap a RecordSet and provide another RecordSet? +// This is because there are so many session state related things that definitely not belongs to the original +// RecordSet, so this struct exists and RecordSet.Close() is overrided handle that. +type execStmtResult struct { + sqlexec.RecordSet + se *session + sql sqlexec.Statement +} + +func (rs *execStmtResult) Close() error { + se := rs.se + if err := rs.RecordSet.Close(); err != nil { + return finishStmt(context.Background(), se, err, rs.sql) + } + return finishStmt(context.Background(), se, nil, rs.sql) +} + +// rollbackOnError makes sure the next statement starts a new transaction with the latest InfoSchema. +func (s *session) rollbackOnError(ctx context.Context) { + if !s.sessionVars.InTxn() { + s.RollbackTxn(ctx) + } +} + +// PrepareStmt is used for executing prepare statement in binary protocol +func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) { + defer func() { + if s.sessionVars.StmtCtx != nil { + s.sessionVars.StmtCtx.DetachMemDiskTracker() + } + }() + if s.sessionVars.TxnCtx.InfoSchema == nil { + // We don't need to create a transaction for prepare statement, just get information schema will do. + s.sessionVars.TxnCtx.InfoSchema = domain.GetDomain(s).InfoSchema() + } + err = s.loadCommonGlobalVariablesIfNeeded() + if err != nil { + return + } + + ctx := context.Background() + // NewPrepareExec may need startTS to build the executor, for example prepare statement has subquery in int. + // So we have to call PrepareTxnCtx here. + if err = s.PrepareTxnCtx(ctx); err != nil { + return + } + + prepareStmt := &ast.PrepareStmt{SQLText: sql} + if err = s.onTxnManagerStmtStartOrRetry(ctx, prepareStmt); err != nil { + return + } + + if err = sessiontxn.GetTxnManager(s).AdviseWarmup(); err != nil { + return + } + prepareExec := executor.NewPrepareExec(s, sql) + err = prepareExec.Next(ctx, nil) + // Rollback even if err is nil. + s.rollbackOnError(ctx) + + if err != nil { + return + } + return prepareExec.ID, prepareExec.ParamCount, prepareExec.Fields, nil +} + +// ExecutePreparedStmt executes a prepared statement. +func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, params []expression.Expression) (sqlexec.RecordSet, error) { + prepStmt, err := s.sessionVars.GetPreparedStmtByID(stmtID) + if err != nil { + err = plannercore.ErrStmtNotFound + logutil.Logger(ctx).Error("prepared statement not found", zap.Uint32("stmtID", stmtID)) + return nil, err + } + stmt, ok := prepStmt.(*plannercore.PlanCacheStmt) + if !ok { + return nil, errors.Errorf("invalid PlanCacheStmt type") + } + execStmt := &ast.ExecuteStmt{ + BinaryArgs: params, + PrepStmt: stmt, + } + return s.ExecuteStmt(ctx, execStmt) +} + +func (s *session) DropPreparedStmt(stmtID uint32) error { + vars := s.sessionVars + if _, ok := vars.PreparedStmts[stmtID]; !ok { + return plannercore.ErrStmtNotFound + } + vars.RetryInfo.DroppedPreparedStmtIDs = append(vars.RetryInfo.DroppedPreparedStmtIDs, stmtID) + return nil +} + +func (s *session) Txn(active bool) (kv.Transaction, error) { + if !active { + return &s.txn, nil + } + _, err := sessiontxn.GetTxnManager(s).ActivateTxn() + s.SetMemoryFootprintChangeHook() + return &s.txn, err +} + +func (s *session) SetValue(key fmt.Stringer, value interface{}) { + s.mu.Lock() + s.mu.values[key] = value + s.mu.Unlock() +} + +func (s *session) Value(key fmt.Stringer) interface{} { + s.mu.RLock() + value := s.mu.values[key] + s.mu.RUnlock() + return value +} + +func (s *session) ClearValue(key fmt.Stringer) { + s.mu.Lock() + delete(s.mu.values, key) + s.mu.Unlock() +} + +type inCloseSession struct{} + +// Close function does some clean work when session end. +// Close should release the table locks which hold by the session. +func (s *session) Close() { + // TODO: do clean table locks when session exited without execute Close. + // TODO: do clean table locks when tidb-server was `kill -9`. + if s.HasLockedTables() && config.TableLockEnabled() { + if ds := config.TableLockDelayClean(); ds > 0 { + time.Sleep(time.Duration(ds) * time.Millisecond) + } + lockedTables := s.GetAllTableLocks() + err := domain.GetDomain(s).DDL().UnlockTables(s, lockedTables) + if err != nil { + logutil.BgLogger().Error("release table lock failed", zap.Uint64("conn", s.sessionVars.ConnectionID)) + } + } + s.ReleaseAllAdvisoryLocks() + if s.statsCollector != nil { + s.statsCollector.Delete() + } + if s.idxUsageCollector != nil { + s.idxUsageCollector.Delete() + } + telemetry.GlobalBuiltinFunctionsUsage.Collect(s.GetBuiltinFunctionUsage()) + bindValue := s.Value(bindinfo.SessionBindInfoKeyType) + if bindValue != nil { + bindValue.(*bindinfo.SessionHandle).Close() + } + ctx := context.WithValue(context.TODO(), inCloseSession{}, struct{}{}) + s.RollbackTxn(ctx) + if s.sessionVars != nil { + s.sessionVars.WithdrawAllPreparedStmt() + } + if s.stmtStats != nil { + s.stmtStats.SetFinished() + } + s.ClearDiskFullOpt() + if s.sessionPlanCache != nil { + s.sessionPlanCache.Close() + } +} + +// GetSessionVars implements the context.Context interface. +func (s *session) GetSessionVars() *variable.SessionVars { + return s.sessionVars +} + +func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { + pm := privilege.GetPrivilegeManager(s) + authplugin, err := pm.GetAuthPluginForConnection(user.Username, user.Hostname) + if err != nil { + return "", err + } + return authplugin, nil +} + +// Auth validates a user using an authentication string and salt. +// If the password fails, it will keep trying other users until exhausted. +// This means it can not be refactored to use MatchIdentity yet. +func (s *session) Auth(user *auth.UserIdentity, authentication, salt []byte, authConn conn.AuthConn) error { + hasPassword := "YES" + if len(authentication) == 0 { + hasPassword = "NO" + } + pm := privilege.GetPrivilegeManager(s) + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return privileges.ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) + } + // Check whether continuous login failure is enabled to lock the account. + // If enabled, determine whether to unlock the account and notify TiDB to update the cache. + enableAutoLock := pm.IsAccountAutoLockEnabled(authUser.Username, authUser.Hostname) + if enableAutoLock { + err = failedLoginTrackingBegin(s) + if err != nil { + return err + } + lockStatusChanged, err := verifyAccountAutoLock(s, authUser.Username, authUser.Hostname) + if err != nil { + rollbackErr := failedLoginTrackingRollback(s) + if rollbackErr != nil { + return rollbackErr + } + return err + } + err = failedLoginTrackingCommit(s) + if err != nil { + rollbackErr := failedLoginTrackingRollback(s) + if rollbackErr != nil { + return rollbackErr + } + return err + } + if lockStatusChanged { + // Notification auto unlock. + err = domain.GetDomain(s).NotifyUpdatePrivilege() + if err != nil { + return err + } + } + } + + info, err := pm.ConnectionVerification(user, authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars, authConn) + if err != nil { + if info.FailedDueToWrongPassword { + // when user enables the account locking function for consecutive login failures, + // the system updates the login failure count and determines whether to lock the account when authentication fails. + if enableAutoLock { + err := failedLoginTrackingBegin(s) + if err != nil { + return err + } + lockStatusChanged, passwordLocking, trackingErr := authFailedTracking(s, authUser.Username, authUser.Hostname) + if trackingErr != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return trackingErr + } + if err := failedLoginTrackingCommit(s); err != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return err + } + if lockStatusChanged { + // Notification auto lock. + err := autolockAction(s, passwordLocking, authUser.Username, authUser.Hostname) + if err != nil { + return err + } + } + } + } + return err + } + + if variable.EnableResourceControl.Load() && info.ResourceGroupName != "" { + s.sessionVars.ResourceGroupName = strings.ToLower(info.ResourceGroupName) + } + + if info.InSandBoxMode { + // Enter sandbox mode, only execute statement for resetting password. + s.EnableSandBoxMode() + } + if enableAutoLock { + err := failedLoginTrackingBegin(s) + if err != nil { + return err + } + // The password is correct. If the account is not locked, the number of login failure statistics will be cleared. + err = authSuccessClearCount(s, authUser.Username, authUser.Hostname) + if err != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return err + } + err = failedLoginTrackingCommit(s) + if err != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return err + } + } + pm.AuthSuccess(authUser.Username, authUser.Hostname) + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname + s.sessionVars.User = user + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) + return nil +} + +func authSuccessClearCount(s *session, user string, host string) error { + // Obtain accurate lock status and failure count information. + passwordLocking, err := getFailedLoginUserAttributes(s, user, host) + if err != nil { + return err + } + // If the account is locked, it may be caused by the untimely update of the cache, + // directly report the account lock. + if passwordLocking.AutoAccountLocked { + if passwordLocking.PasswordLockTimeDays == -1 { + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, + "unlimited", "unlimited") + } + + lds := strconv.FormatInt(passwordLocking.PasswordLockTimeDays, 10) + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, lds, lds) + } + if passwordLocking.FailedLoginCount != 0 { + // If the number of account login failures is not zero, it will be updated to 0. + passwordLockingJSON := privileges.BuildSuccessPasswordLockingJSON(passwordLocking.FailedLoginAttempts, + passwordLocking.PasswordLockTimeDays) + if passwordLockingJSON != "" { + if err := s.passwordLocking(user, host, passwordLockingJSON); err != nil { + return err + } + } + } + return nil +} + +func verifyAccountAutoLock(s *session, user, host string) (bool, error) { + pm := privilege.GetPrivilegeManager(s) + // Use the cache to determine whether to unlock the account. + // If the account needs to be unlocked, read the database information to determine whether + // the account needs to be unlocked. Otherwise, an error message is displayed. + lockStatusInMemory, err := pm.VerifyAccountAutoLockInMemory(user, host) + if err != nil { + return false, err + } + // If the lock status in the cache is Unlock, the automatic unlock is skipped. + // If memory synchronization is slow and there is a lock in the database, it will be processed upon successful login. + if !lockStatusInMemory { + return false, nil + } + lockStatusChanged := false + var plJSON string + // After checking the cache, obtain the latest data from the database and determine + // whether to automatically unlock the database to prevent repeated unlock errors. + pl, err := getFailedLoginUserAttributes(s, user, host) + if err != nil { + return false, err + } + if pl.AutoAccountLocked { + // If it is locked, need to check whether it can be automatically unlocked. + lockTimeDay := pl.PasswordLockTimeDays + if lockTimeDay == -1 { + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, "unlimited", "unlimited") + } + lastChanged := pl.AutoLockedLastChanged + d := time.Now().Unix() - lastChanged + if d <= lockTimeDay*24*60*60 { + lds := strconv.FormatInt(lockTimeDay, 10) + rds := strconv.FormatInt(int64(math.Ceil(float64(lockTimeDay)-float64(d)/(24*60*60))), 10) + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, lds, rds) + } + // Generate unlock json string. + plJSON = privileges.BuildPasswordLockingJSON(pl.FailedLoginAttempts, + pl.PasswordLockTimeDays, "N", 0, time.Now().Format(time.UnixDate)) + } + if plJSON != "" { + lockStatusChanged = true + if err = s.passwordLocking(user, host, plJSON); err != nil { + return false, err + } + } + return lockStatusChanged, nil +} + +func authFailedTracking(s *session, user string, host string) (bool, *privileges.PasswordLocking, error) { + // Obtain the number of consecutive password login failures. + passwordLocking, err := getFailedLoginUserAttributes(s, user, host) + if err != nil { + return false, nil, err + } + // Consecutive wrong password login failure times +1, + // If the lock condition is satisfied, the lock status is updated and the update cache is notified. + lockStatusChanged, err := userAutoAccountLocked(s, user, host, passwordLocking) + if err != nil { + return false, nil, err + } + return lockStatusChanged, passwordLocking, nil +} + +func autolockAction(s *session, passwordLocking *privileges.PasswordLocking, user, host string) error { + // Don't want to update the cache frequently, and only trigger the update cache when the lock status is updated. + err := domain.GetDomain(s).NotifyUpdatePrivilege() + if err != nil { + return err + } + // The number of failed login attempts reaches FAILED_LOGIN_ATTEMPTS. + // An error message is displayed indicating permission denial and account lock. + if passwordLocking.PasswordLockTimeDays == -1 { + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, + "unlimited", "unlimited") + } + lds := strconv.FormatInt(passwordLocking.PasswordLockTimeDays, 10) + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, lds, lds) +} + +func (s *session) passwordLocking(user string, host string, newAttributesStr string) error { + sql := new(strings.Builder) + sqlescape.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable) + sqlescape.MustFormatSQL(sql, "user_attributes=json_merge_patch(coalesce(user_attributes, '{}'), %?)", newAttributesStr) + sqlescape.MustFormatSQL(sql, " WHERE Host=%? and User=%?;", host, user) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, sql.String()) + return err +} + +func failedLoginTrackingBegin(s *session) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, "BEGIN PESSIMISTIC") + return err +} + +func failedLoginTrackingCommit(s *session) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, "COMMIT") + if err != nil { + _, rollBackErr := s.ExecuteInternal(ctx, "ROLLBACK") + if rollBackErr != nil { + return rollBackErr + } + } + return err +} + +func failedLoginTrackingRollback(s *session) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, "ROLLBACK") + return err +} + +// getFailedLoginUserAttributes queries the exact number of consecutive password login failures (concurrency is not allowed). +func getFailedLoginUserAttributes(s *session, user string, host string) (*privileges.PasswordLocking, error) { + passwordLocking := &privileges.PasswordLocking{} + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + rs, err := s.ExecuteInternal(ctx, `SELECT user_attributes from mysql.user WHERE USER = %? AND HOST = %? for update`, user, host) + if err != nil { + return passwordLocking, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + req := rs.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + err = rs.Next(ctx, req) + if err != nil { + return passwordLocking, err + } + if req.NumRows() == 0 { + return passwordLocking, fmt.Errorf("user_attributes by `%s`@`%s` not found", user, host) + } + row := iter.Begin() + if !row.IsNull(0) { + passwordLockingJSON := row.GetJSON(0) + return passwordLocking, passwordLocking.ParseJSON(passwordLockingJSON) + } + return passwordLocking, fmt.Errorf("user_attributes by `%s`@`%s` not found", user, host) +} + +func userAutoAccountLocked(s *session, user string, host string, pl *privileges.PasswordLocking) (bool, error) { + // Indicates whether the user needs to update the lock status change. + lockStatusChanged := false + // The number of consecutive login failures is stored in the database. + // If the current login fails, one is added to the number of consecutive login failures + // stored in the database to determine whether the user needs to be locked and the number of update failures. + failedLoginCount := pl.FailedLoginCount + 1 + // If the cache is not updated, but it is already locked, it will report that the account is locked. + if pl.AutoAccountLocked { + if pl.PasswordLockTimeDays == -1 { + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, + "unlimited", "unlimited") + } + lds := strconv.FormatInt(pl.PasswordLockTimeDays, 10) + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, lds, lds) + } + + autoAccountLocked := "N" + autoLockedLastChanged := "" + if pl.FailedLoginAttempts == 0 || pl.PasswordLockTimeDays == 0 { + return false, nil + } + + if failedLoginCount >= pl.FailedLoginAttempts { + autoLockedLastChanged = time.Now().Format(time.UnixDate) + autoAccountLocked = "Y" + lockStatusChanged = true + } + + newAttributesStr := privileges.BuildPasswordLockingJSON(pl.FailedLoginAttempts, + pl.PasswordLockTimeDays, autoAccountLocked, failedLoginCount, autoLockedLastChanged) + if newAttributesStr != "" { + return lockStatusChanged, s.passwordLocking(user, host, newAttributesStr) + } + return lockStatusChanged, nil +} + +// MatchIdentity finds the matching username + password in the MySQL privilege tables +// for a username + hostname, since MySQL can have wildcards. +func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { + pm := privilege.GetPrivilegeManager(s) + var success bool + var skipNameResolve bool + var user = &auth.UserIdentity{} + varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) + if err == nil && variable.TiDBOptOn(varVal) { + skipNameResolve = true + } + user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve) + if success { + return user, nil + } + // This error will not be returned to the user, access denied will be instead + return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) +} + +// AuthWithoutVerification is required by the ResetConnection RPC +func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { + pm := privilege.GetPrivilegeManager(s) + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname + s.sessionVars.User = user + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) + return true + } + return false +} + +// SetSessionStatesHandler implements the Session.SetSessionStatesHandler interface. +func (s *session) SetSessionStatesHandler(stateType sessionstates.SessionStateType, handler sessionctx.SessionStatesHandler) { + s.sessionStatesHandlers[stateType] = handler +} + +// CreateSession4Test creates a new session environment for test. +func CreateSession4Test(store kv.Storage) (Session, error) { + se, err := CreateSession4TestWithOpt(store, nil) + if err == nil { + // Cover both chunk rpc encoding and default encoding. + // nolint:gosec + if rand.Intn(2) == 0 { + se.GetSessionVars().EnableChunkRPC = false + } else { + se.GetSessionVars().EnableChunkRPC = true + } + } + return se, err +} + +// Opt describes the option for creating session +type Opt struct { + PreparedPlanCache sessionctx.PlanCache +} + +// CreateSession4TestWithOpt creates a new session environment for test. +func CreateSession4TestWithOpt(store kv.Storage, opt *Opt) (Session, error) { + s, err := CreateSessionWithOpt(store, opt) + if err == nil { + // initialize session variables for test. + s.GetSessionVars().InitChunkSize = 2 + s.GetSessionVars().MaxChunkSize = 32 + s.GetSessionVars().MinPagingSize = variable.DefMinPagingSize + s.GetSessionVars().EnablePaging = variable.DefTiDBEnablePaging + err = s.GetSessionVars().SetSystemVarWithoutValidation(variable.CharacterSetConnection, "utf8mb4") + } + return s, err +} + +// CreateSession creates a new session environment. +func CreateSession(store kv.Storage) (Session, error) { + return CreateSessionWithOpt(store, nil) +} + +// CreateSessionWithOpt creates a new session environment with option. +// Use default option if opt is nil. +func CreateSessionWithOpt(store kv.Storage, opt *Opt) (Session, error) { + s, err := createSessionWithOpt(store, opt) + if err != nil { + return nil, err + } + + // Add auth here. + do, err := domap.Get(store) + if err != nil { + return nil, err + } + extensions, err := extension.GetExtensions() + if err != nil { + return nil, err + } + pm := privileges.NewUserPrivileges(do.PrivilegeHandle(), extensions) + privilege.BindPrivilegeManager(s, pm) + + // Add stats collector, and it will be freed by background stats worker + // which periodically updates stats using the collected data. + if do.StatsHandle() != nil && do.StatsUpdating() { + s.statsCollector = do.StatsHandle().NewSessionStatsItem().(*usage.SessionStatsItem) + if GetIndexUsageSyncLease() > 0 { + s.idxUsageCollector = do.StatsHandle().NewSessionIndexUsageCollector().(*usage.SessionIndexUsageCollector) + } + } + + return s, nil +} + +// loadCollationParameter loads collation parameter from mysql.tidb +func loadCollationParameter(ctx context.Context, se *session) (bool, error) { + para, err := se.getTableValue(ctx, mysql.TiDBTable, tidbNewCollationEnabled) + if err != nil { + return false, err + } + if para == varTrue { + return true, nil + } else if para == varFalse { + return false, nil + } + logutil.BgLogger().Warn( + "Unexpected value of 'new_collation_enabled' in 'mysql.tidb', use 'False' instead", + zap.String("value", para)) + return false, nil +} + +type tableBasicInfo struct { + SQL string + id int64 +} + +var ( + errResultIsEmpty = dbterror.ClassExecutor.NewStd(errno.ErrResultIsEmpty) + // DDLJobTables is a list of tables definitions used in concurrent DDL. + DDLJobTables = []tableBasicInfo{ + {ddl.JobTableSQL, ddl.JobTableID}, + {ddl.ReorgTableSQL, ddl.ReorgTableID}, + {ddl.HistoryTableSQL, ddl.HistoryTableID}, + } + // BackfillTables is a list of tables definitions used in dist reorg DDL. + BackfillTables = []tableBasicInfo{ + {ddl.BackgroundSubtaskTableSQL, ddl.BackgroundSubtaskTableID}, + {ddl.BackgroundSubtaskHistoryTableSQL, ddl.BackgroundSubtaskHistoryTableID}, + } + mdlTable = "create table mysql.tidb_mdl_info(job_id BIGINT NOT NULL PRIMARY KEY, version BIGINT NOT NULL, table_ids text(65535));" +) + +func splitAndScatterTable(store kv.Storage, tableIDs []int64) { + if s, ok := store.(kv.SplittableStore); ok && atomic.LoadUint32(&ddl.EnableSplitTableRegion) == 1 { + ctxWithTimeout, cancel := context.WithTimeout(context.Background(), variable.DefWaitSplitRegionTimeout*time.Second) + var regionIDs []uint64 + for _, id := range tableIDs { + regionIDs = append(regionIDs, ddl.SplitRecordRegion(ctxWithTimeout, s, id, id, variable.DefTiDBScatterRegion)) + } + if variable.DefTiDBScatterRegion { + ddl.WaitScatterRegionFinish(ctxWithTimeout, s, regionIDs...) + } + cancel() + } +} + +// InitDDLJobTables is to create tidb_ddl_job, tidb_ddl_reorg and tidb_ddl_history, or tidb_background_subtask and tidb_background_subtask_history. +func InitDDLJobTables(store kv.Storage, targetVer meta.DDLTableVersion) error { + targetTables := DDLJobTables + if targetVer == meta.BackfillTableVersion { + targetTables = BackfillTables + } + return kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + tableVer, err := t.CheckDDLTableVersion() + if err != nil || tableVer >= targetVer { + return errors.Trace(err) + } + dbID, err := t.CreateMySQLDatabaseIfNotExists() + if err != nil { + return err + } + if err = createAndSplitTables(store, t, dbID, targetTables); err != nil { + return err + } + return t.SetDDLTables(targetVer) + }) +} + +func createAndSplitTables(store kv.Storage, t *meta.Meta, dbID int64, tables []tableBasicInfo) error { + tableIDs := make([]int64, 0, len(tables)) + for _, tbl := range tables { + tableIDs = append(tableIDs, tbl.id) + } + splitAndScatterTable(store, tableIDs) + p := parser.New() + for _, tbl := range tables { + stmt, err := p.ParseOneStmt(tbl.SQL, "", "") + if err != nil { + return errors.Trace(err) + } + tblInfo, err := ddl.BuildTableInfoFromAST(stmt.(*ast.CreateTableStmt)) + if err != nil { + return errors.Trace(err) + } + tblInfo.State = model.StatePublic + tblInfo.ID = tbl.id + tblInfo.UpdateTS = t.StartTS + err = t.CreateTableOrView(dbID, tblInfo) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +// InitMDLTable is to create tidb_mdl_info, which is used for metadata lock. +func InitMDLTable(store kv.Storage) error { + return kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + ver, err := t.CheckDDLTableVersion() + if err != nil || ver >= meta.MDLTableVersion { + return errors.Trace(err) + } + dbID, err := t.CreateMySQLDatabaseIfNotExists() + if err != nil { + return err + } + splitAndScatterTable(store, []int64{ddl.MDLTableID}) + p := parser.New() + stmt, err := p.ParseOneStmt(mdlTable, "", "") + if err != nil { + return errors.Trace(err) + } + tblInfo, err := ddl.BuildTableInfoFromAST(stmt.(*ast.CreateTableStmt)) + if err != nil { + return errors.Trace(err) + } + tblInfo.State = model.StatePublic + tblInfo.ID = ddl.MDLTableID + tblInfo.UpdateTS = t.StartTS + err = t.CreateTableOrView(dbID, tblInfo) + if err != nil { + return errors.Trace(err) + } + + return t.SetDDLTables(meta.MDLTableVersion) + }) +} + +// InitMDLVariableForBootstrap initializes the metadata lock variable. +func InitMDLVariableForBootstrap(store kv.Storage) error { + err := kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + return t.SetMetadataLock(true) + }) + if err != nil { + return err + } + variable.EnableMDL.Store(true) + return nil +} + +// InitMDLVariableForUpgrade initializes the metadata lock variable. +func InitMDLVariableForUpgrade(store kv.Storage) (bool, error) { + isNull := false + enable := false + var err error + err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + enable, isNull, err = t.GetMetadataLock() + if err != nil { + return err + } + return nil + }) + if isNull || !enable { + variable.EnableMDL.Store(false) + } else { + variable.EnableMDL.Store(true) + } + return isNull, err +} + +// InitMDLVariable initializes the metadata lock variable. +func InitMDLVariable(store kv.Storage) error { + isNull := false + enable := false + var err error + err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + enable, isNull, err = t.GetMetadataLock() + if err != nil { + return err + } + if isNull { + // Workaround for version: nightly-2022-11-07 to nightly-2022-11-17. + enable = true + logutil.BgLogger().Warn("metadata lock is null") + err = t.SetMetadataLock(true) + if err != nil { + return err + } + } + return nil + }) + variable.EnableMDL.Store(enable) + return err +} + +// BootstrapSession bootstrap session and domain. +func BootstrapSession(store kv.Storage) (*domain.Domain, error) { + return bootstrapSessionImpl(store, createSessions) +} + +// BootstrapSession4DistExecution bootstrap session and dom for Distributed execution test, only for unit testing. +func BootstrapSession4DistExecution(store kv.Storage) (*domain.Domain, error) { + return bootstrapSessionImpl(store, createSessions4DistExecution) +} + +func bootstrapSessionImpl(store kv.Storage, createSessionsImpl func(store kv.Storage, cnt int) ([]*session, error)) (*domain.Domain, error) { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) + cfg := config.GetGlobalConfig() + if len(cfg.Instance.PluginLoad) > 0 { + err := plugin.Load(context.Background(), plugin.Config{ + Plugins: strings.Split(cfg.Instance.PluginLoad, ","), + PluginDir: cfg.Instance.PluginDir, + }) + if err != nil { + return nil, err + } + } + err := InitDDLJobTables(store, meta.BaseDDLTableVersion) + if err != nil { + return nil, err + } + err = InitMDLTable(store) + if err != nil { + return nil, err + } + err = InitDDLJobTables(store, meta.BackfillTableVersion) + if err != nil { + return nil, err + } + ver := getStoreBootstrapVersion(store) + if ver == notBootstrapped { + runInBootstrapSession(store, bootstrap) + } else if ver < currentBootstrapVersion { + runInBootstrapSession(store, upgrade) + } else { + err = InitMDLVariable(store) + if err != nil { + return nil, err + } + } + + analyzeConcurrencyQuota := int(config.GetGlobalConfig().Performance.AnalyzePartitionConcurrencyQuota) + concurrency := int(config.GetGlobalConfig().Performance.StatsLoadConcurrency) + ses, err := createSessionsImpl(store, 10) + if err != nil { + return nil, err + } + ses[0].GetSessionVars().InRestrictedSQL = true + + // get system tz from mysql.tidb + tz, err := ses[0].getTableValue(ctx, mysql.TiDBTable, tidbSystemTZ) + if err != nil { + return nil, err + } + timeutil.SetSystemTZ(tz) + + // get the flag from `mysql`.`tidb` which indicating if new collations are enabled. + newCollationEnabled, err := loadCollationParameter(ctx, ses[0]) + if err != nil { + return nil, err + } + collate.SetNewCollationEnabledForTest(newCollationEnabled) + // To deal with the location partition failure caused by inconsistent NewCollationEnabled values(see issue #32416). + rebuildAllPartitionValueMapAndSorted(ses[0]) + + dom := domain.GetDomain(ses[0]) + + // We should make the load bind-info loop before other loops which has internal SQL. + // Because the internal SQL may access the global bind-info handler. As the result, the data race occurs here as the + // LoadBindInfoLoop inits global bind-info handler. + err = dom.LoadBindInfoLoop(ses[1], ses[2]) + if err != nil { + return nil, err + } + + if !config.GetGlobalConfig().Security.SkipGrantTable { + err = dom.LoadPrivilegeLoop(ses[3]) + if err != nil { + return nil, err + } + } + + // Rebuild sysvar cache in a loop + err = dom.LoadSysVarCacheLoop(ses[4]) + if err != nil { + return nil, err + } + + if config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler { + // Invalid client-go tiflash_compute store cache if necessary. + err = dom.WatchTiFlashComputeNodeChange() + if err != nil { + return nil, err + } + } + + if err = extensionimpl.Bootstrap(context.Background(), dom); err != nil { + return nil, err + } + + if len(cfg.Instance.PluginLoad) > 0 { + err := plugin.Init(context.Background(), plugin.Config{EtcdClient: dom.GetEtcdClient()}) + if err != nil { + return nil, err + } + } + + err = executor.LoadExprPushdownBlacklist(ses[5]) + if err != nil { + return nil, err + } + err = executor.LoadOptRuleBlacklist(ctx, ses[5]) + if err != nil { + return nil, err + } + + if dom.GetEtcdClient() != nil { + // We only want telemetry data in production-like clusters. When TiDB is deployed over other engines, + // for example, unistore engine (used for local tests), we just skip it. Its etcd client is nil. + if config.GetGlobalConfig().EnableTelemetry { + // There is no way to turn telemetry on with global variable `tidb_enable_telemetry` + // when it is disabled in config. See IsTelemetryEnabled function in telemetry/telemetry.go + go func() { + dom.TelemetryReportLoop(ses[5]) + dom.TelemetryRotateSubWindowLoop(ses[5]) + }() + } + } + + planReplayerWorkerCnt := config.GetGlobalConfig().Performance.PlanReplayerDumpWorkerConcurrency + planReplayerWorkersSctx := make([]sessionctx.Context, planReplayerWorkerCnt) + pworkerSes, err := createSessions(store, int(planReplayerWorkerCnt)) + if err != nil { + return nil, err + } + for i := 0; i < int(planReplayerWorkerCnt); i++ { + planReplayerWorkersSctx[i] = pworkerSes[i] + } + // setup plan replayer handle + dom.SetupPlanReplayerHandle(ses[6], planReplayerWorkersSctx) + dom.StartPlanReplayerHandle() + // setup dumpFileGcChecker + dom.SetupDumpFileGCChecker(ses[7]) + dom.DumpFileGcCheckerLoop() + // setup historical stats worker + dom.SetupHistoricalStatsWorker(ses[8]) + dom.StartHistoricalStatsWorker() + failToLoadOrParseSQLFile := false // only used for unit test + if runBootstrapSQLFile { + pm := &privileges.UserPrivileges{ + Handle: dom.PrivilegeHandle(), + } + privilege.BindPrivilegeManager(ses[9], pm) + if err := doBootstrapSQLFile(ses[9]); err != nil && intest.InTest { + failToLoadOrParseSQLFile = true + } + } + // A sub context for update table stats, and other contexts for concurrent stats loading. + cnt := 1 + concurrency + syncStatsCtxs, err := createSessions(store, cnt) + if err != nil { + return nil, err + } + subCtxs := make([]sessionctx.Context, cnt) + for i := 0; i < cnt; i++ { + subCtxs[i] = sessionctx.Context(syncStatsCtxs[i]) + } + + // setup extract Handle + extractWorkers := 1 + sctxs, err := createSessions(store, extractWorkers) + if err != nil { + return nil, err + } + extractWorkerSctxs := make([]sessionctx.Context, 0) + for _, sctx := range sctxs { + extractWorkerSctxs = append(extractWorkerSctxs, sctx) + } + dom.SetupExtractHandle(extractWorkerSctxs) + + // setup init stats loader + initStatsCtx, err := createSession(store) + if err != nil { + return nil, err + } + if err = dom.LoadAndUpdateStatsLoop(subCtxs, initStatsCtx); err != nil { + return nil, err + } + + // start TTL job manager after setup stats collector + // because TTL could modify a lot of columns, and need to trigger auto analyze + ttlworker.AttachStatsCollector = func(s sqlexec.SQLExecutor) sqlexec.SQLExecutor { + if s, ok := s.(*session); ok { + return attachStatsCollector(s, dom) + } + return s + } + ttlworker.DetachStatsCollector = func(s sqlexec.SQLExecutor) sqlexec.SQLExecutor { + if s, ok := s.(*session); ok { + return detachStatsCollector(s) + } + return s + } + dom.StartTTLJobManager() + + analyzeCtxs, err := createSessions(store, analyzeConcurrencyQuota) + if err != nil { + return nil, err + } + subCtxs2 := make([]sessionctx.Context, analyzeConcurrencyQuota) + for i := 0; i < analyzeConcurrencyQuota; i++ { + subCtxs2[i] = analyzeCtxs[i] + } + dom.SetupAnalyzeExec(subCtxs2) + dom.LoadSigningCertLoop(cfg.Security.SessionTokenSigningCert, cfg.Security.SessionTokenSigningKey) + + if raw, ok := store.(kv.EtcdBackend); ok { + err = raw.StartGCWorker() + if err != nil { + return nil, err + } + } + + // This only happens in testing, since the failure of loading or parsing sql file + // would panic the bootstrapping. + if intest.InTest && failToLoadOrParseSQLFile { + dom.Close() + return nil, errors.New("Fail to load or parse sql file") + } + err = dom.InitDistTaskLoop(ctx) + if err != nil { + return nil, err + } + return dom, err +} + +// GetDomain gets the associated domain for store. +func GetDomain(store kv.Storage) (*domain.Domain, error) { + return domap.Get(store) +} + +// runInBootstrapSession create a special session for bootstrap to run. +// If no bootstrap and storage is remote, we must use a little lease time to +// bootstrap quickly, after bootstrapped, we will reset the lease time. +// TODO: Using a bootstrap tool for doing this may be better later. +func runInBootstrapSession(store kv.Storage, bootstrap func(Session)) { + s, err := createSession(store) + if err != nil { + // Bootstrap fail will cause program exit. + logutil.BgLogger().Fatal("createSession error", zap.Error(err)) + } + // For the bootstrap SQLs, the following variables should be compatible with old TiDB versions. + s.sessionVars.EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly + + s.SetValue(sessionctx.Initing, true) + bootstrap(s) + finishBootstrap(store) + s.ClearValue(sessionctx.Initing) + + dom := domain.GetDomain(s) + dom.Close() + if intest.InTest { + infosync.MockGlobalServerInfoManagerEntry.Close() + } + domap.Delete(store) +} + +func createSessions(store kv.Storage, cnt int) ([]*session, error) { + return createSessionsImpl(store, cnt, createSession) +} + +func createSessions4DistExecution(store kv.Storage, cnt int) ([]*session, error) { + domap.Delete(store) + + return createSessionsImpl(store, cnt, createSession4DistExecution) +} + +func createSessionsImpl(store kv.Storage, cnt int, createSessionImpl func(kv.Storage) (*session, error)) ([]*session, error) { + // Then we can create new dom + ses := make([]*session, cnt) + for i := 0; i < cnt; i++ { + se, err := createSessionImpl(store) + if err != nil { + return nil, err + } + ses[i] = se + } + + return ses, nil +} + +// createSession creates a new session. +// Please note that such a session is not tracked by the internal session list. +// This means the min ts reporter is not aware of it and may report a wrong min start ts. +// In most cases you should use a session pool in domain instead. +func createSession(store kv.Storage) (*session, error) { + return createSessionWithOpt(store, nil) +} + +func createSession4DistExecution(store kv.Storage) (*session, error) { + return createSessionWithOpt(store, nil) +} + +func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { + dom, err := domap.Get(store) + if err != nil { + return nil, err + } + s := &session{ + store: store, + ddlOwnerManager: dom.DDL().OwnerManager(), + client: store.GetClient(), + mppClient: store.GetMPPClient(), + stmtStats: stmtstats.CreateStatementStats(), + sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), + } + s.sessionVars = variable.NewSessionVars(s) + + s.functionUsageMu.builtinFunctionUsage = make(telemetry.BuiltinFunctionsUsage) + if opt != nil && opt.PreparedPlanCache != nil { + s.sessionPlanCache = opt.PreparedPlanCache + } + s.mu.values = make(map[fmt.Stringer]interface{}) + s.lockedTables = make(map[int64]model.TableLockTpInfo) + s.advisoryLocks = make(map[string]*advisoryLock) + + domain.BindDomain(s, dom) + // session implements variable.GlobalVarAccessor. Bind it to ctx. + s.sessionVars.GlobalVarsAccessor = s + s.sessionVars.BinlogClient = binloginfo.GetPumpsClient() + s.txn.init() + + sessionBindHandle := bindinfo.NewSessionBindHandle() + s.SetValue(bindinfo.SessionBindInfoKeyType, sessionBindHandle) + s.SetSessionStatesHandler(sessionstates.StateBinding, sessionBindHandle) + return s, nil +} + +// attachStatsCollector attaches the stats collector in the dom for the session +func attachStatsCollector(s *session, dom *domain.Domain) *session { + if dom.StatsHandle() != nil && dom.StatsUpdating() { + if s.statsCollector == nil { + s.statsCollector = dom.StatsHandle().NewSessionStatsItem().(*usage.SessionStatsItem) + } + if s.idxUsageCollector == nil && GetIndexUsageSyncLease() > 0 { + s.idxUsageCollector = dom.StatsHandle().NewSessionIndexUsageCollector().(*usage.SessionIndexUsageCollector) + } + } + + return s +} + +// detachStatsCollector removes the stats collector in the session +func detachStatsCollector(s *session) *session { + if s.statsCollector != nil { + s.statsCollector.Delete() + s.statsCollector = nil + } + if s.idxUsageCollector != nil { + s.idxUsageCollector.Delete() + s.idxUsageCollector = nil + } + return s +} + +// CreateSessionWithDomain creates a new Session and binds it with a Domain. +// We need this because when we start DDL in Domain, the DDL need a session +// to change some system tables. But at that time, we have been already in +// a lock context, which cause we can't call createSession directly. +func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) { + s := &session{ + store: store, + sessionVars: variable.NewSessionVars(nil), + client: store.GetClient(), + mppClient: store.GetMPPClient(), + stmtStats: stmtstats.CreateStatementStats(), + sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), + } + s.functionUsageMu.builtinFunctionUsage = make(telemetry.BuiltinFunctionsUsage) + s.mu.values = make(map[fmt.Stringer]interface{}) + s.lockedTables = make(map[int64]model.TableLockTpInfo) + domain.BindDomain(s, dom) + // session implements variable.GlobalVarAccessor. Bind it to ctx. + s.sessionVars.GlobalVarsAccessor = s + s.txn.init() + return s, nil +} + +const ( + notBootstrapped = 0 +) + +func getStoreBootstrapVersion(store kv.Storage) int64 { + storeBootstrappedLock.Lock() + defer storeBootstrappedLock.Unlock() + // check in memory + _, ok := storeBootstrapped[store.UUID()] + if ok { + return currentBootstrapVersion + } + + var ver int64 + // check in kv store + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) + err := kv.RunInNewTxn(ctx, store, false, func(ctx context.Context, txn kv.Transaction) error { + var err error + t := meta.NewMeta(txn) + ver, err = t.GetBootstrapVersion() + return err + }) + if err != nil { + logutil.BgLogger().Fatal("check bootstrapped failed", + zap.Error(err)) + } + + if ver > notBootstrapped { + // here mean memory is not ok, but other server has already finished it + storeBootstrapped[store.UUID()] = true + } + + modifyBootstrapVersionForTest(ver) + return ver +} + +func finishBootstrap(store kv.Storage) { + setStoreBootstrapped(store.UUID()) + + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) + err := kv.RunInNewTxn(ctx, store, true, func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + err := t.FinishBootstrap(currentBootstrapVersion) + return err + }) + if err != nil { + logutil.BgLogger().Fatal("finish bootstrap failed", + zap.Error(err)) + } +} + +const quoteCommaQuote = "', '" + +// loadCommonGlobalVariablesIfNeeded loads and applies commonly used global variables for the session. +func (s *session) loadCommonGlobalVariablesIfNeeded() error { + vars := s.sessionVars + if vars.CommonGlobalLoaded { + return nil + } + if s.Value(sessionctx.Initing) != nil { + // When running bootstrap or upgrade, we should not access global storage. + return nil + } + + vars.CommonGlobalLoaded = true + + // Deep copy sessionvar cache + sessionCache, err := domain.GetDomain(s).GetSessionCache() + if err != nil { + return err + } + for varName, varVal := range sessionCache { + if _, ok := vars.GetSystemVar(varName); !ok { + err = vars.SetSystemVarWithRelaxedValidation(varName, varVal) + if err != nil { + if variable.ErrUnknownSystemVar.Equal(err) { + continue // sessionCache is stale; sysvar has likely been unregistered + } + return err + } + } + } + // when client set Capability Flags CLIENT_INTERACTIVE, init wait_timeout with interactive_timeout + if vars.ClientCapability&mysql.ClientInteractive > 0 { + if varVal, ok := vars.GetSystemVar(variable.InteractiveTimeout); ok { + if err := vars.SetSystemVar(variable.WaitTimeout, varVal); err != nil { + return err + } + } + } + return nil +} + +// PrepareTxnCtx begins a transaction, and creates a new transaction context. +// It is called before we execute a sql query. +func (s *session) PrepareTxnCtx(ctx context.Context) error { + s.currentCtx = ctx + if s.txn.validOrPending() { + return nil + } + + txnMode := ast.Optimistic + if !s.sessionVars.IsAutocommit() || config.GetGlobalConfig().PessimisticTxn.PessimisticAutoCommit.Load() { + if s.sessionVars.TxnMode == ast.Pessimistic { + txnMode = ast.Pessimistic + } + } + + if s.sessionVars.RetryInfo.Retrying { + txnMode = ast.Pessimistic + } + + return sessiontxn.GetTxnManager(s).EnterNewTxn(ctx, &sessiontxn.EnterNewTxnRequest{ + Type: sessiontxn.EnterNewTxnBeforeStmt, + TxnMode: txnMode, + }) +} + +// PrepareTSFuture uses to try to get ts future. +func (s *session) PrepareTSFuture(ctx context.Context, future oracle.Future, scope string) error { + if s.txn.Valid() { + return errors.New("cannot prepare ts future when txn is valid") + } + + failpoint.Inject("assertTSONotRequest", func() { + if _, ok := future.(sessiontxn.ConstantFuture); !ok && !s.isInternal() { + panic("tso shouldn't be requested") + } + }) + + failpoint.InjectContext(ctx, "mockGetTSFail", func() { + future = txnFailFuture{} + }) + + s.txn.changeToPending(&txnFuture{ + future: future, + store: s.store, + txnScope: scope, + }) + return nil +} + +// GetPreparedTxnFuture returns the TxnFuture if it is valid or pending. +// It returns nil otherwise. +func (s *session) GetPreparedTxnFuture() sessionctx.TxnFuture { + if !s.txn.validOrPending() { + return nil + } + return &s.txn +} + +// RefreshTxnCtx implements context.RefreshTxnCtx interface. +func (s *session) RefreshTxnCtx(ctx context.Context) error { + var commitDetail *tikvutil.CommitDetails + ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) + err := s.doCommit(ctx) + if commitDetail != nil { + s.GetSessionVars().StmtCtx.MergeExecDetails(nil, commitDetail) + } + if err != nil { + return err + } + + s.updateStatsDeltaToCollector() + + return sessiontxn.NewTxn(ctx, s) +} + +// GetStore gets the store of session. +func (s *session) GetStore() kv.Storage { + return s.store +} + +func (s *session) ShowProcess() *util.ProcessInfo { + var pi *util.ProcessInfo + tmp := s.processInfo.Load() + if tmp != nil { + pi = tmp.(*util.ProcessInfo) + } + return pi +} + +// GetStartTSFromSession returns the startTS in the session `se` +func GetStartTSFromSession(se interface{}) (startTS, processInfoID uint64) { + tmp, ok := se.(*session) + if !ok { + logutil.BgLogger().Error("GetStartTSFromSession failed, can't transform to session struct") + return 0, 0 + } + txnInfo := tmp.TxnInfo() + if txnInfo != nil { + startTS = txnInfo.StartTS + processInfoID = txnInfo.ConnectionID + } + + logutil.BgLogger().Debug( + "GetStartTSFromSession getting startTS of internal session", + zap.Uint64("startTS", startTS), zap.Time("start time", oracle.GetTimeFromTS(startTS))) + + return startTS, processInfoID +} + +// logStmt logs some crucial SQL including: CREATE USER/GRANT PRIVILEGE/CHANGE PASSWORD/DDL etc and normal SQL +// if variable.ProcessGeneralLog is set. +func logStmt(execStmt *executor.ExecStmt, s *session) { + vars := s.GetSessionVars() + isCrucial := false + switch stmt := execStmt.StmtNode.(type) { + case *ast.DropIndexStmt: + isCrucial = true + if stmt.IsHypo { + isCrucial = false + } + case *ast.CreateIndexStmt: + isCrucial = true + if stmt.IndexOption != nil && stmt.IndexOption.Tp == model.IndexTypeHypo { + isCrucial = false + } + case *ast.CreateUserStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.SetPwdStmt, *ast.GrantStmt, + *ast.RevokeStmt, *ast.AlterTableStmt, *ast.CreateDatabaseStmt, *ast.CreateTableStmt, + *ast.DropDatabaseStmt, *ast.DropTableStmt, *ast.RenameTableStmt, *ast.TruncateTableStmt, + *ast.RenameUserStmt: + isCrucial = true + } + + if isCrucial { + user := vars.User + schemaVersion := s.GetInfoSchema().SchemaMetaVersion() + if ss, ok := execStmt.StmtNode.(ast.SensitiveStmtNode); ok { + logutil.BgLogger().Info("CRUCIAL OPERATION", + zap.Uint64("conn", vars.ConnectionID), + zap.Int64("schemaVersion", schemaVersion), + zap.String("secure text", ss.SecureText()), + zap.Stringer("user", user)) + } else { + logutil.BgLogger().Info("CRUCIAL OPERATION", + zap.Uint64("conn", vars.ConnectionID), + zap.Int64("schemaVersion", schemaVersion), + zap.String("cur_db", vars.CurrentDB), + zap.String("sql", execStmt.StmtNode.Text()), + zap.Stringer("user", user)) + } + } else { + logGeneralQuery(execStmt, s, false) + } +} + +func logGeneralQuery(execStmt *executor.ExecStmt, s *session, isPrepared bool) { + vars := s.GetSessionVars() + if variable.ProcessGeneralLog.Load() && !vars.InRestrictedSQL { + var query string + if isPrepared { + query = execStmt.OriginText() + } else { + query = execStmt.GetTextToLog(false) + } + + query = executor.QueryReplacer.Replace(query) + if !vars.EnableRedactLog { + query += vars.PlanCacheParams.String() + } + logutil.BgLogger().Info("GENERAL_LOG", + zap.Uint64("conn", vars.ConnectionID), + zap.String("session_alias", vars.SessionAlias), + zap.String("user", vars.User.LoginString()), + zap.Int64("schemaVersion", s.GetInfoSchema().SchemaMetaVersion()), + zap.Uint64("txnStartTS", vars.TxnCtx.StartTS), + zap.Uint64("forUpdateTS", vars.TxnCtx.GetForUpdateTS()), + zap.Bool("isReadConsistency", vars.IsIsolation(ast.ReadCommitted)), + zap.String("currentDB", vars.CurrentDB), + zap.Bool("isPessimistic", vars.TxnCtx.IsPessimistic), + zap.String("sessionTxnMode", vars.GetReadableTxnMode()), + zap.String("sql", query)) + } +} + +func (s *session) recordOnTransactionExecution(err error, counter int, duration float64, isInternal bool) { + if s.sessionVars.TxnCtx.IsPessimistic { + if err != nil { + if isInternal { + session_metrics.TransactionDurationPessimisticAbortInternal.Observe(duration) + session_metrics.StatementPerTransactionPessimisticErrorInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationPessimisticAbortGeneral.Observe(duration) + session_metrics.StatementPerTransactionPessimisticErrorGeneral.Observe(float64(counter)) + } + } else { + if isInternal { + session_metrics.TransactionDurationPessimisticCommitInternal.Observe(duration) + session_metrics.StatementPerTransactionPessimisticOKInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationPessimisticCommitGeneral.Observe(duration) + session_metrics.StatementPerTransactionPessimisticOKGeneral.Observe(float64(counter)) + } + } + } else { + if err != nil { + if isInternal { + session_metrics.TransactionDurationOptimisticAbortInternal.Observe(duration) + session_metrics.StatementPerTransactionOptimisticErrorInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationOptimisticAbortGeneral.Observe(duration) + session_metrics.StatementPerTransactionOptimisticErrorGeneral.Observe(float64(counter)) + } + } else { + if isInternal { + session_metrics.TransactionDurationOptimisticCommitInternal.Observe(duration) + session_metrics.StatementPerTransactionOptimisticOKInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationOptimisticCommitGeneral.Observe(duration) + session_metrics.StatementPerTransactionOptimisticOKGeneral.Observe(float64(counter)) + } + } + } +} + +func (s *session) checkPlacementPolicyBeforeCommit() error { + var err error + // Get the txnScope of the transaction we're going to commit. + txnScope := s.GetSessionVars().TxnCtx.TxnScope + if txnScope == "" { + txnScope = kv.GlobalTxnScope + } + if txnScope != kv.GlobalTxnScope { + is := s.GetInfoSchema().(infoschema.InfoSchema) + deltaMap := s.GetSessionVars().TxnCtx.TableDeltaMap + for physicalTableID := range deltaMap { + var tableName string + var partitionName string + tblInfo, _, partInfo := is.FindTableByPartitionID(physicalTableID) + if tblInfo != nil && partInfo != nil { + tableName = tblInfo.Meta().Name.String() + partitionName = partInfo.Name.String() + } else { + tblInfo, _ := is.TableByID(physicalTableID) + tableName = tblInfo.Meta().Name.String() + } + bundle, ok := is.PlacementBundleByPhysicalTableID(physicalTableID) + if !ok { + errMsg := fmt.Sprintf("table %v doesn't have placement policies with txn_scope %v", + tableName, txnScope) + if len(partitionName) > 0 { + errMsg = fmt.Sprintf("table %v's partition %v doesn't have placement policies with txn_scope %v", + tableName, partitionName, txnScope) + } + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) + break + } + dcLocation, ok := bundle.GetLeaderDC(placement.DCLabelKey) + if !ok { + errMsg := fmt.Sprintf("table %v's leader placement policy is not defined", tableName) + if len(partitionName) > 0 { + errMsg = fmt.Sprintf("table %v's partition %v's leader placement policy is not defined", tableName, partitionName) + } + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) + break + } + if dcLocation != txnScope { + errMsg := fmt.Sprintf("table %v's leader location %v is out of txn_scope %v", tableName, dcLocation, txnScope) + if len(partitionName) > 0 { + errMsg = fmt.Sprintf("table %v's partition %v's leader location %v is out of txn_scope %v", + tableName, partitionName, dcLocation, txnScope) + } + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) + break + } + // FIXME: currently we assume the physicalTableID is the partition ID. In future, we should consider the situation + // if the physicalTableID belongs to a Table. + partitionID := physicalTableID + tbl, _, partitionDefInfo := is.FindTableByPartitionID(partitionID) + if tbl != nil { + tblInfo := tbl.Meta() + state := tblInfo.Partition.GetStateByID(partitionID) + if state == model.StateGlobalTxnOnly { + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs( + fmt.Sprintf("partition %s of table %s can not be written by local transactions when its placement policy is being altered", + tblInfo.Name, partitionDefInfo.Name)) + break + } + } + } + } + return err +} + +func (s *session) SetPort(port string) { + s.sessionVars.Port = port +} + +// GetTxnWriteThroughputSLI implements the Context interface. +func (s *session) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return &s.txn.writeSLI +} + +// GetInfoSchema returns snapshotInfoSchema if snapshot schema is set. +// Transaction infoschema is returned if inside an explicit txn. +// Otherwise the latest infoschema is returned. +func (s *session) GetInfoSchema() sessionctx.InfoschemaMetaVersion { + vars := s.GetSessionVars() + var is infoschema.InfoSchema + if snap, ok := vars.SnapshotInfoschema.(infoschema.InfoSchema); ok { + logutil.BgLogger().Info("use snapshot schema", zap.Uint64("conn", vars.ConnectionID), zap.Int64("schemaVersion", snap.SchemaMetaVersion())) + is = snap + } else { + vars.TxnCtxMu.Lock() + if vars.TxnCtx != nil { + if tmp, ok := vars.TxnCtx.InfoSchema.(infoschema.InfoSchema); ok { + is = tmp + } + } + vars.TxnCtxMu.Unlock() + } + + if is == nil { + is = domain.GetDomain(s).InfoSchema() + } + + // Override the infoschema if the session has temporary table. + return temptable.AttachLocalTemporaryTableInfoSchema(s, is) +} + +func (s *session) GetDomainInfoSchema() sessionctx.InfoschemaMetaVersion { + is := domain.GetDomain(s).InfoSchema() + extIs := &infoschema.SessionExtendedInfoSchema{InfoSchema: is} + return temptable.AttachLocalTemporaryTableInfoSchema(s, extIs) +} + +func getSnapshotInfoSchema(s sessionctx.Context, snapshotTS uint64) (infoschema.InfoSchema, error) { + is, err := domain.GetDomain(s).GetSnapshotInfoSchema(snapshotTS) + if err != nil { + return nil, err + } + // Set snapshot does not affect the witness of the local temporary table. + // The session always see the latest temporary tables. + return temptable.AttachLocalTemporaryTableInfoSchema(s, is), nil +} + +func (s *session) updateTelemetryMetric(es *executor.ExecStmt) { + if es.Ti == nil { + return + } + if s.isInternal() { + return + } + + ti := es.Ti + if ti.UseRecursive { + session_metrics.TelemetryCTEUsageRecurCTE.Inc() + } else if ti.UseNonRecursive { + session_metrics.TelemetryCTEUsageNonRecurCTE.Inc() + } else { + session_metrics.TelemetryCTEUsageNotCTE.Inc() + } + + if ti.UseIndexMerge { + session_metrics.TelemetryIndexMerge.Inc() + } + + if ti.UseMultiSchemaChange { + session_metrics.TelemetryMultiSchemaChangeUsage.Inc() + } + + if ti.UseFlashbackToCluster { + session_metrics.TelemetryFlashbackClusterUsage.Inc() + } + + if ti.UseExchangePartition { + session_metrics.TelemetryExchangePartitionUsage.Inc() + } + + if ti.PartitionTelemetry != nil { + if ti.PartitionTelemetry.UseTablePartition { + session_metrics.TelemetryTablePartitionUsage.Inc() + session_metrics.TelemetryTablePartitionMaxPartitionsUsage.Add(float64(ti.PartitionTelemetry.TablePartitionMaxPartitionsNum)) + } + if ti.PartitionTelemetry.UseTablePartitionList { + session_metrics.TelemetryTablePartitionListUsage.Inc() + } + if ti.PartitionTelemetry.UseTablePartitionRange { + session_metrics.TelemetryTablePartitionRangeUsage.Inc() + } + if ti.PartitionTelemetry.UseTablePartitionHash { + session_metrics.TelemetryTablePartitionHashUsage.Inc() + } + if ti.PartitionTelemetry.UseTablePartitionRangeColumns { + session_metrics.TelemetryTablePartitionRangeColumnsUsage.Inc() + } + if ti.PartitionTelemetry.UseTablePartitionRangeColumnsGt1 { + session_metrics.TelemetryTablePartitionRangeColumnsGt1Usage.Inc() + } + if ti.PartitionTelemetry.UseTablePartitionRangeColumnsGt2 { + session_metrics.TelemetryTablePartitionRangeColumnsGt2Usage.Inc() + } + if ti.PartitionTelemetry.UseTablePartitionRangeColumnsGt3 { + session_metrics.TelemetryTablePartitionRangeColumnsGt3Usage.Inc() + } + if ti.PartitionTelemetry.UseTablePartitionListColumns { + session_metrics.TelemetryTablePartitionListColumnsUsage.Inc() + } + if ti.PartitionTelemetry.UseCreateIntervalPartition { + session_metrics.TelemetryTablePartitionCreateIntervalUsage.Inc() + } + if ti.PartitionTelemetry.UseAddIntervalPartition { + session_metrics.TelemetryTablePartitionAddIntervalUsage.Inc() + } + if ti.PartitionTelemetry.UseDropIntervalPartition { + session_metrics.TelemetryTablePartitionDropIntervalUsage.Inc() + } + if ti.PartitionTelemetry.UseCompactTablePartition { + session_metrics.TelemetryTableCompactPartitionUsage.Inc() + } + if ti.PartitionTelemetry.UseReorganizePartition { + session_metrics.TelemetryReorganizePartitionUsage.Inc() + } + } + + if ti.AccountLockTelemetry != nil { + session_metrics.TelemetryLockUserUsage.Add(float64(ti.AccountLockTelemetry.LockUser)) + session_metrics.TelemetryUnlockUserUsage.Add(float64(ti.AccountLockTelemetry.UnlockUser)) + session_metrics.TelemetryCreateOrAlterUserUsage.Add(float64(ti.AccountLockTelemetry.CreateOrAlterUser)) + } + + if ti.UseTableLookUp.Load() && s.sessionVars.StoreBatchSize > 0 { + session_metrics.TelemetryStoreBatchedUsage.Inc() + } +} + +// GetBuiltinFunctionUsage returns the replica of counting of builtin function usage +func (s *session) GetBuiltinFunctionUsage() map[string]uint32 { + replica := make(map[string]uint32) + s.functionUsageMu.RLock() + defer s.functionUsageMu.RUnlock() + for key, value := range s.functionUsageMu.builtinFunctionUsage { + replica[key] = value + } + return replica +} + +// BuiltinFunctionUsageInc increase the counting of the builtin function usage +func (s *session) BuiltinFunctionUsageInc(scalarFuncSigName string) { + s.functionUsageMu.Lock() + defer s.functionUsageMu.Unlock() + s.functionUsageMu.builtinFunctionUsage.Inc(scalarFuncSigName) +} + +func (s *session) GetStmtStats() *stmtstats.StatementStats { + return s.stmtStats +} + +// SetMemoryFootprintChangeHook sets the hook that is called when the memdb changes its size. +// Call this after s.txn becomes valid, since TxnInfo is initialized when the txn becomes valid. +func (s *session) SetMemoryFootprintChangeHook() { + if config.GetGlobalConfig().Performance.TxnTotalSizeLimit != config.DefTxnTotalSizeLimit { + // if the user manually specifies the config, don't involve the new memory tracker mechanism, let the old config + // work as before. + return + } + hook := func(mem uint64) { + if s.sessionVars.MemDBFootprint == nil { + tracker := memory.NewTracker(memory.LabelForMemDB, -1) + tracker.AttachTo(s.sessionVars.MemTracker) + s.sessionVars.MemDBFootprint = tracker + } + s.sessionVars.MemDBFootprint.ReplaceBytesUsed(int64(mem)) + } + s.txn.SetMemoryFootprintChangeHook(hook) +} + +// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface. +func (s *session) EncodeSessionStates(ctx context.Context, + _ sessionctx.Context, sessionStates *sessionstates.SessionStates) error { + // Transaction status is hard to encode, so we do not support it. + s.txn.mu.Lock() + valid := s.txn.Valid() + s.txn.mu.Unlock() + if valid { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has an active transaction") + } + // Data in local temporary tables is hard to encode, so we do not support it. + // Check temporary tables here to avoid circle dependency. + if s.sessionVars.LocalTemporaryTables != nil { + localTempTables := s.sessionVars.LocalTemporaryTables.(*infoschema.SessionTables) + if localTempTables.Count() > 0 { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has local temporary tables") + } + } + // The advisory locks will be released when the session is closed. + if len(s.advisoryLocks) > 0 { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has advisory locks") + } + // The TableInfo stores session ID and server ID, so the session cannot be migrated. + if len(s.lockedTables) > 0 { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has locked tables") + } + // It's insecure to migrate sandBoxMode because users can fake it. + if s.InSandBoxMode() { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session is in sandbox mode") + } + + if err := s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil { + return err + } + + hasRestrictVarPriv := false + checker := privilege.GetPrivilegeManager(s) + if checker == nil || checker.RequestDynamicVerification(s.sessionVars.ActiveRoles, "RESTRICTED_VARIABLES_ADMIN", false) { + hasRestrictVarPriv = true + } + // Encode session variables. We put it here instead of SessionVars to avoid cycle import. + sessionStates.SystemVars = make(map[string]string) + for _, sv := range variable.GetSysVars() { + switch { + case sv.HasNoneScope(), !sv.HasSessionScope(): + // Hidden attribute is deprecated. + // None-scoped variables cannot be modified. + // Noop variables should also be migrated even if they are noop. + continue + case sv.ReadOnly: + // Skip read-only variables here. We encode them into SessionStates manually. + continue + } + // Get all session variables because the default values may change between versions. + val, keep, err := s.sessionVars.GetSessionStatesSystemVar(sv.Name) + switch { + case err != nil: + return err + case !keep: + continue + case !hasRestrictVarPriv && sem.IsEnabled() && sem.IsInvisibleSysVar(sv.Name): + // If the variable has a global scope, it should be the same with the global one. + // Otherwise, it should be the same with the default value. + defaultVal := sv.Value + if sv.HasGlobalScope() { + // If the session value is the same with the global one, skip it. + if defaultVal, err = sv.GetGlobalFromHook(ctx, s.sessionVars); err != nil { + return err + } + } + if val != defaultVal { + // Case 1: the RESTRICTED_VARIABLES_ADMIN is revoked after setting the session variable. + // Case 2: the global variable is updated after the session is created. + // In any case, the variable can't be set in the new session, so give up. + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs(fmt.Sprintf("session has set invisible variable '%s'", sv.Name)) + } + default: + sessionStates.SystemVars[sv.Name] = val + } + } + + // Encode prepared statements and sql bindings. + for _, handler := range s.sessionStatesHandlers { + if err := handler.EncodeSessionStates(ctx, s, sessionStates); err != nil { + return err + } + } + return nil +} + +// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface. +func (s *session) DecodeSessionStates(ctx context.Context, + _ sessionctx.Context, sessionStates *sessionstates.SessionStates) error { + // Decode prepared statements and sql bindings. + for _, handler := range s.sessionStatesHandlers { + if err := handler.DecodeSessionStates(ctx, s, sessionStates); err != nil { + return err + } + } + + // Decode session variables. + names := variable.OrderByDependency(sessionStates.SystemVars) + // Some variables must be set before others, e.g. tidb_enable_noop_functions should be before noop variables. + for _, name := range names { + val := sessionStates.SystemVars[name] + // Experimental system variables may change scope, data types, or even be removed. + // We just ignore the errors and continue. + if err := s.sessionVars.SetSystemVar(name, val); err != nil { + logutil.Logger(ctx).Warn("set session variable during decoding session states error", + zap.String("name", name), zap.String("value", val), zap.Error(err)) + } + } + + // Decoding session vars / prepared statements may override stmt ctx, such as warnings, + // so we decode stmt ctx at last. + return s.sessionVars.DecodeSessionStates(ctx, sessionStates) +} + +func (s *session) setRequestSource(ctx context.Context, stmtLabel string, stmtNode ast.StmtNode) { + if !s.isInternal() { + if txn, _ := s.Txn(false); txn != nil && txn.Valid() { + txn.SetOption(kv.RequestSourceType, stmtLabel) + } + s.sessionVars.RequestSourceType = stmtLabel + return + } + if source := ctx.Value(kv.RequestSourceKey); source != nil { + requestSource := source.(kv.RequestSource) + if requestSource.RequestSourceType != "" { + s.sessionVars.RequestSourceType = requestSource.RequestSourceType + return + } + } + // panic in test mode in case there are requests without source in the future. + // log warnings in production mode. + if intest.InTest { + panic("unexpected no source type context, if you see this error, " + + "the `RequestSourceTypeKey` is missing in your context") + } + logutil.Logger(ctx).Warn("unexpected no source type context, if you see this warning, "+ + "the `RequestSourceTypeKey` is missing in the context", + zap.Bool("internal", s.isInternal()), + zap.String("sql", stmtNode.Text())) +} + +// RemoveLockDDLJobs removes the DDL jobs which doesn't get the metadata lock from job2ver. +func RemoveLockDDLJobs(s Session, job2ver map[int64]int64, job2ids map[int64]string, printLog bool) { + sv := s.GetSessionVars() + if sv.InRestrictedSQL { + return + } + sv.TxnCtxMu.Lock() + defer sv.TxnCtxMu.Unlock() + if sv.TxnCtx == nil { + return + } + sv.GetRelatedTableForMDL().Range(func(tblID, value any) bool { + for jobID, ver := range job2ver { + ids := util.Str2Int64Map(job2ids[jobID]) + if _, ok := ids[tblID.(int64)]; ok && value.(int64) < ver { + delete(job2ver, jobID) + elapsedTime := time.Since(oracle.GetTimeFromTS(sv.TxnCtx.StartTS)) + if elapsedTime > time.Minute && printLog { + logutil.BgLogger().Info("old running transaction block DDL", zap.Int64("table ID", tblID.(int64)), zap.Int64("jobID", jobID), zap.Uint64("connection ID", sv.ConnectionID), zap.Duration("elapsed time", elapsedTime)) + } else { + logutil.BgLogger().Debug("old running transaction block DDL", zap.Int64("table ID", tblID.(int64)), zap.Int64("jobID", jobID), zap.Uint64("connection ID", sv.ConnectionID), zap.Duration("elapsed time", elapsedTime)) + } + } + } + return true + }) +} + +// GetDBNames gets the sql layer database names from the session. +func GetDBNames(seVar *variable.SessionVars) []string { + dbNames := make(map[string]struct{}) + if seVar == nil || !config.GetGlobalConfig().Status.RecordDBLabel { + return []string{""} + } + if seVar.StmtCtx != nil { + for _, t := range seVar.StmtCtx.Tables { + dbNames[t.DB] = struct{}{} + } + } + if len(dbNames) == 0 { + dbNames[seVar.CurrentDB] = struct{}{} + } + ns := make([]string, 0, len(dbNames)) + for n := range dbNames { + ns = append(ns, n) + } + return ns +} diff --git a/pkg/session/test/txn/txn_test.go b/pkg/session/test/txn/txn_test.go new file mode 100644 index 0000000000000..3f5893157ea34 --- /dev/null +++ b/pkg/session/test/txn/txn_test.go @@ -0,0 +1,631 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package txn + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/stretchr/testify/require" +) + +// TestAutocommit . See https://dev.mysql.com/doc/internals/en/status-flags.html +func TestAutocommit(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists t;") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + tk.MustExec("create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + tk.MustExec("insert t values ()") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + tk.MustExec("begin") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + tk.MustExec("insert t values ()") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + tk.MustExec("drop table if exists t") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + + tk.MustExec("create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + tk.MustExec("set autocommit=0") + require.Equal(t, 0, int(tk.Session().Status()&mysql.ServerStatusAutocommit)) + tk.MustExec("insert t values ()") + require.Equal(t, 0, int(tk.Session().Status()&mysql.ServerStatusAutocommit)) + tk.MustExec("commit") + require.Equal(t, 0, int(tk.Session().Status()&mysql.ServerStatusAutocommit)) + tk.MustExec("drop table if exists t") + require.Equal(t, 0, int(tk.Session().Status()&mysql.ServerStatusAutocommit)) + tk.MustExec("set autocommit='On'") + require.Greater(t, int(tk.Session().Status()&mysql.ServerStatusAutocommit), 0) + + // When autocommit is 0, transaction start ts should be the first *valid* + // statement, rather than *any* statement. + tk.MustExec("create table t (id int key)") + tk.MustExec("set @@autocommit = 0") + tk.MustExec("rollback") + tk.MustExec("set @@autocommit = 0") + + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk1.MustExec("insert into t select 1") + //nolint:all_revive,revive + tk.MustQuery("select * from t").Check(testkit.Rows("1")) + tk.MustExec("delete from t") + + // When the transaction is rolled back, the global set statement would succeed. + tk.MustExec("set @@global.autocommit = 0") + tk.MustExec("begin") + tk.MustExec("insert into t values (1)") + tk.MustExec("set @@global.autocommit = 1") + tk.MustExec("rollback") + tk.MustQuery("select count(*) from t where id = 1").Check(testkit.Rows("0")) + tk.MustQuery("select @@global.autocommit").Check(testkit.Rows("1")) + + // When the transaction is committed because of switching mode, the session set statement shold succeed. + tk.MustExec("set autocommit = 0") + tk.MustExec("begin") + tk.MustExec("insert into t values (1)") + tk.MustExec("set autocommit = 1") + tk.MustExec("rollback") + tk.MustQuery("select count(*) from t where id = 1").Check(testkit.Rows("1")) + tk.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + + tk.MustExec("set autocommit = 0") + tk.MustExec("insert into t values (2)") + tk.MustExec("set autocommit = 1") + tk.MustExec("rollback") + tk.MustQuery("select count(*) from t where id = 2").Check(testkit.Rows("1")) + tk.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + + // Set should not take effect if the mode is not changed. + tk.MustExec("set autocommit = 0") + tk.MustExec("begin") + tk.MustExec("insert into t values (3)") + tk.MustExec("set autocommit = 0") + tk.MustExec("rollback") + tk.MustQuery("select count(*) from t where id = 3").Check(testkit.Rows("0")) + tk.MustQuery("select @@autocommit").Check(testkit.Rows("0")) + + tk.MustExec("set autocommit = 1") + tk.MustExec("begin") + tk.MustExec("insert into t values (4)") + tk.MustExec("set autocommit = 1") + tk.MustExec("rollback") + tk.MustQuery("select count(*) from t where id = 4").Check(testkit.Rows("0")) + tk.MustQuery("select @@autocommit").Check(testkit.Rows("1")) +} + +// TestTxnLazyInitialize tests that when autocommit = 0, not all statement starts +// a new transaction. +func TestTxnLazyInitialize(t *testing.T) { + testTxnLazyInitialize(t, false) + testTxnLazyInitialize(t, true) +} + +func testTxnLazyInitialize(t *testing.T, isPessimistic bool) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (id int)") + if isPessimistic { + tk.MustExec("set tidb_txn_mode = 'pessimistic'") + } + + tk.MustExec("set @@autocommit = 0") + _, err := tk.Session().Txn(true) + require.True(t, kv.ErrInvalidTxn.Equal(err)) + txn, err := tk.Session().Txn(false) + require.NoError(t, err) + require.False(t, txn.Valid()) + tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) + tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) + + // Those statements should not start a new transaction automatically. + tk.MustQuery("select 1") + tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) + + tk.MustExec("set @@tidb_general_log = 0") + tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) + + tk.MustQuery("explain select * from t") + tk.MustQuery("select @@tidb_current_ts").Check(testkit.Rows("0")) + + // Begin statement should start a new transaction. + tk.MustExec("begin") + txn, err = tk.Session().Txn(false) + require.NoError(t, err) + require.True(t, txn.Valid()) + tk.MustExec("rollback") + + tk.MustExec("select * from t") + txn, err = tk.Session().Txn(false) + require.NoError(t, err) + require.True(t, txn.Valid()) + tk.MustExec("rollback") + + tk.MustExec("insert into t values (1)") + txn, err = tk.Session().Txn(false) + require.NoError(t, err) + require.True(t, txn.Valid()) + tk.MustExec("rollback") +} + +func TestDisableTxnAutoRetry(t *testing.T) { + store := testkit.CreateMockStoreWithSchemaLease(t, 1*time.Second) + + setTxnTk := testkit.NewTestKit(t, store) + setTxnTk.MustExec("set global tidb_txn_mode=''") + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + + tk1.MustExec("use test") + tk2.MustExec("use test") + + tk1.MustExec("create table no_retry (id int)") + tk1.MustExec("insert into no_retry values (1)") + tk1.MustExec("set @@tidb_disable_txn_auto_retry = 1") + + tk1.MustExec("begin") + tk1.MustExec("update no_retry set id = 2") + + tk2.MustExec("begin") + tk2.MustExec("update no_retry set id = 3") + tk2.MustExec("commit") + + // No auto retry because tidb_disable_txn_auto_retry is set to 1. + _, err := tk1.Session().Execute(context.Background(), "commit") + require.Error(t, err) + + // session 1 starts a transaction early. + // execute a select statement to clear retry history. + tk1.MustExec("select 1") + err = tk1.Session().PrepareTxnCtx(context.Background()) + require.NoError(t, err) + // session 2 update the value. + tk2.MustExec("update no_retry set id = 4") + // AutoCommit update will retry, so it would not fail. + tk1.MustExec("update no_retry set id = 5") + + // RestrictedSQL should retry. + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + tk1.Session().ExecuteInternal(ctx, "begin") + + tk2.MustExec("update no_retry set id = 6") + + tk1.Session().ExecuteInternal(ctx, "update no_retry set id = 7") + tk1.Session().ExecuteInternal(ctx, "commit") + + // test for disable transaction local latch + defer config.RestoreFunc()() + config.UpdateGlobal(func(conf *config.Config) { + conf.TxnLocalLatches.Enabled = false + }) + tk1.MustExec("begin") + tk1.MustExec("update no_retry set id = 9") + + tk2.MustExec("update no_retry set id = 8") + + _, err = tk1.Session().Execute(context.Background(), "commit") + require.Error(t, err) + require.True(t, kv.ErrWriteConflict.Equal(err), fmt.Sprintf("err %v", err)) + require.Contains(t, err.Error(), kv.TxnRetryableMark) + tk1.MustExec("rollback") + + config.UpdateGlobal(func(conf *config.Config) { + conf.TxnLocalLatches.Enabled = true + }) + tk1.MustExec("begin") + tk2.MustExec("alter table no_retry add index idx(id)") + tk2.MustQuery("select * from no_retry").Check(testkit.Rows("8")) + tk1.MustExec("update no_retry set id = 10") + _, err = tk1.Session().Execute(context.Background(), "commit") + require.Error(t, err) + + // set autocommit to begin and commit + tk1.MustExec("set autocommit = 0") + tk1.MustQuery("select * from no_retry").Check(testkit.Rows("8")) + tk2.MustExec("update no_retry set id = 11") + tk1.MustExec("update no_retry set id = 12") + _, err = tk1.Session().Execute(context.Background(), "set autocommit = 1") + require.Error(t, err) + require.True(t, kv.ErrWriteConflict.Equal(err), fmt.Sprintf("err %v", err)) + require.Contains(t, err.Error(), kv.TxnRetryableMark) + tk1.MustExec("rollback") + tk2.MustQuery("select * from no_retry").Check(testkit.Rows("11")) + + tk1.MustExec("set autocommit = 0") + tk1.MustQuery("select * from no_retry").Check(testkit.Rows("11")) + tk2.MustExec("update no_retry set id = 13") + tk1.MustExec("update no_retry set id = 14") + _, err = tk1.Session().Execute(context.Background(), "commit") + require.Error(t, err) + require.True(t, kv.ErrWriteConflict.Equal(err), fmt.Sprintf("err %v", err)) + require.Contains(t, err.Error(), kv.TxnRetryableMark) + tk1.MustExec("rollback") + tk2.MustQuery("select * from no_retry").Check(testkit.Rows("13")) +} + +// The Read-only flags are checked in the planning stage of queries, +// but this test checks we check them again at commit time. +// The main use case for this is a long-running auto-commit statement. +func TestAutoCommitRespectsReadOnly(t *testing.T) { + store := testkit.CreateMockStore(t) + var wg sync.WaitGroup + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil)) + require.NoError(t, tk2.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil)) + + tk1.MustExec("create table test.auto_commit_test (a int)") + wg.Add(1) + go func() { + err := tk1.ExecToErr("INSERT INTO test.auto_commit_test VALUES (SLEEP(1))") + require.True(t, terror.ErrorEqual(err, plannercore.ErrSQLInReadOnlyMode), fmt.Sprintf("err %v", err)) + wg.Done() + }() + tk2.MustExec("SET GLOBAL tidb_restricted_read_only = 1") + err := tk2.ExecToErr("INSERT INTO test.auto_commit_test VALUES (0)") // should also be an error + require.True(t, terror.ErrorEqual(err, plannercore.ErrSQLInReadOnlyMode), fmt.Sprintf("err %v", err)) + // Reset and check with the privilege to ignore the readonly flag and continue to insert. + wg.Wait() + tk1.MustExec("SET GLOBAL tidb_restricted_read_only = 0") + tk1.MustExec("SET GLOBAL tidb_super_read_only = 0") + tk1.MustExec("GRANT RESTRICTED_REPLICA_WRITER_ADMIN on *.* to 'root'") + + wg.Add(1) + go func() { + tk1.MustExec("INSERT INTO test.auto_commit_test VALUES (SLEEP(1))") + wg.Done() + }() + tk2.MustExec("SET GLOBAL tidb_restricted_read_only = 1") + tk2.MustExec("INSERT INTO test.auto_commit_test VALUES (0)") + + // wait for go routines + wg.Wait() + tk1.MustExec("SET GLOBAL tidb_restricted_read_only = 0") + tk1.MustExec("SET GLOBAL tidb_super_read_only = 0") +} + +func TestRetryForCurrentTxn(t *testing.T) { + store := testkit.CreateMockStore(t) + + setTxnTk := testkit.NewTestKit(t, store) + setTxnTk.MustExec("set global tidb_txn_mode=''") + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("create table history (a int)") + tk.MustExec("insert history values (1)") + + // Firstly, enable retry. + tk.MustExec("set tidb_disable_txn_auto_retry = 0") + tk.MustExec("begin") + tk.MustExec("update history set a = 2") + // Disable retry now. + tk.MustExec("set tidb_disable_txn_auto_retry = 1") + + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk1.MustExec("update history set a = 3") + + tk.MustExec("commit") + tk.MustQuery("select * from history").Check(testkit.Rows("2")) +} + +func TestBatchCommit(t *testing.T) { + store := testkit.CreateMockStore(t) + setTxnTk := testkit.NewTestKit(t, store) + setTxnTk.MustExec("set global tidb_txn_mode=''") + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set tidb_batch_commit = 1") + tk.MustExec("set tidb_disable_txn_auto_retry = 0") + tk.MustExec("create table t (id int)") + defer config.RestoreFunc()() + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = 3 + }) + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk.MustExec("SET SESSION autocommit = 1") + tk.MustExec("begin") + tk.MustExec("insert into t values (1)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("insert into t values (2)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("rollback") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + + // The above rollback will not make the session in transaction. + tk.MustExec("insert into t values (1)") + tk1.MustQuery("select * from t").Check(testkit.Rows("1")) + tk.MustExec("delete from t") + + tk.MustExec("begin") + tk.MustExec("insert into t values (5)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("insert into t values (6)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("insert into t values (7)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + + tk.MustExec("delete from t") + tk.MustExec("commit") + tk.MustExec("begin") + tk.MustExec("explain analyze insert into t values (5)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (6)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (7)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + + // The session is still in transaction. + tk.MustExec("insert into t values (8)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + tk.MustExec("insert into t values (9)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + tk.MustExec("insert into t values (10)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + tk.MustExec("commit") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7", "8", "9", "10")) + + // The above commit will not make the session in transaction. + tk.MustExec("insert into t values (11)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7", "8", "9", "10", "11")) + + tk.MustExec("delete from t") + tk.MustExec("SET SESSION autocommit = 0") + tk.MustExec("insert into t values (1)") + tk.MustExec("insert into t values (2)") + tk.MustExec("insert into t values (3)") + tk.MustExec("rollback") + tk1.MustExec("insert into t values (4)") + tk1.MustExec("insert into t values (5)") + tk.MustQuery("select * from t").Check(testkit.Rows("4", "5")) +} + +func TestTxnRetryErrMsg(t *testing.T) { + store := testkit.CreateMockStore(t) + setTxnTk := testkit.NewTestKit(t, store) + setTxnTk.MustExec("set global tidb_txn_mode=''") + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk1.MustExec("create table no_retry (id int)") + tk1.MustExec("insert into no_retry values (1)") + tk1.MustExec("begin") + tk2.MustExec("use test") + tk2.MustExec("update no_retry set id = id + 1") + tk1.MustExec("update no_retry set id = id + 1") + require.NoError(t, failpoint.Enable("tikvclient/mockRetryableErrorResp", `return(true)`)) + _, err := tk1.Session().Execute(context.Background(), "commit") + require.NoError(t, failpoint.Disable("tikvclient/mockRetryableErrorResp")) + require.Error(t, err) + require.True(t, kv.ErrTxnRetryable.Equal(err), "error: %s", err) + require.True(t, strings.Contains(err.Error(), "mock retryable error"), "error: %s", err) + require.True(t, strings.Contains(err.Error(), kv.TxnRetryableMark), "error: %s", err) +} + +func TestSetTxnScope(t *testing.T) { + // Check the default value of @@tidb_enable_local_txn and @@txn_scope whitout configuring the zone label. + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustQuery("select @@global.tidb_enable_local_txn;").Check(testkit.Rows("0")) + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Check the default value of @@tidb_enable_local_txn and @@txn_scope with configuring the zone label. + require.NoError(t, failpoint.Enable("tikvclient/injectTxnScope", `return("bj")`)) + tk = testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustQuery("select @@global.tidb_enable_local_txn;").Check(testkit.Rows("0")) + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + require.NoError(t, failpoint.Disable("tikvclient/injectTxnScope")) + + // @@tidb_enable_local_txn is off without configuring the zone label. + tk = testkit.NewTestKit(t, store) + tk.MustQuery("select @@global.tidb_enable_local_txn;").Check(testkit.Rows("0")) + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to local. + err := tk.ExecToErr("set @@txn_scope = 'local';") + require.Error(t, err) + require.Regexp(t, `.*txn_scope can not be set to local when tidb_enable_local_txn is off.*`, err) + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to global. + tk.MustExec("set @@txn_scope = 'global';") + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + + // @@tidb_enable_local_txn is off with configuring the zone label. + require.NoError(t, failpoint.Enable("tikvclient/injectTxnScope", `return("bj")`)) + tk = testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustQuery("select @@global.tidb_enable_local_txn;").Check(testkit.Rows("0")) + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to local. + err = tk.ExecToErr("set @@txn_scope = 'local';") + require.Error(t, err) + require.Regexp(t, `.*txn_scope can not be set to local when tidb_enable_local_txn is off.*`, err) + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to global. + tk.MustExec("set @@txn_scope = 'global';") + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + require.NoError(t, failpoint.Disable("tikvclient/injectTxnScope")) + + // @@tidb_enable_local_txn is on without configuring the zone label. + tk = testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set global tidb_enable_local_txn = on;") + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to local. + err = tk.ExecToErr("set @@txn_scope = 'local';") + require.Error(t, err) + require.Regexp(t, `.*txn_scope can not be set to local when zone label is empty or "global".*`, err) + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to global. + tk.MustExec("set @@txn_scope = 'global';") + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + + // @@tidb_enable_local_txn is on with configuring the zone label. + require.NoError(t, failpoint.Enable("tikvclient/injectTxnScope", `return("bj")`)) + tk = testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set global tidb_enable_local_txn = on;") + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.LocalTxnScope)) + require.Equal(t, "bj", tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to global. + tk.MustExec("set @@txn_scope = 'global';") + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.GlobalTxnScope)) + require.Equal(t, kv.GlobalTxnScope, tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Set @@txn_scope to local. + tk.MustExec("set @@txn_scope = 'local';") + tk.MustQuery("select @@txn_scope;").Check(testkit.Rows(kv.LocalTxnScope)) + require.Equal(t, "bj", tk.Session().GetSessionVars().CheckAndGetTxnScope()) + // Try to set @@txn_scope to an invalid value. + err = tk.ExecToErr("set @@txn_scope='foo'") + require.Error(t, err) + require.Regexp(t, `.*txn_scope value should be global or local.*`, err) + require.NoError(t, failpoint.Disable("tikvclient/injectTxnScope")) +} + +func TestErrorRollback(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t_rollback") + tk.MustExec("create table t_rollback (c1 int, c2 int, primary key(c1))") + tk.MustExec("insert into t_rollback values (0, 0)") + + var wg sync.WaitGroup + cnt := 4 + wg.Add(cnt) + num := 20 + + for i := 0; i < cnt; i++ { + go func() { + defer wg.Done() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@session.tidb_retry_limit = 100") + for j := 0; j < num; j++ { + _, _ = tk.Exec("insert into t_rollback values (1, 1)") + tk.MustExec("update t_rollback set c2 = c2 + 1 where c1 = 0") + } + }() + } + + wg.Wait() + tk.MustQuery("select c2 from t_rollback where c1 = 0").Check(testkit.Rows(fmt.Sprint(cnt * num))) +} + +// TestInTrans . See https://dev.mysql.com/doc/internals/en/status-flags.html +func TestInTrans(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)") + tk.MustExec("insert t values ()") + tk.MustExec("begin") + txn, err := tk.Session().Txn(true) + require.NoError(t, err) + require.True(t, txn.Valid()) + tk.MustExec("insert t values ()") + require.True(t, txn.Valid()) + tk.MustExec("drop table if exists t;") + require.False(t, txn.Valid()) + tk.MustExec("create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)") + require.False(t, txn.Valid()) + tk.MustExec("insert t values ()") + require.False(t, txn.Valid()) + tk.MustExec("commit") + tk.MustExec("insert t values ()") + + tk.MustExec("set autocommit=0") + tk.MustExec("begin") + require.True(t, txn.Valid()) + tk.MustExec("insert t values ()") + require.True(t, txn.Valid()) + tk.MustExec("commit") + require.False(t, txn.Valid()) + tk.MustExec("insert t values ()") + require.True(t, txn.Valid()) + tk.MustExec("commit") + require.False(t, txn.Valid()) + + tk.MustExec("set autocommit=1") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)") + tk.MustExec("begin") + require.True(t, txn.Valid()) + tk.MustExec("insert t values ()") + require.True(t, txn.Valid()) + tk.MustExec("rollback") + require.False(t, txn.Valid()) +} + +func TestCommitRetryCount(t *testing.T) { + store := testkit.CreateMockStore(t) + + setTxnTk := testkit.NewTestKit(t, store) + setTxnTk.MustExec("set global tidb_txn_mode=''") + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + + tk1.MustExec("create table no_retry (id int)") + tk1.MustExec("insert into no_retry values (1)") + tk1.MustExec("set @@tidb_retry_limit = 0") + + tk1.MustExec("begin") + tk1.MustExec("update no_retry set id = 2") + + tk2.MustExec("begin") + tk2.MustExec("update no_retry set id = 3") + tk2.MustExec("commit") + + // No auto retry because retry limit is set to 0. + require.Error(t, tk1.ExecToErr("commit")) +} diff --git a/server/server_test.go b/server/server_test.go index f81d8ff5c9069..e0cae87404b51 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -35,11 +35,22 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" +<<<<<<< HEAD:server/server_test.go "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/versioninfo" +======= + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/kv" + tmysql "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/server" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testenv" + "github.com/pingcap/tidb/pkg/util/versioninfo" +>>>>>>> 9d6d6fd3da1 (session: fix select for update statement can't get stmt-count-limit error (#48412)):pkg/server/internal/testserverclient/server_client.go "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -2244,3 +2255,81 @@ func (cli *testServerClient) runTestInfoschemaClientErrors(t *testing.T) { }) } +<<<<<<< HEAD:server/server_test.go +======= + +func (cli *TestServerClient) RunTestStmtCountLimit(t *testing.T) { + originalStmtCountLimit := config.GetGlobalConfig().Performance.StmtCountLimit + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = 3 + }) + defer func() { + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = originalStmtCountLimit + }) + }() + + cli.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (id int key);") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err := dbt.GetDB().Query("select * from t for update;") + require.Error(t, err) + require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("insert into t values (3);") + dbt.MustExec("commit;") + rows := dbt.MustQuery("select * from t;") + var id int + count := 0 + for rows.Next() { + rows.Scan(&id) + count++ + } + require.NoError(t, rows.Close()) + require.Equal(t, 3, id) + require.Equal(t, 1, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err = dbt.GetDB().Exec("insert into t values (3);") + require.Error(t, err) + require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 0, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_batch_commit=1;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + dbt.MustExec("insert into t values (3);") + dbt.MustExec("insert into t values (4);") + dbt.MustExec("insert into t values (5);") + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 5, count) + }) +} + +//revive:enable:exported +>>>>>>> 9d6d6fd3da1 (session: fix select for update statement can't get stmt-count-limit error (#48412)):pkg/server/internal/testserverclient/server_client.go diff --git a/session/tidb.go b/session/tidb.go index 911e64f3727f2..068978b301481 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -250,7 +250,7 @@ func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.St if err != nil { return err } - return checkStmtLimit(ctx, se) + return checkStmtLimit(ctx, se, true) } func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { @@ -280,19 +280,34 @@ func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql s return nil } -func checkStmtLimit(ctx context.Context, se *session) error { +func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error { // If the user insert, insert, insert ... but never commit, TiDB would OOM. // So we limit the statement count in a transaction here. var err error sessVars := se.GetSessionVars() history := GetHistory(se) - if history.Count() > int(config.GetGlobalConfig().Performance.StmtCountLimit) { + stmtCount := history.Count() + if !isFinish { + // history stmt count + current stmt, since current stmt is not finish, it has not add to history. + stmtCount++ + } + if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) { if !sessVars.BatchCommit { se.RollbackTxn(ctx) - return errors.Errorf("statement count %d exceeds the transaction limitation, autocommit = %t", - history.Count(), sessVars.IsAutocommit()) + return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t", + stmtCount, sessVars.IsAutocommit()) } +<<<<<<< HEAD:session/tidb.go err = se.NewTxn(ctx) +======= + if !isFinish { + // if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit. + return nil + } + // If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true, + // then commit the current transaction and create a new transaction. + err = sessiontxn.NewTxn(ctx, se) +>>>>>>> 9d6d6fd3da1 (session: fix select for update statement can't get stmt-count-limit error (#48412)):pkg/session/tidb.go // The transaction does not committed yet, we need to keep it in transaction. // The last history could not be "commit"/"rollback" statement. // It means it is impossible to start a new transaction at the end of the transaction. @@ -303,6 +318,7 @@ func checkStmtLimit(ctx context.Context, se *session) error { } // GetHistory get all stmtHistory in current txn. Exported only for test. +// If stmtHistory is nil, will create a new one for current txn. func GetHistory(ctx sessionctx.Context) *StmtHistory { hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory) if ok {