Skip to content

Commit

Permalink
Block onHalfClose if onMessage was blocked (#241)
Browse files Browse the repository at this point in the history
* test that SecurityInterceptor does not propagate onHalfClose after closing the request

* set the locale in ValidationTest so it passes on non-English systems

* test that ValidatingInterceptor does not propagate onHalfClose after closing the request

* also block onHalfClose after blocking onMessage

fixes #240
  • Loading branch information
jGleitz authored Sep 13, 2021
1 parent b56d6e4 commit 4b458cb
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.lognet.springboot.grpc;

import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;

@GRpcGlobalInterceptor
public class HalfCloseInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next
) {
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call, headers)) {
@Override
public void onHalfClose() {
HalfCloseInterceptor.this.onHalfClose();
super.onHalfClose();
}
};
}

public void onHalfClose() {}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
package org.lognet.springboot.grpc;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.emptyOrNullString;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.NONE;

import java.util.Locale;

import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.GreeterGrpc;
import io.grpc.examples.GreeterOuterClass;
import org.hamcrest.Matchers;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.lognet.springboot.grpc.demo.DemoApp;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.test.mock.mockito.SpyBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.emptyOrNullString;
import static org.junit.Assert.assertThrows;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.NONE;

@RunWith(SpringRunner.class)
@SpringBootTest(classes = {DemoApp.class}, webEnvironment = NONE, properties = {"grpc.port=0"})
@Import(ValidationTest.TestCfg.class)
@ActiveProfiles("disable-security")
public class ValidationTest extends GrpcServerTestBase {

@TestConfiguration
static class TestCfg {
@Bean
Expand All @@ -43,7 +49,21 @@ public Status handle(Object message, Status status, Exception exception, Metadat
}
private GreeterGrpc.GreeterBlockingStub stub;

@SpyBean
HalfCloseInterceptor halfCloseInterceptor;

private static Locale systemDefaultLocale;

@BeforeClass
public static void setLocaleToEnglish() {
systemDefaultLocale = Locale.getDefault();
Locale.setDefault(Locale.ENGLISH);
}

@AfterClass
public static void resetDefaultLocale() {
Locale.setDefault(systemDefaultLocale);
}

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -147,7 +167,6 @@ public void validMessageValidationTest() {
@Test
public void invalidResponseMessageValidationTest() {
StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> {

stub.helloPersonInvalidResponse(GreeterOuterClass.Person.newBuilder()
.setAge(3)//valid
.setName("Dexter")//valid
Expand All @@ -164,6 +183,17 @@ public void invalidResponseMessageValidationTest() {

}

@Test
public void noHalfCloseAfterFailedValidation() {
StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> {
stub.helloPersonValidResponse(GreeterOuterClass.Person.newBuilder()
.setAge(49)// valid
.clearName() //invalid
.build());
});
assertThat(e.getStatus().getCode(), Matchers.is(Status.Code.INVALID_ARGUMENT));
verify(halfCloseInterceptor, never()).onHalfClose();
}

String getFieldName(int fieldNumber) {
return GreeterOuterClass.Person.getDescriptor().findFieldByNumber(fieldNumber).getName();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.lognet.springboot.grpc.auth;


import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;

import com.google.protobuf.Empty;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.SecuredGreeterGrpc;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.lognet.springboot.grpc.GrpcServerTestBase;
import org.lognet.springboot.grpc.HalfCloseInterceptor;
import org.lognet.springboot.grpc.demo.DemoApp;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.SpyBean;
import org.springframework.test.context.junit4.SpringRunner;

@SpringBootTest(
classes = DemoApp.class,
properties = "grpc.security.auth.fail-fast=false"
)
@RunWith(SpringRunner.class)
public class FailLateSecurityInterceptorTest extends GrpcServerTestBase {
@SpyBean
HalfCloseInterceptor halfCloseInterceptor;

@Test
public void noHalfCloseOnFailedAuth() {
final StatusRuntimeException statusRuntimeException = assertThrows(
StatusRuntimeException.class,
() -> SecuredGreeterGrpc.newBlockingStub(selectedChanel).sayAuthHello2(Empty.newBuilder().build()).getMessage()
);
assertThat(statusRuntimeException.getStatus().getCode(), Matchers.is(Status.Code.UNAUTHENTICATED));
verify(halfCloseInterceptor, never()).onHalfClose();
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package org.lognet.springboot.grpc;

import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;

public interface FailureHandlingServerInterceptor extends ServerInterceptor {
default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?> call, Metadata headers, final Status status, Exception exception){
Expand All @@ -14,4 +14,25 @@ default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?
call.close(statusToSend, responseHeaders);

}

class MessageBlockingServerCallListener<R> extends ForwardingServerCallListener.SimpleForwardingServerCallListener<R> {
private volatile boolean messageBlocked = false;

public MessageBlockingServerCallListener(ServerCall.Listener<R> delegate) {
super(delegate);
}

@Override
public void onHalfClose() {
// If the message was blocked, downstream never had a chance to react to it. Hence, the half-close signal would look like
// an error to them. So we do not propagate the signal in that case.
if (!messageBlocked) {
super.onHalfClose();
}
}

protected void blockMessage() {
messageBlocked = true;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package org.lognet.springboot.grpc.security;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Optional;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall;
Expand All @@ -22,10 +26,6 @@
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Optional;

@Slf4j
public class SecurityInterceptor extends AbstractSecurityInterceptor implements FailureHandlingServerInterceptor, Ordered {

Expand Down Expand Up @@ -193,9 +193,10 @@ private <RespT, ReqT> ServerCall.Listener<ReqT> fail(ServerCallHandler<ReqT, Res


} else {
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call, headers)) {
return new MessageBlockingServerCallListener<ReqT>(next.startCall(call, headers)) {
@Override
public void onMessage(ReqT message) {
blockMessage();
closeCall(message, errorHandler, call, headers, status, exception);
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package org.lognet.springboot.grpc.validation;

import java.util.Optional;
import java.util.Set;
import javax.validation.ConstraintViolation;
import javax.validation.ConstraintViolationException;
import javax.validation.Validator;

import io.grpc.ForwardingServerCall;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
Expand All @@ -15,12 +20,6 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.Ordered;

import javax.validation.ConstraintViolation;
import javax.validation.ConstraintViolationException;
import javax.validation.Validator;
import java.util.Optional;
import java.util.Set;


public class ValidatingInterceptor implements FailureHandlingServerInterceptor, Ordered {
private Validator validator;
Expand Down Expand Up @@ -53,14 +52,14 @@ public void sendMessage(RespT message) {
}
}
}, headers);
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(listener) {
return new MessageBlockingServerCallListener<ReqT>(listener) {

@Override
public void onMessage(ReqT message) {
final Set<ConstraintViolation<ReqT>> violations = validator.validate(message, RequestMessage.class);
if (!violations.isEmpty()) {
closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));

blockMessage();
closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));
} else {
super.onMessage(message);
}
Expand Down

0 comments on commit 4b458cb

Please sign in to comment.