diff --git a/builtin/credential/approle/backend.go b/builtin/credential/approle/backend.go index 9612b65f7afd..ebd8d3c06a80 100644 --- a/builtin/credential/approle/backend.go +++ b/builtin/credential/approle/backend.go @@ -18,6 +18,9 @@ const ( secretIDAccessorLocalPrefix = "accessor_local/" ) +// ReportedVersion is used to report a specific version to Vault. +var ReportedVersion = "" + type backend struct { *framework.Backend @@ -111,8 +114,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { pathTidySecretID(b), }, ), - Invalidate: b.invalidate, - BackendType: logical.TypeCredential, + Invalidate: b.invalidate, + BackendType: logical.TypeCredential, + RunningVersion: ReportedVersion, } return b, nil } diff --git a/builtin/logical/consul/backend.go b/builtin/logical/consul/backend.go index 3e37d1510d6e..7fce10e26294 100644 --- a/builtin/logical/consul/backend.go +++ b/builtin/logical/consul/backend.go @@ -7,6 +7,9 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +// ReportedVersion is used to report a specific version to Vault. +var ReportedVersion = "" + func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { b := Backend() if err := b.Setup(ctx, conf); err != nil { @@ -34,7 +37,8 @@ func Backend() *backend { Secrets: []*framework.Secret{ secretToken(&b), }, - BackendType: logical.TypeLogical, + BackendType: logical.TypeLogical, + RunningVersion: ReportedVersion, } return &b diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index b46813727be6..c76558350586 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/helper/dbtxn" "github.com/hashicorp/vault/sdk/helper/template" + "github.com/hashicorp/vault/sdk/logical" _ "github.com/jackc/pgx/v4/stdlib" ) @@ -32,7 +33,8 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; ) var ( - _ dbplugin.Database = &PostgreSQL{} + _ dbplugin.Database = (*PostgreSQL)(nil) + _ logical.PluginVersioner = (*PostgreSQL)(nil) // postgresEndStatement is basically the word "END" but // surrounded by a word boundary to differentiate it from @@ -46,6 +48,9 @@ var ( // singleQuotedPhrases finds substrings like 'hello' // and pulls them out with the quotes included. singleQuotedPhrases = regexp.MustCompile(`('.*?')`) + + // ReportedVersion is used to report a specific version to Vault. + ReportedVersion = "" ) func New() (interface{}, error) { @@ -469,6 +474,10 @@ func (p *PostgreSQL) secretValues() map[string]string { } } +func (p *PostgreSQL) PluginVersion() logical.PluginVersion { + return logical.PluginVersion{Version: ReportedVersion} +} + // containsMultilineStatement is a best effort to determine whether // a particular statement is multiline, and therefore should not be // split upon semicolons. If it's unsure, it defaults to false. diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index b8fc7b672353..ce3be1efb7c6 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -2,12 +2,14 @@ package dbplugin import ( "context" + "errors" "fmt" "sync" "time" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/base62" "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" "google.golang.org/grpc/codes" @@ -43,11 +45,14 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) if err != nil { return nil, err } - if db, ok := g.instances[id]; ok { return db, nil } + return g.createDatabase(id) +} +// must hold the g.Lock() to call this function +func (g *gRPCServer) createDatabase(id string) (Database, error) { db, err := g.factoryFunc() if err != nil { return nil, err @@ -304,12 +309,36 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e return &proto.Empty{}, nil } +// getOrForceCreateDatabase will create a database even if the multiplexing ID is not present +func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) { + impl, err := g.getOrCreateDatabase(ctx) + if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) { + // if this is called without a multiplexing context, like from the plugin catalog directly, + // then we won't have a database ID, so let's generate a new database instance + id, err := base62.Random(10) + if err != nil { + return nil, err + } + + g.Lock() + defer g.Unlock() + impl, err = g.createDatabase(id) + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + return impl, nil +} + // Version forwards the version request to the underlying Database implementation. func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) { - impl, err := g.getDatabaseInternal(ctx) + impl, err := g.getOrForceCreateDatabase(ctx) if err != nil { return nil, err } + if versioner, ok := impl.(logical.PluginVersioner); ok { return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil } diff --git a/sdk/database/dbplugin/v5/middleware.go b/sdk/database/dbplugin/v5/middleware.go index 3c1a85a28bc8..240d64e6915e 100644 --- a/sdk/database/dbplugin/v5/middleware.go +++ b/sdk/database/dbplugin/v5/middleware.go @@ -233,7 +233,10 @@ func (mw databaseMetricsMiddleware) Close() (err error) { // Error Sanitizer Middleware Domain // /////////////////////////////////////////////////// -var _ Database = DatabaseErrorSanitizerMiddleware{} +var ( + _ Database = (*DatabaseErrorSanitizerMiddleware)(nil) + _ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil) +) // DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and // sanitizes returned error messages @@ -280,6 +283,13 @@ func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) { return mw.sanitize(mw.next.Close()) } +func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion { + if versioner, ok := mw.next.(logical.PluginVersioner); ok { + return versioner.PluginVersion() + } + return logical.EmptyPluginVersion +} + // sanitize errors by removing any sensitive strings within their messages. This uses // the secretsFn to determine what fields should be sanitized. func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error { diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go index 9ebc78381d0f..41316ec49df2 100644 --- a/sdk/helper/pluginutil/multiplexing.go +++ b/sdk/helper/pluginutil/multiplexing.go @@ -2,6 +2,7 @@ package pluginutil import ( "context" + "errors" "fmt" "os" "strings" @@ -13,6 +14,8 @@ import ( "google.golang.org/grpc/status" ) +var ErrNoMultiplexingIDFound = errors.New("no multiplexing ID found") + type PluginMultiplexingServerImpl struct { UnimplementedPluginMultiplexingServer @@ -62,7 +65,9 @@ func GetMultiplexIDFromContext(ctx context.Context) (string, error) { } multiplexIDs := md[MultiplexingCtxKey] - if len(multiplexIDs) != 1 { + if len(multiplexIDs) == 0 { + return "", ErrNoMultiplexingIDFound + } else if len(multiplexIDs) != 1 { return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs)) } diff --git a/vault/auth.go b/vault/auth.go index be789c624170..91cca4120de1 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -2,6 +2,7 @@ package vault import ( "context" + "encoding/hex" "errors" "fmt" "strings" @@ -170,7 +171,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry, var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - backend, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { return err } @@ -794,7 +795,7 @@ func (c *Core) setupCredentials(ctx context.Context) error { // Initialize the backend sysView := c.mountEntrySysView(entry) - backend, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err) plug, plugerr := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeCredential, "") @@ -913,25 +914,30 @@ func (c *Core) teardownCredentials(ctx context.Context) error { return nil } -// newCredentialBackend is used to create and configure a new credential backend by name -func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) { +// newCredentialBackend is used to create and configure a new credential backend by name. +// It also returns the SHA256 of the plugin, if available. +func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) { t := entry.Type if alias, ok := credentialAliases[t]; ok { t = alias } + var runningSha string f, ok := c.credentialBackends[t] if !ok { plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version) if err != nil { - return nil, err + return nil, "", err } if plug == nil { errContext := t if entry.Version != "" { errContext += fmt.Sprintf(", version=%s", entry.Version) } - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + } + if len(plug.Sha256) > 0 { + runningSha = hex.EncodeToString(plug.Sha256) } f = plugin.Factory @@ -967,10 +973,10 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV b, err := f(ctx, config) if err != nil { - return nil, err + return nil, "", err } - return b, nil + return b, runningSha, nil } // defaultAuthTable creates a default auth table diff --git a/vault/external_plugin_test.go b/vault/external_plugin_test.go index 29b665d8f69e..be45e86f9689 100644 --- a/vault/external_plugin_test.go +++ b/vault/external_plugin_test.go @@ -3,6 +3,7 @@ package vault import ( "context" "crypto/sha256" + "encoding/hex" "errors" "fmt" "os" @@ -16,18 +17,19 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/logical" ) var ( - compileAuthOnce sync.Once - compileSecretOnce sync.Once - authPluginBytes []byte - secretPluginBytes []byte + pluginCacheLock sync.Mutex + pluginCache = map[string][]byte{} ) -func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, string) { +// version is used to override the plugin's self-reported version +func testCoreWithPlugin(t *testing.T, typ consts.PluginType, version string) (*Core, string, string) { t.Helper() - pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ) + pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ, version) conf := &CoreConfig{ BuiltinRegistry: NewMockBuiltinRegistry(), PluginDirectory: pluginDir, @@ -37,29 +39,46 @@ func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, str return core, pluginName, pluginSHA256 } -// to mount a plugin, we need a working binary plugin, so we compile one here. -func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum string, pluginDir string) { +func getPlugin(t *testing.T, typ consts.PluginType) (string, string, string, string) { t.Helper() + var pluginName string + var pluginType string + var pluginMain string + var pluginVersionLocation string - var pluginType, pluginName, builtinDirectory string - var once *sync.Once - var pluginBytes *[]byte switch typ { case consts.PluginTypeCredential: pluginType = "approle" pluginName = "vault-plugin-auth-" + pluginType - builtinDirectory = "credential" - once = &compileAuthOnce - pluginBytes = &authPluginBytes + pluginMain = filepath.Join("builtin", "credential", pluginType, "cmd", pluginType, "main.go") + pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/credential/%s.ReportedVersion", pluginType) case consts.PluginTypeSecrets: pluginType = "consul" pluginName = "vault-plugin-secrets-" + pluginType - builtinDirectory = "logical" - once = &compileSecretOnce - pluginBytes = &secretPluginBytes + pluginMain = filepath.Join("builtin", "logical", pluginType, "cmd", pluginType, "main.go") + pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/logical/%s.ReportedVersion", pluginType) + case consts.PluginTypeDatabase: + pluginType = "postgresql" + pluginName = "vault-plugin-database-" + pluginType + pluginMain = filepath.Join("plugins", "database", pluginType, fmt.Sprintf("%s-database-plugin", pluginType), "main.go") + pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/plugins/database/%s.ReportedVersion", pluginType) default: t.Fatal(typ.String()) } + return pluginName, pluginType, pluginMain, pluginVersionLocation +} + +// to mount a plugin, we need a working binary plugin, so we compile one here. +// pluginVersion is used to override the plugin's self-reported version +func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string) (pluginName string, shasum string, pluginDir string) { + t.Helper() + + pluginName, pluginType, pluginMain, pluginVersionLocation := getPlugin(t, typ) + + pluginCacheLock.Lock() + defer pluginCacheLock.Unlock() + + var pluginBytes []byte dir := "" // detect if we are in the "vault/" or the root directory and compensate @@ -76,31 +95,41 @@ func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum str pluginPath := path.Join(pluginDir, pluginName) + key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion) // cache the compilation to only run once - once.Do(func() { - cmd := exec.Command("go", "build", "-o", pluginPath, fmt.Sprintf("builtin/%s/%s/cmd/%s/main.go", builtinDirectory, pluginType, pluginType)) + var ok bool + pluginBytes, ok = pluginCache[key] + if !ok { + // we need to compile + line := []string{"build"} + if pluginVersion != "" { + line = append(line, "-ldflags", fmt.Sprintf("-X %s=%s", pluginVersionLocation, pluginVersion)) + } + line = append(line, "-o", pluginPath, pluginMain) + cmd := exec.Command("go", line...) cmd.Dir = dir output, err := cmd.CombinedOutput() if err != nil { t.Fatal(fmt.Errorf("error running go build %v output: %s", err, output)) } - *pluginBytes, err = os.ReadFile(pluginPath) + pluginCache[key], err = os.ReadFile(pluginPath) if err != nil { t.Fatal(err) } - }) + pluginBytes = pluginCache[key] + } // write the cached plugin if necessary var err error if _, err := os.Stat(pluginPath); os.IsNotExist(err) { - err = os.WriteFile(pluginPath, *pluginBytes, 0o777) + err = os.WriteFile(pluginPath, pluginBytes, 0o777) } if err != nil { t.Fatal(err) } sha := sha256.New() - _, err = sha.Write(*pluginBytes) + _, err = sha.Write(pluginBytes) if err != nil { t.Fatal(err) } @@ -125,7 +154,7 @@ func TestCore_EnableExternalPlugin(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -201,7 +230,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") for _, version := range tc.registerVersions { d := &framework.FieldData{ Raw: map[string]interface{}{ @@ -247,6 +276,10 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { if raw.(*routeEntry).mountEntry.RunningVersion != "" { t.Errorf("Expected mount to have no running version but got %s", raw.(*routeEntry).mountEntry.RunningVersion) } + + if raw.(*routeEntry).mountEntry.RunningSha == "" { + t.Errorf("Expected RunningSha to be present: %+v", raw.(*routeEntry).mountEntry.RunningSha) + } }) } } @@ -269,7 +302,7 @@ func TestCore_EnableExternalPlugin_NoVersionsOkay(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -328,7 +361,7 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -372,7 +405,7 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType) + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") d := &framework.FieldData{ Raw: map[string]interface{}{ "name": pluginName, @@ -390,6 +423,69 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) { } } +func TestExternalPlugin_getBackendTypeVersion(t *testing.T) { + for name, tc := range map[string]struct { + pluginType consts.PluginType + setRunningVersion string + }{ + "external credential plugin": { + pluginType: consts.PluginTypeCredential, + setRunningVersion: "v1.2.3", + }, + "external secrets plugin": { + pluginType: consts.PluginTypeSecrets, + setRunningVersion: "v1.2.3", + }, + "external database plugin": { + pluginType: consts.PluginTypeDatabase, + setRunningVersion: "v1.2.3", + }, + } { + t.Run(name, func(t *testing.T) { + c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.setRunningVersion) + d := &framework.FieldData{ + Raw: map[string]interface{}{ + "name": pluginName, + "sha256": pluginSHA256, + "version": tc.setRunningVersion, + "command": pluginName, + }, + Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, + } + resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d) + if err != nil { + t.Fatal(err) + } + if resp.Error() != nil { + t.Fatalf("%#v", resp) + } + + shaBytes, _ := hex.DecodeString(pluginSHA256) + commandFull := filepath.Join(c.pluginCatalog.directory, pluginName) + entry := &pluginutil.PluginRunner{ + Name: pluginName, + Command: commandFull, + Args: nil, + Sha256: shaBytes, + Builtin: false, + } + + var version logical.PluginVersion + if tc.pluginType == consts.PluginTypeDatabase { + version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) + } else { + version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry) + } + if err != nil { + t.Fatal(err) + } + if version.Version != tc.setRunningVersion { + t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version) + } + }) + } +} + func mountTable(pluginType consts.PluginType) string { switch pluginType { case consts.PluginTypeCredential: diff --git a/vault/identity_store_entities_test.go b/vault/identity_store_entities_test.go index 52462832cfad..aefaadcad3b6 100644 --- a/vault/identity_store_entities_test.go +++ b/vault/identity_store_entities_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-uuid" credGithub "github.com/hashicorp/vault/builtin/credential/github" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/namespace" @@ -688,7 +688,7 @@ func TestIdentityStore_LoadingEntities(t *testing.T) { ghSysview := c.mountEntrySysView(meGH) // Create new github auth credential backend - ghAuth, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView) + ghAuth, _, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView) if err != nil { t.Fatal(err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index f78d5895c280..c635e39524a5 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -517,6 +517,9 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, pluginVersion, parts[0], args, env, sha256Bytes) if err != nil { + if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") { + return logical.ErrorResponse(err.Error()), nil + } return nil, err } diff --git a/vault/mount.go b/vault/mount.go index 7fad75af882a..ccf78b6fd603 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -2,6 +2,7 @@ package vault import ( "context" + "encoding/hex" "errors" "fmt" "os" @@ -608,7 +609,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora var backend logical.Backend sysView := c.mountEntrySysView(entry) - backend, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { return err } @@ -1419,7 +1420,7 @@ func (c *Core) setupMounts(ctx context.Context) error { var backend logical.Backend // Create the new backend sysView := c.mountEntrySysView(entry) - backend, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view) if err != nil { c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err) if !c.builtinRegistry.Contains(entry.Type, consts.PluginTypeSecrets) { @@ -1523,25 +1524,30 @@ func (c *Core) unloadMounts(ctx context.Context) error { return nil } -// newLogicalBackend is used to create and configure a new logical backend by name -func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) { +// newLogicalBackend is used to create and configure a new logical backend by name. +// It also returns the SHA256 of the plugin, if available. +func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) { t := entry.Type if alias, ok := mountAliases[t]; ok { t = alias } + var runningSha string f, ok := c.logicalBackends[t] if !ok { plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, entry.Version) if err != nil { - return nil, err + return nil, "", err } if plug == nil { errContext := t if entry.Version != "" { errContext += fmt.Sprintf(", version=%s", entry.Version) } - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + } + if len(plug.Sha256) > 0 { + runningSha = hex.EncodeToString(plug.Sha256) } f = plugin.Factory @@ -1578,14 +1584,14 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView ctx = context.WithValue(ctx, "core_number", c.coreNumber) b, err := f(ctx, config) if err != nil { - return nil, err + return nil, "", err } if b == nil { - return nil, fmt.Errorf("nil backend of type %q returned from factory", t) + return nil, "", fmt.Errorf("nil backend of type %q returned from factory", t) } addLicenseCallback(c, b) - return b, nil + return b, runningSha, nil } // defaultMountTable creates a default mount table diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 95cb50421842..329a02b8e49a 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -365,7 +365,7 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi // getPluginTypeFromUnknown will attempt to run the plugin to determine the // type. It will first attempt to run as a database plugin then a backend // plugin. -func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { +func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { merr := &multierror.Error{} err := c.isDatabasePlugin(ctx, plugin) if err == nil { @@ -461,6 +461,124 @@ func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner * return consts.PluginTypeUnknown, merr.ErrorOrNil() } +// getBackendRunningVersion attempts to get the plugin version +func (c *PluginCatalog) getBackendRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) { + merr := &multierror.Error{} + // Attempt to run as backend plugin + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + PluginSets: backendplugin.PluginSet, + HandshakeConfig: backendplugin.HandshakeConfig, + Logger: log.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, + } + + var client logical.Backend + // First, attempt to run as backend V5 plugin + c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name) + pc, err := c.newPluginClient(ctx, pluginRunner, config) + if err == nil { + // we spawned a subprocess, so make sure to clean it up + defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id) + + // dispense the plugin so we can get its version + client, err = backendplugin.Dispense(pc.ClientProtocol, pc) + if err == nil { + c.logger.Debug("successfully dispensed v5 backend plugin", "name", pluginRunner.Name) + + err = client.Setup(ctx, &logical.BackendConfig{}) + if err != nil { + return logical.EmptyPluginVersion, nil + } + if versioner, ok := client.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + return logical.EmptyPluginVersion, nil + } + merr = multierror.Append(merr, fmt.Errorf("failed to dispense plugin as backend v5: %w", err)) + } + c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err) + config.AutoMTLS = false + config.IsMetadataMode = true + // attempt to run as a v4 backend plugin + client, err = backendplugin.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("failed to dispense v4 backend plugin: %w", err)) + c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", merr) + return logical.EmptyPluginVersion, merr.ErrorOrNil() + } + c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name) + defer client.Cleanup(ctx) + + err = client.Setup(ctx, &logical.BackendConfig{}) + if err != nil { + return logical.EmptyPluginVersion, err + } + if versioner, ok := client.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + return logical.EmptyPluginVersion, nil +} + +// getDatabaseRunningVersion returns the version reported by a database plugin +func (c *PluginCatalog) getDatabaseRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) { + merr := &multierror.Error{} + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + PluginSets: v5.PluginSets, + PluginType: consts.PluginTypeDatabase, + Version: pluginRunner.Version, + HandshakeConfig: v5.HandshakeConfig, + Logger: log.Default(), + IsMetadataMode: true, + AutoMTLS: true, + } + + // Attempt to run as database V5+ multiplexed plugin + c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name) + v5Client, err := c.newPluginClient(ctx, pluginRunner, config) + if err == nil { + defer func() { + // Close the client and cleanup the plugin process + err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id) + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } + }() + + raw, err := v5Client.Dispense("database") + if err != nil { + return logical.EmptyPluginVersion, err + } + if versioner, ok := raw.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + return logical.EmptyPluginVersion, nil + } + merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err)) + + c.logger.Debug("attempting to load database plugin as v4", "name", pluginRunner.Name) + v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) + if err == nil { + // Close the client and cleanup the plugin process + defer func() { + err = v4Client.Close() + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } + }() + + if versioner, ok := v4Client.(logical.PluginVersioner); ok { + return versioner.PluginVersion(), nil + } + + return logical.EmptyPluginVersion, nil + } + merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err)) + return logical.EmptyPluginVersion, merr +} + // isDatabasePlugin returns an error if the plugin is not a database plugin. func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) error { merr := &multierror.Error{} @@ -475,7 +593,7 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug AutoMTLS: true, } - // Attempt to run as database V5 or V6 multiplexed plugin + // Attempt to run as database V5+ multiplexed plugin c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name) v5Client, err := c.newPluginClient(ctx, pluginRunner, config) if err == nil { @@ -671,20 +789,19 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType return nil, errors.New("cannot execute files outside of configured plugin directory") } + // entryTmp should only be used for the below type and version checks, it uses the + // full command instead of the relative command. + entryTmp := &pluginutil.PluginRunner{ + Name: name, + Command: commandFull, + Args: args, + Env: env, + Sha256: sha256, + Builtin: false, + } // If the plugin type is unknown, we want to attempt to determine the type if pluginType == consts.PluginTypeUnknown { - // entryTmp should only be used for the below type check, it uses the - // full command instead of the relative command. - entryTmp := &pluginutil.PluginRunner{ - Name: name, - Command: commandFull, - Args: args, - Env: env, - Sha256: sha256, - Builtin: false, - } - - pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp) + pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp) if err != nil { return nil, err } @@ -693,6 +810,24 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType } } + // getting the plugin version is best-effort, so errors are not fatal + runningVersion := logical.EmptyPluginVersion + var versionErr error + switch pluginType { + case consts.PluginTypeSecrets, consts.PluginTypeCredential: + runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp) + case consts.PluginTypeDatabase: + runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp) + default: + return nil, fmt.Errorf("unknown plugin type: %v", pluginType) + } + if versionErr != nil { + c.logger.Warn("Error determining plugin version", "error", versionErr) + } else if version != "" && runningVersion.Version != "" && version != runningVersion.Version { + c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", name, "requestedVersion", version, "reportedVersion", runningVersion.Version) + return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", name, runningVersion.Version, version) + } + entry := &pluginutil.PluginRunner{ Name: name, Type: pluginType, diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index e60228a9b2c8..b65e8bf5f265 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/vault/helper/namespace" - multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/plugin" @@ -174,11 +174,12 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut } var backend logical.Backend + oldSha := entry.RunningSha if !isAuth { // Dispense a new backend - backend, err = c.newLogicalBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view) } else { - backend, err = c.newCredentialBackend(ctx, entry, sysView, view) + backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view) } if err != nil { return err @@ -187,6 +188,20 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type) } + // update the mount table since we changed the runningSha + if oldSha != entry.RunningSha && MountTableUpdateStorage { + if isAuth { + err = c.persistAuth(ctx, c.auth, &entry.Local) + if err != nil { + return err + } + } else { + err = c.persistMounts(ctx, c.mounts, &entry.Local) + if err != nil { + return err + } + } + } addPathCheckers(c, entry, backend, viewPath) if nilMount {