Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sql_calc_found_rows with limit in sharded keyspace #6680

Merged
merged 17 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions go/test/endtoend/vtgate/found_rows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
Copyright 2020 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package vtgate

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/cluster"
)

func TestFoundRows(t *testing.T) {
defer cluster.PanicHandler(t)
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.Nil(t, err)
defer conn.Close()

exec(t, conn, "insert into t2(id3,id4) values(1,2), (2,2), (3,3), (4,3), (5,3)")

runTests := func(workload string) {
assertFoundRowsValue(t, conn, "select * from t2", workload, 5)
assertFoundRowsValue(t, conn, "select * from t2 limit 2", workload, 2)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 limit 2", workload, 5)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 where id3 = 4 limit 2", workload, 1)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 where id4 = 3 limit 2", workload, 3)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS id4, count(id3) from t2 where id3 = 3 group by id4 limit 1", workload, 1)
}

runTests("oltp")
exec(t, conn, "set workload = olap")
runTests("olap")

// cleanup test data
exec(t, conn, "set workload = oltp")
exec(t, conn, "delete from t2")
exec(t, conn, "delete from t2_id4_idx")
}

func assertFoundRowsValue(t *testing.T, conn *mysql.Conn, query, workload string, count int) {
exec(t, conn, query)
qr := exec(t, conn, "select found_rows()")
got := fmt.Sprintf("%v", qr.Rows)
want := fmt.Sprintf(`[[UINT64(%d)]]`, count)
assert.Equalf(t, want, got, "Workload: %s\nQuery:%s\n", workload, query)
}
20 changes: 4 additions & 16 deletions go/test/endtoend/vtgate/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,29 +410,17 @@ func TestHashLookupMultiInsertIgnore(t *testing.T) {
defer conn2.Close()

// DB should start out clean
qr := exec(t, conn, "select count(*) from t2_id4_idx")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(0)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
qr = exec(t, conn, "select count(*) from t2")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(0)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
assertMatches(t, conn, "select count(*) from t2_id4_idx", "[[INT64(0)]]")
assertMatches(t, conn, "select count(*) from t2", "[[INT64(0)]]")

// Try inserting a bunch of ids at once
exec(t, conn, "begin")
exec(t, conn, "insert ignore into t2(id3, id4) values(50,60), (30,40), (10,20)")
exec(t, conn, "commit")

// Verify
qr = exec(t, conn, "select id3, id4 from t2 order by id3")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
qr = exec(t, conn, "select id3, id4 from t2_id4_idx order by id3")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
assertMatches(t, conn, "select id3, id4 from t2 order by id3", "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]")
assertMatches(t, conn, "select id3, id4 from t2_id4_idx order by id3", "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]")
}

func TestConsistentLookupUpdate(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type noopVCursor struct {
ctx context.Context
}

func (t noopVCursor) SetFoundRows(u uint64) {
panic("implement me")
}

func (t noopVCursor) InTransactionAndIsDML() bool {
panic("implement me")
}
Expand Down Expand Up @@ -199,6 +203,10 @@ type loggingVCursor struct {
resolvedTargetTabletType topodatapb.TabletType
}

func (f *loggingVCursor) SetFoundRows(u uint64) {
panic("implement me")
}

func (f *loggingVCursor) InTransactionAndIsDML() bool {
return false
}
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ type (
SetSQLSelectLimit(int64)
SetTransactionMode(vtgatepb.TransactionMode)
SetWorkload(querypb.ExecuteOptions_Workload)
SetFoundRows(uint64)
}

// Plan represents the execution strategy for a given query.
Expand Down
124 changes: 124 additions & 0 deletions go/vt/vtgate/engine/sql_calc_found_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
Copyright 2020 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

var _ Primitive = (*SQLCalcFoundRows)(nil)

//SQLCalcFoundRows is a primitive to execute limit and count query as per their individual plan.
type SQLCalcFoundRows struct {
LimitPrimitive Primitive
CountPrimitive Primitive
}

//RouteType implements the Primitive interface
func (s SQLCalcFoundRows) RouteType() string {
return "SQLCalcFoundRows"
}

//GetKeyspaceName implements the Primitive interface
func (s SQLCalcFoundRows) GetKeyspaceName() string {
return s.LimitPrimitive.GetKeyspaceName()
}

//GetTableName implements the Primitive interface
func (s SQLCalcFoundRows) GetTableName() string {
return s.LimitPrimitive.GetTableName()
}

//Execute implements the Primitive interface
func (s SQLCalcFoundRows) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
limitQr, err := s.LimitPrimitive.Execute(vcursor, bindVars, wantfields)
if err != nil {
return nil, err
}
countQr, err := s.CountPrimitive.Execute(vcursor, bindVars, false)
if err != nil {
return nil, err
}
if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query is not a scalar")
}
fr, err := evalengine.ToUint64(countQr.Rows[0][0])
if err != nil {
return nil, err
}
vcursor.Session().SetFoundRows(fr)
return limitQr, nil
}

//StreamExecute implements the Primitive interface
func (s SQLCalcFoundRows) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
err := s.LimitPrimitive.StreamExecute(vcursor, bindVars, wantfields, callback)
if err != nil {
return err
}

var fr *uint64

err = s.CountPrimitive.StreamExecute(vcursor, bindVars, wantfields, func(countQr *sqltypes.Result) error {
if len(countQr.Rows) == 0 && countQr.Fields != nil {
// this is the fields, which we can ignore
return nil
}
if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query is not a scalar")
}
toUint64, err := evalengine.ToUint64(countQr.Rows[0][0])
if err != nil {
return err
}
fr = &toUint64
return nil
})
if err != nil {
return err
}
if fr == nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query for SQL_CALC_FOUND_ROWS never returned a value")
}
vcursor.Session().SetFoundRows(*fr)
return nil
}

//GetFields implements the Primitive interface
func (s SQLCalcFoundRows) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return s.LimitPrimitive.GetFields(vcursor, bindVars)
}

//NeedsTransaction implements the Primitive interface
func (s SQLCalcFoundRows) NeedsTransaction() bool {
return s.LimitPrimitive.NeedsTransaction()
}

//Inputs implements the Primitive interface
func (s SQLCalcFoundRows) Inputs() []Primitive {
return []Primitive{s.LimitPrimitive, s.CountPrimitive}
}

func (s SQLCalcFoundRows) description() PrimitiveDescription {
return PrimitiveDescription{
OperatorType: "SQL_CALC_FOUND_ROWS",
}
}
13 changes: 12 additions & 1 deletion go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ func saveSessionStats(safeSession *SafeSession, stmtType sqlparser.StatementType
if err != nil {
return
}
safeSession.FoundRows = result.RowsAffected
if !safeSession.foundRowsHandled {
safeSession.FoundRows = result.RowsAffected
}
if result.InsertID > 0 {
safeSession.LastInsertId = result.InsertID
}
Expand Down Expand Up @@ -997,6 +999,7 @@ func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession
result := &sqltypes.Result{}
byteCount := 0
seenResults := false
var foundRows uint64
err = plan.Instructions.StreamExecute(vcursor, bindVars, true, func(qr *sqltypes.Result) error {
// If the row has field info, send it separately.
// TODO(sougou): this behavior is for handling tests because
Expand All @@ -1009,8 +1012,10 @@ func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession
seenResults = true
}

foundRows += uint64(len(qr.Rows))
for _, row := range qr.Rows {
result.Rows = append(result.Rows, row)

for _, col := range row {
byteCount += col.Len()
}
Expand Down Expand Up @@ -1038,6 +1043,12 @@ func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession
logStats.ExecuteTime = time.Since(execStart)
e.updateQueryCounts(plan.Instructions.RouteType(), plan.Instructions.GetKeyspaceName(), plan.Instructions.GetTableName(), int64(logStats.ShardQueries))

// save session stats for future queries
if !safeSession.foundRowsHandled {
safeSession.FoundRows = foundRows
}
safeSession.RowCount = -1

return err
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func buildRoutePlan(stmt sqlparser.Statement, vschema ContextVSchema, f func(sta
func createInstructionFor(query string, stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) {
switch stmt := stmt.(type) {
case *sqlparser.Select:
return buildRoutePlan(stmt, vschema, buildSelectPlan)
return buildRoutePlan(stmt, vschema, buildSelectPlan(query))
case *sqlparser.Insert:
return buildRoutePlan(stmt, vschema, buildInsertPlan)
case *sqlparser.Update:
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr) (pullouts []*pullout
spb := newPrimitiveBuilder(pb.vschema, pb.jt)
switch stmt := node.Select.(type) {
case *sqlparser.Select:
if err := spb.processSelect(stmt, pb.st); err != nil {
if err := spb.processSelect(stmt, pb.st, ""); err != nil {
return false, err
}
case *sqlparser.Union:
Expand Down Expand Up @@ -230,7 +230,7 @@ func (pb *primitiveBuilder) finalizeUnshardedDMLSubqueries(nodes ...sqlparser.SQ
return true, nil
}
spb := newPrimitiveBuilder(pb.vschema, pb.jt)
if err := spb.processSelect(nodeType, pb.st); err != nil {
if err := spb.processSelect(nodeType, pb.st, ""); err != nil {
samePlan = false
return false, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/from.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (pb *primitiveBuilder) processAliasedTable(tableExpr *sqlparser.AliasedTabl
spb := newPrimitiveBuilder(pb.vschema, pb.jt)
switch stmt := expr.Select.(type) {
case *sqlparser.Select:
if err := spb.processSelect(stmt, nil); err != nil {
if err := spb.processSelect(stmt, nil, ""); err != nil {
return err
}
case *sqlparser.Union:
Expand Down
Loading