Skip to content

Commit

Permalink
💥 LoginSuccessHandler state for oauth #238
Browse files Browse the repository at this point in the history
  • Loading branch information
trydofor committed May 20, 2024
1 parent 6796e75 commit f6f6923
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 28 deletions.
2 changes: 1 addition & 1 deletion observe/docs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
Expand Down Expand Up @@ -40,23 +39,24 @@ public class LoginSuccessHandler extends NonceLoginSuccessHandler implements Ini

@Override
protected void onResponse(@NotNull HttpServletRequest req, @NotNull HttpServletResponse res, @NotNull Authentication aun,
@Nullable String sid, long uid, @Nullable String state) throws IOException, ServletException {
@NotNull State state) throws IOException, ServletException {

if (state != null && !state.isEmpty()) {
if (state.startsWith("/") || isSafeRedirect(state)) {
log.debug("redirect to {}", state);
res.sendRedirect(state);
String cts = state.getStateClient();
if (cts != null && !cts.isEmpty()) {
if (cts.startsWith("/") || isSafeRedirect(cts)) {
log.debug("redirect to {}", cts);
res.sendRedirect(cts);
}
else {
writeResponseBody(state, req, res, aun, sid, uid, state);
writeResponseBody(req, res, aun, state, cts);
}
}
else {
if (warlockSecurityProp.isLoginSuccessRedirect()) {
super.onResponse(req, res, aun, sid, uid, state);
super.onResponse(req, res, aun, state);
}
else {
writeResponseBody(warlockSecurityProp.getLoginSuccessBody(), req, res, aun, sid, uid, state);
writeResponseBody(req, res, aun, state, warlockSecurityProp.getLoginSuccessBody());
}
}
}
Expand All @@ -65,8 +65,8 @@ protected boolean isSafeRedirect(String state) {
return SafeHttpHelper.isSafeRedirect(state, warlockJustAuthProp.getSafeHost());
}

protected void writeResponseBody(@NotNull String body, @NotNull HttpServletRequest req, @NotNull HttpServletResponse res,
@NotNull Authentication aun, @Nullable String sid, long uid, @Nullable String state) {
protected void writeResponseBody(@NotNull HttpServletRequest req, @NotNull HttpServletResponse res, @NotNull Authentication aun,
@NotNull State state, @NotNull String body) {
ResponseHelper.writeBodyUtf8(res, body);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.Data;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
Expand All @@ -32,21 +32,49 @@ public class NonceLoginSuccessHandler extends SavedRequestAwareAuthenticationSuc

@Override
public final void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws ServletException, IOException {
final HttpSession session = request.getSession(false);
final State state = new State();

String sid = null;
final long uid = SecurityContextUtil.getUserId();
final String state = request.getParameter(AuthStateBuilder.ParamState);
state.setUserId(uid);

final String sts = request.getParameter(AuthStateBuilder.ParamState);
state.setStateOauth(sts);

final HttpSession session = request.getSession(false);
if (session != null) {
sid = session.getId();
if (state != null) {
NonceTokenSessionHelper.bindNonceSession(state, sid);
log.debug("parse client state={}, uid={}", state, uid);
String sid = session.getId();
state.setSessionId(sid);

if (sts != null) {
NonceTokenSessionHelper.bindNonceSession(sts, sid);
log.debug("parse oauth state={}, uid={}", sts, uid);
}
}
state.setStateClient(authStateBuilder.parseState(request));
onResponse(request, response, authentication, state);
}

@Data
public static class State {
/**
* SecurityContextUtil.getUserId()
*/
private long userId;

/**
* session id, null if no-login
*/
private String sessionId;

/**
* the state via oauth builder
*/
private String stateOauth;

onResponse(request, response, authentication, sid, uid, authStateBuilder.parseState(request));
/**
* the safe-state send by client
*/
private String stateClient;
}

/**
Expand All @@ -55,12 +83,10 @@ public final void onAuthenticationSuccess(HttpServletRequest request, HttpServle
* @param req HttpServletRequest
* @param res HttpServletResponse
* @param aun Authentication
* @param sid session id, null if no-login
* @param uid user id
* @param state The state set by the client contained in the oauth2 state
* @param state login state
*/
protected void onResponse(@NotNull HttpServletRequest req, @NotNull HttpServletResponse res, @NotNull Authentication aun,
@Nullable String sid, long uid, @Nullable String state) throws ServletException, IOException {
@NotNull State state) throws ServletException, IOException {
super.onAuthenticationSuccess(req, res, aun);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ public Map<String, String[]> parseParam(HttpServletRequest request) {
return args;
}

/**
* <pre>
* parse client state, then merge them to safe format.
* e.g. given,
* (1) .safe-state[/order-list]={1}/#{0}
* (2) GET ?state=/order-list&state=http://localhost%3A8080
* then state=['/order-list', 'http://localhost:8080']
* and parseState returns http://localhost:8080/#/order-list
* </pre>
*/
@NotNull
public String parseState(HttpServletRequest request) {
final Map<String, String[]> map = parseParam(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@
import org.springframework.core.Ordered;
import org.springframework.security.authentication.InsufficientAuthenticationException;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
import pro.fessional.mirana.best.AssertArgs;
import pro.fessional.mirana.flow.FlowEnum;
import pro.fessional.wings.silencer.spring.WingsOrdered;
import pro.fessional.wings.slardar.security.impl.ComboWingsAuthDetailsSource;
import pro.fessional.wings.slardar.security.impl.DefaultWingsAuthDetails;
import pro.fessional.wings.warlock.errcode.CommonErrorEnum;
import pro.fessional.wings.warlock.security.session.NonceTokenSessionHelper;

import java.util.ArrayList;
Expand Down Expand Up @@ -87,14 +89,15 @@ public class JustAuthRequestBuilder implements ComboWingsAuthDetailsSource.Combo
public DefaultWingsAuthDetails buildDetails(@NotNull Enum<?> authType, @NotNull HttpServletRequest request) {
AuthRequest ar = buildRequest(authType, request);
if (ar == null) return null;
AuthCallback callback = new AuthCallback();
final String state = request.getParameter("state");
AssertArgs.notEmpty(state, CommonErrorEnum.AssertNotFound1, "state");

AuthCallback callback = new AuthCallback();
callback.setAuth_code(request.getParameter("auth_code"));
callback.setAuthorization_code(request.getParameter("authorization_code"));
callback.setCode(request.getParameter("code"));
callback.setOauth_token(request.getParameter("oauth_token"));
callback.setOauth_verifier(request.getParameter("oauth_verifier"));
final String state = request.getParameter("state");
callback.setState(state);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private static class Sf {
* Init one-time token
*/
public static void initNonce(String token, String ip) {
if (token == null) return;
if (token == null || token.isEmpty()) return;
final Sf s = new Sf();
s.ip = ip;
cache.put(token, s);
Expand All @@ -42,6 +42,7 @@ public static void initNonce(String token, String ip) {
* bind token to sessionId
*/
public static void bindNonceSession(String token, String sid) {
if (token == null || token.isEmpty()) return;
final SidData data = () -> sid;
final R<?> result = R.okData(data);
bindNonceResult(token, result);
Expand All @@ -51,6 +52,7 @@ public static void bindNonceSession(String token, String sid) {
* bind token to result
*/
public static void bindNonceResult(String token, R<?> result) {
if (token == null || token.isEmpty()) return;
final Sf s = cache.get(token);
if (s != null) {
s.result = result;
Expand All @@ -61,6 +63,7 @@ public static void bindNonceResult(String token, R<?> result) {
* invalid the token
*/
public static void invalidNonce(String token) {
if (token == null || token.isEmpty()) return;
cache.remove(token);
}

Expand Down

0 comments on commit f6f6923

Please sign in to comment.