diff --git a/application/src/main/java/run/halo/app/security/device/DeviceServiceImpl.java b/application/src/main/java/run/halo/app/security/device/DeviceServiceImpl.java index c2ad61f178..001265a790 100644 --- a/application/src/main/java/run/halo/app/security/device/DeviceServiceImpl.java +++ b/application/src/main/java/run/halo/app/security/device/DeviceServiceImpl.java @@ -40,7 +40,7 @@ public class DeviceServiceImpl implements DeviceService { @Override public Mono loginSuccess(ServerWebExchange exchange, Authentication authentication) { - return updateExistingDevice(exchange) + return updateExistingDevice(exchange, authentication) .switchIfEmpty(createDevice(exchange, authentication) .flatMap(client::create) .doOnNext(device -> { @@ -61,10 +61,7 @@ public Mono changeSessionId(ServerWebExchange exchange) { .map(context -> context.getAuthentication().getName()) .flatMap(username -> { var deviceId = deviceIdCookie.getValue(); - return updateWithRetry(deviceId, device -> { - if (!device.getSpec().getPrincipalName().equals(username)) { - return Mono.empty(); - } + return updateWithRetry(deviceId, username, device -> { var oldSessionId = device.getSpec().getSessionId(); return exchange.getSession() .filter(session -> !session.getId().equals(oldSessionId)) @@ -78,9 +75,10 @@ public Mono changeSessionId(ServerWebExchange exchange) { }); } - private Mono updateWithRetry(String deviceId, + private Mono updateWithRetry(String deviceId, String username, Function> updateFunction) { return Mono.defer(() -> client.fetch(Device.class, deviceId) + .filter(device -> device.getSpec().getPrincipalName().equals(username)) .flatMap(updateFunction) .flatMap(client::update) ) @@ -88,38 +86,41 @@ private Mono updateWithRetry(String deviceId, .filter(OptimisticLockingFailureException.class::isInstance)); } - private Mono updateExistingDevice(ServerWebExchange exchange) { + private Mono updateExistingDevice(ServerWebExchange exchange, + Authentication authentication) { var deviceIdCookie = deviceCookieResolver.resolveCookie(exchange); if (deviceIdCookie == null) { return Mono.empty(); } - return updateWithRetry(deviceIdCookie.getValue(), (Device existingDevice) -> { - var sessionId = existingDevice.getSpec().getSessionId(); - return exchange.getSession() - .flatMap(session -> { - var userAgent = - exchange.getRequest().getHeaders().getFirst(HttpHeaders.USER_AGENT); - var deviceUa = existingDevice.getSpec().getUserAgent(); - if (!StringUtils.equals(deviceUa, userAgent)) { - // User agent changed, create a new device - return Mono.empty(); - } - return Mono.just(session); - }) - .flatMap(session -> { - if (session.getId().equals(sessionId)) { + var principalName = authentication.getName(); + return updateWithRetry(deviceIdCookie.getValue(), principalName, + (Device existingDevice) -> { + var sessionId = existingDevice.getSpec().getSessionId(); + return exchange.getSession() + .flatMap(session -> { + var userAgent = + exchange.getRequest().getHeaders().getFirst(HttpHeaders.USER_AGENT); + var deviceUa = existingDevice.getSpec().getUserAgent(); + if (!StringUtils.equals(deviceUa, userAgent)) { + // User agent changed, create a new device + return Mono.empty(); + } return Mono.just(session); - } - return sessionRepository.deleteById(sessionId).thenReturn(session); - }) - .map(session -> { - existingDevice.getSpec().setSessionId(session.getId()); - existingDevice.getSpec().setLastAccessedTime(session.getLastAccessTime()); - existingDevice.getSpec().setLastAuthenticatedTime(Instant.now()); - return existingDevice; - }) - .flatMap(this::removeRememberMeToken); - }); + }) + .flatMap(session -> { + if (session.getId().equals(sessionId)) { + return Mono.just(session); + } + return sessionRepository.deleteById(sessionId).thenReturn(session); + }) + .map(session -> { + existingDevice.getSpec().setSessionId(session.getId()); + existingDevice.getSpec().setLastAccessedTime(session.getLastAccessTime()); + existingDevice.getSpec().setLastAuthenticatedTime(Instant.now()); + return existingDevice; + }) + .flatMap(this::removeRememberMeToken); + }); } @Override