diff --git a/cmd/jwt/app.go b/cmd/jwt/app.go index 4068a805..6f542c33 100644 --- a/cmd/jwt/app.go +++ b/cmd/jwt/app.go @@ -182,7 +182,7 @@ func signToken() error { } // create a new token - token := jwt.New(alg) + token := jwt.New(alg, jwt.CompressionNone) token.Claims = claims if isEs() { diff --git a/compression_method.go b/compression_method.go new file mode 100644 index 00000000..dff9f524 --- /dev/null +++ b/compression_method.go @@ -0,0 +1,103 @@ +package jwt + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" +) + +var ( + // CompressionNone does not perform any data compression/decompression + CompressionNone CompressionMethod + // CompressionGzip compresses the claims part of the JWT with gzip algorithm + CompressionGzip CompressionMethod +) + +var ( + compressionMethods map[string]CompressionMethod +) + +// CompressionMethod is an interface used to compress/decompress the Claims part of the JWT +type CompressionMethod interface { + // Alg returns the name of the compression algorithm. It is saved in the token header + Alg() string + // Compress takes uncompressed data, and returns the compression result` + Compress(data []byte) ([]byte, error) + // Decompress takes compressed data and returns uncompressed version + Decompress(data []byte) ([]byte, error) +} + +type compressionGzip struct{} + +type compressionNone struct{} + +func (c *compressionGzip) Alg() string { + return "gzip" +} + +func (c *compressionGzip) Compress(data []byte) ([]byte, error) { + var buffer = &bytes.Buffer{} + var writer = gzip.NewWriter(buffer) + + if _, err := writer.Write(data); err != nil { + writer.Close() + return nil, err + } + + writer.Close() + return buffer.Bytes(), nil +} + +func (c *compressionGzip) Decompress(data []byte) ([]byte, error) { + var buffer = bytes.NewBuffer(data) + var reader, err = gzip.NewReader(buffer) + defer reader.Close() + if err != nil { + return nil, err + } + + return ioutil.ReadAll(reader) +} + +func (c *compressionNone) Alg() string { + return "none" +} + +func (c *compressionNone) Compress(data []byte) ([]byte, error) { + return data, nil +} + +func (c *compressionNone) Decompress(data []byte) ([]byte, error) { + return data, nil +} + +// RegisterCompressionMethod adds support for additional compression method in the runtime. +// The name value is saved in the token header and later used to retrieve the method interface +// used to decompress the header +func RegisterCompressionMethod(name string, method CompressionMethod) { + compressionMethods[name] = method +} + +func getCompressionMethod(alg interface{}) (CompressionMethod, error) { + var algString, ok = alg.(string) + if ok == false || len(algString) == 0 { + return compressionMethods["none"], nil + } + + var method = compressionMethods[algString] + if method == nil { + return nil, fmt.Errorf("Compression method %s not registered", alg) + } + + return method, nil +} + +func init() { + CompressionNone = &compressionNone{} + CompressionGzip = &compressionGzip{} + + compressionMethods = make(map[string]CompressionMethod) + compressionMethods["none"] = CompressionNone + compressionMethods["gzip"] = CompressionGzip +} diff --git a/example_test.go b/example_test.go index edb48e4d..a5cee509 100644 --- a/example_test.go +++ b/example_test.go @@ -2,8 +2,9 @@ package jwt_test import ( "fmt" - "github.com/dgrijalva/jwt-go" "time" + + "github.com/dgrijalva/jwt-go" ) func ExampleParse(myToken string, myLookupKey func(interface{}) (interface{}, error)) { @@ -20,7 +21,7 @@ func ExampleParse(myToken string, myLookupKey func(interface{}) (interface{}, er func ExampleNew(mySigningKey []byte) (string, error) { // Create the token - token := jwt.New(jwt.SigningMethodHS256) + token := jwt.New(jwt.SigningMethodHS256, jwt.CompressionNone) // Set some claims token.Claims["foo"] = "bar" token.Claims["exp"] = time.Now().Add(time.Hour * 72).Unix() diff --git a/parser.go b/parser.go index 3fc27bfe..1f3bbe9b 100644 --- a/parser.go +++ b/parser.go @@ -37,6 +37,15 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { if claimBytes, err = DecodeSegment(parts[1]); err != nil { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} } + + var compression CompressionMethod + if compression, err = getCompressionMethod(token.Header["cpr"]); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + if claimBytes, err = compression.Decompress(claimBytes); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) if p.UseJSONNumber { dec.UseNumber() diff --git a/parser_test.go b/parser_test.go index 9115017b..64a6c901 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,12 +3,13 @@ package jwt_test import ( "encoding/json" "fmt" - "github.com/dgrijalva/jwt-go" "io/ioutil" "net/http" "reflect" "testing" "time" + + "github.com/dgrijalva/jwt-go" ) var ( @@ -136,13 +137,13 @@ func init() { } } -func makeSample(c map[string]interface{}) string { +func makeSample(c map[string]interface{}, compression jwt.CompressionMethod) string { key, e := ioutil.ReadFile("test/sample_key") if e != nil { panic(e.Error()) } - token := jwt.New(jwt.SigningMethodRS256) + token := jwt.New(jwt.SigningMethodRS256, compression) token.Claims = c s, e := token.SignedString(key) @@ -156,7 +157,7 @@ func makeSample(c map[string]interface{}) string { func TestParser_Parse(t *testing.T) { for _, data := range jwtTestData { if data.tokenString == "" { - data.tokenString = makeSample(data.claims) + data.tokenString = makeSample(data.claims, jwt.CompressionNone) } var token *jwt.Token @@ -202,7 +203,7 @@ func TestParseRequest(t *testing.T) { } if data.tokenString == "" { - data.tokenString = makeSample(data.claims) + data.tokenString = makeSample(data.claims, jwt.CompressionNone) } r, _ := http.NewRequest("GET", "/", nil) @@ -225,9 +226,34 @@ func TestParseRequest(t *testing.T) { } } +func TestParseWithCompression(t *testing.T) { + var token = jwt.New(jwt.SigningMethodHS256, jwt.CompressionGzip) + var claimsMap = map[string]interface{}{ + "claim1": "value1", + } + token.Claims = claimsMap + + var tokenString, err = token.SignedString([]byte("TEST KEY")) + if err != nil { + t.Fatalf("Unexpected error creating token string: %s", err.Error()) + } + + var parsedToken *jwt.Token + parsedToken, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte("TEST KEY"), nil + }) + if err != nil { + t.Errorf("Error while parsing the token: %s", err.Error()) + } + + if reflect.DeepEqual(parsedToken.Claims, claimsMap) == false { + t.Errorf("Claims mismatch") + } +} + // Helper method for benchmarking various methods func benchmarkSigning(b *testing.B, method jwt.SigningMethod, key interface{}) { - t := jwt.New(method) + t := jwt.New(method, jwt.CompressionNone) b.RunParallel(func(pb *testing.PB) { for pb.Next() { if _, err := t.SignedString(key); err != nil { diff --git a/token.go b/token.go index d35aaa4a..edaefbcd 100644 --- a/token.go +++ b/token.go @@ -30,15 +30,16 @@ type Token struct { Valid bool // Is the token valid? Populated when you Parse/Verify a token } -// Create a new Token. Takes a signing method -func New(method SigningMethod) *Token { +// Create a new Token. Takes a signing method and compression method +func New(signingMethod SigningMethod, compressionMethod CompressionMethod) *Token { return &Token{ Header: map[string]interface{}{ "typ": "JWT", - "alg": method.Alg(), + "alg": signingMethod.Alg(), + "cpr": compressionMethod.Alg(), }, Claims: make(map[string]interface{}), - Method: method, + Method: signingMethod, } } @@ -61,22 +62,27 @@ func (t *Token) SignedString(key interface{}) (string, error) { // the SignedString. func (t *Token) SigningString() (string, error) { var err error - parts := make([]string, 2) - for i, _ := range parts { - var source map[string]interface{} - if i == 0 { - source = t.Header - } else { - source = t.Claims - } + var parts = []string{} + var jsonValue []byte + var compression CompressionMethod - var jsonValue []byte - if jsonValue, err = json.Marshal(source); err != nil { - return "", err - } + if jsonValue, err = json.Marshal(t.Header); err != nil { + return "", err + } + parts = append(parts, EncodeSegment(jsonValue)) - parts[i] = EncodeSegment(jsonValue) + if jsonValue, err = json.Marshal(t.Claims); err != nil { + return "", err } + + if compression, err = getCompressionMethod(t.Header["cpr"]); err != nil { + return "", err + } + if jsonValue, err = compression.Compress(jsonValue); err != nil { + return "", err + } + parts = append(parts, EncodeSegment(jsonValue)) + return strings.Join(parts, "."), nil } diff --git a/token_test.go b/token_test.go new file mode 100644 index 00000000..1f3bc2da --- /dev/null +++ b/token_test.go @@ -0,0 +1,46 @@ +package jwt_test + +import ( + "strings" + "testing" + + "github.com/dgrijalva/jwt-go" +) + +func TestNewTokenWithGzipCompression(t *testing.T) { + var token = jwt.New(jwt.SigningMethodHS256, jwt.CompressionGzip) + + token.Claims = map[string]interface{}{ + "claim1": "testvalue1", + "claim2": 42, + } + + var tokenString, err = token.SignedString([]byte("TEST KEY")) + if err != nil { + t.Errorf("Error signing gzip compressed claims: %s", err) + } + + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + t.Errorf("Token string has %d parts", len(parts)) + } + if parts[1] != "H4sIAAAJbogA_6pWSs5JzMw1VLJSKkktLilLzClNNVTSgYgaKVmZGNUCAgAA__9vre6IIwAAAA" { + t.Error("Token claims not encoded properly") + } +} + +func TestNewTokenWithWrongCompressionAlg(t *testing.T) { + var token = jwt.New(jwt.SigningMethodHS256, jwt.CompressionNone) + token.Header["cpr"] = "dummy" // set wrong compression method + token.Claims = map[string]interface{}{ + "claim1": "testvalue1", + "claim2": 42, + } + + var _, err = token.SignedString([]byte("TEST KEY")) + if err == nil { + t.Errorf("Expected error") + } else if err.Error() != "Compression method dummy not registered" { + t.Errorf("Unexpected error description: %s", err.Error()) + } +}