Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Data Access for PrebuiltTaskGraph fixed #541

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package uk.ac.manchester.tornado.api.common;

import java.util.stream.IntStream;

import uk.ac.manchester.tornado.api.AccessorParameters;

public class PrebuiltTaskPackage extends TaskPackage {
Expand All @@ -32,13 +34,11 @@ public class PrebuiltTaskPackage extends TaskPackage {
this.entryPoint = entryPoint;
this.filename = fileName;
this.args = new Object[accessorParameters.numAccessors()];
for (int i = 0; i < accessorParameters.numAccessors(); i++) {
this.args[i] = accessorParameters.getAccessor(i).object();
}
this.accesses = new Access[accessorParameters.numAccessors()];
for (int i = 0; i < accessorParameters.numAccessors(); i++) {
IntStream.range(0, accessorParameters.numAccessors()).forEach(i -> {
this.args[i] = accessorParameters.getAccessor(i).object();
this.accesses[i] = accessorParameters.getAccessor(i).access();
}
});
}

public PrebuiltTaskPackage withAtomics(int[] atomics) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,15 @@ private static DomainTree buildDomainTree(int[] dims) {

}

/**
* Marshal object from {@link PrebuiltTaskPackage} to {@link PrebuiltTask}.
*
* @param meta
* {@link ScheduleContext}
* @param taskPackage
* {@link PrebuiltTaskPackage}
* @return {@link PrebuiltTask}
*/
public static PrebuiltTask createTask(ScheduleContext meta, PrebuiltTaskPackage taskPackage) {
PrebuiltTask prebuiltTask = new PrebuiltTask(meta, //
taskPackage.getId(), //
Expand All @@ -319,7 +328,7 @@ public static PrebuiltTask createTask(ScheduleContext meta, PrebuiltTaskPackage
taskPackage.getArgs(), //
taskPackage.getAccesses());
if (taskPackage.getAtomics() != null) {
prebuiltTask.withAtomics(taskPackage.getAtomics());
prebuiltTask.setAtomics(taskPackage.getAtomics());
}
return prebuiltTask;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.profiler.TornadoProfiler;
import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice;
import uk.ac.manchester.tornado.runtime.domain.DomainTree;
import uk.ac.manchester.tornado.runtime.tasks.meta.ScheduleContext;
import uk.ac.manchester.tornado.runtime.tasks.meta.TaskDataContext;

Expand All @@ -50,41 +49,17 @@ public class PrebuiltTask implements SchedulableTask {
private boolean forceCompiler;
private int[] atomics;

public PrebuiltTask(ScheduleContext scheduleMeta, String id, String entryPoint, String filename, Object[] args, Access[] access, TornadoDevice device, DomainTree domain) {
this.entryPoint = entryPoint;
this.filename = filename;
this.args = args;
this.argumentsAccess = access;
meta = new TaskDataContext(scheduleMeta, id, access.length);
for (int i = 0; i < access.length; i++) {
meta.getArgumentsAccess()[i] = access[i];
}
meta.setDevice(device);
meta.setDomain(domain);

final long[] values = new long[domain.getDepth()];
for (int i = 0; i < domain.getDepth(); i++) {
values[i] = domain.get(i).cardinality();
}
meta.setGlobalWork(values);

}

public PrebuiltTask(ScheduleContext scheduleMeta, String id, String entryPoint, String filename, Object[] args, Access[] access) {
this.entryPoint = entryPoint;
this.filename = filename;
this.args = args;
this.argumentsAccess = access;
meta = new TaskDataContext(scheduleMeta, id, access.length);
for (int i = 0; i < access.length; i++) {
meta.getArgumentsAccess()[i] = access[i];
}

meta.setArgumentsAccess(access);
}

public PrebuiltTask withAtomics(int[] atomics) {
public void setAtomics(int[] atomics) {
this.atomics = atomics;
return this;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
package uk.ac.manchester.tornado.runtime.tasks.meta;

import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.guarantee;
import static uk.ac.manchester.tornado.runtime.common.TornadoOptions.EVENT_WINDOW;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
Expand Down Expand Up @@ -142,12 +141,8 @@ public long[] initLocalWork() {
return localWork;
}

public void addProfile(int id) {
final TornadoXPUDevice device = getXPUDevice();
BitSet events;
profiles.computeIfAbsent(device, k -> new BitSet(EVENT_WINDOW));
events = profiles.get(device);
events.set(id);
public void setArgumentsAccess(Access[] access) {
this.argumentsAccess = access;
}

public Access[] getArgumentsAccess() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,74 @@ public void testPrebuilt01() throws TornadoExecutionPlanException {
.withDevice(defaultDevice) //
.execute();
}
for (int j = 0; j < c.getSize(); j++) {
assertEquals(a.get(j) + b.get(j), c.get(j));
}

}

@Test
public void testPrebuilt01Multi() throws TornadoExecutionPlanException {

final int numElements = 8;
IntArray a = new IntArray(numElements);
IntArray b = new IntArray(numElements);
IntArray c = new IntArray(numElements);

a.init(1);
b.init(2);

switch (backendType) {
case PTX:
FILE_PATH += "add.ptx";
break;
case OPENCL:
FILE_PATH += "add.cl";
break;
case SPIRV:
FILE_PATH += "add.spv";
break;
default:
throw new TornadoRuntimeException("Backend not supported");
}

// Define accessors for each parameter
AccessorParameters accessorParameters = new AccessorParameters(3);
accessorParameters.set(0, a, Access.READ_WRITE);
accessorParameters.set(1, b, Access.READ_WRITE);
accessorParameters.set(2, c, Access.WRITE_ONLY);

// Define the Task-Graph
TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) //
.prebuiltTask("t0", //task name
"add", // name of the low-level kernel to invoke
FILE_PATH, // file name
accessorParameters) // accessors
.transferToHost(DataTransferMode.EVERY_EXECUTION, c);

for (int i = 0; i < c.getSize(); i++) {
assertEquals(a.get(i) + b.get(i), c.get(i));
// When using the prebuilt API, we need to define the WorkerGrid, otherwise it will launch 1 thread
// on the target device
WorkerGrid workerGrid = new WorkerGrid1D(numElements);
GridScheduler gridScheduler = new GridScheduler("s0.t0", workerGrid);

// Launch the application on the target device
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot())) {

executionPlan.withGridScheduler(gridScheduler) //
.withDevice(defaultDevice) //
.execute();

// Run task multiple times
for (int i = 0; i < 10; i++) {
executionPlan.execute();
for (int j = 0; j < c.getSize(); j++) {
assertEquals(a.get(j) + b.get(j), c.get(j));
}
IntStream.range(0, numElements).forEach(k -> a.set(k, c.get(k)));
}
}

}

@Test
Expand Down