diff --git a/CHANGELOG.md b/CHANGELOG.md index cb5511798..5ba76b395 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Unreleased changes are available as `avenga/couper:edge` container. * **Fixed** * Missing support for `set_response_status` within a plain `error_handler` block ([#257](https://github.com/avenga/couper/pull/257)) + * Panic in jwt_sign() and saml_sso_url() functions without proper configuration ([#243](https://github.com/avenga/couper/issues/243)) --- diff --git a/accesscontrol/jwt.go b/accesscontrol/jwt.go index cc1f79706..a7861b6bb 100644 --- a/accesscontrol/jwt.go +++ b/accesscontrol/jwt.go @@ -6,9 +6,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "io/ioutil" "net/http" - "path/filepath" "strings" "time" @@ -52,8 +50,7 @@ type JWTOptions struct { ClaimsRequired []string Name string // TODO: more generic (validate) Source JWTSource - Key string - KeyFile string + Key []byte } func NewJWTSource(cookie, header string) JWTSource { @@ -88,23 +85,6 @@ func NewJWT(options *JWTOptions) (*JWT, error) { source: options.Source, } - if options.Key != "" && options.KeyFile != "" { - return nil, confErr.Message("key and keyFile provided") - } - - key := []byte(options.Key) - if options.KeyFile != "" { - k, err := readKeyFile(options.KeyFile) - if err != nil { - return nil, confErr.With(err) - } - key = k - } - - if len(key) == 0 { - return nil, confErr.Message("key required") - } - if jwtAC.source.Type == Invalid { return nil, confErr.Message("token source is invalid") } @@ -120,11 +100,11 @@ func NewJWT(options *JWTOptions) (*JWT, error) { jwtAC.parser = parser if jwtAC.algorithm.IsHMAC() { - jwtAC.hmacSecret = key + jwtAC.hmacSecret = options.Key return jwtAC, nil } - pubKey, err := parsePublicPEMKey(key) + pubKey, err := parsePublicPEMKey(options.Key) if err != nil { return nil, confErr.With(err) } @@ -281,8 +261,8 @@ func parsePublicPEMKey(key []byte) (pub *rsa.PublicKey, err error) { } pubKey, pubErr := x509.ParsePKCS1PublicKey(pemBlock.Bytes) if pubErr != nil { - pkixKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes) - if err != nil { + pkixKey, pkerr := x509.ParsePKIXPublicKey(pemBlock.Bytes) + if pkerr != nil { cert, cerr := x509.ParseCertificate(pemBlock.Bytes) if cerr != nil { return nil, jwt.ErrNotRSAPublicKey @@ -301,17 +281,6 @@ func parsePublicPEMKey(key []byte) (pub *rsa.PublicKey, err error) { return pubKey, nil } -func readKeyFile(filePath string) ([]byte, error) { - if filePath != "" { - p, err := filepath.Abs(filePath) - if err != nil { - return nil, err - } - return ioutil.ReadFile(p) - } - return nil, nil -} - func isStringType(val interface{}) error { switch val.(type) { case string: diff --git a/accesscontrol/jwt_test.go b/accesscontrol/jwt_test.go index ba0c7e60f..eddec7570 100644 --- a/accesscontrol/jwt_test.go +++ b/accesscontrol/jwt_test.go @@ -8,14 +8,15 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" - "github.com/avenga/couper/errors" - "github.com/dgrijalva/jwt-go/v4" ac "github.com/avenga/couper/accesscontrol" + "github.com/avenga/couper/config/reader" "github.com/avenga/couper/config/request" + "github.com/avenga/couper/errors" "github.com/avenga/couper/internal/test" ) @@ -27,6 +28,7 @@ func Test_JWT_NewJWT_RSA(t *testing.T) { claims map[string]interface{} claimsRequired []string pubKey []byte + pubKeyPath string } privKey, err := rsa.GenerateKey(rand.Reader, 2048) @@ -74,7 +76,8 @@ QolLGgj3tz4NbDEitq+zKMr0uTHvP1Vyu1mXAflcpYcJA4ZmuB3Oj39e0U0gnmr/ fields fields wantErr string }{ - {"missing key", fields{}, "configuration error: test_ac: key required"}, + {"missing key-file path", fields{}, "configuration error: jwt key: read error: required: configured attribute or file"}, + {"missing key-file", fields{pubKeyPath: "./not-there.file"}, "not-there.file: no such file or directory"}, {"PKIX", fields{ algorithm: alg, pubKey: pubKeyBytesPKIX, @@ -94,27 +97,36 @@ QolLGgj3tz4NbDEitq+zKMr0uTHvP1Vyu1mXAflcpYcJA4ZmuB3Oj39e0U0gnmr/ } for _, tt := range tests { - t.Run(fmt.Sprintf("%v / %s", signingMethod, tt.name), func(t *testing.T) { + t.Run(fmt.Sprintf("%v / %s", signingMethod, tt.name), func(subT *testing.T) { + key, rerr := reader.ReadFromAttrFile("jwt key", string(tt.fields.pubKey), tt.fields.pubKeyPath) + if rerr != nil { + logErr := rerr.(errors.GoError) + if tt.wantErr != "" && !strings.HasSuffix(logErr.LogError(), tt.wantErr) { + subT.Errorf("\nWant:\t%q\nGot:\t%q", tt.wantErr, logErr.LogError()) + } else if tt.wantErr == "" { + subT.Fatal(logErr.LogError()) + } + return + } + j, jerr := ac.NewJWT(&ac.JWTOptions{ Algorithm: tt.fields.algorithm, Claims: tt.fields.claims, ClaimsRequired: tt.fields.claimsRequired, Name: "test_ac", - Key: string(tt.fields.pubKey), + Key: key, Source: ac.NewJWTSource("", "Authorization"), }) if jerr != nil { gerr := jerr.(errors.GoError) if tt.wantErr != gerr.LogError() { - t.Errorf("error: %v, want: %v", gerr.LogError(), tt.wantErr) + subT.Errorf("error: %v, want: %v", gerr.LogError(), tt.wantErr) } } else if tt.wantErr != "" { - t.Errorf("error expected: %v", tt.wantErr) + subT.Errorf("error expected: %v", tt.wantErr) } - if tt.wantErr == "" { - if j == nil { - t.Errorf("JWT struct expected") - } + if tt.wantErr == "" && j == nil { + subT.Errorf("JWT struct expected") } }) } @@ -225,7 +237,7 @@ func Test_JWT_Validate(t *testing.T) { ClaimsRequired: tt.fields.claimsRequired, Name: "test_ac", Source: tt.fields.source, - Key: string(tt.fields.pubKey), + Key: tt.fields.pubKey, }) if err != nil { t.Error(err) diff --git a/accesscontrol/saml2.go b/accesscontrol/saml2.go index bb0300883..215c645f2 100644 --- a/accesscontrol/saml2.go +++ b/accesscontrol/saml2.go @@ -6,9 +6,7 @@ import ( "encoding/base64" "encoding/xml" "fmt" - "io/ioutil" "net/http" - "path/filepath" "sort" saml2 "github.com/russellhaering/gosaml2" @@ -25,20 +23,9 @@ type Saml2 struct { sp *saml2.SAMLServiceProvider } -func NewSAML2ACS(metadataFile string, name string, acsUrl string, spEntityId string, arrayAttributes []string) (*Saml2, error) { - p, err := filepath.Abs(metadataFile) - if err != nil { - return nil, err - } - - rawMetadata, err := ioutil.ReadFile(p) - if err != nil { - return nil, err - } - - metadata := &types.EntityDescriptor{} - err = xml.Unmarshal(rawMetadata, metadata) - if err != nil { +func NewSAML2ACS(metadata []byte, name string, acsUrl string, spEntityId string, arrayAttributes []string) (*Saml2, error) { + metadataEntity := &types.EntityDescriptor{} + if err := xml.Unmarshal(metadata, metadataEntity); err != nil { return nil, err } @@ -46,7 +33,7 @@ func NewSAML2ACS(metadataFile string, name string, acsUrl string, spEntityId str Roots: []*x509.Certificate{}, } - for _, kd := range metadata.IDPSSODescriptor.KeyDescriptors { + for _, kd := range metadataEntity.IDPSSODescriptor.KeyDescriptors { for idx, xcert := range kd.KeyInfo.X509Data.X509Certificates { if xcert.Data == "" { return nil, fmt.Errorf("metadata certificate(%d) must not be empty", idx) @@ -69,7 +56,7 @@ func NewSAML2ACS(metadataFile string, name string, acsUrl string, spEntityId str AssertionConsumerServiceURL: acsUrl, AudienceURI: spEntityId, IDPCertificateStore: &certStore, - IdentityProviderIssuer: metadata.EntityID, + IdentityProviderIssuer: metadataEntity.EntityID, } if arrayAttributes != nil { sort.Strings(arrayAttributes) diff --git a/accesscontrol/saml2_test.go b/accesscontrol/saml2_test.go index ed247f760..4587f3ed6 100644 --- a/accesscontrol/saml2_test.go +++ b/accesscontrol/saml2_test.go @@ -15,36 +15,46 @@ import ( "github.com/russellhaering/gosaml2/types" ac "github.com/avenga/couper/accesscontrol" + "github.com/avenga/couper/config/reader" + "github.com/avenga/couper/errors" + "github.com/avenga/couper/internal/test" ) func Test_NewSAML2ACS(t *testing.T) { + helper := test.New(t) + type testCase struct { metadataFile, acsUrl, spEntityId string arrayAttributes []string expErrMsg string shouldFail bool } + for _, tc := range []testCase{ {"testdata/idp-metadata.xml", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{}, "", false}, {"not-there.xml", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{}, "not-there.xml: no such file or directory", true}, } { - sa, err := ac.NewSAML2ACS(tc.metadataFile, "test", tc.acsUrl, tc.spEntityId, tc.arrayAttributes) - if tc.shouldFail && sa != nil { - t.Error("Expected no successful saml acs creation") - } - - if tc.shouldFail && err != nil && tc.expErrMsg != "" { - if !strings.HasSuffix(err.Error(), tc.expErrMsg) { - t.Errorf("Expected error message suffix: %q, got: %q", tc.expErrMsg, err.Error()) + metadata, err := reader.ReadFromAttrFile("saml2", "", tc.metadataFile) + if err != nil { + readErr := err.(errors.GoError) + if tc.shouldFail { + if !strings.HasSuffix(readErr.LogError(), tc.expErrMsg) { + t.Errorf("Want: %q, got: %q", tc.expErrMsg, readErr.LogError()) + } + continue } - } else if err != nil { t.Error(err) + continue } + + _, err = ac.NewSAML2ACS(metadata, "test", tc.acsUrl, tc.spEntityId, tc.arrayAttributes) + helper.Must(err) } } func Test_SAML2ACS_Validate(t *testing.T) { - sa, err := ac.NewSAML2ACS("testdata/idp-metadata.xml", "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"}) + metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml") + sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"}) if err != nil || sa == nil { t.Fatal("Expected a saml acs object") } @@ -87,7 +97,8 @@ func Test_SAML2ACS_Validate(t *testing.T) { } func Test_SAML2ACS_ValidateAssertionInfo(t *testing.T) { - sa, err := ac.NewSAML2ACS("testdata/idp-metadata.xml", "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"}) + metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml") + sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"}) if err != nil || sa == nil { t.Fatal("Expected a saml acs object") } @@ -124,7 +135,8 @@ func Test_SAML2ACS_ValidateAssertionInfo(t *testing.T) { } func Test_SAML2ACS_GetAssertionData(t *testing.T) { - sa, err := ac.NewSAML2ACS("testdata/idp-metadata.xml", "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"}) + metadata, err := reader.ReadFromAttrFile("saml2", "", "testdata/idp-metadata.xml") + sa, err := ac.NewSAML2ACS(metadata, "test", "http://www.examle.org/saml/acs", "my-sp-entity-id", []string{"memberOf"}) if err != nil || sa == nil { t.Fatal("Expected a saml acs object") } diff --git a/config/ac_saml.go b/config/ac_saml.go index 65e09ead8..1aa0aa5c9 100644 --- a/config/ac_saml.go +++ b/config/ac_saml.go @@ -13,6 +13,9 @@ type SAML struct { Remain hcl.Body `hcl:",remain"` SpAcsUrl string `hcl:"sp_acs_url"` SpEntityId string `hcl:"sp_entity_id"` + + // internally used + MetadataBytes []byte } // HCLBody implements the interface. diff --git a/config/configload/load.go b/config/configload/load.go index 5070da051..67bfc34ac 100644 --- a/config/configload/load.go +++ b/config/configload/load.go @@ -17,6 +17,7 @@ import ( "github.com/avenga/couper/config" hclbody "github.com/avenga/couper/config/body" "github.com/avenga/couper/config/parser" + "github.com/avenga/couper/config/reader" "github.com/avenga/couper/errors" "github.com/avenga/couper/eval" ) @@ -217,6 +218,22 @@ func LoadConfig(body hcl.Body, src []byte, filename string) (*config.Couper, err } // Prepare dynamic functions + for _, profile := range couperConfig.Definitions.JWTSigningProfile { + key, err := reader.ReadFromAttrFile("jwt_signing_profile key", profile.Key, profile.KeyFile) + if err != nil { + return nil, errors.Configuration.Label(profile.Name).With(err) + } + profile.KeyBytes = key + } + + for _, saml := range couperConfig.Definitions.SAML { + metadata, err := reader.ReadFromFile("saml2 idp_metadata_file", saml.IdpMetadataFile) + if err != nil { + return nil, errors.Configuration.Label(saml.Name).With(err) + } + saml.MetadataBytes = metadata + } + couperConfig.Context = evalContext. WithJWTProfiles(couperConfig.Definitions.JWTSigningProfile). WithSAML(couperConfig.Definitions.SAML) diff --git a/config/jwt_signing_profile.go b/config/jwt_signing_profile.go index ca6bb8a31..f4d1d9fff 100644 --- a/config/jwt_signing_profile.go +++ b/config/jwt_signing_profile.go @@ -7,4 +7,7 @@ type JWTSigningProfile struct { Name string `hcl:"name,label"` SignatureAlgorithm string `hcl:"signature_algorithm"` TTL string `hcl:"ttl"` + + // internally used + KeyBytes []byte } diff --git a/config/reader/file.go b/config/reader/file.go new file mode 100644 index 000000000..9890ff3df --- /dev/null +++ b/config/reader/file.go @@ -0,0 +1,43 @@ +package reader + +import ( + "io/ioutil" + "path/filepath" + + "github.com/avenga/couper/errors" +) + +func ReadFromAttrFile(context, attribute, path string) ([]byte, error) { + readErr := errors.Configuration.Label(context + ": read error") + if attribute != "" && path != "" { + return nil, readErr.Message("configured attribute and file") + } else if attribute == "" && path == "" { + return nil, readErr.Message("required: configured attribute or file") + } + + if path != "" { + return ReadFromFile(context, path) + } + + return []byte(attribute), nil +} + +func ReadFromFile(context, path string) ([]byte, error) { + readErr := errors.Configuration.Label(context + ": read error") + if path == "" { + return nil, readErr.Message("required: configured file") + } + + absPath, err := filepath.Abs(path) + if err != nil { + return nil, readErr.With(err) + } + b, err := ioutil.ReadFile(absPath) + if err != nil { + return nil, readErr.With(err) + } + if len(b) == 0 { + return nil, readErr.Message("empty file") + } + return b, nil +} diff --git a/config/reader/file_test.go b/config/reader/file_test.go new file mode 100644 index 000000000..01361a15d --- /dev/null +++ b/config/reader/file_test.go @@ -0,0 +1,48 @@ +package reader_test + +import ( + "io/ioutil" + "reflect" + "runtime" + "testing" + + "github.com/avenga/couper/config/reader" +) + +func TestReadFromAttrFile(t *testing.T) { + _, file, _, _ := runtime.Caller(0) + expBytes, ferr := ioutil.ReadFile(file) + if ferr != nil { + t.Fatal(ferr) + } + + type args struct { + context string + attribute string + path string + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + {"not configured", args{context: "testcase"}, nil, true}, + {"both configured", args{context: "testcase", attribute: "", path: ""}, nil, true}, + {"attr configured", args{context: "testcase", attribute: "key", path: ""}, []byte("key"), false}, + {"path configured", args{context: "testcase", attribute: "", path: file}, expBytes, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := reader.ReadFromAttrFile(tt.args.context, tt.args.attribute, tt.args.path) + if (err != nil) != tt.wantErr { + t.Errorf("ReadFromAttrFile() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ReadFromAttrFile() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/config/runtime/server.go b/config/runtime/server.go index 5b63d06c8..4d76051af 100644 --- a/config/runtime/server.go +++ b/config/runtime/server.go @@ -21,6 +21,7 @@ import ( ac "github.com/avenga/couper/accesscontrol" "github.com/avenga/couper/cache" "github.com/avenga/couper/config" + "github.com/avenga/couper/config/reader" "github.com/avenga/couper/config/runtime/server" "github.com/avenga/couper/errors" "github.com/avenga/couper/eval" @@ -84,8 +85,7 @@ func GetHostPort(hostPort string) (string, int, error) { // NewServerConfiguration sets http handler specific defaults and validates the given gateway configuration. // Wire up all endpoints and maps them within the returned Server. -func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *cache.MemoryStore, -) (ServerConfiguration, error) { +func NewServerConfiguration(conf *config.Couper, log *logrus.Entry, memStore *cache.MemoryStore) (ServerConfiguration, error) { // confCtx is created to evaluate request / response related configuration errors on start. noopReq, _ := http.NewRequest(http.MethodGet, "https://couper.io", nil) noopResp := httptest.NewRecorder().Result() @@ -413,6 +413,11 @@ func configureAccessControls(conf *config.Couper, confCtx *hcl.EvalContext) (ACD } for _, jwtConf := range conf.Definitions.JWT { + key, err := reader.ReadFromAttrFile("jwt key", jwtConf.Key, jwtConf.KeyFile) + if err != nil { + return nil, errors.Configuration.Label(jwtConf.Name).With(err) + } + var claims map[string]interface{} if jwtConf.Claims != nil { // TODO: dynamic expr eval ? c, diags := seetie.ExpToMap(confCtx, jwtConf.Claims) @@ -425,8 +430,7 @@ func configureAccessControls(conf *config.Couper, confCtx *hcl.EvalContext) (ACD Algorithm: jwtConf.SignatureAlgorithm, Claims: claims, ClaimsRequired: jwtConf.ClaimsRequired, - Key: jwtConf.Key, - KeyFile: jwtConf.KeyFile, + Key: key, Name: jwtConf.Name, Source: ac.NewJWTSource(jwtConf.Cookie, jwtConf.Header), }) @@ -440,7 +444,11 @@ func configureAccessControls(conf *config.Couper, confCtx *hcl.EvalContext) (ACD } for _, saml := range conf.Definitions.SAML { - s, err := ac.NewSAML2ACS(saml.IdpMetadataFile, saml.Name, saml.SpAcsUrl, saml.SpEntityId, saml.ArrayAttributes) + metadata, err := reader.ReadFromFile("saml2 idp_metadata_file", saml.IdpMetadataFile) + if err != nil { + return nil, errors.Configuration.Label(saml.Name).With(err) + } + s, err := ac.NewSAML2ACS(metadata, saml.Name, saml.SpAcsUrl, saml.SpEntityId, saml.ArrayAttributes) if err != nil { return nil, fmt.Errorf("loading saml definition failed: %s", err) } diff --git a/eval/context.go b/eval/context.go index e49181fad..68cfaf67b 100644 --- a/eval/context.go +++ b/eval/context.go @@ -204,14 +204,11 @@ func (c *Context) HCLContext() *hcl.EvalContext { // updateFunctions recreates the listed functions with latest evaluation context. func updateFunctions(ctx *Context) { - if len(ctx.profiles) > 0 { - jwtfn := lib.NewJwtSignFunction(ctx.profiles, ctx.eval) - ctx.eval.Functions[lib.FnJWTSign] = jwtfn - } - if len(ctx.saml) > 0 { - samlfn := lib.NewSamlSsoUrlFunction(ctx.saml) - ctx.eval.Functions[lib.FnSamlSsoUrl] = samlfn - } + jwtfn := lib.NewJwtSignFunction(ctx.profiles, ctx.eval) + ctx.eval.Functions[lib.FnJWTSign] = jwtfn + + samlfn := lib.NewSamlSsoUrlFunction(ctx.saml) + ctx.eval.Functions[lib.FnSamlSsoUrl] = samlfn } const defaultMaxMemory = 32 << 20 // 32 MB diff --git a/eval/lib/jwt.go b/eval/lib/jwt.go index 27af9f971..4d3b67aba 100644 --- a/eval/lib/jwt.go +++ b/eval/lib/jwt.go @@ -1,10 +1,9 @@ package lib import ( + "crypto/rsa" "encoding/json" - "errors" - "io/ioutil" - "path/filepath" + "fmt" "strings" "time" @@ -20,25 +19,24 @@ import ( const FnJWTSign = "jwt_sign" -var ( - ErrorNoProfileForLabel = errors.New("no signing profile for label") - ErrorMissingKey = errors.New("either key_file or key must be specified") - ErrorUnsupportedSigningMethod = errors.New("unsupported signing method") -) - -type JwtSigningError struct { - error -} - -func (e *JwtSigningError) Error() string { - return e.error.Error() -} +var rsaParseError = &rsa.PrivateKey{} func NewJwtSignFunction(jwtSigningProfiles []*config.JWTSigningProfile, confCtx *hcl.EvalContext) function.Function { signingProfiles := make(map[string]*config.JWTSigningProfile) + rsaKeys := make(map[string]*rsa.PrivateKey) + for _, sp := range jwtSigningProfiles { signingProfiles[sp.Name] = sp + if strings.HasPrefix(sp.SignatureAlgorithm, "RS") { + key, err := jwt.ParseRSAPrivateKeyFromPEM(sp.KeyBytes) + if err != nil { + rsaKeys[sp.Name] = rsaParseError + continue + } + rsaKeys[sp.Name] = key + } } + return function.New(&function.Spec{ Params: []function.Parameter{ { @@ -52,29 +50,14 @@ func NewJwtSignFunction(jwtSigningProfiles []*config.JWTSigningProfile, confCtx }, Type: function.StaticReturnType(cty.String), Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) { + if len(signingProfiles) == 0 { + return cty.StringVal(""), fmt.Errorf("missing jwt_signing_profile definitions") + } + label := args[0].AsString() signingProfile := signingProfiles[label] if signingProfile == nil { - return cty.StringVal(""), &JwtSigningError{error: ErrorNoProfileForLabel} - } - - // get key or secret - var keyData []byte - if signingProfile.KeyFile != "" { - p, err := filepath.Abs(signingProfile.KeyFile) - if err != nil { - return cty.StringVal(""), err - } - content, err := ioutil.ReadFile(p) - if err != nil { - return cty.StringVal(""), err - } - keyData = content - } else if signingProfile.Key != "" { - keyData = []byte(signingProfile.Key) - } - if len(keyData) == 0 { - return cty.StringVal(""), &JwtSigningError{error: ErrorMissingKey} + return cty.StringVal(""), fmt.Errorf("missing jwt_signing_profile for given label: %s", label) } mapClaims := jwt.MapClaims{} @@ -93,9 +76,9 @@ func NewJwtSignFunction(jwtSigningProfiles []*config.JWTSigningProfile, confCtx mapClaims[k] = v } if signingProfile.TTL != "0" { - ttl, err := time.ParseDuration(signingProfile.TTL) - if err != nil { - return cty.StringVal(""), err + ttl, parseErr := time.ParseDuration(signingProfile.TTL) + if parseErr != nil { + return cty.StringVal(""), parseErr } mapClaims["exp"] = time.Now().Unix() + int64(ttl.Seconds()) } @@ -118,19 +101,19 @@ func NewJwtSignFunction(jwtSigningProfiles []*config.JWTSigningProfile, confCtx // create token signingMethod := jwt.GetSigningMethod(signingProfile.SignatureAlgorithm) if signingMethod == nil { - return cty.StringVal(""), &JwtSigningError{error: ErrorUnsupportedSigningMethod} + return cty.StringVal(""), fmt.Errorf("no signing method for given algorithm: %s", signingProfile.SignatureAlgorithm) } token := jwt.NewWithClaims(signingMethod, mapClaims) var key interface{} - if strings.HasPrefix(signingProfile.SignatureAlgorithm, "RS") { - key, err = jwt.ParseRSAPrivateKeyFromPEM(keyData) - if err != nil { - return cty.StringVal(""), err + if rsaKey, exist := rsaKeys[signingProfile.Name]; exist { + if rsaKey == rsaParseError { + return cty.StringVal(""), fmt.Errorf("could not parse rsa private key from pem: %s", signingProfile.Name) } + key = rsaKey } else { - key = keyData + key = signingProfile.KeyBytes } // sign token diff --git a/eval/lib/jwt_test.go b/eval/lib/jwt_test.go index 4af3ebf92..e191979ea 100644 --- a/eval/lib/jwt_test.go +++ b/eval/lib/jwt_test.go @@ -477,55 +477,46 @@ func TestJwtSignError(t *testing.T) { wantErr string }{ { - "No profile for label", + "missing jwt_signing_profile definitions", ` server "test" { - } - definitions { - jwt_signing_profile "MyToken" { - signature_algorithm = "HS256" - key = "$3cRe4" - ttl = "0" - claims = { - iss = to_lower("The_Issuer") - aud = to_upper("The_Audience") + endpoint "/" { + response { + body = jwt_sign() } } } `, - "NoProfileForThisLabel", - `{"sub":"12345"}`, - "no signing profile for label", + "MyToken", + `{"sub": "12345"}`, + "missing jwt_signing_profile definitions", }, { - "Missing file for key_file", + "invalid PEM key format", ` server "test" { } definitions { jwt_signing_profile "MyToken" { - signature_algorithm = "HS256" - key_file = "not_there.txt" - ttl = "0" - claims = { - iss = to_lower("The_Issuer") - aud = to_upper("The_Audience") - } + signature_algorithm = "RS256" + key = "invalid" + ttl = 0 } } `, "MyToken", - `{"sub":"12345"}`, - "no such file or directory", + `{"sub": "12345"}`, + "could not parse rsa private key from pem: MyToken", }, { - "Missing key and key_file", + "No profile for label", ` server "test" { } definitions { jwt_signing_profile "MyToken" { signature_algorithm = "HS256" + key = "$3cRe4" ttl = "0" claims = { iss = to_lower("The_Issuer") @@ -534,9 +525,9 @@ func TestJwtSignError(t *testing.T) { } } `, - "MyToken", + "NoProfileForThisLabel", `{"sub":"12345"}`, - "either key_file or key must be specified", + "missing jwt_signing_profile for given label: NoProfileForThisLabel", }, { "Invalid ttl value", @@ -587,29 +578,27 @@ func TestJwtSignError(t *testing.T) { `, "MyToken", `{"sub": "12345"}`, - "unsupported signing method", + "no signing method for given algorithm: invalid", }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(st *testing.T) { + helper := test.New(st) cf, err := configload.LoadBytes([]byte(tt.hcl), "couper.hcl") - if err != nil { - t.Fatal(err) - } + helper.Must(err) claims, err := stdlib.JSONDecode(cty.StringVal(tt.claims)) - if err != nil { - t.Fatal(err) - } + helper.Must(err) hclContext := cf.Context.Value(eval.ContextType).(*eval.Context).HCLContext() _, err = hclContext.Functions[lib.FnJWTSign].Call([]cty.Value{cty.StringVal(tt.jspLabel), claims}) if err == nil { - t.Fatal(err) + t.Error("expected an error, got nothing") + return } if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("Expected %q, got: %#v", tt.wantErr, err.Error()) + t.Errorf("Want:\t%q\nGot:\t%q", tt.wantErr, err.Error()) } }) } diff --git a/eval/lib/saml.go b/eval/lib/saml.go index e3544af05..b8dba22b9 100644 --- a/eval/lib/saml.go +++ b/eval/lib/saml.go @@ -2,8 +2,7 @@ package lib import ( "encoding/xml" - "io/ioutil" - "path/filepath" + "fmt" saml2 "github.com/russellhaering/gosaml2" "github.com/russellhaering/gosaml2/types" @@ -18,11 +17,24 @@ const ( NameIdFormatUnspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" ) -func NewSamlSsoUrlFunction(samlConfigs []*config.SAML) function.Function { - samls := make(map[string]*config.SAML) - for _, s := range samlConfigs { - samls[s.Name] = s +func NewSamlSsoUrlFunction(configs []*config.SAML) function.Function { + type entity struct { + config *config.SAML + descriptor *types.EntityDescriptor + err error } + + samlEntities := make(map[string]*entity) + for _, conf := range configs { + metadata := &types.EntityDescriptor{} + err := xml.Unmarshal(conf.MetadataBytes, metadata) + samlEntities[conf.Name] = &entity{ + config: conf, + descriptor: metadata, + err: err, + } + } + return function.New(&function.Spec{ Params: []function.Parameter{ { @@ -32,24 +44,20 @@ func NewSamlSsoUrlFunction(samlConfigs []*config.SAML) function.Function { }, Type: function.StaticReturnType(cty.String), Impl: func(args []cty.Value, _ cty.Type) (ret cty.Value, err error) { - label := args[0].AsString() - saml := samls[label] - p, err := filepath.Abs(saml.IdpMetadataFile) - if err != nil { - return cty.StringVal(""), err + if len(samlEntities) == 0 { + return cty.StringVal(""), fmt.Errorf("missing saml2 definitions") } - rawMetadata, err := ioutil.ReadFile(p) - if err != nil { - return cty.StringVal(""), err + if len(args) == 0 { + return cty.StringVal(""), fmt.Errorf("missing saml2 definition reference") } - metadata := &types.EntityDescriptor{} - err = xml.Unmarshal(rawMetadata, metadata) - if err != nil { - return cty.StringVal(""), err + ent := samlEntities[args[0].AsString()] + if ent.err != nil { + return cty.StringVal(""), ent.err } + metadata := ent.descriptor var ssoUrl string for _, ssoService := range metadata.IDPSSODescriptor.SingleSignOnServices { if ssoService.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" { @@ -61,9 +69,9 @@ func NewSamlSsoUrlFunction(samlConfigs []*config.SAML) function.Function { nameIDFormat := getNameIDFormat(metadata.IDPSSODescriptor.NameIDFormats) sp := &saml2.SAMLServiceProvider{ - AssertionConsumerServiceURL: saml.SpAcsUrl, + AssertionConsumerServiceURL: ent.config.SpAcsUrl, IdentityProviderSSOURL: ssoUrl, - ServiceProviderIssuer: saml.SpEntityId, + ServiceProviderIssuer: ent.config.SpEntityId, SignAuthnRequests: false, } if nameIDFormat != "" { diff --git a/eval/lib/saml_test.go b/eval/lib/saml_test.go index 595078027..3509182e9 100644 --- a/eval/lib/saml_test.go +++ b/eval/lib/saml_test.go @@ -10,12 +10,12 @@ import ( "strings" "testing" - "github.com/avenga/couper/eval" - "github.com/zclconf/go-cty/cty" "github.com/avenga/couper/config/configload" + "github.com/avenga/couper/eval" "github.com/avenga/couper/eval/lib" + "github.com/avenga/couper/internal/test" ) func Test_SamlSsoUrl(t *testing.T) { @@ -82,57 +82,53 @@ func Test_SamlSsoUrl(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(st *testing.T) { + h := test.New(st) cf, err := configload.LoadBytes([]byte(tt.hcl), "couper.hcl") if err != nil { - t.Fatal(err) + if tt.wantErr { + return + } + h.Must(err) } hclContext := cf.Context.Value(eval.ContextType).(*eval.Context).HCLContext() ssoUrl, err := hclContext.Functions[lib.FnSamlSsoUrl].Call([]cty.Value{cty.StringVal(tt.samlLabel)}) if err == nil && tt.wantErr { - t.Fatal("Error expected") + st.Fatal("Error expected") } if err != nil { if !tt.wantErr { - t.Fatal(err) + h.Must(err) } else { return } } if !strings.HasPrefix(ssoUrl.AsString(), tt.wantPfx) { - t.Errorf("Expected to start with %q, got: %#v", tt.wantPfx, ssoUrl.AsString()) + st.Errorf("Expected to start with %q, got: %#v", tt.wantPfx, ssoUrl.AsString()) } u, err := url.Parse(ssoUrl.AsString()) - if err != nil { - t.Fatal(err) - } + h.Must(err) q := u.Query() samlRequest := q.Get("SAMLRequest") if samlRequest == "" { - t.Fatal("Expected SAMLRequest query param") + st.Fatal("Expected SAMLRequest query param") } b64Decoded, err := base64.StdEncoding.DecodeString(samlRequest) - if err != nil { - t.Fatal(err) - } + h.Must(err) fr := flate.NewReader(bytes.NewReader(b64Decoded)) deflated, err := ioutil.ReadAll(fr) - if err != nil { - t.Fatal(err) - } + h.Must(err) var x interface{} err = xml.Unmarshal(deflated, &x) - if err != nil { - t.Fatal(err) - } + h.Must(err) }) } diff --git a/internal/test/helper.go b/internal/test/helper.go index d8a103f80..146a0915d 100644 --- a/internal/test/helper.go +++ b/internal/test/helper.go @@ -1,6 +1,10 @@ package test -import "testing" +import ( + "testing" + + "github.com/avenga/couper/errors" +) type Helper struct { tb testing.TB @@ -13,6 +17,10 @@ func New(tb testing.TB) *Helper { func (h *Helper) Must(err error) { h.tb.Helper() if err != nil { + if logErr, ok := err.(errors.GoError); ok { + h.tb.Fatal(logErr.LogError()) + return + } h.tb.Fatal(err) } }