From 77781b60c29aa1f39ad907a5add39e28d2547bd9 Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Mon, 7 Oct 2024 20:06:54 -0600 Subject: [PATCH] sql: use memo metadata to add routines to statement bundles This commit updates the statement bundle logic to take advantage of the information stored in the query plan metadata so that only relevant routines are shown in `schema.sql` for a statement bundle. In addition, stored procedures are now tracked in the metadata in addition to UDFs. This has no impact on query-plan caching, since we currently do not cache plans that invoke stored procedures. Fixes #132142 Fixes #104976 Release note (bug fix): Fixed a bug that prevented the create statement for a routine from being shown in a statement bundle. This happened when the routine was created on a schema other than `public`. The bug has existed since v23.1. --- pkg/sql/explain_bundle.go | 90 ++++++++++++---------------------- pkg/sql/explain_bundle_test.go | 62 +++++++++++++++++++++++ pkg/sql/opt/metadata.go | 11 ++++- 3 files changed, 103 insertions(+), 60 deletions(-) diff --git a/pkg/sql/explain_bundle.go b/pkg/sql/explain_bundle.go index d006a79d5349..4077c260b73c 100644 --- a/pkg/sql/explain_bundle.go +++ b/pkg/sql/explain_bundle.go @@ -25,10 +25,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt/exec/explain" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/stmtdiagnostics" "github.com/cockroachdb/cockroach/pkg/util/buildutil" + "github.com/cockroachdb/cockroach/pkg/util/intsets" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/memzipper" "github.com/cockroachdb/cockroach/pkg/util/pretty" @@ -36,6 +38,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" + "github.com/lib/pq/oid" ) const noPlan = "no plan" @@ -600,12 +603,11 @@ func (b *stmtBundleBuilder) addEnv(ctx context.Context) { // Note: we do not shortcut out of this function if there is no table/sequence/view to report: // the bundle analysis tool require schema.sql to always be present, even if it's empty. - first := true blankLine := func() { - if !first { + if buf.Len() > 0 { + // Don't add newlines to the beginning of the file. buf.WriteByte('\n') } - first = false } blankLine() c.printCreateAllDatabases(&buf, dbNames) @@ -625,28 +627,20 @@ func (b *stmtBundleBuilder) addEnv(ctx context.Context) { } if mem.Metadata().HasUserDefinedRoutines() { // Get all relevant user-defined routines. - blankLine() - err = c.PrintRelevantCreateRoutine( - &buf, strings.ToLower(b.stmt), b.flags.RedactValues, &b.errorStrings, false, /* procedure */ - ) - if err != nil { - b.printError(fmt.Sprintf("-- error getting schema for udfs: %v", err), &buf) - } - } - if call, ok := mem.RootExpr().(*memo.CallExpr); ok { - // Currently, a stored procedure can only be called from a CALL statement, - // which can only be the root expression. - if proc, ok := call.Proc.(*memo.UDFCallExpr); ok { + var ids intsets.Fast + isProcedure := make(map[oid.Oid]bool) + mem.Metadata().ForEachUserDefinedRoutine(func(ol *tree.Overload) { + ids.Add(int(ol.Oid)) + isProcedure[ol.Oid] = ol.Type == tree.ProcedureRoutine + }) + ids.ForEach(func(id int) { blankLine() - err = c.PrintRelevantCreateRoutine( - &buf, strings.ToLower(proc.Def.Name), b.flags.RedactValues, &b.errorStrings, true, /* procedure */ - ) + routineOid := oid.Oid(id) + err = c.PrintCreateRoutine(&buf, routineOid, b.flags.RedactValues, isProcedure[routineOid]) if err != nil { - b.printError(fmt.Sprintf("-- error getting schema for procedure: %v", err), &buf) + b.printError(fmt.Sprintf("-- error getting schema for routine with ID %d: %v", id, err), &buf) } - } else { - b.printError("-- unexpected input expression for CALL statement", &buf) - } + }) } for i := range tables { blankLine() @@ -1030,50 +1024,28 @@ func (c *stmtEnvCollector) PrintCreateEnum(w io.Writer, redactValues bool) error return nil } -func (c *stmtEnvCollector) PrintRelevantCreateRoutine( - w io.Writer, stmt string, redactValues bool, errorStrings *[]string, procedure bool, +func (c *stmtEnvCollector) PrintCreateRoutine( + w io.Writer, id oid.Oid, redactValues bool, procedure bool, ) error { - // The select function_name returns a DOidWrapper, - // we need to cast it to string for queryRows function to process. - // TODO(#104976): consider getting the udf sql body statements from the memo metadata. - var routineTypeName, routineNameQuery string + var createRoutineQuery string + descID := catid.UserDefinedOIDToID(id) + queryTemplate := "SELECT create_statement FROM crdb_internal.create_%[1]s_statements WHERE %[1]s_id = %[2]d" if procedure { - routineTypeName = "PROCEDURE" - routineNameQuery = "SELECT procedure_name::STRING as procedure_name_str FROM [SHOW PROCEDURES]" + createRoutineQuery = fmt.Sprintf(queryTemplate, "procedure", descID) } else { - routineTypeName = "FUNCTION" - routineNameQuery = "SELECT function_name::STRING as function_name_str FROM [SHOW FUNCTIONS]" + createRoutineQuery = fmt.Sprintf(queryTemplate, "function", descID) + } + if redactValues { + createRoutineQuery = fmt.Sprintf( + "SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM (%s)", + createRoutineQuery, + ) } - routineNames, err := c.queryRows(routineNameQuery) + createStatement, err := c.query(createRoutineQuery) if err != nil { return err } - for _, name := range routineNames { - if strings.Contains(stmt, name) { - createRoutineQuery := fmt.Sprintf( - "SELECT create_statement FROM [ SHOW CREATE %s \"%s\" ]", routineTypeName, name, - ) - if redactValues { - createRoutineQuery = fmt.Sprintf( - "SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM [ SHOW CREATE %s \"%s\" ]", - routineTypeName, name, - ) - } - createStatement, err := c.query(createRoutineQuery) - if err != nil { - var errString string - if procedure { - errString = fmt.Sprintf("-- error getting stored procedure %s: %s", name, err) - } else { - errString = fmt.Sprintf("-- error getting user defined function %s: %s", name, err) - } - fmt.Fprint(w, errString+"\n") - *errorStrings = append(*errorStrings, errString) - continue - } - fmt.Fprintf(w, "%s\n", createStatement) - } - } + fmt.Fprintf(w, "%s;\n", createStatement) return nil } diff --git a/pkg/sql/explain_bundle_test.go b/pkg/sql/explain_bundle_test.go index 9b6a76bd659e..18056a0ae7ba 100644 --- a/pkg/sql/explain_bundle_test.go +++ b/pkg/sql/explain_bundle_test.go @@ -420,6 +420,68 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R "distsql.html vec-v.txt vec.txt") }) + t.Run("different schema UDF", func(t *testing.T) { + r.Exec(t, "CREATE FUNCTION foo() RETURNS INT LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + r.Exec(t, "CREATE FUNCTION s.foo() RETURNS INT LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) SELECT s.foo();") + checkBundle( + t, fmt.Sprint(rows), "s.foo", func(name, contents string) error { + if name == "schema.sql" { + reg := regexp.MustCompile("s.foo") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for 's.foo' function in schema.sql") + } + reg = regexp.MustCompile("^foo") + if reg.FindString(contents) != "" { + return errors.Errorf("found irrelevant function 'foo' in schema.sql") + } + reg = regexp.MustCompile("s.a") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 's.a' in schema.sql") + } + reg = regexp.MustCompile("abc") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 'abc' in schema.sql") + } + } + return nil + }, + false /* expectErrors */, base, plans, + "stats-defaultdb.public.abc.sql stats-defaultdb.s.a.sql distsql.html vec-v.txt vec.txt", + ) + }) + + t.Run("different schema procedure", func(t *testing.T) { + r.Exec(t, "CREATE PROCEDURE bar() LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + r.Exec(t, "CREATE PROCEDURE s.bar() LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) CALL s.bar();") + checkBundle( + t, fmt.Sprint(rows), "s.bar", func(name, contents string) error { + if name == "schema.sql" { + reg := regexp.MustCompile("s.bar") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for 's.bar' procedure in schema.sql") + } + reg = regexp.MustCompile("^bar") + if reg.FindString(contents) != "" { + return errors.Errorf("Found irrelevant procedure 'bar' in schema.sql") + } + reg = regexp.MustCompile("s.a") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 's.a' in schema.sql") + } + reg = regexp.MustCompile("abc") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 'abc' in schema.sql") + } + } + return nil + }, + false /* expectErrors */, base, plans, + "stats-defaultdb.public.abc.sql stats-defaultdb.s.a.sql distsql.html vec-v.txt vec.txt", + ) + }) + t.Run("permission error", func(t *testing.T) { r.Exec(t, "CREATE USER test") r.Exec(t, "SET ROLE test") diff --git a/pkg/sql/opt/metadata.go b/pkg/sql/opt/metadata.go index 5172915f762a..aec784f7e881 100644 --- a/pkg/sql/opt/metadata.go +++ b/pkg/sql/opt/metadata.go @@ -604,7 +604,7 @@ func (md *Metadata) HasUserDefinedRoutines() bool { func (md *Metadata) AddUserDefinedRoutine( overload *tree.Overload, invocationTypes []*types.T, name *tree.UnresolvedObjectName, ) { - if overload.Type != tree.UDFRoutine { + if overload.Type == tree.BuiltinRoutine { return } id := cat.StableID(catid.UserDefinedOIDToID(overload.Oid)) @@ -617,6 +617,15 @@ func (md *Metadata) AddUserDefinedRoutine( } } +// ForEachUserDefinedRoutine executes the given function for each user-defined +// routine (UDF or stored procedure) overload. The order of iteration is +// non-deterministic. +func (md *Metadata) ForEachUserDefinedRoutine(fn func(overload *tree.Overload)) { + for _, dep := range md.routineDeps { + fn(dep.overload) + } +} + // AddBuiltin adds a name used to resolve a builtin function to the metadata for // this query. This is necessary to handle the case when changes to the search // path cause a function call to resolve as a UDF instead of a builtin function.