From c9bcee87801f5ab280ae9904b3fed5f22ff7da2f Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sun, 12 Feb 2023 00:34:31 +0800 Subject: [PATCH] fix: handle database switch correctly --- pkg/constants/mysql/constants.go | 3 +- pkg/executor/redirect.go | 73 ++++++---- pkg/executor/redirect_test.go | 27 +++- pkg/mysql/client.go | 4 +- pkg/mysql/conn.go | 62 +++++++-- pkg/mysql/context.go | 33 ----- pkg/mysql/execute_handle.go | 53 +++---- pkg/mysql/server.go | 33 ++--- pkg/proto/interface.go | 54 +++++--- pkg/runtime/function/acos_test.go | 4 +- pkg/runtime/function/asin_test.go | 4 +- pkg/runtime/function/cast_date.go | 10 +- pkg/runtime/function/cast_datetime.go | 18 ++- pkg/runtime/function/cast_time.go | 8 +- pkg/runtime/misc/extvalue/visitor.go | 2 +- pkg/runtime/optimize/ddl/check_table.go | 2 +- pkg/runtime/plan/dml/mapping.go | 2 +- testdata/mock_interface.go | 177 ++++++++++++++++++++++++ 18 files changed, 398 insertions(+), 171 deletions(-) delete mode 100644 pkg/mysql/context.go create mode 100644 testdata/mock_interface.go diff --git a/pkg/constants/mysql/constants.go b/pkg/constants/mysql/constants.go index da58e250..c088bdd3 100644 --- a/pkg/constants/mysql/constants.go +++ b/pkg/constants/mysql/constants.go @@ -589,8 +589,7 @@ const ( // SSNoDatabaseSelected is ER_NO_DB SSNoDatabaseSelected = "3D000" - // SSSPNotExist is ER_SP_DOES_NOT_EXIST - SSSPNotExist = "42000" + SS42000 = "42000" ) // Status flags. They are returned by the server in a few cases. diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go index 064296b7..c477c343 100644 --- a/pkg/executor/redirect.go +++ b/pkg/executor/redirect.go @@ -20,6 +20,7 @@ package executor import ( "bytes" stdErrors "errors" + "fmt" "strings" "sync" "time" @@ -32,6 +33,8 @@ import ( pMysql "github.com/arana-db/parser/mysql" "github.com/pkg/errors" + + "golang.org/x/exp/slices" ) import ( @@ -77,7 +80,7 @@ func IsErrMissingTx(err error) bool { } type RedirectExecutor struct { - localTransactionMap sync.Map // map[uint32]proto.Tx, (ConnectionID,Tx) + localTransactionMap sync.Map // map[uint32]proto.Tx, (connectionID,Tx) } func NewRedirectExecutor() *RedirectExecutor { @@ -89,7 +92,7 @@ func (executor *RedirectExecutor) ProcessDistributedTransaction() bool { } func (executor *RedirectExecutor) InLocalTransaction(ctx *proto.Context) bool { - _, ok := executor.localTransactionMap.Load(ctx.ConnectionID) + _, ok := executor.localTransactionMap.Load(ctx.C.ID()) return ok } @@ -97,22 +100,27 @@ func (executor *RedirectExecutor) InGlobalTransaction(ctx *proto.Context) bool { return false } -func (executor *RedirectExecutor) ExecuteUseDB(ctx *proto.Context) error { - // TODO: check permission, target database should belong to same tenant. - // TODO: process transactions when database switched? +func (executor *RedirectExecutor) ExecuteUseDB(ctx *proto.Context, db string) error { + if ctx.C.Schema() == db { + return nil + } + + clusters := security.DefaultTenantManager().GetClusters(ctx.C.Tenant()) + if !slices.Contains(clusters, db) { + return mysqlErrors.NewSQLError(mConstants.ERBadDb, mConstants.SS42000, fmt.Sprintf("Unknown database '%s'", db)) + } + + if hasTx := executor.InLocalTransaction(ctx); hasTx { + // TODO: should commit existing TX when DB switched + log.Debugf("commit tx when db switched: conn=%s", ctx.C) + } + + // bind schema + ctx.C.SetSchema(db) + + // reset transient variables + ctx.C.SetTransientVariables(make(map[string]proto.Value)) - // do nothing. - //resourcePool := resource.GetDataSourceManager().GetMasterResourcePool(executor.dataSources[0].Master.Name) - //r, err := resourcePool.Get(ctx) - //defer func() { - // resourcePool.Put(r) - //}() - //if err != nil { - // return err - //} - //backendConn := r.(*mysql.BackendConnection) - //db := string(ctx.Data[1:]) - //return backendConn.WriteComInitDB(db) return nil } @@ -121,7 +129,7 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto. table := string(ctx.Data[1:index]) wildcard := string(ctx.Data[index+1:]) - rt, err := runtime.Load(ctx.Schema) + rt, err := runtime.Load(ctx.C.Schema()) if err != nil { return nil, errors.WithStack(err) } @@ -141,6 +149,15 @@ func (executor *RedirectExecutor) ExecuteFieldList(ctx *proto.Context) ([]proto. } func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast.StmtNode) (proto.Result, uint16, error) { + // switch DB + switch u := act.(type) { + case *ast.UseStmt: + if err := executor.ExecuteUseDB(ctx, u.DBName); err != nil { + return nil, 0, err + } + return resultx.New(), 0, nil + } + var ( start = time.Now() schemaless bool // true if schema is not specified @@ -158,15 +175,15 @@ func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast trace.Extract(ctx, hints) metrics.ParserDuration.Observe(time.Since(start).Seconds()) - if len(ctx.Schema) < 1 { + if len(ctx.C.Schema()) < 1 { // TODO: handle multiple clusters - clusters := security.DefaultTenantManager().GetClusters(ctx.Tenant) + clusters := security.DefaultTenantManager().GetClusters(ctx.C.Tenant()) if len(clusters) != 1 { // reject if no schema specified return nil, 0, mysqlErrors.NewSQLError(mConstants.ERNoDb, mConstants.SSNoDatabaseSelected, "No database selected") } schemaless = true - ctx.Schema = security.DefaultTenantManager().GetClusters(ctx.Tenant)[0] + ctx.C.SetSchema(security.DefaultTenantManager().GetClusters(ctx.C.Tenant())[0]) } ctx.Stmt = &proto.Stmt{ @@ -174,7 +191,7 @@ func (executor *RedirectExecutor) doExecutorComQuery(ctx *proto.Context, act ast StmtNode: act, } - rt, err := runtime.Load(ctx.Schema) + rt, err := runtime.Load(ctx.C.Schema()) if err != nil { return nil, 0, err } @@ -285,7 +302,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context, h func(re query := ctx.GetQuery() log.Debugf("ComQuery: %s", query) - charset, collation := getCharsetCollation(ctx.CharacterSet) + charset, collation := getCharsetCollation(ctx.C.CharacterSet()) switch strings.IndexByte(query, ';') { case -1: // no ';' exists @@ -352,7 +369,7 @@ func (executor *RedirectExecutor) ExecutorComStmtExecute(ctx *proto.Context) (pr executable = tx } else { var rt runtime.Runtime - if rt, err = runtime.Load(ctx.Schema); err != nil { + if rt, err = runtime.Load(ctx.C.Schema()); err != nil { return nil, 0, err } executable = rt @@ -380,7 +397,7 @@ func (executor *RedirectExecutor) ConnectionClose(ctx *proto.Context) { } //resourcePool := resource.GetDataSourceManager().GetMasterResourcePool(executor.dataSources[0].Master.Name) - //r, ok := executor.localTransactionMap[ctx.ConnectionID] + //r, ok := executor.localTransactionMap[ctx.connectionID] //if ok { // defer func() { // resourcePool.Put(r) @@ -394,11 +411,11 @@ func (executor *RedirectExecutor) ConnectionClose(ctx *proto.Context) { } func (executor *RedirectExecutor) putTx(ctx *proto.Context, tx proto.Tx) { - executor.localTransactionMap.Store(ctx.ConnectionID, tx) + executor.localTransactionMap.Store(ctx.C.ID(), tx) } func (executor *RedirectExecutor) removeTx(ctx *proto.Context) (proto.Tx, bool) { - exist, ok := executor.localTransactionMap.LoadAndDelete(ctx.ConnectionID) + exist, ok := executor.localTransactionMap.LoadAndDelete(ctx.C.ID()) if !ok { return nil, false } @@ -406,7 +423,7 @@ func (executor *RedirectExecutor) removeTx(ctx *proto.Context) (proto.Tx, bool) } func (executor *RedirectExecutor) getTx(ctx *proto.Context) (proto.Tx, bool) { - exist, ok := executor.localTransactionMap.Load(ctx.ConnectionID) + exist, ok := executor.localTransactionMap.Load(ctx.C.ID()) if !ok { return nil, false } diff --git a/pkg/executor/redirect_test.go b/pkg/executor/redirect_test.go index 440a1211..b9563a23 100644 --- a/pkg/executor/redirect_test.go +++ b/pkg/executor/redirect_test.go @@ -22,6 +22,8 @@ import ( ) import ( + "github.com/golang/mock/gomock" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -29,6 +31,7 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/testdata" ) func TestIsErrMissingTx(t *testing.T) { @@ -42,21 +45,33 @@ func TestProcessDistributedTransaction(t *testing.T) { } func TestInGlobalTransaction(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + c := testdata.NewMockFrontConn(ctrl) + c.EXPECT().ID().Return(uint32(0)).AnyTimes() + redirect := NewRedirectExecutor() - assert.False(t, redirect.InGlobalTransaction(createContext())) + assert.False(t, redirect.InGlobalTransaction(createContext(c))) } func TestInLocalTransaction(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + c := testdata.NewMockFrontConn(ctrl) + c.EXPECT().ID().Return(uint32(0)).Times(1) + redirect := NewRedirectExecutor() - result := redirect.InLocalTransaction(createContext()) + result := redirect.InLocalTransaction(createContext(c)) assert.False(t, result) } -func createContext() *proto.Context { +func createContext(c proto.FrontConn) *proto.Context { result := &proto.Context{ - ConnectionID: 0, - Data: make([]byte, 0), - Stmt: nil, + C: c, + Data: make([]byte, 0), + Stmt: nil, } return result } diff --git a/pkg/mysql/client.go b/pkg/mysql/client.go index 0d0a15f6..7372c593 100644 --- a/pkg/mysql/client.go +++ b/pkg/mysql/client.go @@ -501,7 +501,7 @@ func (c *Connector) NewBackendConnection(ctx context.Context) (*BackendConnectio // \ / // >>> SYNC <<< func (conn *BackendConnection) SyncVariables(vars map[string]proto.Value) error { - transient := conn.c.TransientVariables + transient := conn.c.TransientVariables() if len(vars) < 1 && len(transient) < 1 { return nil @@ -799,7 +799,7 @@ func (conn *BackendConnection) parseInitialHandshakePacket(data []byte) (uint32, } // Read the connection id. - conn.c.ConnectionID, pos, ok = readUint32(data, pos) + conn.c.connectionID, pos, ok = readUint32(data, pos) if !ok { return 0, nil, "", err2.NewSQLError(mysql.CRMalformedPacket, mysql.SSUnknownSQLState, "parseInitialHandshakePacket: packet has no connection id") } diff --git a/pkg/mysql/conn.go b/pkg/mysql/conn.go index ce65b48a..4298d9a4 100644 --- a/pkg/mysql/conn.go +++ b/pkg/mysql/conn.go @@ -91,21 +91,21 @@ type Conn struct { // If there are any ongoing reads or writes, they may get interrupted. conn net.Conn - // Schema is the current database name. - Schema string + // schema is the current database name. + schema string - // Tenant is the current tenant login. - Tenant string + // tenant is the current tenant login. + tenant string - // ConnectionID is set: + // connectionID is set: // - at Connect() time for clients, with the value returned by // the server. // - at accept time for the server. - ConnectionID uint32 + connectionID uint32 - // TransientVariables represents local transient variables. + // transientVariables represents local transient variables. // These variables will always keep sync with backend mysql conns. - TransientVariables map[string]proto.Value + transientVariables map[string]proto.Value // closed is set to true when Close() is called on the connection. closed *atomic.Bool @@ -149,7 +149,9 @@ type Conn struct { // connection. // It is set during the initial handshake. // See the values in constants.go. - CharacterSet uint8 + characterSet uint8 + + serverVersion string } // newConn is an internal method to create a Conn. Used by client and server @@ -159,10 +161,42 @@ func newConn(conn net.Conn) *Conn { conn: conn, closed: atomic.NewBool(false), bufferedReader: bufio.NewReaderSize(conn, connBufferSize), - TransientVariables: make(map[string]proto.Value), + transientVariables: make(map[string]proto.Value), } } +func (c *Conn) ServerVersion() string { + return c.serverVersion +} + +func (c *Conn) CharacterSet() uint8 { + return c.characterSet +} + +func (c *Conn) Schema() string { + return c.schema +} + +func (c *Conn) SetSchema(schema string) { + c.schema = schema +} + +func (c *Conn) Tenant() string { + return c.tenant +} + +func (c *Conn) SetTenant(t string) { + c.tenant = t +} + +func (c *Conn) TransientVariables() map[string]proto.Value { + return c.transientVariables +} + +func (c *Conn) SetTransientVariables(v map[string]proto.Value) { + c.transientVariables = v +} + // startWriterBuffering starts using buffered writes. This should // be terminated by a call to endWriteBuffering. func (c *Conn) startWriterBuffering() { @@ -631,13 +665,13 @@ func (c *Conn) RemoteAddr() net.Addr { } // ID returns the MySQL connection ID for this connection. -func (c *Conn) ID() int64 { - return int64(c.ConnectionID) +func (c *Conn) ID() uint32 { + return c.connectionID } // Ident returns a useful identification string for error logging func (c *Conn) String() string { - return fmt.Sprintf("client %v (%s)", c.ConnectionID, c.RemoteAddr().String()) + return fmt.Sprintf("client %v (%s)", c.ID(), c.RemoteAddr().String()) } // Close closes the connection. It can be called from a different go @@ -708,7 +742,7 @@ func (c *Conn) fixErrNoSuchTable(errorMessage string) string { var sb strings.Builder sb.Grow(len(errorMessage)) sb.WriteString("Table '") - sb.WriteString(c.Schema) + sb.WriteString(c.Schema()) sb.WriteByte('.') sb.WriteString(matches[2]) sb.WriteString("' doesn't exist") diff --git a/pkg/mysql/context.go b/pkg/mysql/context.go deleted file mode 100644 index 05f865f2..00000000 --- a/pkg/mysql/context.go +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package mysql - -import ( - "context" -) - -// Context -type Context struct { - context.Context - - Conn *Conn - - CommandType byte - // sql Data - Data []byte -} diff --git a/pkg/mysql/execute_handle.go b/pkg/mysql/execute_handle.go index ed1f98b7..083b9108 100644 --- a/pkg/mysql/execute_handle.go +++ b/pkg/mysql/execute_handle.go @@ -23,6 +23,8 @@ import ( import ( "github.com/arana-db/parser" + + "golang.org/x/exp/slices" ) import ( @@ -39,28 +41,19 @@ func (l *Listener) handleInitDB(c *Conn, ctx *proto.Context) error { db := string(ctx.Data[1:]) c.recycleReadPacket() - var allow bool - for _, it := range security.DefaultTenantManager().GetClusters(c.Tenant) { - if db == it { - allow = true - break - } - } - - if !allow { - if err := c.writeErrorPacketFromError(errors.NewSQLError(mysql.ERBadDb, "", "Unknown database '%s'", db)); err != nil { + if clusters := security.DefaultTenantManager().GetClusters(c.Tenant()); !slices.Contains(clusters, db) { + if err := c.writeErrorPacketFromError(errors.NewSQLError(mysql.ERBadDb, mysql.SS42000, "Unknown database '%s'", db)); err != nil { log.Errorf("failed to write ComInitDB error to %s: %v", c, err) return err } return nil } - c.Schema = db - err := l.executor.ExecuteUseDB(ctx) - if err != nil { + if err := l.executor.ExecuteUseDB(ctx, db); err != nil { return err } - if err = c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { + + if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { log.Errorf("Error writing ComInitDB result to %s: %v", c, err) return err } @@ -75,23 +68,23 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error { c.startWriterBuffering() defer func() { if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", ctx.ConnectionID, err) + log.Errorf("conn %v: flush() failed: %v", ctx.C.ID(), err) } }() if failure != nil { - log.Errorf("executor com_query error %v: %+v", ctx.ConnectionID, failure) + log.Errorf("executor com_query error %v: %+v", ctx.C.ID(), failure) if err := c.writeErrorPacketFromError(failure); err != nil { - log.Errorf("Error writing query error to client %v: %v", ctx.ConnectionID, err) + log.Errorf("Error writing query error to client %v: %v", ctx.C.ID(), err) return err } return nil } if result == nil { - log.Errorf("executor com_query error %v: %+v", ctx.ConnectionID, "un dataset") + log.Errorf("executor com_query error %v: %+v", ctx.C.ID(), "un dataset") if err := c.writeErrorPacketFromError(errors.NewSQLError(mysql.ERBadNullError, mysql.SSUnknownSQLState, "un dataset")); err != nil { - log.Errorf("Error writing query error to client %v: %v", ctx.ConnectionID, failure) + log.Errorf("Error writing query error to client %v: %v", ctx.C.ID(), failure) return err } return nil @@ -99,9 +92,9 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error { var ds proto.Dataset if ds, failure = result.Dataset(); failure != nil { - log.Errorf("get dataset error %v: %v", ctx.ConnectionID, failure) + log.Errorf("get dataset error %v: %v", ctx.C.ID(), failure) if err := c.writeErrorPacketFromError(failure); err != nil { - log.Errorf("Error writing query error to client %v: %v", ctx.ConnectionID, err) + log.Errorf("Error writing query error to client %v: %v", ctx.C.ID(), err) return err } return nil @@ -125,7 +118,7 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error { } if err := c.writeOKPacket(affected, insertId, statusFlag, warn); err != nil { - log.Errorf("failed to write OK packet into client %v: %v", ctx.ConnectionID, err) + log.Errorf("failed to write OK packet into client %v: %v", ctx.C.ID(), err) return err } return nil @@ -134,11 +127,11 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error { fields, _ := ds.Fields() if err := c.writeFields(fields); err != nil { - log.Errorf("write fields error %v: %v", ctx.ConnectionID, err) + log.Errorf("write fields error %v: %v", ctx.C.ID(), err) return err } if err := c.writeDataset(ds); err != nil { - log.Errorf("write dataset error %v: %v", ctx.ConnectionID, err) + log.Errorf("write dataset error %v: %v", ctx.C.ID(), err) return err } if err := c.writeEndResult(hasMore, 0, 0, warn); err != nil { @@ -212,7 +205,7 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { c.startWriterBuffering() defer func() { if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", ctx.ConnectionID, err) + log.Errorf("conn %v: flush() failed: %v", ctx.C.ID(), err) } }() @@ -237,7 +230,7 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { if err != nil { if wErr := c.writeErrorPacketFromError(err); wErr != nil { // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", ctx.ConnectionID, wErr) + log.Error("Error writing query error to client %v: %v", ctx.C.ID(), wErr) return wErr } return nil @@ -253,7 +246,7 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { if result, warn, err = l.executor.ExecutorComStmtExecute(ctx); err != nil { if wErr := c.writeErrorPacketFromError(err); wErr != nil { - log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.ConnectionID, wErr, err) + log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.C.ID(), wErr, err) return wErr } return nil @@ -262,7 +255,7 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error { var ds proto.Dataset if ds, err = result.Dataset(); err != nil { if wErr := c.writeErrorPacketFromError(err); wErr != nil { - log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.ConnectionID, wErr, err) + log.Errorf("Error writing query error to client %v: %v, executor error: %v", ctx.C.ID(), wErr, err) return wErr } return nil @@ -370,7 +363,7 @@ func (l *Listener) handleSetOption(c *Conn, ctx *proto.Context) error { case 1: c.Capabilities &^= mysql.CapabilityClientMultiStatements default: - log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", ctx.ConnectionID, ctx.Data) + log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", ctx.C.ID(), ctx.Data) if err := c.writeErrorPacket(mysql.ERUnknownComError, mysql.SSUnknownComError, "error handling packet: %v", ctx.Data); err != nil { log.Errorf("Error writing error packet to client: %v", err) return err @@ -381,7 +374,7 @@ func (l *Listener) handleSetOption(c *Conn, ctx *proto.Context) error { return err } } - log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", ctx.ConnectionID, ctx.Data) + log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", ctx.C.ID(), ctx.Data) if err := c.writeErrorPacket(mysql.ERUnknownComError, mysql.SSUnknownComError, "error handling packet: %v", ctx.Data); err != nil { log.Errorf("Error writing error packet to client: %v", err) return err diff --git a/pkg/mysql/server.go b/pkg/mysql/server.go index ef9d501f..35c5eb2a 100644 --- a/pkg/mysql/server.go +++ b/pkg/mysql/server.go @@ -25,6 +25,7 @@ import ( "io" "math" "net" + "runtime/debug" "strconv" "strings" "sync" @@ -155,17 +156,17 @@ func (l *Listener) Close() { func (l *Listener) handle(conn net.Conn, connectionID uint32) { c := newConn(conn) - c.ConnectionID = connectionID + c.connectionID = connectionID // Catch panics, and close the connection in any case. defer func() { if x := recover(); x != nil { - log.Errorf("mysql_server caught panic:\n%v", x) + log.Errorf("mysql_server caught panic:\n%v\n%v", x, string(debug.Stack())) } conn.Close() l.executor.ConnectionClose(&proto.Context{ - Context: context.Background(), - ConnectionID: c.ConnectionID, + Context: context.Background(), + C: c, }) }() @@ -179,7 +180,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) { } c.Capabilities = l.capabilities - c.CharacterSet = l.characterSet + c.characterSet = l.characterSet + c.serverVersion = l.conf.ServerVersion // Negotiation worked, send OK packet. if err = c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { @@ -201,19 +203,14 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) { content := make([]byte, len(data)) copy(content, data) ctx := &proto.Context{ - Context: context.Background(), - Schema: c.Schema, - Tenant: c.Tenant, - ServerVersion: l.conf.ServerVersion, - ConnectionID: c.ConnectionID, - Data: content, - TransientVariables: c.TransientVariables, - CharacterSet: c.CharacterSet, + Context: context.Background(), + C: c, + Data: content, } if err = l.ExecuteCommand(c, ctx); err != nil { if err == io.EOF { - log.Debugf("the connection#%d of remote client %s requests quit", c.ConnectionID, c.conn.(*net.TCPConn).RemoteAddr()) + log.Debugf("the connection#%d of remote client %s requests quit", c.ID(), c.conn.(*net.TCPConn).RemoteAddr()) } else { log.Errorf("failed to execute command: %v", err) } @@ -253,7 +250,7 @@ func (l *Listener) handshake(c *Conn) error { log.Errorf("Cannot parse client handshake response from %s: %v", c, err) return err } - handshake.connectionID = c.ConnectionID + handshake.connectionID = c.ID() handshake.salt = salt if err = l.ValidateHash(handshake); err != nil { @@ -261,8 +258,8 @@ func (l *Listener) handshake(c *Conn) error { return err } - c.Schema = handshake.schema - c.Tenant = handshake.tenant + c.SetSchema(handshake.schema) + c.SetTenant(handshake.tenant) return nil } @@ -311,7 +308,7 @@ func (l *Listener) writeHandshakeV10(c *Conn, enableTLS bool, salt []byte) error pos = writeNullString(data, pos, l.conf.ServerVersion) // Add connectionID in. - pos = writeUint32(data, pos, c.ConnectionID) + pos = writeUint32(data, pos, c.ID()) pos += copy(data[pos:], salt[:8]) diff --git a/pkg/proto/interface.go b/pkg/proto/interface.go index ed693dbe..fd3f5678 100644 --- a/pkg/proto/interface.go +++ b/pkg/proto/interface.go @@ -15,6 +15,7 @@ * limitations under the License. */ +//go:generate mockgen -destination=../../testdata/mock_interface.go -package=testdata . FrontConn package proto import ( @@ -35,27 +36,46 @@ type ( ) type ( + // FrontConn represents a frontend connection. + // APP ---> FRONTEND_CONN ---> ARANA ---> BACKEND_CONN ---> MySQL + FrontConn interface { + // ID returns connection id. + ID() uint32 + + // Schema returns the current schema. + Schema() string + + // SetSchema sets the current schema. + SetSchema(schema string) + + // Tenant returns the tenant. + Tenant() string + + // SetTenant sets the tenant. + SetTenant(tenant string) + + // TransientVariables returns the transient variables. + TransientVariables() map[string]Value + + // SetTransientVariables sets the transient variables. + SetTransientVariables(v map[string]Value) + + // CharacterSet returns the character set. + CharacterSet() uint8 + + // ServerVersion returns the server version. + ServerVersion() string + } // Context is used to carry context objects Context struct { context.Context - Tenant string - Schema string - ServerVersion string - - ConnectionID uint32 + C FrontConn // sql Data Data []byte Stmt *Stmt - - CharacterSet uint8 - - // TransientVariables stores the transient local variables, it will sync with the remote node automatically. - // - SYSTEM: @@xxx - // - USER: @xxx - TransientVariables map[string]Value } Listener interface { @@ -68,7 +88,7 @@ type ( ProcessDistributedTransaction() bool InLocalTransaction(ctx *Context) bool InGlobalTransaction(ctx *Context) bool - ExecuteUseDB(ctx *Context) error + ExecuteUseDB(ctx *Context, schema string) error ExecuteFieldList(ctx *Context) ([]Field, error) ExecutorComQuery(ctx *Context, callback func(Result, uint16, error) error) error ExecutorComStmtExecute(ctx *Context) (Result, uint16, error) @@ -111,15 +131,15 @@ func (c Context) GetArgs() []Value { func (c Context) Value(key interface{}) interface{} { switch key.(type) { case ContextKeyTenant: - return c.Tenant + return c.C.Tenant() case ContextKeySchema: - return c.Schema + return c.C.Schema() case ContextKeyTransientVariables: - return c.TransientVariables + return c.C.TransientVariables() case ContextKeySQL: return c.GetQuery() case ContextKeyServerVersion: - return c.ServerVersion + return c.C.ServerVersion() } return c.Context.Value(key) } diff --git a/pkg/runtime/function/acos_test.go b/pkg/runtime/function/acos_test.go index 9079e10e..b13ceaae 100644 --- a/pkg/runtime/function/acos_test.go +++ b/pkg/runtime/function/acos_test.go @@ -25,11 +25,11 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/proto" + "github.com/stretchr/testify/assert" ) import ( - "github.com/stretchr/testify/assert" + "github.com/arana-db/arana/pkg/proto" ) func TestAcos(t *testing.T) { diff --git a/pkg/runtime/function/asin_test.go b/pkg/runtime/function/asin_test.go index bca54ee1..9f7d9677 100644 --- a/pkg/runtime/function/asin_test.go +++ b/pkg/runtime/function/asin_test.go @@ -25,11 +25,11 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/proto" + "github.com/stretchr/testify/assert" ) import ( - "github.com/stretchr/testify/assert" + "github.com/arana-db/arana/pkg/proto" ) func TestAsin(t *testing.T) { diff --git a/pkg/runtime/function/cast_date.go b/pkg/runtime/function/cast_date.go index 9ac960b5..a0f7cc47 100644 --- a/pkg/runtime/function/cast_date.go +++ b/pkg/runtime/function/cast_date.go @@ -36,10 +36,12 @@ import ( // FuncCastDate is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastDate = "CAST_DATE" -var DateSep = `[~!@#$%^&*_+=:;,.|/?\(\)\[\]\{\}\-\\]+` -var _dateReplace = regexp.MustCompile(DateSep) -var _dateMatchString = regexp.MustCompile(fmt.Sprintf(`^\d{1,4}%s\d{1,2}%s\d{1,2}$`, DateSep, DateSep)) -var _dateMatchInt = regexp.MustCompile(`^\d{5,8}$`) +var ( + DateSep = `[~!@#$%^&*_+=:;,.|/?\(\)\[\]\{\}\-\\]+` + _dateReplace = regexp.MustCompile(DateSep) + _dateMatchString = regexp.MustCompile(fmt.Sprintf(`^\d{1,4}%s\d{1,2}%s\d{1,2}$`, DateSep, DateSep)) + _dateMatchInt = regexp.MustCompile(`^\d{5,8}$`) +) var _ proto.Func = (*castDateFunc)(nil) diff --git a/pkg/runtime/function/cast_datetime.go b/pkg/runtime/function/cast_datetime.go index a0ac6c8a..99ea86a5 100644 --- a/pkg/runtime/function/cast_datetime.go +++ b/pkg/runtime/function/cast_datetime.go @@ -37,11 +37,13 @@ import ( // FuncCastDatetime is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastDatetime = "CAST_DATETIME" -var DatetimeSep = `[~!@#$%^&*_+=:;,|/?\(\)\[\]\{\}\-\\]+` -var _datetimeReplace = regexp.MustCompile(DatetimeSep) -var _datetimeMatchUpperString = regexp.MustCompile(fmt.Sprintf(`^\d{1,4}%s\d{1,2}%s\d{1,2}$`, DatetimeSep, DatetimeSep)) -var _datetimeMatchLowerString = regexp.MustCompile(fmt.Sprintf(`^\d{1,2}%s\d{1,2}%s\d{1,2}$`, DatetimeSep, DatetimeSep)) -var _datetimeMatchInt = regexp.MustCompile(`^\d{11,14}$`) +var ( + DatetimeSep = `[~!@#$%^&*_+=:;,|/?\(\)\[\]\{\}\-\\]+` + _datetimeReplace = regexp.MustCompile(DatetimeSep) + _datetimeMatchUpperString = regexp.MustCompile(fmt.Sprintf(`^\d{1,4}%s\d{1,2}%s\d{1,2}$`, DatetimeSep, DatetimeSep)) + _datetimeMatchLowerString = regexp.MustCompile(fmt.Sprintf(`^\d{1,2}%s\d{1,2}%s\d{1,2}$`, DatetimeSep, DatetimeSep)) + _datetimeMatchInt = regexp.MustCompile(`^\d{11,14}$`) +) var _ proto.Func = (*castDatetimeFunc)(nil) @@ -51,8 +53,10 @@ func init() { type castDatetimeFunc struct{} -var castDate castDateFunc -var castTime castTimeFunc +var ( + castDate castDateFunc + castTime castTimeFunc +) func (a castDatetimeFunc) Apply(ctx context.Context, inputs ...proto.Valuer) (proto.Value, error) { // expr diff --git a/pkg/runtime/function/cast_time.go b/pkg/runtime/function/cast_time.go index 6553bf28..65818857 100644 --- a/pkg/runtime/function/cast_time.go +++ b/pkg/runtime/function/cast_time.go @@ -36,9 +36,11 @@ import ( // FuncCastTime is https://dev.mysql.com/doc/refman/5.6/en/cast-functions.html#function_cast const FuncCastTime = "CAST_TIME" -var _timeMatchDay = regexp.MustCompile(`^\d{1,2} \d{1,3}(:\d{1,2}){0,2}$`) -var _timeMatchString = regexp.MustCompile(`^\d{1,3}(:\d{1,2}){1,2}$`) -var _timeMatchInt = regexp.MustCompile(`^\d{1,7}$`) +var ( + _timeMatchDay = regexp.MustCompile(`^\d{1,2} \d{1,3}(:\d{1,2}){0,2}$`) + _timeMatchString = regexp.MustCompile(`^\d{1,3}(:\d{1,2}){1,2}$`) + _timeMatchInt = regexp.MustCompile(`^\d{1,7}$`) +) var _ proto.Func = (*castTimeFunc)(nil) diff --git a/pkg/runtime/misc/extvalue/visitor.go b/pkg/runtime/misc/extvalue/visitor.go index 4242a937..2750bd18 100644 --- a/pkg/runtime/misc/extvalue/visitor.go +++ b/pkg/runtime/misc/extvalue/visitor.go @@ -279,7 +279,7 @@ func (vv *valueVisitor) VisitFunction(node *ast.Function) (interface{}, error) { return mysqlErrors.NewSQLError( mConstants.ERSPDoseNotExist, - mConstants.SSSPNotExist, + mConstants.SS42000, sb.String(), ) } diff --git a/pkg/runtime/optimize/ddl/check_table.go b/pkg/runtime/optimize/ddl/check_table.go index b07db450..f5935c68 100644 --- a/pkg/runtime/optimize/ddl/check_table.go +++ b/pkg/runtime/optimize/ddl/check_table.go @@ -19,7 +19,6 @@ package ddl import ( "context" - "github.com/arana-db/arana/pkg/runtime/plan/ddl" ) import ( @@ -27,6 +26,7 @@ import ( "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/runtime/optimize" + "github.com/arana-db/arana/pkg/runtime/plan/ddl" ) func init() { diff --git a/pkg/runtime/plan/dml/mapping.go b/pkg/runtime/plan/dml/mapping.go index fa52de81..bc63fc09 100644 --- a/pkg/runtime/plan/dml/mapping.go +++ b/pkg/runtime/plan/dml/mapping.go @@ -358,7 +358,7 @@ func (vt *virtualValueVisitor) VisitFunction(node *ast.Function) (interface{}, e return nil, mysqlErrors.NewSQLError( mConstants.ERSPDoseNotExist, - mConstants.SSSPNotExist, + mConstants.SS42000, sb.String(), ) } diff --git a/testdata/mock_interface.go b/testdata/mock_interface.go new file mode 100644 index 00000000..e56f8f9a --- /dev/null +++ b/testdata/mock_interface.go @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/arana-db/arana/pkg/proto (interfaces: FrontConn) + +// Package testdata is a generated GoMock package. +package testdata + +import ( + reflect "reflect" +) + +import ( + gomock "github.com/golang/mock/gomock" +) + +import ( + proto "github.com/arana-db/arana/pkg/proto" +) + +// MockFrontConn is a mock of FrontConn interface. +type MockFrontConn struct { + ctrl *gomock.Controller + recorder *MockFrontConnMockRecorder +} + +// MockFrontConnMockRecorder is the mock recorder for MockFrontConn. +type MockFrontConnMockRecorder struct { + mock *MockFrontConn +} + +// NewMockFrontConn creates a new mock instance. +func NewMockFrontConn(ctrl *gomock.Controller) *MockFrontConn { + mock := &MockFrontConn{ctrl: ctrl} + mock.recorder = &MockFrontConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFrontConn) EXPECT() *MockFrontConnMockRecorder { + return m.recorder +} + +// CharacterSet mocks base method. +func (m *MockFrontConn) CharacterSet() byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CharacterSet") + ret0, _ := ret[0].(byte) + return ret0 +} + +// CharacterSet indicates an expected call of CharacterSet. +func (mr *MockFrontConnMockRecorder) CharacterSet() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CharacterSet", reflect.TypeOf((*MockFrontConn)(nil).CharacterSet)) +} + +// ID mocks base method. +func (m *MockFrontConn) ID() uint32 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(uint32) + return ret0 +} + +// ID indicates an expected call of ID. +func (mr *MockFrontConnMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockFrontConn)(nil).ID)) +} + +// Schema mocks base method. +func (m *MockFrontConn) Schema() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Schema") + ret0, _ := ret[0].(string) + return ret0 +} + +// Schema indicates an expected call of Schema. +func (mr *MockFrontConnMockRecorder) Schema() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Schema", reflect.TypeOf((*MockFrontConn)(nil).Schema)) +} + +// ServerVersion mocks base method. +func (m *MockFrontConn) ServerVersion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServerVersion") + ret0, _ := ret[0].(string) + return ret0 +} + +// ServerVersion indicates an expected call of ServerVersion. +func (mr *MockFrontConnMockRecorder) ServerVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServerVersion", reflect.TypeOf((*MockFrontConn)(nil).ServerVersion)) +} + +// SetSchema mocks base method. +func (m *MockFrontConn) SetSchema(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetSchema", arg0) +} + +// SetSchema indicates an expected call of SetSchema. +func (mr *MockFrontConnMockRecorder) SetSchema(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSchema", reflect.TypeOf((*MockFrontConn)(nil).SetSchema), arg0) +} + +// SetTenant mocks base method. +func (m *MockFrontConn) SetTenant(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTenant", arg0) +} + +// SetTenant indicates an expected call of SetTenant. +func (mr *MockFrontConnMockRecorder) SetTenant(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTenant", reflect.TypeOf((*MockFrontConn)(nil).SetTenant), arg0) +} + +// SetTransientVariables mocks base method. +func (m *MockFrontConn) SetTransientVariables(arg0 map[string]proto.Value) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTransientVariables", arg0) +} + +// SetTransientVariables indicates an expected call of SetTransientVariables. +func (mr *MockFrontConnMockRecorder) SetTransientVariables(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTransientVariables", reflect.TypeOf((*MockFrontConn)(nil).SetTransientVariables), arg0) +} + +// Tenant mocks base method. +func (m *MockFrontConn) Tenant() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Tenant") + ret0, _ := ret[0].(string) + return ret0 +} + +// Tenant indicates an expected call of Tenant. +func (mr *MockFrontConnMockRecorder) Tenant() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tenant", reflect.TypeOf((*MockFrontConn)(nil).Tenant)) +} + +// TransientVariables mocks base method. +func (m *MockFrontConn) TransientVariables() map[string]proto.Value { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TransientVariables") + ret0, _ := ret[0].(map[string]proto.Value) + return ret0 +} + +// TransientVariables indicates an expected call of TransientVariables. +func (mr *MockFrontConnMockRecorder) TransientVariables() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TransientVariables", reflect.TypeOf((*MockFrontConn)(nil).TransientVariables)) +}