diff --git a/tests/unit/test_saml.py b/tests/unit/test_saml.py index 37d2952..0354759 100644 --- a/tests/unit/test_saml.py +++ b/tests/unit/test_saml.py @@ -196,15 +196,34 @@ def test_non_https_url(self, generic_auth, mock_requests_session, # The error is raised after the call to get the form, but before the # call to submit it. mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='http://example.com' ) + with pytest.raises(SAMLError, match='HTTPS'): generic_auth.retrieve_saml_assertion(config) + def test_endpoint_form_redirect_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=login_form, + url='https://test.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '
' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://test.com/login" + def test_form_action_appended_to_url(self, generic_auth, generic_config, login_form, mock_requests_session): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -216,10 +235,37 @@ def test_form_action_appended_to_url(self, generic_auth, generic_config, url_used = mock_requests_session.post.call_args[0][0] assert url_used == "https://example.com/login" + def test_form_action_replaces_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + saml_form = ( + '' + '' + '' + ) + + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://www.test.com" + def test_extract_assertion(self, generic_auth, mock_requests_session, generic_config, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -248,7 +294,8 @@ def test_passes_in_other_form_fields(self, generic_auth, generic_config, '