diff --git a/README.md b/README.md index d6000d975..74f0d2998 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ The process goes something like this: * OneLogin * NetIQ * Browser, this uses [playwright-go](github.com/mxschmitt/playwright-go) to run a sandbox chromium window. + * [Auth0](pkg/provider/auth0/README.md) NOTE: Currently, MFA not supported * AWS SAML Provider configured ## Caveats diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index e497968fd..603654db1 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -69,7 +69,7 @@ func main() { commonFlags := new(flags.CommonFlags) app.Flag("config", "Path/filename of saml2aws config file (env: SAML2AWS_CONFIGFILE)").Envar("SAML2AWS_CONFIGFILE").StringVar(&commonFlags.ConfigFile) app.Flag("idp-account", "The name of the configured IDP account. (env: SAML2AWS_IDP_ACCOUNT)").Envar("SAML2AWS_IDP_ACCOUNT").Short('a').Default("default").StringVar(&commonFlags.IdpAccount) - app.Flag("idp-provider", "The configured IDP provider. (env: SAML2AWS_IDP_PROVIDER)").Envar("SAML2AWS_IDP_PROVIDER").EnumVar(&commonFlags.IdpProvider, "Akamai", "AzureAD", "ADFS", "ADFS2", "GoogleApps", "Ping", "JumpCloud", "Okta", "OneLogin", "PSU", "KeyCloak", "F5APM", "Shibboleth", "ShibbolethECP", "NetIQ") + app.Flag("idp-provider", "The configured IDP provider. (env: SAML2AWS_IDP_PROVIDER)").Envar("SAML2AWS_IDP_PROVIDER").EnumVar(&commonFlags.IdpProvider, "Akamai", "AzureAD", "ADFS", "ADFS2", "GoogleApps", "Ping", "JumpCloud", "Okta", "OneLogin", "PSU", "KeyCloak", "F5APM", "Shibboleth", "ShibbolethECP", "NetIQ", "Auth0") app.Flag("mfa", "The name of the mfa. (env: SAML2AWS_MFA)").Envar("SAML2AWS_MFA").StringVar(&commonFlags.MFA) app.Flag("skip-verify", "Skip verification of server certificate. (env: SAML2AWS_SKIP_VERIFY)").Envar("SAML2AWS_SKIP_VERIFY").Short('s').BoolVar(&commonFlags.SkipVerify) app.Flag("url", "The URL of the SAML IDP server used to login. (env: SAML2AWS_URL)").Envar("SAML2AWS_URL").StringVar(&commonFlags.URL) diff --git a/pkg/provider/auth0/README.md b/pkg/provider/auth0/README.md new file mode 100644 index 000000000..ced801b57 --- /dev/null +++ b/pkg/provider/auth0/README.md @@ -0,0 +1,33 @@ +## Auth0 Provider + +* https://auth0.com/ + +## Instructions + +You need the SAML policy ID for the AWS account and Auth0 issues URL like below: + +``` +https://.auth0.com/samlp/ +``` + +Example config: + +```ini +[default] +url = https://.auth0.com/samlp/ +username = +provider = Auth0 +skip_verify = false +timeout = 0 +aws_urn = urn:amazon:webservices +aws_session_duration = 3600 +aws_profile = +``` + +## Features + +* Currently, this provider does not support MFA. + +## More details + +* https://auth0.com/docs/protocols/saml-protocol/saml-configuration-options/configure-saml2-web-app-addon-for-aws diff --git a/pkg/provider/auth0/auth0.go b/pkg/provider/auth0/auth0.go new file mode 100644 index 000000000..61d7508b6 --- /dev/null +++ b/pkg/provider/auth0/auth0.go @@ -0,0 +1,423 @@ +package auth0 + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "html" + "io/ioutil" + "net/http" + "net/url" + "regexp" + "strings" + + "github.com/PuerkitoBio/goquery" + "github.com/pkg/errors" + "github.com/tidwall/gjson" + "github.com/versent/saml2aws/v2/pkg/cfg" + "github.com/versent/saml2aws/v2/pkg/creds" + "github.com/versent/saml2aws/v2/pkg/prompter" + "github.com/versent/saml2aws/v2/pkg/provider" +) + +const ( + connectionInfoJSURLFmt = "https://cdn.auth0.com/client/%s.js" + authOriginURLFmt = "https://%s.auth0.com" + authSubmitURLFmt = "https://%s.auth0.com/usernamepassword/login" +) + +var ( + authURLPattern = regexp.MustCompile(`https://([^.]+)\.auth0\.com/samlp/(.+)`) + connectionInfoPattern = regexp.MustCompile(`Auth0\.setClient\((.*)\)`) + sessionInfoPattern = regexp.MustCompile(`window\.atob\('(.*)'\)`) + + defaultPrompter = prompter.NewCli() +) + +// Client wrapper around Auth0. +type Client struct { + provider.ValidateBase + client *provider.HTTPClient +} + +// authInfo represents Auth0 first auth request +type authInfo struct { + clientID string + tenant string + connection string + state string + csrf string + connectionInfoURLFmt string + authOriginURLFmt string + authSubmitURLFmt string +} + +// authRequest represents Auth0 request +type authRequest struct { + ClientID string `json:"client_id"` + Connection string `json:"connection"` + Password string `json:"password"` + PopupOptions interface{} `json:"popup_options"` + Protocol string `json:"protocol"` + RedirectURI string `json:"redirect_uri"` + ResponseType string `json:"response_type"` + Scope string `json:"scope"` + SSO bool `json:"sso"` + State string `json:"state"` + Tenant string `json:"tenant"` + Username string `json:"username"` + CSRF string `json:"_csrf"` + Intstate string `json:"_intstate"` +} + +// clientInfo represents Auth0 client information +type clientInfo struct { + id string + tenantName string +} + +// sessionInfo represents Auth0 session information +type sessionInfo struct { + state string + csrf string +} + +//authCallbackRequest represents Auth0 authentication callback request +type authCallbackRequest struct { + method string + url string + body string +} + +type authInfoOption func(*authInfo) + +func defaultAuthInfoOptions() authInfoOption { + return func(ai *authInfo) { + ai.connectionInfoURLFmt = connectionInfoJSURLFmt + ai.authOriginURLFmt = authOriginURLFmt + ai.authSubmitURLFmt = authSubmitURLFmt + } +} + +// New create a new Auth0 Client +func New(idpAccount *cfg.IDPAccount) (*Client, error) { + tr := provider.NewDefaultTransport(idpAccount.SkipVerify) + client, err := provider.NewHTTPClient(tr, provider.BuildHttpClientOpts(idpAccount)) + if err != nil { + return nil, errors.Wrap(err, "error building http client") + } + + client.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator + + return &Client{ + client: client, + }, nil +} + +// Authenticate logs into Auth0 and returns a SAML response +func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { + authInfo, err := ac.buildAuthInfo(loginDetails.URL, defaultPrompter) + if err != nil { + return "", errors.Wrap(err, "error failed to build authentication info") + } + + formHTML, err := ac.doLogin(loginDetails, authInfo) + if err != nil { + return "", errors.Wrap(err, "error failed to fetch SAML") + } + + samlAssertion, err := mustFindInputByName(formHTML, "SAMLResponse") + if err != nil { + return "", errors.Wrap(err, "error failed to parse SAML") + } + + return samlAssertion, nil +} + +func (ac *Client) buildAuthInfo( + loginURL string, + prompter prompter.Prompter, + opts ...authInfoOption, +) (*authInfo, error) { + var ai authInfo + if len(opts) == 0 { + opts = []authInfoOption{defaultAuthInfoOptions()} + } + for _, opt := range opts { + opt(&ai) + } + + ci, err := extractClientInfo(loginURL) + if err != nil { + return nil, errors.Wrap(err, "error extractClientInfo") + } + + connectionNames, err := ac.getConnectionNames(fmt.Sprintf(ai.connectionInfoURLFmt, ci.id)) + if err != nil { + return nil, errors.Wrap(err, "error getConnectionNames") + } + + var connection string + switch { + case len(connectionNames) == 0: + return nil, errors.New("error connection name") + case len(connectionNames) == 1: + connection = connectionNames[0] + default: + index := prompter.Choose("Select connection", connectionNames) + connection = connectionNames[index] + } + + si, err := ac.fetchSessionInfo(loginURL) + if err != nil { + return nil, errors.Wrap(err, "error fetchSessionInfo") + } + + ai.clientID = ci.id + ai.tenant = ci.tenantName + ai.connection = connection + ai.state = si.state + ai.csrf = si.csrf + + return &ai, nil +} + +func (ac *Client) fetchSessionInfo(loginURL string) (*sessionInfo, error) { + req, err := http.NewRequest("GET", loginURL, nil) + if err != nil { + return nil, errors.Wrap(err, "error building request") + } + + resp, err := ac.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "error retrieving response") + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "error retrieving response body") + } + defer resp.Body.Close() + + tokenEncoded := sessionInfoPattern.FindStringSubmatch(string(respBody)) + if len(tokenEncoded) < 1 { + return nil, errors.New("error response doesn't match") + } + + jsonByte, err := base64.StdEncoding.DecodeString(tokenEncoded[1]) + if err != nil { + return nil, errors.Wrap(err, "error decoding matcher part by base64") + } + + state := gjson.Get(string(jsonByte), "state").String() + csrf := gjson.Get(string(jsonByte), "_csrf").String() + if len(state) == 0 || len(csrf) == 0 { + return nil, errors.New("error response doesn't include session info") + } + + return &sessionInfo{ + state: state, + csrf: csrf, + }, nil +} + +func (ac *Client) getConnectionNames(connectionInfoURL string) ([]string, error) { + req, err := http.NewRequest("GET", connectionInfoURL, nil) + if err != nil { + return nil, errors.Wrap(err, "error building request") + } + + resp, err := ac.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "error retrieving response") + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "error retrieving body from response") + } + defer resp.Body.Close() + + match := connectionInfoPattern.FindStringSubmatch(string(respBody)) + if len(match) < 2 { + return nil, errors.New("cannot find connection name") + } + + var connectionNames []string + result := gjson.Get(match[1], `strategies.#.connections.#.name`) + for _, ary := range result.Array() { + for _, name := range ary.Array() { + connectionNames = append(connectionNames, name.String()) + } + } + + return connectionNames, nil +} + +func (ac *Client) doLogin(loginDetails *creds.LoginDetails, ai *authInfo) (string, error) { + responseDoc, err := ac.loginAuth0(loginDetails, ai) + if err != nil { + return "", errors.Wrap(err, "error failed to login Auth0") + } + + authCallback, err := parseResponseForm(responseDoc) + if err != nil { + return "", errors.Wrap(err, "error parse response document") + } + + resp, err := ac.doAuthCallback(authCallback, ai) + if err != nil { + return "", errors.Wrap(err, "error failed to make callback") + } + + return resp, nil +} + +func (ac *Client) loginAuth0(loginDetails *creds.LoginDetails, ai *authInfo) (string, error) { + authReq := authRequest{ + ClientID: ai.clientID, + Connection: ai.connection, + Password: loginDetails.Password, + PopupOptions: "{}", + Protocol: "samlp", + RedirectURI: "https://signin.aws.amazon.com/saml", + ResponseType: "code", + Scope: "openid profile email", + SSO: true, + State: ai.state, + Tenant: ai.tenant, + Username: loginDetails.Username, + CSRF: ai.csrf, + Intstate: "deprecated", + } + + authBody := new(bytes.Buffer) + err := json.NewEncoder(authBody).Encode(authReq) + if err != nil { + return "", errors.Wrap(err, "error encoding authentication request") + } + + authSubmitURL := fmt.Sprintf(ai.authSubmitURLFmt, ai.tenant) + req, err := http.NewRequest("POST", authSubmitURL, authBody) + if err != nil { + return "", errors.Wrap(err, "error building authentication request") + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Origin", fmt.Sprintf(ai.authOriginURLFmt, ai.tenant)) + req.Header.Add( + "Auth0-Client", + base64.StdEncoding.EncodeToString( + []byte(`{"name":"lock.js","version":"11.11.0","lib_version":{"raw":"9.8.1"}}`), + ), + ) + + resp, err := ac.client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving auth response") + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", errors.Wrap(err, "error retrieving body from response") + } + defer resp.Body.Close() + + return string(respBody), nil +} + +func (ac *Client) doAuthCallback(authCallback *authCallbackRequest, ai *authInfo) (string, error) { + req, err := http.NewRequest(authCallback.method, authCallback.url, strings.NewReader(authCallback.body)) + if err != nil { + return "", errors.Wrap(err, "error building authentication callback request") + } + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Origin", fmt.Sprintf(ai.authOriginURLFmt, ai.tenant)) + resp, err := ac.client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving auth callback response") + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", errors.Wrap(err, "error retrieving body from response") + } + defer resp.Body.Close() + + return string(respBody), nil +} + +func extractClientInfo(url string) (*clientInfo, error) { + matches := authURLPattern.FindStringSubmatch(url) + if len(matches) < 3 { + return nil, errors.New("error invalid Auth0 URL") + } + + return &clientInfo{ + id: matches[2], + tenantName: matches[1], + }, nil +} + +func parseResponseForm(responseForm string) (*authCallbackRequest, error) { + doc, err := goquery.NewDocumentFromReader(strings.NewReader(responseForm)) + if err != nil { + return nil, errors.Wrap(err, "error build goquery error") + } + + form := doc.Find("form") + methodDownCase, ok := form.Attr("method") + if !ok { + return nil, errors.New("invalid form method") + } + + authCallbackURL, ok := form.Attr("action") + if !ok { + return nil, errors.New("invalid form action") + } + + authCallBackForm := url.Values{} + + input := doc.Find("input") + input.Each(func(_ int, selection *goquery.Selection) { + name, nameOk := selection.Attr("name") + value, valueOk := selection.Attr("value") + + if nameOk && valueOk { + authCallBackForm.Add(name, html.UnescapeString(value)) + } + }) + + authCallbackBody := authCallBackForm.Encode() + if len(authCallbackBody) == 0 { + return nil, errors.New("invalid input values") + } + + return &authCallbackRequest{ + method: strings.ToUpper(methodDownCase), + url: authCallbackURL, + body: authCallbackBody, + }, nil +} + +func mustFindInputByName(formHTML string, name string) (string, error) { + doc, err := goquery.NewDocumentFromReader(strings.NewReader(formHTML)) + if err != nil { + return "", errors.Wrap(err, "error parse document") + } + + var fieldValue string + doc.Find(fmt.Sprintf(`input[name="%s"]`, name)).Each( + func(i int, s *goquery.Selection) { + val, _ := s.Attr("value") + fieldValue = val + }, + ) + if len(fieldValue) == 0 { + return "", errors.New("error unable to get value") + } + + return fieldValue, nil +} diff --git a/pkg/provider/auth0/auth0_test.go b/pkg/provider/auth0/auth0_test.go new file mode 100644 index 000000000..2b40d7e17 --- /dev/null +++ b/pkg/provider/auth0/auth0_test.go @@ -0,0 +1,669 @@ +package auth0 + +import ( + "encoding/base64" + "fmt" + "github.com/versent/saml2aws/v2/pkg/cfg" + "github.com/versent/saml2aws/v2/pkg/creds" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/versent/saml2aws/v2/pkg/provider" +) + +const testSAMLFormHTMLFmt = `test +
+ + + +
` + +func newTestProviderHTTPClientHelper(t *testing.T) *Client { + t.Helper() + + tr := provider.NewDefaultTransport(false) + httpClient, _ := provider.NewHTTPClient(tr, provider.BuildHttpClientOpts(&cfg.IDPAccount{})) + httpClient.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator + + return &Client{ + ValidateBase: provider.ValidateBase{}, + client: httpClient, + } +} + +func Test_defaultAuthInfoOptions(t *testing.T) { + tests := []struct { + name string + want authInfo + }{ + { + name: "standard case", + want: authInfo{ + connectionInfoURLFmt: connectionInfoJSURLFmt, + authOriginURLFmt: authOriginURLFmt, + authSubmitURLFmt: authSubmitURLFmt, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got authInfo + opts := defaultAuthInfoOptions() + opts(&got) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("defaultAuthInfoOptions() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_fetchSessionInfo(t *testing.T) { + type fields struct { + mockServerHandlerFunc func(w http.ResponseWriter, r *http.Request) + } + tests := []struct { + name string + fields fields + want *sessionInfo + wantErr bool + }{ + { + name: "standard case", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + jsonStr := `{"state": "StateToken", "_csrf": "CSRFToken"}` + base64Encoded := base64.StdEncoding.EncodeToString([]byte(jsonStr)) + _, _ = w.Write([]byte(fmt.Sprintf(`window.atob('%s')`, base64Encoded))) + }, + }, + want: &sessionInfo{ + state: "StateToken", + csrf: "CSRFToken", + }, + }, + { + name: "error case: server returns error", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + }, + wantErr: true, + }, + { + name: "error case: server returns invalid response(not match)", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + jsonStr := `{"invalid": "response"}` + base64Encoded := base64.StdEncoding.EncodeToString([]byte(jsonStr)) + _, _ = w.Write([]byte(fmt.Sprintf(`%s`, base64Encoded))) + }, + }, + wantErr: true, + }, + { + name: "error case: server returns invalid response(not base64 encoded)", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + jsonStr := `{"invalid": "response"}` + w.Write([]byte(fmt.Sprintf(`window.atob('%s')`, jsonStr))) + }, + }, + wantErr: true, + }, + { + name: "error case: server returns invalid response(not value included)", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + jsonStr := `{"invalid": "response"}` + base64Encoded := base64.StdEncoding.EncodeToString([]byte(jsonStr)) + w.Write([]byte(fmt.Sprintf(`window.atob('%s')`, base64Encoded))) + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(tt.fields.mockServerHandlerFunc)) + defer testServer.Close() + + ac := newTestProviderHTTPClientHelper(t) + + got, err := ac.fetchSessionInfo(testServer.URL) + if (err != nil) != tt.wantErr { + t.Errorf("fetchSessionInfo() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("fetchSessionInfo() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_getConnectionNames(t *testing.T) { + type fields struct { + mockServerHandlerFunc func(w http.ResponseWriter, r *http.Request) + } + tests := []struct { + name string + fields fields + want []string + wantErr bool + }{ + { + name: "standard case: single connection", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte( + `Auth0.setClient({"strategies":[{"name":"user_pool_name","connections":[` + + `{"name":"connection-name1","display_name":"Connection name 1"}]}` + + `]});`), + ) + }, + }, + want: []string{ + "connection-name1", + }, + }, + { + name: "standard case: multiple connection", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte( + `Auth0.setClient({"strategies":[{"name":"user_pool_name","connections":[` + + `{"name":"connection-name1","display_name":"Connection name 1"},` + + `{"name":"connection-name2","display_name":"Connection name 2"}` + + `]});`)) + }, + }, + want: []string{ + "connection-name1", + "connection-name2", + }, + }, + { + name: "error case: server returns BadRequest", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + }, + wantErr: true, + }, + { + name: "error case: invalid response format", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`invalid`)) + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(tt.fields.mockServerHandlerFunc)) + defer testServer.Close() + + ac := newTestProviderHTTPClientHelper(t) + + got, err := ac.getConnectionNames(testServer.URL) + if (err != nil) != tt.wantErr { + t.Errorf("getConnectionNames() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getConnectionNames() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_doLogin(t *testing.T) { + type fields struct { + mockServerHandlerFunc func(w http.ResponseWriter, r *http.Request) + } + type args struct { + loginDetails *creds.LoginDetails + ai *authInfo + } + testArgs := args{ + loginDetails: &creds.LoginDetails{ + ClientID: "clientID", + ClientSecret: "clientSecret", + Username: "username", + Password: "password", + MFAToken: "mfaToken", + DuoMFAOption: "duoMFAOption", + URL: "URL", + StateToken: "stateToken", + }, + ai: &authInfo{ + clientID: "clientID", + tenant: "tenant", + connection: "connectionName", + state: "state", + csrf: "csrf", + }, + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + name: "standard case", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/tenant": + callbackURL := "http://" + r.Host + r.RequestURI + "/saml" + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(testSAMLFormHTMLFmt, callbackURL, "SAMLBase64Encoded"))) + case "/tenant/saml": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`response`)) + default: + w.WriteHeader(http.StatusBadRequest) + } + }, + }, + args: testArgs, + want: "response", + }, + { + name: "error case: loginAuth0 cause error", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + }, + args: testArgs, + wantErr: true, + }, + { + name: "error case: parseResponseForm cause error", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`invalid response for parseResponseForm`)) + }, + }, + args: testArgs, + wantErr: true, + }, + { + name: "error case: doAuthCallback cause error", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/tenant": + callbackURL := "http://" + r.Host + r.RequestURI + "/saml" + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(testSAMLFormHTMLFmt, callbackURL, "SAMLBase64Encoded"))) + default: + w.WriteHeader(http.StatusBadRequest) + } + }, + }, + args: testArgs, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(tt.fields.mockServerHandlerFunc)) + defer testServer.Close() + + ac := newTestProviderHTTPClientHelper(t) + tt.args.ai.authSubmitURLFmt = testServer.URL + "/%s" + tt.args.ai.authOriginURLFmt = testServer.URL + "/%s" + + got, err := ac.doLogin(tt.args.loginDetails, tt.args.ai) + if (err != nil) != tt.wantErr { + t.Errorf("doLogin() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("doLogin() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_loginAuth0(t *testing.T) { + type fields struct { + mockServerHandlerFunc func(w http.ResponseWriter, r *http.Request) + } + type args struct { + loginDetails *creds.LoginDetails + ai *authInfo + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + name: "standard case", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`response`)) + }, + }, + args: args{ + loginDetails: &creds.LoginDetails{ + ClientID: "clientID", + ClientSecret: "clientSecret", + Username: "username", + Password: "password", + MFAToken: "mfaToken", + DuoMFAOption: "duoMFAOption", + URL: "URL", + StateToken: "stateToken", + }, + ai: &authInfo{ + clientID: "clientID", + tenant: "tenant", + connection: "connectionName", + state: "state", + csrf: "csrf", + }, + }, + want: "response", + }, + { + name: "error case: server returns BadRequest", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + }, + args: args{ + loginDetails: &creds.LoginDetails{ + ClientID: "clientID", + ClientSecret: "clientSecret", + Username: "username", + Password: "password", + MFAToken: "mfaToken", + DuoMFAOption: "duoMFAOption", + URL: "URL", + StateToken: "stateToken", + }, + ai: &authInfo{ + clientID: "clientID", + tenant: "tenant", + connection: "connectionName", + state: "state", + csrf: "csrf", + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(tt.fields.mockServerHandlerFunc)) + defer testServer.Close() + + ac := newTestProviderHTTPClientHelper(t) + tt.args.ai.authSubmitURLFmt = testServer.URL + "/%s" + tt.args.ai.authOriginURLFmt = testServer.URL + "/%s" + + got, err := ac.loginAuth0(tt.args.loginDetails, tt.args.ai) + if (err != nil) != tt.wantErr { + t.Errorf("loginAuth0() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("loginAuth0() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_doAuthCallback(t *testing.T) { + type fields struct { + mockServerHandlerFunc func(w http.ResponseWriter, r *http.Request) + } + type args struct { + authCallback *authCallbackRequest + ai *authInfo + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + name: "standard case", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("authCallbackResponse")) + }, + }, + args: args{ + authCallback: &authCallbackRequest{ + method: "POST", + body: "RelayState=&SAMLResponse=SAMLBase64Encoded", + }, + ai: &authInfo{ + tenant: "tenant", + }, + }, + want: "authCallbackResponse", + }, + { + name: "error case: invalid HTTP method", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("authCallbackResponse")) + }, + }, + args: args{ + authCallback: &authCallbackRequest{ + method: "無効", // means invalid in Japanese + body: "RelayState=&SAMLResponse=SAMLBase64Encoded", + }, + }, + wantErr: true, + }, + { + name: "error case: server returns BadRequest", + fields: fields{ + mockServerHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + }, + args: args{ + authCallback: &authCallbackRequest{ + method: "POST", + body: "RelayState=&SAMLResponse=SAMLBase64Encoded", + }, + ai: &authInfo{ + tenant: "tenant", + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(tt.fields.mockServerHandlerFunc)) + defer testServer.Close() + + ac := newTestProviderHTTPClientHelper(t) + tt.args.authCallback.url = testServer.URL + + got, err := ac.doAuthCallback(tt.args.authCallback, tt.args.ai) + if (err != nil) != tt.wantErr { + t.Errorf("doAuthCallback() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("doAuthCallback() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_extractClientInfo(t *testing.T) { + type args struct { + urlStr string + } + tests := []struct { + name string + args args + want *clientInfo + wantErr bool + }{ + { + name: "standard case", + args: args{ + urlStr: "https://tenant.auth0.com/samlp/client_id", + }, + want: &clientInfo{ + id: "client_id", + tenantName: "tenant", + }, + }, + { + name: "error case: invalid URL", + args: args{ + urlStr: "https://example.com", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractClientInfo(tt.args.urlStr) + if (err != nil) != tt.wantErr { + t.Errorf("extractClientInfo() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("extractClientInfo() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseResponseForm(t *testing.T) { + type args struct { + responseForm string + } + tests := []struct { + name string + args args + want *authCallbackRequest + wantErr bool + }{ + { + name: "standard case", + args: args{ + responseForm: fmt.Sprintf(testSAMLFormHTMLFmt, "https://example.com/saml", "SAMLBase64Encoded"), + }, + want: &authCallbackRequest{ + method: "POST", + url: "https://example.com/saml", + body: "RelayState=&SAMLResponse=SAMLBase64Encoded", + }, + }, + { + name: "error case: no method attribute on form element", + args: args{ + responseForm: `
`, + }, + wantErr: true, + }, + { + name: "error case: no action attribute on form element", + args: args{ + responseForm: `
`, + }, + wantErr: true, + }, + { + name: "error case: input element in form element", + args: args{ + responseForm: `
`, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseResponseForm(tt.args.responseForm) + if (err != nil) != tt.wantErr { + t.Errorf("parseResponseForm() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseResponseForm() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mustFindInputByName(t *testing.T) { + type args struct { + formHTML string + name string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "standard case", + args: args{ + formHTML: fmt.Sprintf(testSAMLFormHTMLFmt, "https://example.com/saml", "SAMLBase64Encoded"), + name: "SAMLResponse", + }, + want: "SAMLBase64Encoded", + }, + { + name: "error case: SAML value is empty", + args: args{ + formHTML: fmt.Sprintf(testSAMLFormHTMLFmt, "https://example.com/saml", ""), + name: "SAMLResponse", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := mustFindInputByName(tt.args.formHTML, tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf("mustFindInputByName() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("mustFindInputByName() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/saml2aws.go b/saml2aws.go index 1e99f7b91..4d422a0a3 100644 --- a/saml2aws.go +++ b/saml2aws.go @@ -4,19 +4,19 @@ import ( "fmt" "sort" - "github.com/versent/saml2aws/v2/pkg/provider/browser" - "github.com/versent/saml2aws/v2/pkg/provider/netiq" - "github.com/versent/saml2aws/v2/pkg/cfg" "github.com/versent/saml2aws/v2/pkg/creds" "github.com/versent/saml2aws/v2/pkg/provider/aad" "github.com/versent/saml2aws/v2/pkg/provider/adfs" "github.com/versent/saml2aws/v2/pkg/provider/adfs2" "github.com/versent/saml2aws/v2/pkg/provider/akamai" + "github.com/versent/saml2aws/v2/pkg/provider/auth0" + "github.com/versent/saml2aws/v2/pkg/provider/browser" "github.com/versent/saml2aws/v2/pkg/provider/f5apm" "github.com/versent/saml2aws/v2/pkg/provider/googleapps" "github.com/versent/saml2aws/v2/pkg/provider/jumpcloud" "github.com/versent/saml2aws/v2/pkg/provider/keycloak" + "github.com/versent/saml2aws/v2/pkg/provider/netiq" "github.com/versent/saml2aws/v2/pkg/provider/okta" "github.com/versent/saml2aws/v2/pkg/provider/onelogin" "github.com/versent/saml2aws/v2/pkg/provider/pingfed" @@ -47,6 +47,7 @@ var MFAsByProvider = ProviderList{ "ShibbolethECP": []string{"auto", "phone", "push", "passcode"}, "NetIQ": []string{"Auto", "Privileged"}, "Browser": []string{"Auto"}, + "Auth0": []string{"Auto"}, } // Names get a list of provider names @@ -172,6 +173,11 @@ func NewSAMLClient(idpAccount *cfg.IDPAccount) (SAMLClient, error) { return netiq.New(idpAccount, idpAccount.MFA) case "Browser": return browser.New(idpAccount) + case "Auth0": + if invalidMFA(idpAccount.Provider, idpAccount.MFA) { + return nil, fmt.Errorf("Invalid MFA type: %v for %v provider", idpAccount.MFA, idpAccount.Provider) + } + return auth0.New(idpAccount) default: return nil, fmt.Errorf("Invalid provider: %v", idpAccount.Provider) } diff --git a/saml2aws_test.go b/saml2aws_test.go index 9963e24f4..51e0955c4 100644 --- a/saml2aws_test.go +++ b/saml2aws_test.go @@ -10,7 +10,7 @@ func TestProviderList_Keys(t *testing.T) { names := MFAsByProvider.Names() - require.Len(t, names, 16) + require.Len(t, names, 17) }