diff --git a/identity-server/src/IdentityServer/Hosting/ServerSideSessionCleanupHost.cs b/identity-server/src/IdentityServer/Hosting/ServerSideSessionCleanupHost.cs index 76b152e36..3f8886a0b 100644 --- a/identity-server/src/IdentityServer/Hosting/ServerSideSessionCleanupHost.cs +++ b/identity-server/src/IdentityServer/Hosting/ServerSideSessionCleanupHost.cs @@ -11,155 +11,102 @@ namespace Microsoft.Extensions.DependencyInjection; /// -/// Helper to cleanup expired server side sessions. +/// Helper to clean up expired server side sessions. /// -public class ServerSideSessionCleanupHost : IHostedService +public class ServerSideSessionCleanupHost( + IServiceProvider serviceProvider, + IdentityServerOptions options, + ILogger logger) : BackgroundService { - private readonly IServiceProvider _serviceProvider; - private readonly IdentityServerOptions _options; - private readonly ILogger _logger; - - private CancellationTokenSource _source; - - /// - /// Constructor for ServerSideSessionCleanupHost. - /// - /// - /// - /// - public ServerSideSessionCleanupHost(IServiceProvider serviceProvider, IdentityServerOptions options, ILogger logger) + /// + public override Task StartAsync(CancellationToken cancellationToken) => + !options.ServerSideSessions.RemoveExpiredSessions + ? Task.CompletedTask + : base.StartAsync(cancellationToken); + + /// + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); - _options = options ?? throw new ArgumentNullException(nameof(options)); - _logger = logger; - } - - /// - /// Starts the token cleanup polling. - /// - public Task StartAsync(CancellationToken cancellationToken) - { - if (_options.ServerSideSessions.RemoveExpiredSessions) - { - if (_source != null) - { - throw new InvalidOperationException("Already started. Call Stop first."); - } - - _logger.LogDebug("Starting server-side session removal"); - - _source = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - - Task.Factory.StartNew(() => StartInternalAsync(_source.Token), cancellationToken, TaskCreationOptions.None, TaskScheduler.Default); - } - - return Task.CompletedTask; - } - - /// - /// Stops the token cleanup polling. - /// - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_options.ServerSideSessions.RemoveExpiredSessions) - { - if (_source == null) - { - throw new InvalidOperationException("Not started. Call Start first."); - } - - _logger.LogDebug("Stopping server-side session removal"); - - await _source.CancelAsync(); - _source = null; - } - } + logger.LogDebug("Starting server-side session removal"); - private async Task StartInternalAsync(CancellationToken cancellationToken) - { - var removalFrequencySeconds = (int)_options.ServerSideSessions.RemoveExpiredSessionsFrequency.TotalSeconds; + var removalFrequencySeconds = (int)options.ServerSideSessions.RemoveExpiredSessionsFrequency.TotalSeconds; // Start the first run at a random interval. - var delay = _options.ServerSideSessions.FuzzExpiredSessionRemovalStart + var delay = options.ServerSideSessions.FuzzExpiredSessionRemovalStart #pragma warning disable CA5394 // Randomness for security does not apply here ? TimeSpan.FromSeconds(Random.Shared.Next(removalFrequencySeconds)) #pragma warning restore CA5394 - : _options.ServerSideSessions.RemoveExpiredSessionsFrequency; + : options.ServerSideSessions.RemoveExpiredSessionsFrequency; - while (true) + while (!stoppingToken.IsCancellationRequested) { - if (cancellationToken.IsCancellationRequested) - { - _logger.LogDebug("CancellationRequested. Exiting."); - break; - } - try { - await Task.Delay(delay, cancellationToken); + await Task.Delay(delay, stoppingToken); } catch (TaskCanceledException) { - _logger.LogDebug("TaskCanceledException. Exiting."); + logger.LogDebug("TaskCanceledException. Exiting."); break; } catch (Exception ex) { - _logger.LogError("Task.Delay exception: {ExceptionMessage}. Exiting.", ex.Message); + logger.LogError("Task.Delay exception: {ExceptionMessage}. Exiting.", ex.Message); break; } - if (cancellationToken.IsCancellationRequested) + if (stoppingToken.IsCancellationRequested) { - _logger.LogDebug("CancellationRequested. Exiting."); break; } - await RunAsync(cancellationToken); + await RunAsync(stoppingToken); - delay = _options.ServerSideSessions.RemoveExpiredSessionsFrequency; + delay = options.ServerSideSessions.RemoveExpiredSessionsFrequency; } + + logger.LogDebug("Stopping server-side session removal"); } private async Task RunAsync(CancellationToken cancellationToken = default) { // this is here for testing - if (!_options.ServerSideSessions.RemoveExpiredSessions) + if (!options.ServerSideSessions.RemoveExpiredSessions) { return; } try { - await using (var serviceScope = _serviceProvider.GetRequiredService().CreateAsyncScope()) + await using var serviceScope = serviceProvider.GetRequiredService().CreateAsyncScope(); + var scopedLogger = serviceScope.ServiceProvider.GetRequiredService>(); + var scopedOptions = serviceScope.ServiceProvider.GetRequiredService(); + var serverSideTicketStore = serviceScope.ServiceProvider.GetRequiredService(); + var sessionCoordinationService = serviceScope.ServiceProvider.GetRequiredService(); + + var found = int.MaxValue; + + while (found > 0) { - var logger = serviceScope.ServiceProvider.GetRequiredService>(); - var options = serviceScope.ServiceProvider.GetRequiredService(); - var serverSideTicketStore = serviceScope.ServiceProvider.GetRequiredService(); - var sessionCoordinationService = serviceScope.ServiceProvider.GetRequiredService(); + var sessions = await serverSideTicketStore.GetAndRemoveExpiredSessionsAsync(scopedOptions.ServerSideSessions.RemoveExpiredSessionsBatchSize, cancellationToken); + found = sessions.Count; + + if (found <= 0) + { + continue; + } - var found = int.MaxValue; + scopedLogger.LogDebug("Processing expiration for {count} expired server-side sessions.", found); - while (found > 0) + foreach (var session in sessions) { - var sessions = await serverSideTicketStore.GetAndRemoveExpiredSessionsAsync(options.ServerSideSessions.RemoveExpiredSessionsBatchSize, cancellationToken); - found = sessions.Count; - - if (found > 0) - { - logger.LogDebug("Processing expiration for {count} expired server-side sessions.", found); - - foreach (var session in sessions) - { - await sessionCoordinationService.ProcessExpirationAsync(session); - } - } + await sessionCoordinationService.ProcessExpirationAsync(session); } } } catch (Exception ex) { - _logger.LogError(ex, "Exception removing expired sessions"); + logger.LogError(ex, "Exception removing expired sessions"); } } }