Skip to content

Commit

Permalink
Load plugins when starting API server (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
dandans-verily authored Jul 7, 2023
1 parent 1d81363 commit 525f241
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 8 deletions.
10 changes: 9 additions & 1 deletion pkg/cli/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/abcxyz/jvs/internal/version"
"github.com/abcxyz/jvs/pkg/config"
"github.com/abcxyz/jvs/pkg/justification"
"github.com/abcxyz/jvs/pkg/plugin"
"github.com/abcxyz/pkg/cli"
"github.com/abcxyz/pkg/healthcheck"
"github.com/abcxyz/pkg/logging"
Expand Down Expand Up @@ -117,7 +118,14 @@ func (c *APIServerCommand) RunUnstarted(ctx context.Context, args []string) (*se
// Create basic health check
healthcheck.RegisterGRPCHealthCheck(grpcServer)

p := justification.NewProcessor(kmsClient, c.cfg)
validators, pluginClosers, err := plugin.LoadPlugins(c.cfg.PluginDir)
if err != nil {
return nil, nil, closer, fmt.Errorf("failed to load plugins: %w", err)
}
logger.Infow("plugins loaded", "validators", validators)
closer = multicloser.Append(closer, pluginClosers.Close)

p := justification.NewProcessor(kmsClient, c.cfg).WithValidators(validators)
jvsAgent := justification.NewJVSAgent(p)
jvspb.RegisterJVSServiceServer(grpcServer, jvsAgent)
reflection.Register(grpcServer)
Expand Down
11 changes: 11 additions & 0 deletions pkg/config/justification_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type JustificationConfig struct {
// Issuer will be used to set the issuer field when signing JWTs
Issuer string `env:"JVS_API_ISSUER,overwrite,default=jvs.abcxyz.dev"`

// PluginDir is the path of the directory to load plugins.
PluginDir string `env:"JVS_PLUGIN_DIR,overwrite,default=/var/jvs/plugins"`

// DefaultTTL sets the default TTL for JVS tokens that do not explicitly
// request a TTL. MaxTTL is the system-configured maximum TTL that a token can
// request.
Expand Down Expand Up @@ -124,6 +127,14 @@ func (cfg *JustificationConfig) ToFlags(set *cli.FlagSet) *cli.FlagSet {
Usage: `The value to set to the issuer claim when signing JVS tokens.`,
})

f.StringVar(&cli.StringVar{
Name: "plugin-dir",
Target: &cfg.PluginDir,
EnvVar: "JVS_PLUGIN_DIR",
Default: "/var/jvs/plugins",
Usage: `The path of the directory to load plugins.`,
})

f.DurationVar(&cli.DurationVar{
Name: "signer-cache-timeout",
Target: &cfg.SignerCacheTimeout,
Expand Down
21 changes: 21 additions & 0 deletions pkg/config/justification_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func TestJustificationConfig_ToFlags(t *testing.T) {
"JVS_KEY": "fake/key",
"JVS_API_SIGNER_CACHE_TIMEOUT": "10m",
"JVS_API_ISSUER": "example.com",
"JVS_PLUGIN_DIR": "/var/jvs/pluginsDir",
"JVS_API_DEFAULT_TTL": "30m",
"JVS_API_MAX_TTL": "8h",
},
Expand All @@ -51,6 +52,7 @@ func TestJustificationConfig_ToFlags(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 10 * time.Minute,
Issuer: "example.com",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 30 * time.Minute,
MaxTTL: 8 * time.Hour,
},
Expand All @@ -61,6 +63,7 @@ func TestJustificationConfig_ToFlags(t *testing.T) {
Port: "8080",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/plugins",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand Down Expand Up @@ -102,6 +105,20 @@ func TestJustificationConfig_Validate(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
},
{
name: "relative_plugin_dir",
cfg: &JustificationConfig{
ProjectID: "example-project",
Port: "8080",
KeyName: "fake/key",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "./pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand All @@ -113,6 +130,7 @@ func TestJustificationConfig_Validate(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand All @@ -125,6 +143,7 @@ func TestJustificationConfig_Validate(t *testing.T) {
Port: "8080",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand All @@ -138,6 +157,7 @@ func TestJustificationConfig_Validate(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: -5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand All @@ -151,6 +171,7 @@ func TestJustificationConfig_Validate(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 10 * time.Minute,
},
Expand Down
6 changes: 6 additions & 0 deletions pkg/config/ui_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func TestUIServiceConfig_ToFlags(t *testing.T) {
"JVS_KEY": "fake/key",
"JVS_API_SIGNER_CACHE_TIMEOUT": "10m",
"JVS_API_ISSUER": "example.com",
"JVS_PLUGIN_DIR": "/var/jvs/pluginsDir",
"JVS_API_DEFAULT_TTL": "30m",
"JVS_API_MAX_TTL": "8h",
"JVS_UI_ALLOWLIST": "example.com,*.foo.bar",
Expand All @@ -53,6 +54,7 @@ func TestUIServiceConfig_ToFlags(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 10 * time.Minute,
Issuer: "example.com",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 30 * time.Minute,
MaxTTL: 8 * time.Hour,
},
Expand All @@ -66,6 +68,7 @@ func TestUIServiceConfig_ToFlags(t *testing.T) {
Port: "8080",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/plugins",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand Down Expand Up @@ -109,6 +112,7 @@ func TestUIServiceConfig_Validate(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand All @@ -124,6 +128,7 @@ func TestUIServiceConfig_Validate(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand All @@ -139,6 +144,7 @@ func TestUIServiceConfig_Validate(t *testing.T) {
KeyName: "fake/key",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
PluginDir: "/var/jvs/pluginsDir",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
Expand Down
34 changes: 28 additions & 6 deletions pkg/justification/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ import (
// mints a token.
type Processor struct {
jvspb.UnimplementedJVSServiceServer
kms *kms.KeyManagementClient
config *config.JustificationConfig
cache *cache.Cache[*signerWithID]
kms *kms.KeyManagementClient
config *config.JustificationConfig
cache *cache.Cache[*signerWithID]
validators map[string]jvspb.Validator
}

type signerWithID struct {
Expand All @@ -68,14 +69,19 @@ const (
DefaultAudience = "dev.abcxyz.jvs"
)

func (p *Processor) WithValidators(v map[string]jvspb.Validator) *Processor {
p.validators = v
return p
}

// CreateToken implements the create token API which creates and signs a JWT
// token if the provided justifications are valid.
func (p *Processor) CreateToken(ctx context.Context, requestor string, req *jvspb.CreateJustificationRequest) ([]byte, error) {
now := time.Now().UTC()

logger := logging.FromContext(ctx)

if err := p.runValidations(req); err != nil {
if err := p.runValidations(ctx, req); err != nil {
logger.Errorw("failed to validate request", "error", err)
return nil, status.Errorf(codes.InvalidArgument, "failed to validate request: %s", err)
}
Expand Down Expand Up @@ -130,7 +136,7 @@ func (p *Processor) getPrimarySigner(ctx context.Context) (*signerWithID, error)
}

// TODO: Each category should have its own validator struct, with a shared interface.
func (p *Processor) runValidations(req *jvspb.CreateJustificationRequest) error {
func (p *Processor) runValidations(ctx context.Context, req *jvspb.CreateJustificationRequest) error {
if len(req.Justifications) < 1 {
return fmt.Errorf("no justifications specified")
}
Expand All @@ -146,7 +152,23 @@ func (p *Processor) runValidations(req *jvspb.CreateJustificationRequest) error
err = multierror.Append(err, fmt.Errorf("no value specified for 'explanation' category"))
}
default:
err = multierror.Append(err, fmt.Errorf("unexpected justification %v unrecognized", j))
v, ok := p.validators[j.Category]
if !ok {
err = multierror.Append(err, fmt.Errorf("missing validator for category %q", j.Category))
continue
}
resp, verr := v.Validate(ctx, &jvspb.ValidateJustificationRequest{
Justification: j,
})
if verr != nil {
err = multierror.Append(err, fmt.Errorf("unexpected error from validator %q: %w", j.Category, verr))
continue
}

if !resp.Valid {
err = multierror.Append(err,
fmt.Errorf("failed validation criteria with error %v and warning %v", resp.Error, resp.Warning))
}
}
}

Expand Down
108 changes: 107 additions & 1 deletion pkg/justification/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/rand"
"crypto/x509"
"encoding/pem"
"fmt"
"reflect"
"strings"
"testing"
Expand All @@ -45,13 +46,23 @@ import (
"google.golang.org/protobuf/types/known/durationpb"
)

type mockValidator struct {
resp *jvspb.ValidateJustificationResponse
err error
}

func (m *mockValidator) Validate(ctx context.Context, req *jvspb.ValidateJustificationRequest) (*jvspb.ValidateJustificationResponse, error) {
return m.resp, m.err
}

func TestCreateToken(t *testing.T) {
t.Parallel()

tests := []struct {
name string
request *jvspb.CreateJustificationRequest
requestor string
validators map[string]jvspb.Validator
wantTTL time.Duration
wantSubject string
wantAudiences []string
Expand Down Expand Up @@ -179,6 +190,101 @@ func TestCreateToken(t *testing.T) {
},
wantErr: "must be less than 1000 bytes",
},
{
name: "happy_path_with_validator",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "jira",
Value: "test",
},
},
Ttl: durationpb.New(3600 * time.Second),
},
validators: map[string]jvspb.Validator{
"jira": &mockValidator{
resp: &jvspb.ValidateJustificationResponse{
Valid: true,
},
},
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{DefaultAudience},
},
{
name: "happy_path_with_unused_validator",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "explanation",
Value: "test",
},
},
Ttl: durationpb.New(3600 * time.Second),
},
validators: map[string]jvspb.Validator{
"jira": &mockValidator{
resp: &jvspb.ValidateJustificationResponse{
Valid: false,
Error: []string{"bad jira ticket"},
},
},
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{DefaultAudience},
},
{
name: "failed_validator_criteria",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "jira",
Value: "test",
},
},
Ttl: durationpb.New(3600 * time.Second),
},
validators: map[string]jvspb.Validator{
"jira": &mockValidator{
resp: &jvspb.ValidateJustificationResponse{
Valid: false,
Error: []string{"bad explanation"},
},
},
},
wantErr: "failed validation criteria with error [bad explanation] and warning []",
},
{
name: "validator_err",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "jira",
Value: "test",
},
},
Ttl: durationpb.New(3600 * time.Second),
},
validators: map[string]jvspb.Validator{
"jira": &mockValidator{
err: fmt.Errorf("Cannot connect to validator"),
},
},
wantErr: "unexpected error from validator \"jira\": Cannot connect to validator",
},
{
name: "missing_validator",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "jira",
Value: "test",
},
},
Ttl: durationpb.New(3600 * time.Second),
},
wantErr: "missing validator for category \"jira\"",
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -246,7 +352,7 @@ func TestCreateToken(t *testing.T) {
Issuer: "test-iss",
DefaultTTL: 1 * time.Minute,
MaxTTL: 1 * time.Hour,
})
}).WithValidators(tc.validators)

mockKeyManagement.Reqs = nil
mockKeyManagement.Err = tc.serverErr
Expand Down
1 change: 1 addition & 0 deletions test/integ/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func TestJVS(t *testing.T) {
ProjectID: os.Getenv("PROJECT_ID"),
KeyName: keyName,
Issuer: "ci-test",
PluginDir: "/var/jvs/plugins",
SignerCacheTimeout: 1 * time.Nanosecond, // no caching
DefaultTTL: 15 * time.Minute,
MaxTTL: 2 * time.Hour,
Expand Down

0 comments on commit 525f241

Please sign in to comment.