Skip to content

Commit

Permalink
Merge pull request #352 from mikepapadim/mikepapadim/field_checks
Browse files Browse the repository at this point in the history
Refactor memory limit checks to take into account primitve type wrappers  when using the `withMemoryLimit` API
  • Loading branch information
mikepapadim authored Mar 19, 2024
2 parents 26691bb + 8c4b8f9 commit 11b8c81
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,22 @@
import java.util.Arrays;

public enum DataTypeSize {
// @formatter:off
BYTE(byte.class, (byte) 1), //
CHAR(char.class, (byte) 2), //
SHORT(short.class, (byte) 2), //
INT(int.class, (byte) 4), //
FLOAT(float.class, (byte) 4), //
LONG(long.class, (byte) 8), //
DOUBLE(double.class, (byte) 8);
DOUBLE(double.class, (byte) 8), //
BYTE_WRAPPER(Byte.class, (byte) 1),
CHAR_WRAPPER(Character.class, (byte) 2),
SHORT_WRAPPER(Short.class, (byte) 2),
INT_WRAPPER(Integer.class, (byte) 4),
FLOAT_WRAPPER(Float.class, (byte) 4),
LONG_WRAPPER(Long.class, (byte) 8),
DOUBLE_WRAPPER(Double.class, (byte) 8);
// @formatter:on

private final Class<?> dataType;
private final byte size;
Expand All @@ -42,15 +51,15 @@ public enum DataTypeSize {
this.size = size;
}

public static DataTypeSize findDataTypeSize(Class<?> dataType) {
return Arrays.stream(DataTypeSize.values()).filter(size -> size.getDataType().equals(dataType)).findFirst().orElse(null);
}

public Class<?> getDataType() {
return dataType;
}

public byte getSize() {
return size;
}

public static DataTypeSize findDataTypeSize(Class<?> dataType) {
return Arrays.stream(DataTypeSize.values()).filter(size -> size.getDataType().equals(dataType)).findFirst().orElse(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ public boolean isMemoryLimited() {

public boolean doesExceedExecutionPlanLimit() {
long totalSize = 0;

for (Object parameter : getObjects()) {

if (parameter.getClass().isArray()) {
Class<?> componentType = parameter.getClass().getComponentType();
DataTypeSize dataTypeSize = DataTypeSize.findDataTypeSize(componentType);
Expand All @@ -185,6 +186,16 @@ public boolean doesExceedExecutionPlanLimit() {
throw new TornadoRuntimeException(STR."Unsupported type: \{parameter.getClass()}");
}
}

if (!constants.isEmpty()) {
for (Object field : constants) {
DataTypeSize dataTypeSize = DataTypeSize.findDataTypeSize(field.getClass());
if (dataTypeSize == null) {
throw new TornadoRuntimeException("[UNSUPPORTED] Data type not supported for processing in batches");
}
totalSize += dataTypeSize.getSize();
}
}
return totalSize > getExecutionPlanMemoryLimit();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
*/
package uk.ac.manchester.tornado.unittests.memoryplan;

import static org.junit.Assert.assertEquals;

import org.junit.BeforeClass;
import org.junit.Test;

import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoMemoryException;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.TestHello;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

import static org.junit.Assert.assertEquals;

/**
* How to test?
*
Expand All @@ -49,6 +51,8 @@ public class TestMemoryLimit extends TornadoTestBase {
private static IntArray b = new IntArray(NUM_ELEMENTS);
private static IntArray c = new IntArray(NUM_ELEMENTS);

private static int value = 10000000;

@BeforeClass
public static void setUpBeforeClass() {
a = new IntArray(NUM_ELEMENTS);
Expand All @@ -58,12 +62,18 @@ public static void setUpBeforeClass() {
b.init(2);
}

public static void add(IntArray a, IntArray b, IntArray c, int value) {
for (@Parallel int i = 0; i < c.getSize(); i++) {
c.set(i, a.get(i) + b.get(i) + value);
}
}

@Test
public void testWithMemoryLimitOver() {

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b) //
.task("t0", TestHello::add, a, b, c) //
.task("t0", TestMemoryLimit::add, a, b, c, value) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, c);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
Expand All @@ -73,7 +83,7 @@ public void testWithMemoryLimitOver() {
executionPlan.withMemoryLimit("1GB").execute();

for (int i = 0; i < c.getSize(); i++) {
assertEquals(a.get(i) + b.get(i), c.get(i), 0.001);
assertEquals(a.get(i) + b.get(i) + value, c.get(i), 0.001);
}
executionPlan.freeDeviceMemory();
}
Expand Down

0 comments on commit 11b8c81

Please sign in to comment.