From 50105d9c349076237dbfc995459718714be9dfd1 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Mon, 18 Jul 2022 16:25:18 -0500 Subject: [PATCH] Backport of AutoMTLS for secrets/auth plugins into release/1.11.x --- api/plugin_helpers.go | 6 +- builtin/plugin/backend.go | 29 +- builtin/plugin/backend_lazyLoad_test.go | 2 +- changelog/15671.txt | 3 + sdk/helper/pluginutil/env.go | 4 + sdk/helper/pluginutil/multiplexing.go | 2 + sdk/helper/pluginutil/run_config.go | 12 +- sdk/helper/pluginutil/run_config_test.go | 12 +- sdk/helper/pluginutil/runner.go | 2 - sdk/plugin/backend.go | 20 +- sdk/plugin/plugin.go | 60 +- sdk/plugin/serve.go | 10 +- vault/logical_system_integ_test.go | 924 ++++++++++++++--------- vault/plugin_catalog.go | 71 +- 14 files changed, 729 insertions(+), 428 deletions(-) create mode 100644 changelog/15671.txt diff --git a/api/plugin_helpers.go b/api/plugin_helpers.go index e8ceb9c2fd6e..47d7ddca03d4 100644 --- a/api/plugin_helpers.go +++ b/api/plugin_helpers.go @@ -17,6 +17,10 @@ import ( ) var ( + // PluginAutoMTLSEnv ensures AutoMTLS is used. This overrides setting a + // TLSProviderFunc for a plugin. + PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS" + // PluginMetadataModeEnv is an ENV name used to disable TLS communication // to bootstrap mounting plugins. PluginMetadataModeEnv = "VAULT_PLUGIN_METADATA_MODE" @@ -120,7 +124,7 @@ func VaultPluginTLSProvider(apiTLSConfig *TLSConfig) func() (*tls.Config, error) // VaultPluginTLSProviderContext is run inside a plugin and retrieves the response // wrapped TLS certificate from vault. It returns a configured TLS Config. func VaultPluginTLSProviderContext(ctx context.Context, apiTLSConfig *TLSConfig) func() (*tls.Config, error) { - if os.Getenv(PluginMetadataModeEnv) == "true" { + if os.Getenv(PluginAutoMTLSEnv) == "true" || os.Getenv(PluginMetadataModeEnv) == "true" { return nil } diff --git a/builtin/plugin/backend.go b/builtin/plugin/backend.go index d33fe9c1a8eb..67bfbd34cd28 100644 --- a/builtin/plugin/backend.go +++ b/builtin/plugin/backend.go @@ -7,6 +7,7 @@ import ( "reflect" "sync" + "github.com/hashicorp/go-multierror" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" @@ -49,17 +50,32 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, sys := conf.System - // NewBackend with isMetadataMode set to true - raw, err := bplugin.NewBackend(ctx, name, pluginType, sys, conf, true) + merr := &multierror.Error{} + // NewBackend with isMetadataMode set to false + raw, err := bplugin.NewBackend(ctx, name, pluginType, sys, conf, false, true) if err != nil { - return nil, err + merr = multierror.Append(merr, err) + // NewBackend with isMetadataMode set to true + raw, err = bplugin.NewBackend(ctx, name, pluginType, sys, conf, true, false) + if err != nil { + merr = multierror.Append(merr, err) + return nil, merr + } + } else { + b.Backend = raw + b.config = conf + b.loaded = true + b.autoMTLSSupported = true + + return &b, nil } + + // Setup the backend so we can inspect the SpecialPaths and Type err = raw.Setup(ctx, conf) if err != nil { raw.Cleanup(ctx) return nil, err } - // Get SpecialPaths and BackendType paths := raw.SpecialPaths() btype := raw.Type() @@ -83,7 +99,8 @@ type PluginBackend struct { logical.Backend sync.RWMutex - config *logical.BackendConfig + autoMTLSSupported bool + config *logical.BackendConfig // Used to detect if we already reloaded canary string @@ -103,7 +120,7 @@ func (b *PluginBackend) startBackend(ctx context.Context, storage logical.Storag // Ensure proper cleanup of the backend (i.e. call client.Kill()) b.Backend.Cleanup(ctx) - nb, err := bplugin.NewBackend(ctx, pluginName, pluginType, b.config.System, b.config, false) + nb, err := bplugin.NewBackend(ctx, pluginName, pluginType, b.config.System, b.config, false, b.autoMTLSSupported) if err != nil { return err } diff --git a/builtin/plugin/backend_lazyLoad_test.go b/builtin/plugin/backend_lazyLoad_test.go index 53c6f9611829..f6f61b28fda8 100644 --- a/builtin/plugin/backend_lazyLoad_test.go +++ b/builtin/plugin/backend_lazyLoad_test.go @@ -59,7 +59,7 @@ func testLazyLoad(t *testing.T, methodWrapper func() error) *PluginBackend { } // this is a dummy plugin that hasn't really been loaded yet - orig, err := plugin.NewBackend(ctx, "test-plugin", consts.PluginTypeSecrets, sysView, config, true) + orig, err := plugin.NewBackend(ctx, "test-plugin", consts.PluginTypeSecrets, sysView, config, true, false) if err != nil { t.Fatal(err) } diff --git a/changelog/15671.txt b/changelog/15671.txt new file mode 100644 index 000000000000..aaf0ca4d2b34 --- /dev/null +++ b/changelog/15671.txt @@ -0,0 +1,3 @@ +```release-note:improvement +plugins: Use AutoMTLS for secrets engines and auth methods run as external plugins. +``` diff --git a/sdk/helper/pluginutil/env.go b/sdk/helper/pluginutil/env.go index fd0cd4fb8308..ba6f8f6b580a 100644 --- a/sdk/helper/pluginutil/env.go +++ b/sdk/helper/pluginutil/env.go @@ -8,6 +8,10 @@ import ( ) var ( + // PluginAutoMTLSEnv is used to ensure AutoMTLS is used. This will override + // setting a TLSProviderFunc for a plugin. + PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS" + // PluginMlockEnabled is the ENV name used to pass the configuration for // enabling mlock PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go index cbf50335d0bf..05d9dd63086c 100644 --- a/sdk/helper/pluginutil/multiplexing.go +++ b/sdk/helper/pluginutil/multiplexing.go @@ -9,6 +9,8 @@ import ( status "google.golang.org/grpc/status" ) +const MultiplexingCtxKey string = "multiplex_id" + type PluginMultiplexingServerImpl struct { UnimplementedPluginMultiplexingServer diff --git a/sdk/helper/pluginutil/run_config.go b/sdk/helper/pluginutil/run_config.go index cb804f60d873..56e6c0f58ff1 100644 --- a/sdk/helper/pluginutil/run_config.go +++ b/sdk/helper/pluginutil/run_config.go @@ -22,6 +22,7 @@ type PluginClientConfig struct { IsMetadataMode bool AutoMTLS bool MLock bool + Wrapper RunnerUtil } type runConfig struct { @@ -33,8 +34,6 @@ type runConfig struct { // Initialized with what's in PluginRunner.Env, but can be added to env []string - wrapper RunnerUtil - PluginClientConfig } @@ -43,7 +42,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error cmd.Env = append(cmd.Env, rc.env...) // Add the mlock setting to the ENV of the plugin - if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) { + if rc.MLock || (rc.Wrapper != nil && rc.Wrapper.MlockEnabled()) { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version)) @@ -54,6 +53,9 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode) cmd.Env = append(cmd.Env, metadataEnv) + automtlsEnv := fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, rc.AutoMTLS) + cmd.Env = append(cmd.Env, automtlsEnv) + var clientTLSConfig *tls.Config if !rc.AutoMTLS && !rc.IsMetadataMode { // Get a CA TLS Certificate @@ -70,7 +72,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error // Use CA to sign a server cert and wrap the values in a response wrapped // token. - wrapToken, err := wrapServerConfig(ctx, rc.wrapper, certBytes, key) + wrapToken, err := wrapServerConfig(ctx, rc.Wrapper, certBytes, key) if err != nil { return nil, err } @@ -120,7 +122,7 @@ func Env(env ...string) RunOpt { func Runner(wrapper RunnerUtil) RunOpt { return func(rc *runConfig) { - rc.wrapper = wrapper + rc.Wrapper = wrapper } } diff --git a/sdk/helper/pluginutil/run_config_test.go b/sdk/helper/pluginutil/run_config_test.go index f2373fe9b4a5..4d1948c7bc38 100644 --- a/sdk/helper/pluginutil/run_config_test.go +++ b/sdk/helper/pluginutil/run_config_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os/exec" - "reflect" "testing" "time" @@ -14,6 +13,7 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) func TestMakeConfig(t *testing.T) { @@ -78,6 +78,7 @@ func TestMakeConfig(t *testing.T) { "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), }, ), SecureConfig: &plugin.SecureConfig{ @@ -143,6 +144,7 @@ func TestMakeConfig(t *testing.T) { fmt.Sprintf("%s=%t", PluginMlockEnabled, true), fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, "testtoken"), }, ), @@ -205,6 +207,7 @@ func TestMakeConfig(t *testing.T) { "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), }, ), SecureConfig: &plugin.SecureConfig{ @@ -266,6 +269,7 @@ func TestMakeConfig(t *testing.T) { "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), }, ), SecureConfig: &plugin.SecureConfig{ @@ -290,7 +294,7 @@ func TestMakeConfig(t *testing.T) { Return(test.responseWrapInfo, test.responseWrapInfoErr) mockWrapper.On("MlockEnabled"). Return(test.mlockEnabled) - test.rc.wrapper = mockWrapper + test.rc.Wrapper = mockWrapper defer mockWrapper.AssertNumberOfCalls(t, "ResponseWrapData", test.responseWrapInfoTimes) defer mockWrapper.AssertNumberOfCalls(t, "MlockEnabled", test.mlockEnabledTimes) @@ -318,9 +322,7 @@ func TestMakeConfig(t *testing.T) { } config.TLSConfig = nil - if !reflect.DeepEqual(config, test.expectedConfig) { - t.Fatalf("Actual config: %#v\nExpected config: %#v", config, test.expectedConfig) - } + require.Equal(t, config, test.expectedConfig) }) } } diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index f2822efc1040..e2bf2965519e 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -38,8 +38,6 @@ type PluginClient interface { plugin.ClientProtocol } -const MultiplexingCtxKey string = "multiplex_id" - // PluginRunner defines the metadata needed to run a plugin securely with // go-plugin. type PluginRunner struct { diff --git a/sdk/plugin/backend.go b/sdk/plugin/backend.go index 82c728732703..7b93c85836b6 100644 --- a/sdk/plugin/backend.go +++ b/sdk/plugin/backend.go @@ -20,9 +20,10 @@ var ( // GRPCBackendPlugin is the plugin.Plugin implementation that only supports GRPC // transport type GRPCBackendPlugin struct { - Factory logical.Factory - MetadataMode bool - Logger log.Logger + Factory logical.Factory + MetadataMode bool + AutoMTLSSupported bool + Logger log.Logger // Embeding this will disable the netRPC protocol plugin.NetRPCUnsupportedPlugin @@ -41,12 +42,13 @@ func (b GRPCBackendPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) func (b *GRPCBackendPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { ret := &backendGRPCPluginClient{ - client: pb.NewBackendClient(c), - clientConn: c, - broker: broker, - cleanupCh: make(chan struct{}), - doneCtx: ctx, - metadataMode: b.MetadataMode, + client: pb.NewBackendClient(c), + clientConn: c, + broker: broker, + cleanupCh: make(chan struct{}), + doneCtx: ctx, + // Only run in metadata mode if mode is true and autoMTLS is not supported + metadataMode: b.MetadataMode && !b.AutoMTLSSupported, } // Create the value and set the type diff --git a/sdk/plugin/plugin.go b/sdk/plugin/plugin.go index f4f2d8e18f67..830d83f5f90d 100644 --- a/sdk/plugin/plugin.go +++ b/sdk/plugin/plugin.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/hashicorp/errwrap" - log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" @@ -35,7 +34,7 @@ func (b *BackendPluginClient) Cleanup(ctx context.Context) { // external plugins, or a concrete implementation of the backend if it is a builtin backend. // The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether // the plugin should run in metadata mode. -func NewBackend(ctx context.Context, pluginName string, pluginType consts.PluginType, sys pluginutil.LookRunnerUtil, conf *logical.BackendConfig, isMetadataMode bool) (logical.Backend, error) { +func NewBackend(ctx context.Context, pluginName string, pluginType consts.PluginType, sys pluginutil.LookRunnerUtil, conf *logical.BackendConfig, isMetadataMode bool, autoMTLS bool) (logical.Backend, error) { // Look for plugin in the plugin catalog pluginRunner, err := sys.LookupPlugin(ctx, pluginName, pluginType) if err != nil { @@ -59,8 +58,16 @@ func NewBackend(ctx context.Context, pluginName string, pluginType consts.Plugin } } } else { + config := pluginutil.PluginClientConfig{ + Name: pluginName, + PluginType: pluginType, + Logger: conf.Logger.Named(pluginName), + IsMetadataMode: isMetadataMode, + AutoMTLS: autoMTLS, + Wrapper: sys, + } // create a backendPluginClient instance - backend, err = NewPluginClient(ctx, sys, pluginRunner, conf.Logger, isMetadataMode) + backend, err = NewPluginClient(ctx, pluginRunner, config) if err != nil { return nil, err } @@ -69,34 +76,49 @@ func NewBackend(ctx context.Context, pluginName string, pluginType consts.Plugin return backend, nil } -func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) { - // pluginMap is the map of plugins we can dispense. - pluginSet := map[int]plugin.PluginSet{ +// pluginSet returns the go-plugin PluginSet that we can dispense. This ensures +// that plugins that don't support AutoMTLS are run on the appropriate version. +func pluginSet(autoMTLS, metadataMode bool) map[int]plugin.PluginSet { + if autoMTLS { + return map[int]plugin.PluginSet{ + 5: { + "backend": &GRPCBackendPlugin{ + MetadataMode: false, + AutoMTLSSupported: true, + }, + }, + } + } + return map[int]plugin.PluginSet{ // Version 3 used to supports both protocols. We want to keep it around // since it's possible old plugins built against this version will still // work with gRPC. There is currently no difference between version 3 // and version 4. 3: { "backend": &GRPCBackendPlugin{ - MetadataMode: isMetadataMode, + MetadataMode: metadataMode, }, }, 4: { "backend": &GRPCBackendPlugin{ - MetadataMode: isMetadataMode, + MetadataMode: metadataMode, }, }, } +} - namedLogger := logger.Named(pluginRunner.Name) - - var client *plugin.Client - var err error - if isMetadataMode { - client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger) - } else { - client, err = pluginRunner.Run(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger) - } +func NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (logical.Backend, error) { + ps := pluginSet(config.AutoMTLS, config.IsMetadataMode) + + client, err := pluginRunner.RunConfig(ctx, + pluginutil.Runner(config.Wrapper), + pluginutil.PluginSets(ps), + pluginutil.HandshakeConfig(handshakeConfig), + pluginutil.Env(), + pluginutil.Logger(config.Logger), + pluginutil.MetadataMode(config.IsMetadataMode), + pluginutil.AutoMTLS(config.AutoMTLS), + ) if err != nil { return nil, err } @@ -126,9 +148,9 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne } // Wrap the backend in a tracing middleware - if namedLogger.IsTrace() { + if config.Logger.IsTrace() { backend = &backendTracingMiddleware{ - logger: namedLogger.With("transport", transport), + logger: config.Logger.With("transport", transport), next: backend, } } diff --git a/sdk/plugin/serve.go b/sdk/plugin/serve.go index 1119a2dac645..e518888d658c 100644 --- a/sdk/plugin/serve.go +++ b/sdk/plugin/serve.go @@ -37,12 +37,13 @@ func Serve(opts *ServeOpts) error { }) } - // pluginMap is the map of plugins we can dispense. + // pluginSets is the map of plugins we can dispense. pluginSets := map[int]plugin.PluginSet{ // Version 3 used to supports both protocols. We want to keep it around // since it's possible old plugins built against this version will still // work with gRPC. There is currently no difference between version 3 // and version 4. + // AutoMTLS is not supported by versions lower than 5. 3: { "backend": &GRPCBackendPlugin{ Factory: opts.BackendFactoryFunc, @@ -55,6 +56,13 @@ func Serve(opts *ServeOpts) error { Logger: logger, }, }, + 5: { + "backend": &GRPCBackendPlugin{ + Factory: opts.BackendFactoryFunc, + Logger: logger, + AutoMTLSSupported: true, + }, + }, } err := pluginutil.OptionallyEnableMlock() diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index 1742316422d3..b74231a8fd70 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -31,219 +31,306 @@ const ( expectedEnvValue = "BAR" ) -func TestSystemBackend_Plugin_secret(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() +// logicalVersionMap is a map of version to test plugin +var logicalVersionMap = map[string]string{ + "v4": "TestBackend_PluginMain_V4_Logical", + "v5": "TestBackend_PluginMainLogical", +} - core := cluster.Cores[0] +// credentialVersionMap is a map of version to test plugin +var credentialVersionMap = map[string]string{ + "v4": "TestBackend_PluginMain_V4_Credentials", + "v5": "TestBackend_PluginMainCredentials", +} - // Make a request to lazy load the plugin - req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") +func TestSystemBackend_Plugin_secret(t *testing.T) { + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Seal the cluster - cluster.EnsureCoresSealed(t) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() + + core := cluster.Cores[0] - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil { - t.Fatal(err) + t.Fatalf("err: %v", err) } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } + }) } } func TestSystemBackend_Plugin_auth(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Make a request to lazy load the plugin - req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Seal the cluster - cluster.EnsureCoresSealed(t) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential, tc.pluginVersion) + defer cluster.Cleanup() + + core := cluster.Cores[0] - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(nil), req) if err != nil { - t.Fatal(err) + t.Fatalf("err: %v", err) } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } + }) } } func TestSystemBackend_Plugin_MissingBinary(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, + } - core := cluster.Cores[0] + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() - // Make a request to lazy load the plugin - req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") - } + core := cluster.Cores[0] - // Seal the cluster - cluster.EnsureCoresSealed(t) + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } - // Simulate removal of the plugin binary. Use os.Args to determine file name - // since that's how we create the file for catalog registration in the test - // helper. - pluginFileName := filepath.Base(os.Args[0]) - err = os.Remove(filepath.Join(cluster.TempDir, pluginFileName)) - if err != nil { - t.Fatal(err) - } + // Seal the cluster + cluster.EnsureCoresSealed(t) - // Unseal the cluster - cluster.UnsealCores(t) + // Simulate removal of the plugin binary. Use os.Args to determine file name + // since that's how we create the file for catalog registration in the test + // helper. + pluginFileName := filepath.Base(os.Args[0]) + err = os.Remove(filepath.Join(cluster.TempDir, pluginFileName)) + if err != nil { + t.Fatal(err) + } + + // Unseal the cluster + cluster.UnsealCores(t) - // Make a request against on tune after it is removed - req = logical.TestRequest(t, logical.ReadOperation, "sys/mounts/mock-0/tune") - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err == nil { - t.Fatalf("expected error") + // Make a request against on tune after it is removed + req = logical.TestRequest(t, logical.ReadOperation, "sys/mounts/mock-0/tune") + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatalf("expected error") + } + }) } } func TestSystemBackend_Plugin_MismatchType(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, + } - core := cluster.Cores[0] + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() - // Add a credential backend with the same name - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") + core := cluster.Cores[0] - // Make a request to lazy load the now-credential plugin - // and expect an error - req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - _, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("adding a same-named plugin of a different type should be no problem: %s", err) - } + // Add a credential backend with the same name + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") - // Sleep a bit before cleanup is called - time.Sleep(1 * time.Second) + // Make a request to lazy load the now-credential plugin + // and expect an error + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + _, err := core.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("adding a same-named plugin of a different type should be no problem: %s", err) + } + + // Sleep a bit before cleanup is called + time.Sleep(1 * time.Second) + }) + } } func TestSystemBackend_Plugin_CatalogRemoved(t *testing.T) { t.Run("secret", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeLogical, false) + testPlugin_CatalogRemoved(t, logical.TypeLogical, false, logicalVersionMap) }) t.Run("auth", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeCredential, false) + testPlugin_CatalogRemoved(t, logical.TypeCredential, false, credentialVersionMap) }) t.Run("secret-mount-existing", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeLogical, true) + testPlugin_CatalogRemoved(t, logical.TypeLogical, true, logicalVersionMap) }) t.Run("auth-mount-existing", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeCredential, true) + testPlugin_CatalogRemoved(t, logical.TypeCredential, true, credentialVersionMap) }) } -func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool) { - cluster := testSystemBackendMock(t, 1, 1, btype) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Remove the plugin from the catalog - req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/database/mock-plugin") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) +func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool, versionMap map[string]string) { + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Seal the cluster - cluster.EnsureCoresSealed(t) - - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) - if err != nil { - t.Fatal(err) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Remove the plugin from the catalog + req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/database/mock-plugin") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(nil), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + } - if testMount { - // Mount the plugin at the same path after plugin is re-added to the catalog - // and expect an error due to existing path. - var err error - switch btype { - case logical.TypeLogical: - // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, "") - _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ - "type": "test", - }) - case logical.TypeCredential: - // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") - _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ - "type": "test", - }) - } - if err == nil { - t.Fatal("expected error when mounting on existing path") - } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + + if testMount { + // Mount the plugin at the same path after plugin is re-added to the catalog + // and expect an error due to existing path. + var err error + switch btype { + case logical.TypeLogical: + // Add plugin back to the catalog + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, logicalVersionMap[tc.pluginVersion], []string{}, "") + _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ + "type": "test", + }) + case logical.TypeCredential: + // Add plugin back to the catalog + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, credentialVersionMap[tc.pluginVersion], []string{}, "") + _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ + "type": "test", + }) + } + if err == nil { + t.Fatal("expected error when mounting on existing path") + } + } + }) } } @@ -278,180 +365,215 @@ func TestSystemBackend_Plugin_continueOnError(t *testing.T) { } func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool, mountPoint string, pluginType consts.PluginType) { - cluster := testSystemBackendMock(t, 1, 1, btype) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Get the registered plugin - req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || resp == nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } - - command, ok := resp.Data["command"].(string) - if !ok || command == "" { - t.Fatal("invalid command") - } - - // Mount credential type plugins - switch btype { - case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, mountPoint, consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir) - _, err = core.Client.Logical().Write(fmt.Sprintf("sys/auth/%s", mountPoint), map[string]interface{}{ - "type": "mock-plugin", - }) - if err != nil { - t.Fatalf("err:%v", err) - } + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Trigger a sha256 mismatch or missing plugin error - if mismatch { - req = logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) - req.Data = map[string]interface{}{ - "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", - "command": filepath.Base(command), - } - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } - } else { - err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command))) - if err != nil { - t.Fatal(err) - } - } + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, btype, tc.pluginVersion) + defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Get the registered plugin + req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(nil), req) + if err != nil || resp == nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } - // Seal the cluster - cluster.EnsureCoresSealed(t) + command, ok := resp.Data["command"].(string) + if !ok || command == "" { + t.Fatal("invalid command") + } - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) - if err != nil { - t.Fatal(err) + // Trigger a sha256 mismatch or missing plugin error + if mismatch { + req = logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) + req.Data = map[string]interface{}{ + "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", + "command": filepath.Base(command), + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(nil), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + } else { + err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command))) + if err != nil { + t.Fatal(err) + } } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + } - // Re-add the plugin to the catalog - switch btype { - case logical.TypeLogical: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, cluster.TempDir) - case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir) - } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + + // Re-add the plugin to the catalog + switch btype { + case logical.TypeLogical: + plugin := logicalVersionMap[tc.pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, plugin, []string{}, cluster.TempDir) + case logical.TypeCredential: + plugin := credentialVersionMap[tc.pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, plugin, []string{}, cluster.TempDir) + } - // Reload the plugin - req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend") - req.Data = map[string]interface{}{ - "plugin": "mock-plugin", - } - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } + // Reload the plugin + req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend") + req.Data = map[string]interface{}{ + "plugin": "mock-plugin", + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(nil), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } - // Make a request to lazy load the plugin - var reqPath string - switch btype { - case logical.TypeLogical: - reqPath = "mock-0/internal" - case logical.TypeCredential: - reqPath = "auth/mock-0/internal" - } + // Make a request to lazy load the plugin + var reqPath string + switch btype { + case logical.TypeLogical: + reqPath = "mock-0/internal" + case logical.TypeCredential: + reqPath = "auth/mock-0/internal" + } - req = logical.TestRequest(t, logical.ReadOperation, reqPath) - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") + req = logical.TestRequest(t, logical.ReadOperation, reqPath) + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + }) } } func TestSystemBackend_Plugin_autoReload(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, + } - core := cluster.Cores[0] + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() - // Update internal value - req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - req.Data["value"] = "baz" - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp != nil { - t.Fatalf("bad: %v", resp) - } + core := cluster.Cores[0] - // Call errors/rpc endpoint to trigger reload - req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc") - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err == nil { - t.Fatalf("expected error from error/rpc request") - } + // Update internal value + req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + req.Data["value"] = "baz" + resp, err := core.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } - // Check internal value to make sure it's reset - req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") - } - if resp.Data["value"].(string) == "baz" { - t.Fatal("did not expect backend internal value to be 'baz'") + // Call errors/rpc endpoint to trigger reload + req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc") + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatalf("expected error from error/rpc request") + } + + // Check internal value to make sure it's reset + req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + if resp.Data["value"].(string) == "baz" { + t.Fatal("did not expect backend internal value to be 'baz'") + } + }) } } func TestSystemBackend_Plugin_SealUnseal(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() - - // Seal the cluster - cluster.EnsureCoresSealed(t) + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, + } - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) - if err != nil { - t.Fatal(err) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, cluster.Cores[0].Core) + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, cluster.Cores[0].Core) + }) + } } func TestSystemBackend_Plugin_reload(t *testing.T) { @@ -498,53 +620,68 @@ func TestSystemBackend_Plugin_reload(t *testing.T) { // Helper func to test different reload methods on plugin reload endpoint func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}, backendType logical.BackendType) { - cluster := testSystemBackendMock(t, 1, 2, backendType) - defer cluster.Cleanup() + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, + } - core := cluster.Cores[0] - client := core.Client + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 2, backendType, tc.pluginVersion) + defer cluster.Cleanup() - pathPrefix := "mock-" - if backendType == logical.TypeCredential { - pathPrefix = "auth/" + pathPrefix - } - for i := 0; i < 2; i++ { - // Update internal value in the backend - resp, err := client.Logical().Write(fmt.Sprintf("%s%d/internal", pathPrefix, i), map[string]interface{}{ - "value": "baz", - }) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp != nil { - t.Fatalf("bad: %v", resp) - } - } + core := cluster.Cores[0] + client := core.Client - // Perform plugin reload - resp, err := client.Logical().Write("sys/plugins/reload/backend", reqData) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: %v", resp) - } - if resp.Data["reload_id"] == nil { - t.Fatal("no reload_id in response") - } + pathPrefix := "mock-" + if backendType == logical.TypeCredential { + pathPrefix = "auth/" + pathPrefix + } + for i := 0; i < 2; i++ { + // Update internal value in the backend + resp, err := client.Logical().Write(fmt.Sprintf("%s%d/internal", pathPrefix, i), map[string]interface{}{ + "value": "baz", + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } + } - for i := 0; i < 2; i++ { - // Ensure internal backed value is reset - resp, err := client.Logical().Read(fmt.Sprintf("%s%d/internal", pathPrefix, i)) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") - } - if resp.Data["value"].(string) == "baz" { - t.Fatal("did not expect backend internal value to be 'baz'") - } + // Perform plugin reload + resp, err := client.Logical().Write("sys/plugins/reload/backend", reqData) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: %v", resp) + } + if resp.Data["reload_id"] == nil { + t.Fatal("no reload_id in response") + } + + for i := 0; i < 2; i++ { + // Ensure internal backed value is reset + resp, err := client.Logical().Read(fmt.Sprintf("%s%d/internal", pathPrefix, i)) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + if resp.Data["value"].(string) == "baz" { + t.Fatal("did not expect backend internal value to be 'baz'") + } + } + }) } } @@ -553,7 +690,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} // ways of providing the plugin_name. // // The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts] -func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster { +func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType, pluginVersion string) *vault.TestCluster { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "plugin": plugin.Factory, @@ -585,7 +722,8 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo switch backendType { case logical.TypeLogical: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, tempDir) + plugin := logicalVersionMap[pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, plugin, []string{}, tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ @@ -600,7 +738,8 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo } } case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, tempDir) + plugin := credentialVersionMap[pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, plugin, []string{}, tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ @@ -671,9 +810,15 @@ func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.Test return cluster } -func TestBackend_PluginMainLogical(t *testing.T) { +func TestBackend_PluginMain_V4_Logical(t *testing.T) { args := []string{} - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" { + // don't run as a standalone unit test + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + // don't run as a V5 plugin + if os.Getenv(pluginutil.PluginAutoMTLSEnv) == "true" { return } @@ -686,6 +831,8 @@ func TestBackend_PluginMainLogical(t *testing.T) { apiClientMeta := &api.PluginAPIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) + + // V4 does not support AutoMTLS so we set a TLSConfig via TLSProviderFunc tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) @@ -700,9 +847,9 @@ func TestBackend_PluginMainLogical(t *testing.T) { } } -func TestBackend_PluginMainCredentials(t *testing.T) { +func TestBackend_PluginMainLogical(t *testing.T) { args := []string{} - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" { + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { return } @@ -715,6 +862,40 @@ func TestBackend_PluginMainCredentials(t *testing.T) { apiClientMeta := &api.PluginAPIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) + + factoryFunc := mock.FactoryType(logical.TypeLogical) + + err := lplugin.Serve(&lplugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_PluginMain_V4_Credentials(t *testing.T) { + args := []string{} + // don't run as a standalone unit test + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + // don't run as a V5 plugin + if os.Getenv(pluginutil.PluginAutoMTLSEnv) == "true" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + // V4 does not support AutoMTLS so we set a TLSConfig via TLSProviderFunc tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) @@ -729,6 +910,32 @@ func TestBackend_PluginMainCredentials(t *testing.T) { } } +func TestBackend_PluginMainCredentials(t *testing.T) { + args := []string{} + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + factoryFunc := mock.FactoryType(logical.TypeCredential) + + err := lplugin.Serve(&lplugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + }) + if err != nil { + t.Fatal(err) + } +} + // TestBackend_PluginMainEnv is a mock plugin that simply checks for the existence of FOO env var. func TestBackend_PluginMainEnv(t *testing.T) { args := []string{} @@ -751,14 +958,11 @@ func TestBackend_PluginMainEnv(t *testing.T) { apiClientMeta := &api.PluginAPIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) - tlsConfig := apiClientMeta.GetTLSConfig() - tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) factoryFunc := mock.FactoryType(logical.TypeLogical) err := lplugin.Serve(&lplugin.ServeOpts{ BackendFactoryFunc: factoryFunc, - TLSProviderFunc: tlsProviderFunc, }) if err != nil { t.Fatal(err) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 9caaa2410f2f..d839cdce29cf 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -325,39 +325,72 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log } merr = multierror.Append(merr, err) - // Attempt to run as backend plugin - client, err := backendplugin.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + pluginType, err := c.getBackendPluginType(ctx, plugin) if err == nil { - err := client.Setup(ctx, &logical.BackendConfig{}) + return pluginType, nil + } + merr = multierror.Append(merr, err) + + return consts.PluginTypeUnknown, merr +} + +// getBackendPluginType returns the plugin type (secrets/auth) and an error if +// the plugin is not a backend plugin. +func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (consts.PluginType, error) { + var client logical.Backend + var merr *multierror.Error + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + Logger: log.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, + } + + // Attempt to run as backend V5 plugin + c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name) + client, err := backendplugin.NewPluginClient(ctx, pluginRunner, config) + if err != nil { + merr = multierror.Append(merr, err) + c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err) + config.AutoMTLS = false + config.IsMetadataMode = true + // attemtp to run as a v4 backend plugin + client, err = backendplugin.NewPluginClient(ctx, pluginRunner, config) if err != nil { - return consts.PluginTypeUnknown, err + c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", err) + return consts.PluginTypeUnknown, merr.ErrorOrNil() } + c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name) + } - backendType := client.Type() - client.Cleanup(ctx) + err = client.Setup(ctx, &logical.BackendConfig{}) + if err != nil { + return consts.PluginTypeUnknown, err + } + backendType := client.Type() + client.Cleanup(ctx) - switch backendType { - case logical.TypeCredential: - return consts.PluginTypeCredential, nil - case logical.TypeLogical: - return consts.PluginTypeSecrets, nil - } - } else { - merr = multierror.Append(merr, err) + switch backendType { + case logical.TypeCredential: + return consts.PluginTypeCredential, nil + case logical.TypeLogical: + return consts.PluginTypeSecrets, nil } if client == nil || client.Type() == logical.TypeUnknown { - logger.Warn("unknown plugin type", - "plugin name", plugin.Name, + c.logger.Warn("unknown plugin type", + "plugin name", pluginRunner.Name, "error", merr.Error()) } else { - logger.Warn("unsupported plugin type", - "plugin name", plugin.Name, + c.logger.Warn("unsupported plugin type", + "plugin name", pluginRunner.Name, "plugin type", client.Type().String(), "error", merr.Error()) } - return consts.PluginTypeUnknown, nil + merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as backend plugin: %w", err)) + + return consts.PluginTypeUnknown, merr.ErrorOrNil() } // isDatabasePlugin returns true if the plugin supports multiplexing. An error