From cc06faf6c34c1631f3911b5035d467988d4c21f6 Mon Sep 17 00:00:00 2001 From: Scott Lanning Date: Tue, 23 Jan 2018 14:59:43 +0100 Subject: [PATCH] ignore /*! mysql-specific */ comment statements This is related to https://github.com/youtube/vitess/issues/3520 Also ignore /*!50708 mysql-version-specific */ comments. I'm not sure if it's a good idea. Instead of stripping leading comments, in case it's /*! ... */ we don't strip it, and consider it a new StmtComment statement type. (Not sure if that should be added to ast.go . I didn't.) I assumed this kind of comment was the only thing in the query, so if there's something like "/*! ... */ select ..." it wouldn't work. The handleComment in executor.go basically does nothing, returning &sqltypes.Result{} --- go/vt/sqlparser/analyzer.go | 4 ++++ go/vt/sqlparser/analyzer_test.go | 2 ++ go/vt/sqlparser/comments.go | 16 +++++++++++++++ go/vt/sqlparser/comments_test.go | 34 ++++++++++++++++++++++++++++++++ go/vt/vtgate/executor.go | 9 +++++++++ go/vt/vtgate/executor_test.go | 20 +++++++++++++++++++ 6 files changed, 85 insertions(+) diff --git a/go/vt/sqlparser/analyzer.go b/go/vt/sqlparser/analyzer.go index 08d3ccb6e06..97bd8aeaa80 100644 --- a/go/vt/sqlparser/analyzer.go +++ b/go/vt/sqlparser/analyzer.go @@ -45,6 +45,7 @@ const ( StmtUse StmtOther StmtUnknown + StmtComment ) // Preview analyzes the beginning of the query using a simpler and faster @@ -91,6 +92,9 @@ func Preview(sql string) int { case "analyze", "describe", "desc", "explain", "repair", "optimize", "truncate": return StmtOther } + if strings.Index(trimmed, "/*!") == 0 { + return StmtComment + } return StmtUnknown } diff --git a/go/vt/sqlparser/analyzer_test.go b/go/vt/sqlparser/analyzer_test.go index c345feb70ff..c8f6ec0ed3d 100644 --- a/go/vt/sqlparser/analyzer_test.go +++ b/go/vt/sqlparser/analyzer_test.go @@ -66,6 +66,8 @@ func TestPreview(t *testing.T) { {"/* leading comment */ select ...", StmtSelect}, {"/* leading comment */ /* leading comment 2 */ select ...", StmtSelect}, + {"/*! MySQL-specific comment */", StmtComment}, + {"/*!50708 MySQL-version comment */", StmtComment}, {"-- leading single line comment \n select ...", StmtSelect}, {"-- leading single line comment \n -- leading single line comment 2\n select ...", StmtSelect}, diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index 62ad15e063e..b5ac471df5e 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -121,6 +121,10 @@ func StripLeadingComments(sql string) string { if index <= 1 { return sql } + // don't strip /*! ... */ or /*!50700 ... */ + if len(sql) > 2 && sql[2] == '!' { + return sql + } sql = sql[index+2:] case '-': // Single line comment @@ -140,3 +144,15 @@ func StripLeadingComments(sql string) string { func hasCommentPrefix(sql string) bool { return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-')) } + +// ExtractMysqlComment extracts the version and SQL from a comment-only query +// such as /*!50708 sql here */ +func ExtractMysqlComment(sql string) (version string, innerSQL string) { + sql = sql[3 : len(sql)-2] + + endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool { return !unicode.IsDigit(c) }) + version = sql[0:endOfVersionIndex] + innerSQL = strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace) + + return version, innerSQL +} diff --git a/go/vt/sqlparser/comments_test.go b/go/vt/sqlparser/comments_test.go index f23c4f99e0c..d174a9aa10d 100644 --- a/go/vt/sqlparser/comments_test.go +++ b/go/vt/sqlparser/comments_test.go @@ -128,6 +128,12 @@ func TestStripLeadingComments(t *testing.T) { }, { input: "/**/", outSQL: "", + }, { + input: "/*!*/", + outSQL: "/*!*/", + }, { + input: "/*!a*/", + outSQL: "/*!a*/", }, { input: "/*b*/ /*a*/", outSQL: "", @@ -167,3 +173,31 @@ a`, } } } + +func TestExtractMysqlComment(t *testing.T) { + var testCases = []struct { + input, outSQL, outVersion string + }{{ + input: "/*!50708SET max_execution_time=5000 */", + outSQL: "SET max_execution_time=5000", + outVersion: "50708", + }, { + input: "/*!50708 SET max_execution_time=5000*/", + outSQL: "SET max_execution_time=5000", + outVersion: "50708", + }, { + input: "/*! SET max_execution_time=5000*/", + outSQL: "SET max_execution_time=5000", + outVersion: "", + }} + for _, testCase := range testCases { + gotVersion, gotSQL := ExtractMysqlComment(testCase.input) + + if gotVersion != testCase.outVersion { + t.Errorf("test input: '%s', got version\n%+v, want\n%+v", testCase.input, gotVersion, testCase.outVersion) + } + if gotSQL != testCase.outSQL { + t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL) + } + } +} diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 89b990b0909..10cd0fc5a6d 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -195,6 +195,8 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st return e.handleUse(ctx, safeSession, sql, bindVars) case sqlparser.StmtOther: return e.handleOther(ctx, safeSession, sql, bindVars, target, logStats) + case sqlparser.StmtComment: + return e.handleComment(ctx, safeSession, sql, bindVars, target, logStats) } return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unrecognized statement: %s", sql) } @@ -705,6 +707,13 @@ func (e *Executor) handleOther(ctx context.Context, safeSession *SafeSession, sq return result, err } +func (e *Executor) handleComment(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, target querypb.Target, logStats *LogStats) (*sqltypes.Result, error) { + _, sql = sqlparser.ExtractMysqlComment(sql) + + // Not sure if this is a good idea. + return &sqltypes.Result{}, nil +} + // StreamExecute executes a streaming query. func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, target querypb.Target, callback func(*sqltypes.Result) error) (err error) { logStats := NewLogStats(ctx, method, sql, bindVars) diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 15a249c5a16..53c56903df5 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -734,6 +734,26 @@ func TestExecutorUse(t *testing.T) { } } +func TestExecutorComment(t *testing.T) { + executor, _, _, _ := createExecutorEnv() + + stmts := []string{ + "/*! SET max_execution_time=5000*/", + "/*!50708 SET max_execution_time=5000*/", + } + wantResult := &sqltypes.Result{} + + for _, stmt := range stmts { + gotResult, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(gotResult, wantResult) { + t.Errorf("Exec %s: %v, want %v", stmt, gotResult, wantResult) + } + } +} + func TestExecutorOther(t *testing.T) { executor, sbc1, sbc2, sbclookup := createExecutorEnv()