From 24fc4cf16da589d315b7d6b0557385577c3d150a Mon Sep 17 00:00:00 2001 From: Maciej Galkowski Date: Thu, 26 Nov 2015 15:23:14 +0000 Subject: [PATCH 1/4] Add gzip compression support for the JWT Claims part --- compression_method.go | 103 ++++++++++++++++++++++++++++++++++++++++++ parser.go | 16 +++++-- parser_test.go | 38 +++++++++++++--- token.go | 58 +++++++++++++----------- token_test.go | 46 +++++++++++++++++++ 5 files changed, 226 insertions(+), 35 deletions(-) create mode 100644 compression_method.go create mode 100644 token_test.go diff --git a/compression_method.go b/compression_method.go new file mode 100644 index 00000000..dff9f524 --- /dev/null +++ b/compression_method.go @@ -0,0 +1,103 @@ +package jwt + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" +) + +var ( + // CompressionNone does not perform any data compression/decompression + CompressionNone CompressionMethod + // CompressionGzip compresses the claims part of the JWT with gzip algorithm + CompressionGzip CompressionMethod +) + +var ( + compressionMethods map[string]CompressionMethod +) + +// CompressionMethod is an interface used to compress/decompress the Claims part of the JWT +type CompressionMethod interface { + // Alg returns the name of the compression algorithm. It is saved in the token header + Alg() string + // Compress takes uncompressed data, and returns the compression result` + Compress(data []byte) ([]byte, error) + // Decompress takes compressed data and returns uncompressed version + Decompress(data []byte) ([]byte, error) +} + +type compressionGzip struct{} + +type compressionNone struct{} + +func (c *compressionGzip) Alg() string { + return "gzip" +} + +func (c *compressionGzip) Compress(data []byte) ([]byte, error) { + var buffer = &bytes.Buffer{} + var writer = gzip.NewWriter(buffer) + + if _, err := writer.Write(data); err != nil { + writer.Close() + return nil, err + } + + writer.Close() + return buffer.Bytes(), nil +} + +func (c *compressionGzip) Decompress(data []byte) ([]byte, error) { + var buffer = bytes.NewBuffer(data) + var reader, err = gzip.NewReader(buffer) + defer reader.Close() + if err != nil { + return nil, err + } + + return ioutil.ReadAll(reader) +} + +func (c *compressionNone) Alg() string { + return "none" +} + +func (c *compressionNone) Compress(data []byte) ([]byte, error) { + return data, nil +} + +func (c *compressionNone) Decompress(data []byte) ([]byte, error) { + return data, nil +} + +// RegisterCompressionMethod adds support for additional compression method in the runtime. +// The name value is saved in the token header and later used to retrieve the method interface +// used to decompress the header +func RegisterCompressionMethod(name string, method CompressionMethod) { + compressionMethods[name] = method +} + +func getCompressionMethod(alg interface{}) (CompressionMethod, error) { + var algString, ok = alg.(string) + if ok == false || len(algString) == 0 { + return compressionMethods["none"], nil + } + + var method = compressionMethods[algString] + if method == nil { + return nil, fmt.Errorf("Compression method %s not registered", alg) + } + + return method, nil +} + +func init() { + CompressionNone = &compressionNone{} + CompressionGzip = &compressionGzip{} + + compressionMethods = make(map[string]CompressionMethod) + compressionMethods["none"] = CompressionNone + compressionMethods["gzip"] = CompressionGzip +} diff --git a/parser.go b/parser.go index 3fc27bfe..2c21f3d4 100644 --- a/parser.go +++ b/parser.go @@ -32,11 +32,21 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} } + //var cMethod, ok = token.Header["cpr"].(string) // parse Claims var claimBytes []byte if claimBytes, err = DecodeSegment(parts[1]); err != nil { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} } + + var compression CompressionMethod + if compression, err = getCompressionMethod(token.Header["cpr"]); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + if claimBytes, err = compression.Decompress(claimBytes); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) if p.UseJSONNumber { dec.UseNumber() @@ -47,7 +57,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { // Lookup signature method if method, ok := token.Header["alg"].(string); ok { - if token.Method = GetSigningMethod(method); token.Method == nil { + if token.SigningMethod = GetSigningMethod(method); token.SigningMethod == nil { return token, &ValidationError{err: "signing method (alg) is unavailable.", Errors: ValidationErrorUnverifiable} } } else { @@ -57,7 +67,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { // Verify signing method is in the required set if p.ValidMethods != nil { var signingMethodValid = false - var alg = token.Method.Alg() + var alg = token.SigningMethod.Alg() for _, m := range p.ValidMethods { if m == alg { signingMethodValid = true @@ -99,7 +109,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { // Perform validation token.Signature = parts[2] - if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { + if err = token.SigningMethod.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { vErr.err = err.Error() vErr.Errors |= ValidationErrorSignatureInvalid } diff --git a/parser_test.go b/parser_test.go index 9115017b..64a6c901 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,12 +3,13 @@ package jwt_test import ( "encoding/json" "fmt" - "github.com/dgrijalva/jwt-go" "io/ioutil" "net/http" "reflect" "testing" "time" + + "github.com/dgrijalva/jwt-go" ) var ( @@ -136,13 +137,13 @@ func init() { } } -func makeSample(c map[string]interface{}) string { +func makeSample(c map[string]interface{}, compression jwt.CompressionMethod) string { key, e := ioutil.ReadFile("test/sample_key") if e != nil { panic(e.Error()) } - token := jwt.New(jwt.SigningMethodRS256) + token := jwt.New(jwt.SigningMethodRS256, compression) token.Claims = c s, e := token.SignedString(key) @@ -156,7 +157,7 @@ func makeSample(c map[string]interface{}) string { func TestParser_Parse(t *testing.T) { for _, data := range jwtTestData { if data.tokenString == "" { - data.tokenString = makeSample(data.claims) + data.tokenString = makeSample(data.claims, jwt.CompressionNone) } var token *jwt.Token @@ -202,7 +203,7 @@ func TestParseRequest(t *testing.T) { } if data.tokenString == "" { - data.tokenString = makeSample(data.claims) + data.tokenString = makeSample(data.claims, jwt.CompressionNone) } r, _ := http.NewRequest("GET", "/", nil) @@ -225,9 +226,34 @@ func TestParseRequest(t *testing.T) { } } +func TestParseWithCompression(t *testing.T) { + var token = jwt.New(jwt.SigningMethodHS256, jwt.CompressionGzip) + var claimsMap = map[string]interface{}{ + "claim1": "value1", + } + token.Claims = claimsMap + + var tokenString, err = token.SignedString([]byte("TEST KEY")) + if err != nil { + t.Fatalf("Unexpected error creating token string: %s", err.Error()) + } + + var parsedToken *jwt.Token + parsedToken, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte("TEST KEY"), nil + }) + if err != nil { + t.Errorf("Error while parsing the token: %s", err.Error()) + } + + if reflect.DeepEqual(parsedToken.Claims, claimsMap) == false { + t.Errorf("Claims mismatch") + } +} + // Helper method for benchmarking various methods func benchmarkSigning(b *testing.B, method jwt.SigningMethod, key interface{}) { - t := jwt.New(method) + t := jwt.New(method, jwt.CompressionNone) b.RunParallel(func(pb *testing.PB) { for pb.Next() { if _, err := t.SignedString(key); err != nil { diff --git a/token.go b/token.go index d35aaa4a..ae3e36fe 100644 --- a/token.go +++ b/token.go @@ -22,23 +22,24 @@ type Keyfunc func(*Token) (interface{}, error) // A JWT Token. Different fields will be used depending on whether you're // creating or parsing/verifying a token. type Token struct { - Raw string // The raw token. Populated when you Parse a token - Method SigningMethod // The signing method used or to be used - Header map[string]interface{} // The first segment of the token - Claims map[string]interface{} // The second segment of the token - Signature string // The third segment of the token. Populated when you Parse a token - Valid bool // Is the token valid? Populated when you Parse/Verify a token + Raw string // The raw token. Populated when you Parse a token + SigningMethod SigningMethod // The signing method used or to be used + Header map[string]interface{} // The first segment of the token + Claims map[string]interface{} // The second segment of the token + Signature string // The third segment of the token. Populated when you Parse a token + Valid bool // Is the token valid? Populated when you Parse/Verify a token } -// Create a new Token. Takes a signing method -func New(method SigningMethod) *Token { +// Create a new Token. Takes a signing method and compression method +func New(signingMethod SigningMethod, compressionMethod CompressionMethod) *Token { return &Token{ Header: map[string]interface{}{ - "typ": "JWT", - "alg": method.Alg(), + "typ": "JWT", + "alg": signingMethod.Alg(), + "cpr": compressionMethod.Alg(), }, - Claims: make(map[string]interface{}), - Method: method, + Claims: make(map[string]interface{}), + SigningMethod: signingMethod, } } @@ -49,7 +50,7 @@ func (t *Token) SignedString(key interface{}) (string, error) { if sstr, err = t.SigningString(); err != nil { return "", err } - if sig, err = t.Method.Sign(sstr, key); err != nil { + if sig, err = t.SigningMethod.Sign(sstr, key); err != nil { return "", err } return strings.Join([]string{sstr, sig}, "."), nil @@ -61,22 +62,27 @@ func (t *Token) SignedString(key interface{}) (string, error) { // the SignedString. func (t *Token) SigningString() (string, error) { var err error - parts := make([]string, 2) - for i, _ := range parts { - var source map[string]interface{} - if i == 0 { - source = t.Header - } else { - source = t.Claims - } + var parts = []string{} + var jsonValue []byte + var compression CompressionMethod - var jsonValue []byte - if jsonValue, err = json.Marshal(source); err != nil { - return "", err - } + if jsonValue, err = json.Marshal(t.Header); err != nil { + return "", err + } + parts = append(parts, EncodeSegment(jsonValue)) - parts[i] = EncodeSegment(jsonValue) + if jsonValue, err = json.Marshal(t.Claims); err != nil { + return "", err } + + if compression, err = getCompressionMethod(t.Header["cpr"]); err != nil { + return "", err + } + if jsonValue, err = compression.Compress(jsonValue); err != nil { + return "", err + } + parts = append(parts, EncodeSegment(jsonValue)) + return strings.Join(parts, "."), nil } diff --git a/token_test.go b/token_test.go new file mode 100644 index 00000000..1f3bc2da --- /dev/null +++ b/token_test.go @@ -0,0 +1,46 @@ +package jwt_test + +import ( + "strings" + "testing" + + "github.com/dgrijalva/jwt-go" +) + +func TestNewTokenWithGzipCompression(t *testing.T) { + var token = jwt.New(jwt.SigningMethodHS256, jwt.CompressionGzip) + + token.Claims = map[string]interface{}{ + "claim1": "testvalue1", + "claim2": 42, + } + + var tokenString, err = token.SignedString([]byte("TEST KEY")) + if err != nil { + t.Errorf("Error signing gzip compressed claims: %s", err) + } + + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + t.Errorf("Token string has %d parts", len(parts)) + } + if parts[1] != "H4sIAAAJbogA_6pWSs5JzMw1VLJSKkktLilLzClNNVTSgYgaKVmZGNUCAgAA__9vre6IIwAAAA" { + t.Error("Token claims not encoded properly") + } +} + +func TestNewTokenWithWrongCompressionAlg(t *testing.T) { + var token = jwt.New(jwt.SigningMethodHS256, jwt.CompressionNone) + token.Header["cpr"] = "dummy" // set wrong compression method + token.Claims = map[string]interface{}{ + "claim1": "testvalue1", + "claim2": 42, + } + + var _, err = token.SignedString([]byte("TEST KEY")) + if err == nil { + t.Errorf("Expected error") + } else if err.Error() != "Compression method dummy not registered" { + t.Errorf("Unexpected error description: %s", err.Error()) + } +} From 5e58bd6b1c205094e94236dce72d34760ac10908 Mon Sep 17 00:00:00 2001 From: Maciej Galkowski Date: Thu, 26 Nov 2015 15:30:35 +0000 Subject: [PATCH 2/4] Cosmetic changes --- parser.go | 7 +++---- token.go | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/parser.go b/parser.go index 2c21f3d4..1f3bbe9b 100644 --- a/parser.go +++ b/parser.go @@ -32,7 +32,6 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} } - //var cMethod, ok = token.Header["cpr"].(string) // parse Claims var claimBytes []byte if claimBytes, err = DecodeSegment(parts[1]); err != nil { @@ -57,7 +56,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { // Lookup signature method if method, ok := token.Header["alg"].(string); ok { - if token.SigningMethod = GetSigningMethod(method); token.SigningMethod == nil { + if token.Method = GetSigningMethod(method); token.Method == nil { return token, &ValidationError{err: "signing method (alg) is unavailable.", Errors: ValidationErrorUnverifiable} } } else { @@ -67,7 +66,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { // Verify signing method is in the required set if p.ValidMethods != nil { var signingMethodValid = false - var alg = token.SigningMethod.Alg() + var alg = token.Method.Alg() for _, m := range p.ValidMethods { if m == alg { signingMethodValid = true @@ -109,7 +108,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { // Perform validation token.Signature = parts[2] - if err = token.SigningMethod.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { + if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { vErr.err = err.Error() vErr.Errors |= ValidationErrorSignatureInvalid } diff --git a/token.go b/token.go index ae3e36fe..edaefbcd 100644 --- a/token.go +++ b/token.go @@ -22,24 +22,24 @@ type Keyfunc func(*Token) (interface{}, error) // A JWT Token. Different fields will be used depending on whether you're // creating or parsing/verifying a token. type Token struct { - Raw string // The raw token. Populated when you Parse a token - SigningMethod SigningMethod // The signing method used or to be used - Header map[string]interface{} // The first segment of the token - Claims map[string]interface{} // The second segment of the token - Signature string // The third segment of the token. Populated when you Parse a token - Valid bool // Is the token valid? Populated when you Parse/Verify a token + Raw string // The raw token. Populated when you Parse a token + Method SigningMethod // The signing method used or to be used + Header map[string]interface{} // The first segment of the token + Claims map[string]interface{} // The second segment of the token + Signature string // The third segment of the token. Populated when you Parse a token + Valid bool // Is the token valid? Populated when you Parse/Verify a token } // Create a new Token. Takes a signing method and compression method func New(signingMethod SigningMethod, compressionMethod CompressionMethod) *Token { return &Token{ Header: map[string]interface{}{ - "typ": "JWT", - "alg": signingMethod.Alg(), + "typ": "JWT", + "alg": signingMethod.Alg(), "cpr": compressionMethod.Alg(), }, - Claims: make(map[string]interface{}), - SigningMethod: signingMethod, + Claims: make(map[string]interface{}), + Method: signingMethod, } } @@ -50,7 +50,7 @@ func (t *Token) SignedString(key interface{}) (string, error) { if sstr, err = t.SigningString(); err != nil { return "", err } - if sig, err = t.SigningMethod.Sign(sstr, key); err != nil { + if sig, err = t.Method.Sign(sstr, key); err != nil { return "", err } return strings.Join([]string{sstr, sig}, "."), nil From 195d80f48a7f5ff15f5a51a1495093c4cc995c64 Mon Sep 17 00:00:00 2001 From: Maciej Galkowski Date: Fri, 27 Nov 2015 10:06:43 +0000 Subject: [PATCH 3/4] Fix build error in the cmd app --- cmd/jwt/app.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/jwt/app.go b/cmd/jwt/app.go index 4068a805..6f542c33 100644 --- a/cmd/jwt/app.go +++ b/cmd/jwt/app.go @@ -182,7 +182,7 @@ func signToken() error { } // create a new token - token := jwt.New(alg) + token := jwt.New(alg, jwt.CompressionNone) token.Claims = claims if isEs() { From 15ad6f96f0e2d526e8bdc9fff8a6f1abfdefee29 Mon Sep 17 00:00:00 2001 From: Maciej Galkowski Date: Fri, 27 Nov 2015 10:11:02 +0000 Subject: [PATCH 4/4] Fix more compile errors --- example_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example_test.go b/example_test.go index edb48e4d..a5cee509 100644 --- a/example_test.go +++ b/example_test.go @@ -2,8 +2,9 @@ package jwt_test import ( "fmt" - "github.com/dgrijalva/jwt-go" "time" + + "github.com/dgrijalva/jwt-go" ) func ExampleParse(myToken string, myLookupKey func(interface{}) (interface{}, error)) { @@ -20,7 +21,7 @@ func ExampleParse(myToken string, myLookupKey func(interface{}) (interface{}, er func ExampleNew(mySigningKey []byte) (string, error) { // Create the token - token := jwt.New(jwt.SigningMethodHS256) + token := jwt.New(jwt.SigningMethodHS256, jwt.CompressionNone) // Set some claims token.Claims["foo"] = "bar" token.Claims["exp"] = time.Now().Add(time.Hour * 72).Unix()