From fdab5082a76859f5c706ef44116ccd86b23f2c30 Mon Sep 17 00:00:00 2001 From: MiguelNovelo Date: Thu, 18 Mar 2021 16:45:55 -0600 Subject: [PATCH] sql: code generation for missing pg_catalog columns Previously, keeping up with postgres by adding the same columns in pg_catalog was a manual process This was inadequate because from version to version the list of missing columns might increase a lot To address this, this patch closes the gap between postgres and cockroachdb by adding the missing columns automatically with nil values Release note: None Fixes: #58001 --- .../generate-postgres-metadata-tables/main.go | 4 + pkg/sql/pg_metadata_diff.go | 71 +++- pkg/sql/pg_metadata_test.go | 390 +++++++++++++++++- .../testdata/information_schema_tables.json | 7 + pkg/sql/testdata/pg_catalog_tables.json | 11 + 5 files changed, 472 insertions(+), 11 deletions(-) diff --git a/pkg/cmd/generate-postgres-metadata-tables/main.go b/pkg/cmd/generate-postgres-metadata-tables/main.go index 2c6fd124c16d..4ed672d3e2c2 100644 --- a/pkg/cmd/generate-postgres-metadata-tables/main.go +++ b/pkg/cmd/generate-postgres-metadata-tables/main.go @@ -57,6 +57,10 @@ func main() { panic(err) } pgCatalogFile.PGMetadata.AddColumnMetadata(table, column, dataType, dataTypeOid) + columnType := pgCatalogFile.PGMetadata[table][column] + if columnType.TypeIsUnimplemented() { + pgCatalogFile.AddUnimplementedType(columnType) + } } pgCatalogFile.Save(os.Stdout) diff --git a/pkg/sql/pg_metadata_diff.go b/pkg/sql/pg_metadata_diff.go index 867aeb1fe14e..a0c6651b01dd 100644 --- a/pkg/sql/pg_metadata_diff.go +++ b/pkg/sql/pg_metadata_diff.go @@ -66,8 +66,9 @@ type PGMetadataTables map[string]PGMetadataColumns // PGMetadataFile is used to export pg_catalog from postgres and store the representation of this structure as a // json file type PGMetadataFile struct { - PGVersion string `json:"pgVersion"` - PGMetadata PGMetadataTables `json:"pgMetadata"` + PGVersion string `json:"pgVersion"` + PGMetadata PGMetadataTables `json:"pgMetadata"` + UnimplementedTypes map[oid.Oid]string `json:"unimplementedTypes"` } func (p PGMetadataTables) addColumn(tableName, columnName string, column *PGMetadataColumnType) { @@ -214,6 +215,53 @@ func (p PGMetadataTables) getUnimplementedTables(source PGMetadataTables) PGMeta return notImplemented } +// getNotImplementedColumns is used by diffs as it might not be in sync with already implemented columns +func (p PGMetadataTables) getUnimplementedColumns(target PGMetadataTables) PGMetadataTables { + unimplementedColumns := make(PGMetadataTables) + for tableName, columns := range p { + if len(columns) == 0 { + //not implemented table + continue + } + + for columnName, columnType := range columns { + if columnType != nil { + //dataType mismatch (Not a new column) + continue + } + sourceType := target[tableName][columnName] + typeOid := oid.Oid(sourceType.Oid) + if _, ok := types.OidToType[typeOid]; !ok || typeOid == oid.T_anyarray { + //can't implement this column due to missing type + continue + } + unimplementedColumns.AddColumnMetadata(tableName, columnName, sourceType.DataType, sourceType.Oid) + } + } + return unimplementedColumns +} + +func (p PGMetadataTables) removeImplementedColumns(source PGMetadataTables) { + for tableName, columns := range source { + pColumns, exists := p[tableName] + if !exists { + continue + } + for columnName := range columns { + columnType, exists := pColumns[columnName] + if !exists { + continue + } + if columnType != nil { + // type diff + continue + } + + delete(pColumns, columnName) + } + } +} + // getUnimplementedTypes verifies that all the types are implemented in cockroach db. func (c PGMetadataColumns) getUnimplementedTypes() map[oid.Oid]string { unimplemented := make(map[oid.Oid]string) @@ -226,3 +274,22 @@ func (c PGMetadataColumns) getUnimplementedTypes() map[oid.Oid]string { return unimplemented } + +// AddUnimplementedType reports a type that is not implemented in cockroachdb. +func (f *PGMetadataFile) AddUnimplementedType(columnType *PGMetadataColumnType) { + typeOid := oid.Oid(columnType.Oid) + if f.UnimplementedTypes == nil { + f.UnimplementedTypes = make(map[oid.Oid]string) + } + + f.UnimplementedTypes[typeOid] = columnType.DataType +} + +// TypeIsUnimplemented determine whether the type is implemented or not in +// cockroachdb. +func (t *PGMetadataColumnType) TypeIsUnimplemented() bool { + typeOid := oid.Oid(t.Oid) + _, ok := types.OidToType[typeOid] + // Cannot use type oid.T_anyarray in CREATE TABLE + return !ok || typeOid == oid.T_anyarray +} diff --git a/pkg/sql/pg_metadata_test.go b/pkg/sql/pg_metadata_test.go index 1ee452a15c10..98ffd99c3bae 100644 --- a/pkg/sql/pg_metadata_test.go +++ b/pkg/sql/pg_metadata_test.go @@ -33,6 +33,9 @@ import ( "encoding/json" "flag" "fmt" + "go/ast" + "go/parser" + "go/token" "io" "io/ioutil" "os" @@ -94,11 +97,14 @@ const ( var addMissingTables = flag.Bool( "add-missing-tables", false, - "add-missing-tables will complete pg_catalog tables in the go code", + "add-missing-tables will complete pg_catalog tables and columns in the go code", ) var ( tableFinderRE = regexp.MustCompile(`(?i)CREATE TABLE pg_catalog\.([^\s]+)\s`) + constNameRE = regexp.MustCompile(`const ([^\s]+)\s`) + indentationRE = regexp.MustCompile(`^(\s*)`) + indexRE = regexp.MustCompile(`(?i)INDEX\s*\([^\)]+\)`) ) var none = struct{}{} @@ -252,26 +258,56 @@ func fixConstants(t *testing.T, notImplemented PGMetadataTables) { }) } -// fixVtable adds missing table's create table constants. -func fixVtable(t *testing.T, notImplemented PGMetadataTables) { +// fixVtable adds missing table's create table constants +func fixVtable( + t *testing.T, + unimplemented PGMetadataTables, + unimplementedColumns PGMetadataTables, + pgCode *pgCatalogCode, +) { fileName := filepath.Join(vtablePkg, pgCatalogGo) // rewriteFile first will check existing create table constants to avoid duplicates. rewriteFile(fileName, func(input *os.File, output outputFile) { existingTables := make(map[string]struct{}) reader := bufio.NewScanner(input) + var constName string + var tableName string + fixedTables := make(map[string]struct{}) + var sb strings.Builder + for reader.Scan() { text := reader.Text() - output.appendString(text) - output.appendString("\n") + trimText := strings.TrimSpace(text) + constDecl := constNameRE.FindStringSubmatch(text) + if constDecl != nil { + constName = constDecl[1] + } + createTable := tableFinderRE.FindStringSubmatch(text) if createTable != nil { - tableName := createTable[1] + tableName = createTable[1] existingTables[tableName] = none } + + nextIsIndex := indexRE.MatchString(strings.ToUpper(trimText)) + if _, fixed := fixedTables[tableName]; !fixed && (text == ")`" || nextIsIndex) { + missingColumnsText := getMissingColumnsText(constName, tableName, nextIsIndex, unimplementedColumns, pgCode) + if len(missingColumnsText) > 0 { + output.seekRelative(-1) + } + output.appendString(missingColumnsText) + fixedTables[tableName] = none + pgCode.addRowPositions.removeIfNoMissingColumns(constName) + pgCode.addRowPositions.reportNewColumns(&sb, constName, tableName) + } + + output.appendString(text) + output.appendString("\n") } - for tableName, columns := range notImplemented { + first := true + for tableName, columns := range unimplemented { if _, ok := existingTables[tableName]; ok { // Table already implemented. continue @@ -282,11 +318,57 @@ func fixVtable(t *testing.T, notImplemented PGMetadataTables) { t.Log(err) continue } + reportNewTable(&sb, tableName, &first) output.appendString(createTable) } + + fmt.Println(sb.String()) }) } +func reportNewTable(sb *strings.Builder, tableName string, first *bool) { + if *first { + sb.WriteString("New Tables:\n") + *first = false + } + sb.WriteString("\t") + sb.WriteString(tableName) + sb.WriteString("\n") +} + +func getMissingColumnsText( + constName string, + tableName string, + nextIsIndex bool, + unimplementedColumns PGMetadataTables, + pgCode *pgCatalogCode, +) string { + nilPopulateTables := pgCode.fixableTables + if _, fixable := nilPopulateTables[constName]; !fixable { + return "" + } + columns, found := unimplementedColumns[tableName] + if !found { + return "" + } + var sb strings.Builder + prefix := ",\n" + if nextIsIndex { + // Previous line already had comma + prefix = "\n" + } + for columnName, columnType := range columns { + formatColumn(&sb, prefix, columnName, columnType) + pgCode.addRowPositions.addMissingColumn(constName, columnName) + prefix = ",\n" + } + if nextIsIndex { + sb.WriteString(",") + } + sb.WriteString("\n") + return sb.String() +} + // fixPgCatalogGo will update pgCatalog.allTableNames, pgCatalog.tableDefs and // will add needed virtualSchemas. func fixPgCatalogGo(notImplemented PGMetadataTables) { @@ -315,6 +397,53 @@ func fixPgCatalogGo(notImplemented PGMetadataTables) { }) } +func fixPgCatalogGoColumns(positions addRowPositionList) { + rewriteFile(pgCatalogGo, func(input *os.File, output outputFile) { + reader := bufio.NewScanner(input) + soFar := 0 + currentPosition := 0 + for reader.Scan() { + text := reader.Text() + count := len(text) + 1 + + if currentPosition < len(positions) && int64(soFar+count) > positions[currentPosition].insertPosition { + relativeIndex := int(positions[currentPosition].insertPosition-int64(soFar)) - 1 + left := text[:relativeIndex] + indentation := indentationRE.FindStringSubmatch(text)[1] //The way it is it should at least give "" + if len(strings.TrimSpace(left)) > 0 { + //Parenthesis is right after the last variable in this case indentation is correct + output.appendString(left) + output.appendString(",\n") + } else { + //Parenthesis is after a new line, we got to add one tab + indentation += "\t" + } + + output.appendString(indentation) + output.appendString("// These columns were automatically created by pg_catalog_test's missing column generator.") + output.appendString("\n") + + for _, columnName := range positions[currentPosition].missingColumns { + output.appendString(indentation) + output.appendString("tree.DNull,") + output.appendString(" // ") + output.appendString(columnName) + output.appendString("\n") + } + + output.appendString(indentation[:len(indentation)-1]) + output.appendString(text[relativeIndex:]) + currentPosition++ + } else { + // No insertion point, just write what-ever have been read + output.appendString(text) + } + output.appendString("\n") + soFar += count + } + }) +} + // printBeforeTerminalString will skip all the lines and print `s` text when finds the terminal string. func printBeforeTerminalString( reader *bufio.Scanner, output outputFile, terminalString string, s string, @@ -372,7 +501,7 @@ func getPgCatalogConstants( return pgConstants } -// outputFile wraps an *os.file to avoid explicit error checks on every WriteString. +// outputFile wraps an *os.file to avoid explicit error checks on every WriteString type outputFile struct { f *os.File } @@ -384,6 +513,12 @@ func (o outputFile) appendString(s string) { } } +func (o outputFile) seekRelative(offset int64) { + if _, err := o.f.Seek(offset, 1); err != nil { + panic(fmt.Errorf("could not seek file")) + } +} + // rewriteFile recreate a file by using the f func, this creates a temporary // file to place all the output first then it replaces the original file. func rewriteFile(fileName string, f func(*os.File, outputFile)) { @@ -611,6 +746,239 @@ func getSortedDefKeys(tableDefs map[string]string) []string { return keys } +// goParsePgCatalogGo analyzes whether a virtualSchemaTable have a return nil populate function or maps addRow calls +// which its schema to know how many columns are missing at addRow and where is the insertion points for Nulls +func goParsePgCatalogGo(t *testing.T) *pgCatalogCode { + fs := token.NewFileSet() + f, err := parser.ParseFile(fs, pgCatalogGo, nil, parser.AllErrors) + if err != nil { + t.Fatal(err) + } + pgCode := &pgCatalogCode{ + fixableTables: make(map[string]struct{}), + addRowPositions: make(map[string]addRowPositionList), + schemaParam: -1, // This value will be calculated later and once + } + + ast.Walk(&pgCatalogCodeVisitor{ + pgCode: pgCode, + schema: "", + bodyStmts: 0, + }, f) + + return pgCode +} + +// addRowPosition describes an insertion point for new columns +type addRowPosition struct { + schema string + argSize int + insertPosition int64 + missingColumns []string +} + +type addRowPositionList []*addRowPosition + +type mappedRowPositions map[string]addRowPositionList + +// pgCatalogCode describes pg_catalog.go insertion points for addRow calls and virtualSchemaTables which its populate +// func returns nil +type pgCatalogCode struct { + fixableTables map[string]struct{} + addRowPositions mappedRowPositions + schemaParam int // Index where we expect to find schema at makeAllRelationsVirtualTableWithDescriptorIDIndex +} + +// pgCatalogCodeVisitor implements ast.Visitor for traversing pg_catalog.go +type pgCatalogCodeVisitor struct { + pgCode *pgCatalogCode + schema string + bodyStmts int +} + +// next copies this visitor for inner nodes (as inner nodes might change schema) +func (v *pgCatalogCodeVisitor) next() *pgCatalogCodeVisitor { + return &pgCatalogCodeVisitor{ + pgCode: v.pgCode, + schema: v.schema, + bodyStmts: v.bodyStmts, + } +} + +// nextWithSchema will set the schema for inner nodes visitors +func (v *pgCatalogCodeVisitor) nextWithSchema(schema string) *pgCatalogCodeVisitor { + return &pgCatalogCodeVisitor{ + pgCode: v.pgCode, + schema: schema, + bodyStmts: v.bodyStmts, + } +} + +// Visit implements ast.Visitor and sets the rules for detecting schema, matching schema with addRow calls and finding +// which schemas have "return nil" at populate function +func (v *pgCatalogCodeVisitor) Visit(node ast.Node) ast.Visitor { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.KeyValueExpr: + // This is coming from virtualSchemaTable definitions + key, ok := n.Key.(*ast.Ident) + if !ok { + return v.next() + } + switch val := n.Value.(type) { + case *ast.SelectorExpr: + if key.Name != "schema" { + return v.next() + } + v.schema = val.Sel.String() + case *ast.FuncLit: + if key.Name != "populate" { + return v.next() + } + v.bodyStmts = len(val.Body.List) + } + case *ast.ReturnStmt: + result, ok := n.Results[0].(*ast.Ident) + if !ok { + return v.next() + } + + // This validates the ReturnStmt comes from populate function and it is the only statement + if result.String() == "nil" && v.schema != "" && v.bodyStmts == 1 { + // Populate function just returns nil + v.pgCode.fixableTables[v.schema] = none + } + case *ast.CallExpr: + fun, ok := n.Fun.(*ast.Ident) + if !ok { + return v.next() + } + + switch fun.Name { + case "addRow": + if v.schema == "" { + // Could not match addRow with schema + return v.next() + } + + if _, ok = v.pgCode.fixableTables[v.schema]; !ok { + v.pgCode.fixableTables[v.schema] = none + } + if _, ok = v.pgCode.addRowPositions[v.schema]; !ok { + v.pgCode.addRowPositions[v.schema] = make([]*addRowPosition, 0, 3) + } + + addRow := &addRowPosition{ + schema: v.schema, + argSize: len(n.Args), // Number of arguments must match with amount of columns + insertPosition: int64(n.Rparen), + missingColumns: make([]string, 0, 5), // This will be filled when fixing vtable and used to comment nils + } + addRowList := v.pgCode.addRowPositions[v.schema] + v.pgCode.addRowPositions[v.schema] = append(addRowList, addRow) + case "makeAllRelationsVirtualTableWithDescriptorIDIndex": + // This special case when the table definition is in this function and passed the populate function as an + // argument + schemaIndex := v.findSchemaIndex(n) + if schemaIndex < 0 || len(n.Args) <= schemaIndex { + // This is not probably that happen but just in case to avoid hitting an index out of bounds + return v.next() + } + val, ok := n.Args[schemaIndex].(*ast.SelectorExpr) + if !ok || val == nil { + return v.next() + } + + return v.nextWithSchema(val.Sel.String()) + } + } + + return v.next() +} + +// singleSortedList will retrieve all the positions at ascending order to fix columns sequentially by reading the file +func (m mappedRowPositions) singleSortedList() addRowPositionList { + positions := make(addRowPositionList, 0, len(m)) + for _, val := range m { + positions = append(positions, val...) + } + sort.Slice(positions, func(i int, j int) bool { + return positions[i].insertPosition < positions[j].insertPosition + }) + return positions +} + +// removeIfNoMissingColumns will clean up addRow calls positions that doesn't require any column adding +func (m mappedRowPositions) removeIfNoMissingColumns(constName string) { + if addRowList, ok := m[constName]; ok && len(addRowList) > 0 { + addRow := addRowList[0] + if len(addRow.missingColumns) == 0 { + delete(m, constName) + } + } +} + +// addMissingColumn adds columnName for specific constName (schema) +func (m mappedRowPositions) addMissingColumn(constName, columnName string) { + if addRows, ok := m[constName]; ok { + for _, addRow := range addRows { + addRow.missingColumns = append(addRow.missingColumns, columnName) + } + } +} + +func (m mappedRowPositions) reportNewColumns(sb *strings.Builder, constName, tableName string) { + if addRowList, ok := m[constName]; ok && len(addRowList) > 0 { + addRow := addRowList[0] + if len(addRow.missingColumns) > 0 { + sb.WriteString("New columns in table ") + sb.WriteString(tableName) + sb.WriteString(":\n") + for _, columnName := range addRow.missingColumns { + sb.WriteString("\t") + sb.WriteString(columnName) + sb.WriteString("\n") + } + } + } +} + +// findSchemaIndex is a helper function to retrieve what is the parameter index for schemaDef at function +// makeAllRelationsVirtualTableWithDescriptorIDIndex +func (v *pgCatalogCodeVisitor) findSchemaIndex(call *ast.CallExpr) int { + if v.pgCode.schemaParam != -1 { + return v.pgCode.schemaParam + } + + fun, ok := call.Fun.(*ast.Ident) + if !ok { + return -1 + } + decl, ok := fun.Obj.Decl.(*ast.FuncDecl) + if !ok { + return -1 + } + for index, field := range decl.Type.Params.List { + if len(field.Names) != 1 || field.Names[0] == nil { + continue + } + fieldType, ok := field.Type.(*ast.Ident) + if !ok { + continue + } + + if field.Names[0].String() == "schemaDef" && fieldType.String() == "string" { + v.pgCode.schemaParam = index + return v.pgCode.schemaParam + } + } + + return -1 +} + // TestPGCatalog is the pg_catalog diff tool test which compares pg_catalog // with postgres and cockroach. func TestPGCatalog(t *testing.T) { @@ -671,12 +1039,16 @@ func TestPGCatalog(t *testing.T) { } sum.report(t) + diffs.removeImplementedColumns(crdbTables) rewriteDiffs(t, diffs, filepath.Join(testdata, fmt.Sprintf(expectedDiffs, *catalogName))) if *addMissingTables { unimplemented := diffs.getUnimplementedTables(pgTables) + unimplementedColumns := diffs.getUnimplementedColumns(pgTables) + pgCode := goParsePgCatalogGo(t) fixConstants(t, unimplemented) - fixVtable(t, unimplemented) + fixVtable(t, unimplemented, unimplementedColumns, pgCode) + fixPgCatalogGoColumns(pgCode.addRowPositions.singleSortedList()) fixPgCatalogGo(unimplemented) } } diff --git a/pkg/sql/testdata/information_schema_tables.json b/pkg/sql/testdata/information_schema_tables.json index d42e55ab1599..6da47b71e5f1 100644 --- a/pkg/sql/testdata/information_schema_tables.json +++ b/pkg/sql/testdata/information_schema_tables.json @@ -4103,5 +4103,12 @@ "expectedDataType": null } } + }, + "unimplementedTypes": { + "13438": "cardinal_number", + "13441": "character_data", + "13443": "sql_identifier", + "13448": "time_stamp", + "13450": "yes_or_no" } } \ No newline at end of file diff --git a/pkg/sql/testdata/pg_catalog_tables.json b/pkg/sql/testdata/pg_catalog_tables.json index 1938acbe1f8d..ca3480248d6f 100644 --- a/pkg/sql/testdata/pg_catalog_tables.json +++ b/pkg/sql/testdata/pg_catalog_tables.json @@ -8801,5 +8801,16 @@ "expectedDataType": null } } + }, + "unimplementedTypes": { + "1034": "_aclitem", + "194": "pg_node_tree", + "2275": "cstring", + "2277": "anyarray", + "28": "xid", + "3220": "pg_lsn", + "3361": "pg_ndistinct", + "3402": "pg_dependencies", + "5017": "pg_mcv_list" } } \ No newline at end of file