diff --git a/chill-protobuf/src/main/java/com/twitter/chill/protobuf/ProtobufSerializer.java b/chill-protobuf/src/main/java/com/twitter/chill/protobuf/ProtobufSerializer.java index ffc366c4..072a7636 100644 --- a/chill-protobuf/src/main/java/com/twitter/chill/protobuf/ProtobufSerializer.java +++ b/chill-protobuf/src/main/java/com/twitter/chill/protobuf/ProtobufSerializer.java @@ -43,21 +43,32 @@ public class ProtobufSerializer extends Serializer { * classes in play, which should not be very large. * We can replace with a LRU if we start to see any issues. */ + // Cache for the `parseFrom(byte[] bytes)` method final protected HashMap methodCache = new HashMap(); + // Cache for the `getDefaultInstance()` method + final protected HashMap defaultInstanceMethodCache = new HashMap(); /** * This is slow, so we should cache to avoid killing perf: * See: http://www.jguru.com/faq/view.jsp?EID=246569 */ - protected Method getParse(Class cls) throws Exception { - Method meth = methodCache.get(cls); + private Method getMethodFromCache(Class cls, HashMap cache, String methodName, Class... parameterTypes) throws Exception { + Method meth = cache.get(cls); if (null == meth) { - meth = cls.getMethod("parseFrom", new Class[]{ byte[].class }); - methodCache.put(cls, meth); + meth = cls.getMethod(methodName, parameterTypes); + cache.put(cls, meth); } return meth; } + protected Method getParse(Class cls) throws Exception { + return getMethodFromCache(cls, methodCache, "parseFrom", byte[].class); + } + + protected Method getDefaultInstance(Class cls) throws Exception { + return getMethodFromCache(cls, defaultInstanceMethodCache, "getDefaultInstance"); + } + @Override public void write(Kryo kryo, Output output, Message mes) { byte[] ser = mes.toByteArray(); @@ -69,6 +80,9 @@ public void write(Kryo kryo, Output output, Message mes) { public Message read(Kryo kryo, Input input, Class pbClass) { try { int size = input.readInt(true); + if (size == 0) { + return (Message) getDefaultInstance(pbClass).invoke(null); + } byte[] barr = new byte[size]; input.readBytes(barr); return (Message)getParse(pbClass).invoke(null, barr); diff --git a/chill-protobuf/src/test/scala/com/twitter/chill/protobuf/ProtobufTest.scala b/chill-protobuf/src/test/scala/com/twitter/chill/protobuf/ProtobufTest.scala index 3e02f031..9f2de1d4 100644 --- a/chill-protobuf/src/test/scala/com/twitter/chill/protobuf/ProtobufTest.scala +++ b/chill-protobuf/src/test/scala/com/twitter/chill/protobuf/ProtobufTest.scala @@ -27,6 +27,18 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class ProtobufTest extends AnyWordSpec with Matchers { + def buildKyroPoolWithProtoSer(): KryoPool = + KryoPool.withByteArrayOutputStream( + 1, + new KryoInstantiator { + override def newKryo(): Kryo = { + val k = new Kryo + k.addDefaultSerializer(classOf[Message], classOf[ProtobufSerializer]) + k + } + } + ) + def buildFatigueCount(target: Long, id: Long, count: Int, recentClicks: List[Long]): FatigueCount = { val bldr = FatigueCount .newBuilder() @@ -39,16 +51,7 @@ class ProtobufTest extends AnyWordSpec with Matchers { } "Protobuf round-trips" in { - val kpool = KryoPool.withByteArrayOutputStream( - 1, - new KryoInstantiator { - override def newKryo(): Kryo = { - val k = new Kryo - k.addDefaultSerializer(classOf[Message], classOf[ProtobufSerializer]) - k - } - } - ) + val kpool = buildKyroPoolWithProtoSer() kpool.deepCopy(buildFatigueCount(12L, -1L, 42, List(1L, 2L))) should equal( buildFatigueCount(12L, -1L, 42, List(1L, 2L)) @@ -58,4 +61,10 @@ class ProtobufTest extends AnyWordSpec with Matchers { val kpoolBusted = KryoPool.withByteArrayOutputStream(1, new KryoInstantiator) an[Exception] should be thrownBy kpoolBusted.deepCopy(buildFatigueCount(12L, -1L, 42, List(1L, 2L))) } + + "Default Instance of Should be Ser-DeSer correctly" in { + val kpool = buildKyroPoolWithProtoSer() + + kpool.deepCopy(FatigueCount.getDefaultInstance) should equal(FatigueCount.getDefaultInstance) + } }