From 9982540d99b63cecaca9af2d8fae51133d6cffd7 Mon Sep 17 00:00:00 2001 From: Robbie Cronin Date: Tue, 26 Nov 2024 05:03:33 +1100 Subject: [PATCH] Refactor guacgql command to support out-of-tree backends (#2247) Signed-off-by: robert-cronin Co-authored-by: Parth Patel <88045217+pxp928@users.noreply.github.com> --- cmd/guacgql/cmd/ent.go | 55 --------- cmd/guacgql/cmd/root.go | 77 +++---------- cmd/guacgql/cmd/server.go | 107 +++--------------- pkg/assembler/backends/arangodb/backend.go | 34 +++++- pkg/assembler/backends/ent/backend/backend.go | 47 ++++++++ .../backends/ent/backend/register.go | 2 +- pkg/assembler/backends/keyvalue/backend.go | 54 ++++++++- .../assembler/backends/keyvalue}/tikv.go | 2 +- pkg/assembler/backends/neo4j/backend.go | 37 +++++- pkg/assembler/backends/neptune/neptune.go | 46 +++++++- pkg/assembler/backends/register.go | 38 ++++++- pkg/cli/store.go | 26 ----- 12 files changed, 281 insertions(+), 244 deletions(-) delete mode 100644 cmd/guacgql/cmd/ent.go rename {cmd/guacgql/cmd => pkg/assembler/backends/keyvalue}/tikv.go (98%) diff --git a/cmd/guacgql/cmd/ent.go b/cmd/guacgql/cmd/ent.go deleted file mode 100644 index b3685d6809..0000000000 --- a/cmd/guacgql/cmd/ent.go +++ /dev/null @@ -1,55 +0,0 @@ -// -// Copyright 2023 The GUAC Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !(386 || arm || mips) - -package cmd - -import ( - "context" - "fmt" - "os" - "time" - - "github.com/guacsec/guac/pkg/assembler/backends" - entbackend "github.com/guacsec/guac/pkg/assembler/backends/ent/backend" -) - -func init() { - if getOpts == nil { - getOpts = make(map[string]optsFunc) - } - getOpts[ent] = getEnt -} - -func getEnt(_ context.Context) backends.BackendArgs { - var connTimeout *time.Duration - if flags.dbConnTime != "" { - if timeout, err := time.ParseDuration(flags.dbConnTime); err != nil { - fmt.Printf("failed to parser duration with error: %v", err) - os.Exit(1) - } else { - connTimeout = &timeout - } - } - - return &entbackend.BackendOptions{ - DriverName: flags.dbDriver, - Address: flags.dbAddress, - Debug: flags.dbDebug, - AutoMigrate: flags.dbMigrate, - ConnectionMaxLifeTime: connTimeout, - } -} diff --git a/cmd/guacgql/cmd/root.go b/cmd/guacgql/cmd/root.go index 547c83b57d..87001e515b 100644 --- a/cmd/guacgql/cmd/root.go +++ b/cmd/guacgql/cmd/root.go @@ -20,6 +20,7 @@ import ( "os" "strings" + "github.com/guacsec/guac/pkg/assembler/backends" "github.com/guacsec/guac/pkg/cli" "github.com/guacsec/guac/pkg/version" "github.com/spf13/cobra" @@ -34,35 +35,6 @@ var flags = struct { tlsKeyFile string debug bool tracegql bool - - // Needed only if using neo4j backend - nAddr string - nUser string - nPass string - nRealm string - - // Needed only if using ent backend - dbAddress string - dbDriver string - dbDebug bool - dbMigrate bool - dbConnTime string - - // Needed only if using arangodb backend - arangoAddr string - arangoUser string - arangoPass string - - // Needed only if using neptune backend - neptuneEndpoint string - neptunePort int - neptuneRegion string - neptuneUser string - neptuneRealm string - - kvStore string - kvRedis string - kvTiKV string }{} var rootCmd = &cobra.Command{ @@ -77,31 +49,6 @@ var rootCmd = &cobra.Command{ flags.debug = viper.GetBool("gql-debug") flags.tracegql = viper.GetBool("gql-trace") - flags.nUser = viper.GetString("neo4j-user") - flags.nPass = viper.GetString("neo4j-pass") - flags.nAddr = viper.GetString("neo4j-addr") - flags.nRealm = viper.GetString("neo4j-realm") - - // Needed only if using ent backend - flags.dbAddress = viper.GetString("db-address") - flags.dbDriver = viper.GetString("db-driver") - flags.dbDebug = viper.GetBool("db-debug") - flags.dbMigrate = viper.GetBool("db-migrate") - flags.dbConnTime = viper.GetString("db-conn-time") - - flags.arangoUser = viper.GetString("arango-user") - flags.arangoPass = viper.GetString("arango-pass") - flags.arangoAddr = viper.GetString("arango-addr") - - flags.neptuneEndpoint = viper.GetString("neptune-endpoint") - flags.neptunePort = viper.GetInt("neptune-port") - flags.neptuneRegion = viper.GetString("neptune-region") - flags.neptuneUser = viper.GetString("neptune-user") - flags.neptuneRealm = viper.GetString("neptune-realm") - - flags.kvStore = viper.GetString("kv-store") - flags.kvRedis = viper.GetString("kv-redis") - flags.kvTiKV = viper.GetString("kv-tikv") startServer(cmd) }, } @@ -109,19 +56,29 @@ var rootCmd = &cobra.Command{ func init() { cobra.OnInitialize(cli.InitConfig) + // Register common flags set, err := cli.BuildFlags([]string{ - "arango-addr", "arango-user", "arango-pass", - "neo4j-addr", "neo4j-user", "neo4j-pass", "neo4j-realm", - "neptune-endpoint", "neptune-port", "neptune-region", "neptune-user", "neptune-realm", - "gql-listen-port", "gql-tls-cert-file", "gql-tls-key-file", "gql-debug", "gql-backend", "gql-trace", - "db-address", "db-driver", "db-debug", "db-migrate", "db-conn-time", - "kv-store", "kv-redis", "kv-tikv", "enable-prometheus", + "gql-listen-port", + "gql-tls-cert-file", + "gql-tls-key-file", + "gql-debug", + "gql-backend", + "gql-trace", + "enable-prometheus", }) if err != nil { fmt.Fprintf(os.Stderr, "failed to setup flag: %v", err) os.Exit(1) } rootCmd.Flags().AddFlagSet(set) + + // Register backend-specific flags + err = backends.RegisterFlags(rootCmd) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to register backend flags: %v", err) + os.Exit(1) + } + if err := viper.BindPFlags(rootCmd.Flags()); err != nil { fmt.Fprintf(os.Stderr, "failed to bind flags: %v", err) os.Exit(1) diff --git a/cmd/guacgql/cmd/server.go b/cmd/guacgql/cmd/server.go index 845d5aea76..4339ab583f 100644 --- a/cmd/guacgql/cmd/server.go +++ b/cmd/guacgql/cmd/server.go @@ -26,47 +26,24 @@ import ( "syscall" "time" - "github.com/guacsec/guac/pkg/version" + // import all known backends + _ "github.com/guacsec/guac/pkg/assembler/backends/neo4j" + _ "github.com/guacsec/guac/pkg/assembler/backends/neptune" + _ "github.com/guacsec/guac/pkg/assembler/backends/ent/backend" + _ "github.com/guacsec/guac/pkg/assembler/backends/keyvalue" + _ "github.com/guacsec/guac/pkg/assembler/backends/arangodb" "github.com/99designs/gqlgen/graphql/handler/debug" "github.com/99designs/gqlgen/graphql/playground" "github.com/guacsec/guac/pkg/assembler/backends" - "github.com/guacsec/guac/pkg/assembler/backends/arangodb" - _ "github.com/guacsec/guac/pkg/assembler/backends/keyvalue" - "github.com/guacsec/guac/pkg/assembler/backends/neo4j" - "github.com/guacsec/guac/pkg/assembler/backends/neptune" - "github.com/guacsec/guac/pkg/assembler/kv" - "github.com/guacsec/guac/pkg/assembler/kv/redis" "github.com/guacsec/guac/pkg/assembler/server" "github.com/guacsec/guac/pkg/logging" "github.com/guacsec/guac/pkg/metrics" + "github.com/guacsec/guac/pkg/version" "github.com/spf13/cobra" "github.com/spf13/viper" - "golang.org/x/exp/maps" ) -const ( - arango = "arango" - neo4js = "neo4j" - ent = "ent" - neptunes = "neptune" - keyvalue = "keyvalue" -) - -type optsFunc func(context.Context) backends.BackendArgs - -var getOpts map[string]optsFunc - -func init() { - if getOpts == nil { - getOpts = make(map[string]optsFunc) - } - getOpts[arango] = getArango - getOpts[neo4js] = getNeo4j - getOpts[neptunes] = getNeptune - getOpts[keyvalue] = getKeyValue -} - func startServer(cmd *cobra.Command) { var srvHandler http.Handler ctx := logging.WithLogger(context.Background()) @@ -78,9 +55,15 @@ func startServer(cmd *cobra.Command) { os.Exit(1) } - backend, err := backends.Get(flags.backend, ctx, getOpts[flags.backend](ctx)) + backendArgs, err := backends.GetBackendArgs(ctx, flags.backend) if err != nil { - logger.Errorf("error creating %v backend: %w", flags.backend, err) + logger.Errorf("failed to parse backend flags with error: %v", err) + os.Exit(1) + } + + backend, err := backends.Get(flags.backend, ctx, backendArgs) + if err != nil { + logger.Errorf("Error creating %v backend: %v", flags.backend, err) os.Exit(1) } @@ -161,15 +144,9 @@ func setupPrometheus(ctx context.Context, name string) (metrics.MetricCollector, } func validateFlags() error { - if !slices.Contains(maps.Keys(getOpts), flags.backend) { - return fmt.Errorf("invalid graphql backend specified: %v", flags.backend) - } if !slices.Contains(backends.List(), flags.backend) { return fmt.Errorf("invalid graphql backend specified: %v", flags.backend) } - if !slices.Contains([]string{"memmap", "redis", "tikv"}, flags.kvStore) { - return fmt.Errorf("invalid kv store specified: %v", flags.kvStore) - } return nil } @@ -183,57 +160,3 @@ func versionHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = fmt.Fprint(w, version.Version) } - -func getArango(_ context.Context) backends.BackendArgs { - return &arangodb.ArangoConfig{ - User: flags.arangoUser, - Pass: flags.arangoPass, - DBAddr: flags.arangoAddr, - } -} - -func getNeo4j(_ context.Context) backends.BackendArgs { - return &neo4j.Neo4jConfig{ - User: flags.nUser, - Pass: flags.nPass, - Realm: flags.nRealm, - DBAddr: flags.nAddr, - } -} - -var tikvGS func(context.Context, string) (kv.Store, error) - -func getKeyValue(ctx context.Context) backends.BackendArgs { - logger := logging.FromContext(ctx) - switch flags.kvStore { - case "memmap": - // default is memmap - return nil - case "redis": - s, err := redis.GetStore(flags.kvRedis) - if err != nil { - logger.Fatalf("error with Redis: %v", err) - } - return s - case "tikv": - if tikvGS == nil { - logger.Fatal("TiKV not supported on 32-bit") - } - s, err := tikvGS(ctx, flags.kvTiKV) - if err != nil { - logger.Fatalf("error with TiKV: %v", err) - } - return s - } - return nil -} - -func getNeptune(_ context.Context) backends.BackendArgs { - return &neptune.NeptuneConfig{ - Endpoint: flags.neptuneEndpoint, - Port: flags.neptunePort, - Region: flags.neptuneRegion, - User: flags.neptuneUser, - Realm: flags.neptuneRealm, - } -} diff --git a/pkg/assembler/backends/arangodb/backend.go b/pkg/assembler/backends/arangodb/backend.go index f81050942e..5e74aa3977 100644 --- a/pkg/assembler/backends/arangodb/backend.go +++ b/pkg/assembler/backends/arangodb/backend.go @@ -23,6 +23,8 @@ import ( "time" "github.com/99designs/gqlgen/graphql" + "github.com/spf13/cobra" + "github.com/spf13/viper" jsoniter "github.com/json-iterator/go" @@ -42,6 +44,13 @@ type ArangoConfig struct { TestData bool } +// flags holds the command-line flags for ArangoDB configuration +var flags = struct { + addr string + user string + pass string +}{} + type arangoQueryBuilder struct { query strings.Builder } @@ -59,7 +68,7 @@ type index struct { } func init() { - backends.Register("arango", getBackend) + backends.Register("arango", getBackend, registerFlags, parseFlags) } func initIndex(name string, fields []string, unique bool) index { @@ -117,6 +126,29 @@ func DeleteDatabase(ctx context.Context, args backends.BackendArgs) error { return nil } +// registerFlags registers ArangoDB-specific command line flags +func registerFlags(cmd *cobra.Command) error { + flagSet := cmd.Flags() + flagSet.StringVar(&flags.addr, "arango-addr", "http://localhost:8529", "address to arango db") + flagSet.StringVar(&flags.user, "arango-user", "", "arango user to connect to graph db") + flagSet.StringVar(&flags.pass, "arango-pass", "", "arango password to connect to graph db") + + if err := viper.BindPFlags(flagSet); err != nil { + return fmt.Errorf("failed to bind flags: %w", err) + } + + return nil +} + +// parseFlags returns the ArangoDB configuration from parsed flags +func parseFlags(ctx context.Context) (backends.BackendArgs, error) { + return &ArangoConfig{ + DBAddr: flags.addr, + User: flags.user, + Pass: flags.pass, + }, nil +} + func getBackend(ctx context.Context, args backends.BackendArgs) (backends.Backend, error) { config, ok := args.(*ArangoConfig) if !ok { diff --git a/pkg/assembler/backends/ent/backend/backend.go b/pkg/assembler/backends/ent/backend/backend.go index a4011642bd..4141ca2770 100644 --- a/pkg/assembler/backends/ent/backend/backend.go +++ b/pkg/assembler/backends/ent/backend/backend.go @@ -18,9 +18,12 @@ package backend import ( "context" "fmt" + "time" "github.com/guacsec/guac/pkg/assembler/backends" "github.com/guacsec/guac/pkg/assembler/backends/ent" + "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/vektah/gqlparser/v2/gqlerror" // Import regular postgres driver @@ -38,6 +41,50 @@ type EntBackend struct { client *ent.Client } +// flags holds the command-line flags for Ent configuration +var flags = struct { + dbAddress string + dbDriver string + dbDebug bool + dbMigrate bool + dbConnTime string +}{} + +// registerFlags registers Ent-specific command line flags +func registerFlags(cmd *cobra.Command) error { + flagSet := cmd.Flags() + flagSet.StringVar(&flags.dbAddress, "db-address", "postgres://guac:guac@0.0.0.0:5432/guac?sslmode=disable", "Full URL of database to connect to") + flagSet.StringVar(&flags.dbDriver, "db-driver", "postgres", "database driver to use, one of [postgres | sqlite3 | mysql] or anything supported by sql.DB") + flagSet.BoolVar(&flags.dbDebug, "db-debug", false, "enable debug logging for database queries") + flagSet.BoolVar(&flags.dbMigrate, "db-migrate", true, "automatically run database migrations on start") + flagSet.StringVar(&flags.dbConnTime, "db-conn-time", "", "sets the maximum amount of time a connection may be reused in m, h, s, etc.") + + if err := viper.BindPFlags(flagSet); err != nil { + return fmt.Errorf("failed to bind flags: %w", err) + } + + return nil +} + +// parseFlags returns the Ent configuration from parsed flags +func parseFlags(ctx context.Context) (backends.BackendArgs, error) { + var connTimeout *time.Duration + if flags.dbConnTime != "" { + if timeout, err := time.ParseDuration(flags.dbConnTime); err == nil { + connTimeout = &timeout + } else { + return nil, fmt.Errorf("failed to parse duration with error: %w", err) + } + } + return &BackendOptions{ + DriverName: flags.dbDriver, + Address: flags.dbAddress, + Debug: flags.dbDebug, + AutoMigrate: flags.dbMigrate, + ConnectionMaxLifeTime: connTimeout, + }, nil +} + func getBackend(ctx context.Context, args backends.BackendArgs) (backends.Backend, error) { config, ok := args.(*BackendOptions) if !ok { diff --git a/pkg/assembler/backends/ent/backend/register.go b/pkg/assembler/backends/ent/backend/register.go index 5daf384675..a0b134e075 100644 --- a/pkg/assembler/backends/ent/backend/register.go +++ b/pkg/assembler/backends/ent/backend/register.go @@ -20,5 +20,5 @@ package backend import "github.com/guacsec/guac/pkg/assembler/backends" func init() { - backends.Register("ent", getBackend) + backends.Register("ent", getBackend, registerFlags, parseFlags) } diff --git a/pkg/assembler/backends/keyvalue/backend.go b/pkg/assembler/backends/keyvalue/backend.go index 05f08f659d..0be6c1362d 100644 --- a/pkg/assembler/backends/keyvalue/backend.go +++ b/pkg/assembler/backends/keyvalue/backend.go @@ -31,10 +31,62 @@ import ( "github.com/guacsec/guac/pkg/assembler/graphql/model" "github.com/guacsec/guac/pkg/assembler/kv" "github.com/guacsec/guac/pkg/assembler/kv/memmap" + "github.com/guacsec/guac/pkg/assembler/kv/redis" + "github.com/spf13/cobra" + "github.com/spf13/viper" ) +// flags holds the command-line flags for KeyValue configuration +var flags = struct { + kvStore string + kvRedis string + kvTiKV string +}{} + +// registerFlags registers KeyValue-specific command line flags +func registerFlags(cmd *cobra.Command) error { + flagSet := cmd.Flags() + flagSet.StringVar(&flags.kvStore, "kv-store", "memmap", "Which keyvalue store to use: memmap, redis, tikv.") + flagSet.StringVar(&flags.kvRedis, "kv-redis", "redis://user@localhost:6379/0", "Experimental: Redis connection string for keyvalue backend") + flagSet.StringVar(&flags.kvTiKV, "kv-tikv", "127.0.0.1:2379", "Experimental: TiKV address and port") + + if err := viper.BindPFlags(flagSet); err != nil { + return fmt.Errorf("failed to bind flags: %w", err) + } + + return nil +} + +var tikvGS func(context.Context, string) (kv.Store, error) + +// parseFlags returns the KeyValue store configuration from parsed flags +func parseFlags(ctx context.Context) (backends.BackendArgs, error) { + switch flags.kvStore { + case "memmap": + // default is memmap + return nil, nil + case "redis": + s, err := redis.GetStore(flags.kvRedis) + if err != nil { + return nil, fmt.Errorf("error with Redis: %w", err) + } + return s, nil + case "tikv": + if tikvGS == nil { + return nil, fmt.Errorf("TiKV not supported on 32-bit") + } + s, err := tikvGS(ctx, flags.kvTiKV) + if err != nil { + return nil, fmt.Errorf("error with TiKV: %w", err) + } + return s, nil + } + // default is memmap + return nil, fmt.Errorf("invalid kv store specified: %v", flags.kvStore) +} + func init() { - backends.Register("keyvalue", getBackend) + backends.Register("keyvalue", getBackend, registerFlags, parseFlags) } // node is the common interface of all backend nodes. diff --git a/cmd/guacgql/cmd/tikv.go b/pkg/assembler/backends/keyvalue/tikv.go similarity index 98% rename from cmd/guacgql/cmd/tikv.go rename to pkg/assembler/backends/keyvalue/tikv.go index 52fcfa8fb4..2e7337bc9b 100644 --- a/cmd/guacgql/cmd/tikv.go +++ b/pkg/assembler/backends/keyvalue/tikv.go @@ -15,7 +15,7 @@ //go:build !(386 || arm || mips || darwin) -package cmd +package keyvalue import "github.com/guacsec/guac/pkg/assembler/kv/tikv" diff --git a/pkg/assembler/backends/neo4j/backend.go b/pkg/assembler/backends/neo4j/backend.go index 9bc353b077..9ddb719cbb 100644 --- a/pkg/assembler/backends/neo4j/backend.go +++ b/pkg/assembler/backends/neo4j/backend.go @@ -23,6 +23,8 @@ import ( "github.com/guacsec/guac/pkg/assembler/backends" "github.com/guacsec/guac/pkg/assembler/graphql/model" "github.com/neo4j/neo4j-go-driver/v4/neo4j" + "github.com/spf13/cobra" + "github.com/spf13/viper" ) const ( @@ -50,8 +52,41 @@ type neo4jClient struct { driver neo4j.Driver } +// flags holds the command-line flags for Neo4j configuration +var flags = struct { + addr string + user string + pass string + realm string +}{} + func init() { - backends.Register("neo4j", getBackend) + backends.Register("neo4j", getBackend, registerFlags, parseFlags) +} + +// registerFlags registers Neo4j-specific command line flags +func registerFlags(cmd *cobra.Command) error { + flagSet := cmd.Flags() + flagSet.StringVar(&flags.addr, "neo4j-addr", "neo4j://localhost:7687", "address to neo4j db") + flagSet.StringVar(&flags.user, "neo4j-user", "", "neo4j user credential to connect to graph db") + flagSet.StringVar(&flags.pass, "neo4j-pass", "", "neo4j password credential to connect to graph db") + flagSet.StringVar(&flags.realm, "neo4j-realm", "neo4j", "realm to connect to graph db") + + if err := viper.BindPFlags(flagSet); err != nil { + return fmt.Errorf("failed to bind flags: %w", err) + } + + return nil +} + +// parseFlags returns the Neo4j configuration from parsed flags +func parseFlags(ctx context.Context) (backends.BackendArgs, error) { + return &Neo4jConfig{ + DBAddr: flags.addr, + User: flags.user, + Pass: flags.pass, + Realm: flags.realm, + }, nil } func getBackend(_ context.Context, args backends.BackendArgs) (backends.Backend, error) { diff --git a/pkg/assembler/backends/neptune/neptune.go b/pkg/assembler/backends/neptune/neptune.go index 899ab2eb16..ac445f6307 100644 --- a/pkg/assembler/backends/neptune/neptune.go +++ b/pkg/assembler/backends/neptune/neptune.go @@ -29,16 +29,14 @@ import ( v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/guacsec/guac/pkg/assembler/backends" "github.com/guacsec/guac/pkg/assembler/backends/neo4j" + "github.com/spf13/cobra" + "github.com/spf13/viper" ) var json = jsoniter.ConfigCompatibleWithStandardLibrary const neptuneServiceName = "neptune-db" -func init() { - backends.Register("neptune", getBackend) -} - type NeptuneConfig struct { Endpoint string Port int @@ -47,6 +45,46 @@ type NeptuneConfig struct { Realm string } +// flags holds the command-line flags for Neptune configuration +var flags = struct { + endpoint string + port int + region string + user string + realm string +}{} + +func init() { + backends.Register("neptune", getBackend, registerFlags, parseFlags) +} + +// registerFlags registers Neptune-specific command line flags +func registerFlags(cmd *cobra.Command) error { + flagSet := cmd.Flags() + flagSet.StringVar(&flags.endpoint, "neptune-endpoint", "localhost", "address to neptune db") + flagSet.IntVar(&flags.port, "neptune-port", 8182, "port used for neptune db connection") + flagSet.StringVar(&flags.region, "neptune-region", "us-east-1", "region to connect to neptune db") + flagSet.StringVar(&flags.user, "neptune-user", "", "neptune user credential to connect to graph db") + flagSet.StringVar(&flags.realm, "neptune-realm", "neptune", "realm to connect to graph db") + + if err := viper.BindPFlags(flagSet); err != nil { + return fmt.Errorf("failed to bind flags: %w", err) + } + + return nil +} + +// parseFlags returns the Neptune configuration from parsed flags +func parseFlags(ctx context.Context) (backends.BackendArgs, error) { + return &NeptuneConfig{ + Endpoint: flags.endpoint, + Port: flags.port, + Region: flags.region, + User: flags.user, + Realm: flags.realm, + }, nil +} + func getBackend(ctx context.Context, args backends.BackendArgs) (backends.Backend, error) { config, ok := args.(*NeptuneConfig) if !ok { diff --git a/pkg/assembler/backends/register.go b/pkg/assembler/backends/register.go index ce9e2bde84..e679ea7469 100644 --- a/pkg/assembler/backends/register.go +++ b/pkg/assembler/backends/register.go @@ -17,20 +17,54 @@ package backends import ( "context" + "fmt" + "github.com/spf13/cobra" "golang.org/x/exp/maps" ) type GBFunc func(context.Context, BackendArgs) (Backend, error) +type FlagRegistrarFunc func(*cobra.Command) error +type FlagParserFunc func(ctx context.Context) (BackendArgs, error) -var getBackend map[string]GBFunc +var ( + getBackend map[string]GBFunc + flagRegistrar map[string]FlagRegistrarFunc + flagParser map[string]FlagParserFunc +) func init() { getBackend = make(map[string]GBFunc) + flagRegistrar = make(map[string]FlagRegistrarFunc) + flagParser = make(map[string]FlagParserFunc) } -func Register(name string, gb GBFunc) { +// Register registers a backend with its flag handling functions +func Register(name string, gb GBFunc, fr FlagRegistrarFunc, fp FlagParserFunc) { getBackend[name] = gb + flagRegistrar[name] = fr + flagParser[name] = fp +} + +// RegisterFlags registers all backend-specific flags to the given command +func RegisterFlags(cmd *cobra.Command) error { + var err error + for _, register := range flagRegistrar { + err = register(cmd) + if err != nil { + return err + } + } + + return nil +} + +// GetBackendArgs returns the parsed backend arguments for the given backend +func GetBackendArgs(ctx context.Context, name string) (BackendArgs, error) { + if parser, ok := flagParser[name]; ok { + return parser(ctx) + } + return nil, fmt.Errorf("backend %s not found", name) } func Get(name string, ctx context.Context, args BackendArgs) (Backend, error) { diff --git a/pkg/cli/store.go b/pkg/cli/store.go index 33e1da0869..20745123bb 100644 --- a/pkg/cli/store.go +++ b/pkg/cli/store.go @@ -51,11 +51,6 @@ func init() { set.Bool("gql-debug", false, "debug flag which enables the graphQL playground") set.Bool("gql-trace", false, "flag which enables tracing of graphQL requests and responses on the console") - set.String("neo4j-addr", "neo4j://localhost:7687", "address to neo4j db") - set.String("neo4j-user", "", "neo4j user credential to connect to graph db") - set.String("neo4j-pass", "", "neo4j password credential to connect to graph db") - set.String("neo4j-realm", "neo4j", "realm to connect to graph db") - // blob store address set.String("blob-addr", "file:///tmp/blobstore?no_tmp_dir=true", "gocloud connection string for blob store configured via https://gocloud.dev/howto/blob/ (default: filesystem)") @@ -75,22 +70,6 @@ func init() { // the ingestor will query and ingest endoflife.date for EOL set.Bool("add-eol-on-ingest", false, "if enabled, the ingestor will query and ingest endoflife.date for EOL data. Warning: This will increase ingestion times") - set.String("neptune-endpoint", "localhost", "address to neptune db") - set.Int("neptune-port", 8182, "port used for neptune db connection") - set.String("neptune-region", "us-east-1", "region to connect to neptune db") - set.String("neptune-user", "", "neptune user credential to connect to graph db") - set.String("neptune-realm", "neptune", "realm to connect to graph db") - - set.String("db-address", "postgres://guac:guac@0.0.0.0:5432/guac?sslmode=disable", "Full URL of database to connect to") - set.String("db-driver", "postgres", "database driver to use, one of [postgres | sqlite3 | mysql] or anything supported by sql.DB") - set.Bool("db-debug", false, "enable debug logging for database queries") - set.Bool("db-migrate", true, "automatically run database migrations on start") - set.String("db-conn-time", "", "sets the maximum amount of time a connection may be reused in m, h, s, etc.") - - set.String("arango-addr", "http://localhost:8529", "address to arango db") - set.String("arango-user", "", "arango user to connect to graph db") - set.String("arango-pass", "", "arango password to connect to graph db") - set.String("gql-addr", "http://localhost:8080/query", "endpoint used to connect to graphQL server") set.String("rest-api-server-port", "8081", "port to serve the REST API from") @@ -151,11 +130,6 @@ func init() { set.String("s3-queues", "", "comma-separated list of queue/topic names") set.String("s3-region", "us-east-1", "aws region") - // KeyValue Backend Store options. - set.String("kv-store", "memmap", "Which keyvalue store to use: memmap, redis, tikv.") - set.String("kv-redis", "redis://user@localhost:6379/0", "Experimental: Redis connection string for keyvalue backend") - set.String("kv-tikv", "127.0.0.1:2379", "Experimental: TiKV address and port") - // GitHub collector options set.String("github-mode", "release", "mode to run github collector in: [release | workflow]") set.String("github-sbom", "", "name of sbom file to look for in github release.")