diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 35d5fdb68a030..1f2c7ee967c6e 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -38,6 +38,7 @@ import ( // This plan is much faster to build and to execute because it avoid the optimization and coprocessor cost. type PointGetPlan struct { basePlan + dbName string schema *expression.Schema TblInfo *model.TableInfo IndexInfo *model.IndexInfo @@ -223,7 +224,11 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if schema == nil { return nil } - p := newPointGetPlan(ctx, schema, tbl) + dbName := tblName.Schema.L + if dbName == "" { + dbName = ctx.GetSessionVars().CurrentDB + } + p := newPointGetPlan(ctx, dbName, schema, tbl) intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType) if err != nil { if terror.ErrorEqual(types.ErrOverflow, err) { @@ -263,7 +268,11 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if schema == nil { return nil } - p := newPointGetPlan(ctx, schema, tbl) + dbName := tblName.Schema.L + if dbName == "" { + dbName = ctx.GetSessionVars().CurrentDB + } + p := newPointGetPlan(ctx, dbName, schema, tbl) p.IndexInfo = idxInfo p.IndexValues = idxValues p.IndexValueParams = idxValueParams @@ -272,9 +281,10 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP return nil } -func newPointGetPlan(ctx sessionctx.Context, schema *expression.Schema, tbl *model.TableInfo) *PointGetPlan { +func newPointGetPlan(ctx sessionctx.Context, dbName string, schema *expression.Schema, tbl *model.TableInfo) *PointGetPlan { p := &PointGetPlan{ basePlan: newBasePlan(ctx, "Point_Get"), + dbName: dbName, schema: schema, TblInfo: tbl, } @@ -287,9 +297,8 @@ func checkFastPlanPrivilege(ctx sessionctx.Context, fastPlan *PointGetPlan, chec if pm == nil { return nil } - dbName := ctx.GetSessionVars().CurrentDB for _, checkType := range checkTypes { - if !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, dbName, fastPlan.TblInfo.Name.L, "", checkType) { + if !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, fastPlan.dbName, fastPlan.TblInfo.Name.L, "", checkType) { return errors.New("privilege check fail") } } diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 14cad3c030bab..6ada150296ddb 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -129,6 +129,24 @@ func (s *testPrivilegeSuite) TestCheckDBPrivilege(c *C) { c.Assert(pc.RequestVerification(activeRoles, "test", "", "", mysql.UpdatePriv), IsTrue) } +func (s *testPrivilegeSuite) TestCheckPointGetDBPrivilege(c *C) { + rootSe := newSession(c, s.store, s.dbName) + mustExec(c, rootSe, `CREATE USER 'tester'@'localhost';`) + mustExec(c, rootSe, `GRANT SELECT,UPDATE ON test.* TO 'tester'@'localhost';`) + mustExec(c, rootSe, `flush privileges;`) + mustExec(c, rootSe, `create database test2`) + mustExec(c, rootSe, `create table test2.t(id int, v int, primary key(id))`) + mustExec(c, rootSe, `insert into test2.t(id, v) values(1, 1)`) + + se := newSession(c, s.store, s.dbName) + c.Assert(se.Auth(&auth.UserIdentity{Username: "tester", Hostname: "localhost"}, nil, nil), IsTrue) + mustExec(c, se, `use test;`) + _, err := se.Execute(context.Background(), `select * from test2.t where id = 1`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + _, err = se.Execute(context.Background(), "update test2.t set v = 2 where id = 1") + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) +} + func (s *testPrivilegeSuite) TestCheckTablePrivilege(c *C) { rootSe := newSession(c, s.store, s.dbName) mustExec(c, rootSe, `CREATE USER 'test1'@'localhost';`)