Skip to content

Commit ecff4ff

Browse files
committed
Update simple row batch to improve performance & use SimpleRowBatch by default
1 parent a158125 commit ecff4ff

File tree

2 files changed

+55
-65
lines changed

2 files changed

+55
-65
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SimpleRowBatch.java

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242
public final class SimpleRowBatch extends MemoryConsumer{
4343
private static final int DEFAULT_CAPACITY = 1 << 16;
4444

45-
private final TaskMemoryManager taskMemoryManager;
46-
4745
private final StructType keySchema;
4846
private final StructType valueSchema;
4947
private final int capacity;
@@ -55,17 +53,10 @@ public final class SimpleRowBatch extends MemoryConsumer{
5553

5654
// ids for current key row and value row being retrieved
5755
private int keyRowId = -1;
58-
private int valueRowId = -1;
5956

6057
// full addresses for key rows and value rows
6158
// TODO: opt: this could be eliminated if all fields are fixed length
62-
private long[] keyFullAddress;
63-
private long[] valueFullAddress;
64-
// shortcuts for lengths, which can also be retrieved directly from UnsafeRow
65-
// TODO: might want to remove this shortcut, retrieving directly from UnsafeRow could be
66-
// faster due to cache locality
67-
private int[] keyLength;
68-
private int[] valueLength;
59+
private long[] keyOffsets;
6960

7061
// if all data types in the schema are fixed length
7162
private boolean allFixedLength;
@@ -118,7 +109,6 @@ private long getKeyOffsetForFixedLengthRecords(int rowId) {
118109

119110
public UnsafeRow appendRow(Object kbase, long koff, int klen,
120111
Object vbase, long voff, int vlen) {
121-
122112
final long recordLength = 8 + klen + vlen + 8;
123113
// if run out of max supported rows or page size, return null
124114
if (numRows >= capacity || currentPage == null
@@ -137,8 +127,10 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen,
137127
final Object base = currentPage.getBaseObject();
138128
long offset = currentPage.getBaseOffset() + pageCursor;
139129
final long recordOffset = offset;
140-
Platform.putInt(base, offset, klen + vlen + 4);
141-
Platform.putInt(base, offset + 4, klen);
130+
if (!allFixedLength) { // we only put lengths info for variable length
131+
Platform.putInt(base, offset, klen + vlen + 4);
132+
Platform.putInt(base, offset + 4, klen);
133+
}
142134
offset += 8;
143135
Platform.copyMemory(kbase, koff, base, offset, klen);
144136
offset += klen;
@@ -150,15 +142,10 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen,
150142
Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
151143
pageCursor += recordLength;
152144

153-
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage,
154-
recordOffset);
155-
keyFullAddress[numRows] = storedKeyAddress + 8;
156-
valueFullAddress[numRows] = storedKeyAddress + 8 + klen;
157-
keyLength[numRows] = klen;
158-
valueLength[numRows] = vlen;
145+
146+
if (!allFixedLength) keyOffsets[numRows] = recordOffset + 8;
159147

160148
keyRowId = numRows;
161-
valueRowId = numRows;
162149
keyRow.pointTo(base, recordOffset + 8, klen);
163150
valueRow.pointTo(base, recordOffset + 8 + klen, vlen + 4);
164151
numRows++;
@@ -171,47 +158,43 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen,
171158
public UnsafeRow getKeyRow(int rowId) {
172159
assert(rowId >= 0);
173160
assert(rowId < numRows);
174-
if (keyRowId != rowId) {
175-
long offset = getKeyOffsetForFixedLengthRecords(rowId);
176-
keyRow.pointTo(currentAndOnlyBase, offset, klen);
161+
if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
162+
if (allFixedLength) {
163+
long offset = getKeyOffsetForFixedLengthRecords(rowId);
164+
keyRow.pointTo(currentAndOnlyBase, offset, klen);
165+
} else {
166+
long offset = keyOffsets[rowId];
167+
klen = Platform.getInt(currentAndOnlyBase, offset - 4);
168+
keyRow.pointTo(currentAndOnlyBase, offset, klen);
169+
}
170+
// set keyRowId so we can check if desired row is cached
177171
keyRowId = rowId;
178172
}
179173
return keyRow;
180174
}
181175

182-
/**
183-
* Returns the value row in this batch at `rowId`. Returned value row is reused across calls.
184-
* Should be avoided if `getValueFromKey()` gives better performance.
185-
*/
186-
public UnsafeRow getValueRow(int rowId) {
187-
assert(rowId >= 0);
188-
assert(rowId < numRows);
189-
if (valueRowId != rowId) {
190-
long offset = getKeyOffsetForFixedLengthRecords(rowId) + klen;
191-
valueRow.pointTo(currentAndOnlyBase, offset, vlen + 4);
192-
valueRowId = rowId;
193-
}
194-
return valueRow;
195-
}
196-
197176
/**
198177
* Returns the value row in this batch at `rowId`.
199178
* It can be a faster path if `keyRowId` is equal to `rowId`, which means the preceding
200-
* key row has just been accessed. As this is often the case, this method should be preferred
201-
* over `getValueRow()`.
202-
* This method is faster than `getValueRow()` because it avoids address decoding, instead reuse
203-
* the page and offset information from the preceding key row.
179+
* key row has just been accessed. This is always the case so far.
204180
* Returned value row is reused across calls.
205181
*/
206182
public UnsafeRow getValueFromKey(int rowId) {
207183
if (keyRowId != rowId) {
208184
getKeyRow(rowId);
209185
}
210186
assert(rowId >= 0);
211-
valueRow.pointTo(currentAndOnlyBase,
212-
keyRow.getBaseOffset() + klen,
213-
vlen + 4);
214-
valueRowId = rowId;
187+
if (allFixedLength) {
188+
valueRow.pointTo(currentAndOnlyBase,
189+
keyRow.getBaseOffset() + klen,
190+
vlen + 4);
191+
} else {
192+
long offset = keyOffsets[rowId];
193+
vlen = Platform.getInt(currentAndOnlyBase, offset - 8) - klen - 4;
194+
valueRow.pointTo(currentAndOnlyBase,
195+
offset + klen,
196+
vlen + 4);
197+
}
215198
return valueRow;
216199
}
217200

@@ -232,8 +215,8 @@ public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
232215
private long offsetInPage = 0;
233216
private int recordsInPage = 0;
234217

235-
private int klen;
236-
private int vlen;
218+
private int currentklen;
219+
private int currentvlen;
237220
private int totalLength;
238221

239222
private boolean inited = false;
@@ -257,12 +240,18 @@ public boolean next() {
257240
if (!advanceToNextPage()) return false;
258241
}
259242

260-
totalLength = Platform.getInt(pageBaseObject, offsetInPage);
261-
klen = Platform.getInt(pageBaseObject, offsetInPage + 4);
262-
vlen = totalLength - klen;
243+
if (allFixedLength) {
244+
totalLength = klen + vlen + 4;
245+
currentklen = klen;
246+
currentvlen = vlen;
247+
} else {
248+
totalLength = Platform.getInt(pageBaseObject, offsetInPage);
249+
klen = Platform.getInt(pageBaseObject, offsetInPage + 4);
250+
currentvlen = totalLength - currentklen - 4;
251+
}
263252

264-
key.pointTo(pageBaseObject, offsetInPage + 8, klen);
265-
value.pointTo(pageBaseObject, offsetInPage + 8 + klen, vlen);
253+
key.pointTo(pageBaseObject, offsetInPage + 8, currentklen);
254+
value.pointTo(pageBaseObject, offsetInPage + 8 + currentklen, currentvlen + 4);
266255
offsetInPage += 4 + totalLength + 8;
267256
recordsInPage -= 1;
268257
return true;
@@ -306,17 +295,12 @@ private SimpleRowBatch(StructType keySchema, StructType valueSchema, int maxRows
306295
this.keySchema = keySchema;
307296
this.valueSchema = valueSchema;
308297
this.capacity = maxRows;
309-
this.taskMemoryManager = manager;
310-
this.keyFullAddress = new long[maxRows];
311-
this.valueFullAddress = new long[maxRows];
312-
this.keyLength = new int[maxRows];
313-
this.valueLength = new int[maxRows];
314298

315299
this.keyRow = new UnsafeRow(keySchema.length());
316300
this.valueRow = new UnsafeRow(valueSchema.length());
317301

318302
// checking if there is any variable length fields
319-
// there is probably more succint impl of this
303+
// there is probably a more succint impl of this
320304
allFixedLength = true;
321305
for (String name : keySchema.fieldNames()) {
322306
allFixedLength = allFixedLength
@@ -326,10 +310,16 @@ private SimpleRowBatch(StructType keySchema, StructType valueSchema, int maxRows
326310
allFixedLength = allFixedLength
327311
&& UnsafeRow.isFixedLength(valueSchema.apply(name).dataType());
328312
}
329-
klen = keySchema.defaultSize() + UnsafeRow.calculateBitSetWidthInBytes(keySchema.length());
330-
vlen = valueSchema.defaultSize()
331-
+ UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length());
332-
recordLength = 8 + klen + vlen + 8;
313+
if (allFixedLength) {
314+
klen = keySchema.defaultSize()
315+
+ UnsafeRow.calculateBitSetWidthInBytes(keySchema.length());
316+
vlen = valueSchema.defaultSize()
317+
+ UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length());
318+
recordLength = 8 + klen + vlen + 8;
319+
} else {
320+
// we only need the following data structures for variable length cases
321+
this.keyOffsets = new long[maxRows];
322+
}
333323

334324
if (!acquireNewPage(64 * 1024 * 1024)) { //64MB
335325
currentPage = null;

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class RowBasedHashMapGenerator(
109109
}.mkString("\n").concat(";")
110110

111111
s"""
112-
| private org.apache.spark.sql.catalyst.expressions.RowBatch batch;
112+
| private org.apache.spark.sql.catalyst.expressions.SimpleRowBatch batch;
113113
| private int[] buckets;
114114
| private int capacity = 1 << 16;
115115
| private double loadFactor = 0.5;
@@ -130,7 +130,7 @@ class RowBasedHashMapGenerator(
130130
| super(taskMemoryManager,
131131
| taskMemoryManager.pageSizeBytes(),
132132
| taskMemoryManager.getTungstenMemoryMode());
133-
| batch = org.apache.spark.sql.catalyst.expressions.RowBatch.allocate(keySchema,
133+
| batch = org.apache.spark.sql.catalyst.expressions.SimpleRowBatch.allocate(keySchema,
134134
| valueSchema, taskMemoryManager, capacity);
135135
|
136136
| final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);

0 commit comments

Comments
 (0)