Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: add privilege check for prepare stmt #36933

Merged
merged 9 commits into from
Aug 10, 2022
5 changes: 5 additions & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error {
NormalizedSQL4PC: normalizedSQL4PC,
SQLDigest4PC: digest4PC,
}

if err = plannercore.CheckPreparedPriv(e.ctx, preparedObj, ret.InfoSchema); err != nil {
return err
}

return vars.AddPreparedStmt(e.ID, preparedObj)
}

Expand Down
10 changes: 5 additions & 5 deletions planner/core/plan_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, i
}

if stmtAst.UseCache && !ignorePlanCache { // for general plans
if plan, names, ok, err := getGeneralPlan(ctx, sctx, cacheKey, bindSQL, is, stmt,
if plan, names, ok, err := getGeneralPlan(sctx, cacheKey, bindSQL, is, stmt,
paramTypes); err != nil || ok {
return plan, names, err
}
Expand Down Expand Up @@ -136,14 +136,14 @@ func getPointQueryPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmtC
return plan, names, true, nil
}

func getGeneralPlan(ctx context.Context, sctx sessionctx.Context, cacheKey kvcache.Key, bindSQL string,
func getGeneralPlan(sctx sessionctx.Context, cacheKey kvcache.Key, bindSQL string,
is infoschema.InfoSchema, stmt *PlanCacheStmt, paramTypes []*types.FieldType) (Plan,
[]*types.FieldName, bool, error) {
sessVars := sctx.GetSessionVars()
stmtCtx := sessVars.StmtCtx

if cacheValue, exists := sctx.PreparedPlanCache().Get(cacheKey); exists {
if err := checkPreparedPriv(ctx, sctx, stmt, is); err != nil {
if err := CheckPreparedPriv(sctx, stmt, is); err != nil {
return nil, nil, false, err
}
cachedVals := cacheValue.([]*PlanCacheValue)
Expand Down Expand Up @@ -537,8 +537,8 @@ func buildRangeForIndexScan(sctx sessionctx.Context, is *PhysicalIndexScan) (err
return
}

func checkPreparedPriv(_ context.Context, sctx sessionctx.Context,
stmt *PlanCacheStmt, is infoschema.InfoSchema) error {
// CheckPreparedPriv checks the privilege of the prepared statement
func CheckPreparedPriv(sctx sessionctx.Context, stmt *PlanCacheStmt, is infoschema.InfoSchema) error {
if pm := privilege.GetPrivilegeManager(sctx); pm != nil {
visitInfo := VisitInfo4PrivCheck(is, stmt.PreparedAst.Stmt, stmt.VisitInfos)
if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, visitInfo); err != nil {
Expand Down
40 changes: 40 additions & 0 deletions privilege/privileges/privileges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2895,3 +2895,43 @@ func TestIssue29823(t *testing.T) {
err = tk2.QueryToErr("show tables from test")
require.EqualError(t, err, "[executor:1044]Access denied for user 'u1'@'%' to database 'test'")
}

func TestCheckPreparePrivileges(t *testing.T) {
store := createStoreAndPrepareDB(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create user u1")
tk.MustExec("create table t (a int)")
tk.MustExec("insert into t values(1)")

tk2 := testkit.NewTestKit(t, store)
require.True(t, tk2.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "%"}, nil, nil))

// sql
err := tk2.ExecToErr("prepare s from 'select * from test.t'")
require.EqualError(t, err, "[planner:1142]SELECT command denied to user 'u1'@'%' for table 't'")
err = tk2.ExecToErr("execute s")
require.EqualError(t, err, "[planner:8111]Prepared statement not found")

// binary proto
stmtID, _, _, err := tk2.Session().PrepareStmt("select * from test.t")
require.EqualError(t, err, "[planner:1142]SELECT command denied to user 'u1'@'%' for table 't'")
require.Zero(t, stmtID)

// grant
tk.MustExec("grant SELECT ON test.t TO 'u1'@'%';")

// should success after grant
tk2.MustExec("prepare s from 'select * from test.t'")
tk2.MustQuery("execute s").Check(testkit.Rows("1"))
stmtID, _, _, err = tk2.Session().PrepareStmt("select * from test.t")
require.NoError(t, err)
require.NotZero(t, stmtID)
rs, err := tk2.Session().ExecutePreparedStmt(context.TODO(), stmtID, nil)
require.NoError(t, err)
defer func() {
require.NoError(t, rs.Close())
}()
tk2.ResultSetToResult(rs, "").Check(testkit.Rows("1"))
}