diff --git a/engine_test.go b/engine_test.go index 9b853ab8d..49826275d 100644 --- a/engine_test.go +++ b/engine_test.go @@ -529,6 +529,18 @@ var queries = []struct { `SELECT -1`, []sql.Row{{int64(-1)}}, }, + { + ` + SHOW WARNINGS + `, + []sql.Row{}, + }, + { + ` + SHOW WARNINGS LIMIT 0 + `, + []sql.Row{}, + }, } func TestQueries(t *testing.T) { @@ -548,6 +560,121 @@ func TestQueries(t *testing.T) { } }) } + +func TestWarnings(t *testing.T) { + ctx := newCtx() + ctx.Session.Warn(&sql.Warning{Code: 1}) + ctx.Session.Warn(&sql.Warning{Code: 2}) + ctx.Session.Warn(&sql.Warning{Code: 3}) + + var queries = []struct { + query string + expected []sql.Row + }{ + { + ` + SHOW WARNINGS + `, + []sql.Row{ + {"", 3, ""}, + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 1 + `, + []sql.Row{ + {"", 3, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 1,2 + `, + []sql.Row{ + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 0 + `, + []sql.Row{ + {"", 3, ""}, + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 2,0 + `, + []sql.Row{ + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 10 + `, + []sql.Row{ + {"", 3, ""}, + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 10,1 + `, + []sql.Row{}, + }, + } + + e := newEngine(t) + ep := newEngineWithParallelism(t, 2) + + t.Run("sequential", func(t *testing.T) { + for _, tt := range queries { + testQueryWithContext(ctx, t, e, tt.query, tt.expected) + } + }) + + t.Run("parallel", func(t *testing.T) { + for _, tt := range queries { + testQueryWithContext(ctx, t, ep, tt.query, tt.expected) + } + }) +} + +func TestClearWarnings(t *testing.T) { + require := require.New(t) + e := newEngine(t) + ctx := newCtx() + ctx.Session.Warn(&sql.Warning{Code: 1}) + ctx.Session.Warn(&sql.Warning{Code: 2}) + ctx.Session.Warn(&sql.Warning{Code: 3}) + + _, iter, err := e.Query(ctx, "SHOW WARNINGS") + require.NoError(err) + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.Equal(3, len(rows)) + + _, iter, err = e.Query(ctx, "SHOW WARNINGS LIMIT 1") + require.NoError(err) + rows, err = sql.RowIterToRows(iter) + require.NoError(err) + require.Equal(1, len(rows)) + + _, _, err = e.Query(ctx, "SELECT * FROM mytable LIMIT 1") + require.NoError(err) + require.Equal(0, len(ctx.Session.Warnings())) +} + func TestDescribe(t *testing.T) { e := newEngine(t) @@ -923,11 +1050,13 @@ func TestInnerNestedInNaturalJoins(t *testing.T) { } func testQuery(t *testing.T, e *sqle.Engine, q string, expected []sql.Row) { + testQueryWithContext(newCtx(), t, e, q, expected) +} + +func testQueryWithContext(ctx *sql.Context, t *testing.T, e *sqle.Engine, q string, expected []sql.Row) { t.Run(q, func(t *testing.T) { require := require.New(t) - session := newCtx() - - _, iter, err := e.Query(session, q) + _, iter, err := e.Query(ctx, q) require.NoError(err) rows, err := sql.RowIterToRows(iter) diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index cc197559d..f4ea6a919 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -42,6 +42,7 @@ var OnceAfterDefault = []Rule{ var OnceAfterAll = []Rule{ {"track_process", trackProcess}, {"parallelize", parallelize}, + {"clear_warnings", clearWarnings}, } var ( diff --git a/sql/analyzer/warnings.go b/sql/analyzer/warnings.go new file mode 100644 index 000000000..642ab9211 --- /dev/null +++ b/sql/analyzer/warnings.go @@ -0,0 +1,27 @@ +package analyzer + +import ( + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" +) + +func clearWarnings(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { + children := node.Children() + if len(children) == 0 { + return node, nil + } + + switch ch := children[0].(type) { + case plan.ShowWarnings: + return node, nil + case *plan.Offset: + clearWarnings(ctx, a, ch) + return node, nil + case *plan.Limit: + clearWarnings(ctx, a, ch) + return node, nil + } + + ctx.ClearWarnings() + return node, nil +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 9f37f4125..bac013bff 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -40,6 +40,7 @@ var ( showIndexRegex = regexp.MustCompile(`^show\s+(index|indexes|keys)\s+(from|in)\s+\S+\s*`) showCreateRegex = regexp.MustCompile(`^show create\s+\S+\s*`) showVariablesRegex = regexp.MustCompile(`^show\s+(.*)?variables\s*`) + showWarningsRegex = regexp.MustCompile(`^show\s+warnings\s*`) describeRegex = regexp.MustCompile(`^(describe|desc|explain)\s+(.*)\s+`) fullProcessListRegex = regexp.MustCompile(`^show\s+(full\s+)?processlist$`) unlockTablesRegex = regexp.MustCompile(`^unlock\s+tables$`) @@ -63,6 +64,7 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { } lowerQuery := strings.ToLower(s) + switch true { case describeTablesRegex.MatchString(lowerQuery): return parseDescribeTables(lowerQuery) @@ -76,6 +78,8 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { return parseShowCreate(s) case showVariablesRegex.MatchString(lowerQuery): return parseShowVariables(ctx, s) + case showWarningsRegex.MatchString(lowerQuery): + return parseShowWarnings(ctx, s) case describeRegex.MatchString(lowerQuery): return parseDescribeQuery(ctx, s) case fullProcessListRegex.MatchString(lowerQuery): diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 782e48677..a2742232e 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -837,6 +837,9 @@ var fixtures = map[string]sql.Node{ }, plan.NewUnresolvedTable("mytable", ""), ), + `SHOW WARNINGS`: plan.NewOffset(0, plan.ShowWarnings(sql.NewEmptyContext().Warnings())), + `SHOW WARNINGS LIMIT 10`: plan.NewLimit(10, plan.NewOffset(0, plan.ShowWarnings(sql.NewEmptyContext().Warnings()))), + `SHOW WARNINGS LIMIT 5,10`: plan.NewLimit(10, plan.NewOffset(5, plan.ShowWarnings(sql.NewEmptyContext().Warnings()))), } func TestParse(t *testing.T) { diff --git a/sql/parse/warnings.go b/sql/parse/warnings.go new file mode 100644 index 000000000..10a01c331 --- /dev/null +++ b/sql/parse/warnings.go @@ -0,0 +1,77 @@ +package parse + +import ( + "bufio" + "strconv" + "strings" + + errors "gopkg.in/src-d/go-errors.v1" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" +) + +var errInvalidIndex = errors.NewKind("invalid %s index %d (index must be non-negative)") + +func parseShowWarnings(ctx *sql.Context, s string) (sql.Node, error) { + var ( + offstr string + cntstr string + ) + + r := bufio.NewReader(strings.NewReader(s)) + for _, fn := range []parseFunc{ + expect("show"), + skipSpaces, + expect("warnings"), + skipSpaces, + func(in *bufio.Reader) error { + if expect("limit")(in) == nil { + skipSpaces(in) + readValue(&cntstr)(in) + skipSpaces(in) + if expectRune(',')(in) == nil { + if readValue(&offstr)(in) == nil { + offstr, cntstr = cntstr, offstr + } + } + + } + return nil + }, + skipSpaces, + checkEOF, + } { + if err := fn(r); err != nil { + return nil, err + } + } + + var ( + node sql.Node = plan.ShowWarnings(ctx.Session.Warnings()) + offset int + count int + err error + ) + if offstr != "" { + if offset, err = strconv.Atoi(offstr); err != nil { + return nil, err + } + if offset < 0 { + return nil, errInvalidIndex.New("offset", offset) + } + } + node = plan.NewOffset(int64(offset), node) + if cntstr != "" { + if count, err = strconv.Atoi(cntstr); err != nil { + return nil, err + } + if count < 0 { + return nil, errInvalidIndex.New("count", count) + } + if count > 0 { + node = plan.NewLimit(int64(count), node) + } + } + + return node, nil +} diff --git a/sql/plan/showwarnings.go b/sql/plan/showwarnings.go new file mode 100644 index 000000000..388c46e7c --- /dev/null +++ b/sql/plan/showwarnings.go @@ -0,0 +1,51 @@ +package plan + +import ( + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +// ShowWarnings is a node that shows the session warnings +type ShowWarnings []*sql.Warning + +// Resolved implements sql.Node interface. The function always returns true. +func (ShowWarnings) Resolved() bool { + return true +} + +// TransformUp implements the sq.Transformable interface. +func (sw ShowWarnings) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { + return f(sw) +} + +// TransformExpressionsUp implements the sql.Transformable interface. +func (sw ShowWarnings) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + return sw, nil +} + +// String implements the Stringer interface. +func (ShowWarnings) String() string { + return "SHOW WARNINGS" +} + +// Schema returns a new Schema reference for "SHOW VARIABLES" query. +func (ShowWarnings) Schema() sql.Schema { + return sql.Schema{ + &sql.Column{Name: "Level", Type: sql.Text, Nullable: false}, + &sql.Column{Name: "Code", Type: sql.Int32, Nullable: true}, + &sql.Column{Name: "Message", Type: sql.Text, Nullable: false}, + } +} + +// Children implements sql.Node interface. The function always returns nil. +func (ShowWarnings) Children() []sql.Node { return nil } + +// RowIter implements the sql.Node interface. +// The function returns an iterator for warnings (considering offset and counter) +func (sw ShowWarnings) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var rows []sql.Row + for _, w := range sw { + rows = append(rows, sql.NewRow(w.Level, w.Code, w.Message)) + } + + return sql.RowsToRowIter(rows...), nil +} diff --git a/sql/plan/showwarnings_test.go b/sql/plan/showwarnings_test.go new file mode 100644 index 000000000..b04b555e1 --- /dev/null +++ b/sql/plan/showwarnings_test.go @@ -0,0 +1,40 @@ +package plan + +import ( + "io" + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +func TestShowWarnings(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + ctx.Session.Warn(&sql.Warning{"l1", "w1", 1}) + ctx.Session.Warn(&sql.Warning{"l2", "w2", 2}) + ctx.Session.Warn(&sql.Warning{"l4", "w3", 3}) + + sw := ShowWarnings(ctx.Session.Warnings()) + require.True(sw.Resolved()) + + it, err := sw.RowIter(ctx) + require.NoError(err) + + n := 3 + for row, err := it.Next(); err == nil; row, err = it.Next() { + level := row[0].(string) + code := row[1].(int) + message := row[2].(string) + + t.Logf("level: %s\tcode: %v\tmessage: %s\n", level, code, message) + + require.Equal(n, code) + n-- + } + if err != io.EOF { + require.NoError(err) + } + require.NoError(it.Close()) +} diff --git a/sql/session.go b/sql/session.go index cb177214b..9a90ab86d 100644 --- a/sql/session.go +++ b/sql/session.go @@ -32,15 +32,22 @@ type Session interface { GetAll() map[string]TypedValue // ID returns the unique ID of the connection. ID() uint32 + // Warn stores the warning in the session. + Warn(warn *Warning) + // Warnings returns a copy of session warnings (from the most recent) + Warnings() []*Warning + // ClearWarnings cleans up session warnings + ClearWarnings() } // BaseSession is the basic session type. type BaseSession struct { - id uint32 - addr string - user string - mu sync.RWMutex - config map[string]TypedValue + id uint32 + addr string + user string + mu sync.RWMutex + config map[string]TypedValue + warnings []*Warning } // User returns the current user of the session. @@ -83,12 +90,52 @@ func (s *BaseSession) GetAll() map[string]TypedValue { // ID implements the Session interface. func (s *BaseSession) ID() uint32 { return s.id } -// TypedValue is a value along with its type. -type TypedValue struct { - Typ Type - Value interface{} +// Warn stores the warning in the session. +func (s *BaseSession) Warn(warn *Warning) { + s.mu.Lock() + defer s.mu.Unlock() + s.warnings = append(s.warnings, warn) +} + +// Warnings returns a copy of session warnings (from the most recent - the last one) +// The function implements sql.Session interface +func (s *BaseSession) Warnings() []*Warning { + s.mu.RLock() + defer s.mu.RUnlock() + + n := len(s.warnings) + warns := make([]*Warning, n) + for i := 0; i < n; i++ { + warns[i] = s.warnings[n-i-1] + } + + return warns } +// ClearWarnings cleans up session warnings +func (s *BaseSession) ClearWarnings() { + s.mu.Lock() + defer s.mu.Unlock() + if s.warnings != nil { + s.warnings = s.warnings[:0] + } +} + +type ( + // TypedValue is a value along with its type. + TypedValue struct { + Typ Type + Value interface{} + } + + // Warning stands for mySQL warning record. + Warning struct { + Level string + Message string + Code int + } +) + func defaultSessionConfig() map[string]TypedValue { return map[string]TypedValue{ "auto_increment_increment": TypedValue{Int64, int64(1)}, diff --git a/sql/session_test.go b/sql/session_test.go index 76740cedf..9240944e6 100644 --- a/sql/session_test.go +++ b/sql/session_test.go @@ -22,6 +22,19 @@ func TestSessionConfig(t *testing.T) { typ, v = sess.Get("foo") require.Equal(Int64, typ) require.Equal(1, v) + + require.Equal(0, len(sess.Warnings())) + + sess.Warn(&Warning{Code: 1}) + sess.Warn(&Warning{Code: 2}) + sess.Warn(&Warning{Code: 3}) + + require.Equal(3, len(sess.Warnings())) + + require.Equal(3, sess.Warnings()[0].Code) + require.Equal(2, sess.Warnings()[1].Code) + require.Equal(1, sess.Warnings()[2].Code) + } type testNode struct{}