Skip to content

Commit

Permalink
Check if plugin version matches running version (#17182)
Browse files Browse the repository at this point in the history
Check if plugin version matches running version

When registering a plugin, we check if the request version matches the
self-reported version from the plugin. If these do not match, we log a
warning.

This uncovered a few missing pieces for getting the database version
code fully working.

We added an environment variable that helps us unit test the running
version behavior as well, but only for approle, postgresql, and consul
plugins.

Return 400 on plugin not found or version mismatch

Populate the running SHA256 of plugins in the mount and auth tables (#17217)
  • Loading branch information
Christopher Swenson authored Sep 21, 2022
1 parent c3c323d commit 0b34b73
Show file tree
Hide file tree
Showing 13 changed files with 394 additions and 72 deletions.
8 changes: 6 additions & 2 deletions builtin/credential/approle/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ const (
secretIDAccessorLocalPrefix = "accessor_local/"
)

// ReportedVersion is used to report a specific version to Vault.
var ReportedVersion = ""

type backend struct {
*framework.Backend

Expand Down Expand Up @@ -111,8 +114,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathTidySecretID(b),
},
),
Invalidate: b.invalidate,
BackendType: logical.TypeCredential,
Invalidate: b.invalidate,
BackendType: logical.TypeCredential,
RunningVersion: ReportedVersion,
}
return b, nil
}
Expand Down
6 changes: 5 additions & 1 deletion builtin/logical/consul/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)

// ReportedVersion is used to report a specific version to Vault.
var ReportedVersion = ""

func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(ctx, conf); err != nil {
Expand Down Expand Up @@ -34,7 +37,8 @@ func Backend() *backend {
Secrets: []*framework.Secret{
secretToken(&b),
},
BackendType: logical.TypeLogical,
BackendType: logical.TypeLogical,
RunningVersion: ReportedVersion,
}

return &b
Expand Down
11 changes: 10 additions & 1 deletion plugins/database/postgresql/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/helper/template"
"github.com/hashicorp/vault/sdk/logical"
_ "github.com/jackc/pgx/v4/stdlib"
)

Expand All @@ -32,7 +33,8 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
)

var (
_ dbplugin.Database = &PostgreSQL{}
_ dbplugin.Database = (*PostgreSQL)(nil)
_ logical.PluginVersioner = (*PostgreSQL)(nil)

// postgresEndStatement is basically the word "END" but
// surrounded by a word boundary to differentiate it from
Expand All @@ -46,6 +48,9 @@ var (
// singleQuotedPhrases finds substrings like 'hello'
// and pulls them out with the quotes included.
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)

// ReportedVersion is used to report a specific version to Vault.
ReportedVersion = ""
)

func New() (interface{}, error) {
Expand Down Expand Up @@ -469,6 +474,10 @@ func (p *PostgreSQL) secretValues() map[string]string {
}
}

func (p *PostgreSQL) PluginVersion() logical.PluginVersion {
return logical.PluginVersion{Version: ReportedVersion}
}

// containsMultilineStatement is a best effort to determine whether
// a particular statement is multiline, and therefore should not be
// split upon semicolons. If it's unsure, it defaults to false.
Expand Down
33 changes: 31 additions & 2 deletions sdk/database/dbplugin/v5/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package dbplugin

import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -43,11 +45,14 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error)
if err != nil {
return nil, err
}

if db, ok := g.instances[id]; ok {
return db, nil
}
return g.createDatabase(id)
}

// must hold the g.Lock() to call this function
func (g *gRPCServer) createDatabase(id string) (Database, error) {
db, err := g.factoryFunc()
if err != nil {
return nil, err
Expand Down Expand Up @@ -304,12 +309,36 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e
return &proto.Empty{}, nil
}

// getOrForceCreateDatabase will create a database even if the multiplexing ID is not present
func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) {
impl, err := g.getOrCreateDatabase(ctx)
if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) {
// if this is called without a multiplexing context, like from the plugin catalog directly,
// then we won't have a database ID, so let's generate a new database instance
id, err := base62.Random(10)
if err != nil {
return nil, err
}

g.Lock()
defer g.Unlock()
impl, err = g.createDatabase(id)
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return impl, nil
}

// Version forwards the version request to the underlying Database implementation.
func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) {
impl, err := g.getDatabaseInternal(ctx)
impl, err := g.getOrForceCreateDatabase(ctx)
if err != nil {
return nil, err
}

if versioner, ok := impl.(logical.PluginVersioner); ok {
return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil
}
Expand Down
12 changes: 11 additions & 1 deletion sdk/database/dbplugin/v5/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ func (mw databaseMetricsMiddleware) Close() (err error) {
// Error Sanitizer Middleware Domain
// ///////////////////////////////////////////////////

var _ Database = DatabaseErrorSanitizerMiddleware{}
var (
_ Database = (*DatabaseErrorSanitizerMiddleware)(nil)
_ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil)
)

// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
Expand Down Expand Up @@ -280,6 +283,13 @@ func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}

func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion {
if versioner, ok := mw.next.(logical.PluginVersioner); ok {
return versioner.PluginVersion()
}
return logical.EmptyPluginVersion
}

// sanitize errors by removing any sensitive strings within their messages. This uses
// the secretsFn to determine what fields should be sanitized.
func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
Expand Down
7 changes: 6 additions & 1 deletion sdk/helper/pluginutil/multiplexing.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pluginutil

import (
"context"
"errors"
"fmt"
"os"
"strings"
Expand All @@ -13,6 +14,8 @@ import (
"google.golang.org/grpc/status"
)

var ErrNoMultiplexingIDFound = errors.New("no multiplexing ID found")

type PluginMultiplexingServerImpl struct {
UnimplementedPluginMultiplexingServer

Expand Down Expand Up @@ -62,7 +65,9 @@ func GetMultiplexIDFromContext(ctx context.Context) (string, error) {
}

multiplexIDs := md[MultiplexingCtxKey]
if len(multiplexIDs) != 1 {
if len(multiplexIDs) == 0 {
return "", ErrNoMultiplexingIDFound
} else if len(multiplexIDs) != 1 {
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
}

Expand Down
22 changes: 14 additions & 8 deletions vault/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package vault

import (
"context"
"encoding/hex"
"errors"
"fmt"
"strings"
Expand Down Expand Up @@ -170,7 +171,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
var backend logical.Backend
// Create the new backend
sysView := c.mountEntrySysView(entry)
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil {
return err
}
Expand Down Expand Up @@ -794,7 +795,7 @@ func (c *Core) setupCredentials(ctx context.Context) error {
// Initialize the backend
sysView := c.mountEntrySysView(entry)

backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil {
c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err)
plug, plugerr := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeCredential, "")
Expand Down Expand Up @@ -913,25 +914,30 @@ func (c *Core) teardownCredentials(ctx context.Context) error {
return nil
}

// newCredentialBackend is used to create and configure a new credential backend by name
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
// newCredentialBackend is used to create and configure a new credential backend by name.
// It also returns the SHA256 of the plugin, if available.
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
t := entry.Type
if alias, ok := credentialAliases[t]; ok {
t = alias
}

var runningSha string
f, ok := c.credentialBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version)
if err != nil {
return nil, err
return nil, "", err
}
if plug == nil {
errContext := t
if entry.Version != "" {
errContext += fmt.Sprintf(", version=%s", entry.Version)
}
return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
}
if len(plug.Sha256) > 0 {
runningSha = hex.EncodeToString(plug.Sha256)
}

f = plugin.Factory
Expand Down Expand Up @@ -967,10 +973,10 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV

b, err := f(ctx, config)
if err != nil {
return nil, err
return nil, "", err
}

return b, nil
return b, runningSha, nil
}

// defaultAuthTable creates a default auth table
Expand Down
Loading

0 comments on commit 0b34b73

Please sign in to comment.