diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index afffd326c8c..e8cd4b87652 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -55,7 +55,7 @@ func leadingCommentEnd(text string) (end int) { // Found visible characters. Look for '/*' at the beginning // and '*/' somewhere after that. - if len(remainingText) < 4 || remainingText[:2] != "/*" { + if len(remainingText) < 4 || remainingText[:2] != "/*" || remainingText[2] == '!' { break } commentLength := 4 + strings.Index(remainingText[2:], "*/") @@ -93,8 +93,8 @@ func trailingCommentStart(text string) (start int) { // Find the beginning of the comment startCommentPos := strings.LastIndex(text[:reducedLen-2], "/*") - if startCommentPos < 0 { - // Badly formatted sql :/ + if startCommentPos < 0 || text[startCommentPos+2] == '!' { + // Badly formatted sql, or a special /*! comment break } @@ -164,28 +164,6 @@ func hasCommentPrefix(sql string) bool { return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-')) } -// StripComments removes all comments from the string regardless -// of where they occur -func StripComments(sql string) string { - sql = StripLeadingComments(sql) // handle -- or /* ... */ at the beginning - - for { - start := strings.Index(sql, "/*") - if start == -1 { - break - } - end := strings.Index(sql, "*/") - if end <= 1 { - break - } - sql = sql[:start] + sql[end+2:] - } - - sql = strings.TrimFunc(sql, unicode.IsSpace) - - return sql -} - // ExtractMysqlComment extracts the version and SQL from a comment-only query // such as /*!50708 sql here */ func ExtractMysqlComment(sql string) (version string, innerSQL string) { diff --git a/go/vt/sqlparser/comments_test.go b/go/vt/sqlparser/comments_test.go index bd9c26f9b8d..3d875faf1cb 100644 --- a/go/vt/sqlparser/comments_test.go +++ b/go/vt/sqlparser/comments_test.go @@ -119,20 +119,32 @@ func TestSplitComments(t *testing.T) { outSQL: "foo", outLeadingComments: "", outTrailingComments: "", + }, { + input: "select 1 from t where col = '*//*'", + outSQL: "select 1 from t where col = '*//*'", + outLeadingComments: "", + outTrailingComments: "", + }, { + input: "/*! select 1 */", + outSQL: "/*! select 1 */", + outLeadingComments: "", + outTrailingComments: "", }} for _, testCase := range testCases { - gotSQL, gotComments := SplitMarginComments(testCase.input) - gotLeadingComments, gotTrailingComments := gotComments.Leading, gotComments.Trailing + t.Run(testCase.input, func(t *testing.T) { + gotSQL, gotComments := SplitMarginComments(testCase.input) + gotLeadingComments, gotTrailingComments := gotComments.Leading, gotComments.Trailing - if gotSQL != testCase.outSQL { - t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL) - } - if gotLeadingComments != testCase.outLeadingComments { - t.Errorf("test input: '%s', got LeadingComments\n%+v, want\n%+v", testCase.input, gotLeadingComments, testCase.outLeadingComments) - } - if gotTrailingComments != testCase.outTrailingComments { - t.Errorf("test input: '%s', got TrailingComments\n%+v, want\n%+v", testCase.input, gotTrailingComments, testCase.outTrailingComments) - } + if gotSQL != testCase.outSQL { + t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL) + } + if gotLeadingComments != testCase.outLeadingComments { + t.Errorf("test input: '%s', got LeadingComments\n%+v, want\n%+v", testCase.input, gotLeadingComments, testCase.outLeadingComments) + } + if gotTrailingComments != testCase.outTrailingComments { + t.Errorf("test input: '%s', got TrailingComments\n%+v, want\n%+v", testCase.input, gotTrailingComments, testCase.outTrailingComments) + } + }) } } @@ -212,85 +224,6 @@ a`, } } -func TestRemoveComments(t *testing.T) { - var testCases = []struct { - input, outSQL string - }{{ - input: "/", - outSQL: "/", - }, { - input: "*/", - outSQL: "*/", - }, { - input: "/*/", - outSQL: "/*/", - }, { - input: "/*a", - outSQL: "/*a", - }, { - input: "/*a*", - outSQL: "/*a*", - }, { - input: "/*a**", - outSQL: "/*a**", - }, { - input: "/*b**a*/", - outSQL: "", - }, { - input: "/*a*/", - outSQL: "", - }, { - input: "/**/", - outSQL: "", - }, { - input: "/*!*/", - outSQL: "", - }, { - input: "/*!a*/", - outSQL: "", - }, { - input: "/*b*/ /*a*/", - outSQL: "", - }, { - input: `/*b*/ --foo -bar`, - outSQL: "bar", - }, { - input: "foo /* bar */", - outSQL: "foo", - }, { - input: "foo /* bar */ baz", - outSQL: "foo baz", - }, { - input: "/* foo */ bar", - outSQL: "bar", - }, { - input: "-- /* foo */ bar", - outSQL: "", - }, { - input: "foo -- bar */", - outSQL: "foo -- bar */", - }, { - input: `/* -foo */ bar`, - outSQL: "bar", - }, { - input: `-- foo bar -a`, - outSQL: "a", - }, { - input: `-- foo bar`, - outSQL: "", - }} - for _, testCase := range testCases { - gotSQL := StripComments(testCase.input) - - if gotSQL != testCase.outSQL { - t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL) - } - } -} - func TestExtractMysqlComment(t *testing.T) { var testCases = []struct { input, outSQL, outVersion string diff --git a/go/vt/vtexplain/testdata/test-schema.sql b/go/vt/vtexplain/testdata/test-schema.sql index e89714a474e..b901cceb25e 100644 --- a/go/vt/vtexplain/testdata/test-schema.sql +++ b/go/vt/vtexplain/testdata/test-schema.sql @@ -60,10 +60,4 @@ create table test_partitioned ( date_create int, primary key(id) ) Engine=InnoDB -/*!50100 PARTITION BY RANGE (date_create) -(PARTITION p2018_06_14 VALUES LESS THAN (1528959600) ENGINE = InnoDB, - PARTITION p2018_06_15 VALUES LESS THAN (1529046000) ENGINE = InnoDB, - PARTITION p2018_06_16 VALUES LESS THAN (1529132400) ENGINE = InnoDB, - PARTITION p2018_06_17 VALUES LESS THAN (1529218800) ENGINE = InnoDB) -*/ ; diff --git a/go/vt/vtexplain/vtexplain.go b/go/vt/vtexplain/vtexplain.go index 1887fb241a0..d1db8cea0e7 100644 --- a/go/vt/vtexplain/vtexplain.go +++ b/go/vt/vtexplain/vtexplain.go @@ -192,7 +192,7 @@ func parseSchema(sqlSchema string, opts *Options) ([]*sqlparser.DDL, error) { if sql == "" { break } - sql = sqlparser.StripComments(sql) + sql, _ = sqlparser.SplitMarginComments(sql) if sql == "" { continue } diff --git a/go/vt/vtexplain/vtexplain_flaky_test.go b/go/vt/vtexplain/vtexplain_flaky_test.go index 59a1efb2fcf..36b84ba9351 100644 --- a/go/vt/vtexplain/vtexplain_flaky_test.go +++ b/go/vt/vtexplain/vtexplain_flaky_test.go @@ -48,7 +48,7 @@ func initTest(mode string, opts *Options, t *testing.T) { opts.ExecutionMode = mode err = Init(string(vSchema), string(schema), opts) - require.NoError(t, err, "vtexplain Init error") + require.NoError(t, err, "vtexplain Init error\n%s", string(schema)) } func testExplain(testcase string, opts *Options, t *testing.T) {