Skip to content

Commit

Permalink
feature: split statements for postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
kanzihuang committed Oct 18, 2023
1 parent 655bbd9 commit ac8d9f0
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 20 deletions.
8 changes: 4 additions & 4 deletions go/vt/sqlparser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ var ErrEmpty = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.EmptyQ

// SplitStatement returns the first sql statement up to either a ; or EOF
// and the remainder from the given buffer
func SplitStatement(blob string) (string, string, error) {
tokenizer := NewStringTokenizer(blob)
func SplitStatement(blob string, opts ...TokenizerOpt) (string, string, error) {
tokenizer := NewStringTokenizer(blob, opts...)
tkn := 0
for {
tkn, _ = tokenizer.Scan()
Expand All @@ -263,7 +263,7 @@ func SplitStatement(blob string) (string, string, error) {

// SplitStatementToPieces split raw sql statement that may have multi sql pieces to sql pieces
// returns the sql pieces blob contains; or error if sql cannot be parsed
func SplitStatementToPieces(blob string) ([]string, error) {
func SplitStatementToPieces(blob string, opts ...TokenizerOpt) ([]string, error) {
// fast path: the vast majority of SQL statements do not have semicolons in them
if blob == "" {
return nil, nil
Expand All @@ -276,7 +276,7 @@ func SplitStatementToPieces(blob string) ([]string, error) {
}

pieces := make([]string, 0, 16)
tokenizer := NewStringTokenizer(blob)
tokenizer := NewStringTokenizer(blob, opts...)
for {
stmt, err := SplitNext(tokenizer)
if err == io.EOF {
Expand Down
54 changes: 47 additions & 7 deletions go/vt/sqlparser/split_next_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ limitations under the License.
package sqlparser

import (
"github.com/stretchr/testify/require"
"io"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

const functionShowCreateTable = `CREATE OR REPLACE FUNCTION "public"."showcreatetable"("namespace" varchar, "tablename" varchar) RETURNS "pg_catalog"."varchar" AS $BODY$
Expand Down Expand Up @@ -104,13 +103,44 @@ const functionShowCreateTable = `CREATE OR REPLACE FUNCTION "public"."showcreate
end
$BODY$ LANGUAGE plpgsql VOLATILE COST 100`

const insertIntoRuleDescSections = `INSERT INTO "public"."rule_desc_sections" VALUES ('AYq8jq9Zzf2mMqOTe-St', 'AYq8jq9Zzf2mMqOTe-Ss', 'default', '<p>Typically, backslashes are seen only as part of escape sequences. Therefore, the use of a backslash outside of a raw string or escape sequence
looks suspiciously like a broken escape sequence.</p>
<p>Characters recognized as escape-able are: <code>abfnrtvox\''"</code></p>
<h2>Noncompliant Code Example</h2>
<pre>
s = "Hello \world."
t = "Nice to \ meet you"
u = "Let''s have \ lunch"
</pre>
<h2>Compliant Solution</h2>
<pre>
s = "Hello world."
t = "Nice to \\ meet you"
u = r"Let''s have \ lunch" // raw string
</pre>
<h2>Deprecated</h2>
<p>This rule is deprecated, and will eventually be removed.</p>', NULL, NULL);select 1`

func Test_SplitNext(t *testing.T) {
testcases := []struct {
name string
input string
output string
count int
name string
input string
output string
count int
dialect Dialect
}{
{
name: "mysql select '\\'''",
input: "select '\\\\\\'hello''';select 2",
count: 2,
dialect: MysqlDialect{},
},
{
name: "postgres select '\\'''",
input: "select '\\''hello''' from dual;select 2",
count: 2,
dialect: PostgresDialect{},
},
{
name: "with blanks",
input: "select * from `my-table`; \t; \n; \n\t\t ;select * from `my-table`;",
Expand Down Expand Up @@ -211,14 +241,24 @@ func Test_SplitNext(t *testing.T) {
input: functionShowCreateTable,
count: 1,
},
{
name: "insert into rule_desc_sections",
input: insertIntoRuleDescSections,
count: 2,
dialect: PostgresDialect{},
},
}

for _, tcase := range testcases {
t.Run(tcase.name, func(t *testing.T) {
if tcase.output == "" {
tcase.output = tcase.input
}
tokenizer := NewReaderTokenizer(strings.NewReader(tcase.input), WithCacheInBuffer())
if tcase.dialect == nil {
tcase.dialect = MysqlDialect{}
}
tokenizer := NewReaderTokenizer(strings.NewReader(tcase.input),
WithCacheInBuffer(), WithDialect(tcase.dialect))
var sb strings.Builder
var i int
for {
Expand Down
27 changes: 27 additions & 0 deletions go/vt/sqlparser/sqlparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,30 @@ loop:
}
return strings.TrimSpace(sb.String()), nil
}

type Dialect interface {
iDialect()
EscapingBackslash() bool
}

var _ Dialect = MysqlDialect{}

type MysqlDialect struct {
}

func (m MysqlDialect) EscapingBackslash() bool {
return true
}

func (m MysqlDialect) iDialect() {}

var _ Dialect = PostgresDialect{}

type PostgresDialect struct {
}

func (p PostgresDialect) EscapingBackslash() bool {
return false
}

func (p PostgresDialect) iDialect() {}
20 changes: 15 additions & 5 deletions go/vt/sqlparser/sqlparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@ import (

func Test_ParseNext(t *testing.T) {
tests := []struct {
name string
input string
want string
err string
name string
dialect Dialect
input string
want string
err string
}{
{
name: "mysql select '\\\\\\'hello'''",
input: "select '\\\\\\'hello''' from dual",
want: "select '\\\\\\'hello\\'' from dual",
dialect: MysqlDialect{},
},
{
name: "create table `my-table`",
input: "create table `my-table` (\n\t`my-id` bigint(20)\n)",
Expand All @@ -30,7 +37,10 @@ func Test_ParseNext(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tokens := NewReaderTokenizer(strings.NewReader(test.input))
if test.dialect == nil {
test.dialect = MysqlDialect{}
}
tokens := NewReaderTokenizer(strings.NewReader(test.input), WithDialect(test.dialect))
tree, err := ParseNext(tokens)
if len(test.err) > 0 {
require.Error(t, err)
Expand Down
41 changes: 37 additions & 4 deletions go/vt/sqlparser/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ type Tokenizer struct {
multi bool
specialComment *Tokenizer

buf *buffer.Buffer
buf *buffer.Buffer
dialect Dialect
}

type TokenizerOpt func(*Tokenizer)
Expand All @@ -56,6 +57,12 @@ func WithCacheInBuffer() TokenizerOpt {
}
}

func WithDialect(dialect Dialect) TokenizerOpt {
return func(tokenizer *Tokenizer) {
tokenizer.dialect = dialect
}
}

// NewStringTokenizer creates a new Tokenizer for the
// sql string.
func NewStringTokenizer(sql string, opts ...TokenizerOpt) *Tokenizer {
Expand All @@ -64,6 +71,7 @@ func NewStringTokenizer(sql string, opts ...TokenizerOpt) *Tokenizer {
tokenizer := &Tokenizer{
buf: buffer.NewStringBuffer(sql),
BindVars: make(map[string]struct{}),
dialect: MysqlDialect{},
}
for _, opt := range opts {
opt(tokenizer)
Expand All @@ -79,6 +87,7 @@ func NewReaderTokenizer(reader io.Reader, opts ...TokenizerOpt) *Tokenizer {
tokenizer := &Tokenizer{
buf: buffer.NewReaderBuffer(reader),
BindVars: make(map[string]struct{}),
dialect: MysqlDialect{},
}
for _, opt := range opts {
opt(tokenizer)
Expand Down Expand Up @@ -585,11 +594,34 @@ func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, string) {
tkn.skip(1)
return typ, sb.String()
}
fallthrough
tkn.next()
sb.WriteString(tkn.buf.ReadBuffer())
tkn.skip(1)

case '\\':
sb.WriteString(tkn.buf.ReadBuffer())
return tkn.scanStringSlow(&sb, delim, typ)
if tkn.dialect.EscapingBackslash() {
var ch uint16
sb.WriteString(tkn.buf.ReadBuffer())
tkn.skip(1)
if tkn.cur() == eofChar {
// String terminates mid escape character.
return LEX_ERROR, sb.String()
}
// Preserve escaping of % and _
if tkn.cur() == '%' || tkn.cur() == '_' {
sb.WriteByte('\\')
ch = tkn.cur()
} else if decodedChar := sqltypes.SQLDecodeMap[byte(tkn.cur())]; decodedChar == sqltypes.DontEscape {
ch = tkn.cur()
} else {
ch = uint16(decodedChar)
}
sb.WriteByte(byte(ch))
tkn.skip(1)
//return tkn.scanStringSlow(&sb, delim, typ)
} else {
tkn.next()
}

case eofChar:
sb.WriteString(tkn.buf.ReadBuffer())
Expand Down Expand Up @@ -639,6 +671,7 @@ func (tkn *Tokenizer) scanStringSlow(buffer *strings.Builder, delim uint16, typ
tkn.skip(1) // Read one past the delim or escape character.

if ch == '\\' {
tkn.skip(1)
if tkn.cur() == eofChar {
// String terminates mid escape character.
return LEX_ERROR, buffer.String()
Expand Down

0 comments on commit ac8d9f0

Please sign in to comment.