Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configuration for setting default and max TTLs #176

Merged
merged 1 commit into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pkg/config/justification_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"time"

"github.com/abcxyz/pkg/timeutil"
"github.com/hashicorp/go-multierror"
)

Expand All @@ -43,6 +44,14 @@ type JustificationConfig struct {

// Issuer will be used to set the issuer field when signing JWTs
Issuer string `yaml:"issuer" env:"ISSUER,overwrite,default=jvs.abcxyz.dev"`

// DefaultTTL sets the default TTL for JVS tokens that do not explicitly
// request a TTL. MaxTTL is the system-configured maximum TTL that a token can
// request.
//
// The DefaultTTL must be less than or equal to MaxTTL.
DefaultTTL time.Duration `yaml:"default_ttl" env:"DEFAULT_TTL,overwrite,default=15m"`
MaxTTL time.Duration `yaml:"max_ttl" env:"MAX_TTL,overwrite,default=4h"`
}

// Validate checks if the config is valid.
Expand All @@ -58,5 +67,11 @@ func (cfg *JustificationConfig) Validate() error {
err = multierror.Append(err, fmt.Errorf("cache timeout must be a positive duration, got %s",
cfg.SignerCacheTimeout))
}

if def, max := cfg.DefaultTTL, cfg.MaxTTL; def > max {
err = multierror.Append(err, fmt.Errorf("default ttl (%s) must be less than or equal to the max ttl (%s)",
timeutil.HumanDuration(def), timeutil.HumanDuration(max)))
}

return err.ErrorOrNil()
}
27 changes: 25 additions & 2 deletions pkg/config/justification_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ func TestJustificationConfig_Defaults(t *testing.T) {
Port: "8080",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
}

if diff := cmp.Diff(want, &justificationConfig); diff != "" {
Expand All @@ -66,12 +68,16 @@ port: 123
version: 1
signer_cache_timeout: 1m
issuer: jvs
default_ttl: 5m
max_ttl: 30m
`,
wantConfig: &JustificationConfig{
Port: "123",
Version: "1",
SignerCacheTimeout: 1 * time.Minute,
Issuer: "jvs",
DefaultTTL: 5 * time.Minute,
MaxTTL: 30 * time.Minute,
},
},
{
Expand All @@ -82,43 +88,60 @@ issuer: jvs
Version: "1",
SignerCacheTimeout: 5 * time.Minute,
Issuer: "jvs.abcxyz.dev",
DefaultTTL: 15 * time.Minute,
MaxTTL: 4 * time.Hour,
},
},
{
name: "test_wrong_version",
name: "wrong_version",
cfg: `
version: 255
`,
wantConfig: nil,
wantErr: `version "255" is invalid, valid versions are:`,
},
{
name: "test_invalid_signer_cache_timeout",
name: "invalid_signer_cache_timeout",
cfg: `
signer_cache_timeout: -1m
`,
wantConfig: nil,
wantErr: `cache timeout must be a positive duration, got -1m0s`,
},
{
name: "default_ttl_greater_than_max_ttl",
cfg: `
default_ttl: 1h
max_ttl: 30m
`,
wantConfig: nil,
wantErr: `default ttl (1h) must be less than or equal to the max ttl (30m)`,
},
{
name: "all_values_specified_env_override",
cfg: `
version: 1
port: 8080
signer_cache_timeout: 1m
issuer: jvs
default_ttl: 15m
max_ttl: 1h
`,
envs: map[string]string{
"VERSION": "1",
"PORT": "tcp",
"SIGNER_CACHE_TIMEOUT": "2m",
"ISSUER": "other",
"DEFAULT_TTL": "30m",
"MAX_TTL": "2h",
},
wantConfig: &JustificationConfig{
Version: "1",
Port: "tcp",
SignerCacheTimeout: 2 * time.Minute,
Issuer: "other",
DefaultTTL: 30 * time.Minute,
MaxTTL: 2 * time.Hour,
},
},
}
Expand Down
48 changes: 34 additions & 14 deletions pkg/justification/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/abcxyz/pkg/cache"
"github.com/abcxyz/pkg/grpcutil"
"github.com/abcxyz/pkg/logging"
"github.com/abcxyz/pkg/timeutil"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -79,28 +80,28 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat

if err := p.runValidations(req); err != nil {
logger.Errorw("failed to validate request", "error", err)
return nil, status.Error(codes.InvalidArgument, "failed to validate request")
return nil, status.Errorf(codes.InvalidArgument, "failed to validate request: %s", err)
}

token, err := p.createToken(ctx, now, req)
if err != nil {
logger.Errorw("failed to create token", "error", err)
return nil, status.Error(codes.Internal, "failed to create token")
return nil, status.Errorf(codes.Internal, "failed to create token: %s", err)
}

signer, err := p.cache.WriteThruLookup(cacheKey, func() (*signerWithID, error) {
return p.getPrimarySigner(ctx)
})
if err != nil {
logger.Errorw("failed to get signer", "error", err)
return nil, status.Error(codes.Internal, "failed to get token signer")
logger.Errorw("failed to get token signer", "error", err)
return nil, status.Errorf(codes.Internal, "failed to get token signer: %s", err)
}

// Build custom headers and set the "kid" as the signer ID.
headers := jws.NewHeaders()
if err := headers.Set(jws.KeyIDKey, signer.id); err != nil {
logger.Errorw("failed to set kid header", "error", err)
return nil, status.Error(codes.Internal, "failed to set token headers")
return nil, status.Errorf(codes.Internal, "failed to set token headers: %s", err)
}

// Sign the token.
Expand All @@ -116,14 +117,14 @@ func (p *Processor) CreateToken(ctx context.Context, req *jvspb.CreateJustificat
func (p *Processor) getPrimarySigner(ctx context.Context) (*signerWithID, error) {
primaryVer, err := jvscrypto.GetPrimary(ctx, p.kms, p.config.KeyName)
if err != nil {
return nil, fmt.Errorf("unable to determine primary, %w", err)
return nil, fmt.Errorf("failed to determine primary signing key: %w", err)
}
if primaryVer == "" {
return nil, fmt.Errorf("no primary version found")
}
sig, err := gcpkms.NewSigner(ctx, p.kms, primaryVer)
if err != nil {
return nil, fmt.Errorf("failed to create signer, %w", err)
return nil, fmt.Errorf("failed to create signer: %w", err)
}
return &signerWithID{
Signer: sig,
Expand All @@ -137,10 +138,6 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er
return fmt.Errorf("no justifications specified")
}

if request.Ttl == nil {
return fmt.Errorf("no ttl specified")
}

var err *multierror.Error
for _, j := range request.Justifications {
switch j.Category {
Expand All @@ -160,12 +157,18 @@ func (p *Processor) runValidations(request *jvspb.CreateJustificationRequest) er
func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.CreateJustificationRequest) (jwt.Token, error) {
email, err := p.authHandler.RequestPrincipal(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get email of requestor: %w", err)
return nil, fmt.Errorf("failed to get email of requestor: %w", err)
}

ttl, err := computeTTL(req.Ttl.AsDuration(), p.config.DefaultTTL, p.config.MaxTTL)
if err != nil {
return nil, fmt.Errorf("failed to compute ttl: %w", err)
}

id := uuid.New().String()
exp := now.Add(req.Ttl.AsDuration())
exp := now.Add(ttl)
justs := req.Justifications
iss := p.config.Issuer

// Use audiences in the request if provided.
aud := req.Audiences
Expand All @@ -177,7 +180,7 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C
Audience(aud).
Expiration(exp).
IssuedAt(now).
Issuer(p.config.Issuer).
Issuer(iss).
JwtID(id).
NotBefore(now).
Subject(email).
Expand All @@ -192,3 +195,20 @@ func (p *Processor) createToken(ctx context.Context, now time.Time, req *jvspb.C

return token, nil
}

// computeTTL is a helper that computes the best TTL given the requested TTL,
// default TTL, and maximum configured TTL. If the requested TTL is greater than
// the maximum TTL, it returns an error. If the requested TTL is 0, it returns
// the default TTL.
func computeTTL(req, def, max time.Duration) (time.Duration, error) {
if req <= 0 {
return def, nil
}

if req > max {
return 0, fmt.Errorf("requested ttl (%s) cannot be greater than max tll (%s)",
timeutil.HumanDuration(req), timeutil.HumanDuration(max))
}

return req, nil
}
85 changes: 82 additions & 3 deletions pkg/justification/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/abcxyz/jvs/pkg/jvscrypto"
"github.com/abcxyz/jvs/pkg/testutil"
"github.com/abcxyz/pkg/grpcutil"
"github.com/abcxyz/pkg/logging"
pkgtestutil "github.com/abcxyz/pkg/testutil"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -59,6 +60,7 @@ func TestCreateToken(t *testing.T) {
tests := []struct {
name string
request *jvspb.CreateJustificationRequest
wantTTL time.Duration
wantAudiences []string
wantErr string
serverErr error
Expand All @@ -74,6 +76,7 @@ func TestCreateToken(t *testing.T) {
},
Ttl: durationpb.New(3600 * time.Second),
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{DefaultAudience},
},
{
Expand All @@ -88,6 +91,7 @@ func TestCreateToken(t *testing.T) {
Ttl: durationpb.New(3600 * time.Second),
Audiences: []string{"aud1", "aud2"},
},
wantTTL: 1 * time.Hour,
wantAudiences: []string{"aud1", "aud2"},
},
{
Expand All @@ -107,7 +111,21 @@ func TestCreateToken(t *testing.T) {
},
},
},
wantErr: "failed to validate request",
wantTTL: 15 * time.Minute, // comes from default
wantAudiences: []string{"dev.abcxyz.jvs"},
},
{
name: "ttl_exceeds_max",
request: &jvspb.CreateJustificationRequest{
Justifications: []*jvspb.Justification{
{
Category: "explanation",
Value: "test",
},
},
Ttl: durationpb.New(10 * time.Hour),
},
wantErr: "requested ttl (10h) cannot be greater than max tll (1h)",
},
}

Expand All @@ -117,7 +135,7 @@ func TestCreateToken(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx := context.Background()
ctx := logging.WithLogger(context.Background(), logging.TestLogger(t))
now := time.Now().UTC()

var clientOpt option.ClientOption
Expand Down Expand Up @@ -187,6 +205,8 @@ func TestCreateToken(t *testing.T) {
KeyName: key,
SignerCacheTimeout: 5 * time.Minute,
Issuer: "test-iss",
DefaultTTL: 1 * time.Minute,
MaxTTL: 1 * time.Hour,
}, authHandler)

mockKeyManagement.Reqs = nil
Expand All @@ -196,7 +216,7 @@ func TestCreateToken(t *testing.T) {

response, gotErr := processor.CreateToken(ctx, tc.request)
if diff := pkgtestutil.DiffErrString(gotErr, tc.wantErr); diff != "" {
t.Errorf("Unexpected err: %s", diff)
t.Error(diff)
}
if gotErr != nil {
return
Expand Down Expand Up @@ -266,3 +286,62 @@ func TestCreateToken(t *testing.T) {
})
}
}

func TestComputeTTL(t *testing.T) {
t.Parallel()

cases := []struct {
name string
req time.Duration
def time.Duration
max time.Duration
exp time.Duration
err string
}{
{
name: "request_zero_uses_default",
req: 0,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 15 * time.Minute,
},
{
name: "request_negative_uses_default",
req: -10 * time.Second,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 15 * time.Minute,
},
{
name: "request_uses_self_in_bounds",
req: 12 * time.Minute,
def: 15 * time.Minute,
max: 30 * time.Minute,
exp: 12 * time.Minute,
},
{
name: "request_greater_than_max_errors",
req: 1 * time.Hour,
def: 15 * time.Minute,
max: 30 * time.Minute,
err: "requested ttl (1h) cannot be greater than max tll (30m)",
},
}

for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got, err := computeTTL(tc.req, tc.def, tc.max)
if result := pkgtestutil.DiffErrString(err, tc.err); result != "" {
t.Fatal(result)
}

if want := tc.exp; got != want {
t.Errorf("expected %q to be %q", got, want)
}
})
}
}
Loading