From d8dd1faffe14639134d8883c333acba0ee7595eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Thu, 1 Apr 2021 18:57:26 +0200 Subject: [PATCH] fix!: NumericDate parsing conformance --- claims.go | 62 +++++++++++++++++++-------------- claims_test.go | 81 ++++++++++++++++++++++++++++++++++++++++++++ http_example_test.go | 7 ++-- map_claims.go | 21 ++++++------ parser_test.go | 2 +- rsa_pss_test.go | 2 +- 6 files changed, 135 insertions(+), 40 deletions(-) create mode 100644 claims_test.go diff --git a/claims.go b/claims.go index 62489066..06c84a45 100644 --- a/claims.go +++ b/claims.go @@ -3,6 +3,7 @@ package jwt import ( "crypto/subtle" "fmt" + "math" "time" ) @@ -17,12 +18,12 @@ type Claims interface { // See examples for how to use this with your own claim types type StandardClaims struct { Audience []string `json:"aud,omitempty"` - ExpiresAt int64 `json:"exp,omitempty"` - Id string `json:"jti,omitempty"` - IssuedAt int64 `json:"iat,omitempty"` - Issuer string `json:"iss,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - Subject string `json:"sub,omitempty"` + ExpiresAt float64 `json:"exp,omitempty"` + Id string `json:"jti,omitempty"` + IssuedAt float64 `json:"iat,omitempty"` + Issuer string `json:"iss,omitempty"` + NotBefore float64 `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` } // Validates time based claims "exp, iat, nbf". @@ -31,12 +32,12 @@ type StandardClaims struct { // be considered a valid claim. func (c StandardClaims) Valid() error { vErr := new(ValidationError) - now := TimeFunc().Unix() + now := TimeFunc() // The claims below are optional, by default, so if they are set to the // default value in Go, let's not fail the verification for them. if c.VerifyExpiresAt(now, false) == false { - delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) + delta := now.Sub(parseUnixFloat(c.ExpiresAt)) vErr.Inner = fmt.Errorf("token is expired by %v", delta) vErr.Errors |= ValidationErrorExpired } @@ -66,13 +67,13 @@ func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool { // Compares the exp claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool { +func (c *StandardClaims) VerifyExpiresAt(cmp time.Time, req bool) bool { return verifyExp(c.ExpiresAt, cmp, req) } // Compares the iat claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool { +func (c *StandardClaims) VerifyIssuedAt(cmp time.Time, req bool) bool { return verifyIat(c.IssuedAt, cmp, req) } @@ -84,7 +85,7 @@ func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool { // Compares the nbf claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { +func (c *StandardClaims) VerifyNotBefore(cmp time.Time, req bool) bool { return verifyNbf(c.NotBefore, cmp, req) } @@ -103,34 +104,45 @@ func verifyAud(aud []string, cmp string, required bool) bool { return false } -func verifyExp(exp int64, now int64, required bool) bool { - if exp == 0 { +func verifyExp(exp float64, now time.Time, required bool) bool { + if exp == 0. { return !required } - return now <= exp + + pexp := parseUnixFloat(exp) + + return pexp.Equal(now) || now.Before(pexp) } -func verifyIat(iat int64, now int64, required bool) bool { - if iat == 0 { +func verifyIat(iat float64, now time.Time, required bool) bool { + if iat == 0. { return !required } - return now >= iat + + piat := parseUnixFloat(iat) + + return piat.Equal(now) || now.After(piat) } func verifyIss(iss string, cmp string, required bool) bool { if iss == "" { return !required } - if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { - return true - } else { - return false - } + + return subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 } -func verifyNbf(nbf int64, now int64, required bool) bool { - if nbf == 0 { +func verifyNbf(nbf float64, now time.Time, required bool) bool { + if nbf == 0. { return !required } - return now >= nbf + + pnbf := parseUnixFloat(nbf) + + return pnbf.Equal(now) || now.After(pnbf) +} + +func parseUnixFloat(ts float64) time.Time { + int, frac := math.Modf(ts) + return time.Unix(int64(int), int64(frac*(1e9))) } diff --git a/claims_test.go b/claims_test.go new file mode 100644 index 00000000..b30aa268 --- /dev/null +++ b/claims_test.go @@ -0,0 +1,81 @@ +package jwt + +import ( + "testing" + "time" +) + +func Test_StandardClaims_VerifyExpiresAt_empty(t *testing.T) { + c := StandardClaims{} + if !c.VerifyExpiresAt(time.Now(), false) { + t.Fatalf("Failed to verify exp claim, wanted: %v got %v", true, false) + } +} + +func Test_StandardClaims_VerifyExpiresAt_expired(t *testing.T) { + c := StandardClaims{ + ExpiresAt: float64(time.Now().Add(-1*time.Hour).Unix()) + 0.123, + } + if c.VerifyExpiresAt(time.Now(), true) { + t.Fatalf("Failed to verify exp claim, wanted: %v got %v", false, true) + } +} + +func Test_StandardClaims_VerifyExpiresAt_not_expired(t *testing.T) { + c := StandardClaims{ + ExpiresAt: float64(time.Now().Add(1*time.Hour).Unix()) + 0.123, + } + if !c.VerifyExpiresAt(time.Now(), true) { + t.Fatalf("Failed to verify exp claim, wanted: %v got %v", true, false) + } +} + +func Test_StandardClaims_VerifyIssuedAt_empty(t *testing.T) { + c := StandardClaims{} + if !c.VerifyIssuedAt(time.Now(), false) { + t.Fatalf("Failed to verify iat claim, wanted: %v got %v", true, false) + } +} + +func Test_StandardClaims_VerifyIssuedAt_expired(t *testing.T) { + c := StandardClaims{ + IssuedAt: float64(time.Now().Add(1*time.Hour).Unix()) + 0.123, + } + if c.VerifyIssuedAt(time.Now(), true) { + t.Fatalf("Failed to verify iat claim, wanted: %v got %v", false, true) + } +} + +func Test_StandardClaims_VerifyIssuedAt_past(t *testing.T) { + c := StandardClaims{ + IssuedAt: float64(time.Now().Add(-1*time.Hour).Unix()) + 0.123, + } + if !c.VerifyIssuedAt(time.Now(), true) { + t.Fatalf("Failed to verify iat claim, wanted: %v got %v", true, false) + } +} + +func Test_StandardClaims_VerifyNotBefore_empty(t *testing.T) { + c := StandardClaims{} + if !c.VerifyNotBefore(time.Now(), false) { + t.Fatalf("Failed to verify nbf claim, wanted: %v got %v", true, false) + } +} + +func Test_StandardClaims_VerifyNotBefore_expired(t *testing.T) { + c := StandardClaims{ + NotBefore: float64(time.Now().Add(1*time.Hour).Unix()) + 0.123, + } + if c.VerifyNotBefore(time.Now(), true) { + t.Fatalf("Failed to verify nbf claim, wanted: %v got %v", false, true) + } +} + +func Test_StandardClaims_VerifyNotBefore_passed(t *testing.T) { + c := StandardClaims{ + NotBefore: float64(time.Now().Add(-1*time.Hour).Unix()) + 0.123, + } + if !c.VerifyNotBefore(time.Now(), true) { + t.Fatalf("Failed to verify nbf claim, wanted: %v got %v", true, false) + } +} diff --git a/http_example_test.go b/http_example_test.go index 6d1f835f..55df78fc 100644 --- a/http_example_test.go +++ b/http_example_test.go @@ -7,8 +7,6 @@ import ( "bytes" "crypto/rsa" "fmt" - "github.com/form3tech-oss/jwt-go" - "github.com/form3tech-oss/jwt-go/request" "io" "io/ioutil" "log" @@ -17,6 +15,9 @@ import ( "net/url" "strings" "time" + + "github.com/form3tech-oss/jwt-go" + "github.com/form3tech-oss/jwt-go/request" ) // location of the files used for signing and verification @@ -150,7 +151,7 @@ func createToken(user string) (string, error) { &jwt.StandardClaims{ // set the expire time // see http://tools.ietf.org/html/draft-ietf-oauth-json-web-token-20#section-4.1.4 - ExpiresAt: time.Now().Add(time.Minute * 1).Unix(), + ExpiresAt: float64(time.Now().Add(time.Minute * 1).Unix()), }, "level1", CustomerInfo{user, "human"}, diff --git a/map_claims.go b/map_claims.go index 90ab6bea..590104a8 100644 --- a/map_claims.go +++ b/map_claims.go @@ -3,6 +3,7 @@ package jwt import ( "encoding/json" "errors" + "time" // "fmt" ) @@ -27,12 +28,12 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool { // Compares the exp claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { +func (m MapClaims) VerifyExpiresAt(cmp time.Time, req bool) bool { switch exp := m["exp"].(type) { case float64: - return verifyExp(int64(exp), cmp, req) + return verifyExp(exp, cmp, req) case json.Number: - v, _ := exp.Int64() + v, _ := exp.Float64() return verifyExp(v, cmp, req) } return req == false @@ -40,12 +41,12 @@ func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { // Compares the iat claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { +func (m MapClaims) VerifyIssuedAt(cmp time.Time, req bool) bool { switch iat := m["iat"].(type) { case float64: - return verifyIat(int64(iat), cmp, req) + return verifyIat(iat, cmp, req) case json.Number: - v, _ := iat.Int64() + v, _ := iat.Float64() return verifyIat(v, cmp, req) } return req == false @@ -60,12 +61,12 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { // Compares the nbf claim against cmp. // If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { +func (m MapClaims) VerifyNotBefore(cmp time.Time, req bool) bool { switch nbf := m["nbf"].(type) { case float64: - return verifyNbf(int64(nbf), cmp, req) + return verifyNbf(nbf, cmp, req) case json.Number: - v, _ := nbf.Int64() + v, _ := nbf.Float64() return verifyNbf(v, cmp, req) } return req == false @@ -77,7 +78,7 @@ func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { // be considered a valid claim. func (m MapClaims) Valid() error { vErr := new(ValidationError) - now := TimeFunc().Unix() + now := TimeFunc() if m.VerifyExpiresAt(now, false) == false { vErr.Inner = errors.New("Token is expired") diff --git a/parser_test.go b/parser_test.go index 8a4a408a..f0fb0811 100644 --- a/parser_test.go +++ b/parser_test.go @@ -139,7 +139,7 @@ var jwtTestData = []struct { "", defaultKeyFunc, &jwt.StandardClaims{ - ExpiresAt: time.Now().Add(time.Second * 10).Unix(), + ExpiresAt: float64(time.Now().Add(time.Second * 10).Unix()), }, true, 0, diff --git a/rsa_pss_test.go b/rsa_pss_test.go index f4a0a0a8..8eebb928 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -133,7 +133,7 @@ func TestRSAPSSSaltLengthCompatibility(t *testing.T) { func makeToken(method jwt.SigningMethod) string { token := jwt.NewWithClaims(method, jwt.StandardClaims{ Issuer: "example", - IssuedAt: time.Now().Unix(), + IssuedAt: float64(time.Now().Unix()), }) privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key") signed, err := token.SignedString(privateKey)