Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protocol: support query attribute since mysql 8.0.23 #55175

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,9 +983,10 @@ var funcs = map[string]functionClass{
ast.TiDBEncodeSQLDigest: &tidbEncodeSQLDigestFunctionClass{baseFunctionClass{ast.TiDBEncodeSQLDigest, 1, 1}},

// TiDB Sequence function.
ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}},
ast.LastVal: &lastValFunctionClass{baseFunctionClass{ast.LastVal, 1, 1}},
ast.SetVal: &setValFunctionClass{baseFunctionClass{ast.SetVal, 2, 2}},
ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}},
ast.LastVal: &lastValFunctionClass{baseFunctionClass{ast.LastVal, 1, 1}},
ast.SetVal: &setValFunctionClass{baseFunctionClass{ast.SetVal, 2, 2}},
ast.QueryAttrString: &getQueryAttrFunctionClass{baseFunctionClass{ast.QueryAttrString, 1, 1}},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ast.QueryAttrString: &getQueryAttrFunctionClass{baseFunctionClass{ast.QueryAttrString, 1, 1}},
// Query attributes and ...
ast.QueryAttrString: &getQueryAttrFunctionClass{baseFunctionClass{ast.QueryAttrString, 1, 1}},

I don't think this belongs under "TiDB Sequence function." Maybe think of a good name or leave it out if it is obvious.

}

// IsFunctionSupported check if given function name is a builtin sql function.
Expand Down
59 changes: 59 additions & 0 deletions pkg/expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression/expropt"
"github.com/pingcap/tidb/pkg/param"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
Expand All @@ -46,6 +47,7 @@ var (
_ functionClass = &valuesFunctionClass{}
_ functionClass = &bitCountFunctionClass{}
_ functionClass = &getParamFunctionClass{}
_ functionClass = &getQueryAttrFunctionClass{}
)

var (
Expand Down Expand Up @@ -1676,3 +1678,60 @@ func (b *builtinGetParamStringSig) evalString(ctx EvalContext, row chunk.Row) (s
}
return str, false, nil
}

// getQueryAttrFunctionClass for plan cache of prepared statements
type getQueryAttrFunctionClass struct {
baseFunctionClass
}

func (c *getQueryAttrFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString)
if err != nil {
return nil, err
}
bf.tp.SetFlen(mysql.MaxFieldVarCharLength)
sig := &builtinGetQueryAttrStringSig{baseBuiltinFunc: bf}
return sig, nil
}

type builtinGetQueryAttrStringSig struct {
baseBuiltinFunc
expropt.SessionVarsPropReader
}

func (b *builtinGetQueryAttrStringSig) Clone() builtinFunc {
newSig := &builtinGetQueryAttrStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGetQueryAttrStringSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SessionVarsPropReader.RequiredOptionalEvalProps()
}

func (b *builtinGetQueryAttrStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func (b *builtinGetQueryAttrStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) {
// This implements `mysql_query_attribute_string(str)`
func (b *builtinGetQueryAttrStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Often people know the function name and then have to fine the location of the code. This might help to make it easier to find.

sessionVars, err := b.GetSessionVars(ctx)
if err != nil {
return "", true, err
}

varName, isNull, err := b.args[0].EvalString(ctx, row)
if isNull || err != nil {
return "", true, err
}
attrs := sessionVars.QueryAttributes
if attrs == nil {
return "", true, nil
}
if v, ok := attrs[varName]; ok {
paramData, err := ExecBinaryParam(sessionVars.StmtCtx.TypeCtx(), []param.BinaryParam{v})
if err != nil {
return "", true, err
}
return paramData[0].EvalString(ctx, row)
}
return "", true, nil
}
33 changes: 17 additions & 16 deletions pkg/expression/function_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,23 @@ var UnCacheableFunctions = map[string]struct{}{

// unFoldableFunctions stores functions which can not be folded duration constant folding stage.
var unFoldableFunctions = map[string]struct{}{
ast.Sysdate: {},
ast.FoundRows: {},
ast.Rand: {},
ast.UUID: {},
ast.Sleep: {},
ast.RowFunc: {},
ast.Values: {},
ast.SetVar: {},
ast.GetVar: {},
ast.GetParam: {},
ast.Benchmark: {},
ast.DayName: {},
ast.NextVal: {},
ast.LastVal: {},
ast.SetVal: {},
ast.AnyValue: {},
ast.Sysdate: {},
ast.FoundRows: {},
ast.Rand: {},
ast.UUID: {},
ast.Sleep: {},
ast.RowFunc: {},
ast.Values: {},
ast.SetVar: {},
ast.GetVar: {},
ast.GetParam: {},
ast.Benchmark: {},
ast.DayName: {},
ast.NextVal: {},
ast.LastVal: {},
ast.SetVal: {},
ast.AnyValue: {},
ast.QueryAttrString: {},
}

// DisableFoldFunctions stores functions which prevent child scope functions from being constant folded.
Expand Down
2 changes: 1 addition & 1 deletion pkg/param/binary_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
var ErrUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)

// BinaryParam stores the information decoded from the binary protocol
// It can be further parsed into `expression.Expression` through the `ExecArgs` function in this package
// It can be further parsed into `expression.Expression` through the expression.ExecBinaryParam function in expression package
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// It can be further parsed into `expression.Expression` through the expression.ExecBinaryParam function in expression package
// It can be further parsed into `expression.Expression` through the expression.ExecBinaryParam function in the expression package

type BinaryParam struct {
Tp byte
IsUnsigned bool
Expand Down
7 changes: 4 additions & 3 deletions pkg/parser/ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,10 @@ const (
GetMvccInfo = "get_mvcc_info"

// Sequence function.
NextVal = "nextval"
LastVal = "lastval"
SetVal = "setval"
NextVal = "nextval"
LastVal = "lastval"
SetVal = "setval"
QueryAttrString = "mysql_query_attribute_string"
)

type FuncCallExprType int8
Expand Down
4 changes: 3 additions & 1 deletion pkg/parser/mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ const (
ClientDeprecateEOF // CLIENT_DEPRECATE_EOF
ClientOptionalResultsetMetadata // CLIENT_OPTIONAL_RESULTSET_METADATA, Not supported: https://dev.mysql.com/doc/c-api/8.0/en/c-api-optional-metadata.html
ClientZstdCompressionAlgorithm // CLIENT_ZSTD_COMPRESSION_ALGORITHM
// 1 << 27 == CLIENT_QUERY_ATTRIBUTES
ClientQueryAttributes // CLIENT_QUERY_ATTRIBUTES
// 1 << 28 == MULTI_FACTOR_AUTHENTICATION
// 1 << 29 == CLIENT_CAPABILITY_EXTENSION
// 1 << 30 == CLIENT_SSL_VERIFY_SERVER_CERT
Expand Down Expand Up @@ -665,6 +665,8 @@ const (
CursorTypeReadOnly = 1 << iota
CursorTypeForUpdate
CursorTypeScrollable
// ParameterCountAvailable On when the client will send the parameter count even for 0 parameters.
ParameterCountAvailable
)

// ZlibCompressDefaultLevel is the zlib compression level for the compressed protocol
Expand Down
1 change: 1 addition & 0 deletions pkg/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,7 @@ func TestBuiltin(t *testing.T) {
{`SELECT IS_USED_LOCK(@str);`, true, "SELECT IS_USED_LOCK(@`str`)"},
{`SELECT MASTER_POS_WAIT(@log_name, @log_pos), MASTER_POS_WAIT(@log_name, @log_pos, @timeout), MASTER_POS_WAIT(@log_name, @log_pos, @timeout, @channel_name);`, true, "SELECT MASTER_POS_WAIT(@`log_name`, @`log_pos`),MASTER_POS_WAIT(@`log_name`, @`log_pos`, @`timeout`),MASTER_POS_WAIT(@`log_name`, @`log_pos`, @`timeout`, @`channel_name`)"},
{`SELECT NAME_CONST('myname', 14);`, true, "SELECT NAME_CONST(_UTF8MB4'myname', 14)"},
{`SELECT MYSQL_QUERY_ATTRIBUTE_STRING(@str);`, true, "SELECT MYSQL_QUERY_ATTRIBUTE_STRING(@`str`)"},
{`SELECT RELEASE_ALL_LOCKS();`, true, "SELECT RELEASE_ALL_LOCKS()"},
{`SELECT UUID();`, true, "SELECT UUID()"},
{`SELECT UUID_SHORT()`, true, "SELECT UUID_SHORT()"},
Expand Down
65 changes: 64 additions & 1 deletion pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import (
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/param"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/auth"
Expand Down Expand Up @@ -1337,6 +1338,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {

cc.server.releaseToken(token)
cc.lastActive = time.Now()
cc.ctx.GetSessionVars().QueryAttributes = nil
}()

vars := cc.ctx.GetSessionVars()
Expand Down Expand Up @@ -1373,8 +1375,14 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
// See http://dev.mysql.com/doc/internals/en/com-query.html
if len(data) > 0 && data[len(data)-1] == 0 {
data = data[:len(data)-1]
dataStr = string(hack.String(data))
}
pos, err := cc.parseQueryAttributes(ctx, data)
if err != nil {
return err
}
// fix lastPacket for display/log
cc.lastPacket = append([]byte{cc.lastPacket[0]}, data[pos:]...)
dataStr = string(hack.String(data[pos:]))
return cc.handleQuery(ctx, dataStr)
case mysql.ComFieldList:
return cc.handleFieldList(ctx, dataStr)
Expand Down Expand Up @@ -1698,6 +1706,61 @@ func (cc *clientConn) audit(eventType plugin.GeneralEvent) {
}
}

// parseQueryAttributes support query attributes since mysql 8.0.23
// see https://dev.mysql.com/doc/refman/8.0/en/query-attributes.html
// https://archive.fosdem.org/2021/schedule/event/mysql_protocl/attachments/slides/4274/export/events/attachments/mysql_protocl/slides/4274/FOSDEM21_MySQL_Protocols_Query_Attributes.pdf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// https://archive.fosdem.org/2021/schedule/event/mysql_protocl/attachments/slides/4274/export/events/attachments/mysql_protocl/slides/4274/FOSDEM21_MySQL_Protocols_Query_Attributes.pdf
// https://archive.fosdem.org/2021/schedule/event/mysql_protocl/attachments/slides/4274/export/events/attachments/mysql_protocl/slides/4274/FOSDEM21_MySQL_Protocols_Query_Attributes.pdf
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html

func (cc *clientConn) parseQueryAttributes(ctx context.Context, data []byte) (pos int, err error) {
if cc.capability&mysql.ClientQueryAttributes > 0 {
paraCount, _, np := util2.ParseLengthEncodedInt(data)
numParams := int(paraCount)
pos += np
_, _, np = util2.ParseLengthEncodedInt(data[pos:])
pos += np
ps := make([]param.BinaryParam, numParams)
names := make([]string, numParams)
if paraCount > 0 {
var (
nullBitmaps []byte
paramTypes []byte
)
cc.initInputEncoder(ctx)
nullBitmapLen := (numParams + 7) >> 3
nullBitmaps = data[pos : pos+nullBitmapLen]
pos += nullBitmapLen
if data[pos] != 1 {
return 0, mysql.ErrMalformPacket
}

pos++
for i := 0; i < numParams; i++ {
paramTypes = append(paramTypes, data[pos:pos+2]...)
pos += 2
s, _, p, e := util2.ParseLengthEncodedBytes(data[pos:])
if e != nil {
return 0, mysql.ErrMalformPacket
}
names[i] = string(hack.String(s))
pos += p
}

boundParams := make([][]byte, numParams)
p := 0
if p, err = parseBinaryParams(ps, boundParams, nullBitmaps, paramTypes, data[pos:], cc.inputDecoder); err != nil {
return
}

pos += p
psWithName := make(map[string]param.BinaryParam, numParams)
for i := range names {
psWithName[names[i]] = ps[i]
}
cc.ctx.GetSessionVars().QueryAttributes = psWithName
}
}

return
}

// handleQuery executes the sql query string and writes result set or result ok to the client.
// As the execution time of this function represents the performance of TiDB, we do time log and metrics here.
// Some special queries like `load data` that does not return result, which is handled in handleFileTransInConn.
Expand Down
34 changes: 29 additions & 5 deletions pkg/server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"github.com/pingcap/tidb/pkg/server/internal/dump"
"github.com/pingcap/tidb/pkg/server/internal/parse"
"github.com/pingcap/tidb/pkg/server/internal/resultset"
util2 "github.com/pingcap/tidb/pkg/server/internal/util"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/sessiontxn"
storeerr "github.com/pingcap/tidb/pkg/store/driver/error"
Expand Down Expand Up @@ -180,6 +181,13 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
)
cc.initInputEncoder(ctx)
numParams := stmt.NumParams()
clientHasQueryAttr := cc.capability&mysql.ClientQueryAttributes > 0
if clientHasQueryAttr && (numParams > 0 || flag&mysql.ParameterCountAvailable > 0) {
paraCount, _, np := util2.ParseLengthEncodedInt(data[pos:])
numParams = int(paraCount)
pos += np
}

args := make([]param.BinaryParam, numParams)
if numParams > 0 {
nullBitmapLen := (numParams + 7) >> 3
Expand All @@ -192,12 +200,28 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
// new param bound flag
if data[pos] == 1 {
pos++
if len(data) < (pos + (numParams << 1)) {
return mysql.ErrMalformPacket
// For client that has query attribute ability, parameter's name will also be sent.
// However, it is useless for execute statement, so we ignore it here.
if clientHasQueryAttr {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also set SessionVars.QueryAttributes in this branch? Or a statement cannot use related functions if it's executed through COM_EXECUTE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think set SessionVars.QueryAttributes in this branch is a good idea since query attribute only 'works' inCOM_QUERY. SessionVars.QueryAttributes has no other purpose beyond being used for related functions.

for i := 0; i < numParams; i++ {
paramTypes = append(paramTypes, data[pos:pos+2]...)
pos += 2
// parse names
_, _, p, e := util2.ParseLengthEncodedBytes(data[pos:])
if e != nil {
return mysql.ErrMalformPacket
}
pos += p
}
} else {
if len(data) < (pos + (numParams << 1)) {
return mysql.ErrMalformPacket
}

paramTypes = data[pos : pos+(numParams<<1)]
pos += numParams << 1
}

paramTypes = data[pos : pos+(numParams<<1)]
pos += numParams << 1
paramValues = data[pos:]
// Just the first StmtExecute packet contain parameters type,
// we need save it for further use.
Expand All @@ -206,7 +230,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramValues = data[pos+1:]
}

err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
_, err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
// This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine)
errReset := stmt.Reset()
if errReset != nil {
Expand Down
5 changes: 2 additions & 3 deletions pkg/server/conn_stmt_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import (
var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)

// parseBinaryParams decodes the binary params according to the protocol
func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte, enc *util2.InputDecoder) (err error) {
pos := 0
func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte, enc *util2.InputDecoder) (pos int, err error) {
if enc == nil {
enc = util2.NewInputDecoder(charset.CharsetUTF8)
}
Expand Down Expand Up @@ -76,7 +75,7 @@ func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBit
}

if (i<<1)+1 >= len(paramTypes) {
return mysql.ErrMalformPacket
return pos, mysql.ErrMalformPacket
}

tp := paramTypes[i<<1]
Expand Down
Loading