diff --git a/tests/unit/test_saml.py b/tests/unit/test_saml.py index 37d2952..0354759 100644 --- a/tests/unit/test_saml.py +++ b/tests/unit/test_saml.py @@ -196,15 +196,34 @@ def test_non_https_url(self, generic_auth, mock_requests_session, # The error is raised after the call to get the form, but before the # call to submit it. mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='http://example.com' ) + with pytest.raises(SAMLError, match='HTTPS'): generic_auth.retrieve_saml_assertion(config) + def test_endpoint_form_redirect_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=login_form, + url='https://test.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '
' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://test.com/login" + def test_form_action_appended_to_url(self, generic_auth, generic_config, login_form, mock_requests_session): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -216,10 +235,37 @@ def test_form_action_appended_to_url(self, generic_auth, generic_config, url_used = mock_requests_session.post.call_args[0][0] assert url_used == "https://example.com/login" + def test_form_action_replaces_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + saml_form = ( + '' + '
' + '' + '' + '' + '
' + '' + ) + + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '
' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://www.test.com" + def test_extract_assertion(self, generic_auth, mock_requests_session, generic_config, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -248,7 +294,8 @@ def test_passes_in_other_form_fields(self, generic_auth, generic_config, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=saml_form + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -280,7 +327,8 @@ def tests_uses_default_form_values(self, generic_auth, generic_config, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=saml_form + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -384,7 +432,8 @@ def test_missing_form_username(self, generic_auth, mock_requests_session, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=missing_form_fields + spec=requests.Response, status_code=200, text=missing_form_fields, + url='https://example.com' ) with pytest.raises(SAMLError, match='could not find'): generic_auth.retrieve_saml_assertion(generic_config) @@ -397,7 +446,8 @@ def test_missing_form_password(self, generic_auth, mock_requests_session, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=missing_form_fields + spec=requests.Response, status_code=200, text=missing_form_fields, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -422,7 +472,8 @@ def test_missing_form_password(self, generic_auth, mock_requests_session, def test_empty_assertion(self, generic_auth, mock_requests_session, login_form, generic_config, assertion_response): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=assertion_response @@ -433,7 +484,8 @@ def test_empty_assertion(self, generic_auth, mock_requests_session, def test_non_200_authenticate_response(self, generic_auth, generic_config, mock_requests_session, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, text=login_form, status_code=200 + spec=requests.Response, text=login_form, status_code=200, + url='https://example.com' ) # This 401 response represents an auth failure, such as a bad password. @@ -449,7 +501,8 @@ def test_non_200_authenticate_response(self, generic_auth, generic_config, def test_no_saml_assertion_in_response(self, generic_auth, generic_config, mock_requests_session, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, text=login_form, status_code=200 + spec=requests.Response, text=login_form, status_code=200, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, text='login failed', @@ -489,7 +542,8 @@ def test_authn_requests_made(self, okta_auth, okta_config, mock_requests_session.get.return_value = mock.Mock( text=('
'), - status_code=200 + status_code=200, + url='https://example.com' ) saml_assertion = okta_auth.retrieve_saml_assertion(okta_config) assert saml_assertion == 'fakeassertion' @@ -543,7 +597,8 @@ def test_uses_adfs_fields(self, adfs_auth, mock_requests_session, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=adfs_login_form + spec=requests.Response, status_code=200, text=adfs_login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=(