Skip to content

Commit

Permalink
Refactor driver selection to allow registering new drivers
Browse files Browse the repository at this point in the history
Signed-off-by: Sambhav Kothari <skothari44@bloomberg.net>
  • Loading branch information
sambhav committed Sep 25, 2024
1 parent 70a99f8 commit 8ef18f1
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 150 deletions.
139 changes: 139 additions & 0 deletions pkg/app/app.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package app

import (
"fmt"
"time"

"github.com/k3s-io/kine/pkg/endpoint"
"github.com/k3s-io/kine/pkg/metrics"
"github.com/k3s-io/kine/pkg/signals"
"github.com/k3s-io/kine/pkg/version"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
)

var (
config endpoint.Config
metricsConfig metrics.Config
)

func New() *cli.App {
app := cli.NewApp()
app.Name = "kine"
app.Usage = "Minimal etcd v3 API to support custom Kubernetes storage engines"
app.Version = fmt.Sprintf("%s (%s)", version.Version, version.GitCommit)
app.Flags = []cli.Flag{
&cli.StringFlag{
Name: "listen-address",
Value: "0.0.0.0:2379",
Destination: &config.Listener,
},
&cli.StringFlag{
Name: "endpoint",
Usage: "Storage endpoint (default is sqlite)",
Destination: &config.Endpoint,
},
&cli.StringFlag{
Name: "ca-file",
Usage: "CA cert for DB connection",
Destination: &config.BackendTLSConfig.CAFile,
},
&cli.StringFlag{
Name: "cert-file",
Usage: "Certificate for DB connection",
Destination: &config.BackendTLSConfig.CertFile,
},
&cli.StringFlag{
Name: "key-file",
Usage: "Key file for DB connection",
Destination: &config.BackendTLSConfig.KeyFile,
},
&cli.BoolFlag{
Name: "skip-verify",
Usage: "Whether the TLS client should verify the server certificate.",
Destination: &config.BackendTLSConfig.SkipVerify,
Value: false,
},
&cli.StringFlag{
Name: "metrics-bind-address",
Usage: "The address the metric endpoint binds to. Default :8080, set 0 to disable metrics serving.",
Destination: &metricsConfig.ServerAddress,
Value: ":8080",
},
&cli.StringFlag{
Name: "server-cert-file",
Usage: "Certificate for etcd connection",
Destination: &config.ServerTLSConfig.CertFile,
},
&cli.StringFlag{
Name: "server-key-file",
Usage: "Key file for etcd connection",
Destination: &config.ServerTLSConfig.KeyFile,
},
&cli.IntFlag{
Name: "datastore-max-idle-connections",
Usage: "Maximum number of idle connections retained by datastore. If value = 0, the system default will be used. If value < 0, idle connections will not be reused.",
Destination: &config.ConnectionPoolConfig.MaxIdle,
Value: 0,
},
&cli.IntFlag{
Name: "datastore-max-open-connections",
Usage: "Maximum number of open connections used by datastore. If value <= 0, then there is no limit",
Destination: &config.ConnectionPoolConfig.MaxOpen,
Value: 0,
},
&cli.DurationFlag{
Name: "datastore-connection-max-lifetime",
Usage: "Maximum amount of time a connection may be reused. If value <= 0, then there is no limit.",
Destination: &config.ConnectionPoolConfig.MaxLifetime,
Value: 0,
},
&cli.DurationFlag{
Name: "slow-sql-threshold",
Usage: "The duration which SQL executed longer than will be logged. Default 1s, set <= 0 to disable slow SQL log.",
Destination: &metrics.SlowSQLThreshold,
Value: time.Second,
},
&cli.BoolFlag{
Name: "metrics-enable-profiling",
Usage: "Enable net/http/pprof handlers on the metrics bind address. Default is false.",
Destination: &metricsConfig.EnableProfiling,
},
&cli.DurationFlag{
Name: "watch-progress-notify-interval",
Usage: "Interval between periodic watch progress notifications. Default is 5s to ensure support for watch progress notifications.",
Destination: &config.NotifyInterval,
Value: time.Second * 5,
},
&cli.StringFlag{
Name: "emulated-etcd-version",
Usage: "The emulated etcd version to return on a call to the status endpoint. Defaults to 3.5.13, in order to indicate support for watch progress notifications.",
Destination: &config.EmulatedETCDVersion,
Value: "3.5.13",
},
&cli.BoolFlag{Name: "debug"},
}
app.Action = run
return app
}

func run(c *cli.Context) error {
logrus.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: time.RFC3339Nano,
})
if c.Bool("debug") {
logrus.SetLevel(logrus.TraceLevel)
}
ctx := signals.SetupSignalContext()

metricsConfig.ServerTLSConfig = config.ServerTLSConfig
go metrics.Serve(ctx, metricsConfig)
config.MetricsRegisterer = metrics.Registry
_, err := endpoint.Listen(ctx, config)
if err != nil {
return err
}
<-ctx.Done()
return ctx.Err()
}
10 changes: 6 additions & 4 deletions pkg/drivers/dqlite/dqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/k3s-io/kine/pkg/drivers/sqlite"
"github.com/k3s-io/kine/pkg/server"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)

Expand All @@ -29,6 +28,7 @@ var (
)

func init() {
generic.RegisterDriver("dqlite", New)
// We assume SQLite will be used multi-threaded
if err := dqlite.ConfigMultiThread(); err != nil {
panic(errors.Wrap(err, "failed to set dqlite multithreaded mode"))
Expand Down Expand Up @@ -69,7 +69,8 @@ outer:
return nil
}

func New(ctx context.Context, datasourceName string, connPoolConfig generic.ConnectionPoolConfig, metricsRegisterer prometheus.Registerer) (server.Backend, error) {
func New(ctx context.Context, cfg *generic.Config) (bool, server.Backend, error) {
dataSourceName = cfg.Address
opts, err := parseOpts(datasourceName)
if err != nil {
return nil, err
Expand Down Expand Up @@ -98,7 +99,8 @@ func New(ctx context.Context, datasourceName string, connPoolConfig generic.Conn
}

sql.Register("dqlite", d)
backend, generic, err := sqlite.NewVariant(ctx, "dqlite", opts.dsn, connPoolConfig, metricsRegisterer)
cfg.Address = opts.dsn
backend, generic, err := sqlite.NewVariant(ctx, "dqlite", cfg)
if err != nil {
return nil, errors.Wrap(err, "sqlite client")
}
Expand All @@ -120,7 +122,7 @@ func New(ctx context.Context, datasourceName string, connPoolConfig generic.Conn
return err
}

return backend, nil
return true, backend, nil
}

func migrate(ctx context.Context, newDB *sql.DB) (exitErr error) {
Expand Down
9 changes: 6 additions & 3 deletions pkg/drivers/dqlite/no_dqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (

"github.com/k3s-io/kine/pkg/drivers/generic"
"github.com/k3s-io/kine/pkg/server"
"github.com/prometheus/client_golang/prometheus"
)

func New(ctx context.Context, datasourceName string, connPoolConfig generic.ConnectionPoolConfig, metricsRegisterer prometheus.Registerer) (server.Backend, error) {
return nil, errors.New(`this binary is built without dqlite support, compile with "-tags dqlite"`)
func New(ctx context.Context, cfg *generic.Config) (bool, server.Backend, error) {
return false, nil, errors.New(`this binary is built without dqlite support, compile with "-tags dqlite"`)
}

func init() {
generic.RegisterDriver("dqlite", New)
}
45 changes: 45 additions & 0 deletions pkg/drivers/factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package drivers

import (
"context"
"errors"
"strings"

"github.com/k3s-io/kine/pkg/drivers/generic"

"github.com/k3s-io/kine/pkg/server"
"github.com/k3s-io/kine/pkg/util"
)

var ErrUnknownDriver = errors.New("unknown driver")

func New(ctx context.Context, cfg *generic.Config) (leaderElect bool, backend server.Backend, err error) {
if cfg.Endpoint == "" {
driver := generic.GetDefaultDriver()
if driver == nil {
return false, nil, errors.New("no default driver found")
}
return driver(ctx, cfg)
}

if err := validateDSNuri(cfg.Endpoint); err != nil {
return false, nil, err
}

cfg.Scheme, cfg.DataSourceName = util.SchemeAndAddress(cfg.Endpoint)

driver, ok := generic.GetDriver(cfg.Scheme)
if !ok {
return false, nil, ErrUnknownDriver
}
return driver(ctx, cfg)
}

// validateDSNuri ensure that the given string is of that format <scheme://<authority>
func validateDSNuri(str string) error {
parts := strings.SplitN(str, "://", 2)
if len(parts) > 1 {
return nil
}
return errors.New("invalid datastore endpoint; endpoint should be a DSN URI in the format <scheme>://<authority>")
}
35 changes: 35 additions & 0 deletions pkg/drivers/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/Rican7/retry/strategy"
"github.com/k3s-io/kine/pkg/metrics"
"github.com/k3s-io/kine/pkg/server"
"github.com/k3s-io/kine/pkg/tls"
"github.com/k3s-io/kine/pkg/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
Expand All @@ -25,6 +26,40 @@ const (
defaultMaxIdleConns = 2 // copied from database/sql
)

type Config struct {
MetricsRegisterer prometheus.Registerer
Endpoint string
Scheme string
DataSourceName string
ConnectionPoolConfig ConnectionPoolConfig
BackendTLSConfig tls.Config
}

type Constructor func(ctx context.Context, cfg *Config) (leaderElect bool, backend server.Backend, err error)

var driverRegistry map[string]Constructor = map[string]Constructor{}
var defaultDriver string

func RegisterDriver(driver string, constructor Constructor) {
driverRegistry[driver] = constructor
}

func SetDefaultDriver(driver string) {
defaultDriver = driver
}

func GetDefaultDriver() Constructor {
return driverRegistry[defaultDriver]
}

func GetDriver(scheme string) (Constructor, bool) {
constructor, ok := driverRegistry[scheme]
if !ok {
return nil, false
}
return constructor, true
}

// explicit interface check
var _ server.Dialect = (*Generic)(nil)

Expand Down
17 changes: 17 additions & 0 deletions pkg/drivers/http/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package http

import (
"context"

"github.com/k3s-io/kine/pkg/drivers/generic"
"github.com/k3s-io/kine/pkg/server"
)

func New(ctx context.Context, cfg *generic.Config) (leaderElect bool, backend server.Backend, err error) {
return true, nil, nil
}

func init() {
generic.RegisterDriver("http", New)
generic.RegisterDriver("https", New)
}
26 changes: 14 additions & 12 deletions pkg/drivers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
"strconv"

"github.com/go-sql-driver/mysql"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"

"github.com/k3s-io/kine/pkg/drivers/generic"
"github.com/k3s-io/kine/pkg/logstructured"
"github.com/k3s-io/kine/pkg/logstructured/sqllog"
"github.com/k3s-io/kine/pkg/server"
"github.com/k3s-io/kine/pkg/tls"
"github.com/k3s-io/kine/pkg/util"
)

Expand Down Expand Up @@ -52,28 +50,28 @@ var (
createDB = "CREATE DATABASE IF NOT EXISTS "
)

func New(ctx context.Context, dataSourceName string, tlsInfo tls.Config, connPoolConfig generic.ConnectionPoolConfig, metricsRegisterer prometheus.Registerer) (server.Backend, error) {
tlsConfig, err := tlsInfo.ClientConfig()
func New(ctx context.Context, cfg *generic.Config) (bool, server.Backend, error) {
tlsConfig, err := cfg.BackendTLSConfig.ClientConfig()
if err != nil {
return nil, err
return false, nil, err
}

if tlsConfig != nil {
tlsConfig.MinVersion = cryptotls.VersionTLS11
}

parsedDSN, err := prepareDSN(dataSourceName, tlsConfig)
parsedDSN, err := prepareDSN(cfg.DataSourceName, tlsConfig)
if err != nil {
return nil, err
return false, nil, err
}

if err := createDBIfNotExist(parsedDSN); err != nil {
return nil, err
return false, nil, err
}

dialect, err := generic.Open(ctx, "mysql", parsedDSN, connPoolConfig, "?", false, metricsRegisterer)
dialect, err := generic.Open(ctx, "mysql", parsedDSN, cfg.ConnectionPoolConfig, "?", false, cfg.MetricsRegisterer)
if err != nil {
return nil, err
return false, nil, err
}

dialect.LastInsertID = true
Expand Down Expand Up @@ -114,11 +112,11 @@ func New(ctx context.Context, dataSourceName string, tlsInfo tls.Config, connPoo
return err.Error()
}
if err := setup(dialect.DB); err != nil {
return nil, err
return false, nil, err
}

dialect.Migrate(context.Background())
return logstructured.New(sqllog.New(dialect)), nil
return true, logstructured.New(sqllog.New(dialect)), nil
}

func setup(db *sql.DB) error {
Expand Down Expand Up @@ -223,3 +221,7 @@ func prepareDSN(dataSourceName string, tlsConfig *cryptotls.Config) (string, err

return parsedDSN, nil
}

func init() {
generic.RegisterDriver("mysql", New)
}
Loading

0 comments on commit 8ef18f1

Please sign in to comment.