Skip to content

Commit

Permalink
add grpc request validator
Browse files Browse the repository at this point in the history
  • Loading branch information
sangyongchoi committed Jun 23, 2022
1 parent 5edaef0 commit 5f5b81c
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.linecorp.armeria.server.grpc.validation;

import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;

public class RequestValidationInterceptor implements ServerInterceptor {

private RequestValidatorResolver requestValidatorResolver;

public RequestValidationInterceptor(RequestValidatorResolver requestValidatorResolver) {
this.requestValidatorResolver = requestValidatorResolver;
}

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next
) {
ServerCall.Listener<ReqT> delegate = next.startCall(call, headers);

return new RequestValidationListener<>(delegate, call, headers, requestValidatorResolver);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.linecorp.armeria.server.grpc.validation;

import com.google.protobuf.MessageLiteOrBuilder;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;

public class RequestValidationListener<ReqT, ResT> extends ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT> {

private ServerCall<ReqT, ResT> serverCall;
private Metadata headers;
private RequestValidatorResolver requestValidatorResolver;

public RequestValidationListener(
ServerCall.Listener<ReqT> delegate,
ServerCall<ReqT, ResT> serverCall,
Metadata headers,
RequestValidatorResolver requestValidatorResolver
) {
super(delegate);
this.serverCall = serverCall;
this.headers = headers;
this.requestValidatorResolver = requestValidatorResolver;
}

@Override
public void onMessage(ReqT message) {
MessageLiteOrBuilder convertMessage = (MessageLiteOrBuilder) message;
RequestValidator<MessageLiteOrBuilder> validator = requestValidatorResolver.find(convertMessage.getClass().getTypeName());

if (validator == null) {
super.onMessage(message);
} else {
try {
ValidationResult validationResult = validator.isValid(convertMessage);

if (validationResult.isValid()) {
super.onMessage(message);
} else {
Status status = Status.INVALID_ARGUMENT
.withDescription("invalid argument. " + validationResult.getMessage());
handleInvalidRequest(status);
}
} catch (Exception e) {
Status status = Status.INTERNAL.withDescription(e.getMessage());

handleInvalidRequest(status);
}
}
}

private void handleInvalidRequest(Status status) {
if (!serverCall.isCancelled()) {
serverCall.close(status, headers);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.linecorp.armeria.server.grpc.validation;

import com.google.protobuf.MessageLiteOrBuilder;

interface RequestValidator<T extends MessageLiteOrBuilder> {
ValidationResult isValid(T request);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.linecorp.armeria.server.grpc.validation;

import com.google.protobuf.MessageLiteOrBuilder;

import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class RequestValidatorResolver {

private List<RequestValidator<MessageLiteOrBuilder>> validators;

private Map<String, RequestValidator<MessageLiteOrBuilder>> validatorMap;

public RequestValidatorResolver(List<RequestValidator<MessageLiteOrBuilder>> validators) {
this.validators = validators;

validatorMap = validators.stream()
.collect(Collectors.toMap(this::getClassName, it -> it));
}

private String getClassName(RequestValidator<MessageLiteOrBuilder> it) {
Type[] genericInterfaces = it.getClass().getGenericInterfaces();

if (genericInterfaces.length == 0) {
return null;
}

return ((ParameterizedType) genericInterfaces[0]).getActualTypeArguments()[0].getTypeName();
}

public RequestValidator<MessageLiteOrBuilder> find(String typeName) {
return validatorMap.get(typeName);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.linecorp.armeria.server.grpc.validation;

public class ValidationResult {

private boolean isValid;

private String message;

public ValidationResult(boolean isValid, String message) {
this.isValid = isValid;
this.message = message;
}

public boolean isValid() {
return isValid;
}

public String getMessage() {
return message;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.linecorp.armeria.internal.common.grpc;

import com.linecorp.armeria.grpc.testing.Hello;
import com.linecorp.armeria.grpc.testing.HelloServiceGrpc;
import io.grpc.stub.StreamObserver;

public class HeloServiceImpl extends HelloServiceGrpc.HelloServiceImplBase {

@Override
public void hello(Hello.HelloRequest request, StreamObserver<Hello.HelloResponse> responseObserver) {
Hello.HelloResponse response = Hello.HelloResponse.newBuilder()
.setMessage("success")
.build();

responseObserver.onNext(response);
responseObserver.onCompleted();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.linecorp.armeria.server.grpc.validation;

import com.google.protobuf.MessageLiteOrBuilder;
import com.linecorp.armeria.client.grpc.GrpcClients;
import com.linecorp.armeria.grpc.testing.Hello;
import com.linecorp.armeria.grpc.testing.HelloServiceGrpc.HelloServiceBlockingStub;
import com.linecorp.armeria.internal.common.grpc.HeloServiceImpl;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.grpc.GrpcService;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.util.ArrayList;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.catchThrowable;

class RequestValidationInterceptorTest {

static String ERROR_MESSAGE = "invalid argument";

@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
List<RequestValidator<MessageLiteOrBuilder>> validators = new ArrayList<>();

validators.add((RequestValidator) new HelloRequestValidator());

RequestValidatorResolver requestValidatorResolver = new RequestValidatorResolver(validators);
sb.service(GrpcService.builder()
.addService(new HeloServiceImpl())
.intercept(new RequestValidationInterceptor(requestValidatorResolver))
.build());
}
};

@Test
void validation_fail_test() {
HelloServiceBlockingStub client = GrpcClients.builder(server.httpUri())
.build(HelloServiceBlockingStub.class);

final Throwable cause = catchThrowable(() -> client.hello(Hello.HelloRequest.getDefaultInstance()));
assertThat(cause).isInstanceOf(StatusRuntimeException.class);
assertThat(((StatusRuntimeException) cause).getStatus().getCode()).isEqualTo(Status.INVALID_ARGUMENT.getCode());
}

@Test
void validation_success_test() {
HelloServiceBlockingStub client = GrpcClients.builder(server.httpUri())
.build(HelloServiceBlockingStub.class);

Hello.HelloResponse response = client.hello(
Hello.HelloRequest.newBuilder()
.setMessage("success")
.build()
);

assertThat(response.getMessage()).isEqualTo("success");
}

private static class HelloRequestValidator implements RequestValidator<Hello.HelloRequest> {

@Override
public ValidationResult isValid(Hello.HelloRequest request) {
if (request.getMessage().equals("success")) {
return new ValidationResult(true, null);
}

return new ValidationResult(false, ERROR_MESSAGE);
}
}

}
33 changes: 33 additions & 0 deletions grpc/src/test/proto/com/linecorp/armeria/grpc/testing/hello.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2022 LINE Corporation
//
// LINE Corporation licenses this file to you under the Apache License,
// version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at:
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

syntax = "proto3";

package armeria.grpc.testing;

option java_package = "com.linecorp.armeria.grpc.testing";

import "google/api/annotations.proto";

service HelloService {
rpc hello (HelloRequest) returns (HelloResponse) {}
}

message HelloRequest {
string message = 1;
}

message HelloResponse {
string message = 1;
}

0 comments on commit 5f5b81c

Please sign in to comment.