From be654f52742e2a167ce36379e9d82b3440278c81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mihkel=20P=C3=A4rna?= Date: Mon, 12 Feb 2024 14:56:42 +0200 Subject: [PATCH] Wrap redis client creation to a separate function to return different implementations based on 'enable-tls' flag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mihkel Pärna --- cmd/backfill-redis/main.go | 40 +++++++++++++++++++++++++----------- cmd/rekor-server/app/root.go | 3 ++- pkg/api/api.go | 29 ++++++++++++++++++-------- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/cmd/backfill-redis/main.go b/cmd/backfill-redis/main.go index 5c37626a0..b6eeebe69 100644 --- a/cmd/backfill-redis/main.go +++ b/cmd/backfill-redis/main.go @@ -70,7 +70,8 @@ var ( redisPassword = flag.String("password", "", "Password for Redis authentication") startIndex = flag.Int("start", -1, "First index to backfill") endIndex = flag.Int("end", -1, "Last index to backfill") - insecureSkipVerify = flag.Bool("insecure-skip-verify", false, "Whether to skip TLS verification or not") + enableTls = flag.Bool("enable-tls", false, "Enable TLS for Redis client") + insecureSkipVerify = flag.Bool("insecure-skip-verify", false, "Whether to skip TLS verification for Redis client or not") rekorAddress = flag.String("rekor-address", "", "Address for Rekor, e.g. https://rekor.sigstore.dev") versionFlag = flag.Bool("version", false, "Print the current version of Backfill Redis") concurrency = flag.Int("concurrency", 1, "Number of workers to use for backfill") @@ -103,18 +104,8 @@ func main() { } log.Printf("running backfill redis Version: %s GitCommit: %s BuildDate: %s", versionInfo.GitVersion, versionInfo.GitCommit, versionInfo.BuildDate) - // #nosec G402 - tlsConfig := &tls.Config{ - InsecureSkipVerify: *insecureSkipVerify, - } - redisClient := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", *redisHostname, *redisPort), - Password: *redisPassword, - Network: "tcp", - TLSConfig: tlsConfig, - DB: 0, // default DB - }) + redisClient := redisClient() rekorClient, err := client.GetRekorClient(*rekorAddress) if err != nil { @@ -217,6 +208,31 @@ func main() { } } +func redisClient() *redis.Client { + + // #nosec G402 + tlsConfig := &tls.Config{ + InsecureSkipVerify: *insecureSkipVerify, + } + + if *enableTls == true { + return redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%s", *redisHostname, *redisPort), + Password: *redisPassword, + Network: "tcp", + TLSConfig: tlsConfig, + DB: 0, // default DB + }) + } else { + return redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%s", *redisHostname, *redisPort), + Password: *redisPassword, + Network: "tcp", + DB: 0, // default DB + }) + } +} + // unmarshalEntryImpl decodes the base64-encoded entry to a specific entry type (types.EntryImpl). // Taken from Cosign func unmarshalEntryImpl(e string) (types.EntryImpl, string, string, error) { diff --git a/cmd/rekor-server/app/root.go b/cmd/rekor-server/app/root.go index 6bf50ef61..a17054ff9 100644 --- a/cmd/rekor-server/app/root.go +++ b/cmd/rekor-server/app/root.go @@ -115,7 +115,8 @@ Memory and file-based signers should only be used for testing.`) rootCmd.PersistentFlags().String("redis_server.address", "127.0.0.1", "Redis server address") rootCmd.PersistentFlags().Uint16("redis_server.port", 6379, "Redis server port") rootCmd.PersistentFlags().String("redis_server.password", "", "Redis server password") - rootCmd.PersistentFlags().Bool("redis_server.insecure-skip-verify", false, "Whether to skip TLS verification when connecting to Redis endpoint") + rootCmd.PersistentFlags().Bool("redis_server.enable-tls", false, "Whether to enable TLS verification when connecting to Redis endpoint") + rootCmd.PersistentFlags().Bool("redis_server.insecure-skip-verify", false, "Whether to skip TLS verification when connecting to Redis endpoint, only applicable when 'redis_server.enable-tls' is set to 'true'") rootCmd.PersistentFlags().Bool("enable_attestation_storage", false, "enables rich attestation storage") rootCmd.PersistentFlags().String("attestation_storage_bucket", "", "url for attestation storage bucket") diff --git a/pkg/api/api.go b/pkg/api/api.go index 43932bcab..6e6bfc804 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -176,7 +176,20 @@ func ConfigureAPI(treeID uint) { } if viper.GetBool("enable_stable_checkpoint") { - redisClient = redis.NewClient(&redis.Options{ + redisClient = NewRedisClient() + checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.ActiveTreeID(), + viper.GetString("rekor_server.hostname"), api.signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount) + + // create context to cancel goroutine on server shutdown + ctx, cancel := context.WithCancel(context.Background()) + api.checkpointPublishCancel = cancel + checkpointPublisher.StartPublisher(ctx) + } +} + +func NewRedisClient() *redis.Client { + if viper.GetBool("redis_server.enable-tls") == true { + return redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%v:%v", viper.GetString("redis_server.address"), viper.GetUint64("redis_server.port")), Password: viper.GetString("redis_server.password"), Network: "tcp", @@ -186,13 +199,13 @@ func ConfigureAPI(treeID uint) { }, DB: 0, // default DB }) - checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.ActiveTreeID(), - viper.GetString("rekor_server.hostname"), api.signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount) - - // create context to cancel goroutine on server shutdown - ctx, cancel := context.WithCancel(context.Background()) - api.checkpointPublishCancel = cancel - checkpointPublisher.StartPublisher(ctx) + } else { + return redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%v:%v", viper.GetString("redis_server.address"), viper.GetUint64("redis_server.port")), + Password: viper.GetString("redis_server.password"), + Network: "tcp", + DB: 0, // default DB + }) } }