Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor memory limit checks to take into account primitve type wrappers when using the withMemoryLimit API #352

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,22 @@
import java.util.Arrays;

public enum DataTypeSize {
// @formatter:off
BYTE(byte.class, (byte) 1), //
CHAR(char.class, (byte) 2), //
SHORT(short.class, (byte) 2), //
INT(int.class, (byte) 4), //
FLOAT(float.class, (byte) 4), //
LONG(long.class, (byte) 8), //
DOUBLE(double.class, (byte) 8);
DOUBLE(double.class, (byte) 8), //
BYTE_WRAPPER(Byte.class, (byte) 1),
CHAR_WRAPPER(Character.class, (byte) 2),
SHORT_WRAPPER(Short.class, (byte) 2),
INT_WRAPPER(Integer.class, (byte) 4),
FLOAT_WRAPPER(Float.class, (byte) 4),
LONG_WRAPPER(Long.class, (byte) 8),
DOUBLE_WRAPPER(Double.class, (byte) 8);
// @formatter:on

private final Class<?> dataType;
private final byte size;
Expand All @@ -42,15 +51,15 @@ public enum DataTypeSize {
this.size = size;
}

public static DataTypeSize findDataTypeSize(Class<?> dataType) {
return Arrays.stream(DataTypeSize.values()).filter(size -> size.getDataType().equals(dataType)).findFirst().orElse(null);
}

public Class<?> getDataType() {
return dataType;
}

public byte getSize() {
return size;
}

public static DataTypeSize findDataTypeSize(Class<?> dataType) {
return Arrays.stream(DataTypeSize.values()).filter(size -> size.getDataType().equals(dataType)).findFirst().orElse(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ public boolean isMemoryLimited() {

public boolean doesExceedExecutionPlanLimit() {
long totalSize = 0;

for (Object parameter : getObjects()) {

if (parameter.getClass().isArray()) {
Class<?> componentType = parameter.getClass().getComponentType();
DataTypeSize dataTypeSize = DataTypeSize.findDataTypeSize(componentType);
Expand All @@ -185,6 +186,16 @@ public boolean doesExceedExecutionPlanLimit() {
throw new TornadoRuntimeException(STR."Unsupported type: \{parameter.getClass()}");
}
}

if (!constants.isEmpty()) {
for (Object field : constants) {
DataTypeSize dataTypeSize = DataTypeSize.findDataTypeSize(field.getClass());
if (dataTypeSize == null) {
throw new TornadoRuntimeException("[UNSUPPORTED] Data type not supported for processing in batches");
}
totalSize += dataTypeSize.getSize();
}
}
return totalSize > getExecutionPlanMemoryLimit();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
*/
package uk.ac.manchester.tornado.unittests.memoryplan;

import static org.junit.Assert.assertEquals;

import org.junit.BeforeClass;
import org.junit.Test;

import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoMemoryException;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.TestHello;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

import static org.junit.Assert.assertEquals;

/**
* How to test?
*
Expand All @@ -49,6 +51,8 @@ public class TestMemoryLimit extends TornadoTestBase {
private static IntArray b = new IntArray(NUM_ELEMENTS);
private static IntArray c = new IntArray(NUM_ELEMENTS);

private static int value = 10000000;

@BeforeClass
public static void setUpBeforeClass() {
a = new IntArray(NUM_ELEMENTS);
Expand All @@ -58,12 +62,18 @@ public static void setUpBeforeClass() {
b.init(2);
}

public static void add(IntArray a, IntArray b, IntArray c, int value) {
for (@Parallel int i = 0; i < c.getSize(); i++) {
c.set(i, a.get(i) + b.get(i) + value);
}
}

@Test
public void testWithMemoryLimitOver() {

TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b) //
.task("t0", TestHello::add, a, b, c) //
.task("t0", TestMemoryLimit::add, a, b, c, value) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, c);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
Expand All @@ -73,7 +83,7 @@ public void testWithMemoryLimitOver() {
executionPlan.withMemoryLimit("1GB").execute();

for (int i = 0; i < c.getSize(); i++) {
assertEquals(a.get(i) + b.get(i), c.get(i), 0.001);
assertEquals(a.get(i) + b.get(i) + value, c.get(i), 0.001);
}
executionPlan.freeDeviceMemory();
}
Expand Down