Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/djangoapps/third_party_auth/signals/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def update_saml_provider_configs_on_configuration_change(sender, instance, creat
# Find all existing SAMLProviderConfig instances (current_set) that should be
# pointing to this slug but are pointing to an older version
existing_providers = SAMLProviderConfig.objects.current_set().filter(
site_id=instance.site_id,
saml_configuration__site_id=instance.site_id,
saml_configuration__slug=instance.slug
).exclude(saml_configuration_id=instance.id)
).exclude(saml_configuration_id=instance.id).exclude(saml_configuration_id__isnull=True)

updated_count = 0
for provider_config in existing_providers:
Expand Down
250 changes: 168 additions & 82 deletions common/djangoapps/third_party_auth/signals/tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from unittest import mock
from unittest.mock import call
from django.test import TestCase, override_settings
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory
from django.contrib.sites.models import Site
from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory
from common.djangoapps.third_party_auth.models import SAMLProviderConfig


@ddt.ddt
Expand All @@ -21,97 +23,181 @@ def setUp(self):
org_info_str='{"en-US": {"url": "http://test.com", "displayname": "Test", "name": "test"}}'
)

@ddt.data(
# Case 1: Tests behavior when SAML config signal handlers are disabled
# Verifies that basic attributes are set but no provider updates are attempted
{
'enabled': False,
'simulate_error': False,
'description': 'handlers disabled',
'expected_calls': [
self.site1 = Site.objects.get_or_create(domain='test-site1.com', name='Site 1')[0]
self.site2 = Site.objects.get_or_create(domain='test-site2.com', name='Site 2')[0]

# Existing SAML config used by provider update tests
self.existing_saml_config = SAMLConfigurationFactory(
site=self.site1,
slug='slug',
entity_id='https://existing.example.com'
)

@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
def test_saml_config_signal_handlers_disabled(self, mock_set_custom_attribute):
"""
Test behavior when SAML config signal handlers are disabled.

Verifies that basic attributes are set but no provider updates are attempted.
"""
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=False):
self.saml_config.entity_id = 'https://updated.example.com'
self.saml_config.save()

expected_calls = [
call('saml_config_signal.enabled', False),
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
call('saml_config_signal.slug', 'test-config'),
],
'expected_call_count': 3,
},
# Case 2: Tests behavior when SAML config signal handlers are enabled
# Verifies that attributes are set and provider updates are attempted successfully
{
'enabled': True,
'simulate_error': False,
'description': 'handlers enabled',
'expected_calls': [
call('saml_config_signal.enabled', True),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ktyagiapphelix2u: Sorry if it wasn't clear, but I wanted thes custom attribute assertions for these tests moved to our new ddt test, which is the new test that serves as a replacement for this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure Robert Will Add This. Thankyou

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. This obviously can wait until tomorrow. And please don't clobber my changes, which I added to your branch. Thanks again.

call('saml_config_signal.new_config_id', 'CONFIG_ID'),
call('saml_config_signal.new_config_id', self.saml_config.id),
call('saml_config_signal.slug', 'test-config'),
call('saml_config_signal.updated_count', 0),
],
'expected_call_count': 4,
},
# Case 3: Tests error handling when signal handlers are enabled but encounter an exception
# Verifies that error information is properly captured when provider updates fail
{
'enabled': True,
'simulate_error': True,
'description': 'handlers enabled with exception',
'expected_calls': [
]

mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
assert mock_set_custom_attribute.call_count == 3

@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
def test_saml_config_signal_handlers_with_error(self, mock_set_custom_attribute):
"""
Test error handling when signal handlers encounter an exception.

Verifies that error information is properly captured when provider updates fail.
"""
error_message = "Test error"
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=True):
# Simulate an exception in the provider config update logic
with mock.patch(
'common.djangoapps.third_party_auth.models.SAMLProviderConfig.objects.current_set',
side_effect=Exception(error_message)
):
self.saml_config.entity_id = 'https://updated.example.com'
self.saml_config.save()

expected_calls = [
call('saml_config_signal.enabled', True),
call('saml_config_signal.new_config_id', 'CONFIG_ID'),
call('saml_config_signal.new_config_id', self.saml_config.id),
call('saml_config_signal.slug', 'test-config'),
],
'expected_call_count': 4, # includes error_message call
'error_message': 'Test error',
},
]

mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False)
assert mock_set_custom_attribute.call_count == 4

# Verify error message was logged
mock_set_custom_attribute.assert_any_call(
'saml_config_signal.error_message',
mock.ANY
)
error_calls = [
call for call in mock_set_custom_attribute.mock_calls
if call[1][0] == 'saml_config_signal.error_message'
]
assert error_message in error_calls[0][1][1], (
f"Expected '{error_message}' in error message, "
f"got: {error_calls[0][1][1]}"
)

def _get_current_provider(self, slug):
"""
Helper to get current version of provider by slug.
"""
return SAMLProviderConfig.objects.current_set().get(slug=slug)

def _get_site(self, site_id):
"""
Helper to get site by ID (1 = site1, 2 = site2).
"""
if site_id == 1:
return self.site1
elif site_id == 2:
return self.site2
else:
raise ValueError(f"Unexpected site_id: {site_id}.")

@ddt.data(
# Args: provider_site_id, provider_slug, signal_saml_site_id, signal_saml_slug, is_provider_updated

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Args: provider_site_id, provider_slug, signal_saml_site_id, signal_saml_slug, is_provider_updated
# Args: provider_site_id, provider_slug, signal_saml_site_id, signal_saml_slug, should_provider_update

Is this more apt ?

Or maybe will_provider_update ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are fine. Its just which will sounds more clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, provider_update_expected? I'm good with any option, include leaving as-is. I'll leave it up to you two to decide.

# All tests: provider's saml_configuration has site_id=1, slug='slug'
# Signal matches provider's saml config and should update
(1, 'slug', 1, 'slug', True), # Same site, same slug
(2, 'slug', 1, 'slug', True), # Cross-site provider, matching saml config
(1, 'provider-slug', 1, 'slug', True), # Different provider slug, matching saml config
# Signal does not match provider's saml config and should not update
(1, 'slug', 2, 'slug', False), # Different saml config site
(2, 'slug', 2, 'slug', False), # Different saml config site (cross-site)
(1, 'provider-slug', 1, 'provider-slug', False), # Different saml config slug
(2, 'provider-slug', 1, 'provider-slug', False), # Different saml config slug (cross-site)
)
@ddt.unpack
@mock.patch('common.djangoapps.third_party_auth.signals.handlers.set_custom_attribute')
def test_saml_config_signal_handlers(
self, mock_set_custom_attribute, enabled, simulate_error,
description, expected_calls, expected_call_count, error_message=None):
@override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=True)
def test_saml_provider_config_updates(self, provider_site_id, provider_slug,
signal_saml_site_id, signal_saml_slug, is_provider_updated,
mock_set_custom_attribute):
"""
Test SAML configuration signal handlers under different conditions.
Test SAML provider config updates under different scenarios.

Tests that providers are updated only when the signal's SAML configuration
matches the provider's existing SAML configuration (by site and slug).
"""
with override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=enabled):
if simulate_error:
# Simulate an exception in the provider config update logic
with mock.patch(
'common.djangoapps.third_party_auth.models.SAMLProviderConfig.objects.current_set',
side_effect=Exception(error_message)
):
self.saml_config.entity_id = 'https://updated.example.com'
self.saml_config.save()
else:
self.saml_config.entity_id = 'https://updated.example.com'
self.saml_config.save()
provider_site = self._get_site(provider_site_id)
signal_saml_site = self._get_site(signal_saml_site_id)

provider = SAMLProviderConfigFactory(
slug=provider_slug,
site=provider_site,
saml_configuration=self.existing_saml_config
)
original_config_id = provider.saml_configuration_id

expected_calls_with_id = []
for call_obj in expected_calls:
args = list(call_obj[1])
if args[1] == 'CONFIG_ID':
args[1] = self.saml_config.id
expected_calls_with_id.append(call(args[0], args[1]))
new_saml_config = SAMLConfigurationFactory(
site=signal_saml_site,
slug=signal_saml_slug,
entity_id='https://new.example.com'
)

# Verify expected calls were made
mock_set_custom_attribute.assert_has_calls(expected_calls_with_id, any_order=False)
current_provider = self._get_current_provider(provider_slug)

# Verify total call count
assert mock_set_custom_attribute.call_count == expected_call_count, (
f"Expected {expected_call_count} calls for {description}, "
f"got {mock_set_custom_attribute.call_count}"
)
mock_set_custom_attribute.assert_any_call('saml_config_signal.enabled', True)
mock_set_custom_attribute.assert_any_call('saml_config_signal.new_config_id', new_saml_config.id)
mock_set_custom_attribute.assert_any_call('saml_config_signal.slug', signal_saml_slug)
Comment on lines +156 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] Using expected_calls = [...] like you do elsewhere makes this easier to read, because there is less redundant code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure Robert Will Look Into It And Test It Tommorow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, @ktyagiapphelix2u, since this is optional, consider doing it on the follow-up PR. I will merge so you can work from master again and open a new PR.


if is_provider_updated:
mock_set_custom_attribute.assert_any_call('saml_config_signal.updated_count', 1)
self.assertEqual(current_provider.saml_configuration_id, new_saml_config.id,
"Provider should be updated when signal SAML config matches")
else:
mock_set_custom_attribute.assert_any_call('saml_config_signal.updated_count', 0)
self.assertEqual(current_provider.saml_configuration_id, original_config_id,
"Provider should NOT be updated when signal SAML config doesn't match")

@ddt.data(
# Args: provider_site_id, provider_slug, signal_saml_site_id, signal_saml_slug
# All tests: provider's saml config is None and should never be updated
(1, 'slug', 1, 'default'),
(1, 'default', 1, 'default'),
(2, 'slug', 1, 'default'),
)
@ddt.unpack
@override_settings(ENABLE_SAML_CONFIG_SIGNAL_HANDLERS=True)
def test_saml_provider_with_null_config_not_updated(self, provider_site_id, provider_slug,
signal_saml_site_id, signal_saml_slug):
"""
Test that providers with NULL SAML configuration are never updated by signal handler.

This is critical for fallback authentication scenarios where providers
intentionally have no SAML configuration.
"""
provider_site = self._get_site(provider_site_id)
signal_saml_site = self._get_site(signal_saml_site_id)

null_provider = SAMLProviderConfigFactory(
slug=provider_slug,
site=provider_site,
saml_configuration=None
)

new_saml_config = SAMLConfigurationFactory(
site=signal_saml_site,
slug=signal_saml_slug,
entity_id='https://new.example.com'
)

# If error is expected, verify error message was logged
if error_message:
mock_set_custom_attribute.assert_any_call(
'saml_config_signal.error_message',
mock.ANY
)
error_calls = [
call for call in mock_set_custom_attribute.mock_calls
if call[1][0] == 'saml_config_signal.error_message'
]
assert error_message in error_calls[0][1][1], (
f"Expected '{error_message}' in error message for {description}, "
f"got: {error_calls[0][1][1]}"
)
current_provider = self._get_current_provider(provider_slug)
self.assertIsNone(current_provider.saml_configuration_id,
"Provider with NULL SAML config should never be updated")
Loading