diff --git a/pkg/config/justification_config.go b/pkg/config/justification_config.go index 9e4386ee..cdfdfcd4 100644 --- a/pkg/config/justification_config.go +++ b/pkg/config/justification_config.go @@ -19,6 +19,7 @@ import ( "fmt" "time" + "github.com/abcxyz/pkg/timeutil" "github.com/hashicorp/go-multierror" ) @@ -43,6 +44,14 @@ type JustificationConfig struct { // Issuer will be used to set the issuer field when signing JWTs Issuer string `yaml:"issuer" env:"ISSUER,overwrite,default=jvs.abcxyz.dev"` + + // DefaultTTL sets the default TTL for JVS tokens that do not explicitly + // request a TTL. MaxTTL is the system-configured maximum TTL that a token can + // request. + // + // The DefaultTTL must be less than or equal to MaxTTL. + DefaultTTL time.Duration `yaml:"default_ttl" env:"DEFAULT_TTL,overwrite,default=15m"` + MaxTTL time.Duration `yaml:"max_ttl" env:"MAX_TTL,overwrite,default=4h"` } // Validate checks if the config is valid. @@ -58,5 +67,11 @@ func (cfg *JustificationConfig) Validate() error { err = multierror.Append(err, fmt.Errorf("cache timeout must be a positive duration, got %s", cfg.SignerCacheTimeout)) } + + if def, max := cfg.DefaultTTL, cfg.MaxTTL; def > max { + err = multierror.Append(err, fmt.Errorf("default ttl (%s) must be less than or equal to the max ttl (%s)", + timeutil.HumanDuration(def), timeutil.HumanDuration(max))) + } + return err.ErrorOrNil() } diff --git a/pkg/config/justification_config_test.go b/pkg/config/justification_config_test.go index 18bcb84e..836d0b14 100644 --- a/pkg/config/justification_config_test.go +++ b/pkg/config/justification_config_test.go @@ -41,6 +41,8 @@ func TestJustificationConfig_Defaults(t *testing.T) { Port: "8080", SignerCacheTimeout: 5 * time.Minute, Issuer: "jvs.abcxyz.dev", + DefaultTTL: 15 * time.Minute, + MaxTTL: 4 * time.Hour, } if diff := cmp.Diff(want, &justificationConfig); diff != "" { @@ -66,12 +68,16 @@ port: 123 version: 1 signer_cache_timeout: 1m issuer: jvs +default_ttl: 5m +max_ttl: 30m `, wantConfig: &JustificationConfig{ Port: "123", Version: "1", SignerCacheTimeout: 1 * time.Minute, Issuer: "jvs", + DefaultTTL: 5 * time.Minute, + MaxTTL: 30 * time.Minute, }, }, { @@ -82,10 +88,12 @@ issuer: jvs Version: "1", SignerCacheTimeout: 5 * time.Minute, Issuer: "jvs.abcxyz.dev", + DefaultTTL: 15 * time.Minute, + MaxTTL: 4 * time.Hour, }, }, { - name: "test_wrong_version", + name: "wrong_version", cfg: ` version: 255 `, @@ -93,13 +101,22 @@ version: 255 wantErr: `version "255" is invalid, valid versions are:`, }, { - name: "test_invalid_signer_cache_timeout", + name: "invalid_signer_cache_timeout", cfg: ` signer_cache_timeout: -1m `, wantConfig: nil, wantErr: `cache timeout must be a positive duration, got -1m0s`, }, + { + name: "default_ttl_greater_than_max_ttl", + cfg: ` +default_ttl: 1h +max_ttl: 30m +`, + wantConfig: nil, + wantErr: `default ttl (1h) must be less than or equal to the max ttl (30m)`, + }, { name: "all_values_specified_env_override", cfg: ` @@ -107,18 +124,24 @@ version: 1 port: 8080 signer_cache_timeout: 1m issuer: jvs +default_ttl: 15m +max_ttl: 1h `, envs: map[string]string{ "VERSION": "1", "PORT": "tcp", "SIGNER_CACHE_TIMEOUT": "2m", "ISSUER": "other", + "DEFAULT_TTL": "30m", + "MAX_TTL": "2h", }, wantConfig: &JustificationConfig{ Version: "1", Port: "tcp", SignerCacheTimeout: 2 * time.Minute, Issuer: "other", + DefaultTTL: 30 * time.Minute, + MaxTTL: 2 * time.Hour, }, }, } diff --git a/pkg/justification/processor.go b/pkg/justification/processor.go index 3d46af0e..d70393d3 100644 --- a/pkg/justification/processor.go +++ b/pkg/justification/processor.go @@ -26,6 +26,7 @@ import ( "github.com/abcxyz/pkg/cache" "github.com/abcxyz/pkg/grpcutil" "github.com/abcxyz/pkg/logging" + "github.com/abcxyz/pkg/timeutil" "github.com/google/uuid" "github.com/hashicorp/go-multierror" "github.com/lestrrat-go/jwx/v2/jwa" @@ -79,28 +80,28 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat if err := p.runValidations(req); err != nil { logger.Errorw("failed to validate request", "error", err) - return nil, status.Error(codes.InvalidArgument, "failed to validate request") + return nil, status.Errorf(codes.InvalidArgument, "failed to validate request: %s", err) } token, err := p.createToken(ctx, now, req) if err != nil { logger.Errorw("failed to create token", "error", err) - return nil, status.Error(codes.Internal, "failed to create token") + return nil, status.Errorf(codes.Internal, "failed to create token: %s", err) } signer, err := p.cache.WriteThruLookup(cacheKey, func() (*signerWithID, error) { return p.getPrimarySigner(ctx) }) if err != nil { - logger.Errorw("failed to get signer", "error", err) - return nil, status.Error(codes.Internal, "failed to get token signer") + logger.Errorw("failed to get token signer", "error", err) + return nil, status.Errorf(codes.Internal, "failed to get token signer: %s", err) } // Build custom headers and set the "kid" as the signer ID. headers := jws.NewHeaders() if err := headers.Set(jws.KeyIDKey, signer.id); err != nil { logger.Errorw("failed to set kid header", "error", err) - return nil, status.Error(codes.Internal, "failed to set token headers") + return nil, status.Errorf(codes.Internal, "failed to set token headers: %s", err) } // Sign the token. @@ -116,14 +117,14 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat func (p *Processor) getPrimarySigner(ctx context.Context) (*signerWithID, error) { primaryVer, err := jvscrypto.GetPrimary(ctx, p.kms, p.config.KeyName) if err != nil { - return nil, fmt.Errorf("unable to determine primary, %w", err) + return nil, fmt.Errorf("failed to determine primary signing key: %w", err) } if primaryVer == "" { return nil, fmt.Errorf("no primary version found") } sig, err := gcpkms.NewSigner(ctx, p.kms, primaryVer) if err != nil { - return nil, fmt.Errorf("failed to create signer, %w", err) + return nil, fmt.Errorf("failed to create signer: %w", err) } return &signerWithID{ Signer: sig, @@ -137,10 +138,6 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er return fmt.Errorf("no justifications specified") } - if request.Ttl == nil { - return fmt.Errorf("no ttl specified") - } - var err *multierror.Error for _, j := range request.Justifications { switch j.Category { @@ -160,12 +157,18 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.CreateJustificationRequest) (jwt.Token, error) { email, err := p.authHandler.RequestPrincipal(ctx) if err != nil { - return nil, fmt.Errorf("unable to get email of requestor: %w", err) + return nil, fmt.Errorf("failed to get email of requestor: %w", err) + } + + ttl, err := computeTTL(req.Ttl.AsDuration(), p.config.DefaultTTL, p.config.MaxTTL) + if err != nil { + return nil, fmt.Errorf("failed to compute ttl: %w", err) } id := uuid.New().String() - exp := now.Add(req.Ttl.AsDuration()) + exp := now.Add(ttl) justs := req.Justifications + iss := p.config.Issuer // Use audiences in the request if provided. aud := req.Audiences @@ -177,7 +180,7 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C Audience(aud). Expiration(exp). IssuedAt(now). - Issuer(p.config.Issuer). + Issuer(iss). JwtID(id). NotBefore(now). Subject(email). @@ -192,3 +195,20 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C return token, nil } + +// computeTTL is a helper that computes the best TTL given the requested TTL, +// default TTL, and maximum configured TTL. If the requested TTL is greater than +// the maximum TTL, it returns an error. If the requested TTL is 0, it returns +// the default TTL. +func computeTTL(req, def, max time.Duration) (time.Duration, error) { + if req <= 0 { + return def, nil + } + + if req > max { + return 0, fmt.Errorf("requested ttl (%s) cannot be greater than max tll (%s)", + timeutil.HumanDuration(req), timeutil.HumanDuration(max)) + } + + return req, nil +} diff --git a/pkg/justification/processor_test.go b/pkg/justification/processor_test.go index e361d03c..1ca26c2c 100644 --- a/pkg/justification/processor_test.go +++ b/pkg/justification/processor_test.go @@ -32,6 +32,7 @@ import ( "github.com/abcxyz/jvs/pkg/jvscrypto" "github.com/abcxyz/jvs/pkg/testutil" "github.com/abcxyz/pkg/grpcutil" + "github.com/abcxyz/pkg/logging" pkgtestutil "github.com/abcxyz/pkg/testutil" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -59,6 +60,7 @@ func TestCreateToken(t *testing.T) { tests := []struct { name string request *jvspb.CreateJustificationRequest + wantTTL time.Duration wantAudiences []string wantErr string serverErr error @@ -74,6 +76,7 @@ func TestCreateToken(t *testing.T) { }, Ttl: durationpb.New(3600 * time.Second), }, + wantTTL: 1 * time.Hour, wantAudiences: []string{DefaultAudience}, }, { @@ -88,6 +91,7 @@ func TestCreateToken(t *testing.T) { Ttl: durationpb.New(3600 * time.Second), Audiences: []string{"aud1", "aud2"}, }, + wantTTL: 1 * time.Hour, wantAudiences: []string{"aud1", "aud2"}, }, { @@ -107,7 +111,21 @@ func TestCreateToken(t *testing.T) { }, }, }, - wantErr: "failed to validate request", + wantTTL: 15 * time.Minute, // comes from default + wantAudiences: []string{"dev.abcxyz.jvs"}, + }, + { + name: "ttl_exceeds_max", + request: &jvspb.CreateJustificationRequest{ + Justifications: []*jvspb.Justification{ + { + Category: "explanation", + Value: "test", + }, + }, + Ttl: durationpb.New(10 * time.Hour), + }, + wantErr: "requested ttl (10h) cannot be greater than max tll (1h)", }, } @@ -117,7 +135,7 @@ func TestCreateToken(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx := logging.WithLogger(context.Background(), logging.TestLogger(t)) now := time.Now().UTC() var clientOpt option.ClientOption @@ -187,6 +205,8 @@ func TestCreateToken(t *testing.T) { KeyName: key, SignerCacheTimeout: 5 * time.Minute, Issuer: "test-iss", + DefaultTTL: 1 * time.Minute, + MaxTTL: 1 * time.Hour, }, authHandler) mockKeyManagement.Reqs = nil @@ -196,7 +216,7 @@ func TestCreateToken(t *testing.T) { response, gotErr := processor.CreateToken(ctx, tc.request) if diff := pkgtestutil.DiffErrString(gotErr, tc.wantErr); diff != "" { - t.Errorf("Unexpected err: %s", diff) + t.Error(diff) } if gotErr != nil { return @@ -266,3 +286,62 @@ func TestCreateToken(t *testing.T) { }) } } + +func TestComputeTTL(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + req time.Duration + def time.Duration + max time.Duration + exp time.Duration + err string + }{ + { + name: "request_zero_uses_default", + req: 0, + def: 15 * time.Minute, + max: 30 * time.Minute, + exp: 15 * time.Minute, + }, + { + name: "request_negative_uses_default", + req: -10 * time.Second, + def: 15 * time.Minute, + max: 30 * time.Minute, + exp: 15 * time.Minute, + }, + { + name: "request_uses_self_in_bounds", + req: 12 * time.Minute, + def: 15 * time.Minute, + max: 30 * time.Minute, + exp: 12 * time.Minute, + }, + { + name: "request_greater_than_max_errors", + req: 1 * time.Hour, + def: 15 * time.Minute, + max: 30 * time.Minute, + err: "requested ttl (1h) cannot be greater than max tll (30m)", + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, err := computeTTL(tc.req, tc.def, tc.max) + if result := pkgtestutil.DiffErrString(err, tc.err); result != "" { + t.Fatal(result) + } + + if want := tc.exp; got != want { + t.Errorf("expected %q to be %q", got, want) + } + }) + } +} diff --git a/pkg/jvscrypto/rotation_handler.go b/pkg/jvscrypto/rotation_handler.go index 001e3362..b13bc658 100644 --- a/pkg/jvscrypto/rotation_handler.go +++ b/pkg/jvscrypto/rotation_handler.go @@ -65,15 +65,15 @@ func (h *RotationHandler) RotateKey(ctx context.Context, key string) error { // Get any relevant Key Version information from the StateStore primaryName, err := GetPrimary(ctx, h.KMSClient, key) if err != nil { - return fmt.Errorf("unable to determine primary: %w", err) + return fmt.Errorf("failed to determine primary: %w", err) } actions, err := h.determineActions(ctx, vers, primaryName, curTime) if err != nil { - return fmt.Errorf("unable to determine cert actions: %w", err) + return fmt.Errorf("failed to determine cert actions: %w", err) } if err = h.performActions(ctx, key, actions); err != nil { - return fmt.Errorf("unable to perform some cert actions: %w", err) + return fmt.Errorf("failed to perform some cert actions: %w", err) } return nil } diff --git a/test/integ/main_test.go b/test/integ/main_test.go index c4fe279b..7527fe12 100644 --- a/test/integ/main_test.go +++ b/test/integ/main_test.go @@ -88,6 +88,8 @@ func TestJVS(t *testing.T) { KeyName: keyName, Issuer: "ci-test", SignerCacheTimeout: 1 * time.Nanosecond, // no caching + DefaultTTL: 15 * time.Minute, + MaxTTL: 2 * time.Hour, } if err := cfg.Validate(); err != nil { t.Fatal(err) @@ -189,18 +191,6 @@ func TestJVS(t *testing.T) { }, wantErr: "failed to validate request", }, - { - name: "no_ttl", - request: &jvspb.CreateJustificationRequest{ - Justifications: []*jvspb.Justification{ - { - Category: "explanation", - Value: "This is a test.", - }, - }, - }, - wantErr: "failed to validate request", - }, { name: "no_justification", request: &jvspb.CreateJustificationRequest{