From 0b34b73c47e049b450552b5d2f0983d2a6962ccc Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Wed, 21 Sep 2022 12:25:04 -0700 Subject: [PATCH] Check if plugin version matches running version (#17182) Check if plugin version matches running version When registering a plugin, we check if the request version matches the self-reported version from the plugin. If these do not match, we log a warning. This uncovered a few missing pieces for getting the database version code fully working. We added an environment variable that helps us unit test the running version behavior as well, but only for approle, postgresql, and consul plugins. Return 400 on plugin not found or version mismatch Populate the running SHA256 of plugins in the mount and auth tables (#17217) --- builtin/credential/approle/backend.go | 8 +- builtin/logical/consul/backend.go | 6 +- plugins/database/postgresql/postgresql.go | 11 +- sdk/database/dbplugin/v5/grpc_server.go | 33 ++++- sdk/database/dbplugin/v5/middleware.go | 12 +- sdk/helper/pluginutil/multiplexing.go | 7 +- vault/auth.go | 22 +-- vault/external_plugin_test.go | 152 ++++++++++++++++---- vault/identity_store_entities_test.go | 4 +- vault/logical_system.go | 3 + vault/mount.go | 24 ++-- vault/plugin_catalog.go | 163 ++++++++++++++++++++-- vault/plugin_reload.go | 21 ++- 13 files changed, 394 insertions(+), 72 deletions(-) diff --git a/builtin/credential/approle/backend.go b/builtin/credential/approle/backend.go index 9612b65f7afd..ebd8d3c06a80 100644 --- a/builtin/credential/approle/backend.go +++ b/builtin/credential/approle/backend.go @@ -18,6 +18,9 @@ const ( secretIDAccessorLocalPrefix = "accessor_local/" ) +// ReportedVersion is used to report a specific version to Vault. +var ReportedVersion = "" + type backend struct { *framework.Backend @@ -111,8 +114,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { pathTidySecretID(b), }, ), - Invalidate: b.invalidate, - BackendType: logical.TypeCredential, + Invalidate: b.invalidate, + BackendType: logical.TypeCredential, + RunningVersion: ReportedVersion, } return b, nil } diff --git a/builtin/logical/consul/backend.go b/builtin/logical/consul/backend.go index 3e37d1510d6e..7fce10e26294 100644 --- a/builtin/logical/consul/backend.go +++ b/builtin/logical/consul/backend.go @@ -7,6 +7,9 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +// ReportedVersion is used to report a specific version to Vault. +var ReportedVersion = "" + func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { b := Backend() if err := b.Setup(ctx, conf); err != nil { @@ -34,7 +37,8 @@ func Backend() *backend { Secrets: []*framework.Secret{ secretToken(&b), }, - BackendType: logical.TypeLogical, + BackendType: logical.TypeLogical, + RunningVersion: ReportedVersion, } return &b diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index b46813727be6..c76558350586 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/helper/dbtxn" "github.com/hashicorp/vault/sdk/helper/template" + "github.com/hashicorp/vault/sdk/logical" _ "github.com/jackc/pgx/v4/stdlib" ) @@ -32,7 +33,8 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; ) var ( - _ dbplugin.Database = &PostgreSQL{} + _ dbplugin.Database = (*PostgreSQL)(nil) + _ logical.PluginVersioner = (*PostgreSQL)(nil) // postgresEndStatement is basically the word "END" but // surrounded by a word boundary to differentiate it from @@ -46,6 +48,9 @@ var ( // singleQuotedPhrases finds substrings like 'hello' // and pulls them out with the quotes included. singleQuotedPhrases = regexp.MustCompile(`('.*?')`) + + // ReportedVersion is used to report a specific version to Vault. + ReportedVersion = "" ) func New() (interface{}, error) { @@ -469,6 +474,10 @@ func (p *PostgreSQL) secretValues() map[string]string { } } +func (p *PostgreSQL) PluginVersion() logical.PluginVersion { + return logical.PluginVersion{Version: ReportedVersion} +} + // containsMultilineStatement is a best effort to determine whether // a particular statement is multiline, and therefore should not be // split upon semicolons. If it's unsure, it defaults to false. diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index b8fc7b672353..ce3be1efb7c6 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -2,12 +2,14 @@ package dbplugin import ( "context" + "errors" "fmt" "sync" "time" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/base62" "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" "google.golang.org/grpc/codes" @@ -43,11 +45,14 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) if err != nil { return nil, err } - if db, ok := g.instances[id]; ok { return db, nil } + return g.createDatabase(id) +} +// must hold the g.Lock() to call this function +func (g *gRPCServer) createDatabase(id string) (Database, error) { db, err := g.factoryFunc() if err != nil { return nil, err @@ -304,12 +309,36 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e return &proto.Empty{}, nil } +// getOrForceCreateDatabase will create a database even if the multiplexing ID is not present +func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) { + impl, err := g.getOrCreateDatabase(ctx) + if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) { + // if this is called without a multiplexing context, like from the plugin catalog directly, + // then we won't have a database ID, so let's generate a new database instance + id, err := base62.Random(10) + if err != nil { + return nil, err + } + + g.Lock() + defer g.Unlock() + impl, err = g.createDatabase(id) + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + return impl, nil +} + // Version forwards the version request to the underlying Database implementation. func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) { - impl, err := g.getDatabaseInternal(ctx) + impl, err := g.getOrForceCreateDatabase(ctx) if err != nil { return nil, err } + if versioner, ok := impl.(logical.PluginVersioner); ok { return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil } diff --git a/sdk/database/dbplugin/v5/middleware.go b/sdk/database/dbplugin/v5/middleware.go index 3c1a85a28bc8..240d64e6915e 100644 --- a/sdk/database/dbplugin/v5/middleware.go +++ b/sdk/database/dbplugin/v5/middleware.go @@ -233,7 +233,10 @@ func (mw databaseMetricsMiddleware) Close() (err error) { // Error Sanitizer Middleware Domain // /////////////////////////////////////////////////// -var _ Database = DatabaseErrorSanitizerMiddleware{} +var ( + _ Database = (*DatabaseErrorSanitizerMiddleware)(nil) + _ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil) +) // DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and // sanitizes returned error messages @@ -280,6 +283,13 @@ func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) { return mw.sanitize(mw.next.Close()) } +func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion { + if versioner, ok := mw.next.(logical.PluginVersioner); ok { + return versioner.PluginVersion() + } + return logical.EmptyPluginVersion +} + // sanitize errors by removing any sensitive strings within their messages. This uses // the secretsFn to determine what fields should be sanitized. func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error { diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go index 9ebc78381d0f..41316ec49df2 100644 --- a/sdk/helper/pluginutil/multiplexing.go +++ b/sdk/helper/pluginutil/multiplexing.go @@ -2,6 +2,7 @@ package pluginutil import ( "context" + "errors" "fmt" "os" "strings" @@ -13,6 +14,8 @@ import ( "google.golang.org/grpc/status" ) +var ErrNoMultiplexingIDFound = errors.New("no multiplexing ID found") + type PluginMultiplexingServerImpl struct { UnimplementedPluginMultiplexingServer @@ -62,7 +65,9 @@ func GetMultiplexIDFromContext(ctx context.Context) (string, error) { } multiplexIDs := md[MultiplexingCtxKey] - if len(multiplexIDs) != 1 { + if len(multiplexIDs) == 0 { + return "", ErrNoMultiplexingIDFound + } else if len(multiplexIDs) != 1 { return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs)) } diff --git a/vault/auth.go b/vault/auth.go index be789c624170..91cca4120de1 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -2,6 +2,7 @@ package vault import ( "context" + "encoding/hex" "errors" "fmt" "strings" @@ -170,7 +171,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry, var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - backend, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { return err } @@ -794,7 +795,7 @@ func (c *Core) setupCredentials(ctx context.Context) error { // Initialize the backend sysView := c.mountEntrySysView(entry) - backend, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err) plug, plugerr := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeCredential, "") @@ -913,25 +914,30 @@ func (c *Core) teardownCredentials(ctx context.Context) error { return nil } -// newCredentialBackend is used to create and configure a new credential backend by name -func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) { +// newCredentialBackend is used to create and configure a new credential backend by name. +// It also returns the SHA256 of the plugin, if available. +func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) { t := entry.Type if alias, ok := credentialAliases[t]; ok { t = alias } + var runningSha string f, ok := c.credentialBackends[t] if !ok { plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version) if err != nil { - return nil, err + return nil, "", err } if plug == nil { errContext := t if entry.Version != "" { errContext += fmt.Sprintf(", version=%s", entry.Version) } - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + } + if len(plug.Sha256) > 0 { + runningSha = hex.EncodeToString(plug.Sha256) } f = plugin.Factory @@ -967,10 +973,10 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV b, err := f(ctx, config) if err != nil { - return nil, err + return nil, "", err } - return b, nil + return b, runningSha, nil } // defaultAuthTable creates a default auth table diff --git a/vault/external_plugin_test.go b/vault/external_plugin_test.go index 29b665d8f69e..be45e86f9689 100644 --- a/vault/external_plugin_test.go +++ b/vault/external_plugin_test.go @@ -3,6 +3,7 @@ package vault import ( "context" "crypto/sha256" + "encoding/hex" "errors" "fmt" "os" @@ -16,18 +17,19 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/logical" ) var ( - compileAuthOnce sync.Once - compileSecretOnce sync.Once - authPluginBytes []byte - secretPluginBytes []byte + pluginCacheLock sync.Mutex + pluginCache = map[string][]byte{} ) -func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, string) { +// version is used to override the plugin's self-reported version +func testCoreWithPlugin(t *testing.T, typ consts.PluginType, version string) (*Core, string, string) { t.Helper() - pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ) + pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ, version) conf := &CoreConfig{ BuiltinRegistry: NewMockBuiltinRegistry(), PluginDirectory: pluginDir, @@ -37,29 +39,46 @@ func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, str return core, pluginName, pluginSHA256 } -// to mount a plugin, we need a working binary plugin, so we compile one here. -func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum string, pluginDir string) { +func getPlugin(t *testing.T, typ consts.PluginType) (string, string, string, string) { t.Helper() + var pluginName string + var pluginType string + var pluginMain string + var pluginVersionLocation string - var pluginType, pluginName, builtinDirectory string - var once *sync.Once - var pluginBytes *[]byte switch typ { case consts.PluginTypeCredential: pluginType = "approle" pluginName = "vault-plugin-auth-" + pluginType - builtinDirectory = "credential" - once = &compileAuthOnce - pluginBytes = &authPluginBytes + pluginMain = filepath.Join("builtin", "credential", pluginType, "cmd", pluginType, "main.go") + pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/credential/%s.ReportedVersion", pluginType) case consts.PluginTypeSecrets: pluginType = "consul" pluginName = "vault-plugin-secrets-" + pluginType - builtinDirectory = "logical" - once = &compileSecretOnce - pluginBytes = &secretPluginBytes + pluginMain = filepath.Join("builtin", "logical", pluginType, "cmd", pluginType, "main.go") + pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/logical/%s.ReportedVersion", pluginType) + case consts.PluginTypeDatabase: + pluginType = "postgresql" + pluginName = "vault-plugin-database-" + pluginType + pluginMain = filepath.Join("plugins", "database", pluginType, fmt.Sprintf("%s-database-plugin", pluginType), "main.go") + pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/plugins/database/%s.ReportedVersion", pluginType) default: t.Fatal(typ.String()) } + return pluginName, pluginType, pluginMain, pluginVersionLocation +} + +// to mount a plugin, we need a working binary plugin, so we compile one here. +// pluginVersion is used to override the plugin's self-reported version +func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string) (pluginName string, shasum string, pluginDir string) { + t.Helper() + + pluginName, pluginType, pluginMain, pluginVersionLocation := getPlugin(t, typ) + + pluginCacheLock.Lock() + defer pluginCacheLock.Unlock() + + var pluginBytes []byte dir := "" // detect if we are in the "vault/" or the root directory and compensate @@ -76,31 +95,41 @@ func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum str pluginPath := path.Join(pluginDir, pluginName) + key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion) // cache the compilation to only run once - once.Do(func() { - cmd := exec.Command("go", "build", "-o", pluginPath, fmt.Sprintf("builtin/%s/%s/cmd/%s/main.go", builtinDirectory, pluginType, pluginType)) + var ok bool + pluginBytes, ok = pluginCache[key] + if !ok { + // we need to compile + line := []string{"build"} + if pluginVersion != "" { + line = append(line, "-ldflags", fmt.Sprintf("-X %s=%s", pluginVersionLocation, pluginVersion)) + } + line = append(line, "-o", pluginPath, pluginMain) + cmd := exec.Command("go", line...) cmd.Dir = dir output, err := cmd.CombinedOutput() if err != nil { t.Fatal(fmt.Errorf("error running go build %v output: %s", err, output)) } - *pluginBytes, err = os.ReadFile(pluginPath) + pluginCache[key], err = os.ReadFile(pluginPath) if err != nil { t.Fatal(err) } - }) + pluginBytes = pluginCache[key] + } // write the cached plugin if necessary var err error if _, err := os.Stat(pluginPath); os.IsNotExist(err) { - err = os.WriteFile(pluginPath, *pluginBytes, 0o777) + err = os.WriteFile(pluginPath, pluginBytes, 0o777) } if err != nil { t.Fatal(err) } sha := sha256.New() - _, err = sha.Write(*pluginBytes) + _, err = sha.Write(pluginBytes) if err != nil { t.Fatal(err) } @@ -125,7 +154,7 @@ func TestCore_EnableExternalPlugin(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -201,7 +230,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") for _, version := range tc.registerVersions { d := &framework.FieldData{ Raw: map[string]interface{}{ @@ -247,6 +276,10 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { if raw.(*routeEntry).mountEntry.RunningVersion != "" { t.Errorf("Expected mount to have no running version but got %s", raw.(*routeEntry).mountEntry.RunningVersion) } + + if raw.(*routeEntry).mountEntry.RunningSha == "" { + t.Errorf("Expected RunningSha to be present: %+v", raw.(*routeEntry).mountEntry.RunningSha) + } }) } } @@ -269,7 +302,7 @@ func TestCore_EnableExternalPlugin_NoVersionsOkay(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -328,7 +361,7 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -372,7 +405,7 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -390,6 +423,69 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) { } } +func TestExternalPlugin_getBackendTypeVersion(t *testing.T) { + for name, tc := range map[string]struct { + pluginType consts.PluginType + setRunningVersion string + }{ + "external credential plugin": { + pluginType: consts.PluginTypeCredential, + setRunningVersion: "v1.2.3", + }, + "external secrets plugin": { + pluginType: consts.PluginTypeSecrets, + setRunningVersion: "v1.2.3", + }, + "external database plugin": { + pluginType: consts.PluginTypeDatabase, + setRunningVersion: "v1.2.3", + }, + } { + t.Run(name, func(t *testing.T) { + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.setRunningVersion) + d := &framework.FieldData{ + Raw: map[string]interface{}{ + "name": pluginName, + "sha256": pluginSHA256, + "version": tc.setRunningVersion, + "command": pluginName, + }, + Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, + } + resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d) + if err != nil { + t.Fatal(err) + } + if resp.Error() != nil { + t.Fatalf("%#v", resp) + } + + shaBytes, _ := hex.DecodeString(pluginSHA256) + commandFull := filepath.Join(c.pluginCatalog.directory, pluginName) + entry := &pluginutil.PluginRunner{ + Name: pluginName, + Command: commandFull, + Args: nil, + Sha256: shaBytes, + Builtin: false, + } + + var version logical.PluginVersion + if tc.pluginType == consts.PluginTypeDatabase { + version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) + } else { + version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry) + } + if err != nil { + t.Fatal(err) + } + if version.Version != tc.setRunningVersion { + t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version) + } + }) + } +} + func mountTable(pluginType consts.PluginType) string { switch pluginType { case consts.PluginTypeCredential: diff --git a/vault/identity_store_entities_test.go b/vault/identity_store_entities_test.go index 52462832cfad..aefaadcad3b6 100644 --- a/vault/identity_store_entities_test.go +++ b/vault/identity_store_entities_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" credGithub "github.com/hashicorp/vault/builtin/credential/github" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/namespace" @@ -688,7 +688,7 @@ func TestIdentityStore_LoadingEntities(t *testing.T) { ghSysview := c.mountEntrySysView(meGH) // Create new github auth credential backend - ghAuth, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView) + ghAuth, _, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView) if err != nil { t.Fatal(err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index f78d5895c280..c635e39524a5 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -517,6 +517,9 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, pluginVersion, parts[0], args, env, sha256Bytes) if err != nil { + if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") { + return logical.ErrorResponse(err.Error()), nil + } return nil, err } diff --git a/vault/mount.go b/vault/mount.go index 7fad75af882a..ccf78b6fd603 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -2,6 +2,7 @@ package vault import ( "context" + "encoding/hex" "errors" "fmt" "os" @@ -608,7 +609,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora var backend logical.Backend sysView := c.mountEntrySysView(entry) - backend, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { return err } @@ -1419,7 +1420,7 @@ func (c *Core) setupMounts(ctx context.Context) error { var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - backend, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err) if !c.builtinRegistry.Contains(entry.Type, consts.PluginTypeSecrets) { @@ -1523,25 +1524,30 @@ func (c *Core) unloadMounts(ctx context.Context) error { return nil } -// newLogicalBackend is used to create and configure a new logical backend by name -func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) { +// newLogicalBackend is used to create and configure a new logical backend by name. +// It also returns the SHA256 of the plugin, if available. +func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) { t := entry.Type if alias, ok := mountAliases[t]; ok { t = alias } + var runningSha string f, ok := c.logicalBackends[t] if !ok { plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, entry.Version) if err != nil { - return nil, err + return nil, "", err } if plug == nil { errContext := t if entry.Version != "" { errContext += fmt.Sprintf(", version=%s", entry.Version) } - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + } + if len(plug.Sha256) > 0 { + runningSha = hex.EncodeToString(plug.Sha256) } f = plugin.Factory @@ -1578,14 +1584,14 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView ctx = context.WithValue(ctx, "core_number", c.coreNumber) b, err := f(ctx, config) if err != nil { - return nil, err + return nil, "", err } if b == nil { - return nil, fmt.Errorf("nil backend of type %q returned from factory", t) + return nil, "", fmt.Errorf("nil backend of type %q returned from factory", t) } addLicenseCallback(c, b) - return b, nil + return b, runningSha, nil } // defaultMountTable creates a default mount table diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 95cb50421842..329a02b8e49a 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -365,7 +365,7 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi // getPluginTypeFromUnknown will attempt to run the plugin to determine the // type. It will first attempt to run as a database plugin then a backend // plugin. -func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { +func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { merr := &multierror.Error{} err := c.isDatabasePlugin(ctx, plugin) if err == nil { @@ -461,6 +461,124 @@ func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner * return consts.PluginTypeUnknown, merr.ErrorOrNil() } +// getBackendRunningVersion attempts to get the plugin version +func (c *PluginCatalog) getBackendRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) { + merr := &multierror.Error{} + // Attempt to run as backend plugin + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + PluginSets: backendplugin.PluginSet, + HandshakeConfig: backendplugin.HandshakeConfig, + Logger: log.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, + } + + var client logical.Backend + // First, attempt to run as backend V5 plugin + c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name) + pc, err := c.newPluginClient(ctx, pluginRunner, config) + if err == nil { + // we spawned a subprocess, so make sure to clean it up + defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id) + + // dispense the plugin so we can get its version + client, err = backendplugin.Dispense(pc.ClientProtocol, pc) + if err == nil { + c.logger.Debug("successfully dispensed v5 backend plugin", "name", pluginRunner.Name) + + err = client.Setup(ctx, &logical.BackendConfig{}) + if err != nil { + return logical.EmptyPluginVersion, nil + } + if versioner, ok := client.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + return logical.EmptyPluginVersion, nil + } + merr = multierror.Append(merr, fmt.Errorf("failed to dispense plugin as backend v5: %w", err)) + } + c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err) + config.AutoMTLS = false + config.IsMetadataMode = true + // attempt to run as a v4 backend plugin + client, err = backendplugin.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("failed to dispense v4 backend plugin: %w", err)) + c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", merr) + return logical.EmptyPluginVersion, merr.ErrorOrNil() + } + c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name) + defer client.Cleanup(ctx) + + err = client.Setup(ctx, &logical.BackendConfig{}) + if err != nil { + return logical.EmptyPluginVersion, err + } + if versioner, ok := client.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + return logical.EmptyPluginVersion, nil +} + +// getDatabaseRunningVersion returns the version reported by a database plugin +func (c *PluginCatalog) getDatabaseRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) { + merr := &multierror.Error{} + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + PluginSets: v5.PluginSets, + PluginType: consts.PluginTypeDatabase, + Version: pluginRunner.Version, + HandshakeConfig: v5.HandshakeConfig, + Logger: log.Default(), + IsMetadataMode: true, + AutoMTLS: true, + } + + // Attempt to run as database V5+ multiplexed plugin + c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name) + v5Client, err := c.newPluginClient(ctx, pluginRunner, config) + if err == nil { + defer func() { + // Close the client and cleanup the plugin process + err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id) + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } + }() + + raw, err := v5Client.Dispense("database") + if err != nil { + return logical.EmptyPluginVersion, err + } + if versioner, ok := raw.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + return logical.EmptyPluginVersion, nil + } + merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err)) + + c.logger.Debug("attempting to load database plugin as v4", "name", pluginRunner.Name) + v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) + if err == nil { + // Close the client and cleanup the plugin process + defer func() { + err = v4Client.Close() + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } + }() + + if versioner, ok := v4Client.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + + return logical.EmptyPluginVersion, nil + } + merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err)) + return logical.EmptyPluginVersion, merr +} + // isDatabasePlugin returns an error if the plugin is not a database plugin. func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) error { merr := &multierror.Error{} @@ -475,7 +593,7 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug AutoMTLS: true, } - // Attempt to run as database V5 or V6 multiplexed plugin + // Attempt to run as database V5+ multiplexed plugin c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name) v5Client, err := c.newPluginClient(ctx, pluginRunner, config) if err == nil { @@ -671,20 +789,19 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType return nil, errors.New("cannot execute files outside of configured plugin directory") } + // entryTmp should only be used for the below type and version checks, it uses the + // full command instead of the relative command. + entryTmp := &pluginutil.PluginRunner{ + Name: name, + Command: commandFull, + Args: args, + Env: env, + Sha256: sha256, + Builtin: false, + } // If the plugin type is unknown, we want to attempt to determine the type if pluginType == consts.PluginTypeUnknown { - // entryTmp should only be used for the below type check, it uses the - // full command instead of the relative command. - entryTmp := &pluginutil.PluginRunner{ - Name: name, - Command: commandFull, - Args: args, - Env: env, - Sha256: sha256, - Builtin: false, - } - - pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp) + pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp) if err != nil { return nil, err } @@ -693,6 +810,24 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType } } + // getting the plugin version is best-effort, so errors are not fatal + runningVersion := logical.EmptyPluginVersion + var versionErr error + switch pluginType { + case consts.PluginTypeSecrets, consts.PluginTypeCredential: + runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp) + case consts.PluginTypeDatabase: + runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp) + default: + return nil, fmt.Errorf("unknown plugin type: %v", pluginType) + } + if versionErr != nil { + c.logger.Warn("Error determining plugin version", "error", versionErr) + } else if version != "" && runningVersion.Version != "" && version != runningVersion.Version { + c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", name, "requestedVersion", version, "reportedVersion", runningVersion.Version) + return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", name, runningVersion.Version, version) + } + entry := &pluginutil.PluginRunner{ Name: name, Type: pluginType, diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index e60228a9b2c8..b65e8bf5f265 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/vault/helper/namespace" - multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/plugin" @@ -174,11 +174,12 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut } var backend logical.Backend + oldSha := entry.RunningSha if !isAuth { // Dispense a new backend - backend, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view) } else { - backend, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view) } if err != nil { return err @@ -187,6 +188,20 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type) } + // update the mount table since we changed the runningSha + if oldSha != entry.RunningSha && MountTableUpdateStorage { + if isAuth { + err = c.persistAuth(ctx, c.auth, &entry.Local) + if err != nil { + return err + } + } else { + err = c.persistMounts(ctx, c.mounts, &entry.Local) + if err != nil { + return err + } + } + } addPathCheckers(c, entry, backend, viewPath) if nilMount {