diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java index a88a315bf479f..f2ff34ae50dbc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java @@ -79,6 +79,9 @@ public UnsafeRow getKeyRow(int rowId) { keyRow.pointTo(base, offset, klen); // set keyRowId so we can check if desired row is cached keyRowId = rowId; + isValueCached = false; + } else { + isValueCached = true; } return keyRow; } @@ -91,10 +94,12 @@ public UnsafeRow getKeyRow(int rowId) { */ @Override protected UnsafeRow getValueFromKey(int rowId) { - if (keyRowId != rowId) { + assert(rowId >= 0); + if (keyRowId == rowId && isValueCached) { + return valueRow; + } else if (keyRowId != rowId) { getKeyRow(rowId); } - assert(rowId >= 0); valueRow.pointTo(base, keyRow.getBaseOffset() + klen, vlen + 4); return valueRow; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 551443a11298b..3f2225ab3cadf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -62,6 +62,9 @@ public abstract class RowBasedKeyValueBatch extends MemoryConsumer { protected final UnsafeRow keyRow; protected final UnsafeRow valueRow; + // mark valueRow as cached after calling getKeyRow if rowId == keyRowId + protected boolean isValueCached = false; + protected MemoryBlock page = null; protected Object base = null; protected final long recordStartOffset; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index ea4f984be24e5..3c4c6cd804667 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -83,6 +83,9 @@ public UnsafeRow getKeyRow(int rowId) { keyRow.pointTo(base, offset, klen); // set keyRowId so we can check if desired row is cached keyRowId = rowId; + isValueCached = false; + } else { + isValueCached = true; } return keyRow; } @@ -95,10 +98,12 @@ public UnsafeRow getKeyRow(int rowId) { */ @Override public UnsafeRow getValueFromKey(int rowId) { - if (keyRowId != rowId) { + assert(rowId >= 0); + if (keyRowId == rowId && isValueCached) { + return valueRow; + } else if (keyRowId != rowId) { getKeyRow(rowId); } - assert(rowId >= 0); long offset = keyRow.getBaseOffset(); int klen = keyRow.getSizeInBytes(); int vlen = Platform.getInt(base, offset - 8) - klen - 4;