Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sql_calc_found_rows with limit in sharded keyspace #6680

Merged
merged 17 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions go/test/endtoend/vtgate/found_rows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright 2020 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package vtgate

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/cluster"
)

func TestFoundRows(t *testing.T) {
defer cluster.PanicHandler(t)
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.Nil(t, err)
defer conn.Close()

exec(t, conn, "insert into t2(id3,id4) values(1,2), (2,2), (3,3), (4,3), (5,3)")

assertFoundRowsValue(t, conn, "select * from t2", 5)
assertFoundRowsValue(t, conn, "select * from t2 limit 2", 2)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 limit 2", 5)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 where id3 = 4 limit 2", 1)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 where id4 = 3 limit 2", 3)
assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS id4, count(id3) from t2 where id3 = 3 group by id4 limit 1", 1)

// cleanup test data
exec(t, conn, "delete from t2")
exec(t, conn, "delete from t2_id4_idx") // TODO systay do we really need to do this manually?
}

func assertFoundRowsValue(t *testing.T, conn *mysql.Conn, query string, count int) {
exec(t, conn, query)
qr := exec(t, conn, "select found_rows()")
got := fmt.Sprintf("%v", qr.Rows)
want := fmt.Sprintf(`[[UINT64(%d)]]`, count)
require.Equal(t, want, got)
}
20 changes: 4 additions & 16 deletions go/test/endtoend/vtgate/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,29 +410,17 @@ func TestHashLookupMultiInsertIgnore(t *testing.T) {
defer conn2.Close()

// DB should start out clean
qr := exec(t, conn, "select count(*) from t2_id4_idx")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(0)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
qr = exec(t, conn, "select count(*) from t2")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(0)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
assertMatches(t, conn, "select count(*) from t2_id4_idx", "[[INT64(0)]]")
assertMatches(t, conn, "select count(*) from t2", "[[INT64(0)]]")

// Try inserting a bunch of ids at once
exec(t, conn, "begin")
exec(t, conn, "insert ignore into t2(id3, id4) values(50,60), (30,40), (10,20)")
exec(t, conn, "commit")

// Verify
qr = exec(t, conn, "select id3, id4 from t2 order by id3")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
qr = exec(t, conn, "select id3, id4 from t2_id4_idx order by id3")
if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]"; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
assertMatches(t, conn, "select id3, id4 from t2 order by id3", "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]")
assertMatches(t, conn, "select id3, id4 from t2_id4_idx order by id3", "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]")
}

func TestConsistentLookupUpdate(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type noopVCursor struct {
ctx context.Context
}

func (t noopVCursor) SetFoundRows(u uint64) {
panic("implement me")
}

func (t noopVCursor) InTransactionAndIsDML() bool {
panic("implement me")
}
Expand Down Expand Up @@ -199,6 +203,10 @@ type loggingVCursor struct {
resolvedTargetTabletType topodatapb.TabletType
}

func (f *loggingVCursor) SetFoundRows(u uint64) {
panic("implement me")
}

func (f *loggingVCursor) InTransactionAndIsDML() bool {
return false
}
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ type (
SetSQLSelectLimit(int64)
SetTransactionMode(vtgatepb.TransactionMode)
SetWorkload(querypb.ExecuteOptions_Workload)
SetFoundRows(uint64)
}

// Plan represents the execution strategy for a given query.
Expand Down
110 changes: 110 additions & 0 deletions go/vt/vtgate/engine/sql_calc_found_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
Copyright 2020 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

var _ Primitive = (*SQLCalcFoundRows)(nil)

//SQLCalcFoundRows is a primitive to execute limit and count query as per their individual plan.
type SQLCalcFoundRows struct {
LimitPrimitive Primitive
CountPrimitive Primitive
}

//RouteType implements the Primitive interface
func (s SQLCalcFoundRows) RouteType() string {
return "SQLCalcFoundRows"
}

//GetKeyspaceName implements the Primitive interface
func (s SQLCalcFoundRows) GetKeyspaceName() string {
return s.LimitPrimitive.GetKeyspaceName()
}

//GetTableName implements the Primitive interface
func (s SQLCalcFoundRows) GetTableName() string {
return s.LimitPrimitive.GetTableName()
}

//Execute implements the Primitive interface
func (s SQLCalcFoundRows) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
limitQr, err := s.LimitPrimitive.Execute(vcursor, bindVars, wantfields)
if err != nil {
return nil, err
}
countQr, err := s.CountPrimitive.Execute(vcursor, bindVars, false)
if err != nil {
return nil, err
}
if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "count query is not a scalar")
}
fr, err := evalengine.ToUint64(countQr.Rows[0][0])
if err != nil {
return nil, err
}
vcursor.Session().SetFoundRows(fr)
return limitQr, nil
}

//StreamExecute implements the Primitive interface
func (s SQLCalcFoundRows) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
err := s.LimitPrimitive.StreamExecute(vcursor, bindVars, wantfields, callback)
if err != nil {
return err
}

return s.CountPrimitive.StreamExecute(vcursor, bindVars, wantfields, func(countQr *sqltypes.Result) error {
if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "count query is not a scalar")
}
fr, err := evalengine.ToUint64(countQr.Rows[0][0])
if err != nil {
return err
}
vcursor.Session().SetFoundRows(fr)
return nil
})
}

systay marked this conversation as resolved.
Show resolved Hide resolved
//GetFields implements the Primitive interface
func (s SQLCalcFoundRows) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return s.LimitPrimitive.GetFields(vcursor, bindVars)
}

//NeedsTransaction implements the Primitive interface
func (s SQLCalcFoundRows) NeedsTransaction() bool {
return s.LimitPrimitive.NeedsTransaction()
}

//Inputs implements the Primitive interface
func (s SQLCalcFoundRows) Inputs() []Primitive {
return []Primitive{s.LimitPrimitive, s.CountPrimitive}
}

func (s SQLCalcFoundRows) description() PrimitiveDescription {
return PrimitiveDescription{
OperatorType: "SQL_CALC_FOUND_ROWS",
}
}
4 changes: 3 additions & 1 deletion go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ func saveSessionStats(safeSession *SafeSession, stmtType sqlparser.StatementType
if err != nil {
return
}
safeSession.FoundRows = result.RowsAffected
if !safeSession.foundRowsHandled {
safeSession.FoundRows = result.RowsAffected
}
if result.InsertID > 0 {
safeSession.LastInsertId = result.InsertID
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func buildRoutePlan(stmt sqlparser.Statement, vschema ContextVSchema, f func(sta
func createInstructionFor(query string, stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) {
switch stmt := stmt.(type) {
case *sqlparser.Select:
return buildRoutePlan(stmt, vschema, buildSelectPlan)
return buildRoutePlan(stmt, vschema, buildSelectPlan(query))
case *sqlparser.Insert:
return buildRoutePlan(stmt, vschema, buildInsertPlan)
case *sqlparser.Update:
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr) (pullouts []*pullout
spb := newPrimitiveBuilder(pb.vschema, pb.jt)
switch stmt := node.Select.(type) {
case *sqlparser.Select:
if err := spb.processSelect(stmt, pb.st); err != nil {
if err := spb.processSelect(stmt, pb.st, ""); err != nil {
return false, err
}
case *sqlparser.Union:
Expand Down Expand Up @@ -230,7 +230,7 @@ func (pb *primitiveBuilder) finalizeUnshardedDMLSubqueries(nodes ...sqlparser.SQ
return true, nil
}
spb := newPrimitiveBuilder(pb.vschema, pb.jt)
if err := spb.processSelect(nodeType, pb.st); err != nil {
if err := spb.processSelect(nodeType, pb.st, ""); err != nil {
samePlan = false
return false, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/from.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (pb *primitiveBuilder) processAliasedTable(tableExpr *sqlparser.AliasedTabl
spb := newPrimitiveBuilder(pb.vschema, pb.jt)
switch stmt := expr.Select.(type) {
case *sqlparser.Select:
if err := spb.processSelect(stmt, nil); err != nil {
if err := spb.processSelect(stmt, nil, ""); err != nil {
return err
}
case *sqlparser.Union:
Expand Down
Loading