diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py index e3891cefa08f..afe369c2ade0 100644 --- a/common/djangoapps/third_party_auth/management/commands/saml.py +++ b/common/djangoapps/third_party_auth/management/commands/saml.py @@ -6,6 +6,7 @@ import logging from django.core.management.base import BaseCommand, CommandError +from edx_django_utils.monitoring import set_custom_attribute from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration @@ -18,34 +19,24 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs") parser.add_argument( - '--fix-references', + '--run-checks', action='store_true', - help="Fix SAMLProviderConfig references to use current SAMLConfiguration versions" - ) - parser.add_argument( - '--site-id', - type=int, - help='Only fix configurations for a specific site ID (to be used with --fix-references)' - ) - parser.add_argument( - '--dry-run', - action='store_true', - help='Show what would be changed, but do not make any changes.' + help="Run checks on SAMLProviderConfig configurations and report potential issues" ) def handle(self, *args, **options): should_pull_saml_metadata = options.get('pull', False) - should_fix_references = options.get('fix_references', False) - dry_run = options.get('dry_run', False) - - if not should_pull_saml_metadata and not should_fix_references: - raise CommandError("Command must be used with '--pull' or '--fix-references' option.") + should_run_checks = options.get('run_checks', False) if should_pull_saml_metadata: self._handle_pull_metadata() + return - if should_fix_references: - self._handle_fix_references(options, dry_run=dry_run) + if should_run_checks: + self._handle_run_checks() + return + + raise CommandError("Command must be used with '--pull' or '--run-checks' option.") def _handle_pull_metadata(self): """ @@ -76,45 +67,139 @@ def _handle_pull_metadata(self): ) ) - def _handle_fix_references(self, options, dry_run=False): - """Handle the --fix-references option for fixing outdated SAML configuration references.""" - site_id = options.get('site_id') - updated_count = 0 + def _handle_run_checks(self): + """ + Handle the --run-checks option for checking SAMLProviderConfig configuration issues. + + This is a report-only command. It identifies potential configuration problems such as: + - Outdated SAMLConfiguration references (provider pointing to old config version) + - Site ID mismatches between SAMLProviderConfig and its SAMLConfiguration + - Slug mismatches (except 'default' slugs) # noqa: E501 + - SAMLProviderConfig objects with null SAMLConfiguration references (informational) + + Includes observability attributes for monitoring. + """ + # Set custom attributes for monitoring the check operation + # .. custom_attribute_name: saml_management_command.operation + # .. custom_attribute_description: Records current SAML operation ('run_checks'). + set_custom_attribute('saml_management_command.operation', 'run_checks') + + metrics = self._check_provider_configurations() + self._report_check_summary(metrics) + + def _check_provider_configurations(self): + """ + Check each provider configuration for potential issues. + Returns a dictionary of metrics about the found issues. + """ + outdated_count = 0 + site_mismatch_count = 0 + slug_mismatch_count = 0 + null_config_count = 0 error_count = 0 + total_providers = 0 - # Filter by site if specified provider_configs = SAMLProviderConfig.objects.current_set() - if site_id: - provider_configs = provider_configs.filter(site_id=site_id) + + self.stdout.write(self.style.SUCCESS("SAML Configuration Check Report")) + self.stdout.write("=" * 50) + self.stdout.write("") for provider_config in provider_configs: - if provider_config.saml_configuration: - try: - current_config = SAMLConfiguration.current( - provider_config.site_id, - provider_config.saml_configuration.slug - ) + total_providers += 1 + provider_info = ( + f"Provider (id={provider_config.id}, name={provider_config.name}, " + f"slug={provider_config.slug}, site_id={provider_config.site_id})" + ) + + if not provider_config.saml_configuration: + self.stdout.write( + f"[INFO] {provider_info} has no SAML configuration because " + "a matching default was not found." + ) + null_config_count += 1 + continue - if current_config and current_config.id != provider_config.saml_configuration_id: + try: + current_config = SAMLConfiguration.current( + provider_config.saml_configuration.site_id, + provider_config.saml_configuration.slug + ) + + # Check for outdated configuration references + if current_config: + if current_config.id != provider_config.saml_configuration_id: self.stdout.write( - f"Provider '{provider_config.slug}' (site {provider_config.site_id}) " - f"has outdated config (ID: {provider_config.saml_configuration_id} -> {current_config.id})" + f"[WARNING] {provider_info} " + f"has outdated SAML config (id={provider_config.saml_configuration_id} which " + f"should be updated to the current SAML config (id={current_config.id})." ) + outdated_count += 1 + + if provider_config.saml_configuration.site_id != provider_config.site_id: + config_site_id = provider_config.saml_configuration.site_id + provider_site_id = provider_config.site_id + self.stdout.write( + f"[WARNING] {provider_info} " + f"SAML config (id={provider_config.saml_configuration_id}, site_id={config_site_id}) " + "does not match the provider's site_id." + ) + site_mismatch_count += 1 - if not dry_run: - provider_config.saml_configuration = current_config - provider_config.save() - updated_count += 1 + saml_configuration_slug = provider_config.saml_configuration.slug + provider_config_slug = provider_config.slug - except Exception as e: # pylint: disable=broad-except - self.stderr.write( - f"Error processing provider '{provider_config.slug}': {e}" + if saml_configuration_slug not in (provider_config_slug, 'default'): + self.stdout.write( + f"[WARNING] {provider_info} " + f"SAML config (id={provider_config.saml_configuration_id}, slug='{saml_configuration_slug}') " + "does not match the provider's slug." ) - error_count += 1 + slug_mismatch_count += 1 + + except Exception as e: # pylint: disable=broad-except + self.stderr.write(f"[ERROR] Error processing {provider_info}: {e}") + error_count += 1 + + metrics = { + 'total_providers': {'count': total_providers, 'requires_attention': False}, + 'outdated_count': {'count': outdated_count, 'requires_attention': True}, + 'site_mismatch_count': {'count': site_mismatch_count, 'requires_attention': True}, + 'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': True}, + 'null_config_count': {'count': null_config_count, 'requires_attention': False}, + 'error_count': {'count': error_count, 'requires_attention': True}, + } + + for key, metric_data in metrics.items(): + # .. custom_attribute_name: saml_management_command.{key} + # .. custom_attribute_description: Records metrics from SAML configuration checks. + set_custom_attribute(f'saml_management_command.{key}', metric_data['count']) + + return metrics + + def _report_check_summary(self, metrics): + """ + Print a summary of the check results and set the total_requiring_attention custom attribute. + """ + total_requiring_attention = sum( + metric_data['count'] for metric_data in metrics.values() + if metric_data['requires_attention'] + ) - style = self.style.SUCCESS - if dry_run: - msg = f"[DRY RUN] Would update {updated_count} provider configurations. {error_count} errors encountered." + # .. custom_attribute_name: saml_management_command.total_requiring_attention + # .. custom_attribute_description: The total number of configuration issues requiring attention. + set_custom_attribute('saml_management_command.total_requiring_attention', total_requiring_attention) + + self.stdout.write(self.style.SUCCESS("CHECK SUMMARY:")) + self.stdout.write(f" Providers checked: {metrics['total_providers']['count']}") + self.stdout.write(f" Null configs: {metrics['null_config_count']['count']}") + + if total_requiring_attention > 0: + self.stdout.write("\nIssues requiring attention:") + self.stdout.write(f" Outdated: {metrics['outdated_count']['count']}") + self.stdout.write(f" Site mismatches: {metrics['site_mismatch_count']['count']}") + self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}") + self.stdout.write(f" Errors: {metrics['error_count']['count']}") + self.stdout.write(f"\nTotal issues requiring attention: {total_requiring_attention}") else: - msg = f"Updated {updated_count} provider configurations. {error_count} errors encountered." - self.stdout.write(style(msg)) + self.stdout.write(self.style.SUCCESS("\nNo configuration issues found!")) diff --git a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py index 168d88ae3b21..6963d5dcd0d5 100644 --- a/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py +++ b/common/djangoapps/third_party_auth/management/commands/tests/test_saml.py @@ -8,7 +8,7 @@ from io import StringIO from unittest import mock -from ddt import ddt, data, unpack +from ddt import ddt from django.contrib.sites.models import Site from django.core.management import call_command from django.core.management.base import CommandError @@ -18,8 +18,6 @@ from openedx.core.djangolib.testing.utils import CacheIsolationTestCase, skip_unless_lms from common.djangoapps.third_party_auth.tests.factories import SAMLConfigurationFactory, SAMLProviderConfigFactory -from common.djangoapps.third_party_auth.models import SAMLProviderConfig - def mock_get(status_code=200): """ @@ -64,6 +62,7 @@ def setUp(self): self.stdout = StringIO() self.site = Site.objects.get_current() + self.other_site = Site.objects.create(domain='other.example.com', name='Other Site') # We are creating SAMLConfiguration instance here so that there is always at-least one # disabled saml configuration instance, this is done to verify that disabled configurations are @@ -82,9 +81,9 @@ def setUp(self): metadata_source='https://www.testshib.org/metadata/testshib-providers.xml', ) - def _setup_test_configs_for_fix_references(self): + def _setup_test_configs_for_run_checks(self): """ - Helper method to create SAML configurations for fix-references tests. + Helper method to create SAML configurations for run-checks tests. Returns tuple of (old_config, new_config, provider_config) @@ -108,7 +107,7 @@ def _setup_test_configs_for_fix_references(self): entity_id='https://updated.example.com' ) - # Create a provider config that references the old config for fix-references tests + # Create a provider config that references the old config for run-checks tests test_provider_config = SAMLProviderConfigFactory.create( site=self.site, slug='test-provider', @@ -148,14 +147,10 @@ def test_raises_command_error_for_invalid_arguments(self): This test would fail with an error if ValueError is raised. """ - # Call `saml` command without any argument so that it raises a CommandError - with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."): + # Call `saml` command without any arguments so that it raises a CommandError + with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--run-checks' option."): call_command("saml") - # Call `saml` command without any argument so that it raises a CommandError - with self.assertRaisesMessage(CommandError, "Command must be used with '--pull' or '--fix-references' option."): - call_command("saml", pull=False) - def test_no_saml_configuration(self): """ Test that management command completes without errors and logs correct information when no @@ -334,59 +329,144 @@ def test_xml_parse_exceptions(self, mocked_get): call_command("saml", pull=True, stdout=self.stdout) assert expected in self.stdout.getvalue() - @data( - (True, '[DRY RUN]', 'should not update provider configs'), - (False, '', 'should create new provider config for new version') - ) - @unpack - def test_fix_references(self, dry_run, expected_output_marker, test_description): + def _run_checks_command(self): """ - Test the --fix-references command with and without --dry-run option. + Helper method to run the --run-checks command and return output. + """ + out = StringIO() + call_command('saml', '--run-checks', stdout=out) + return out.getvalue() - Args: - dry_run (bool): Whether to run with --dry-run flag - expected_output_marker (str): Expected marker in output - test_description (str): Description of what the test should do + @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') + def test_run_checks_outdated_configs(self, mock_set_custom_attribute): """ - old_config, new_config, test_provider_config = self._setup_test_configs_for_fix_references() - new_config_id = new_config.id - original_config_id = old_config.id + Test the --run-checks command identifies outdated configurations. + """ + old_config, new_config, test_provider_config = self._setup_test_configs_for_run_checks() - out = StringIO() - if dry_run: - call_command('saml', '--fix-references', '--dry-run', stdout=out) - else: - call_command('saml', '--fix-references', stdout=out) + output = self._run_checks_command() + + self.assertIn('[WARNING]', output) + self.assertIn('test-provider', output) + self.assertIn( + f'id={old_config.id} which should be updated to the current SAML config (id={new_config.id})', + output + ) + self.assertIn('CHECK SUMMARY:', output) + self.assertIn('Providers checked: 2', output) + self.assertIn('Outdated: 1', output) + + # Check key observability calls + expected_calls = [ + mock.call('saml_management_command.operation', 'run_checks'), + mock.call('saml_management_command.total_providers', 2), + mock.call('saml_management_command.outdated_count', 1), + mock.call('saml_management_command.site_mismatch_count', 0), + mock.call('saml_management_command.slug_mismatch_count', 1), + mock.call('saml_management_command.null_config_count', 1), + mock.call('saml_management_command.error_count', 0), + mock.call('saml_management_command.total_requiring_attention', 2), + ] + mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) + + @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') + def test_run_checks_site_mismatches(self, mock_set_custom_attribute): + """ + Test the --run-checks command identifies site ID mismatches. + """ + config = SAMLConfigurationFactory.create( + site=self.other_site, + slug='test-config', + entity_id='https://example.com' + ) + + SAMLProviderConfigFactory.create( + site=self.site, + slug='test-provider', + saml_configuration=config + ) - output = out.getvalue() + output = self._run_checks_command() + self.assertIn('[WARNING]', output) self.assertIn('test-provider', output) - if expected_output_marker: - self.assertIn(expected_output_marker, output) - - test_provider_config.refresh_from_db() - - if dry_run: - # For dry run, ensure the provider config was NOT updated - self.assertEqual( - test_provider_config.saml_configuration_id, - original_config_id, - "Provider config should not be updated in dry run mode" - ) - else: - # For actual run, check that a new provider config was created - new_provider = SAMLProviderConfig.objects.filter( - site=self.site, - slug='test-provider', - saml_configuration_id=new_config_id - ).exclude(id=test_provider_config.id).first() - - self.assertIsNotNone(new_provider, "New provider config should be created") - self.assertEqual(new_provider.saml_configuration_id, new_config_id) - - # Original provider config should still reference the old config - self.assertEqual( - test_provider_config.saml_configuration_id, - original_config_id, - "Original provider config should still reference old config" - ) + self.assertIn('does not match the provider\'s site_id', output) + + # Check observability calls + expected_calls = [ + mock.call('saml_management_command.operation', 'run_checks'), + mock.call('saml_management_command.total_providers', 2), + mock.call('saml_management_command.outdated_count', 0), + mock.call('saml_management_command.site_mismatch_count', 1), + mock.call('saml_management_command.slug_mismatch_count', 1), + mock.call('saml_management_command.null_config_count', 1), + mock.call('saml_management_command.error_count', 0), + mock.call('saml_management_command.total_requiring_attention', 2), + ] + mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) + + @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') + def test_run_checks_slug_mismatches(self, mock_set_custom_attribute): + """ + Test the --run-checks command identifies slug mismatches. + """ + config = SAMLConfigurationFactory.create( + site=self.site, + slug='config-slug', + entity_id='https://example.com' + ) + + SAMLProviderConfigFactory.create( + site=self.site, + slug='provider-slug', + saml_configuration=config + ) + + output = self._run_checks_command() + + self.assertIn('[WARNING]', output) + self.assertIn('provider-slug', output) + self.assertIn('does not match the provider\'s slug', output) + + # Check observability calls + expected_calls = [ + mock.call('saml_management_command.operation', 'run_checks'), + mock.call('saml_management_command.total_providers', 2), + mock.call('saml_management_command.outdated_count', 0), + mock.call('saml_management_command.site_mismatch_count', 0), + mock.call('saml_management_command.slug_mismatch_count', 1), + mock.call('saml_management_command.null_config_count', 1), + mock.call('saml_management_command.error_count', 0), + mock.call('saml_management_command.total_requiring_attention', 1), + ] + mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) + + @mock.patch('common.djangoapps.third_party_auth.management.commands.saml.set_custom_attribute') + def test_run_checks_null_configurations(self, mock_set_custom_attribute): + """ + Test the --run-checks command identifies providers with null configurations. + """ + SAMLProviderConfigFactory.create( + site=self.site, + slug='null-provider', + saml_configuration=None + ) + + output = self._run_checks_command() + + self.assertIn('[INFO]', output) + self.assertIn('null-provider', output) + self.assertIn('has no SAML configuration because a matching default was not found', output) + + # Check observability calls + expected_calls = [ + mock.call('saml_management_command.operation', 'run_checks'), + mock.call('saml_management_command.total_providers', 2), + mock.call('saml_management_command.outdated_count', 0), + mock.call('saml_management_command.site_mismatch_count', 0), + mock.call('saml_management_command.slug_mismatch_count', 0), + mock.call('saml_management_command.null_config_count', 2), + mock.call('saml_management_command.error_count', 0), + mock.call('saml_management_command.total_requiring_attention', 0), + ] + mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) diff --git a/common/djangoapps/third_party_auth/signals/tests/test_handlers.py b/common/djangoapps/third_party_auth/signals/tests/test_handlers.py index 7875f0fcfa57..1dd2fd6d7d01 100644 --- a/common/djangoapps/third_party_auth/signals/tests/test_handlers.py +++ b/common/djangoapps/third_party_auth/signals/tests/test_handlers.py @@ -153,9 +153,12 @@ def test_saml_provider_config_updates(self, provider_site_id, provider_slug, current_provider = self._get_current_provider(provider_slug) - mock_set_custom_attribute.assert_any_call('saml_config_signal.enabled', True) - mock_set_custom_attribute.assert_any_call('saml_config_signal.new_config_id', new_saml_config.id) - mock_set_custom_attribute.assert_any_call('saml_config_signal.slug', signal_saml_slug) + expected_calls = [ + call('saml_config_signal.enabled', True), + call('saml_config_signal.new_config_id', new_saml_config.id), + call('saml_config_signal.slug', signal_saml_slug), + ] + mock_set_custom_attribute.assert_has_calls(expected_calls, any_order=False) if is_provider_updated: mock_set_custom_attribute.assert_any_call('saml_config_signal.updated_count', 1)