Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reloading database plugins across multiple mounts #24512

Merged
merged 5 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 133 additions & 138 deletions builtin/logical/database/backend_test.go

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions builtin/logical/database/versioning_large_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ func TestPlugin_lifecycle(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", []string{})
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{})
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", []string{})
env := []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)}
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", env)
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", env)
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", env)

config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
Expand Down
5 changes: 2 additions & 3 deletions builtin/plugin/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func())
},
}

os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, []string{})
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd,
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})

return config, func() {
cluster.Cleanup()
Expand Down
6 changes: 6 additions & 0 deletions changelog/24512.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
```release-note:change
plugins: Add a warning to the response from sys/plugins/reload/backend if no plugins were reloaded.
```
```release-note:improvement
secrets/database: Support reloading named database plugins using the sys/plugins/reload/backend API endpoint.
```
6 changes: 3 additions & 3 deletions http/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package http

import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"reflect"
Expand Down Expand Up @@ -55,10 +56,9 @@ func getPluginClusterAndCore(t *testing.T, logger log.Logger) (*vault.TestCluste
cores := cluster.Cores
core := cores[0]

os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

vault.TestWaitActive(benchhelpers.TBtoT(t), core.Core)
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", []string{})
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain",
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})

// Mount the mock plugin
err = core.Client.Sys().Mount("mock", &api.MountInput{
Expand Down
32 changes: 32 additions & 0 deletions vault/external_tests/plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
if resp.Data["reload_id"] == nil {
t.Fatal("no reload_id in response")
}
if len(resp.Warnings) != 0 {
t.Fatal(resp.Warnings)
}

for i := 0; i < 2; i++ {
// Ensure internal backed value is reset
Expand All @@ -578,6 +581,35 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
}
}

func TestSystemBackend_PluginReload_WarningIfNoneReloaded(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical, "v5")
defer cluster.Cleanup()

core := cluster.Cores[0]
client := core.Client

for _, backendType := range []logical.BackendType{logical.TypeLogical, logical.TypeCredential} {
t.Run(backendType.String(), func(t *testing.T) {
// Perform plugin reload
resp, err := client.Logical().Write("sys/plugins/reload/backend", map[string]any{
"plugin": "does-not-exist",
})
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.Data["reload_id"] == nil {
t.Fatal("no reload_id in response")
}
if len(resp.Warnings) == 0 {
t.Fatal("expected warning")
}
})
}
}

// testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends. numMounts alternates between different
// ways of providing the plugin_name.
Expand Down
25 changes: 16 additions & 9 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,32 +737,39 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic
return logical.ErrorResponse("plugin or mounts must be provided"), nil
}

resp := logical.Response{
Data: map[string]interface{}{
"reload_id": req.ID,
},
}

if pluginName != "" {
err := b.Core.reloadMatchingPlugin(ctx, pluginName)
reloaded, err := b.Core.reloadMatchingPlugin(ctx, pluginName)
if err != nil {
return nil, err
}
if reloaded == 0 {
if scope == globalScope {
resp.AddWarning("no plugins were reloaded locally (but they may be reloaded on other nodes)")
} else {
resp.AddWarning("no plugins were reloaded")
}
}
} else if len(pluginMounts) > 0 {
err := b.Core.reloadMatchingPluginMounts(ctx, pluginMounts)
if err != nil {
return nil, err
}
}

r := logical.Response{
Data: map[string]interface{}{
"reload_id": req.ID,
},
}

if scope == globalScope {
err := handleGlobalPluginReload(ctx, b.Core, req.ID, pluginName, pluginMounts)
if err != nil {
return nil, err
}
return logical.RespondWithStatusCode(&r, req, http.StatusAccepted)
return logical.RespondWithStatusCode(&resp, req, http.StatusAccepted)
}
return &r, nil
return &resp, nil
}

func (b *SystemBackend) handlePluginRuntimeCatalogUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand Down
2 changes: 1 addition & 1 deletion vault/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ const mountStateUnmounting = "unmounting"
// MountEntry is used to represent a mount table entry
type MountEntry struct {
Table string `json:"table"` // The table it belongs to
Path string `json:"path"` // Mount Path
Path string `json:"path"` // Mount Path, as provided in the mount API call but with a trailing slash, i.e. no auth/ or namespace prefix.
Type string `json:"type"` // Logical backend Type. NB: This is the plugin name, e.g. my-vault-plugin, NOT plugin type (e.g. auth).
Description string `json:"description"` // User-provided description
UUID string `json:"uuid"` // Barrier view UUID
Expand Down
47 changes: 36 additions & 11 deletions vault/plugin_reload.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,60 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, mounts []string)
return errors
}

// reloadPlugin reloads all mounted backends that are of
// plugin pluginName (name of the plugin as registered in
// the plugin catalog).
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) error {
// reloadMatchingPlugin reloads all mounted backends that are named pluginName
// (name of the plugin as registered in the plugin catalog). It returns the
// number of plugins that were reloaded and an error if any.
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) (reloaded int, err error) {
c.mountsLock.RLock()
defer c.mountsLock.RUnlock()
c.authLock.RLock()
defer c.authLock.RUnlock()

ns, err := namespace.FromContext(ctx)
if err != nil {
return err
return reloaded, err
}

// Filter mount entries that only matches the plugin name
for _, entry := range c.mounts.Entries {
// We dont reload mounts that are not in the same namespace
if ns.ID != entry.Namespace().ID {
continue
}

if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
err := c.reloadBackendCommon(ctx, entry, false)
if err != nil {
return err
return reloaded, err
}
reloaded++
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.Version)
} else if entry.Type == "database" {
// The combined database plugin is itself a secrets engine, but
// knowledge of whether a database plugin is in use within a particular
// mount is internal to the combined database plugin's storage, so
// we delegate the reload request with an internally routed request.
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: entry.Path + "reload/" + pluginName,
}
resp, err := c.router.Route(ctx, req)
if err != nil {
return reloaded, err
}
if resp == nil {
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s", pluginName, entry.Path)
}
if resp.IsError() {
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s: %s", pluginName, entry.Path, resp.Error())
}

if count, ok := resp.Data["count"].(int); ok && count > 0 {
c.logger.Info("successfully reloaded database plugin(s)", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "connections", resp.Data["connections"])
reloaded += count
}
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "path", entry.Path, "version", entry.Version)
}
}

// Filter auth mount entries that ony matches the plugin name
for _, entry := range c.auth.Entries {
// We dont reload mounts that are not in the same namespace
if ns.ID != entry.Namespace().ID {
Expand All @@ -109,13 +133,14 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
err := c.reloadBackendCommon(ctx, entry, true)
if err != nil {
return err
return reloaded, err
}
reloaded++
c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version)
}
}

return nil
return reloaded, nil
}

// reloadBackendCommon is a generic method to reload a backend provided a
Expand Down
Loading