Skip to content

Commit

Permalink
Add configuration for setting default and max TTLs
Browse files Browse the repository at this point in the history
This introduces new configuration into the JVS server - DefaultTTL and MaxTTL. The DefaultTTL is the TTL that the server will assign if the client does not request a TTL. The MaxTTL is the maximum value a client can request. If a client requests a larger TTL than the MaxTTL, they will get an error back in the response. This required threading through some data into the response that was not previously provided.

Fixes GH-169
  • Loading branch information
sethvargo committed Jan 25, 2023
1 parent 0d8c9af commit 0c08e32
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 34 deletions.
15 changes: 15 additions & 0 deletions pkg/config/justification_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"time"

"github.com/abcxyz/pkg/timeutil"
"github.com/hashicorp/go-multierror"
)

Expand All @@ -43,6 +44,14 @@ type JustificationConfig struct {

// Issuer will be used to set the issuer field when signing JWTs
Issuer string `yaml:"issuer" env:"ISSUER,overwrite,default=jvs.abcxyz.dev"`

// DefaultTTL sets the default TTL for JVS tokens that do not explicitly
// request a TTL. MaxTTL is the system-configured maximum TTL that a token can
// request.
//
// The DefaultTTL must be less than or equal to MaxTTL.
DefaultTTL time.Duration `yaml:"default_ttl" env:"DEFAULT_TTL,overwrite,default=15m"`
MaxTTL time.Duration `yaml:"max_ttl" env:"MAX_TTL,overwrite,default=4h"`
}

// Validate checks if the config is valid.
Expand All @@ -58,5 +67,11 @@ func (cfg *JustificationConfig) Validate() error {
err = multierror.Append(err, fmt.Errorf("cache timeout must be a positive duration, got %s",
cfg.SignerCacheTimeout))
}

if def, max := cfg.DefaultTTL, cfg.MaxTTL; def > max {
err = multierror.Append(err, fmt.Errorf("default ttl (%s) must be less than or equal to the max ttl (%s)",
timeutil.HumanDuration(def), timeutil.HumanDuration(max)))
}

return err.ErrorOrNil()
}
27 changes: 25 additions & 2 deletions pkg/config/justification_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ func TestJustificationConfig_Defaults(t *testing.T) {
Port: "8080",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
}

if diff := cmp.Diff(want, &justificationConfig); diff != "" {
Expand All @@ -66,12 +68,16 @@ port: 123
version: 1
signer_cache_timeout: 1m
issuer: jvs
default_ttl: 5m
max_ttl: 30m
`,
wantConfig: &JustificationConfig{
Port: "123",
Version: "1",
SignerCacheTimeout: 1 * time.Minute,
Issuer: "jvs",
DefaultTTL: 5 * time.Minute,
MaxTTL: 30 * time.Minute,
},
},
{
Expand All @@ -82,43 +88,60 @@ issuer: jvs
Version: "1",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
},
{
name: "test_wrong_version",
name: "wrong_version",
cfg: `
version: 255
`,
wantConfig: nil,
wantErr: `version "255" is invalid, valid versions are:`,
},
{
name: "test_invalid_signer_cache_timeout",
name: "invalid_signer_cache_timeout",
cfg: `
signer_cache_timeout: -1m
`,
wantConfig: nil,
wantErr: `cache timeout must be a positive duration, got -1m0s`,
},
{
name: "default_ttl_greater_than_max_ttl",
cfg: `
default_ttl: 1h
max_ttl: 30m
`,
wantConfig: nil,
wantErr: `default ttl (1h) must be less than or equal to the max ttl (30m)`,
},
{
name: "all_values_specified_env_override",
cfg: `
version: 1
port: 8080
signer_cache_timeout: 1m
issuer: jvs
default_ttl: 15m
max_ttl: 1h
`,
envs: map[string]string{
"VERSION": "1",
"PORT": "tcp",
"SIGNER_CACHE_TIMEOUT": "2m",
"ISSUER": "other",
"DEFAULT_TTL": "30m",
"MAX_TTL": "2h",
},
wantConfig: &JustificationConfig{
Version: "1",
Port: "tcp",
SignerCacheTimeout: 2 * time.Minute,
Issuer: "other",
DefaultTTL: 30 * time.Minute,
MaxTTL: 2 * time.Hour,
},
},
}
Expand Down
48 changes: 34 additions & 14 deletions pkg/justification/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/abcxyz/pkg/cache"
"github.com/abcxyz/pkg/grpcutil"
"github.com/abcxyz/pkg/logging"
"github.com/abcxyz/pkg/timeutil"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -79,28 +80,28 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat

if err := p.runValidations(req); err != nil {
logger.Errorw("failed to validate request", "error", err)
return nil, status.Error(codes.InvalidArgument, "failed to validate request")
return nil, status.Errorf(codes.InvalidArgument, "failed to validate request: %s", err)
}

token, err := p.createToken(ctx, now, req)
if err != nil {
logger.Errorw("failed to create token", "error", err)
return nil, status.Error(codes.Internal, "failed to create token")
return nil, status.Errorf(codes.Internal, "failed to create token: %s", err)
}

signer, err := p.cache.WriteThruLookup(cacheKey, func() (*signerWithID, error) {
return p.getPrimarySigner(ctx)
})
if err != nil {
logger.Errorw("failed to get signer", "error", err)
return nil, status.Error(codes.Internal, "failed to get token signer")
logger.Errorw("failed to get token signer", "error", err)
return nil, status.Errorf(codes.Internal, "failed to get token signer: %s", err)
}

// Build custom headers and set the "kid" as the signer ID.
headers := jws.NewHeaders()
if err := headers.Set(jws.KeyIDKey, signer.id); err != nil {
logger.Errorw("failed to set kid header", "error", err)
return nil, status.Error(codes.Internal, "failed to set token headers")
return nil, status.Errorf(codes.Internal, "failed to set token headers: %s", err)
}

// Sign the token.
Expand All @@ -116,14 +117,14 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat
func (p *Processor) getPrimarySigner(ctx context.Context) (*signerWithID, error) {
primaryVer, err := jvscrypto.GetPrimary(ctx, p.kms, p.config.KeyName)
if err != nil {
return nil, fmt.Errorf("unable to determine primary, %w", err)
return nil, fmt.Errorf("failed to determine primary signing key: %w", err)
}
if primaryVer == "" {
return nil, fmt.Errorf("no primary version found")
}
sig, err := gcpkms.NewSigner(ctx, p.kms, primaryVer)
if err != nil {
return nil, fmt.Errorf("failed to create signer, %w", err)
return nil, fmt.Errorf("failed to create signer: %w", err)
}
return &signerWithID{
Signer: sig,
Expand All @@ -137,10 +138,6 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er
return fmt.Errorf("no justifications specified")
}

if request.Ttl == nil {
return fmt.Errorf("no ttl specified")
}

var err *multierror.Error
for _, j := range request.Justifications {
switch j.Category {
Expand All @@ -160,12 +157,18 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er
func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.CreateJustificationRequest) (jwt.Token, error) {
email, err := p.authHandler.RequestPrincipal(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get email of requestor: %w", err)
return nil, fmt.Errorf("failed to get email of requestor: %w", err)
}

ttl, err := computeTTL(req.Ttl.AsDuration(), p.config.DefaultTTL, p.config.MaxTTL)
if err != nil {
return nil, fmt.Errorf("failed to compute ttl: %w", err)
}

id := uuid.New().String()
exp := now.Add(req.Ttl.AsDuration())
exp := now.Add(ttl)
justs := req.Justifications
iss := p.config.Issuer

// Use audiences in the request if provided.
aud := req.Audiences
Expand All @@ -177,7 +180,7 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C
Audience(aud).
Expiration(exp).
IssuedAt(now).
Issuer(p.config.Issuer).
Issuer(iss).
JwtID(id).
NotBefore(now).
Subject(email).
Expand All @@ -192,3 +195,20 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C

return token, nil
}

// computeTTL is a helper that computes the best TTL given the requested TTL,
// default TTL, and maximum configured TTL. If the requested TTL is greater than
// the maximum TTL, it returns an error. If the requested TTL is 0, it returns
// the default TTL.
func computeTTL(req, def, max time.Duration) (time.Duration, error) {
if req <= 0 {
return def, nil
}

if req > max {
return 0, fmt.Errorf("requested ttl (%s) cannot be greater than max tll (%s)",
timeutil.HumanDuration(req), timeutil.HumanDuration(max))
}

return req, nil
}
85 changes: 82 additions & 3 deletions pkg/justification/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/abcxyz/jvs/pkg/jvscrypto"
"github.com/abcxyz/jvs/pkg/testutil"
"github.com/abcxyz/pkg/grpcutil"
"github.com/abcxyz/pkg/logging"
pkgtestutil "github.com/abcxyz/pkg/testutil"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -59,6 +60,7 @@ func TestCreateToken(t *testing.T) {
tests := []struct {
name string
request *jvspb.CreateJustificationRequest
wantTTL time.Duration
wantAudiences []string
wantErr string
serverErr error
Expand All @@ -74,6 +76,7 @@ func TestCreateToken(t *testing.T) {
},
Ttl: durationpb.New(3600 * time.Second),
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{DefaultAudience},
},
{
Expand All @@ -88,6 +91,7 @@ func TestCreateToken(t *testing.T) {
Ttl: durationpb.New(3600 * time.Second),
Audiences: []string{"aud1", "aud2"},
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{"aud1", "aud2"},
},
{
Expand All @@ -107,7 +111,21 @@ func TestCreateToken(t *testing.T) {
},
},
},
wantErr: "failed to validate request",
wantTTL: 15 * time.Minute, // comes from default
wantAudiences: []string{"dev.abcxyz.jvs"},
},
{
name: "ttl_exceeds_max",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "explanation",
Value: "test",
},
},
Ttl: durationpb.New(10 * time.Hour),
},
wantErr: "requested ttl (10h) cannot be greater than max tll (1h)",
},
}

Expand All @@ -117,7 +135,7 @@ func TestCreateToken(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx := context.Background()
ctx := logging.WithLogger(context.Background(), logging.TestLogger(t))
now := time.Now().UTC()

var clientOpt option.ClientOption
Expand Down Expand Up @@ -187,6 +205,8 @@ func TestCreateToken(t *testing.T) {
KeyName: key,
SignerCacheTimeout: 5 * time.Minute,
Issuer: "test-iss",
DefaultTTL: 1 * time.Minute,
MaxTTL: 1 * time.Hour,
}, authHandler)

mockKeyManagement.Reqs = nil
Expand All @@ -196,7 +216,7 @@ func TestCreateToken(t *testing.T) {

response, gotErr := processor.CreateToken(ctx, tc.request)
if diff := pkgtestutil.DiffErrString(gotErr, tc.wantErr); diff != "" {
t.Errorf("Unexpected err: %s", diff)
t.Error(diff)
}
if gotErr != nil {
return
Expand Down Expand Up @@ -266,3 +286,62 @@ func TestCreateToken(t *testing.T) {
})
}
}

func TestComputeTTL(t *testing.T) {
t.Parallel()

cases := []struct {
name string
req time.Duration
def time.Duration
max time.Duration
exp time.Duration
err string
}{
{
name: "request_zero_uses_default",
req: 0,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 15 * time.Minute,
},
{
name: "request_negative_uses_default",
req: -10 * time.Second,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 15 * time.Minute,
},
{
name: "request_uses_self_in_bounds",
req: 12 * time.Minute,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 12 * time.Minute,
},
{
name: "request_greater_than_max_errors",
req: 1 * time.Hour,
def: 15 * time.Minute,
max: 30 * time.Minute,
err: "requested ttl (1h) cannot be greater than max tll (30m)",
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got, err := computeTTL(tc.req, tc.def, tc.max)
if result := pkgtestutil.DiffErrString(err, tc.err); result != "" {
t.Fatal(result)
}

if want := tc.exp; got != want {
t.Errorf("expected %q to be %q", got, want)
}
})
}
}
Loading

0 comments on commit 0c08e32

Please sign in to comment.