diff --git a/crypto/jwt.go b/crypto/jwt.go index 777adde5..f81f4916 100644 --- a/crypto/jwt.go +++ b/crypto/jwt.go @@ -8,6 +8,7 @@ import ( "github.com/goccy/go-json" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwk" + "github.com/lestrrat-go/jwx/jws" "github.com/lestrrat-go/jwx/jwt" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -95,13 +96,6 @@ func (sv *JWTSigner) SignJWT(kvs map[string]interface{}) ([]byte, error) { if err := t.Set(jwt.IssuerKey, kid); err != nil { return nil, fmt.Errorf("could not set iss with provided value: %s", kid) } - if err := t.Set(jwk.KeyIDKey, kid); err != nil { - return nil, fmt.Errorf("could not set kid with provided value: %s", kid) - } - alg := sv.Key.Algorithm() - if err := t.Set(jwk.AlgorithmKey, alg); err != nil { - return nil, fmt.Errorf("could not set alg with value: %s", alg) - } iat := time.Now().Unix() if err := t.Set(jwt.IssuedAtKey, iat); err != nil { return nil, fmt.Errorf("could not set iat with value: %d", iat) @@ -146,6 +140,20 @@ func (*JWTVerifier) ParseJWT(token string) (jwt.Token, error) { return parsed, nil } +// ParseJWS attempts to pull of a single signature from a token, containing its headers +func (*JWTVerifier) ParseJWS(token string) (*jws.Signature, error) { + parsed, err := jws.Parse([]byte(token)) + if err != nil { + logrus.WithError(err).Error("could not parse JWS") + return nil, err + } + signatures := parsed.Signatures() + if len(signatures) != 1 { + return nil, fmt.Errorf("expected 1 signature, got %d", len(signatures)) + } + return signatures[0], nil +} + // VerifyAndParseJWT attempts to turn a string into a jwt.Token and verify its signature using the verifier func (sv *JWTVerifier) VerifyAndParseJWT(token string) (jwt.Token, error) { parsed, err := jwt.Parse([]byte(token), jwt.WithVerify(jwa.SignatureAlgorithm(sv.Algorithm()), sv.Key)) diff --git a/crypto/jwt_test.go b/crypto/jwt_test.go index 4de8e1d2..66c9c1b2 100644 --- a/crypto/jwt_test.go +++ b/crypto/jwt_test.go @@ -3,6 +3,7 @@ package crypto import ( "testing" + "github.com/lestrrat-go/jwx/jwt" "github.com/stretchr/testify/assert" ) @@ -79,7 +80,7 @@ func TestSignVerifyGenericJWT(t *testing.T) { assert.True(t, ok) assert.EqualValues(t, "abcd", gotID) - gotJTI, ok := parsed.Get("jti") + gotJTI, ok := parsed.Get(jwt.JwtIDKey) assert.True(t, ok) assert.EqualValues(t, "1234", gotJTI) @@ -89,6 +90,13 @@ func TestSignVerifyGenericJWT(t *testing.T) { _, err = verifier.VerifyAndParseJWT(string(token)) assert.NoError(t, err) + + // parse out the headers + jws, err := verifier.ParseJWS(string(token)) + assert.NoError(t, err) + assert.NotEmpty(t, jws) + assert.EqualValues(t, "EdDSA", jws.ProtectedHeaders().Algorithm()) + assert.EqualValues(t, "did:example:123#key-0", jws.ProtectedHeaders().KeyID()) } func getTestVectorKey0Signer(t *testing.T) JWTSigner {