diff --git a/tests/integration/saml_test.go b/tests/integration/saml_test.go index 9115270fde182..4dfb5ba9296e3 100644 --- a/tests/integration/saml_test.go +++ b/tests/integration/saml_test.go @@ -5,7 +5,6 @@ package integration import ( "fmt" - "html" "io" "net/http" "net/http/cookiejar" @@ -77,26 +76,25 @@ func TestSAMLRegistration(t *testing.T) { req, err = http.NewRequest("GET", test.RedirectURL(resp), nil) assert.NoError(t, err) + var formRedirectURL *url.URL + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + // capture the redirected destination to use in POST request + formRedirectURL = req.URL + return nil + } + res, err := client.Do(req) + client.CheckRedirect = nil assert.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) - - // find the auth state hidden input - authStateMatcher := regexp.MustCompile(``) - body, err := io.ReadAll(res.Body) - assert.NoError(t, err) - - matches := authStateMatcher.FindStringSubmatch(string(body)) - assert.Len(t, matches, 2) - assert.NoError(t, res.Body.Close()) + assert.NotNil(t, formRedirectURL) form := url.Values{ - "username": {"user1"}, - "password": {"user1pass"}, - "AuthState": {html.UnescapeString(matches[1])}, + "username": {"user1"}, + "password": {"user1pass"}, } - req, err = http.NewRequest("POST", fmt.Sprintf("http://%s/simplesaml/module.php/core/loginuserpass.php", samlURL), strings.NewReader(form.Encode())) + req, err = http.NewRequest("POST", formRedirectURL.String(), strings.NewReader(form.Encode())) assert.NoError(t, err) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") @@ -104,12 +102,11 @@ func TestSAMLRegistration(t *testing.T) { assert.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) - samlResMatcher := regexp.MustCompile(``) - - body, err = io.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) assert.NoError(t, err) - matches = samlResMatcher.FindStringSubmatch(string(body)) + samlResMatcher := regexp.MustCompile(``) + matches := samlResMatcher.FindStringSubmatch(string(body)) assert.Len(t, matches, 2) assert.NoError(t, res.Body.Close())