Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,17 @@ public final class BytesToBytesMap extends MemoryConsumer {

private long peakMemoryUsedBytes = 0L;

private final BlockManager blockManager;
private volatile MapIterator destructiveIterator = null;
private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();

public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
Expand Down Expand Up @@ -210,7 +207,6 @@ public BytesToBytesMap(
boolean enablePerfMetrics) {
this(
taskMemoryManager,
SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
initialCapacity,
0.70,
pageSizeBytes,
Expand Down Expand Up @@ -272,7 +268,7 @@ private void advanceToNextPage() {
}
}
try {
reader = spillWriters.getFirst().getReader(blockManager);
reader = spillWriters.getFirst().getReader(SparkEnv.get().blockManager());
recordsInPage = -1;
} catch (IOException e) {
// Scala iterator does not handle exception
Expand Down Expand Up @@ -347,6 +343,7 @@ public long spill(long numBytes) throws IOException {
long offset = block.getBaseOffset();
int numRecords = Platform.getInt(base, offset);
offset += 4;
BlockManager blockManager = SparkEnv.get().blockManager();
final UnsafeSorterSpillWriter writer =
new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords);
while (numRecords > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,29 @@

package org.apache.spark.unsafe.map;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.*;

import scala.Tuple2;
import scala.Tuple2$;
import scala.runtime.AbstractFunction1;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.SparkContext;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.Utils;

import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.mockito.AdditionalAnswers.returnsSecondArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.when;


public abstract class AbstractBytesToBytesMapSuite {
Expand All @@ -66,19 +50,6 @@ public abstract class AbstractBytesToBytesMapSuite {
private TaskMemoryManager taskMemoryManager;
private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes

final LinkedList<File> spillFilesCreated = new LinkedList<File>();
File tempDir;

@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;

private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
return stream;
}
}

@Before
public void setup() {
memoryManager =
Expand All @@ -87,50 +58,10 @@ public void setup() {
.set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator())
.set("spark.memory.offHeapSize", "256mb"));
taskMemoryManager = new TaskMemoryManager(memoryManager, 0);

tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
spillFilesCreated.clear();
MockitoAnnotations.initMocks(this);
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
@Override
public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
File file = File.createTempFile("spillFile", ".spill", tempDir);
spillFilesCreated.add(file);
return Tuple2$.MODULE$.apply(blockId, file);
}
});
when(blockManager.getDiskWriter(
any(BlockId.class),
any(File.class),
any(SerializerInstance.class),
anyInt(),
any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
@Override
public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
Object[] args = invocationOnMock.getArguments();

return new DiskBlockObjectWriter(
(File) args[1],
(SerializerInstance) args[2],
(Integer) args[3],
new CompressStream(),
false,
(ShuffleWriteMetrics) args[4],
(BlockId) args[0]
);
}
});
when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
.then(returnsSecondArg());
}

@After
public void tearDown() {
Utils.deleteRecursively(tempDir);
tempDir = null;

Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
if (taskMemoryManager != null) {
long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
Expand Down Expand Up @@ -537,7 +468,8 @@ public void failureToGrow() {

@Test
public void spillInIterator() throws IOException {
BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
SparkContext sc = new SparkContext("local", "BytesToBytesMapSuite");
BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 0.75, 1024, false);
try {
int i;
for (i = 0; i < 1024; i++) {
Expand Down Expand Up @@ -566,10 +498,7 @@ public void spillInIterator() throws IOException {
assertFalse(iter2.hasNext());
} finally {
map.free();
for (File spillFile : spillFilesCreated) {
assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
spillFile.exists());
}
sc.stop();
}
}

Expand Down