Skip to content

Commit

Permalink
feat: diff auth and storage config on link
Browse files Browse the repository at this point in the history
  • Loading branch information
sweatybridge committed Nov 19, 2024
1 parent 91eabb1 commit f4b5ad4
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 49 deletions.
91 changes: 52 additions & 39 deletions internal/link/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"os"
"strconv"
"strings"
"sync"

"github.com/go-errors/errors"
Expand All @@ -26,10 +25,9 @@ import (
)

func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
original, err := cliConfig.ToTomlBytes(map[string]interface{}{
"api": utils.Config.Api,
"db": utils.Config.Db,
})
copy := utils.Config.Clone()
copy.Auth.HashSecrets(projectRef)
original, err := cliConfig.ToTomlBytes(copy)
if err != nil {
fmt.Fprintln(utils.GetDebugLogger(), err)
}
Expand Down Expand Up @@ -64,10 +62,7 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(
fmt.Fprintln(os.Stdout, "Finished "+utils.Aqua("supabase link")+".")

// 4. Suggest config update
updated, err := cliConfig.ToTomlBytes(map[string]interface{}{
"api": utils.Config.Api,
"db": utils.Config.Db,
})
updated, err := cliConfig.ToTomlBytes(utils.Config.Clone())
if err != nil {
fmt.Fprintln(utils.GetDebugLogger(), err)
}
Expand All @@ -82,10 +77,10 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(
func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) {
// Ignore non-fatal errors linking services
var wg sync.WaitGroup
wg.Add(6)
wg.Add(8)
go func() {
defer wg.Done()
if err := linkDatabaseVersion(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
if err := linkDatabaseSettings(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
Expand All @@ -95,6 +90,18 @@ func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkGotrue(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkStorage(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkPooler(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
Expand Down Expand Up @@ -126,12 +133,11 @@ func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs
func linkPostgrest(ctx context.Context, projectRef string) error {
resp, err := utils.GetSupabase().V1GetPostgrestServiceConfigWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to get postgrest config: %w", err)
}
if resp.JSON200 == nil {
return errors.Errorf("%w: %s", tenant.ErrAuthToken, string(resp.Body))
return errors.Errorf("failed to read API config: %w", err)
} else if resp.JSON200 == nil {
return errors.Errorf("unexpected API config status %d: %s", resp.StatusCode(), string(resp.Body))
}
updateApiConfig(*resp.JSON200)
utils.Config.Api.FromRemoteApiConfig(*resp.JSON200)
return nil
}

Expand All @@ -143,22 +149,15 @@ func linkPostgrestVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.
return utils.WriteFile(utils.RestVersionPath, []byte(version), fsys)
}

func updateApiConfig(config api.PostgrestConfigWithJWTSecretResponse) {
utils.Config.Api.MaxRows = cast.IntToUint(config.MaxRows)
utils.Config.Api.ExtraSearchPath = readCsv(config.DbExtraSearchPath)
utils.Config.Api.Schemas = readCsv(config.DbSchema)
}

func readCsv(line string) []string {
var result []string
tokens := strings.Split(line, ",")
for _, t := range tokens {
trimmed := strings.TrimSpace(t)
if len(trimmed) > 0 {
result = append(result, trimmed)
}
func linkGotrue(ctx context.Context, projectRef string) error {
resp, err := utils.GetSupabase().V1GetAuthServiceConfigWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to read Auth config: %w", err)
} else if resp.JSON200 == nil {
return errors.Errorf("unexpected Auth config status %d: %s", resp.StatusCode(), string(resp.Body))
}
return result
utils.Config.Auth.FromRemoteAuthConfig(*resp.JSON200)
return nil
}

func linkGotrueVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs) error {
Expand All @@ -169,6 +168,17 @@ func linkGotrueVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs)
return utils.WriteFile(utils.GotrueVersionPath, []byte(version), fsys)
}

func linkStorage(ctx context.Context, projectRef string) error {
resp, err := utils.GetSupabase().V1GetStorageConfigWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to read Storage config: %w", err)
} else if resp.JSON200 == nil {
return errors.Errorf("unexpected Storage config status %d: %s", resp.StatusCode(), string(resp.Body))
}
utils.Config.Storage.FromRemoteStorageConfig(*resp.JSON200)
return nil
}

func linkStorageVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs) error {
version, err := api.GetStorageVersion(ctx)
if err != nil {
Expand All @@ -177,6 +187,17 @@ func linkStorageVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs
return utils.WriteFile(utils.StorageVersionPath, []byte(version), fsys)
}

func linkDatabaseSettings(ctx context.Context, projectRef string) error {
resp, err := utils.GetSupabase().V1GetPostgresConfigWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to read DB config: %w", err)
} else if resp.JSON200 == nil {
return errors.Errorf("unexpected DB config status %d: %s", resp.StatusCode(), string(resp.Body))
}
utils.Config.Db.Settings.FromRemotePostgresConfig(*resp.JSON200)
return nil
}

func linkDatabase(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) error {
conn, err := utils.ConnectByConfig(ctx, config, options...)
if err != nil {
Expand All @@ -191,14 +212,6 @@ func linkDatabase(ctx context.Context, config pgconn.Config, options ...func(*pg
return migration.CreateSeedTable(ctx, conn)
}

func linkDatabaseVersion(ctx context.Context, projectRef string, fsys afero.Fs) error {
version, err := tenant.GetDatabaseVersion(ctx, projectRef)
if err != nil {
return err
}
return utils.WriteFile(utils.PostgresVersionPath, []byte(version), fsys)
}

func updatePostgresConfig(conn *pgx.Conn) {
serverVersion := conn.PgConn().ParameterStatus("server_version")
// Safe to assume that supported Postgres version is 10.0 <= n < 100.0
Expand Down
4 changes: 2 additions & 2 deletions pkg/config/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (a *api) ToUpdatePostgrestConfigBody() v1API.UpdatePostgrestConfigBody {
return body
}

func (a *api) fromRemoteApiConfig(remoteConfig v1API.PostgrestConfigWithJWTSecretResponse) {
func (a *api) FromRemoteApiConfig(remoteConfig v1API.PostgrestConfigWithJWTSecretResponse) {
if a.Enabled = len(remoteConfig.DbSchema) > 0; !a.Enabled {
return
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func (a *api) DiffWithRemote(remoteConfig v1API.PostgrestConfigWithJWTSecretResp
if err != nil {
return nil, err
}
copy.fromRemoteApiConfig(remoteConfig)
copy.FromRemoteApiConfig(remoteConfig)
remoteCompare, err := ToTomlBytes(copy)
if err != nil {
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions pkg/config/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (a *auth) ToUpdateAuthConfigBody() v1API.UpdateAuthConfigBody {
return body
}

func (a *auth) fromRemoteAuthConfig(remoteConfig v1API.AuthConfigResponse) {
func (a *auth) FromRemoteAuthConfig(remoteConfig v1API.AuthConfigResponse) {
a.SiteUrl = cast.Val(remoteConfig.SiteUrl, "")
a.AdditionalRedirectUrls = strToArr(cast.Val(remoteConfig.UriAllowList, ""))
a.JwtExpiry = cast.IntToUint(cast.Val(remoteConfig.JwtExp, 0))
Expand Down Expand Up @@ -775,13 +775,13 @@ func (e external) fromAuthConfig(remoteConfig v1API.AuthConfigResponse) {

func (a *auth) DiffWithRemote(projectRef string, remoteConfig v1API.AuthConfigResponse) ([]byte, error) {
copy := a.Clone()
copy.hashSecrets(projectRef)
copy.HashSecrets(projectRef)
// Convert the config values into easily comparable remoteConfig values
currentValue, err := ToTomlBytes(copy)
if err != nil {
return nil, err
}
copy.fromRemoteAuthConfig(remoteConfig)
copy.FromRemoteAuthConfig(remoteConfig)
remoteCompare, err := ToTomlBytes(copy)
if err != nil {
return nil, err
Expand All @@ -791,7 +791,7 @@ func (a *auth) DiffWithRemote(projectRef string, remoteConfig v1API.AuthConfigRe

const hashPrefix = "hash:"

func (a *auth) hashSecrets(key string) {
func (a *auth) HashSecrets(key string) {
hash := func(v string) string {
return hashPrefix + sha256Hmac(key, v)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/config/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (a *settings) ToUpdatePostgresConfigBody() v1API.UpdatePostgresConfigBody {
return body
}

func (a *settings) fromRemoteConfig(remoteConfig v1API.PostgresConfigResponse) {
func (a *settings) FromRemotePostgresConfig(remoteConfig v1API.PostgresConfigResponse) {
a.EffectiveCacheSize = remoteConfig.EffectiveCacheSize
a.LogicalDecodingWorkMem = remoteConfig.LogicalDecodingWorkMem
a.MaintenanceWorkMem = remoteConfig.MaintenanceWorkMem
Expand Down Expand Up @@ -155,7 +155,7 @@ func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]
if err != nil {
return nil, err
}
copy.fromRemoteConfig(remoteConfig)
copy.FromRemotePostgresConfig(remoteConfig)
remoteCompare, err := ToTomlBytes(copy)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions pkg/config/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (s *storage) ToUpdateStorageConfigBody() v1API.UpdateStorageConfigBody {
return body
}

func (s *storage) fromRemoteStorageConfig(remoteConfig v1API.StorageConfigResponse) {
func (s *storage) FromRemoteStorageConfig(remoteConfig v1API.StorageConfigResponse) {
s.FileSizeLimit = sizeInBytes(remoteConfig.FileSizeLimit)
s.ImageTransformation.Enabled = remoteConfig.Features.ImageTransformation.Enabled
}
Expand All @@ -56,7 +56,7 @@ func (s *storage) DiffWithRemote(remoteConfig v1API.StorageConfigResponse) ([]by
if err != nil {
return nil, err
}
copy.fromRemoteStorageConfig(remoteConfig)
copy.FromRemoteStorageConfig(remoteConfig)
remoteCompare, err := ToTomlBytes(copy)
if err != nil {
return nil, err
Expand Down

0 comments on commit f4b5ad4

Please sign in to comment.