From 9cbfe1f64ff8e22be6e83a32646eef05565d27bf Mon Sep 17 00:00:00 2001 From: Jorge Luis Betancourt Gonzalez Date: Fri, 26 Jun 2020 13:43:44 +0200 Subject: [PATCH] Add basic tests for the verifier middleware --- go.mod | 9 +- go.sum | 6 ++ main.go | 84 ++++++++------- main_test.go | 252 ++++++++++++++++++++++++++++++++++++++++++++ testdata/rsa_1.pem | 27 +++++ testdata/rsa_gen.sh | 21 ++++ 6 files changed, 360 insertions(+), 39 deletions(-) create mode 100644 main_test.go create mode 100644 testdata/rsa_1.pem create mode 100755 testdata/rsa_gen.sh diff --git a/go.mod b/go.mod index 4971b2b..011299f 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,13 @@ go 1.14 require ( github.com/coreos/go-oidc v2.2.1+incompatible - github.com/google/go-cmp v0.4.1 // indirect + github.com/google/go-cmp v0.4.1 + github.com/kelseyhightower/envconfig v1.4.0 github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect - github.com/stretchr/testify v1.6.1 // indirect + github.com/quay/jwtproxy v0.0.4 + github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d // indirect - gopkg.in/square/go-jose.v2 v2.5.1 // indirect + gopkg.in/square/go-jose.v2 v2.5.1 + gopkg.in/yaml.v2 v2.3.0 // indirect ) diff --git a/go.sum b/go.sum index 9af92b4..0ef8e08 100644 --- a/go.sum +++ b/go.sum @@ -7,10 +7,14 @@ github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0= github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= +github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= +github.com/quay/jwtproxy v0.0.4 h1:M7YZxrqLaY0MA20AkWqH+1HGFjxQPLmNrC8TjrkfbwQ= +github.com/quay/jwtproxy v0.0.4/go.mod h1:Q0Zg96r0uvf49Ny3uRJ0Y09CCdtXU54LBntn6NZLShg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -37,5 +41,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index c9b0d1f..4c4fdb1 100644 --- a/main.go +++ b/main.go @@ -20,45 +20,44 @@ import ( "log" "net/http" "net/http/httputil" - "os" "strings" "github.com/coreos/go-oidc" + "github.com/kelseyhightower/envconfig" ) -var ( - ctx = context.Background() - authDomain = os.Getenv("AUTHDOMAIN") - certsURL = fmt.Sprintf("%s/cdn-cgi/access/certs", authDomain) - - // policyAUD is your application AUD value - policyAUD = os.Getenv("POLICYAUD") - - // forwardHeader is the header to be set from the email claim embedded in the JWT token - forwardHeader = os.Getenv("FORWARDHEADER") +const ( + // CFJWTHeader is the header key set by Cloudflare Access after a successful authentication + CFJWTHeader = "Cf-Access-Jwt-Assertion" +) - // forwardHost is the host to bet used to forward the request. If set it will override the Host - // header of the original request - forwardHost = os.Getenv("FORWARDHOST") +// CloudflareClaim holds the claims about the End-User/Authentication event. +type CloudflareClaim struct { + Email string `json:"email"` + Type string `json:"type"` +} - // listenAddr is the port where this proxy will be listening - listenAddr = os.Getenv("ADDR") +// Config is the general configuration (read from environment variables) +type Config struct { + AuthDomain string + PolicyAUD string + ForwardHeader string + ForwardHost string + ListenAddr string `envconfig:"ADDR"` +} - config = &oidc.Config{ - ClientID: policyAUD, - } - keySet = oidc.NewRemoteKeySet(ctx, certsURL) - verifier = oidc.NewVerifier(authDomain, keySet, config) +var ( + ctx = context.Background() ) // VerifyToken is a middleware to verify a CF Access token -func VerifyToken(next http.Handler) http.Handler { +func VerifyToken(next http.Handler, tokenVerifier *oidc.IDTokenVerifier, cfg *Config) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { headers := r.Header // Make sure that the incoming request has our token header // Could also look in the cookies for CF_AUTHORIZATION - accessJWT := headers.Get("Cf-Access-Jwt-Assertion") + accessJWT := headers.Get(CFJWTHeader) if accessJWT == "" { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("No token on the request")) @@ -67,7 +66,7 @@ func VerifyToken(next http.Handler) http.Handler { // Verify the access token ctx := r.Context() - token, err := verifier.Verify(ctx, accessJWT) + token, err := tokenVerifier.Verify(ctx, accessJWT) if err != nil { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(fmt.Sprintf("Invalid token: %s", err.Error()))) @@ -75,18 +74,14 @@ func VerifyToken(next http.Handler) http.Handler { } // Extract custom claims - var claims struct { - Email string `json:"email"` - Type string `json:"type"` - } - + var claims CloudflareClaim if err := token.Claims(&claims); err != nil { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(fmt.Sprintf("Invalid claims in token: %s", err.Error()))) } // set the authentication forward header before proxying the request - r.Header.Add(forwardHeader, claims.Email) + r.Header.Add(cfg.ForwardHeader, claims.Email) log.Printf("Authenticated as: %s", claims.Email) next.ServeHTTP(w, r) @@ -96,23 +91,40 @@ func VerifyToken(next http.Handler) http.Handler { } func main() { + var cfg Config + err := envconfig.Process("", &cfg) + if err != nil { + log.Fatal(err.Error()) + } + + var ( + certsURL = fmt.Sprintf("%s/cdn-cgi/access/certs", cfg.AuthDomain) + + config = &oidc.Config{ + ClientID: cfg.PolicyAUD, + } + keySet = oidc.NewRemoteKeySet(ctx, certsURL) + verifier = oidc.NewVerifier(cfg.AuthDomain, keySet, config) + ) + director := func(req *http.Request) { req.Header.Add("X-Forwarded-Host", req.Host) req.Header.Add("X-Origin-Host", "cloudflare-access-proxy") + // TODO: should we trust on the Schema of the original request? req.URL.Scheme = "http" - if len(strings.TrimSpace(forwardHost)) > 0 { - req.URL.Host = forwardHost + if len(strings.TrimSpace(cfg.ForwardHost)) > 0 { + req.URL.Host = cfg.ForwardHost } } proxy := &httputil.ReverseProxy{Director: director} http.Handle("/", VerifyToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxy.ServeHTTP(w, r) - }))) + }), verifier, &cfg)) - log.Printf("Listening on %s", listenAddr) - if err := http.ListenAndServe(listenAddr, nil); err != nil { - log.Fatalf("Unable to start server on [%s], error: %s", listenAddr, err.Error()) + log.Printf("Listening on %s", cfg.ListenAddr) + if err := http.ListenAndServe(cfg.ListenAddr, nil); err != nil { + log.Fatalf("Unable to start server on [%s], error: %s", cfg.ListenAddr, err.Error()) } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..e49d1e8 --- /dev/null +++ b/main_test.go @@ -0,0 +1,252 @@ +// Copyright 2020 Jorge Luis Betancourt +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "crypto" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coreos/go-oidc" + "github.com/google/go-cmp/cmp" + "gopkg.in/square/go-jose.v2" +) + +var ( + now, _ = time.Parse(time.RFC3339Nano, "2009-11-10T23:00:00Z") + valid, _ = time.Parse(time.RFC3339Nano, "2009-11-11T00:00:00Z") +) + +// utilities for loading JOSE keys. +func loadRSAKey(t *testing.T, filepath string, alg jose.SignatureAlgorithm) *jose.JSONWebKey { + return loadKey(t, filepath, alg, func(b []byte) (interface{}, error) { + key, err := x509.ParsePKCS1PrivateKey(b) + if err != nil { + return nil, err + } + + return key.Public(), nil + }) +} + +func loadRSAPrivKey(t *testing.T, filepath string, alg jose.SignatureAlgorithm) *jose.JSONWebKey { + return loadKey(t, filepath, alg, func(b []byte) (interface{}, error) { + return x509.ParsePKCS1PrivateKey(b) + }) +} + +func loadKey(t *testing.T, filepath string, alg jose.SignatureAlgorithm, unmarshal func([]byte) (interface{}, error)) *jose.JSONWebKey { + data, err := ioutil.ReadFile(filepath) + if err != nil { + t.Fatalf("load file: %v", err) + } + block, _ := pem.Decode(data) + if block == nil { + t.Fatalf("file contained no PEM encoded data: %s", filepath) + } + priv, err := unmarshal(block.Bytes) + if err != nil { + t.Fatalf("unmarshal key: %v", err) + } + key := &jose.JSONWebKey{Key: priv, Use: "sig", Algorithm: string(alg)} + thumbprint, err := key.Thumbprint(crypto.SHA256) + if err != nil { + t.Fatalf("computing thumbprint: %v", err) + } + key.KeyID = hex.EncodeToString(thumbprint) + return key +} + +// staticKeySet implements oidc.KeySet. +type staticKeySet struct { + keys []*jose.JSONWebKey +} + +func (s *staticKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { + jws, err := jose.ParseSigned(jwt) + if err != nil { + return nil, err + } + if len(jws.Signatures) == 0 { + return nil, fmt.Errorf("jwt contained no signatures") + } + kid := jws.Signatures[0].Header.KeyID + + for _, key := range s.keys { + if key.KeyID == kid { + return jws.Verify(key) + } + } + + return nil, fmt.Errorf("no keys matches jwk keyid") +} + +type claimsTest struct { + name string + now time.Time + signingKey *jose.JSONWebKey + pubKeys []*jose.JSONWebKey + claims string + upstream bool + fn func(*testing.T, *http.Request, *httptest.ResponseRecorder) +} + +func (c *claimsTest) run(t *testing.T) { + cfg := &Config{ + AuthDomain: "https://your-own.cloudflareaccess.com", + PolicyAUD: "my-policy-aud", + ForwardHeader: "X-WEBAUTH-USER", + ForwardHost: "localhost:3000", + ListenAddr: ":3002", + } + + config := &oidc.Config{ + ClientID: "my-policy-aud", + Now: func() time.Time { return c.now }, + } + verifier := oidc.NewVerifier( + "https://your-own.cloudflareaccess.com", + &staticKeySet{ + keys: c.pubKeys, + }, + config, + ) + + // Sign and serialize the claims in a JWT. + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(c.signingKey.Algorithm), + Key: c.signingKey, + }, nil) + if err != nil { + t.Fatalf("initialize signer: %v", err) + } + + jws, err := signer.Sign([]byte(c.claims)) + if err != nil { + t.Fatalf("sign claims: %v", err) + } + + token, err := jws.CompactSerialize() + if err != nil { + t.Fatalf("serialize token: %v", err) + } + + upstream := false + rr := httptest.NewRecorder() + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstream = true + if c.fn != nil { + c.fn(t, r, rr) + } + }) + + req := httptest.NewRequest("GET", "http://domain.com", nil) + req.Header.Add(CFJWTHeader, token) + + VerifyToken(next, verifier, cfg).ServeHTTP(rr, req) + + if c.upstream != upstream { + t.Fatalf("Forward to upstream got: %t, want: %t", upstream, c.upstream) + } + + // We do not expect the upstream to be called so the VerifyToken middleware must have sent a + // reply back. + if c.upstream == false && c.fn != nil { + c.fn(t, req, rr) + } +} + +func TestVerifierMiddleware(t *testing.T) { + tests := []claimsTest{ + { + name: "valid token", + now: now, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://your-own.cloudflareaccess.com", + "aud": "my-policy-aud", + "email": "test@example.com", + "type": "app", + "exp": %d + }`, valid.Unix()), + upstream: true, + fn: func(t *testing.T, r *http.Request, rr *httptest.ResponseRecorder) { + email := r.Header.Get("X-WEBAUTH-USER") + if diff := cmp.Diff(email, "test@example.com"); diff != "" { + t.Errorf("Wrong user was authenticated (-want +got):\n%s", diff) + } + }, + }, + { + name: "expired token", + now: now.Add(24 * time.Hour), + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://your-own.cloudflareaccess.com", + "aud": "my-policy-aud", + "email": "test@example.com", + "type": "app", + "exp": %d + }`, valid.Unix()), + upstream: false, + fn: func(t *testing.T, r *http.Request, rr *httptest.ResponseRecorder) { + expected := `Invalid token: oidc: token is expired (Token Expiry: 2009-11-11 01:00:00 +0100 CET)` + if diff := cmp.Diff(expected, rr.Body.String()); diff != "" { + t.Errorf("Wrong user was authenticated (-want +got):\n%s", diff) + } + }, + }, + { + name: "invalid token", + now: now, + signingKey: loadRSAPrivKey(t, "testdata/rsa_1.pem", jose.RS256), + pubKeys: []*jose.JSONWebKey{ + loadRSAKey(t, "testdata/rsa_1.pem", jose.RS256), + }, + claims: fmt.Sprintf(`{ + "iss": "https://another-domain.cloudflareaccess.com", + "aud": "my-policy-aud", + "email": "test@example.com", + "type": "app", + "exp": %d + }`, valid.Unix()), + upstream: false, + fn: func(t *testing.T, r *http.Request, rr *httptest.ResponseRecorder) { + expected := `Invalid token: oidc: id token issued by a different provider, expected "https://your-own.cloudflareaccess.com" got "https://another-domain.cloudflareaccess.com"` + if diff := cmp.Diff(expected, rr.Body.String()); diff != "" { + t.Errorf("Wrong user was authenticated (-want +got):\n%s", diff) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, test.run) + } +} diff --git a/testdata/rsa_1.pem b/testdata/rsa_1.pem new file mode 100644 index 0000000..2c4a923 --- /dev/null +++ b/testdata/rsa_1.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA0pXWMYjWRjBEds/fKj/u9r2E6SIDx0J+TAg+eyVeR20Ky9jZ +mIXW5zSxE/EKpNQpiBWm1e6G9kmhMuqjr7g455S7E+3rD3OVkdTT6SU5AKBNSFoR +XUd+G/YJEtRzrpEYNtEJHkxUxWuyfCHblHSt+wsrE6t0DccCqC87lKQiGb/QfC8u +P6ZS99SCjKBEFp1fZvyNkYwStFc2OH5fBGPXXb6SNsquvDeKX9NeWjXkmxDkbOg2 +kSkel4s/zw5KwcW3JzERfEcLStrDQ8fRbJ1C3uC088sUk4q4APQmKI/8FTvJe431 +Vne9sOSptphiqCjlR+Knja58rc/vt4TkSPZf2wIDAQABAoIBAQDO3UgjMtOi8Wlf ++YW1IEbjdXrp9XMWu9gLYpHWMPgzXAeeBfCDJv7b8uP8ve2By7TcrMBOKVnE+MF0 +nhCb3nFv9KftxOsDK70DG7nrrpgXaGFisK+cHU3hs8hoCfF1y6yotKGrdLpVkR0t +Wak1ZYU/NlJjqSqBGj0e7/8sXivtc7oME8tBBRBCEa8OqPqaelCInfFF1rX5vmxX +pQjPpZoA+vroSJy8SYE0N5oqtGwOPT+9rVuDOL10eaMbGUcssZl8ofwuvzOYPMW4 +KFSVtvdtKnACq94Qy6XQbK5hZbZXSpzxANKq8SFyG2N1wOlpu/ktdXqkyDs08AZY +c/KkpXspAoGBAPdC73GOZn/hxzkwZ2Dl+S9rgrLT3VbeuhMp6GXSdiT+f9babMuw +HlYw6uULmvL1gD/0GmyWrHopPFJxodBG5SlwYS5wl49slcxeKCjK26vbNfK2eCbu +9uMtED4dN/5NlaXF4hqy/FmSyaFhQT+5hvx8n/zvLsgpuSQ+SCiDAHMfAoGBANoH +FCZeCWzzUFhObYG9wxGJ9FBPQa0htafIYEgTwezlKPsrfXfCTnVg1lLkr6Z4IwYQ +9VufJZNAc5V0X9H/ceyKJYxhQ+E01NEVzVpoK8fOC4yCYSYtbJnqkOUQzZJzkjFT +mNcIa8o4UrBOWzMhMQa0AOZH4VrbtZDCZhid+hfFAoGAAbKh9kOmDIa+WXQtoYqy +tVKlqRivUmNhH7SP9fMGAKcGtbD2QkfJTYo0crIrtDNfWBETBV/be1NBKMfC9q0l +8azl3e3D/KYgOTEEUZNjAsEUk8AQ/yNw6opqrCKDOenKd0LulIRaGztYyxTh39Ak +TyOD7bauuY0fylHrKOwNWr0CgYEAsVZ0o0h1rjKyNUGFfLQWyFtHZ1Mv/lye3tvy +xG2dnMxAaxvSr+hR3NNpQH9WB7dL9ZExoNZvv7f6y6OelLaLuXQcWnR6u+E3AOIU +5+Y3RgtoBV+/GUh1PzQ1qrviGa77SDfQ54an9hGd4F27fHkQ4XzkBmqM+FQg+J/G +X1uPomkCgYBo4ZBEA20Wvf1k2iWOVdfsxZNeOLxwcN5x89dAvm0v6RHg2NMy+XKw +Rj+YRuudFdxfg39J/V/Md9qsvjW+4FthD8GhgPs22dksV+7j6ApWkYTmIKG4rmh3 +RhHOr6uLg9BeShnlvMMaMJKf2eA7SaVtmuS6uBGgEUNaa3qEBq0R+Q== +-----END RSA PRIVATE KEY----- diff --git a/testdata/rsa_gen.sh b/testdata/rsa_gen.sh new file mode 100755 index 0000000..dd67d8e --- /dev/null +++ b/testdata/rsa_gen.sh @@ -0,0 +1,21 @@ +#!/bin/bash -e + +# Copyright 2020 Jorge Luis Betancourt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +for N in $(seq 1 1); do + ssh-keygen -t rsa -b 2048 -f rsa_$N.pem -N '' -m PEM +done + +rm *.pub