Skip to content

Commit

Permalink
Add configuration for setting default and max TTLs (#176)
Browse files Browse the repository at this point in the history
This introduces new configuration into the JVS server - DefaultTTL and MaxTTL. The DefaultTTL is the TTL that the server will assign if the client does not request a TTL. The MaxTTL is the maximum value a client can request. If a client requests a larger TTL than the MaxTTL, they will get an error back in the response. This required threading through some data into the response that was not previously provided.

Fixes GH-169
  • Loading branch information
sethvargo authored Jan 25, 2023
1 parent fb49004 commit c958a66
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 34 deletions.
15 changes: 15 additions & 0 deletions pkg/config/justification_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"time"

"github.com/abcxyz/pkg/timeutil"
"github.com/hashicorp/go-multierror"
)

Expand All @@ -43,6 +44,14 @@ type JustificationConfig struct {

// Issuer will be used to set the issuer field when signing JWTs
Issuer string `yaml:"issuer" env:"ISSUER,overwrite,default=jvs.abcxyz.dev"`

// DefaultTTL sets the default TTL for JVS tokens that do not explicitly
// request a TTL. MaxTTL is the system-configured maximum TTL that a token can
// request.
//
// The DefaultTTL must be less than or equal to MaxTTL.
DefaultTTL time.Duration `yaml:"default_ttl" env:"DEFAULT_TTL,overwrite,default=15m"`
MaxTTL time.Duration `yaml:"max_ttl" env:"MAX_TTL,overwrite,default=4h"`
}

// Validate checks if the config is valid.
Expand All @@ -58,5 +67,11 @@ func (cfg *JustificationConfig) Validate() error {
err = multierror.Append(err, fmt.Errorf("cache timeout must be a positive duration, got %s",
cfg.SignerCacheTimeout))
}

if def, max := cfg.DefaultTTL, cfg.MaxTTL; def > max {
err = multierror.Append(err, fmt.Errorf("default ttl (%s) must be less than or equal to the max ttl (%s)",
timeutil.HumanDuration(def), timeutil.HumanDuration(max)))
}

return err.ErrorOrNil()
}
27 changes: 25 additions & 2 deletions pkg/config/justification_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ func TestJustificationConfig_Defaults(t *testing.T) {
Port: "8080",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
}

if diff := cmp.Diff(want, &justificationConfig); diff != "" {
Expand All @@ -66,12 +68,16 @@ port: 123
version: 1
signer_cache_timeout: 1m
issuer: jvs
default_ttl: 5m
max_ttl: 30m
`,
wantConfig: &JustificationConfig{
Port: "123",
Version: "1",
SignerCacheTimeout: 1 * time.Minute,
Issuer: "jvs",
DefaultTTL: 5 * time.Minute,
MaxTTL: 30 * time.Minute,
},
},
{
Expand All @@ -82,43 +88,60 @@ issuer: jvs
Version: "1",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
},
{
name: "test_wrong_version",
name: "wrong_version",
cfg: `
version: 255
`,
wantConfig: nil,
wantErr: `version "255" is invalid, valid versions are:`,
},
{
name: "test_invalid_signer_cache_timeout",
name: "invalid_signer_cache_timeout",
cfg: `
signer_cache_timeout: -1m
`,
wantConfig: nil,
wantErr: `cache timeout must be a positive duration, got -1m0s`,
},
{
name: "default_ttl_greater_than_max_ttl",
cfg: `
default_ttl: 1h
max_ttl: 30m
`,
wantConfig: nil,
wantErr: `default ttl (1h) must be less than or equal to the max ttl (30m)`,
},
{
name: "all_values_specified_env_override",
cfg: `
version: 1
port: 8080
signer_cache_timeout: 1m
issuer: jvs
default_ttl: 15m
max_ttl: 1h
`,
envs: map[string]string{
"VERSION": "1",
"PORT": "tcp",
"SIGNER_CACHE_TIMEOUT": "2m",
"ISSUER": "other",
"DEFAULT_TTL": "30m",
"MAX_TTL": "2h",
},
wantConfig: &JustificationConfig{
Version: "1",
Port: "tcp",
SignerCacheTimeout: 2 * time.Minute,
Issuer: "other",
DefaultTTL: 30 * time.Minute,
MaxTTL: 2 * time.Hour,
},
},
}
Expand Down
48 changes: 34 additions & 14 deletions pkg/justification/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/abcxyz/pkg/cache"
"github.com/abcxyz/pkg/grpcutil"
"github.com/abcxyz/pkg/logging"
"github.com/abcxyz/pkg/timeutil"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -79,28 +80,28 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat

if err := p.runValidations(req); err != nil {
logger.Errorw("failed to validate request", "error", err)
return nil, status.Error(codes.InvalidArgument, "failed to validate request")
return nil, status.Errorf(codes.InvalidArgument, "failed to validate request: %s", err)
}

token, err := p.createToken(ctx, now, req)
if err != nil {
logger.Errorw("failed to create token", "error", err)
return nil, status.Error(codes.Internal, "failed to create token")
return nil, status.Errorf(codes.Internal, "failed to create token: %s", err)
}

signer, err := p.cache.WriteThruLookup(cacheKey, func() (*signerWithID, error) {
return p.getPrimarySigner(ctx)
})
if err != nil {
logger.Errorw("failed to get signer", "error", err)
return nil, status.Error(codes.Internal, "failed to get token signer")
logger.Errorw("failed to get token signer", "error", err)
return nil, status.Errorf(codes.Internal, "failed to get token signer: %s", err)
}

// Build custom headers and set the "kid" as the signer ID.
headers := jws.NewHeaders()
if err := headers.Set(jws.KeyIDKey, signer.id); err != nil {
logger.Errorw("failed to set kid header", "error", err)
return nil, status.Error(codes.Internal, "failed to set token headers")
return nil, status.Errorf(codes.Internal, "failed to set token headers: %s", err)
}

// Sign the token.
Expand All @@ -116,14 +117,14 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat
func (p *Processor) getPrimarySigner(ctx context.Context) (*signerWithID, error) {
primaryVer, err := jvscrypto.GetPrimary(ctx, p.kms, p.config.KeyName)
if err != nil {
return nil, fmt.Errorf("unable to determine primary, %w", err)
return nil, fmt.Errorf("failed to determine primary signing key: %w", err)
}
if primaryVer == "" {
return nil, fmt.Errorf("no primary version found")
}
sig, err := gcpkms.NewSigner(ctx, p.kms, primaryVer)
if err != nil {
return nil, fmt.Errorf("failed to create signer, %w", err)
return nil, fmt.Errorf("failed to create signer: %w", err)
}
return &signerWithID{
Signer: sig,
Expand All @@ -137,10 +138,6 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er
return fmt.Errorf("no justifications specified")
}

if request.Ttl == nil {
return fmt.Errorf("no ttl specified")
}

var err *multierror.Error
for _, j := range request.Justifications {
switch j.Category {
Expand All @@ -160,12 +157,18 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er
func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.CreateJustificationRequest) (jwt.Token, error) {
email, err := p.authHandler.RequestPrincipal(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get email of requestor: %w", err)
return nil, fmt.Errorf("failed to get email of requestor: %w", err)
}

ttl, err := computeTTL(req.Ttl.AsDuration(), p.config.DefaultTTL, p.config.MaxTTL)
if err != nil {
return nil, fmt.Errorf("failed to compute ttl: %w", err)
}

id := uuid.New().String()
exp := now.Add(req.Ttl.AsDuration())
exp := now.Add(ttl)
justs := req.Justifications
iss := p.config.Issuer

// Use audiences in the request if provided.
aud := req.Audiences
Expand All @@ -177,7 +180,7 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C
Audience(aud).
Expiration(exp).
IssuedAt(now).
Issuer(p.config.Issuer).
Issuer(iss).
JwtID(id).
NotBefore(now).
Subject(email).
Expand All @@ -192,3 +195,20 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C

return token, nil
}

// computeTTL is a helper that computes the best TTL given the requested TTL,
// default TTL, and maximum configured TTL. If the requested TTL is greater than
// the maximum TTL, it returns an error. If the requested TTL is 0, it returns
// the default TTL.
func computeTTL(req, def, max time.Duration) (time.Duration, error) {
if req <= 0 {
return def, nil
}

if req > max {
return 0, fmt.Errorf("requested ttl (%s) cannot be greater than max tll (%s)",
timeutil.HumanDuration(req), timeutil.HumanDuration(max))
}

return req, nil
}
85 changes: 82 additions & 3 deletions pkg/justification/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/abcxyz/jvs/pkg/jvscrypto"
"github.com/abcxyz/jvs/pkg/testutil"
"github.com/abcxyz/pkg/grpcutil"
"github.com/abcxyz/pkg/logging"
pkgtestutil "github.com/abcxyz/pkg/testutil"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -59,6 +60,7 @@ func TestCreateToken(t *testing.T) {
tests := []struct {
name string
request *jvspb.CreateJustificationRequest
wantTTL time.Duration
wantAudiences []string
wantErr string
serverErr error
Expand All @@ -74,6 +76,7 @@ func TestCreateToken(t *testing.T) {
},
Ttl: durationpb.New(3600 * time.Second),
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{DefaultAudience},
},
{
Expand All @@ -88,6 +91,7 @@ func TestCreateToken(t *testing.T) {
Ttl: durationpb.New(3600 * time.Second),
Audiences: []string{"aud1", "aud2"},
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{"aud1", "aud2"},
},
{
Expand All @@ -107,7 +111,21 @@ func TestCreateToken(t *testing.T) {
},
},
},
wantErr: "failed to validate request",
wantTTL: 15 * time.Minute, // comes from default
wantAudiences: []string{"dev.abcxyz.jvs"},
},
{
name: "ttl_exceeds_max",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "explanation",
Value: "test",
},
},
Ttl: durationpb.New(10 * time.Hour),
},
wantErr: "requested ttl (10h) cannot be greater than max tll (1h)",
},
}

Expand All @@ -117,7 +135,7 @@ func TestCreateToken(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx := context.Background()
ctx := logging.WithLogger(context.Background(), logging.TestLogger(t))
now := time.Now().UTC()

var clientOpt option.ClientOption
Expand Down Expand Up @@ -187,6 +205,8 @@ func TestCreateToken(t *testing.T) {
KeyName: key,
SignerCacheTimeout: 5 * time.Minute,
Issuer: "test-iss",
DefaultTTL: 1 * time.Minute,
MaxTTL: 1 * time.Hour,
}, authHandler)

mockKeyManagement.Reqs = nil
Expand All @@ -196,7 +216,7 @@ func TestCreateToken(t *testing.T) {

response, gotErr := processor.CreateToken(ctx, tc.request)
if diff := pkgtestutil.DiffErrString(gotErr, tc.wantErr); diff != "" {
t.Errorf("Unexpected err: %s", diff)
t.Error(diff)
}
if gotErr != nil {
return
Expand Down Expand Up @@ -266,3 +286,62 @@ func TestCreateToken(t *testing.T) {
})
}
}

func TestComputeTTL(t *testing.T) {
t.Parallel()

cases := []struct {
name string
req time.Duration
def time.Duration
max time.Duration
exp time.Duration
err string
}{
{
name: "request_zero_uses_default",
req: 0,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 15 * time.Minute,
},
{
name: "request_negative_uses_default",
req: -10 * time.Second,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 15 * time.Minute,
},
{
name: "request_uses_self_in_bounds",
req: 12 * time.Minute,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 12 * time.Minute,
},
{
name: "request_greater_than_max_errors",
req: 1 * time.Hour,
def: 15 * time.Minute,
max: 30 * time.Minute,
err: "requested ttl (1h) cannot be greater than max tll (30m)",
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got, err := computeTTL(tc.req, tc.def, tc.max)
if result := pkgtestutil.DiffErrString(err, tc.err); result != "" {
t.Fatal(result)
}

if want := tc.exp; got != want {
t.Errorf("expected %q to be %q", got, want)
}
})
}
}
Loading

0 comments on commit c958a66

Please sign in to comment.