Skip to content

Commit

Permalink
Merge pull request #376 from mairooni/fix/batch_copyout
Browse files Browse the repository at this point in the history
Support lazy copy-out for batch processing
  • Loading branch information
jjfumero authored Apr 16, 2024
2 parents ac476de + 0187e0b commit 4204e3b
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public int read(long executionPlanId, final Object reference, long hostOffset, l
} else if (batchSize <= 0) {
returnEvent = deviceContext.readBuffer(executionPlanId, toBuffer(), numBytes, segment.address(), hostOffset, (useDeps) ? events : null);
} else {
returnEvent = deviceContext.readBuffer(executionPlanId, toBuffer() + TornadoNativeArray.ARRAY_HEADER, bufferSize, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, (useDeps)
returnEvent = deviceContext.readBuffer(executionPlanId, toBuffer() + TornadoNativeArray.ARRAY_HEADER, numBytes, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, (useDeps)
? events
: null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
import uk.ac.manchester.tornado.runtime.analyzer.MetaReduceCodeAnalysis;
import uk.ac.manchester.tornado.runtime.analyzer.ReduceCodeAnalysis;
import uk.ac.manchester.tornado.runtime.analyzer.TaskUtils;
import uk.ac.manchester.tornado.runtime.common.BatchConfiguration;
import uk.ac.manchester.tornado.runtime.common.RuntimeUtilities;
import uk.ac.manchester.tornado.runtime.common.Tornado;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
Expand Down Expand Up @@ -161,7 +162,7 @@ public class TornadoTaskGraph implements TornadoTaskGraphInterface {

private TornadoVM vm; // One TornadoVM instance per TornadoExecutionPlan

// HashMap to keep an instance of the TornadoVM per Device
// HashMap to keep an instance of the TornadoVM per Device
private Map<TornadoXPUDevice, TornadoVM> vmTable;
private Event event;
private String taskGraphName;
Expand Down Expand Up @@ -934,7 +935,7 @@ public void transferToDevice(final int mode, Object... objects) {
}

executionContext.getLocalStateObject(parameter).setStreamIn(isObjectForStreaming);

// List of input objects for the dynamic reconfiguration
inputModesObjects.add(new StreamingObject(mode, parameter));

Expand Down Expand Up @@ -979,7 +980,7 @@ public void transferToHost(final int mode, Object... objects) {
// List of output objects for the dynamic reconfiguration
outputModeObjects.add(new StreamingObject(mode, functionParameter));

if (TornadoOptions.isReusedBuffersEnabled()) {
if (TornadoOptions.isReusedBuffersEnabled() || mode == DataTransferMode.UNDER_DEMAND) {
if (!argumentsLookUp.contains(functionParameter)) {
// We already set function parameter in transferToDevice
lockObjectsInMemory(functionParameter);
Expand Down Expand Up @@ -1102,8 +1103,35 @@ private Event syncObjectInner(Object object, long offset, long partialCopySize)
return null;
}

private Event syncObjectInnerLazy(Object object, long hostOffset, long bufferSize) {
final LocalObjectState localState = executionContext.getLocalStateObject(object);
final DataObjectState globalState = localState.getGlobalState();
final TornadoXPUDevice device = meta().getLogicDevice();
final XPUDeviceBufferState deviceState = globalState.getDeviceState(device);
if (deviceState.isLockedBuffer()) {
deviceState.getObjectBuffer().setSizeSubRegion(bufferSize);
return device.resolveEvent(executionPlanId, device.streamOutBlocking(executionPlanId, object, hostOffset, deviceState, null));
}
return null;
}

private Event syncParameter(Object object) {
Event eventParameter = syncObjectInner(object);
Event eventParameter = null;
if (batchSizeBytes != TornadoExecutionContext.INIT_VALUE) {
BatchConfiguration batchConfiguration = BatchConfiguration.computeChunkSizes(executionContext, batchSizeBytes);
long hostOffset = 0;
for (int i = 0; i < batchConfiguration.getTotalChunks(); i++) {
hostOffset = (batchSizeBytes * i);
eventParameter = syncObjectInnerLazy(object, hostOffset, batchSizeBytes);
}
// Last chunk
if (batchConfiguration.getRemainingChunkSize() != 0) {
hostOffset += batchSizeBytes;
eventParameter = syncObjectInnerLazy(object, hostOffset, batchConfiguration.getRemainingChunkSize());
}
} else {
eventParameter = syncObjectInner(object);
}
if (eventParameter != null) {
eventParameter.waitOn();
}
Expand Down Expand Up @@ -1583,7 +1611,7 @@ private void runParallelTaskGraphs(int numDevices, Thread[] threads, Timer timer
for (StreamingObject streamingObject : outputModeObjects) {
performStreamOutThreads(streamingObject.mode, task, streamingObject.object);
}

ImmutableTaskGraph immutableTaskGraph = task.snapshot();
TornadoExecutionPlan executor = new TornadoExecutionPlan(immutableTaskGraph);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2013-2020, APT Group, Department of Computer Science,
* Copyright (c) 2013-2020, 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -29,6 +29,7 @@
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.TornadoExecutionResult;
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException;
Expand Down Expand Up @@ -194,6 +195,41 @@ public void test100MBSmall() {
executionPlan.freeDeviceMemory();
}

@Test
public void test100MBSmallLazy() {

long maxAllocMemory = checkMaxHeapAllocationOnDevice(100, MemoryUnit.MB);

// Fill 120MB of float array
int size = 30000000;
// or as much as we can
if (size * 4 > maxAllocMemory) {
size = (int) ((maxAllocMemory / 4 / 2) * 0.9);
}
FloatArray arrayA = new FloatArray(size);
FloatArray arrayB = new FloatArray(size);

IntStream.range(0, arrayA.getSize()).sequential().forEach(idx -> arrayA.set(idx, 0));

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, arrayA) //
.task("t0", TestBatches::compute, arrayA, arrayB) //
.transferToHost(DataTransferMode.UNDER_DEMAND, arrayB);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph);
TornadoExecutionResult tornadoExecutionResult = executionPlan.withBatch("60MB") // Slots of 100 MB
.execute();

tornadoExecutionResult.transferToHost(arrayB);

for (int i = 0; i < arrayB.getSize(); i++) {
assertEquals(arrayA.get(i) + 100, arrayB.get(i), 0.1f);
}

executionPlan.freeDeviceMemory();
}

@Test
public void test100MB() {

Expand Down Expand Up @@ -227,6 +263,41 @@ public void test100MB() {
executionPlan.freeDeviceMemory();
}

@Test
public void test100MBLazy() {

long maxAllocMemory = checkMaxHeapAllocationOnDevice(100, MemoryUnit.MB);

// Fill 800MB of float array
int size = 200000000;
// or as much as we can
if (size * 4 > maxAllocMemory) {
size = (int) ((maxAllocMemory / 4 / 2) * 0.9);
}
FloatArray arrayA = new FloatArray(size);
FloatArray arrayB = new FloatArray(size);

IntStream.range(0, arrayA.getSize()).sequential().forEach(idx -> arrayA.set(idx, 0));

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, arrayA) //
.task("t0", TestBatches::compute, arrayA, arrayB) //
.transferToHost(DataTransferMode.UNDER_DEMAND, arrayB);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph);
TornadoExecutionResult tornadoExecutionResult = executionPlan.withBatch("100MB") // Slots of 100 MB
.execute();

tornadoExecutionResult.transferToHost(arrayB);

for (int i = 0; i < arrayB.getSize(); i++) {
assertEquals(arrayA.get(i) + 100, arrayB.get(i), 0.1f);
}

executionPlan.freeDeviceMemory();
}

@Test
public void test300MB() {

Expand Down Expand Up @@ -261,6 +332,42 @@ public void test300MB() {
executionPlan.freeDeviceMemory();
}

@Test
public void test300MBLazy() {

long maxAllocMemory = checkMaxHeapAllocationOnDevice(300, MemoryUnit.MB);

// Fill 1.0GB
int size = 250_000_000;
// Or as much as we can
if (size * 4 > maxAllocMemory) {
size = (int) ((maxAllocMemory / 4 / 2) * 0.9);
}
FloatArray arrayA = new FloatArray(size);
FloatArray arrayB = new FloatArray(size);

Random r = new Random();
IntStream.range(0, arrayA.getSize()).sequential().forEach(idx -> arrayA.set(idx, r.nextFloat()));

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, arrayA) //
.task("t0", TestBatches::compute, arrayA, arrayB) //
.transferToHost(DataTransferMode.UNDER_DEMAND, arrayB);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph);
TornadoExecutionResult tornadoExecutionResult = executionPlan.withBatch("300MB") // Slots of 300 MB
.execute();

tornadoExecutionResult.transferToHost(arrayB);

for (int i = 0; i < arrayB.getSize(); i++) {
assertEquals(arrayA.get(i) + 100, arrayB.get(i), 1.0f);
}

executionPlan.freeDeviceMemory();
}

@Test
public void test512MB() {

Expand Down Expand Up @@ -293,6 +400,40 @@ public void test512MB() {
executionPlan.freeDeviceMemory();
}

@Test
public void test512MBLazy() {

long maxAllocMemory = checkMaxHeapAllocationOnDevice(512, MemoryUnit.MB);

// Fill 800MB
int size = 200000000;
// or as much as we can
if (size * 4 > maxAllocMemory) {
size = (int) ((maxAllocMemory / 4) * 0.9);
}
FloatArray arrayA = new FloatArray(size);

IntStream.range(0, arrayA.getSize()).sequential().forEach(idx -> arrayA.set(idx, idx));

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, arrayA) //
.task("t0", TestBatches::compute, arrayA) //
.transferToHost(DataTransferMode.UNDER_DEMAND, arrayA);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph);
TornadoExecutionResult tornadoExecutionResult = executionPlan.withBatch("512MB") // Slots of 512 MB
.execute();

tornadoExecutionResult.transferToHost(arrayA);

for (int i = 0; i < arrayA.getSize(); i++) {
assertEquals(i, arrayA.get(i), 0.1f);
}

executionPlan.freeDeviceMemory();
}

@Test
public void test50MB() {

Expand Down Expand Up @@ -739,6 +880,34 @@ public void testBatchNotEven2() {
executionPlan.freeDeviceMemory();
}

@Test
public void testBatchNotEven2Lazy() {
checkMaxHeapAllocationOnDevice(64, MemoryUnit.MB);

// Allocate ~ 64MB
FloatArray array = new FloatArray(1024 * 1024 * 16);
FloatArray array2 = new FloatArray(1024 * 1024 * 16);
array.init(1.0f);
array2.init(1.0f);

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, array) //
.task("t1", TestBatches::compute2, array) //
.task("t2", TestBatches::compute2, array) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, array);

TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot());
TornadoExecutionResult tornadoExecutionResult = executionPlan.withBatch("10MB") // Batches of 10MB
.execute();

tornadoExecutionResult.transferToHost(array);

for (int i = 0; i < array.getSize(); i++) {
assertEquals(array2.get(i) * 4, array.get(i), 0.01f);
}
executionPlan.freeDeviceMemory();
}

private long checkMaxHeapAllocationOnDevice(int size, MemoryUnit memoryUnit) throws UnsupportedConfigurationException {

long maxAllocMemory = getTornadoRuntime().getDefaultDevice().getDeviceContext().getMemoryManager().getHeapSize();
Expand Down

0 comments on commit 4204e3b

Please sign in to comment.