diff --git a/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java b/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java index a3ed99a3..667d3e31 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java +++ b/src/main/java/com/xiaomi/infra/pegasus/rpc/async/ReplicaSession.java @@ -385,6 +385,24 @@ public void run() { TimeUnit.MILLISECONDS); } + public void onAuthSucceed() { + Queue swappedPendingSend = new LinkedList<>(); + synchronized (authPendingSend) { + authSucceed = true; + swappedPendingSend.addAll(authPendingSend); + authPendingSend.clear(); + } + + while (!swappedPendingSend.isEmpty()) { + RequestEntry e = swappedPendingSend.poll(); + if (pendingResponse.get(e.sequenceId) != null) { + write(e, fields); + } else { + logger.info("{}: {} is removed from pending, perhaps timeout", name(), e.sequenceId); + } + } + } + // return value: // true - pend succeed // false - pend failed diff --git a/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java b/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java index fa6502d8..bb41ba6a 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java +++ b/src/main/java/com/xiaomi/infra/pegasus/security/Negotiation.java @@ -40,7 +40,7 @@ class Negotiation { private static final List expectedMechanisms = Collections.singletonList("GSSAPI"); private negotiation_status status; - private ReplicaSession session; + ReplicaSession session; SaslWrapper saslWrapper; Negotiation(ReplicaSession session, Subject subject, String serviceName, String serviceFQDN) { @@ -95,7 +95,7 @@ private void handleResponse() throws Exception { break; case SASL_INITIATE: case SASL_CHALLENGE_RESP: - // TBD(zlw): + onChallenge(resp); break; default: throw new Exception("unexpected negotiation status: " + resp.status); @@ -125,6 +125,21 @@ void onMechanismSelected(negotiation_response response) throws Exception { send(status, msg); } + void onChallenge(negotiation_response response) throws Exception { + switch (response.status) { + case SASL_CHALLENGE: + blob msg = saslWrapper.evaluateChallenge(response.msg.data); + status = negotiation_status.SASL_CHALLENGE_RESP; + send(status, msg); + break; + case SASL_SUCC: + negotiationSucceed(); + break; + default: + throw new Exception("receive wrong negotiation msg type" + response.status.toString()); + } + } + public String getMatchMechanism(String respString) { String matchMechanism = ""; String[] serverSupportMechanisms = respString.split(","); @@ -150,6 +165,11 @@ private void negotiationFailed() { session.closeSession(); } + private void negotiationSucceed() { + status = negotiation_status.SASL_SUCC; + session.onAuthSucceed(); + } + negotiation_status getStatus() { return status; } diff --git a/src/main/java/com/xiaomi/infra/pegasus/security/SaslWrapper.java b/src/main/java/com/xiaomi/infra/pegasus/security/SaslWrapper.java index 1648afdf..a76ce87b 100644 --- a/src/main/java/com/xiaomi/infra/pegasus/security/SaslWrapper.java +++ b/src/main/java/com/xiaomi/infra/pegasus/security/SaslWrapper.java @@ -67,4 +67,12 @@ public blob getInitialResponse() throws PrivilegedActionException { } }); } + + // If a challenge is received from the server during the authentication process, + // this method is called to prepare an appropriate next response to submit to the server. + public blob evaluateChallenge(final byte[] data) throws PrivilegedActionException { + return Subject.doAs( + subject, + (PrivilegedExceptionAction) () -> new blob(saslClient.evaluateChallenge(data))); + } } diff --git a/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java b/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java index 9c51dd58..d08ac18b 100644 --- a/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java +++ b/src/test/java/com/xiaomi/infra/pegasus/security/NegotiationTest.java @@ -23,6 +23,7 @@ import com.xiaomi.infra.pegasus.apps.negotiation_response; import com.xiaomi.infra.pegasus.apps.negotiation_status; import com.xiaomi.infra.pegasus.base.blob; +import com.xiaomi.infra.pegasus.rpc.async.ReplicaSession; import java.nio.charset.Charset; import javax.security.auth.Subject; import org.junit.Assert; @@ -134,4 +135,45 @@ public void testMechanismSelected() { response.status = negotiation_status.SASL_LIST_MECHANISMS; Assertions.assertThrows(Exception.class, () -> mockNegotiation.onMechanismSelected(response)); } + + @Test + public void testChallenge() { + Negotiation mockNegotiation = Mockito.spy(negotiation); + SaslWrapper mockSaslWrapper = Mockito.mock(SaslWrapper.class); + ReplicaSession mockSession = Mockito.mock(ReplicaSession.class); + mockNegotiation.saslWrapper = mockSaslWrapper; + mockNegotiation.session = mockSession; + + // mock operation + Mockito.doNothing().when(mockNegotiation).send(any(), any()); + Mockito.doNothing().when(mockNegotiation.session).onAuthSucceed(); + try { + Mockito.when(mockNegotiation.saslWrapper.evaluateChallenge(any())).thenReturn(new blob()); + } catch (Exception ex) { + Assert.fail(); + } + + // normal case + Assertions.assertDoesNotThrow( + () -> { + negotiation_response response = + new negotiation_response(negotiation_status.SASL_CHALLENGE, new blob(new byte[0])); + mockNegotiation.onChallenge(response); + Assert.assertEquals(mockNegotiation.getStatus(), negotiation_status.SASL_CHALLENGE_RESP); + + response = new negotiation_response(negotiation_status.SASL_SUCC, new blob(new byte[0])); + mockNegotiation.onChallenge(response); + Assert.assertEquals(mockNegotiation.getStatus(), negotiation_status.SASL_SUCC); + }); + + // deal with wrong response.status + Assertions.assertThrows( + Exception.class, + () -> { + negotiation_response response = + new negotiation_response( + negotiation_status.SASL_LIST_MECHANISMS, new blob(new byte[0])); + mockNegotiation.onChallenge(response); + }); + } }