Skip to content

Commit

Permalink
Add more defense about unmarshalling bad types
Browse files Browse the repository at this point in the history
  • Loading branch information
sethvargo committed Sep 21, 2022
1 parent c5bced2 commit 5c58380
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 10 deletions.
37 changes: 28 additions & 9 deletions apis/v0/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"

"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/mitchellh/mapstructure"
)

const (
Expand All @@ -31,15 +32,22 @@ const (

// WithTypedJustifications is an option for parsing JWTs that will convert
// decode the [Justification] claims into the correct Go structure. If this is
// not supplied, the claims will be "any" and future type assertions will fail.
// not supplied, the claims will be "any" and future type assertions may fail.
func WithTypedJustifications() jwt.ParseOption {
return jwt.WithTypedClaim(jwtJustificationsKey, []*Justification{})
}

// GetJustifications retrieves a copy of the justifications on the token. If the
// token does not have any justifications, it returns an empty slice of
// justifications. Modifying the slice does not modify the underlying token -
// you must call SetJustifications to update the data on the token.
// justifications.
//
// This function is incredibly defensive against a poorly-parsed jwt. It handles
// situations where the JWT was not properly decoded (i.e. the caller did not
// use [WithTypedJustifications]), and when the token uses a single
// justification instead of a slice.
//
// Modifying the slice does not modify the underlying token - you must call
// [SetJustifications] to update the data on the token.
func GetJustifications(t jwt.Token) ([]*Justification, error) {
if t == nil {
return nil, fmt.Errorf("token cannot be nil")
Expand All @@ -50,15 +58,26 @@ func GetJustifications(t jwt.Token) ([]*Justification, error) {
return []*Justification{}, nil
}

typ, ok := raw.([]*Justification)
if !ok {
return nil, fmt.Errorf("found justifications, but was %T (expected %T)",
raw, []*Justification{})
var claims []*Justification
switch list := raw.(type) {
case []*Justification:
// Token was decoded with typed claims.
claims = list
case *Justification:
// Token did not provide a list.
claims = []*Justification{list}
case []any:
// Token was a proto but wasn't decoded.
if err := mapstructure.Decode(list, &claims); err != nil {
return nil, fmt.Errorf("found justifications, but could not decode map data: %w", err)
}
default:
return nil, fmt.Errorf("found justifications, but was of unknown type %T", raw)
}

// Make a copy of the slice so we don't modify the underlying data structure.
cp := make([]*Justification, 0, len(typ))
cp = append(cp, typ...)
cp := make([]*Justification, 0, len(claims))
cp = append(cp, claims...)
return cp, nil
}

Expand Down
56 changes: 55 additions & 1 deletion apis/v0/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/abcxyz/pkg/testutil"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)

Expand Down Expand Up @@ -48,7 +49,60 @@ func TestGetJustifications(t *testing.T) {
token: testTokenBuilder(t, jwt.
NewBuilder().
Claim(jwtJustificationsKey, "not_valid")),
expErr: "found justifications, but was string",
expErr: "unknown type",
},
{
// This test checks that we still properly decode justifications even if
// the caller did not specify decoding the custom type claims. To drop all
// type information, we serialize the token and then parse it without type
// information.
name: "not_decoded_claims",
token: func() jwt.Token {
token, err := jwt.NewBuilder().
Claim(jwtJustificationsKey, []*Justification{
{
Category: "category",
Value: "value",
},
}).
Build()
if err != nil {
t.Fatal(err)
}

b, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("KEY")))
if err != nil {
t.Fatal(err)
}

parsed, err := jwt.Parse(b, jwt.WithVerify(false))
if err != nil {
t.Fatal(err)
}
return parsed
}(),
exp: []*Justification{
{
Category: "category",
Value: "value",
},
},
},
{
name: "single_justification",
token: testTokenBuilder(t, jwt.
NewBuilder().
Claim(jwtJustificationsKey, &Justification{
Category: "category",
Value: "value",
}),
),
exp: []*Justification{
{
Category: "category",
Value: "value",
},
},
},
{
name: "returns_justifications",
Expand Down

0 comments on commit 5c58380

Please sign in to comment.