Skip to content

Commit

Permalink
extension: disable some optimizations for extension function
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Mar 20, 2024
1 parent 161a223 commit 31ff7ce
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 4 deletions.
4 changes: 4 additions & 0 deletions pkg/expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) {
if _, ok := unFoldableFunctions[x.FuncName.L]; ok {
return expr, false
}
if _, ok := x.Function.(*extensionFuncSig); ok {
// we should not fold the extension function, because it may have a side effect.
return expr, false
}
if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) {
return function(ctx, x)
}
Expand Down
19 changes: 16 additions & 3 deletions pkg/expression/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func newExtensionFuncClass(def *extension.FunctionDef) (*extensionFuncClass, err
}

func (c *extensionFuncClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.checkPrivileges(ctx); err != nil {
if err := checkPrivileges(ctx, &c.funcDef); err != nil {
return nil, err
}

Expand All @@ -107,13 +107,18 @@ func (c *extensionFuncClass) getFunction(ctx BuildContext, args []Expression) (b
if err != nil {
return nil, err
}

// Though currently, `getFunction` does not require too much information that makes it safe to be cached,
// we still skip the plan cache for extension functions because there are no strong requirements to do it.
// Skipping the plan cache can make the behavior simple.
ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("extension function should not be cached"))
bf.tp.SetFlen(c.flen)
sig := &extensionFuncSig{bf, c.funcDef}
return sig, nil
}

func (c *extensionFuncClass) checkPrivileges(ctx BuildContext) error {
fn := c.funcDef.RequireDynamicPrivileges
func checkPrivileges(ctx EvalContext, fnDef *extension.FunctionDef) error {
fn := fnDef.RequireDynamicPrivileges
if fn == nil {
return nil
}
Expand Down Expand Up @@ -155,6 +160,10 @@ func (b *extensionFuncSig) Clone() builtinFunc {
}

func (b *extensionFuncSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) {
if err := checkPrivileges(ctx, &b.FunctionDef); err != nil {
return "", true, err
}

if b.EvalTp == types.ETString {
fnCtx := newExtensionFnContext(ctx, b)
return b.EvalStringFunc(fnCtx, row)
Expand All @@ -163,6 +172,10 @@ func (b *extensionFuncSig) evalString(ctx EvalContext, row chunk.Row) (string, b
}

func (b *extensionFuncSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) {
if err := checkPrivileges(ctx, &b.FunctionDef); err != nil {
return 0, true, err
}

if b.EvalTp == types.ETInt {
fnCtx := newExtensionFnContext(ctx, b)
return b.EvalIntFunc(fnCtx, row)
Expand Down
5 changes: 5 additions & 0 deletions pkg/expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ func (sf *ScalarFunction) ConstLevel() ConstLevel {
return ConstNone
}

if _, ok := sf.Function.(*extensionFuncSig); ok {
// we should return `ConstNone` for extension functions for safety, because it may have a side effect.
return ConstNone
}

level := ConstStrict
for _, arg := range sf.GetArgs() {
argLevel := arg.ConstLevel()
Expand Down
4 changes: 3 additions & 1 deletion pkg/extension/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ go_test(
],
embed = [":extension"],
flaky = True,
shard_count = 14,
shard_count = 15,
deps = [
"//pkg/expression",
"//pkg/parser/ast",
"//pkg/parser/auth",
"//pkg/parser/mysql",
"//pkg/planner/util/fixcontrol",
"//pkg/privilege/privileges",
"//pkg/server",
"//pkg/sessionctx",
Expand All @@ -55,6 +56,7 @@ go_test(
"//pkg/testkit/testsetup",
"//pkg/types",
"//pkg/util/chunk",
"//pkg/util/mock",
"//pkg/util/sem",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//require",
Expand Down
108 changes: 108 additions & 0 deletions pkg/extension/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ import (
"fmt"
"sort"
"strings"
"sync/atomic"
"testing"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/planner/util/fixcontrol"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/pingcap/tidb/pkg/util/sem"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -318,6 +321,19 @@ func TestExtensionFuncPrivilege(t *testing.T) {
return "ghi", false, nil
},
},
{
Name: "custom_eval_int_func",
EvalTp: types.ETInt,
RequireDynamicPrivileges: func(sem bool) []string {
if sem {
return []string{"RESTRICTED_CUSTOM_DYN_PRIV_2"}
}
return []string{"CUSTOM_DYN_PRIV_1"}
},
EvalIntFunc: func(ctx extension.FunctionContext, row chunk.Row) (int64, bool, error) {
return 1, false, nil
},
},
}),
extension.WithCustomDynPrivs([]string{
"CUSTOM_DYN_PRIV_1",
Expand Down Expand Up @@ -349,34 +365,43 @@ func TestExtensionFuncPrivilege(t *testing.T) {
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))

// u1 in non-sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")

// prepare should check privilege
require.EqualError(t, tk1.ExecToErr("prepare stmt1 from 'select custom_only_dyn_priv_func()'"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("prepare stmt2 from 'select custom_eval_int_func()'"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")

// u2 in non-sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))

// u3 in non-sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")

// u4 in non-sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))

sem.Enable()

Expand All @@ -386,32 +411,115 @@ func TestExtensionFuncPrivilege(t *testing.T) {
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")

// u1 in sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")

// u2 in sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")

// u3 in sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))

// u4 in sem
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil, nil))
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))

// Test the privilege should also be checked when evaluating especially for when privilege is revoked.
// We enable `fixcontrol.Fix49736` to force enable plan cache to make sure `Expression.EvalXXX` will be invoked.
tk1.Session().GetSessionVars().OptimizerFixControl[fixcontrol.Fix49736] = "ON"
tk1.MustExec("prepare s1 from 'select custom_both_dyn_priv_func()'")
tk1.MustExec("prepare s2 from 'select custom_eval_int_func()'")
tk1.MustQuery("execute s1").Check(testkit.Rows("ghi"))
tk1.MustQuery("execute s2").Check(testkit.Rows("1"))
tk.MustExec("REVOKE RESTRICTED_CUSTOM_DYN_PRIV_2 on *.* FROM u4@localhost")
require.EqualError(t, tk1.QueryToErr("execute s1"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
require.EqualError(t, tk1.QueryToErr("execute s2"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
delete(tk1.Session().GetSessionVars().OptimizerFixControl, fixcontrol.Fix49736)
}

func TestShouldNotOptimizeExtensionFunc(t *testing.T) {
defer func() {
extension.Reset()
sem.Disable()
}()

extension.Reset()
var cnt atomic.Int64
require.NoError(t, extension.Register("test",
extension.WithCustomFunctions([]*extension.FunctionDef{
{
Name: "my_func1",
EvalTp: types.ETInt,
EvalIntFunc: func(ctx extension.FunctionContext, row chunk.Row) (int64, bool, error) {
val := cnt.Add(1)
return val, false, nil
},
},
{
Name: "my_func2",
EvalTp: types.ETString,
EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) {
val := cnt.Add(1)
if val%2 == 0 {
return "abc", false, nil
}
return "def", false, nil
},
},
}),
))

store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create table t1(a int primary key)")
tk.MustExec("insert into t1 values(1000), (2000)")

// Test extension function should not fold.
// if my_func1 is folded, the result will be "1000 1", "2000 1",
// because after fold the function will be called only once.
tk.MustQuery("select a, my_func1() from t1 order by a").Check(testkit.Rows("1000 1", "2000 2"))
require.Equal(t, int64(2), cnt.Load())

// Test extension function should not be seen as a constant, i.e., its `ConstantLevel()` should return `ConstNone`.
// my_func2 should be called twice to return different regexp string for the below query.
// If it is optimized by mistake, a wrong result "1000 0", "2000 0" will be produced.
cnt.Store(0)
tk.MustQuery("select a, 'abc' regexp my_func2() from t1 order by a").Check(testkit.Rows("1000 0", "2000 1"))

// Test flags after building expression
for _, exprStr := range []string{
"my_func1()",
"my_func2()",
} {
ctx := mock.NewContext()
ctx.GetSessionVars().StmtCtx.UseCache = true
expr, err := expression.ParseSimpleExpr(ctx, exprStr)
require.NoError(t, err)
scalar, ok := expr.(*expression.ScalarFunction)
require.True(t, ok)
require.Equal(t, expression.ConstNone, scalar.ConstLevel())
require.False(t, ctx.GetSessionVars().StmtCtx.UseCache)
}
}

0 comments on commit 31ff7ce

Please sign in to comment.