Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Panic in jwt_sign() function #242

Merged
merged 26 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Unreleased changes are available as `avenga/couper:edge` container.

* **Fixed**
* Missing support for `set_response_status` within a plain `error_handler` block ([#257](https://github.com/avenga/couper/pull/257))
* Panic in jwt_sign() and saml_sso_url() functions without proper configuration ([#243](https://github.com/avenga/couper/issues/243))

---

Expand Down
41 changes: 5 additions & 36 deletions accesscontrol/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"net/http"
"path/filepath"
"strings"
"time"

Expand Down Expand Up @@ -52,8 +50,7 @@ type JWTOptions struct {
ClaimsRequired []string
Name string // TODO: more generic (validate)
Source JWTSource
Key string
KeyFile string
Key []byte
}

func NewJWTSource(cookie, header string) JWTSource {
Expand Down Expand Up @@ -88,23 +85,6 @@ func NewJWT(options *JWTOptions) (*JWT, error) {
source: options.Source,
}

if options.Key != "" && options.KeyFile != "" {
return nil, confErr.Message("key and keyFile provided")
}

key := []byte(options.Key)
if options.KeyFile != "" {
k, err := readKeyFile(options.KeyFile)
if err != nil {
return nil, confErr.With(err)
}
key = k
}

if len(key) == 0 {
return nil, confErr.Message("key required")
}

if jwtAC.source.Type == Invalid {
return nil, confErr.Message("token source is invalid")
}
Expand All @@ -120,11 +100,11 @@ func NewJWT(options *JWTOptions) (*JWT, error) {
jwtAC.parser = parser

if jwtAC.algorithm.IsHMAC() {
jwtAC.hmacSecret = key
jwtAC.hmacSecret = options.Key
return jwtAC, nil
}

pubKey, err := parsePublicPEMKey(key)
pubKey, err := parsePublicPEMKey(options.Key)
if err != nil {
return nil, confErr.With(err)
}
Expand Down Expand Up @@ -281,8 +261,8 @@ func parsePublicPEMKey(key []byte) (pub *rsa.PublicKey, err error) {
}
pubKey, pubErr := x509.ParsePKCS1PublicKey(pemBlock.Bytes)
if pubErr != nil {
pkixKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
if err != nil {
pkixKey, pkerr := x509.ParsePKIXPublicKey(pemBlock.Bytes)
if pkerr != nil {
cert, cerr := x509.ParseCertificate(pemBlock.Bytes)
if cerr != nil {
return nil, jwt.ErrNotRSAPublicKey
Expand All @@ -301,17 +281,6 @@ func parsePublicPEMKey(key []byte) (pub *rsa.PublicKey, err error) {
return pubKey, nil
}

func readKeyFile(filePath string) ([]byte, error) {
if filePath != "" {
p, err := filepath.Abs(filePath)
if err != nil {
return nil, err
}
return ioutil.ReadFile(p)
}
return nil, nil
}

func isStringType(val interface{}) error {
switch val.(type) {
case string:
Expand Down
36 changes: 24 additions & 12 deletions accesscontrol/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/avenga/couper/errors"

"github.com/dgrijalva/jwt-go/v4"

ac "github.com/avenga/couper/accesscontrol"
"github.com/avenga/couper/config/reader"
"github.com/avenga/couper/config/request"
"github.com/avenga/couper/errors"
"github.com/avenga/couper/internal/test"
)

Expand All @@ -27,6 +28,7 @@ func Test_JWT_NewJWT_RSA(t *testing.T) {
claims map[string]interface{}
claimsRequired []string
pubKey []byte
pubKeyPath string
}

privKey, err := rsa.GenerateKey(rand.Reader, 2048)
Expand Down Expand Up @@ -74,7 +76,8 @@ QolLGgj3tz4NbDEitq+zKMr0uTHvP1Vyu1mXAflcpYcJA4ZmuB3Oj39e0U0gnmr/
fields fields
wantErr string
}{
{"missing key", fields{}, "configuration error: test_ac: key required"},
{"missing key-file path", fields{}, "configuration error: jwt key: read error: required: configured attribute or file"},
{"missing key-file", fields{pubKeyPath: "./not-there.file"}, "not-there.file: no such file or directory"},
{"PKIX", fields{
algorithm: alg,
pubKey: pubKeyBytesPKIX,
Expand All @@ -94,27 +97,36 @@ QolLGgj3tz4NbDEitq+zKMr0uTHvP1Vyu1mXAflcpYcJA4ZmuB3Oj39e0U0gnmr/
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%v / %s", signingMethod, tt.name), func(t *testing.T) {
t.Run(fmt.Sprintf("%v / %s", signingMethod, tt.name), func(subT *testing.T) {
key, rerr := reader.ReadFromAttrFile("jwt key", string(tt.fields.pubKey), tt.fields.pubKeyPath)
if rerr != nil {
logErr := rerr.(errors.GoError)
if tt.wantErr != "" && !strings.HasSuffix(logErr.LogError(), tt.wantErr) {
subT.Errorf("\nWant:\t%q\nGot:\t%q", tt.wantErr, logErr.LogError())
} else if tt.wantErr == "" {
subT.Fatal(logErr.LogError())
}
return
}

j, jerr := ac.NewJWT(&ac.JWTOptions{
Algorithm: tt.fields.algorithm,
Claims: tt.fields.claims,
ClaimsRequired: tt.fields.claimsRequired,
Name: "test_ac",
Key: string(tt.fields.pubKey),
Key: key,
Source: ac.NewJWTSource("", "Authorization"),
})
if jerr != nil {
gerr := jerr.(errors.GoError)
if tt.wantErr != gerr.LogError() {
t.Errorf("error: %v, want: %v", gerr.LogError(), tt.wantErr)
subT.Errorf("error: %v, want: %v", gerr.LogError(), tt.wantErr)
}
} else if tt.wantErr != "" {
t.Errorf("error expected: %v", tt.wantErr)
subT.Errorf("error expected: %v", tt.wantErr)
}
if tt.wantErr == "" {
if j == nil {
t.Errorf("JWT struct expected")
}
if tt.wantErr == "" && j == nil {
subT.Errorf("JWT struct expected")
}
})
}
Expand Down Expand Up @@ -225,7 +237,7 @@ func Test_JWT_Validate(t *testing.T) {
ClaimsRequired: tt.fields.claimsRequired,
Name: "test_ac",
Source: tt.fields.source,
Key: string(tt.fields.pubKey),
Key: tt.fields.pubKey,
})
if err != nil {
t.Error(err)
Expand Down
23 changes: 5 additions & 18 deletions accesscontrol/saml2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"encoding/base64"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"path/filepath"
"sort"

saml2 "github.com/russellhaering/gosaml2"
Expand All @@ -25,28 +23,17 @@ type Saml2 struct {
sp *saml2.SAMLServiceProvider
}

func NewSAML2ACS(metadataFile string, name string, acsUrl string, spEntityId string, arrayAttributes []string) (*Saml2, error) {
p, err := filepath.Abs(metadataFile)
if err != nil {
return nil, err
}

rawMetadata, err := ioutil.ReadFile(p)
if err != nil {
return nil, err
}

metadata := &types.EntityDescriptor{}
err = xml.Unmarshal(rawMetadata, metadata)
if err != nil {
func NewSAML2ACS(metadata []byte, name string, acsUrl string, spEntityId string, arrayAttributes []string) (*Saml2, error) {
metadataEntity := &types.EntityDescriptor{}
if err := xml.Unmarshal(metadata, metadataEntity); err != nil {
return nil, err
}

certStore := dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{},
}

for _, kd := range metadata.IDPSSODescriptor.KeyDescriptors {
for _, kd := range metadataEntity.IDPSSODescriptor.KeyDescriptors {
for idx, xcert := range kd.KeyInfo.X509Data.X509Certificates {
if xcert.Data == "" {
return nil, fmt.Errorf("metadata certificate(%d) must not be empty", idx)
Expand All @@ -69,7 +56,7 @@ func NewSAML2ACS(metadataFile string, name string, acsUrl string, spEntityId str
AssertionConsumerServiceURL: acsUrl,
AudienceURI: spEntityId,
IDPCertificateStore: &certStore,
IdentityProviderIssuer: metadata.EntityID,
IdentityProviderIssuer: metadataEntity.EntityID,
}
if arrayAttributes != nil {
sort.Strings(arrayAttributes)
Expand Down
36 changes: 24 additions & 12 deletions accesscontrol/saml2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,46 @@ import (
"github.com/russellhaering/gosaml2/types"

ac "github.com/avenga/couper/accesscontrol"
"github.com/avenga/couper/config/reader"
"github.com/avenga/couper/errors"
"github.com/avenga/couper/internal/test"
)

func Test_NewSAML2ACS(t *testing.T) {
helper := test.New(t)

type testCase struct {
metadataFile, acsUrl, spEntityId string
arrayAttributes []string
expErrMsg string
shouldFail bool
}

for _, tc := range []testCase{
{"testdata/idp-metadata.xml", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{}, "", false},
{"not-there.xml", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{}, "not-there.xml: no such file or directory", true},
} {
sa, err := ac.NewSAML2ACS(tc.metadataFile, "test", tc.acsUrl, tc.spEntityId, tc.arrayAttributes)
if tc.shouldFail && sa != nil {
t.Error("Expected no successful saml acs creation")
}

if tc.shouldFail && err != nil && tc.expErrMsg != "" {
if !strings.HasSuffix(err.Error(), tc.expErrMsg) {
t.Errorf("Expected error message suffix: %q, got: %q", tc.expErrMsg, err.Error())
metadata, err := reader.ReadFromAttrFile("saml2", "", tc.metadataFile)
if err != nil {
readErr := err.(errors.GoError)
if tc.shouldFail {
if !strings.HasSuffix(readErr.LogError(), tc.expErrMsg) {
t.Errorf("Want: %q, got: %q", tc.expErrMsg, readErr.LogError())
}
continue
}
} else if err != nil {
t.Error(err)
continue
}

_, err = ac.NewSAML2ACS(metadata, "test", tc.acsUrl, tc.spEntityId, tc.arrayAttributes)
helper.Must(err)
}
}

func Test_SAML2ACS_Validate(t *testing.T) {
sa, err := ac.NewSAML2ACS("testdata/idp-metadata.xml", "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml")
sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
if err != nil || sa == nil {
t.Fatal("Expected a saml acs object")
}
Expand Down Expand Up @@ -87,7 +97,8 @@ func Test_SAML2ACS_Validate(t *testing.T) {
}

func Test_SAML2ACS_ValidateAssertionInfo(t *testing.T) {
sa, err := ac.NewSAML2ACS("testdata/idp-metadata.xml", "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml")
sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
if err != nil || sa == nil {
t.Fatal("Expected a saml acs object")
}
Expand Down Expand Up @@ -124,7 +135,8 @@ func Test_SAML2ACS_ValidateAssertionInfo(t *testing.T) {
}

func Test_SAML2ACS_GetAssertionData(t *testing.T) {
sa, err := ac.NewSAML2ACS("testdata/idp-metadata.xml", "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml")
sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"})
if err != nil || sa == nil {
t.Fatal("Expected a saml acs object")
}
Expand Down
3 changes: 3 additions & 0 deletions config/ac_saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type SAML struct {
Remain hcl.Body `hcl:",remain"`
SpAcsUrl string `hcl:"sp_acs_url"`
SpEntityId string `hcl:"sp_entity_id"`

// internally used
MetadataBytes []byte
}

// HCLBody implements the <Body> interface.
Expand Down
17 changes: 17 additions & 0 deletions config/configload/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/avenga/couper/config"
hclbody "github.com/avenga/couper/config/body"
"github.com/avenga/couper/config/parser"
"github.com/avenga/couper/config/reader"
"github.com/avenga/couper/errors"
"github.com/avenga/couper/eval"
)
Expand Down Expand Up @@ -217,6 +218,22 @@ func LoadConfig(body hcl.Body, src []byte, filename string) (*config.Couper, err
}

// Prepare dynamic functions
for _, profile := range couperConfig.Definitions.JWTSigningProfile {
key, err := reader.ReadFromAttrFile("jwt_signing_profile key", profile.Key, profile.KeyFile)
if err != nil {
return nil, errors.Configuration.Label(profile.Name).With(err)
}
profile.KeyBytes = key
}

for _, saml := range couperConfig.Definitions.SAML {
metadata, err := reader.ReadFromFile("saml2 idp_metadata_file", saml.IdpMetadataFile)
if err != nil {
return nil, errors.Configuration.Label(saml.Name).With(err)
}
saml.MetadataBytes = metadata
}

couperConfig.Context = evalContext.
WithJWTProfiles(couperConfig.Definitions.JWTSigningProfile).
WithSAML(couperConfig.Definitions.SAML)
Expand Down
3 changes: 3 additions & 0 deletions config/jwt_signing_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ type JWTSigningProfile struct {
Name string `hcl:"name,label"`
SignatureAlgorithm string `hcl:"signature_algorithm"`
TTL string `hcl:"ttl"`

// internally used
KeyBytes []byte
}
Loading