diff --git a/pkg/sql/explain_bundle.go b/pkg/sql/explain_bundle.go index d006a79d5349..4077c260b73c 100644 --- a/pkg/sql/explain_bundle.go +++ b/pkg/sql/explain_bundle.go @@ -25,10 +25,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt/exec/explain" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/stmtdiagnostics" "github.com/cockroachdb/cockroach/pkg/util/buildutil" + "github.com/cockroachdb/cockroach/pkg/util/intsets" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/memzipper" "github.com/cockroachdb/cockroach/pkg/util/pretty" @@ -36,6 +38,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" + "github.com/lib/pq/oid" ) const noPlan = "no plan" @@ -600,12 +603,11 @@ func (b *stmtBundleBuilder) addEnv(ctx context.Context) { // Note: we do not shortcut out of this function if there is no table/sequence/view to report: // the bundle analysis tool require schema.sql to always be present, even if it's empty. - first := true blankLine := func() { - if !first { + if buf.Len() > 0 { + // Don't add newlines to the beginning of the file. buf.WriteByte('\n') } - first = false } blankLine() c.printCreateAllDatabases(&buf, dbNames) @@ -625,28 +627,20 @@ func (b *stmtBundleBuilder) addEnv(ctx context.Context) { } if mem.Metadata().HasUserDefinedRoutines() { // Get all relevant user-defined routines. - blankLine() - err = c.PrintRelevantCreateRoutine( - &buf, strings.ToLower(b.stmt), b.flags.RedactValues, &b.errorStrings, false, /* procedure */ - ) - if err != nil { - b.printError(fmt.Sprintf("-- error getting schema for udfs: %v", err), &buf) - } - } - if call, ok := mem.RootExpr().(*memo.CallExpr); ok { - // Currently, a stored procedure can only be called from a CALL statement, - // which can only be the root expression. - if proc, ok := call.Proc.(*memo.UDFCallExpr); ok { + var ids intsets.Fast + isProcedure := make(map[oid.Oid]bool) + mem.Metadata().ForEachUserDefinedRoutine(func(ol *tree.Overload) { + ids.Add(int(ol.Oid)) + isProcedure[ol.Oid] = ol.Type == tree.ProcedureRoutine + }) + ids.ForEach(func(id int) { blankLine() - err = c.PrintRelevantCreateRoutine( - &buf, strings.ToLower(proc.Def.Name), b.flags.RedactValues, &b.errorStrings, true, /* procedure */ - ) + routineOid := oid.Oid(id) + err = c.PrintCreateRoutine(&buf, routineOid, b.flags.RedactValues, isProcedure[routineOid]) if err != nil { - b.printError(fmt.Sprintf("-- error getting schema for procedure: %v", err), &buf) + b.printError(fmt.Sprintf("-- error getting schema for routine with ID %d: %v", id, err), &buf) } - } else { - b.printError("-- unexpected input expression for CALL statement", &buf) - } + }) } for i := range tables { blankLine() @@ -1030,50 +1024,28 @@ func (c *stmtEnvCollector) PrintCreateEnum(w io.Writer, redactValues bool) error return nil } -func (c *stmtEnvCollector) PrintRelevantCreateRoutine( - w io.Writer, stmt string, redactValues bool, errorStrings *[]string, procedure bool, +func (c *stmtEnvCollector) PrintCreateRoutine( + w io.Writer, id oid.Oid, redactValues bool, procedure bool, ) error { - // The select function_name returns a DOidWrapper, - // we need to cast it to string for queryRows function to process. - // TODO(#104976): consider getting the udf sql body statements from the memo metadata. - var routineTypeName, routineNameQuery string + var createRoutineQuery string + descID := catid.UserDefinedOIDToID(id) + queryTemplate := "SELECT create_statement FROM crdb_internal.create_%[1]s_statements WHERE %[1]s_id = %[2]d" if procedure { - routineTypeName = "PROCEDURE" - routineNameQuery = "SELECT procedure_name::STRING as procedure_name_str FROM [SHOW PROCEDURES]" + createRoutineQuery = fmt.Sprintf(queryTemplate, "procedure", descID) } else { - routineTypeName = "FUNCTION" - routineNameQuery = "SELECT function_name::STRING as function_name_str FROM [SHOW FUNCTIONS]" + createRoutineQuery = fmt.Sprintf(queryTemplate, "function", descID) + } + if redactValues { + createRoutineQuery = fmt.Sprintf( + "SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM (%s)", + createRoutineQuery, + ) } - routineNames, err := c.queryRows(routineNameQuery) + createStatement, err := c.query(createRoutineQuery) if err != nil { return err } - for _, name := range routineNames { - if strings.Contains(stmt, name) { - createRoutineQuery := fmt.Sprintf( - "SELECT create_statement FROM [ SHOW CREATE %s \"%s\" ]", routineTypeName, name, - ) - if redactValues { - createRoutineQuery = fmt.Sprintf( - "SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM [ SHOW CREATE %s \"%s\" ]", - routineTypeName, name, - ) - } - createStatement, err := c.query(createRoutineQuery) - if err != nil { - var errString string - if procedure { - errString = fmt.Sprintf("-- error getting stored procedure %s: %s", name, err) - } else { - errString = fmt.Sprintf("-- error getting user defined function %s: %s", name, err) - } - fmt.Fprint(w, errString+"\n") - *errorStrings = append(*errorStrings, errString) - continue - } - fmt.Fprintf(w, "%s\n", createStatement) - } - } + fmt.Fprintf(w, "%s;\n", createStatement) return nil } diff --git a/pkg/sql/explain_bundle_test.go b/pkg/sql/explain_bundle_test.go index 9b6a76bd659e..18056a0ae7ba 100644 --- a/pkg/sql/explain_bundle_test.go +++ b/pkg/sql/explain_bundle_test.go @@ -420,6 +420,68 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R "distsql.html vec-v.txt vec.txt") }) + t.Run("different schema UDF", func(t *testing.T) { + r.Exec(t, "CREATE FUNCTION foo() RETURNS INT LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + r.Exec(t, "CREATE FUNCTION s.foo() RETURNS INT LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) SELECT s.foo();") + checkBundle( + t, fmt.Sprint(rows), "s.foo", func(name, contents string) error { + if name == "schema.sql" { + reg := regexp.MustCompile("s.foo") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for 's.foo' function in schema.sql") + } + reg = regexp.MustCompile("^foo") + if reg.FindString(contents) != "" { + return errors.Errorf("found irrelevant function 'foo' in schema.sql") + } + reg = regexp.MustCompile("s.a") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 's.a' in schema.sql") + } + reg = regexp.MustCompile("abc") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 'abc' in schema.sql") + } + } + return nil + }, + false /* expectErrors */, base, plans, + "stats-defaultdb.public.abc.sql stats-defaultdb.s.a.sql distsql.html vec-v.txt vec.txt", + ) + }) + + t.Run("different schema procedure", func(t *testing.T) { + r.Exec(t, "CREATE PROCEDURE bar() LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + r.Exec(t, "CREATE PROCEDURE s.bar() LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';") + rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) CALL s.bar();") + checkBundle( + t, fmt.Sprint(rows), "s.bar", func(name, contents string) error { + if name == "schema.sql" { + reg := regexp.MustCompile("s.bar") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for 's.bar' procedure in schema.sql") + } + reg = regexp.MustCompile("^bar") + if reg.FindString(contents) != "" { + return errors.Errorf("Found irrelevant procedure 'bar' in schema.sql") + } + reg = regexp.MustCompile("s.a") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 's.a' in schema.sql") + } + reg = regexp.MustCompile("abc") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for relation 'abc' in schema.sql") + } + } + return nil + }, + false /* expectErrors */, base, plans, + "stats-defaultdb.public.abc.sql stats-defaultdb.s.a.sql distsql.html vec-v.txt vec.txt", + ) + }) + t.Run("permission error", func(t *testing.T) { r.Exec(t, "CREATE USER test") r.Exec(t, "SET ROLE test") diff --git a/pkg/sql/opt/metadata.go b/pkg/sql/opt/metadata.go index 5172915f762a..aec784f7e881 100644 --- a/pkg/sql/opt/metadata.go +++ b/pkg/sql/opt/metadata.go @@ -604,7 +604,7 @@ func (md *Metadata) HasUserDefinedRoutines() bool { func (md *Metadata) AddUserDefinedRoutine( overload *tree.Overload, invocationTypes []*types.T, name *tree.UnresolvedObjectName, ) { - if overload.Type != tree.UDFRoutine { + if overload.Type == tree.BuiltinRoutine { return } id := cat.StableID(catid.UserDefinedOIDToID(overload.Oid)) @@ -617,6 +617,15 @@ func (md *Metadata) AddUserDefinedRoutine( } } +// ForEachUserDefinedRoutine executes the given function for each user-defined +// routine (UDF or stored procedure) overload. The order of iteration is +// non-deterministic. +func (md *Metadata) ForEachUserDefinedRoutine(fn func(overload *tree.Overload)) { + for _, dep := range md.routineDeps { + fn(dep.overload) + } +} + // AddBuiltin adds a name used to resolve a builtin function to the metadata for // this query. This is necessary to handle the case when changes to the search // path cause a function call to resolve as a UDF instead of a builtin function.