Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CLI with multi-schema support and database schema source #108

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/pg-schema-diff/apply_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ func buildApplyCmd() *cobra.Command {
" (example: --allowed-hazards DELETES_DATA,INDEX_BUILD)")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
logger := log.SimpleLogger()
connConfig, err := connFlags.parseConnConfig(logger)
connConfig, err := parseConnConfig(*connFlags, logger)
if err != nil {
return err
}

planConfig, err := planFlags.parsePlanConfig()
planConfig, err := parsePlanConfig(*planFlags)
if err != nil {
return err
}
Expand Down
25 changes: 10 additions & 15 deletions cmd/pg-schema-diff/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,23 @@ import (
)

type connFlags struct {
dsn *string
dsn string
}

func createConnFlags(cmd *cobra.Command) connFlags {
dsn := cmd.Flags().String("dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)")
func createConnFlags(cmd *cobra.Command) *connFlags {
flags := &connFlags{}

cmd.Flags().StringVar(&flags.dsn, "dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)")
// Don't mark dsn as a required flag.
// Allow users to use the "PGHOST" etc environment variables like `psql`.
return connFlags{
dsn: dsn,
}

return flags
}

func (c connFlags) parseConnConfig(logger log.Logger) (*pgx.ConnConfig, error) {
if c.dsn == nil || *c.dsn == "" {
func parseConnConfig(c connFlags, logger log.Logger) (*pgx.ConnConfig, error) {
if c.dsn == "" {
logger.Warnf("DSN flag not set. Using libpq environment variables and default values.")
}

return pgx.ParseConfig(*c.dsn)
}

func mustMarkFlagAsRequired(cmd *cobra.Command, flagName string) {
if err := cmd.MarkFlagRequired(flagName); err != nil {
panic(err)
}
return pgx.ParseConfig(c.dsn)
}
148 changes: 115 additions & 33 deletions cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
Expand All @@ -18,6 +19,10 @@ import (
"github.com/stripe/pg-schema-diff/pkg/tempdb"
)

const (
defaultMaxConnections = 5
)

var (
// Match arguments in the format "regex=duration" where duration is any duration valid in time.ParseDuration
// We'll let time.ParseDuration handle the complexity of parsing invalid duration, so the regex we're extracting is
Expand Down Expand Up @@ -47,12 +52,12 @@ func buildPlanCmd() *cobra.Command {
planFlags := createPlanFlags(cmd)
cmd.RunE = func(cmd *cobra.Command, args []string) error {
logger := log.SimpleLogger()
connConfig, err := connFlags.parseConnConfig(logger)
connConfig, err := parseConnConfig(*connFlags, logger)
if err != nil {
return err
}

planConfig, err := planFlags.parsePlanConfig()
planConfig, err := parsePlanConfig(*planFlags)
if err != nil {
return err
}
Expand All @@ -75,11 +80,24 @@ func buildPlanCmd() *cobra.Command {
}

type (
schemaFlags struct {
includeSchemas []string
excludeSchemas []string
}

schemaSourceFlags struct {
schemaDir string
targetDatabaseDSN string
}

planFlags struct {
schemaDir *string
statementTimeoutModifiers *[]string
lockTimeoutModifiers *[]string
insertStatements *[]string
dbSchemaSourceFlags schemaSourceFlags

schemaFlags schemaFlags

statementTimeoutModifiers []string
lockTimeoutModifiers []string
insertStatements []string
}

timeoutModifiers struct {
Expand All @@ -93,43 +111,65 @@ type (
timeout time.Duration
}

schemaSourceFactory func() (diff.SchemaSource, io.Closer, error)

planConfig struct {
schemaDir string
schemaSourceFactory schemaSourceFactory
opts []diff.PlanOpt

statementTimeoutModifiers []timeoutModifiers
lockTimeoutModifiers []timeoutModifiers
insertStatements []insertStatement
}
)

func createPlanFlags(cmd *cobra.Command) planFlags {
schemaDir := cmd.Flags().String("schema-dir", "", "Directory containing schema files")
mustMarkFlagAsRequired(cmd, "schema-dir")
func createPlanFlags(cmd *cobra.Command) *planFlags {
flags := &planFlags{}

schemaSourceFlagsVar(cmd, &flags.dbSchemaSourceFlags)

statementTimeoutModifiers := timeoutModifierFlag(cmd, "statement", "t")
lockTimeoutModifiers := timeoutModifierFlag(cmd, "lock", "l")
insertStatements := cmd.Flags().StringArrayP("insert-statement", "s", nil,
schemaFlagsVar(cmd, &flags.schemaFlags)

timeoutModifierFlagVar(cmd, &flags.statementTimeoutModifiers, "statement", "t")
timeoutModifierFlagVar(cmd, &flags.lockTimeoutModifiers, "lock", "l")
cmd.Flags().StringArrayVarP(&flags.insertStatements, "insert-statement", "s", nil,
"<index>_<timeout>:<statement> values. Will insert the statement at the index in the "+
"generated plan with the specified timeout. This follows normal insert semantics. Example: -s '0 5s:SELECT 1''")

return planFlags{
schemaDir: schemaDir,
statementTimeoutModifiers: statementTimeoutModifiers,
lockTimeoutModifiers: lockTimeoutModifiers,
insertStatements: insertStatements,
return flags
}

func schemaSourceFlagsVar(cmd *cobra.Command, p *schemaSourceFlags) {
cmd.Flags().StringVar(&p.schemaDir, "schema-dir", "", "Directory of .SQL files to use as the schema source. Use to generate a diff between the target database and the schema in this directory.")
if err := cmd.MarkFlagDirname("schema-dir"); err != nil {
panic(err)
}
cmd.Flags().StringVar(&p.targetDatabaseDSN, "schema-source-dsn", "", "DSN for the database to use as the schema source. Use to generate a diff between the target database and the schema in this database.")

cmd.MarkFlagsMutuallyExclusive("schema-dir", "schema-source-dsn")
}

func schemaFlagsVar(cmd *cobra.Command, p *schemaFlags) {
cmd.Flags().StringArrayVar(&p.includeSchemas, "include-schema", nil, "Include the specified schema in the plan")
cmd.Flags().StringArrayVar(&p.excludeSchemas, "exclude-schema", nil, "Exclude the specified schema in the plan")
}

func timeoutModifierFlag(cmd *cobra.Command, timeoutType string, shorthand string) *[]string {
func timeoutModifierFlagVar(cmd *cobra.Command, p *[]string, timeoutType string, shorthand string) {
flagName := fmt.Sprintf("%s-timeout-modifier", timeoutType)
desc := fmt.Sprintf("regex=timeout key-value pairs, where if a statement matches the regex, the statement "+
"will be modified to have the %s timeout. If multiple regexes match, the latest regex will take priority. "+
"Example: -t 'CREATE TABLE=5m' -t 'CONCURRENTLY=10s'", timeoutType)
return cmd.Flags().StringArrayP(flagName, shorthand, nil, desc)
cmd.Flags().StringArrayVarP(p, flagName, shorthand, nil, desc)
}

func (p planFlags) parsePlanConfig() (planConfig, error) {
func parsePlanConfig(p planFlags) (planConfig, error) {
schemaSourceFactory, err := parseSchemaSource(p.dbSchemaSourceFlags)
if err != nil {
return planConfig{}, err
}

var statementTimeoutModifiers []timeoutModifiers
for _, s := range *p.statementTimeoutModifiers {
for _, s := range p.statementTimeoutModifiers {
stm, err := parseTimeoutModifier(s)
if err != nil {
return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err)
Expand All @@ -138,7 +178,7 @@ func (p planFlags) parsePlanConfig() (planConfig, error) {
}

var lockTimeoutModifiers []timeoutModifiers
for _, s := range *p.lockTimeoutModifiers {
for _, s := range p.lockTimeoutModifiers {
ltm, err := parseTimeoutModifier(s)
if err != nil {
return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err)
Expand All @@ -147,7 +187,7 @@ func (p planFlags) parsePlanConfig() (planConfig, error) {
}

var insertStatements []insertStatement
for _, i := range *p.insertStatements {
for _, i := range p.insertStatements {
is, err := parseInsertStatementStr(i)
if err != nil {
return planConfig{}, fmt.Errorf("parsing insert statement from %q: %w", i, err)
Expand All @@ -156,13 +196,49 @@ func (p planFlags) parsePlanConfig() (planConfig, error) {
}

return planConfig{
schemaDir: *p.schemaDir,
schemaSourceFactory: schemaSourceFactory,
opts: parseSchemaConfig(p.schemaFlags),
statementTimeoutModifiers: statementTimeoutModifiers,
lockTimeoutModifiers: lockTimeoutModifiers,
insertStatements: insertStatements,
}, nil
}

func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) {
if p.schemaDir != "" {
ddl, err := getDDLFromPath(p.schemaDir)
if err != nil {
return nil, err
}
return func() (diff.SchemaSource, io.Closer, error) {
return diff.DDLSchemaSource(ddl), nil, nil
}, nil
}

if p.targetDatabaseDSN != "" {
connConfig, err := pgx.ParseConfig(p.targetDatabaseDSN)
if err != nil {
return nil, fmt.Errorf("parsing DSN %q: %w", p.targetDatabaseDSN, err)
}
return func() (diff.SchemaSource, io.Closer, error) {
connPool, err := openDbWithPgxConfig(connConfig)
if err != nil {
return nil, nil, fmt.Errorf("opening db with pgx config: %w", err)
}
return diff.DBSchemaSource(connPool), connPool, nil
}, nil
}

return nil, fmt.Errorf("either --schema-dir or --schema-source-dsn must be set")
}

func parseSchemaConfig(p schemaFlags) []diff.PlanOpt {
return []diff.PlanOpt{
diff.WithIncludeSchemas(p.includeSchemas...),
diff.WithExcludeSchemas(p.excludeSchemas...),
}
}

func parseTimeoutModifier(val string) (timeoutModifiers, error) {
submatches := statementTimeoutModifierRegex.FindStringSubmatch(val)
if len(submatches) <= regexSTMRegexIndex || len(submatches) <= durationSTMRegexIndex {
Expand Down Expand Up @@ -216,11 +292,6 @@ func parseInsertStatementStr(val string) (insertStatement, error) {
}

func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnConfig, planConfig planConfig) (diff.Plan, error) {
ddl, err := getDDLFromPath(planConfig.schemaDir)
if err != nil {
return diff.Plan{}, nil
}

tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) {
copiedConfig := connConfig.Copy()
copiedConfig.Database = dbName
Expand All @@ -241,11 +312,22 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo
return diff.Plan{}, err
}
defer connPool.Close()
connPool.SetMaxOpenConns(defaultMaxConnections)

connPool.SetMaxOpenConns(5)
schemaSource, schemaSourceCloser, err := planConfig.schemaSourceFactory()
if err != nil {
return diff.Plan{}, fmt.Errorf("creating schema source: %w", err)
}
if schemaSourceCloser != nil {
defer schemaSourceCloser.Close()
}

plan, err := diff.GeneratePlan(ctx, connPool, tempDbFactory, ddl,
diff.WithDataPackNewTables(),
plan, err := diff.Generate(ctx, connPool, schemaSource,
append(
planConfig.opts,
diff.WithTempDbFactory(tempDbFactory),
diff.WithDataPackNewTables(),
)...,
)
if err != nil {
return diff.Plan{}, fmt.Errorf("generating plan: %w", err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/diff/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func WithLogger(logger log.Logger) PlanOpt {
}
}

func WithSchemas(schemas ...string) PlanOpt {
func WithIncludeSchemas(schemas ...string) PlanOpt {
return func(opts *planOptions) {
opts.getSchemaOpts = append(opts.getSchemaOpts, schema.WithIncludeSchemas(schemas...))
}
Expand All @@ -96,7 +96,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
// deprecated: GeneratePlan generates a migration plan to migrate the database to the target schema. This function only
// diffs the public schemas.
//
// Use Generate instead with the DDLSchemaSource(newDDL) and WithSchemas("public") and WithTempDbFactory options.
// Use Generate instead with the DDLSchemaSource(newDDL) and WithIncludeSchemas("public") and WithTempDbFactory options.
//
// Parameters:
// queryable: The target database to generate the diff for. It is recommended to pass in *sql.DB of the db you
Expand All @@ -106,7 +106,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
// newDDL: DDL encoding the new schema
// opts: Additional options to configure the plan generation
func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithSchemas("public"))...)
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
}

// Generate generates a migration plan to migrate the database to the target schema
Expand Down
4 changes: 2 additions & 2 deletions pkg/diff/plan_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWit
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
WithSchemas("public"),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
suite.ErrorContains(err, "tempDbFactory is required")
Expand All @@ -185,7 +185,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFac
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
WithSchemas("public"),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
suite.ErrorContains(err, "tempDbFactory is required")
Expand Down
Loading