Skip to content

Commit

Permalink
Merge pull request quarkusio#34520 from michalvavrik/feature/grpc-che…
Browse files Browse the repository at this point in the history
…ck-event-loop-with-updated-value

Invoke secured blocking Grpc methods on worker thread
  • Loading branch information
cescoffier authored Jul 5, 2023
2 parents dc0d198 + a0e5579 commit ea13bc3
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkus.grpc.deployment;

import static io.quarkus.deployment.Feature.GRPC_SERVER;
import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT;
import static io.quarkus.grpc.deployment.GrpcDotNames.BLOCKING;
import static io.quarkus.grpc.deployment.GrpcDotNames.MUTINY_SERVICE;
import static io.quarkus.grpc.deployment.GrpcDotNames.NON_BLOCKING;
Expand Down Expand Up @@ -40,6 +41,7 @@
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem;
import io.quarkus.arc.deployment.BeanArchivePredicateBuildItem;
import io.quarkus.arc.deployment.BeanContainerBuildItem;
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
import io.quarkus.arc.deployment.RecorderBeanInitializedBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
Expand Down Expand Up @@ -71,6 +73,7 @@
import io.quarkus.grpc.GrpcService;
import io.quarkus.grpc.auth.DefaultAuthExceptionHandlerProvider;
import io.quarkus.grpc.auth.GrpcSecurityInterceptor;
import io.quarkus.grpc.auth.GrpcSecurityRecorder;
import io.quarkus.grpc.deployment.devmode.FieldDefinalizingVisitor;
import io.quarkus.grpc.protoc.plugin.MutinyGrpcGenerator;
import io.quarkus.grpc.runtime.GrpcContainer;
Expand Down Expand Up @@ -728,4 +731,26 @@ UnremovableBeanBuildItem unremovableServerInterceptors() {
return UnremovableBeanBuildItem.beanTypes(GrpcDotNames.SERVER_INTERCEPTOR);
}

@Consume(SyntheticBeansRuntimeInitBuildItem.class)
@Record(RUNTIME_INIT)
@BuildStep
void initGrpcSecurityInterceptor(List<BindableServiceBuildItem> bindables, Capabilities capabilities,
GrpcSecurityRecorder recorder, BeanContainerBuildItem beanContainer) {
if (capabilities.isPresent(Capability.SECURITY)) {

// Grpc service to blocking method
Map<String, List<String>> blocking = new HashMap<>();
for (BindableServiceBuildItem bindable : bindables) {
if (bindable.hasBlockingMethods()) {
blocking.put(bindable.serviceClass.toString(), bindable.blockingMethods);
}
}

if (!blocking.isEmpty()) {
// provide GrpcSecurityInterceptor with blocking methods
recorder.initGrpcSecurityInterceptor(blocking, beanContainer.getValue());
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

Expand All @@ -32,6 +33,7 @@
import io.quarkus.security.identity.request.AuthenticationRequest;
import io.quarkus.security.identity.request.UsernamePasswordAuthenticationRequest;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Context;
Expand Down Expand Up @@ -66,7 +68,30 @@ void shouldSecureUniEndpoint() {
SecuredService client = GrpcClientUtils.attachHeaders(securityClient, headers);
AtomicInteger resultCount = new AtomicInteger();
client.unaryCall(Security.Container.newBuilder().setText("woo-hoo").build())
.subscribe().with(e -> resultCount.incrementAndGet());
.subscribe().with(e -> {
if (!e.getIsOnEventLoop()) {
Assertions.fail("Secured method should be run on event loop");
}
resultCount.incrementAndGet();
});

await().atMost(10, TimeUnit.SECONDS)
.until(() -> resultCount.get() == 1);
}

@Test
void shouldSecureBlockingUniEndpoint() {
Metadata headers = new Metadata();
headers.put(AUTHORIZATION, "Basic " + JOHN_BASIC_CREDS);
SecuredService client = GrpcClientUtils.attachHeaders(securityClient, headers);
AtomicInteger resultCount = new AtomicInteger();
client.unaryCallBlocking(Security.Container.newBuilder().setText("woo-hoo").build())
.subscribe().with(e -> {
if (e.getIsOnEventLoop()) {
Assertions.fail("Secured method annotated with @Blocking should be executed on worker thread");
}
resultCount.incrementAndGet();
});

await().atMost(10, TimeUnit.SECONDS)
.until(() -> resultCount.get() == 1);
Expand All @@ -88,6 +113,22 @@ void shouldSecureMultiEndpoint() {
assertThat(results.stream().filter(e -> !e)).isEmpty();
}

@Test
void shouldSecureBlockingMultiEndpoint() {
Metadata headers = new Metadata();
headers.put(AUTHORIZATION, "Basic " + PAUL_BASIC_CREDS);
SecuredService client = GrpcClientUtils.attachHeaders(securityClient, headers);
List<Boolean> results = new CopyOnWriteArrayList<>();
client.streamCallBlocking(Multi.createBy().repeating()
.supplier(() -> (Security.Container.newBuilder().setText("woo-hoo").build())).atMost(4))
.subscribe().with(e -> results.add(e.getIsOnEventLoop()));

await().atMost(10, TimeUnit.SECONDS)
.until(() -> results.size() == 5);

assertThat(results.stream().filter(e -> e)).isEmpty();
}

@Test
void shouldFailWithInvalidCredentials() {
Metadata headers = new Metadata();
Expand Down Expand Up @@ -139,6 +180,22 @@ public Multi<Security.ThreadInfo> streamCall(Multi<Security.Container> request)
.atMost(5);
}

@Blocking
@Override
@RolesAllowed("employees")
public Uni<Security.ThreadInfo> unaryCallBlocking(Security.Container request) {
return Uni.createFrom()
.item(newBuilder().setIsOnEventLoop(Context.isOnEventLoopThread()).build());
}

@Blocking
@Override
@RolesAllowed("interns")
public Multi<Security.ThreadInfo> streamCallBlocking(Multi<Security.Container> request) {
return Multi.createBy()
.repeating().supplier(() -> newBuilder().setIsOnEventLoop(Context.isOnEventLoopThread()).build())
.atMost(5);
}
}

@Singleton
Expand Down
2 changes: 2 additions & 0 deletions extensions/grpc/deployment/src/test/proto/security.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ option java_package = "com.example.security";
service SecuredService {
rpc unaryCall(Container) returns (ThreadInfo);
rpc streamCall(stream Container) returns (stream ThreadInfo);
rpc unaryCallBlocking(Container) returns (ThreadInfo);
rpc streamCallBlocking(stream Container) returns (stream ThreadInfo);
}

message ThreadInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;

import jakarta.enterprise.inject.Instance;
Expand Down Expand Up @@ -42,6 +44,9 @@ public final class GrpcSecurityInterceptor implements ServerInterceptor, Priorit
private final AuthExceptionHandlerProvider exceptionHandlerProvider;
private final List<GrpcSecurityMechanism> securityMechanisms;

private final Map<String, List<String>> serviceToBlockingMethods = new HashMap<>();
private boolean hasBlockingMethods = false;

@Inject
public GrpcSecurityInterceptor(
CurrentIdentityAssociation identityAssociation,
Expand Down Expand Up @@ -79,13 +84,25 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
Context context = Vertx.currentContext();
boolean onEventLoopThread = Context.isOnEventLoopThread();

final boolean isBlockingMethod;
if (hasBlockingMethods) {
var methods = serviceToBlockingMethods.get(serverCall.getMethodDescriptor().getServiceName());
if (methods != null) {
isBlockingMethod = methods.contains(serverCall.getMethodDescriptor().getFullMethodName());
} else {
isBlockingMethod = false;
}
} else {
isBlockingMethod = false;
}

if (authenticationRequest != null) {
Uni<SecurityIdentity> auth = identityProviderManager
.authenticate(authenticationRequest)
.emitOn(new Executor() {
@Override
public void execute(Runnable command) {
if (onEventLoopThread) {
if (onEventLoopThread && !isBlockingMethod) {
context.runOnContext(new Handler<>() {
@Override
public void handle(Void event) {
Expand Down Expand Up @@ -119,4 +136,9 @@ public void handle(Void event) {
public int getPriority() {
return Integer.MAX_VALUE - 100;
}

void init(Map<String, List<String>> serviceToBlockingMethods) {
this.serviceToBlockingMethods.putAll(serviceToBlockingMethods);
this.hasBlockingMethods = true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.quarkus.grpc.auth;

import static io.quarkus.grpc.runtime.GrpcServerRecorder.GrpcServiceDefinition.getImplementationClassName;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import io.grpc.BindableService;
import io.grpc.ServerMethodDefinition;
import io.quarkus.arc.runtime.BeanContainer;
import io.quarkus.grpc.runtime.GrpcContainer;
import io.quarkus.runtime.annotations.Recorder;

@Recorder
public class GrpcSecurityRecorder {

public void initGrpcSecurityInterceptor(Map<String, List<String>> serviceClassToBlockingMethod,
BeanContainer container) {

// service to full method names
var svcToMethods = new HashMap<String, List<String>>();
var services = container.beanInstance(GrpcContainer.class).getServices();
for (BindableService service : services) {
var className = getImplementationClassName(service);
var blockingMethods = serviceClassToBlockingMethod.get(className);
if (blockingMethods != null && !blockingMethods.isEmpty()) {
var svcName = service.bindService().getServiceDescriptor().getName();
var methods = new ArrayList<String>();
for (String blockingMethod : blockingMethods) {
for (ServerMethodDefinition<?, ?> method : service.bindService().getMethods()) {
if (blockingMethod.equals(method.getMethodDescriptor().getBareMethodName())) {
methods.add(method.getMethodDescriptor().getFullMethodName());
break;
}
}
}
svcToMethods.put(svcName, methods);
}
}

container.beanInstance(GrpcSecurityInterceptor.class).init(svcToMethods);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ public static final class GrpcServiceDefinition {
}

public String getImplementationClassName() {
return getImplementationClassName(service);
}

public static String getImplementationClassName(BindableService service) {
if (service instanceof Subclass) {
// All intercepted services are represented by a generated subclass
return service.getClass().getSuperclass().getName();
Expand Down

0 comments on commit ea13bc3

Please sign in to comment.