Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restores PackableMethod serialization mode #14323

Merged
merged 10 commits into from
Jul 23, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ public void cancel(Throwable throwable) {
closed();
}
}
this.cancellationContext.cancel(throwable);
if (cancellationContext != null) {
cancellationContext.cancel(throwable);
}
long errorCode = 0;
if (throwable instanceof ErrorCodeHolder) {
errorCode = ((ErrorCodeHolder) throwable).getErrorCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
*/
package org.apache.dubbo.remoting.http12.message;

import org.apache.dubbo.common.io.StreamUtils;
import org.apache.dubbo.remoting.http12.CompositeInputStream;
import org.apache.dubbo.remoting.http12.exception.DecodeException;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;

Expand All @@ -47,8 +45,6 @@ public class LengthFieldStreamingDecoder implements StreamingDecoder {

private int requiredLength;

private InputStream dataHeader = StreamUtils.EMPTY;

public LengthFieldStreamingDecoder() {
this(4);
}
Expand Down Expand Up @@ -147,16 +143,12 @@ private void deliver() {
}

private void processHeader() throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream(lengthFieldOffset + lengthFieldLength);
byte[] offsetData = new byte[lengthFieldOffset];
int ignore = accumulate.read(offsetData);
bos.write(offsetData);
processOffset(new ByteArrayInputStream(offsetData), lengthFieldOffset);
byte[] lengthBytes = new byte[lengthFieldLength];
ignore = accumulate.read(lengthBytes);
bos.write(lengthBytes);
requiredLength = bytesToInt(lengthBytes);
this.dataHeader = new ByteArrayInputStream(bos.toByteArray());

// Continue reading the frame body.
state = DecodeState.PAYLOAD;
Expand Down Expand Up @@ -184,8 +176,8 @@ private void processBody() throws IOException {
requiredLength = lengthFieldOffset + lengthFieldLength;
}

protected void invokeListener(InputStream inputStream) {
this.listener.onFragmentMessage(dataHeader, inputStream);
public void invokeListener(InputStream inputStream) {
this.listener.onFragmentMessage(inputStream);
}

protected byte[] readRawMessage(InputStream inputStream, int length) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ interface FragmentListener {
*/
void onFragmentMessage(InputStream rawMessage);

/**
* @param rawMessage raw message
*/
default void onFragmentMessage(InputStream dataHeader, InputStream rawMessage) {
onFragmentMessage(rawMessage);
}

default void onClose() {}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.apache.dubbo.common.URL;
import org.apache.dubbo.common.constants.CommonConstants;
import org.apache.dubbo.common.io.StreamUtils;
import org.apache.dubbo.common.utils.CollectionUtils;
import org.apache.dubbo.remoting.http12.exception.UnimplementedException;
import org.apache.dubbo.rpc.Invoker;
Expand All @@ -28,6 +29,8 @@
import org.apache.dubbo.rpc.service.ServiceDescriptorInternalCache;
import org.apache.dubbo.rpc.stub.StubSuppliers;

import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;

Expand Down Expand Up @@ -124,9 +127,10 @@ public static MethodDescriptor findReflectionMethodDescriptor(
}

public static MethodDescriptor findTripleMethodDescriptor(
ServiceDescriptor serviceDescriptor, String methodName, byte[] data) {
ServiceDescriptor serviceDescriptor, String methodName, InputStream rawMessage) throws IOException {
MethodDescriptor methodDescriptor = findReflectionMethodDescriptor(serviceDescriptor, methodName);
if (methodDescriptor == null) {
byte[] data = StreamUtils.readBytes(rawMessage);
List<MethodDescriptor> methodDescriptors = serviceDescriptor.getMethods(methodName);
TripleRequestWrapper request = TripleRequestWrapper.parseFrom(data);
String[] paramTypes = request.getArgTypes().toArray(new String[0]);
Expand All @@ -141,6 +145,7 @@ public static MethodDescriptor findTripleMethodDescriptor(
if (methodDescriptor == null) {
throw new UnimplementedException("method:" + methodName);
}
rawMessage.reset();
}
return methodDescriptor;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ protected RpcInvocation buildRpcInvocation(RpcInvocationBuildContext context) {
methodDescriptor = DescriptorUtils.findMethodDescriptor(
context.getServiceDescriptor(), context.getMethodName(), context.isHasStub());
context.setMethodDescriptor(methodDescriptor);
onSettingMethodDescriptor(methodDescriptor);
}
MethodMetadata methodMetadata = context.getMethodMetadata();
if (methodMetadata == null) {
Expand Down Expand Up @@ -280,4 +281,6 @@ protected final HttpMessageListener getHttpMessageListener() {
protected void setHttpMessageListener(HttpMessageListener httpMessageListener) {
this.httpMessageListener = httpMessageListener;
}

protected void onSettingMethodDescriptor(MethodDescriptor methodDescriptor) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,62 @@
*/
package org.apache.dubbo.rpc.protocol.tri.h12.grpc;

import org.apache.dubbo.common.URL;
import org.apache.dubbo.common.config.ConfigurationUtils;
import org.apache.dubbo.common.io.StreamUtils;
import org.apache.dubbo.common.utils.ArrayUtils;
import org.apache.dubbo.remoting.http12.exception.DecodeException;
import org.apache.dubbo.remoting.http12.exception.EncodeException;
import org.apache.dubbo.remoting.http12.exception.HttpStatusException;
import org.apache.dubbo.remoting.http12.message.HttpMessageCodec;
import org.apache.dubbo.remoting.http12.message.MediaType;
import org.apache.dubbo.rpc.model.FrameworkModel;
import org.apache.dubbo.rpc.model.MethodDescriptor;
import org.apache.dubbo.rpc.model.PackableMethod;
import org.apache.dubbo.rpc.model.PackableMethodFactory;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.google.protobuf.Message;

import static org.apache.dubbo.common.constants.CommonConstants.PROTOBUF_MESSAGE_CLASS_NAME;
import static org.apache.dubbo.common.constants.CommonConstants.DEFAULT_KEY;
import static org.apache.dubbo.common.constants.CommonConstants.DUBBO_PACKABLE_METHOD_FACTORY;

public class GrpcCompositeCodec implements HttpMessageCodec {

private final ProtobufHttpMessageCodec protobufHttpMessageCodec;
private static final String PACKABLE_METHOD_CACHE = "PACKABLE_METHOD_CACHE";

private final WrapperHttpMessageCodec wrapperHttpMessageCodec;
private final URL url;

public GrpcCompositeCodec(
ProtobufHttpMessageCodec protobufHttpMessageCodec, WrapperHttpMessageCodec wrapperHttpMessageCodec) {
this.protobufHttpMessageCodec = protobufHttpMessageCodec;
this.wrapperHttpMessageCodec = wrapperHttpMessageCodec;
}
private final FrameworkModel frameworkModel;

private final String mediaType;

public void setEncodeTypes(Class<?>[] encodeTypes) {
this.wrapperHttpMessageCodec.setEncodeTypes(encodeTypes);
private PackableMethod packableMethod;

public GrpcCompositeCodec(URL url, FrameworkModel frameworkModel, String mediaType) {
this.url = url;
this.frameworkModel = frameworkModel;
this.mediaType = mediaType;
}

public void setDecodeTypes(Class<?>[] decodeTypes) {
this.wrapperHttpMessageCodec.setDecodeTypes(decodeTypes);
public void loadPackableMethod(MethodDescriptor methodDescriptor) {
if (methodDescriptor instanceof PackableMethod) {
packableMethod = (PackableMethod) methodDescriptor;
return;
}
Map<MethodDescriptor, PackableMethod> cacheMap = (Map<MethodDescriptor, PackableMethod>) url.getServiceModel()
.getServiceMetadata()
.getAttributeMap()
.computeIfAbsent(PACKABLE_METHOD_CACHE, k -> new ConcurrentHashMap<>());
packableMethod = cacheMap.computeIfAbsent(methodDescriptor, md -> frameworkModel
.getExtensionLoader(PackableMethodFactory.class)
.getExtension(ConfigurationUtils.getGlobalConfiguration(url.getApplicationModel())
.getString(DUBBO_PACKABLE_METHOD_FACTORY, DEFAULT_KEY))
.create(methodDescriptor, url, mediaType));
}

@Override
Expand All @@ -58,34 +82,38 @@ public void encode(OutputStream outputStream, Object data, Charset charset) thro
try {
int compressed = 0;
outputStream.write(compressed);
if (isProtobuf(data)) {
ProtobufWriter.write(protobufHttpMessageCodec, outputStream, data);
return;
}
// wrapper
wrapperHttpMessageCodec.encode(outputStream, data);
} catch (IOException e) {
byte[] bytes = packableMethod.packResponse(data);
writeLength(outputStream, bytes.length);
outputStream.write(bytes);
} catch (HttpStatusException e) {
throw e;
} catch (Exception e) {
throw new EncodeException(e);
}
}

@Override
public Object decode(InputStream inputStream, Class<?> targetType, Charset charset) throws DecodeException {
if (isProtoClass(targetType)) {
return protobufHttpMessageCodec.decode(inputStream, targetType, charset);
try {
byte[] data = StreamUtils.readBytes(inputStream);
return packableMethod.parseRequest(data);
} catch (HttpStatusException e) {
throw e;
} catch (Exception e) {
throw new DecodeException(e);
}
return wrapperHttpMessageCodec.decode(inputStream, targetType, charset);
}

@Override
public Object[] decode(InputStream inputStream, Class<?>[] targetTypes, Charset charset) throws DecodeException {
if (targetTypes.length > 1) {
return wrapperHttpMessageCodec.decode(inputStream, targetTypes, charset);
Object message = decode(inputStream, ArrayUtils.isEmpty(targetTypes) ? null : targetTypes[0], charset);
if (message instanceof Object[]) {
return (Object[]) message;
}
return HttpMessageCodec.super.decode(inputStream, targetTypes, charset);
return new Object[] {message};
}

private static void writeLength(OutputStream outputStream, int length) {
private void writeLength(OutputStream outputStream, int length) {
try {
outputStream.write(((length >> 24) & 0xFF));
outputStream.write(((length >> 16) & 0xFF));
Expand All @@ -100,39 +128,4 @@ private static void writeLength(OutputStream outputStream, int length) {
public MediaType mediaType() {
return MediaType.APPLICATION_GRPC;
}

private static boolean isProtobuf(Object data) {
if (data == null) {
return false;
}
return isProtoClass(data.getClass());
}

private static boolean isProtoClass(Class<?> clazz) {
while (clazz != Object.class && clazz != null) {
Class<?>[] interfaces = clazz.getInterfaces();
if (interfaces.length > 0) {
for (Class<?> clazzInterface : interfaces) {
if (PROTOBUF_MESSAGE_CLASS_NAME.equalsIgnoreCase(clazzInterface.getName())) {
return true;
}
}
}
clazz = clazz.getSuperclass();
}
return false;
}

/**
* lazy init protobuf class
*/
private static class ProtobufWriter {

private static void write(HttpMessageCodec codec, OutputStream outputStream, Object data) {
int serializedSize = ((Message) data).getSerializedSize();
// write length
writeLength(outputStream, serializedSize);
codec.encode(outputStream, data);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@
import org.apache.dubbo.remoting.http12.message.HttpMessageDecoderFactory;
import org.apache.dubbo.remoting.http12.message.HttpMessageEncoderFactory;
import org.apache.dubbo.remoting.http12.message.MediaType;
import org.apache.dubbo.remoting.utils.UrlUtils;
import org.apache.dubbo.rpc.model.FrameworkModel;

@Activate
public class GrpcCompositeCodecFactory implements HttpMessageEncoderFactory, HttpMessageDecoderFactory {

@Override
public HttpMessageCodec createCodec(URL url, FrameworkModel frameworkModel, String mediaType) {
String serializeName = UrlUtils.serializationOrDefault(url);
WrapperHttpMessageCodec wrapperHttpMessageCodec = new WrapperHttpMessageCodec(url, frameworkModel);
wrapperHttpMessageCodec.setSerializeType(serializeName);
ProtobufHttpMessageCodec protobufHttpMessageCodec = new ProtobufHttpMessageCodec();
return new GrpcCompositeCodec(protobufHttpMessageCodec, wrapperHttpMessageCodec);
return new GrpcCompositeCodec(url, frameworkModel, mediaType);
}

@Override
Expand Down
Loading
Loading