Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 133 additions & 48 deletions common/djangoapps/third_party_auth/management/commands/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging

from django.core.management.base import BaseCommand, CommandError
from edx_django_utils.monitoring import set_custom_attribute

from common.djangoapps.third_party_auth.tasks import fetch_saml_metadata
from common.djangoapps.third_party_auth.models import SAMLProviderConfig, SAMLConfiguration
Expand All @@ -18,34 +19,24 @@ class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs")
parser.add_argument(
'--fix-references',
'--run-checks',
action='store_true',
help="Fix SAMLProviderConfig references to use current SAMLConfiguration versions"
)
parser.add_argument(
'--site-id',
type=int,
help='Only fix configurations for a specific site ID (to be used with --fix-references)'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Show what would be changed, but do not make any changes.'
help="Run checks on SAMLProviderConfig configurations and report potential issues"
)

def handle(self, *args, **options):
should_pull_saml_metadata = options.get('pull', False)
should_fix_references = options.get('fix_references', False)
dry_run = options.get('dry_run', False)

if not should_pull_saml_metadata and not should_fix_references:
raise CommandError("Command must be used with '--pull' or '--fix-references' option.")
should_run_checks = options.get('run_checks', False)

if should_pull_saml_metadata:
self._handle_pull_metadata()
return

if should_fix_references:
self._handle_fix_references(options, dry_run=dry_run)
if should_run_checks:
self._handle_run_checks()
return

raise CommandError("Command must be used with '--pull' or '--run-checks' option.")

def _handle_pull_metadata(self):
"""
Expand Down Expand Up @@ -76,45 +67,139 @@ def _handle_pull_metadata(self):
)
)

def _handle_fix_references(self, options, dry_run=False):
"""Handle the --fix-references option for fixing outdated SAML configuration references."""
site_id = options.get('site_id')
updated_count = 0
def _handle_run_checks(self):
"""
Handle the --run-checks option for checking SAMLProviderConfig configuration issues.

This is a report-only command. It identifies potential configuration problems such as:
- Outdated SAMLConfiguration references (provider pointing to old config version)
- Site ID mismatches between SAMLProviderConfig and its SAMLConfiguration
- Slug mismatches (except 'default' slugs) # noqa: E501
- SAMLProviderConfig objects with null SAMLConfiguration references (informational)

Includes observability attributes for monitoring.
"""
# Set custom attributes for monitoring the check operation
# .. custom_attribute_name: saml_management_command.operation
# .. custom_attribute_description: Records current SAML operation ('run_checks').
set_custom_attribute('saml_management_command.operation', 'run_checks')

metrics = self._check_provider_configurations()
self._report_check_summary(metrics)

def _check_provider_configurations(self):
"""
Check each provider configuration for potential issues.
Returns a dictionary of metrics about the found issues.
"""
outdated_count = 0
site_mismatch_count = 0
slug_mismatch_count = 0
null_config_count = 0
error_count = 0
total_providers = 0

# Filter by site if specified
provider_configs = SAMLProviderConfig.objects.current_set()
if site_id:
provider_configs = provider_configs.filter(site_id=site_id)

self.stdout.write(self.style.SUCCESS("SAML Configuration Check Report"))
self.stdout.write("=" * 50)
self.stdout.write("")

for provider_config in provider_configs:
if provider_config.saml_configuration:
try:
current_config = SAMLConfiguration.current(
provider_config.site_id,
provider_config.saml_configuration.slug
)
total_providers += 1
provider_info = (
f"Provider (id={provider_config.id}, name={provider_config.name}, "
f"slug={provider_config.slug}, site_id={provider_config.site_id})"
)

if not provider_config.saml_configuration:
self.stdout.write(
f"[INFO] {provider_info} has no SAML configuration because "
"a matching default was not found."
)
null_config_count += 1
continue

if current_config and current_config.id != provider_config.saml_configuration_id:
try:
current_config = SAMLConfiguration.current(
provider_config.saml_configuration.site_id,
provider_config.saml_configuration.slug
)

# Check for outdated configuration references
if current_config:
if current_config.id != provider_config.saml_configuration_id:
self.stdout.write(
f"Provider '{provider_config.slug}' (site {provider_config.site_id}) "
f"has outdated config (ID: {provider_config.saml_configuration_id} -> {current_config.id})"
f"[WARNING] {provider_info} "
f"has outdated SAML config (id={provider_config.saml_configuration_id} which "
f"should be updated to the current SAML config (id={current_config.id})."
)
outdated_count += 1

if provider_config.saml_configuration.site_id != provider_config.site_id:
config_site_id = provider_config.saml_configuration.site_id
provider_site_id = provider_config.site_id
self.stdout.write(
f"[WARNING] {provider_info} "
f"SAML config (id={provider_config.saml_configuration_id}, site_id={config_site_id}) "
"does not match the provider's site_id."
)
site_mismatch_count += 1

if not dry_run:
provider_config.saml_configuration = current_config
provider_config.save()
updated_count += 1
saml_configuration_slug = provider_config.saml_configuration.slug
provider_config_slug = provider_config.slug

except Exception as e: # pylint: disable=broad-except
self.stderr.write(
f"Error processing provider '{provider_config.slug}': {e}"
if saml_configuration_slug not in (provider_config_slug, 'default'):
self.stdout.write(
f"[WARNING] {provider_info} "
f"SAML config (id={provider_config.saml_configuration_id}, slug='{saml_configuration_slug}') "
"does not match the provider's slug."
)
error_count += 1
slug_mismatch_count += 1

except Exception as e: # pylint: disable=broad-except
self.stderr.write(f"[ERROR] Error processing {provider_info}: {e}")
error_count += 1

metrics = {
'total_providers': {'count': total_providers, 'requires_attention': False},
'outdated_count': {'count': outdated_count, 'requires_attention': True},
'site_mismatch_count': {'count': site_mismatch_count, 'requires_attention': True},
'slug_mismatch_count': {'count': slug_mismatch_count, 'requires_attention': True},
'null_config_count': {'count': null_config_count, 'requires_attention': False},
'error_count': {'count': error_count, 'requires_attention': True},
}

for key, metric_data in metrics.items():
# .. custom_attribute_name: saml_management_command.{key}
# .. custom_attribute_description: Records metrics from SAML configuration checks.
set_custom_attribute(f'saml_management_command.{key}', metric_data['count'])

return metrics

def _report_check_summary(self, metrics):
"""
Print a summary of the check results and set the total_requiring_attention custom attribute.
"""
total_requiring_attention = sum(
metric_data['count'] for metric_data in metrics.values()
if metric_data['requires_attention']
)

style = self.style.SUCCESS
if dry_run:
msg = f"[DRY RUN] Would update {updated_count} provider configurations. {error_count} errors encountered."
# .. custom_attribute_name: saml_management_command.total_requiring_attention
# .. custom_attribute_description: The total number of configuration issues requiring attention.
set_custom_attribute('saml_management_command.total_requiring_attention', total_requiring_attention)

self.stdout.write(self.style.SUCCESS("CHECK SUMMARY:"))
self.stdout.write(f" Providers checked: {metrics['total_providers']['count']}")
self.stdout.write(f" Null configs: {metrics['null_config_count']['count']}")

if total_requiring_attention > 0:
self.stdout.write("\nIssues requiring attention:")
self.stdout.write(f" Outdated: {metrics['outdated_count']['count']}")
self.stdout.write(f" Site mismatches: {metrics['site_mismatch_count']['count']}")
self.stdout.write(f" Slug mismatches: {metrics['slug_mismatch_count']['count']}")
self.stdout.write(f" Errors: {metrics['error_count']['count']}")
self.stdout.write(f"\nTotal issues requiring attention: {total_requiring_attention}")
else:
msg = f"Updated {updated_count} provider configurations. {error_count} errors encountered."
self.stdout.write(style(msg))
self.stdout.write(self.style.SUCCESS("\nNo configuration issues found!"))
Loading
Loading