diff --git a/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/HalfCloseInterceptor.java b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/HalfCloseInterceptor.java new file mode 100644 index 00000000..35dbb65b --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/HalfCloseInterceptor.java @@ -0,0 +1,25 @@ +package org.lognet.springboot.grpc; + +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; + +@GRpcGlobalInterceptor +public class HalfCloseInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next + ) { + return new ForwardingServerCallListener.SimpleForwardingServerCallListener(next.startCall(call, headers)) { + @Override + public void onHalfClose() { + HalfCloseInterceptor.this.onHalfClose(); + super.onHalfClose(); + } + }; + } + + public void onHalfClose() {} +} diff --git a/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/ValidationTest.java b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/ValidationTest.java index 6b818329..44d69b98 100644 --- a/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/ValidationTest.java +++ b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/ValidationTest.java @@ -1,33 +1,39 @@ package org.lognet.springboot.grpc; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.emptyOrNullString; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.NONE; + +import java.util.Locale; + import io.grpc.Metadata; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.examples.GreeterGrpc; import io.grpc.examples.GreeterOuterClass; import org.hamcrest.Matchers; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.lognet.springboot.grpc.demo.DemoApp; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.junit4.SpringRunner; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.emptyOrNullString; -import static org.junit.Assert.assertThrows; -import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.NONE; - @RunWith(SpringRunner.class) @SpringBootTest(classes = {DemoApp.class}, webEnvironment = NONE, properties = {"grpc.port=0"}) @Import(ValidationTest.TestCfg.class) @ActiveProfiles("disable-security") public class ValidationTest extends GrpcServerTestBase { - @TestConfiguration static class TestCfg { @Bean @@ -43,7 +49,21 @@ public Status handle(Object message, Status status, Exception exception, Metadat } private GreeterGrpc.GreeterBlockingStub stub; + @SpyBean + HalfCloseInterceptor halfCloseInterceptor; + private static Locale systemDefaultLocale; + + @BeforeClass + public static void setLocaleToEnglish() { + systemDefaultLocale = Locale.getDefault(); + Locale.setDefault(Locale.ENGLISH); + } + + @AfterClass + public static void resetDefaultLocale() { + Locale.setDefault(systemDefaultLocale); + } @Before public void setUp() throws Exception { @@ -147,7 +167,6 @@ public void validMessageValidationTest() { @Test public void invalidResponseMessageValidationTest() { StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> { - stub.helloPersonInvalidResponse(GreeterOuterClass.Person.newBuilder() .setAge(3)//valid .setName("Dexter")//valid @@ -164,6 +183,17 @@ public void invalidResponseMessageValidationTest() { } + @Test + public void noHalfCloseAfterFailedValidation() { + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> { + stub.helloPersonValidResponse(GreeterOuterClass.Person.newBuilder() + .setAge(49)// valid + .clearName() //invalid + .build()); + }); + assertThat(e.getStatus().getCode(), Matchers.is(Status.Code.INVALID_ARGUMENT)); + verify(halfCloseInterceptor, never()).onHalfClose(); + } String getFieldName(int fieldNumber) { return GreeterOuterClass.Person.getDescriptor().findFieldByNumber(fieldNumber).getName(); diff --git a/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/auth/FailLateSecurityInterceptorTest.java b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/auth/FailLateSecurityInterceptorTest.java new file mode 100644 index 00000000..d6a388ae --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/auth/FailLateSecurityInterceptorTest.java @@ -0,0 +1,41 @@ +package org.lognet.springboot.grpc.auth; + + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import com.google.protobuf.Empty; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.examples.SecuredGreeterGrpc; +import org.hamcrest.Matchers; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.lognet.springboot.grpc.GrpcServerTestBase; +import org.lognet.springboot.grpc.HalfCloseInterceptor; +import org.lognet.springboot.grpc.demo.DemoApp; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.test.context.junit4.SpringRunner; + +@SpringBootTest( + classes = DemoApp.class, + properties = "grpc.security.auth.fail-fast=false" +) +@RunWith(SpringRunner.class) +public class FailLateSecurityInterceptorTest extends GrpcServerTestBase { + @SpyBean + HalfCloseInterceptor halfCloseInterceptor; + + @Test + public void noHalfCloseOnFailedAuth() { + final StatusRuntimeException statusRuntimeException = assertThrows( + StatusRuntimeException.class, + () -> SecuredGreeterGrpc.newBlockingStub(selectedChanel).sayAuthHello2(Empty.newBuilder().build()).getMessage() + ); + assertThat(statusRuntimeException.getStatus().getCode(), Matchers.is(Status.Code.UNAUTHENTICATED)); + verify(halfCloseInterceptor, never()).onHalfClose(); + } +} diff --git a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/FailureHandlingServerInterceptor.java b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/FailureHandlingServerInterceptor.java index e67893a1..1df1bb1f 100644 --- a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/FailureHandlingServerInterceptor.java +++ b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/FailureHandlingServerInterceptor.java @@ -1,10 +1,10 @@ package org.lognet.springboot.grpc; +import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerInterceptor; import io.grpc.Status; -import io.grpc.StatusRuntimeException; public interface FailureHandlingServerInterceptor extends ServerInterceptor { default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall call, Metadata headers, final Status status, Exception exception){ @@ -14,4 +14,25 @@ default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall extends ForwardingServerCallListener.SimpleForwardingServerCallListener { + private volatile boolean messageBlocked = false; + + public MessageBlockingServerCallListener(ServerCall.Listener delegate) { + super(delegate); + } + + @Override + public void onHalfClose() { + // If the message was blocked, downstream never had a chance to react to it. Hence, the half-close signal would look like + // an error to them. So we do not propagate the signal in that case. + if (!messageBlocked) { + super.onHalfClose(); + } + } + + protected void blockMessage() { + messageBlocked = true; + } + } } diff --git a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/security/SecurityInterceptor.java b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/security/SecurityInterceptor.java index 44fc28c8..67b507c7 100644 --- a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/security/SecurityInterceptor.java +++ b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/security/SecurityInterceptor.java @@ -1,5 +1,9 @@ package org.lognet.springboot.grpc.security; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Optional; + import io.grpc.Context; import io.grpc.Contexts; import io.grpc.ForwardingServerCall; @@ -22,10 +26,6 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Optional; - @Slf4j public class SecurityInterceptor extends AbstractSecurityInterceptor implements FailureHandlingServerInterceptor, Ordered { @@ -193,9 +193,10 @@ private ServerCall.Listener fail(ServerCallHandler(next.startCall(call, headers)) { + return new MessageBlockingServerCallListener(next.startCall(call, headers)) { @Override public void onMessage(ReqT message) { + blockMessage(); closeCall(message, errorHandler, call, headers, status, exception); } }; diff --git a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/validation/ValidatingInterceptor.java b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/validation/ValidatingInterceptor.java index 5585cc28..c4883ce8 100644 --- a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/validation/ValidatingInterceptor.java +++ b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/validation/ValidatingInterceptor.java @@ -1,7 +1,12 @@ package org.lognet.springboot.grpc.validation; +import java.util.Optional; +import java.util.Set; +import javax.validation.ConstraintViolation; +import javax.validation.ConstraintViolationException; +import javax.validation.Validator; + import io.grpc.ForwardingServerCall; -import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; @@ -15,12 +20,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.Ordered; -import javax.validation.ConstraintViolation; -import javax.validation.ConstraintViolationException; -import javax.validation.Validator; -import java.util.Optional; -import java.util.Set; - public class ValidatingInterceptor implements FailureHandlingServerInterceptor, Ordered { private Validator validator; @@ -53,14 +52,14 @@ public void sendMessage(RespT message) { } } }, headers); - return new ForwardingServerCallListener.SimpleForwardingServerCallListener(listener) { + return new MessageBlockingServerCallListener(listener) { @Override public void onMessage(ReqT message) { final Set> violations = validator.validate(message, RequestMessage.class); if (!violations.isEmpty()) { - closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations)); - + blockMessage(); + closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations)); } else { super.onMessage(message); }