Skip to content

Commit

Permalink
Add rate limiter for login endpoint (#4062)
Browse files Browse the repository at this point in the history
#### What type of PR is this?

/kind feature
/area core

#### What this PR does / why we need it:

This PR introduces https://github.com/resilience4j/resilience4j to archive the feature. The login endpoint has limited login failures at a rate of 3 per minute.

See #4044 for more.

#### Which issue(s) this PR fixes:

Fixes #4044

#### Special notes for your reviewer:

1. Start Halo.
2. Try to login with incorrect credential 4 times
3. Check the response.

#### Does this PR introduce a user-facing change?

```release-note
增加登录失败次数限制功能
```
  • Loading branch information
JohnNiang authored Jun 16, 2023
1 parent 350e54d commit 02369fb
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 84 deletions.
3 changes: 3 additions & 0 deletions api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ dependencies {
api "com.github.java-json-tools:json-patch"
api "org.thymeleaf.extras:thymeleaf-extras-springsecurity6"

api "io.github.resilience4j:resilience4j-spring-boot3"
api "io.github.resilience4j:resilience4j-reactor"

runtimeOnly 'io.r2dbc:r2dbc-h2'
runtimeOnly 'org.postgresql:postgresql'
runtimeOnly 'org.postgresql:r2dbc-postgresql'
Expand Down
11 changes: 0 additions & 11 deletions application/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ tasks.named('jar') {
enabled = false
}

ext {
commonsLang3 = "3.12.0"
base62 = "0.1.3"
pf4j = '3.9.0'
javaDiffUtils = "4.12"
jsoup = '1.15.3'
jsonPatch = "1.13"
springDocOpenAPI = "2.0.2"
lucene = "9.5.0"
}

dependencies {
implementation project(':api')

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package run.halo.app.infra.exception;

public enum Exceptions {
;

public static final String THEME_ALREADY_EXISTS_TYPE =
"https://halo.run/probs/theme-alreay-exists";

public static final String INVALID_CREDENTIAL_TYPE =
"https://halo.run/probs/invalid-credential";

public static final String REQUEST_NOT_PERMITTED_TYPE =
"https://halo.run/probs/request-not-permitted";

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
*/
public class ThemeAlreadyExistsException extends ServerWebInputException {

public static final String THEME_ALREADY_EXISTS_TYPE =
"https://halo.run/probs/theme-alreay-exists";

/**
* Constructs a {@code ThemeAlreadyExistsException} with the given theme name.
*
Expand All @@ -23,7 +20,7 @@ public class ThemeAlreadyExistsException extends ServerWebInputException {
public ThemeAlreadyExistsException(@NonNull String themeName) {
super("Theme already exists.", null, null, "problemDetail.theme.install.alreadyExists",
new Object[] {themeName});
setType(URI.create(THEME_ALREADY_EXISTS_TYPE));
setType(URI.create(Exceptions.THEME_ALREADY_EXISTS_TYPE));
getBody().setProperty("themeName", themeName);
}
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,54 @@
package run.halo.app.infra.utils;

import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.web.reactive.function.server.ServerRequest;

/**
* Ip address utils.
* Code from internet.
*/
@Slf4j
public class IpAddressUtils {
private static final String UNKNOWN = "unknown";
private static final String X_REAL_IP = "X-Real-IP";
private static final String X_FORWARDED_FOR = "X-Forwarded-For";
private static final String PROXY_CLIENT_IP = "Proxy-Client-IP";
private static final String WL_PROXY_CLIENT_IP = "WL-Proxy-Client-IP";
private static final String HTTP_CLIENT_IP = "HTTP_CLIENT_IP";
private static final String HTTP_X_FORWARDED_FOR = "HTTP_X_FORWARDED_FOR";
public static final String UNKNOWN = "unknown";

private static final String[] IP_HEADER_NAMES = {
"X-Forwarded-For",
"Proxy-Client-IP",
"WL-Proxy-Client-IP",
"CF-Connecting-IP",
"HTTP_X_FORWARDED_FOR",
"HTTP_X_FORWARDED",
"HTTP_X_CLUSTER_CLIENT_IP",
"HTTP_CLIENT_IP",
"HTTP_FORWARDED_FOR",
"HTTP_FORWARDED",
"HTTP_VIA",
"REMOTE_ADDR",
};

/**
* Gets the IP address from request.
*
* @param request is server http request
* @return IP address if found, otherwise {@link #UNKNOWN}.
*/
public static String getClientIp(ServerHttpRequest request) {
for (String header : IP_HEADER_NAMES) {
String ipList = request.getHeaders().getFirst(header);
if (ipList != null && ipList.length() != 0 && !"unknown".equalsIgnoreCase(ipList)) {
String[] ips = ipList.trim().split("[,;]");
for (String ip : ips) {
if (ip != null && ip.length() != 0 && !"unknown".equalsIgnoreCase(ip)) {
return ip;
}
}
}
}
var remoteAddress = request.getRemoteAddress();
return remoteAddress == null ? UNKNOWN : remoteAddress.getAddress().getHostAddress();
}


/**
* Gets the ip address from request.
Expand All @@ -25,48 +58,11 @@ public class IpAddressUtils {
*/
public static String getIpAddress(ServerRequest request) {
try {
return getIpAddressInternal(request);
return getClientIp(request.exchange().getRequest());
} catch (Exception e) {
log.warn("Failed to obtain client IP, and fallback to unknown.", e);
return UNKNOWN;
}
}

private static String getIpAddressInternal(ServerRequest request) {
HttpHeaders httpHeaders = request.headers().asHttpHeaders();
String xrealIp = httpHeaders.getFirst(X_REAL_IP);
String xforwardedFor = httpHeaders.getFirst(X_FORWARDED_FOR);

if (StringUtils.isNotEmpty(xforwardedFor) && !UNKNOWN.equalsIgnoreCase(xforwardedFor)) {
// After multiple reverse proxies, there will be multiple IP values. The first IP is
// the real IP
int index = xforwardedFor.indexOf(",");
if (index != -1) {
return xforwardedFor.substring(0, index);
} else {
return xforwardedFor;
}
}
xforwardedFor = xrealIp;
if (StringUtils.isNotEmpty(xforwardedFor) && !UNKNOWN.equalsIgnoreCase(xforwardedFor)) {
return xforwardedFor;
}
if (StringUtils.isBlank(xforwardedFor) || UNKNOWN.equalsIgnoreCase(xforwardedFor)) {
xforwardedFor = httpHeaders.getFirst(PROXY_CLIENT_IP);
}
if (StringUtils.isBlank(xforwardedFor) || UNKNOWN.equalsIgnoreCase(xforwardedFor)) {
xforwardedFor = httpHeaders.getFirst(WL_PROXY_CLIENT_IP);
}
if (StringUtils.isBlank(xforwardedFor) || UNKNOWN.equalsIgnoreCase(xforwardedFor)) {
xforwardedFor = httpHeaders.getFirst(HTTP_CLIENT_IP);
}
if (StringUtils.isBlank(xforwardedFor) || UNKNOWN.equalsIgnoreCase(xforwardedFor)) {
xforwardedFor = httpHeaders.getFirst(HTTP_X_FORWARDED_FOR);
}
if (StringUtils.isBlank(xforwardedFor) || UNKNOWN.equalsIgnoreCase(xforwardedFor)) {
xforwardedFor = request.remoteAddress()
.map(remoteAddress -> remoteAddress.getAddress().getHostAddress())
.orElse(UNKNOWN);
}
return xforwardedFor;
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
package run.halo.app.security.authentication.login;

import static org.springframework.http.HttpStatus.TOO_MANY_REQUESTS;
import static org.springframework.http.HttpStatus.UNAUTHORIZED;
import static org.springframework.http.MediaType.APPLICATION_JSON;
import static run.halo.app.infra.exception.Exceptions.INVALID_CREDENTIAL_TYPE;
import static run.halo.app.infra.exception.Exceptions.REQUEST_NOT_PERMITTED_TYPE;
import static run.halo.app.security.authentication.WebExchangeMatchers.ignoringMediaTypeAll;

import io.github.resilience4j.ratelimiter.RateLimiterRegistry;
import io.github.resilience4j.ratelimiter.RequestNotPermitted;
import io.github.resilience4j.reactor.ratelimiter.operator.RateLimiterOperator;
import io.micrometer.observation.ObservationRegistry;
import java.util.Map;
import java.net.URI;
import java.time.Instant;
import java.util.Locale;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.MessageSource;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.lang.NonNull;
import org.springframework.security.authentication.ObservationReactiveAuthenticationManager;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
Expand All @@ -28,10 +39,12 @@
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.stereotype.Component;
import org.springframework.web.ErrorResponse;
import org.springframework.web.reactive.function.server.ServerResponse;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
import run.halo.app.infra.utils.IpAddressUtils;
import run.halo.app.security.AdditionalWebFilter;

/**
Expand All @@ -40,6 +53,7 @@
* @author guqing
* @since 2.4.0
*/
@Slf4j
@Component
public class UsernamePasswordAuthenticator implements AdditionalWebFilter {

Expand All @@ -59,17 +73,23 @@ public class UsernamePasswordAuthenticator implements AdditionalWebFilter {

private final AuthenticationWebFilter authenticationWebFilter;

private final RateLimiterRegistry rateLimiterRegistry;
private final MessageSource messageSource;

public UsernamePasswordAuthenticator(ServerResponse.Context context,
ObservationRegistry observationRegistry, ReactiveUserDetailsService userDetailsService,
ReactiveUserDetailsPasswordService passwordService, PasswordEncoder passwordEncoder,
ServerSecurityContextRepository securityContextRepository, CryptoService cryptoService) {
ServerSecurityContextRepository securityContextRepository, CryptoService cryptoService,
RateLimiterRegistry rateLimiterRegistry, MessageSource messageSource) {
this.context = context;
this.observationRegistry = observationRegistry;
this.userDetailsService = userDetailsService;
this.passwordService = passwordService;
this.passwordEncoder = passwordEncoder;
this.securityContextRepository = securityContextRepository;
this.cryptoService = cryptoService;
this.rateLimiterRegistry = rateLimiterRegistry;
this.messageSource = messageSource;

this.authenticationWebFilter = new AuthenticationWebFilter(authenticationManager());
configureAuthenticationWebFilter(this.authenticationWebFilter);
Expand All @@ -91,7 +111,8 @@ void configureAuthenticationWebFilter(AuthenticationWebFilter filter) {
filter.setRequiresAuthenticationMatcher(requiresMatcher);
filter.setAuthenticationFailureHandler(new LoginFailureHandler());
filter.setAuthenticationSuccessHandler(new LoginSuccessHandler());
filter.setServerAuthenticationConverter(new LoginAuthenticationConverter(cryptoService));
filter.setServerAuthenticationConverter(new LoginAuthenticationConverter(cryptoService
));
filter.setSecurityContextRepository(securityContextRepository);
}

Expand All @@ -102,6 +123,62 @@ ReactiveAuthenticationManager authenticationManager() {
return new ObservationReactiveAuthenticationManager(observationRegistry, manager);
}


private <T> RateLimiterOperator<T> createIPBasedRateLimiter(ServerWebExchange exchange) {
var clientIp = IpAddressUtils.getClientIp(exchange.getRequest());
var rateLimiter =
rateLimiterRegistry.rateLimiter("authentication-from-ip-" + clientIp,
"authentication");
if (log.isDebugEnabled()) {
var metrics = rateLimiter.getMetrics();
log.debug(
"Authentication with Rate Limiter: {}, available permissions: {}, number of "
+ "waiting threads: {}",
rateLimiter, metrics.getAvailablePermissions(),
metrics.getNumberOfWaitingThreads());
}
return RateLimiterOperator.of(rateLimiter);
}

private Mono<Void> handleRequestNotPermitted(RequestNotPermitted e,
ServerWebExchange exchange) {
var errorResponse =
createErrorResponse(e, TOO_MANY_REQUESTS, REQUEST_NOT_PERMITTED_TYPE, exchange);
return writeErrorResponse(errorResponse, exchange);
}

private Mono<Void> handleAuthenticationException(AuthenticationException exception,
ServerWebExchange exchange) {
var errorResponse =
createErrorResponse(exception, UNAUTHORIZED, INVALID_CREDENTIAL_TYPE, exchange);
return writeErrorResponse(errorResponse, exchange);
}

private ErrorResponse createErrorResponse(Throwable t, HttpStatus status, String type,
ServerWebExchange exchange) {
var errorResponse =
ErrorResponse.create(t, status, t.getMessage());
var problemDetail = errorResponse.updateAndGetBody(messageSource, getLocale(exchange));
problemDetail.setType(URI.create(type));
problemDetail.setInstance(exchange.getRequest().getURI());
problemDetail.setProperty("requestId", exchange.getRequest().getId());
problemDetail.setProperty("timestamp", Instant.now());
return errorResponse;
}

private Mono<Void> writeErrorResponse(ErrorResponse errorResponse,
ServerWebExchange exchange) {
return ServerResponse.status(errorResponse.getStatusCode())
.contentType(APPLICATION_JSON)
.bodyValue(errorResponse.getBody())
.flatMap(response -> response.writeTo(exchange, context));
}

private Locale getLocale(ServerWebExchange exchange) {
var locale = exchange.getLocaleContext().getLocale();
return locale == null ? Locale.getDefault() : locale;
}

public class LoginSuccessHandler implements ServerAuthenticationSuccessHandler {

private final ServerAuthenticationSuccessHandler defaultHandler =
Expand All @@ -110,8 +187,9 @@ public class LoginSuccessHandler implements ServerAuthenticationSuccessHandler {
@Override
public Mono<Void> onAuthenticationSuccess(WebFilterExchange webFilterExchange,
Authentication authentication) {
return ignoringMediaTypeAll(MediaType.APPLICATION_JSON)
.matches(webFilterExchange.getExchange())
var exchange = webFilterExchange.getExchange();
return ignoringMediaTypeAll(APPLICATION_JSON)
.matches(exchange)
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
.switchIfEmpty(
defaultHandler.onAuthenticationSuccess(webFilterExchange, authentication)
Expand All @@ -124,11 +202,14 @@ public Mono<Void> onAuthenticationSuccess(WebFilterExchange webFilterExchange,
}

return ServerResponse.ok()
.contentType(MediaType.APPLICATION_JSON)
.contentType(APPLICATION_JSON)
.bodyValue(principal)
.flatMap(serverResponse ->
serverResponse.writeTo(webFilterExchange.getExchange(), context));
});
serverResponse.writeTo(exchange, context));
})
.transformDeferred(createIPBasedRateLimiter(exchange))
.onErrorResume(RequestNotPermitted.class,
e -> handleRequestNotPermitted(e, exchange));
}
}

Expand All @@ -142,21 +223,25 @@ public class LoginFailureHandler implements ServerAuthenticationFailureHandler {
private final ServerAuthenticationFailureHandler defaultHandler =
new RedirectServerAuthenticationFailureHandler("/console?error#/login");

public LoginFailureHandler() {
}

@Override
public Mono<Void> onAuthenticationFailure(WebFilterExchange webFilterExchange,
AuthenticationException exception) {
return ignoringMediaTypeAll(MediaType.APPLICATION_JSON).matches(
webFilterExchange.getExchange())
var exchange = webFilterExchange.getExchange();
return ignoringMediaTypeAll(APPLICATION_JSON)
.matches(exchange)
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
.flatMap(matchResult -> ServerResponse.status(HttpStatus.UNAUTHORIZED)
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(Map.of(
"error", exception.getLocalizedMessage()
))
.flatMap(serverResponse -> serverResponse.writeTo(
webFilterExchange.getExchange(), context)))
.switchIfEmpty(
defaultHandler.onAuthenticationFailure(webFilterExchange, exception));
.switchIfEmpty(defaultHandler.onAuthenticationFailure(webFilterExchange, exception)
// Skip the handleAuthenticationException.
.then(Mono.empty())
)
.flatMap(matchResult -> handleAuthenticationException(exception, exchange))
.transformDeferred(createIPBasedRateLimiter(exchange))
.onErrorResume(RequestNotPermitted.class,
e -> handleRequestNotPermitted(e, exchange));
}

}
}
Loading

0 comments on commit 02369fb

Please sign in to comment.