Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New root namespace plugin reload API sys/plugins/reload/:type/:name #24878

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion api/plugin_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ package api
// https://github.com/hashicorp/vault/blob/main/sdk/helper/consts/plugin_types.go
// Any changes made should be made to both files at the same time.

import "fmt"
import (
"encoding/json"
"fmt"
)

var PluginTypes = []PluginType{
PluginTypeUnknown,
Expand Down Expand Up @@ -64,3 +67,34 @@ func ParsePluginType(pluginType string) (PluginType, error) {
return PluginTypeUnknown, fmt.Errorf("%q is not a supported plugin type", pluginType)
}
}

// UnmarshalJSON implements json.Unmarshaler. It supports unmarshaling either a
// string or a uint32. All new serialization will be as a string, but we
// previously serialized as a uint32 so we need to support that for backwards
// compatibility.
func (p *PluginType) UnmarshalJSON(data []byte) error {
var asString string
err := json.Unmarshal(data, &asString)
if err == nil {
*p, err = ParsePluginType(asString)
return err
}

var asUint32 uint32
err = json.Unmarshal(data, &asUint32)
if err != nil {
return err
}
*p = PluginType(asUint32)
switch *p {
case PluginTypeUnknown, PluginTypeCredential, PluginTypeDatabase, PluginTypeSecrets:
return nil
default:
return fmt.Errorf("%d is not a supported plugin type", asUint32)
}
}

// MarshalJSON implements json.Marshaler.
func (p PluginType) MarshalJSON() ([]byte, error) {
return json.Marshal(p.String())
}
101 changes: 101 additions & 0 deletions api/plugin_types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package api

// NOTE: this file was copied from
// https://github.com/hashicorp/vault/blob/main/sdk/helper/consts/plugin_types_test.go
// Any changes made should be made to both files at the same time.

import (
"encoding/json"
"testing"
)

type testType struct {
PluginType PluginType `json:"plugin_type"`
}

func TestPluginTypeJSONRoundTrip(t *testing.T) {
for _, pluginType := range PluginTypes {
original := testType{
PluginType: pluginType,
}
asBytes, err := json.Marshal(original)
if err != nil {
t.Fatal(err)
}

var roundTripped testType
err = json.Unmarshal(asBytes, &roundTripped)
if err != nil {
t.Fatal(err)
}

if original != roundTripped {
t.Fatalf("expected %v, got %v", original, roundTripped)
}
}
}

func TestPluginTypeJSONUnmarshal(t *testing.T) {
// Failure/unsupported cases.
for name, tc := range map[string]string{
"unsupported": `{"plugin_type":"unsupported"}`,
"random string": `{"plugin_type":"foo"}`,
"boolean": `{"plugin_type":true}`,
"empty": `{"plugin_type":""}`,
"negative": `{"plugin_type":-1}`,
"out of range": `{"plugin_type":10}`,
} {
t.Run(name, func(t *testing.T) {
var result testType
err := json.Unmarshal([]byte(tc), &result)
if err == nil {
t.Fatal("expected error")
}
})
}

// Valid cases.
for name, tc := range map[string]struct {
json string
expected PluginType
}{
"unknown": {`{"plugin_type":"unknown"}`, PluginTypeUnknown},
"auth": {`{"plugin_type":"auth"}`, PluginTypeCredential},
"secret": {`{"plugin_type":"secret"}`, PluginTypeSecrets},
"database": {`{"plugin_type":"database"}`, PluginTypeDatabase},
"absent": {`{}`, PluginTypeUnknown},
"integer unknown": {`{"plugin_type":0}`, PluginTypeUnknown},
"integer auth": {`{"plugin_type":1}`, PluginTypeCredential},
"integer db": {`{"plugin_type":2}`, PluginTypeDatabase},
"integer secret": {`{"plugin_type":3}`, PluginTypeSecrets},
} {
t.Run(name, func(t *testing.T) {
var result testType
err := json.Unmarshal([]byte(tc.json), &result)
if err != nil {
t.Fatal(err)
}
if tc.expected != result.PluginType {
t.Fatalf("expected %v, got %v", tc.expected, result.PluginType)
}
})
}
}

func TestUnknownTypeExcludedWithOmitEmpty(t *testing.T) {
type testTypeOmitEmpty struct {
Type PluginType `json:"type,omitempty"`
}
bytes, err := json.Marshal(testTypeOmitEmpty{})
if err != nil {
t.Fatal(err)
}
m := map[string]any{}
json.Unmarshal(bytes, &m)
if _, exists := m["type"]; exists {
t.Fatal("type should not be present")
}
}
29 changes: 25 additions & 4 deletions api/sys_plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,22 @@ func (c *Sys) DeregisterPluginWithContext(ctx context.Context, i *DeregisterPlug
return err
}

// RootReloadPluginInput is used as input to the RootReloadPlugin function.
type RootReloadPluginInput struct {
Plugin string `json:"-"` // Plugin name, as registered in the plugin catalog.
Type PluginType `json:"-"` // Plugin type: auth, secret, or database.
Scope string `json:"scope,omitempty"` // Empty to reload on current node, "global" for all nodes.
}

// RootReloadPlugin reloads plugins, possibly returning reloadID for a global
// scoped reload. This is only available in the root namespace, and reloads
// plugins across all namespaces, whereas ReloadPlugin is available in all
// namespaces but only reloads plugins in use in the request's namespace.
func (c *Sys) RootReloadPlugin(ctx context.Context, i *RootReloadPluginInput) (string, error) {
path := fmt.Sprintf("/v1/sys/plugins/reload/%s/%s", i.Type.String(), i.Plugin)
return c.reloadPluginInternal(ctx, path, i, i.Scope == "global")
}

// ReloadPluginInput is used as input to the ReloadPlugin function.
type ReloadPluginInput struct {
// Plugin is the name of the plugin to reload, as registered in the plugin catalog
Expand All @@ -292,15 +308,20 @@ func (c *Sys) ReloadPlugin(i *ReloadPluginInput) (string, error) {
}

// ReloadPluginWithContext reloads mounted plugin backends, possibly returning
// reloadId for a cluster scoped reload
// reloadID for a cluster scoped reload. It is limited to reloading plugins that
// are in use in the request's namespace. See RootReloadPlugin for an API that
// can reload plugins across all namespaces.
func (c *Sys) ReloadPluginWithContext(ctx context.Context, i *ReloadPluginInput) (string, error) {
return c.reloadPluginInternal(ctx, "/v1/sys/plugins/reload/backend", i, i.Scope == "global")
}

func (c *Sys) reloadPluginInternal(ctx context.Context, path string, body any, global bool) (string, error) {
ctx, cancelFunc := c.c.withConfiguredTimeout(ctx)
defer cancelFunc()

path := "/v1/sys/plugins/reload/backend"
req := c.c.NewRequest(http.MethodPut, path)

if err := req.SetJSONBody(i); err != nil {
if err := req.SetJSONBody(body); err != nil {
return "", err
}

Expand All @@ -310,7 +331,7 @@ func (c *Sys) ReloadPluginWithContext(ctx context.Context, i *ReloadPluginInput)
}
defer resp.Body.Close()

if i.Scope == "global" {
if global {
// Get the reload id
secret, parseErr := ParseSecret(resp.Body)
if parseErr != nil {
Expand Down
6 changes: 6 additions & 0 deletions changelog/24878.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
```release-note:improvement
plugins: New API `sys/plugins/reload/:type/:name` available in the root namespace for reloading a specific plugin across all namespaces.
```
```release-note:change
cli: Using `vault plugin reload` with `-plugin` in the root namespace will now reload the plugin across all namespaces instead of just the root namespace.
```
6 changes: 3 additions & 3 deletions command/plugin_register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func TestFlagParsing(t *testing.T) {
pluginType: api.PluginTypeUnknown,
name: "foo",
sha256: "abc123",
expectedPayload: `{"type":0,"command":"foo","sha256":"abc123"}`,
expectedPayload: `{"type":"unknown","command":"foo","sha256":"abc123"}`,
},
"full": {
pluginType: api.PluginTypeCredential,
Expand All @@ -261,14 +261,14 @@ func TestFlagParsing(t *testing.T) {
sha256: "abc123",
args: []string{"--a=b", "--b=c", "positional"},
env: []string{"x=1", "y=2"},
expectedPayload: `{"type":1,"args":["--a=b","--b=c","positional"],"command":"cmd","sha256":"abc123","version":"v1.0.0","oci_image":"image","runtime":"runtime","env":["x=1","y=2"]}`,
expectedPayload: `{"type":"auth","args":["--a=b","--b=c","positional"],"command":"cmd","sha256":"abc123","version":"v1.0.0","oci_image":"image","runtime":"runtime","env":["x=1","y=2"]}`,
},
"command remains empty if oci_image specified": {
pluginType: api.PluginTypeCredential,
name: "name",
ociImage: "image",
sha256: "abc123",
expectedPayload: `{"type":1,"sha256":"abc123","oci_image":"image"}`,
expectedPayload: `{"type":"auth","sha256":"abc123","oci_image":"image"}`,
},
} {
tc := tc
Expand Down
72 changes: 57 additions & 15 deletions command/plugin_reload.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package command

import (
"context"
"fmt"
"strings"

Expand All @@ -19,9 +20,10 @@ var (

type PluginReloadCommand struct {
*BaseCommand
plugin string
mounts []string
scope string
plugin string
mounts []string
scope string
pluginType string
}

func (c *PluginReloadCommand) Synopsis() string {
Expand All @@ -36,9 +38,16 @@ Usage: vault plugin reload [options]
mount(s) must be provided, but not both. In case the plugin name is provided,
all of its corresponding mounted paths that use the plugin backend will be reloaded.

Reload the plugin named "my-custom-plugin":
If run with a Vault namespace other than the root namespace, only plugins
running in the same namespace will be reloaded.

$ vault plugin reload -plugin=my-custom-plugin
Reload the secret plugin named "my-custom-plugin" on the current node:

$ vault plugin reload -type=secret -plugin=my-custom-plugin

Reload the secret plugin named "my-custom-plugin" across all nodes and replicated clusters:

$ vault plugin reload -type=secret -plugin=my-custom-plugin -scope=global

` + c.Flags().Help()

Expand Down Expand Up @@ -68,7 +77,15 @@ func (c *PluginReloadCommand) Flags() *FlagSets {
Name: "scope",
Target: &c.scope,
Completion: complete.PredictAnything,
Usage: "The scope of the reload, omitted for local, 'global', for replicated reloads",
Usage: "The scope of the reload, omitted for local, 'global', for replicated reloads.",
})

f.StringVar(&StringVar{
Name: "type",
Target: &c.pluginType,
Completion: complete.PredictAnything,
Usage: "The type of plugin to reload, one of auth, secret, or database. Mutually " +
"exclusive with -mounts. If not provided, all plugins with a matching name will be reloaded.",
})

return set
Expand Down Expand Up @@ -103,6 +120,10 @@ func (c *PluginReloadCommand) Run(args []string) int {
return 1
case c.scope != "" && c.scope != "global":
c.UI.Error(fmt.Sprintf("Invalid reload scope: %s", c.scope))
return 1
case len(c.mounts) > 0 && c.pluginType != "":
c.UI.Error("Cannot specify -type with -mounts")
return 1
}

client, err := c.Client()
Expand All @@ -111,25 +132,46 @@ func (c *PluginReloadCommand) Run(args []string) int {
return 2
}

rid, err := client.Sys().ReloadPlugin(&api.ReloadPluginInput{
Plugin: c.plugin,
Mounts: c.mounts,
Scope: c.scope,
})
var reloadID string
if client.Namespace() == "" {
pluginType := api.PluginTypeUnknown
pluginTypeStr := strings.TrimSpace(c.pluginType)
if pluginTypeStr != "" {
var err error
pluginType, err = api.ParsePluginType(pluginTypeStr)
if err != nil {
c.UI.Error(fmt.Sprintf("Error parsing -type as a plugin type, must be unset or one of auth, secret, or database: %s", err))
return 1
}
}

reloadID, err = client.Sys().RootReloadPlugin(context.Background(), &api.RootReloadPluginInput{
Plugin: c.plugin,
Type: pluginType,
Scope: c.scope,
})
} else {
reloadID, err = client.Sys().ReloadPlugin(&api.ReloadPluginInput{
Plugin: c.plugin,
Mounts: c.mounts,
Scope: c.scope,
})
}

if err != nil {
c.UI.Error(fmt.Sprintf("Error reloading plugin/mounts: %s", err))
return 2
}

if len(c.mounts) > 0 {
if rid != "" {
c.UI.Output(fmt.Sprintf("Success! Reloading mounts: %s, reload_id: %s", c.mounts, rid))
if reloadID != "" {
c.UI.Output(fmt.Sprintf("Success! Reloading mounts: %s, reload_id: %s", c.mounts, reloadID))
} else {
c.UI.Output(fmt.Sprintf("Success! Reloaded mounts: %s", c.mounts))
}
} else {
if rid != "" {
c.UI.Output(fmt.Sprintf("Success! Reloading plugin: %s, reload_id: %s", c.plugin, rid))
if reloadID != "" {
c.UI.Output(fmt.Sprintf("Success! Reloading plugin: %s, reload_id: %s", c.plugin, reloadID))
} else {
c.UI.Output(fmt.Sprintf("Success! Reloaded plugin: %s", c.plugin))
}
Expand Down
12 changes: 12 additions & 0 deletions command/plugin_reload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ func TestPluginReloadCommand_Run(t *testing.T) {
"Must specify exactly one of -plugin or -mounts",
1,
},
{
"type_and_mounts_mutually_exclusive",
[]string{"-mounts", "bar", "-type", "secret"},
"Cannot specify -type with -mounts",
1,
},
{
"invalid_type",
[]string{"-plugin", "bar", "-type", "unsupported"},
"Error parsing -type as a plugin type",
1,
},
}

for _, tc := range cases {
Expand Down
Loading