From 395c1da4e1ec8ec808736c69be18325581cf69e8 Mon Sep 17 00:00:00 2001 From: Jasper Potts <1466205+jasperpotts@users.noreply.github.com> Date: Thu, 1 Feb 2024 13:30:51 -0800 Subject: [PATCH] Fix writeBytes performance, 1000x improvement. (#193) Signed-off-by: jasperpotts Signed-off-by: Anthony Petrov Co-authored-by: jasperpotts Co-authored-by: Anthony Petrov --- .../kotlin/com.hedera.pbj.runtime.gradle.kts | 6 + .../runtime/io/WriteBufferedDataBench.java | 111 ++++++++++++++++++ .../pbj/runtime/io/WriteBytesBench.java | 110 +++++++++++++++++ .../pbj/runtime/io/buffer/BufferedData.java | 30 +++++ .../hedera/pbj/runtime/io/buffer/Bytes.java | 14 +-- .../runtime/io/buffer/RandomAccessData.java | 23 +++- .../io/stream/WritableStreamingData.java | 17 +++ .../buffer/StubbedRandomAccessDataTest.java | 21 ++++ .../io/stream/WritableStreamingDataTest.java | 23 ++++ 9 files changed, 344 insertions(+), 11 deletions(-) create mode 100644 pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBufferedDataBench.java create mode 100644 pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBytesBench.java diff --git a/pbj-core/buildSrc/src/main/kotlin/com.hedera.pbj.runtime.gradle.kts b/pbj-core/buildSrc/src/main/kotlin/com.hedera.pbj.runtime.gradle.kts index 2ffb669d..e6359a7c 100644 --- a/pbj-core/buildSrc/src/main/kotlin/com.hedera.pbj.runtime.gradle.kts +++ b/pbj-core/buildSrc/src/main/kotlin/com.hedera.pbj.runtime.gradle.kts @@ -36,3 +36,9 @@ protobuf { val maven = publishing.publications.create("maven") { from(components["java"]) } signing.sign(maven) + +// Filter JMH benchmarks for testing +//jmh { +// includes.add("WriteBytesBench") +// includes.add("WriteBufferedDataBench") +//} diff --git a/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBufferedDataBench.java b/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBufferedDataBench.java new file mode 100644 index 00000000..458e6822 --- /dev/null +++ b/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBufferedDataBench.java @@ -0,0 +1,111 @@ +package com.hedera.pbj.runtime.io; + +import com.hedera.pbj.runtime.FieldDefinition; +import com.hedera.pbj.runtime.FieldType; +import com.hedera.pbj.runtime.ProtoParserTools; +import com.hedera.pbj.runtime.ProtoWriterTools; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.io.stream.ReadableStreamingData; +import com.hedera.pbj.runtime.io.stream.WritableStreamingData; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@SuppressWarnings("unused") +@State(Scope.Benchmark) +@Fork(1) +@Warmup(iterations = 4, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Mode.Throughput) +public class WriteBufferedDataBench { + + public static final FieldDefinition BYTES_FIELD = new FieldDefinition("bytesField", FieldType.BYTES, false, false, false, 17); + final static BufferedData sampleData; + final static byte[] sampleWrittenData; + + static { + final Random random = new Random(6262266); + byte[] data = new byte[1024*16]; + random.nextBytes(data); + sampleData = BufferedData.wrap(data); + + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + try (WritableStreamingData out = new WritableStreamingData(bout)) { + for (int i = 0; i < 100; i++) { + random.nextBytes(data); + ProtoWriterTools.writeBytes(out, BYTES_FIELD, sampleData); + } + } catch (IOException e) { + e.printStackTrace(); + } + sampleWrittenData = bout.toByteArray(); + } + + Path tempFileWriting; + Path tempFileReading; + OutputStream fout; + WritableStreamingData dataOut; + + @Setup + public void prepare() { + try { + tempFileWriting = Files.createTempFile("WriteBytesBench", "dat"); + tempFileWriting.toFile().deleteOnExit(); + fout = Files.newOutputStream(tempFileWriting); + dataOut = new WritableStreamingData(fout); + tempFileReading = Files.createTempFile("WriteBytesBench", "dat"); + tempFileReading.toFile().deleteOnExit(); + Files.write(tempFileReading, sampleWrittenData); + } catch (IOException e) { + e.printStackTrace(); + throw new UncheckedIOException(e); + } + } + + @TearDown + public void cleanUp() { + try { + dataOut.close(); + fout.close(); + } catch (IOException e){ + e.printStackTrace(); + throw new UncheckedIOException(e); + } + } + + @Benchmark + public void writeBytes(Blackhole blackhole) throws IOException { + ProtoWriterTools.writeBytes(dataOut, BYTES_FIELD, sampleData); + } + + @Benchmark + @OperationsPerInvocation(100) + public void readBytes(Blackhole blackhole) throws IOException { + try (ReadableStreamingData in = new ReadableStreamingData(Files.newInputStream(tempFileReading)) ) { + for (int i = 0; i < 100; i++) { + blackhole.consume(in.readVarInt(false)); + blackhole.consume(ProtoParserTools.readBytes(in)); + } + } + } +} diff --git a/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBytesBench.java b/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBytesBench.java new file mode 100644 index 00000000..aca18db0 --- /dev/null +++ b/pbj-core/pbj-runtime/src/jmh/java/com/hedera/pbj/runtime/io/WriteBytesBench.java @@ -0,0 +1,110 @@ +package com.hedera.pbj.runtime.io; + +import com.hedera.pbj.runtime.FieldDefinition; +import com.hedera.pbj.runtime.FieldType; +import com.hedera.pbj.runtime.ProtoParserTools; +import com.hedera.pbj.runtime.ProtoWriterTools; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.runtime.io.stream.ReadableStreamingData; +import com.hedera.pbj.runtime.io.stream.WritableStreamingData; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@SuppressWarnings("unused") +@State(Scope.Benchmark) +@Fork(1) +@Warmup(iterations = 4, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Mode.Throughput) +public class WriteBytesBench { + + public static final FieldDefinition BYTES_FIELD = new FieldDefinition("bytesField", FieldType.BYTES, false, false, false, 17); + final static Bytes sampleData; + final static byte[] sampleWrittenData; + + static { + final Random random = new Random(6262266); + byte[] data = new byte[1024*16]; + random.nextBytes(data); + sampleData = Bytes.wrap(data); + + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + try (WritableStreamingData out = new WritableStreamingData(bout)) { + for (int i = 0; i < 100; i++) { + random.nextBytes(data); + ProtoWriterTools.writeBytes(out, BYTES_FIELD, sampleData); + } + } catch (IOException e) { + e.printStackTrace(); + } + sampleWrittenData = bout.toByteArray(); + } + + Path tempFileWriting; + Path tempFileReading; + OutputStream fout; + WritableStreamingData dataOut; + + @Setup + public void prepare() { + try { + tempFileWriting = Files.createTempFile("WriteBytesBench", "dat"); + tempFileWriting.toFile().deleteOnExit(); + fout = Files.newOutputStream(tempFileWriting); + dataOut = new WritableStreamingData(fout); + tempFileReading = Files.createTempFile("WriteBytesBench", "dat"); + tempFileReading.toFile().deleteOnExit(); + Files.write(tempFileReading, sampleWrittenData); + } catch (IOException e) { + e.printStackTrace(); + throw new UncheckedIOException(e); + } + } + + @TearDown + public void cleanUp() { + try { + dataOut.close(); + fout.close(); + } catch (IOException e){ + e.printStackTrace(); + throw new UncheckedIOException(e); + } + } + + @Benchmark + public void writeBytes(Blackhole blackhole) throws IOException { + ProtoWriterTools.writeBytes(dataOut, BYTES_FIELD, sampleData); + } + + @Benchmark + @OperationsPerInvocation(100) + public void readBytes(Blackhole blackhole) throws IOException { + try (ReadableStreamingData in = new ReadableStreamingData(Files.newInputStream(tempFileReading)) ) { + for (int i = 0; i < 100; i++) { + blackhole.consume(in.readVarInt(false)); + blackhole.consume(ProtoParserTools.readBytes(in)); + } + } + } +} diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/BufferedData.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/BufferedData.java index c0d254c9..aaa7d0f1 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/BufferedData.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/BufferedData.java @@ -8,10 +8,13 @@ import edu.umd.cs.findbugs.annotations.NonNull; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; /** * A buffer backed by a {@link ByteBuffer} that is a {@link BufferedSequentialData} (and therefore contains @@ -869,6 +872,33 @@ public void writeVarLong(long value, final boolean zigZag) { } } + /** + * {@inheritDoc} + */ + @Override + public void writeTo(@NonNull OutputStream outStream) { + try { + final WritableByteChannel channel = Channels.newChannel(outStream); + channel.write(buffer.duplicate().position(0).limit(buffer.limit())); + } catch (IOException e) { + throw new DataAccessException(e); + } + } + + /** + * {@inheritDoc} + */ + @Override + public void writeTo(@NonNull OutputStream outStream, int offset, int length) { + validateCanRead(offset, length); + try { + final WritableByteChannel channel = Channels.newChannel(outStream); + channel.write(buffer.duplicate().position(offset).limit(offset + length)); + } catch (IOException e) { + throw new DataAccessException(e); + } + } + // Helper methods protected void validateLen(final long len) { diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java index e0352932..e848ed6b 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/Bytes.java @@ -238,11 +238,9 @@ public void writeTo(@NonNull final ByteBuffer dstBuffer, final int offset, final } /** - * A helper method for efficient copy of our data into an OutputStream without creating a defensive copy - * of the data. The implementation relies on a well-behaved OutputStream that doesn't modify the buffer data. - * - * @param outStream the OutputStream to copy into + * {@inheritDoc} */ + @Override public void writeTo(@NonNull final OutputStream outStream) { try { outStream.write(buffer, start, length); @@ -252,13 +250,9 @@ public void writeTo(@NonNull final OutputStream outStream) { } /** - * A helper method for efficient copy of our data into an OutputStream without creating a defensive copy - * of the data. The implementation relies on a well-behaved OutputStream that doesn't modify the buffer data. - * - * @param outStream The OutputStream to copy into. - * @param offset The offset from the start of this {@link Bytes} object to get the bytes from. - * @param length The number of bytes to extract. + * {@inheritDoc} */ + @Override public void writeTo(@NonNull final OutputStream outStream, final int offset, final int length) { try { outStream.write(buffer, offset, length); diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/RandomAccessData.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/RandomAccessData.java index d2e5b864..2c217165 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/RandomAccessData.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/buffer/RandomAccessData.java @@ -4,13 +4,16 @@ import com.hedera.pbj.runtime.io.SequentialData; import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.IOException; +import java.io.OutputStream; +import java.io.UncheckedIOException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; /** - * Represents data which may be accessed out of order in some random manner. Unliked {@link SequentialData}, + * Represents data which may be accessed out of order in some random manner. Unlike {@link SequentialData}, * this interface is only backed by a buffer of some kind: an array, a {@link ByteBuffer}, a memory-mapped file, etc. * Unlike {@link BufferedSequentialData}, it does not define any kind of "position" cursor, just a "length" representing * the valid range of indexes and methods for reading data at any of those indexes. @@ -546,4 +549,22 @@ default boolean contains(final long offset, @NonNull final RandomAccessData data } return true; } + + /** + * A helper method for efficient copy of our data into an OutputStream without creating a defensive copy + * of the data. The implementation relies on a well-behaved OutputStream that doesn't modify the buffer data. + * + * @param outStream the OutputStream to copy into + */ + void writeTo(@NonNull final OutputStream outStream); + + /** + * A helper method for efficient copy of our data into an OutputStream without creating a defensive copy + * of the data. The implementation relies on a well-behaved OutputStream that doesn't modify the buffer data. + * + * @param outStream The OutputStream to copy into. + * @param offset The offset from the start of this {@link Bytes} object to get the bytes from. + * @param length The number of bytes to extract. + */ + void writeTo(@NonNull final OutputStream outStream, final int offset, final int length); } diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/stream/WritableStreamingData.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/stream/WritableStreamingData.java index 079424b5..d14b1adf 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/stream/WritableStreamingData.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/io/stream/WritableStreamingData.java @@ -2,6 +2,7 @@ import com.hedera.pbj.runtime.io.DataAccessException; import com.hedera.pbj.runtime.io.WritableSequentialData; +import com.hedera.pbj.runtime.io.buffer.RandomAccessData; import edu.umd.cs.findbugs.annotations.NonNull; import java.io.Closeable; import java.io.Flushable; @@ -193,6 +194,9 @@ public void writeBytes(@NonNull final byte[] src) { } } + /** + * {@inheritDoc} + */ @Override public void writeBytes(@NonNull final ByteBuffer src) { if (!src.hasArray()) { @@ -221,6 +225,19 @@ public void writeBytes(@NonNull final ByteBuffer src) { } } + /** + * {@inheritDoc} + */ + @Override + public void writeBytes(@NonNull final RandomAccessData src) { + final long len = src.length(); + if (remaining() < len) { + throw new BufferOverflowException(); + } + src.writeTo(out); + position += len; + } + // ================================================================================================================ // Flushable Methods diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/buffer/StubbedRandomAccessDataTest.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/buffer/StubbedRandomAccessDataTest.java index e58c7ce1..ff27c080 100644 --- a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/buffer/StubbedRandomAccessDataTest.java +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/buffer/StubbedRandomAccessDataTest.java @@ -1,7 +1,10 @@ package com.hedera.pbj.runtime.io.buffer; +import com.hedera.pbj.runtime.io.DataAccessException; import com.hedera.pbj.runtime.io.ReadableSequentialData; import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.IOException; +import java.io.OutputStream; import static org.assertj.core.api.Assertions.assertThat; @@ -43,5 +46,23 @@ public long length() { public byte getByte(long offset) { return bytes[Math.toIntExact(offset)]; } + + @Override + public void writeTo(@NonNull OutputStream outStream) { + try { + outStream.write(bytes); + } catch (IOException e) { + throw new DataAccessException(e); + } + } + + @Override + public void writeTo(@NonNull OutputStream outStream, int offset, int length) { + try { + outStream.write(bytes, offset, length); + } catch (IOException e) { + throw new DataAccessException(e); + } + } } } diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/stream/WritableStreamingDataTest.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/stream/WritableStreamingDataTest.java index c9c0ce26..8e08ce1c 100644 --- a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/stream/WritableStreamingDataTest.java +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/io/stream/WritableStreamingDataTest.java @@ -3,6 +3,7 @@ import com.hedera.pbj.runtime.io.DataAccessException; import com.hedera.pbj.runtime.io.WritableSequentialData; import com.hedera.pbj.runtime.io.WritableTestBase; +import com.hedera.pbj.runtime.io.buffer.RandomAccessData; import edu.umd.cs.findbugs.annotations.NonNull; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -15,9 +16,11 @@ import java.nio.charset.StandardCharsets; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -118,4 +121,24 @@ void testFlushable() throws IOException { verify(out, times(1)).flush(); verifyNoMoreInteractions(out); } + + @Test + @DisplayName("writeBytes(RandomAccessData) should delegate to RandomAccessData.writeTo(OutputStream)") + void testWriteBytesFastPath() { + final OutputStream out = mock(OutputStream.class); + final RandomAccessData data = mock(RandomAccessData.class); + doReturn(10L).when(data).length(); + doNothing().when(data).writeTo(out); + + final WritableStreamingData seq = new WritableStreamingData(out); + + seq.writeBytes(data); + + verify(data, times(1)).length(); + verify(data, times(1)).writeTo(out); + verifyNoMoreInteractions(data, out); + + assertEquals(10L, seq.position()); + } + }