diff --git a/.gitignore b/.gitignore index e573fb39d6..fcbf063ec7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ tornado-examples/target/ tornado-runtime/target/ tornado.iml tornado_unittests.log - +OpenCL-Headers/ diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java index ba756444f4..b534ba5642 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoDeviceContext.java @@ -40,7 +40,7 @@ public interface TornadoDeviceContext { boolean isFP64Supported(); - boolean isCached(String methodName, SchedulableTask task); + boolean isCached(long executionPlanId, String methodName, SchedulableTask task); int getDeviceIndex(); diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java index c072bb86ec..37a5c6afe6 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java @@ -66,7 +66,7 @@ public class TornadoExecutionPlan implements AutoCloseable { */ public TornadoExecutionPlan(ImmutableTaskGraph... immutableTaskGraphs) { this.tornadoExecutor = new TornadoExecutor(immutableTaskGraphs); - long id = globalExecutionPlanCounter.incrementAndGet(); + final long id = globalExecutionPlanCounter.incrementAndGet(); executionPackage = new ExecutorFrame(id); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java index 84a5566deb..7f8ca316ab 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContext.java @@ -59,22 +59,25 @@ public class OCLDeviceContext implements OCLDeviceContextInterface { /** * Table to represent {@link uk.ac.manchester.tornado.api.TornadoExecutionPlan} -> {@link OCLCommandQueueTable} */ - private Map commandQueueTable; - + private final Map commandQueueTable; private final OCLContext context; private final PowerMetric powerMetric; private final OCLMemoryManager memoryManager; - private final OCLCodeCache codeCache; private final Map oclEventPool; private final TornadoBufferProvider bufferProvider; private boolean wasReset; - private Set executionIDs; + private final Set executionIDs; + + /** + * Map table to represent the compiled-code per execution plan. Each entry in the execution plan has its own + * code cache. The code cache manages the compilation and the cache for each task within an execution plan. + */ + private final Map codeCache; public OCLDeviceContext(OCLTargetDevice device, OCLContext context) { this.device = device; this.context = context; this.memoryManager = new OCLMemoryManager(this); - this.codeCache = new OCLCodeCache(this); this.oclEventPool = new ConcurrentHashMap<>(); this.bufferProvider = new OCLBufferProvider(this); this.commandQueueTable = new ConcurrentHashMap<>(); @@ -85,6 +88,7 @@ public OCLDeviceContext(OCLTargetDevice device, OCLContext context) { } else { this.powerMetric = new OCLEmptyPowerMetric(); } + codeCache = new ConcurrentHashMap<>(); } private boolean isDeviceContextOfNvidia() { @@ -523,7 +527,9 @@ public void reset(long executionPlanId) { executionIDs.remove(executionPlanId); } getMemoryManager().releaseKernelStackFrame(executionPlanId); - codeCache.reset(); + OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId); + oclCodeCache.reset(); + codeCache.remove(executionPlanId); wasReset = true; } @@ -588,11 +594,6 @@ public int getDevicePlatform() { return context.getPlatformIndex(); } - public void retainEvent(long executionPlanId, int localEventId) { - OCLEventPool eventPool = getOCLEventPool(executionPlanId); - eventPool.retainEvent(localEventId); - } - @Override public Event resolveEvent(long executionPlanId, int event) { if (event == -1) { @@ -609,57 +610,66 @@ public void flush(long executionPlanId) { commandQueue.flush(); } - public void finish(long executionPlanId) { - OCLCommandQueue commandQueue = getCommandQueue(executionPlanId); - commandQueue.finish(); - } - @Override public void flushEvents(long executionPlanId) { OCLCommandQueue commandQueue = getCommandQueue(executionPlanId); commandQueue.flushEvents(); } + private OCLCodeCache getOCLCodeCache(long executionPlanId) { + if (!codeCache.containsKey(executionPlanId)) { + codeCache.put(executionPlanId, new OCLCodeCache(this)); + } + return codeCache.get(executionPlanId); + } + @Override - public boolean isKernelAvailable() { - return codeCache.isKernelAvailable(); + public boolean isKernelAvailable(long executionPlanId) { + OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId); + return oclCodeCache.isKernelAvailable(); } - public OCLInstalledCode installCode(OCLCompilationResult result) { - return installCode(result.getMeta(), result.getId(), result.getName(), result.getTargetCode()); + public OCLInstalledCode installCode(long executionPlanId, OCLCompilationResult result) { + return installCode(executionPlanId, result.getMeta(), result.getId(), result.getName(), result.getTargetCode()); } @Override - public OCLInstalledCode installCode(TaskDataContext meta, String id, String entryPoint, byte[] code) { + public OCLInstalledCode installCode(long executionPlanId, TaskDataContext meta, String id, String entryPoint, byte[] code) { entryPoint = checkKernelName(entryPoint); - return codeCache.installSource(meta, id, entryPoint, code); + OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId); + return oclCodeCache.installSource(meta, id, entryPoint, code); } @Override - public OCLInstalledCode installCode(String id, String entryPoint, byte[] code, boolean printKernel) { - return codeCache.installFPGASource(id, entryPoint, code, printKernel); + public OCLInstalledCode installCode(long executionPlanId, String id, String entryPoint, byte[] code, boolean printKernel) { + OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId); + return oclCodeCache.installFPGASource(id, entryPoint, code, printKernel); } @Override - public boolean isCached(String id, String entryPoint) { + public boolean isCached(long executionPlanId, String id, String entryPoint) { entryPoint = checkKernelName(entryPoint); - return codeCache.isCached(id + "-" + entryPoint); + OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId); + return oclCodeCache.isCached(id + "-" + entryPoint); } @Override - public boolean isCached(String methodName, SchedulableTask task) { + public boolean isCached(long executionPlanId, String methodName, SchedulableTask task) { methodName = checkKernelName(methodName); - return codeCache.isCached(task.getId() + "-" + methodName); + OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId); + return oclCodeCache.isCached(task.getId() + "-" + methodName); } - public OCLInstalledCode getInstalledCode(String id, String entryPoint) { + @Override + public OCLInstalledCode getInstalledCode(long executionPlanId, String id, String entryPoint) { entryPoint = checkKernelName(entryPoint); - return codeCache.getInstalledCode(id, entryPoint); + OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId); + return oclCodeCache.getInstalledCode(id, entryPoint); } @Override - public OCLCodeCache getCodeCache() { - return this.codeCache; + public OCLCodeCache getCodeCache(long executionPlanId) { + return getOCLCodeCache(executionPlanId); } } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContextInterface.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContextInterface.java index d2f64f5600..f8e197b936 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContextInterface.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLDeviceContextInterface.java @@ -36,19 +36,19 @@ public interface OCLDeviceContextInterface extends TornadoDeviceContext { OCLTargetDevice getDevice(); - OCLCodeCache getCodeCache(); + OCLCodeCache getCodeCache(long executionPlanId); - boolean isCached(String id, String entryPoint); + boolean isCached(long executionPlanId, String id, String entryPoint); - OCLInstalledCode getInstalledCode(String id, String entryPoint); + OCLInstalledCode getInstalledCode(long executionPlanId, String id, String entryPoint); - OCLInstalledCode installCode(String id, String entryPoint, byte[] code, boolean printKernel); + OCLInstalledCode installCode(long executionPlanId, OCLCompilationResult result); - OCLInstalledCode installCode(OCLCompilationResult result); + OCLInstalledCode installCode(long executionPlanId, TaskDataContext meta, String id, String entryPoint, byte[] code); - OCLInstalledCode installCode(TaskDataContext meta, String id, String entryPoint, byte[] code); + OCLInstalledCode installCode(long executionPlanId, String id, String entryPoint, byte[] code, boolean printKernel); - boolean isKernelAvailable(); + boolean isKernelAvailable(long executionPlanId); void reset(long executionPlanId); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLJIT.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLJIT.java index 2756c67af5..8871369e07 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLJIT.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLJIT.java @@ -72,7 +72,9 @@ public static void main(String[] args) { OCLCompilationResult result = OCLCompiler.compileCodeForDevice(resolvedMethod, new Object[] {}, meta, (OCLProviders) backend.getProviders(), backend, new EmptyProfiler()); - OCLInstalledCode code = OpenCL.defaultDevice().getDeviceContext().installCode(result); + final long executionPlanId = 0; + + OCLInstalledCode code = OpenCL.defaultDevice().getDeviceContext().installCode(executionPlanId, result); for (byte b : code.getCode()) { System.out.printf("%c", b); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java index b7fda7c7b6..2020d2f491 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java @@ -238,20 +238,20 @@ public XPUBuffer createOrReuseAtomicsBuffer(int[] array) { return reuseBuffer; } - private boolean isOpenCLPreLoadBinary(OCLDeviceContextInterface deviceContext, String deviceInfo) { - OCLCodeCache installedCode = deviceContext.getCodeCache(); + private boolean isOpenCLPreLoadBinary(long executionPlanId, OCLDeviceContextInterface deviceContext, String deviceInfo) { + OCLCodeCache installedCode = deviceContext.getCodeCache(executionPlanId); return (installedCode.isLoadBinaryOptionEnabled() && (installedCode.getOpenCLBinary(deviceInfo) != null)); } - private TornadoInstalledCode compileTask(SchedulableTask task) { + private TornadoInstalledCode compileTask(long executionPlanId, SchedulableTask task) { final OCLDeviceContextInterface deviceContext = getDeviceContext(); final CompilableTask executable = (CompilableTask) task; final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod()); final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex()); // Return the code from the cache - if (!task.shouldCompile() && deviceContext.isCached(task.getId(), resolvedMethod.getName())) { - return deviceContext.getInstalledCode(task.getId(), resolvedMethod.getName()); + if (!task.shouldCompile() && deviceContext.isCached(executionPlanId, task.getId(), resolvedMethod.getName())) { + return deviceContext.getInstalledCode(executionPlanId, task.getId(), resolvedMethod.getName()); } // copy meta data into task @@ -289,10 +289,10 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { OCLInstalledCode installedCode; if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) { // A) for FPGA - installedCode = deviceContext.installCode(result.getId(), result.getName(), result.getTargetCode(), task.meta().isPrintKernelEnabled()); + installedCode = deviceContext.installCode(executionPlanId, result.getId(), result.getName(), result.getTargetCode(), task.meta().isPrintKernelEnabled()); } else { // B) for CPU multi-core or GPU - installedCode = deviceContext.installCode(result); + installedCode = deviceContext.installCode(executionPlanId, result); } profiler.stop(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId()); profiler.sum(ProfilerType.TOTAL_DRIVER_COMPILE_TIME, profiler.getTaskTimer(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId())); @@ -310,11 +310,11 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { } } - private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { + private TornadoInstalledCode compilePreBuiltTask(long executionPlanId, SchedulableTask task) { final OCLDeviceContextInterface deviceContext = getDeviceContext(); final PrebuiltTask executable = (PrebuiltTask) task; - if (deviceContext.isCached(task.getId(), executable.getEntryPoint())) { - return deviceContext.getInstalledCode(task.getId(), executable.getEntryPoint()); + if (deviceContext.isCached(executionPlanId, task.getId(), executable.getEntryPoint())) { + return deviceContext.getInstalledCode(executionPlanId, task.getId(), executable.getEntryPoint()); } final Path path = Paths.get(executable.getFilename()); @@ -325,10 +325,10 @@ private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { OCLInstalledCode installedCode; if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) { // A) for FPGA - installedCode = deviceContext.installCode(task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled()); + installedCode = deviceContext.installCode(executionPlanId, task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled()); } else { // B) for CPU multi-core or GPU - installedCode = deviceContext.installCode(executable.meta(), task.getId(), executable.getEntryPoint(), source); + installedCode = deviceContext.installCode(executionPlanId, executable.meta(), task.getId(), executable.getEntryPoint(), source); } return installedCode; } catch (IOException e) { @@ -337,11 +337,11 @@ private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { return null; } - private TornadoInstalledCode compileJavaToAccelerator(SchedulableTask task) { + private TornadoInstalledCode compileJavaToAccelerator(long executionPlanId, SchedulableTask task) { if (task instanceof CompilableTask) { - return compileTask(task); + return compileTask(executionPlanId, task); } else if (task instanceof PrebuiltTask) { - return compilePreBuiltTask(task); + return compilePreBuiltTask(executionPlanId, task); } TornadoInternalError.shouldNotReachHere("task of unknown type: " + task.getClass().getSimpleName()); return null; @@ -351,15 +351,15 @@ private String getTaskEntryName(SchedulableTask task) { return task.getTaskName(); } - private TornadoInstalledCode loadPreCompiledBinaryForTask(SchedulableTask task) { + private TornadoInstalledCode loadPreCompiledBinaryForTask(long executionPlanId, SchedulableTask task) { final OCLDeviceContextInterface deviceContext = getDeviceContext(); - final OCLCodeCache codeCache = deviceContext.getCodeCache(); + final OCLCodeCache codeCache = deviceContext.getCodeCache(executionPlanId); final String deviceFullName = getFullTaskIdDevice(task); final Path lookupPath = Paths.get(codeCache.getOpenCLBinary(deviceFullName)); String entry = getTaskEntryName(task); - if (deviceContext.getInstalledCode(task.getId(), entry) != null) { - return deviceContext.getInstalledCode(task.getId(), entry); + if (deviceContext.getInstalledCode(executionPlanId, task.getId(), entry) != null) { + return deviceContext.getInstalledCode(executionPlanId, task.getId(), entry); } else { return codeCache.installEntryPointForBinaryForFPGAs(task.getId(), lookupPath, entry); } @@ -376,16 +376,16 @@ private String getFullTaskIdDevice(SchedulableTask task) { } @Override - public boolean isFullJITMode(SchedulableTask task) { + public boolean isFullJITMode(long executionPlanId, SchedulableTask task) { final OCLDeviceContextInterface deviceContext = getDeviceContext(); final String deviceFullName = getFullTaskIdDevice(task); - return (!isOpenCLPreLoadBinary(deviceContext, deviceFullName) && deviceContext.isPlatformFPGA()); + return (!isOpenCLPreLoadBinary(executionPlanId, deviceContext, deviceFullName) && deviceContext.isPlatformFPGA()); } @Override - public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { + public TornadoInstalledCode getCodeFromCache(long executionPlanId, SchedulableTask task) { String entry = getTaskEntryName(task); - return getDeviceContext().getInstalledCode(task.getId(), entry); + return getDeviceContext().getInstalledCode(executionPlanId, task.getId(), entry); } @Override @@ -441,34 +441,34 @@ public boolean checkAtomicsParametersForTask(SchedulableTask task) { return TornadoAtomicIntegerNode.globalAtomicsParameters.containsKey(task.meta().getCompiledResolvedJavaMethod()); } - private boolean isJITTaskForFGPA(SchedulableTask task) { + private boolean isJITTaskForFGPA(long executionPlanId, SchedulableTask task) { final OCLDeviceContextInterface deviceContext = getDeviceContext(); final String deviceFullName = getFullTaskIdDevice(task); - return !isOpenCLPreLoadBinary(deviceContext, deviceFullName) && deviceContext.isPlatformFPGA(); + return !isOpenCLPreLoadBinary(executionPlanId, deviceContext, deviceFullName) && deviceContext.isPlatformFPGA(); } - private boolean isJITTaskForGPUsAndCPUs(SchedulableTask task) { + private boolean isJITTaskForGPUsAndCPUs(long executionplanId, SchedulableTask task) { final OCLDeviceContextInterface deviceContext = getDeviceContext(); final String deviceFullName = getFullTaskIdDevice(task); - return !isOpenCLPreLoadBinary(deviceContext, deviceFullName) && !deviceContext.isPlatformFPGA(); + return !isOpenCLPreLoadBinary(executionplanId, deviceContext, deviceFullName) && !deviceContext.isPlatformFPGA(); } - private TornadoInstalledCode compileJavaForFPGAs(SchedulableTask task) { - TornadoInstalledCode tornadoInstalledCode = compileJavaToAccelerator(task); + private TornadoInstalledCode compileJavaForFPGAs(long executionPlanId, SchedulableTask task) { + TornadoInstalledCode tornadoInstalledCode = compileJavaToAccelerator(executionPlanId, task); if (tornadoInstalledCode != null) { - return loadPreCompiledBinaryForTask(task); + return loadPreCompiledBinaryForTask(executionPlanId, task); } return null; } @Override - public TornadoInstalledCode installCode(SchedulableTask task) { - if (isJITTaskForFGPA(task)) { - return compileJavaForFPGAs(task); - } else if (isJITTaskForGPUsAndCPUs(task)) { - return compileJavaToAccelerator(task); + public TornadoInstalledCode installCode(long executionPlanId, SchedulableTask task) { + if (isJITTaskForFGPA(executionPlanId, task)) { + return compileJavaForFPGAs(executionPlanId, task); + } else if (isJITTaskForGPUsAndCPUs(executionPlanId, task)) { + return compileJavaToAccelerator(executionPlanId, task); } - return loadPreCompiledBinaryForTask(task); + return loadPreCompiledBinaryForTask(executionPlanId, task); } private XPUBuffer createArrayWrapper(Class type, OCLDeviceContext device, long batchSize) { diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/tests/TestOpenCLJITCompiler.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/tests/TestOpenCLJITCompiler.java index fe4bd33d2c..4879ef249c 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/tests/TestOpenCLJITCompiler.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/tests/TestOpenCLJITCompiler.java @@ -75,7 +75,7 @@ public static void main(String[] args) { new TestOpenCLJITCompiler().test(); } - public MetaCompilation compileMethod(Class klass, String methodName, OCLTornadoDevice tornadoDevice, Object... parameters) { + public MetaCompilation compileMethod(long executionPlanId, Class klass, String methodName, OCLTornadoDevice tornadoDevice, Object... parameters) { // Get the method object to be compiled Method methodToCompile = CompilerUtil.getMethodForName(klass, methodName); @@ -108,7 +108,7 @@ public MetaCompilation compileMethod(Class klass, String methodName, OCLTorna OCLCompilationResult compilationResult = OCLCompiler.compileSketchForDevice(sketch, compilableTask, (OCLProviders) providers, openCLBackend, new EmptyProfiler()); // Install the OpenCL Code in the VM - OCLInstalledCode openCLCode = tornadoDevice.getDeviceContext().installCode(compilationResult); + OCLInstalledCode openCLCode = tornadoDevice.getDeviceContext().installCode(executionPlanId, compilationResult); return new MetaCompilation(taskMeta, openCLCode); } @@ -164,15 +164,16 @@ public void test() { Arrays.fill(a, -10); Arrays.fill(b, 10); + long executionPlanId = 0; OCLTornadoDevice tornadoDevice = OpenCL.defaultDevice(); - MetaCompilation compileMethod = compileMethod(TestOpenCLJITCompiler.class, "methodToCompile", tornadoDevice, a, b, c); + MetaCompilation compileMethod = compileMethod(executionPlanId, TestOpenCLJITCompiler.class, "methodToCompile", tornadoDevice, a, b, c); // Check with all internal APIs run(tornadoDevice, (OCLInstalledCode) compileMethod.getInstalledCode(), compileMethod.getTaskMeta(), a, b, c); - long executionPlanId = 0; + // Check with OpenCL API runWithOpenCLAPI(executionPlanId, tornadoDevice, (OCLInstalledCode) compileMethod.getInstalledCode(), compileMethod.getTaskMeta(), a, b, c); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLDeviceContext.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLDeviceContext.java index 307a22f6af..57b41618bf 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLDeviceContext.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLDeviceContext.java @@ -204,42 +204,42 @@ public int getDevicePlatform() { } @Override - public boolean isKernelAvailable() { + public boolean isKernelAvailable(long executionPlanId) { return true; } @Override - public OCLInstalledCode installCode(OCLCompilationResult result) { + public OCLInstalledCode installCode(long executionPlanId, OCLCompilationResult result) { return null; } @Override - public OCLInstalledCode installCode(TaskDataContext meta, String id, String entryPoint, byte[] code) { + public OCLInstalledCode installCode(long executionPlanId, TaskDataContext meta, String id, String entryPoint, byte[] code) { return null; } @Override - public OCLInstalledCode installCode(String id, String entryPoint, byte[] code, boolean printKernel) { + public OCLInstalledCode installCode(long executionPlanId, String id, String entryPoint, byte[] code, boolean printKernel) { return null; } @Override - public boolean isCached(String id, String entryPoint) { + public boolean isCached(long executionPlanId, String id, String entryPoint) { return false; } @Override - public OCLInstalledCode getInstalledCode(String id, String entryPoint) { + public OCLInstalledCode getInstalledCode(long executionPlanId, String id, String entryPoint) { return null; } @Override - public OCLCodeCache getCodeCache() { + public OCLCodeCache getCodeCache(long executionPlanId) { return codeCache; } @Override - public boolean isCached(String methodName, SchedulableTask task) { + public boolean isCached(long executionPlanId, String methodName, SchedulableTask task) { return false; } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLTornadoDevice.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLTornadoDevice.java index 37566b330c..eec3e5f4b2 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLTornadoDevice.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/virtual/VirtualOCLTornadoDevice.java @@ -244,12 +244,12 @@ private TornadoInstalledCode compileJavaToAccelerator(SchedulableTask task) { } @Override - public boolean isFullJITMode(SchedulableTask task) { + public boolean isFullJITMode(long executionPlanId, SchedulableTask task) { return true; } @Override - public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { + public TornadoInstalledCode getCodeFromCache(long executionPlanId, SchedulableTask task) { return null; } @@ -279,7 +279,7 @@ public boolean checkAtomicsParametersForTask(SchedulableTask task) { } @Override - public TornadoInstalledCode installCode(SchedulableTask task) { + public TornadoInstalledCode installCode(long executionPlanId, SchedulableTask task) { return compileJavaToAccelerator(task); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java index 87c17af5d4..117a4ae1eb 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java @@ -34,7 +34,7 @@ public class PTXCodeCache { private final PTXDeviceContext deviceContext; private final ConcurrentHashMap cache; - public PTXCodeCache(PTXDeviceContext deviceContext) { + PTXCodeCache(PTXDeviceContext deviceContext) { this.deviceContext = deviceContext; cache = new ConcurrentHashMap<>(); } @@ -60,15 +60,15 @@ public PTXInstalledCode installSource(String name, byte[] targetCode, String res return cache.get(name); } - public PTXInstalledCode getCachedCode(String name) { + PTXInstalledCode getCachedCode(String name) { return cache.get(name); } - public boolean isCached(String name) { + boolean isCached(String name) { return cache.containsKey(name); } - public void reset() { + void reset() { for (PTXInstalledCode code : cache.values()) { code.invalidate(); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java index 432c610d22..98d020388d 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java @@ -62,20 +62,25 @@ public class PTXDeviceContext implements TornadoDeviceContext { private final PTXDevice device; private final PTXMemoryManager memoryManager; - private final PTXCodeCache codeCache; private final PTXScheduler scheduler; private final TornadoBufferProvider bufferProvider; private final PowerMetric powerMetric; private final Map streamTable; private boolean wasReset; - private Set executionIDs; + private final Set executionIDs; + + /** + * Map table to represent the compiled-code per execution plan. Each entry in the execution plan has its own + * code cache. The code cache manages the compilation and the cache for each task within an execution plan. + */ + private final Map codeCache; public PTXDeviceContext(PTXDevice device) { this.device = device; streamTable = new ConcurrentHashMap<>(); this.scheduler = new PTXScheduler(device); this.powerMetric = new PTXNvidiaPowerMetric(this); - codeCache = new PTXCodeCache(this); + codeCache = new ConcurrentHashMap<>(); memoryManager = new PTXMemoryManager(this); bufferProvider = new PTXBufferProvider(this); wasReset = false; @@ -120,20 +125,23 @@ public PTXTornadoDevice toDevice() { return new PTXTornadoDevice(device.getDeviceIndex()); } - public TornadoInstalledCode installCode(PTXCompilationResult result, String resolvedMethodName) { - return codeCache.installSource(result.getName(), result.getTargetCode(), resolvedMethodName, result.metaData().isPrintKernelEnabled()); + public TornadoInstalledCode installCode(long executionPlanId, PTXCompilationResult result, String resolvedMethodName) { + PTXCodeCache ptxCodeCache = getPTXCodeCache(executionPlanId); + return ptxCodeCache.installSource(result.getName(), result.getTargetCode(), resolvedMethodName, result.metaData().isPrintKernelEnabled()); } - public TornadoInstalledCode installCode(String name, byte[] code, String resolvedMethodName, boolean printKernel) { - return codeCache.installSource(name, code, resolvedMethodName, printKernel); + public TornadoInstalledCode installCode(long executionPlanId, String name, byte[] code, String resolvedMethodName, boolean printKernel) { + PTXCodeCache ptxCodeCache = getPTXCodeCache(executionPlanId); + return ptxCodeCache.installSource(name, code, resolvedMethodName, printKernel); } - public TornadoInstalledCode getInstalledCode(String name) { - return codeCache.getCachedCode(name); + public TornadoInstalledCode getInstalledCode(long executionPlanId, String name) { + PTXCodeCache ptxCodeCache = getPTXCodeCache(executionPlanId); + return ptxCodeCache.getCachedCode(name); } - public PTXCodeCache getCodeCache() { - return codeCache; + public PTXCodeCache getCodeCache(long executionPlanId) { + return getPTXCodeCache(executionPlanId); } public PTXDevice getDevice() { @@ -237,12 +245,12 @@ public void syncIfNeeded(long executionPlanId) { } public void flush(long executionPlanId) { - // I don't think there is anything like this in CUDA so I am calling sync + // I don't think there is anything like this in CUDA, so I am calling sync sync(executionPlanId); } @Override - public void reset(long executionPlanId) { + public synchronized void reset(long executionPlanId) { PTXStreamTable table = streamTable.get(executionPlanId); if (table != null) { table.cleanup(device); @@ -252,7 +260,8 @@ public void reset(long executionPlanId) { executionIDs.remove(executionPlanId); } getMemoryManager().releaseKernelStackFrame(executionPlanId); - codeCache.reset(); + PTXCodeCache ptxCodeCache = getPTXCodeCache(executionPlanId); + ptxCodeCache.reset(); wasReset = true; } @@ -359,8 +368,9 @@ private void updateProfiler(long executionPlanId, final int taskEvent, final Tas } @Override - public boolean isCached(String methodName, SchedulableTask task) { - return codeCache.isCached(buildKernelName(methodName, task)); + public boolean isCached(long executionPlanId, String methodName, SchedulableTask task) { + PTXCodeCache ptxCodeCache = getPTXCodeCache(executionPlanId); + return ptxCodeCache.isCached(buildKernelName(methodName, task)); } public void destroyStream(long executionPlanId) { @@ -570,6 +580,13 @@ private PTXStream getStream(long executionPlanId) { return streamTable.get(executionPlanId).get(device); } + private PTXCodeCache getPTXCodeCache(long executionPlanId) { + if (!codeCache.containsKey(executionPlanId)) { + codeCache.put(executionPlanId, new PTXCodeCache(this)); + } + return codeCache.get(executionPlanId); + } + private PTXStream getStreamIfNeeded(long executionPlanId) { if (!streamTable.containsKey(executionPlanId)) { return null; diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java index 6d97f73447..33eba38527 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java @@ -153,15 +153,15 @@ public boolean checkAtomicsParametersForTask(SchedulableTask task) { } @Override - public TornadoInstalledCode installCode(SchedulableTask task) { + public TornadoInstalledCode installCode(long executionPlanId, SchedulableTask task) { return switch (task) { - case CompilableTask _ -> compileTask(task); - case PrebuiltTask _ -> compilePreBuiltTask(task); + case CompilableTask _ -> compileTask(executionPlanId, task); + case PrebuiltTask _ -> compilePreBuiltTask(executionPlanId, task); default -> throw new TornadoInternalError("task of unknown type: " + task.getClass().getSimpleName()); }; } - private TornadoInstalledCode compileTask(SchedulableTask task) { + private TornadoInstalledCode compileTask(long executionPlanId, SchedulableTask task) { TornadoProfiler profiler = task.getProfiler(); final PTXDeviceContext deviceContext = getDeviceContext(); @@ -177,7 +177,7 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { try { PTXCompilationResult result; - if (!deviceContext.isCached(resolvedMethod.getName(), executable)) { + if (!deviceContext.isCached(executionPlanId, resolvedMethod.getName(), executable)) { PTXProviders providers = (PTXProviders) getBackend().getProviders(); profiler.start(ProfilerType.TASK_COMPILE_GRAAL_TIME, taskMeta.getId()); result = PTXCompiler.compileSketchForDevice(sketch, executable, providers, getBackend(), executable.getProfiler()); @@ -188,7 +188,7 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { } profiler.start(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId()); - TornadoInstalledCode installedCode = deviceContext.installCode(result, resolvedMethod.getName()); + TornadoInstalledCode installedCode = deviceContext.installCode(executionPlanId, result, resolvedMethod.getName()); profiler.stop(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId()); profiler.sum(ProfilerType.TOTAL_DRIVER_COMPILE_TIME, profiler.getTaskTimer(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId())); return installedCode; @@ -202,12 +202,12 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { } } - private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { + private TornadoInstalledCode compilePreBuiltTask(long executionPlanId, SchedulableTask task) { final PTXDeviceContext deviceContext = getDeviceContext(); final PrebuiltTask executable = (PrebuiltTask) task; String functionName = buildKernelName(executable.getEntryPoint(), executable); - if (deviceContext.isCached(executable.getEntryPoint(), executable)) { - return deviceContext.getInstalledCode(functionName); + if (deviceContext.isCached(executionPlanId, executable.getEntryPoint(), executable)) { + return deviceContext.getInstalledCode(executionPlanId, functionName); } final Path path = Paths.get(executable.getFilename()); @@ -215,7 +215,7 @@ private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { try { byte[] source = Files.readAllBytes(path); source = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend()); - return deviceContext.installCode(functionName, source, executable.getEntryPoint(), task.meta().isPrintKernelEnabled()); + return deviceContext.installCode(executionPlanId, functionName, source, executable.getEntryPoint(), task.meta().isPrintKernelEnabled()); } catch (IOException e) { e.printStackTrace(); } @@ -223,12 +223,12 @@ private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { } @Override - public boolean isFullJITMode(SchedulableTask task) { + public boolean isFullJITMode(long executionPlanId, SchedulableTask task) { return true; } @Override - public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { + public TornadoInstalledCode getCodeFromCache(long executionPlanId, SchedulableTask task) { String methodName; if (task instanceof PrebuiltTask) { PrebuiltTask prebuiltTask = (PrebuiltTask) task; @@ -239,7 +239,7 @@ public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { methodName = resolvedMethod.getName(); } String functionName = buildKernelName(methodName, task); - return getDeviceContext().getInstalledCode(functionName); + return getDeviceContext().getInstalledCode(executionPlanId, functionName); } private XPUBuffer createDeviceBuffer(Class type, Object object, long batchSize) { diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXJITCompiler.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXJITCompiler.java index 82620b42a1..69a08214e1 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXJITCompiler.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXJITCompiler.java @@ -74,7 +74,7 @@ public static void main(String[] args) { new TestPTXJITCompiler().test(); } - public MetaCompilation compileMethod(Class klass, String methodName, PTXTornadoDevice tornadoDevice, Object... parameters) { + public MetaCompilation compileMethod(long executionPlanId, Class klass, String methodName, PTXTornadoDevice tornadoDevice, Object... parameters) { // Get the method object to be compiled Method methodToCompile = CompilerUtil.getMethodForName(klass, methodName); @@ -107,7 +107,7 @@ public MetaCompilation compileMethod(Class klass, String methodName, PTXTorna PTXCompilationResult compilationResult = PTXCompiler.compileSketchForDevice(sketch, compilableTask, (PTXProviders) providers, ptxBackend, new EmptyProfiler()); // Install the PTX Code in the VM - TornadoInstalledCode ptxCode = tornadoDevice.getDeviceContext().installCode(compilationResult, resolvedJavaMethod.getName()); + TornadoInstalledCode ptxCode = tornadoDevice.getDeviceContext().installCode(executionPlanId, compilationResult, resolvedJavaMethod.getName()); return new MetaCompilation(taskMeta, (PTXInstalledCode) ptxCode); } @@ -161,10 +161,11 @@ public void test() { Arrays.fill(a, -10); Arrays.fill(b, 10); + final long executionPlanId = 0; PTXTornadoDevice tornadoDevice = PTX.defaultDevice(); - MetaCompilation compileMethod = compileMethod(TestPTXJITCompiler.class, "methodToCompile", tornadoDevice, a, b, c); + MetaCompilation compileMethod = compileMethod(executionPlanId, TestPTXJITCompiler.class, "methodToCompile", tornadoDevice, a, b, c); // Check with all internal APIs run(tornadoDevice, (PTXInstalledCode) compileMethod.getInstalledCode(), compileMethod.getTaskMeta(), a, b, c); diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXTornadoCompiler.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXTornadoCompiler.java index 007fa9e7f7..bd5f083c7c 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXTornadoCompiler.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXTornadoCompiler.java @@ -77,8 +77,9 @@ public class TestPTXTornadoCompiler { public static void main(String[] args) { + final long executionPlanId = 0; PTXPlatform platform = PTX.getPlatform(); - PTXCodeCache codeCache = platform.getDevice(0).getPTXContext().getDeviceContext().getCodeCache(); + PTXCodeCache codeCache = platform.getDevice(0).getPTXContext().getDeviceContext().getCodeCache(executionPlanId); TornadoCoreRuntime tornadoRuntime = TornadoCoreRuntime.getTornadoRuntime(); PTXBackend backend = tornadoRuntime.getBackend(PTXBackendImpl.class).getDefaultBackend(); diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java index c410d353da..0dc8f3c676 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java @@ -65,12 +65,16 @@ public abstract class SPIRVDeviceContext implements TornadoDeviceContext { protected SPIRVContext spirvContext; protected SPIRVTornadoDevice tornadoDevice; protected SPIRVMemoryManager memoryManager; - protected SPIRVCodeCache codeCache; protected boolean wasReset; protected Map spirvEventPool; private TornadoBufferProvider bufferProvider; + private final Set executionIds; - private Set executionIds; + /** + * Map table to represent the compiled-code per execution plan. Each entry in the execution plan has its own + * code cache. The code cache manages the compilation and the cache for each task within an execution plan. + */ + protected Map codeCache; protected SPIRVDeviceContext(SPIRVDevice device, SPIRVContext context) { init(device); @@ -82,11 +86,7 @@ private void init(SPIRVDevice device) { this.device = device; this.tornadoDevice = new SPIRVTornadoDevice(device); this.memoryManager = new SPIRVMemoryManager(this); - if (this instanceof SPIRVLevelZeroDeviceContext) { - this.codeCache = new SPIRVLevelZeroCodeCache(this); - } else { - this.codeCache = new SPIRVOCLCodeCache(this); - } + this.codeCache = new ConcurrentHashMap<>(); this.wasReset = false; this.spirvEventPool = new ConcurrentHashMap<>(); this.bufferProvider = new SPIRVBufferProvider(this); @@ -159,10 +159,25 @@ public void reset(long executionPlanId) { spirvContext.reset(executionPlanId, getDeviceIndex()); spirvEventPool.remove(executionPlanId); getMemoryManager().releaseKernelStackFrame(executionPlanId); - codeCache.reset(); + + SPIRVCodeCache spirvCodeCache = getSPIRVCodeCache(executionPlanId); + spirvCodeCache.reset(); wasReset = true; } + private SPIRVCodeCache getSPIRVCodeCache(long executionPlanId) { + if (!codeCache.containsKey(executionPlanId)) { + SPIRVCodeCache spirvCodeCache; + if (this instanceof SPIRVLevelZeroDeviceContext) { + spirvCodeCache = new SPIRVLevelZeroCodeCache(this); + } else { + spirvCodeCache = new SPIRVOCLCodeCache(this); + } + codeCache.put(executionPlanId, spirvCodeCache); + } + return codeCache.get(executionPlanId); + } + public int readBuffer(long executionPlanId, long bufferId, long offset, long bytes, byte[] value, long hostOffset, int[] waitEvents) { ProfilerTransfer profilerTransfer = createStartAndStopBufferTimers(); executionIds.add(executionPlanId); @@ -382,29 +397,34 @@ public void flush(long executionPlanId, int deviceIndex) { spirvContext.flush(executionPlanId, deviceIndex); } - public TornadoInstalledCode installBinary(SPIRVCompilationResult result) { - return installBinary(result.getMeta(), result.getId(), result.getName(), result.getSPIRVBinary()); + public TornadoInstalledCode installBinary(long executionPlanId, SPIRVCompilationResult result) { + return installBinary(executionPlanId, result.getMeta(), result.getId(), result.getName(), result.getSPIRVBinary()); } - public SPIRVInstalledCode installBinary(TaskDataContext meta, String id, String entryPoint, byte[] code) { - return codeCache.installSPIRVBinary(meta, id, entryPoint, code); + public SPIRVInstalledCode installBinary(long executionPlanId, TaskDataContext meta, String id, String entryPoint, byte[] code) { + SPIRVCodeCache spirvCodeCache = getSPIRVCodeCache(executionPlanId); + return spirvCodeCache.installSPIRVBinary(meta, id, entryPoint, code); } - public SPIRVInstalledCode installBinary(TaskDataContext meta, String id, String entryPoint, String pathToFile) { - return codeCache.installSPIRVBinary(meta, id, entryPoint, pathToFile); + public SPIRVInstalledCode installBinary(long executionPlanId, TaskDataContext meta, String id, String entryPoint, String pathToFile) { + SPIRVCodeCache spirvCodeCache = getSPIRVCodeCache(executionPlanId); + return spirvCodeCache.installSPIRVBinary(meta, id, entryPoint, pathToFile); } - public boolean isCached(String id, String entryPoint) { - return codeCache.isCached(id + "-" + entryPoint); + public boolean isCached(long executionPlanId, String id, String entryPoint) { + SPIRVCodeCache spirvCodeCache = getSPIRVCodeCache(executionPlanId); + return spirvCodeCache.isCached(id + "-" + entryPoint); } @Override - public boolean isCached(String methodName, SchedulableTask task) { - return codeCache.isCached(task.getId() + "-" + methodName); + public boolean isCached(long executionPlanId, String methodName, SchedulableTask task) { + SPIRVCodeCache spirvCodeCache = getSPIRVCodeCache(executionPlanId); + return spirvCodeCache.isCached(task.getId() + "-" + methodName); } - public SPIRVInstalledCode getInstalledCode(String id, String entryPoint) { - return codeCache.getInstalledCode(id, entryPoint); + public SPIRVInstalledCode getInstalledCode(long executionPlanId, String id, String entryPoint) { + SPIRVCodeCache spirvCodeCache = getSPIRVCodeCache(executionPlanId); + return spirvCodeCache.getInstalledCode(id, entryPoint); } public int enqueueMarker(long executionPlanId) { @@ -437,7 +457,7 @@ public Event resolveEvent(long executionPlanId, int eventId) { OCLEventPool eventPool = context.getOCLEventPool(executionPlanId); return new OCLEvent(eventPool.getDescriptor(eventId).getNameDescription(), commandQueue, eventId, eventPool.getOCLEvent(eventId)); } else { - throw new RuntimeException("Not implemented yet"); + throw new TornadoRuntimeException("[Error] SPIR-V Device Context Class not implemented yet."); } } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java index 454fd0af8d..25384bcfcf 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java @@ -122,31 +122,31 @@ public XPUBuffer createOrReuseAtomicsBuffer(int[] arr) { } @Override - public TornadoInstalledCode installCode(SchedulableTask task) { + public TornadoInstalledCode installCode(long executionPlanId, SchedulableTask task) { if (task instanceof CompilableTask) { - return compileTask((CompilableTask) task); + return compileTask(executionPlanId, (CompilableTask) task); } else if (task instanceof PrebuiltTask) { - return compilePreBuiltTask((PrebuiltTask) task); + return compilePreBuiltTask(executionPlanId, (PrebuiltTask) task); } else { throw new RuntimeException("SchedulableTask type is not supported: " + task.getClass()); } } - private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask task) { + private TornadoInstalledCode compilePreBuiltTask(long executionPlanId, PrebuiltTask task) { final SPIRVDeviceContext deviceContext = getDeviceContext(); - if (deviceContext.isCached(task.getId(), task.getEntryPoint())) { - return deviceContext.getInstalledCode(task.getId(), task.getEntryPoint()); + if (deviceContext.isCached(executionPlanId, task.getId(), task.getEntryPoint())) { + return deviceContext.getInstalledCode(executionPlanId, task.getId(), task.getEntryPoint()); } final Path pathToSPIRVBin = Paths.get(task.getFilename()); TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", task.getFilename()); - return deviceContext.installBinary(task.meta(), task.getId(), task.getEntryPoint(), task.getFilename()); + return deviceContext.installBinary(executionPlanId, task.meta(), task.getId(), task.getEntryPoint(), task.getFilename()); } public SPIRVBackend getBackend() { return findDriver().getBackendOfDevice(device); } - private TornadoInstalledCode compileTask(CompilableTask task) { + private TornadoInstalledCode compileTask(long executionPlanId, CompilableTask task) { TornadoProfiler profiler = task.getProfiler(); final SPIRVDeviceContext deviceContext = getDeviceContext(); @@ -157,8 +157,8 @@ private TornadoInstalledCode compileTask(CompilableTask task) { final TaskDataContext taskMeta = task.meta(); // Return the code from the cache - if (!task.shouldCompile() && deviceContext.isCached(task.getId(), resolvedMethod.getName())) { - return deviceContext.getInstalledCode(task.getId(), resolvedMethod.getName()); + if (!task.shouldCompile() && deviceContext.isCached(executionPlanId, task.getId(), resolvedMethod.getName())) { + return deviceContext.getInstalledCode(executionPlanId, task.getId(), resolvedMethod.getName()); } final Access[] sketchAccess = sketch.getArgumentsAccess(); @@ -176,7 +176,7 @@ private TornadoInstalledCode compileTask(CompilableTask task) { profiler.sum(ProfilerType.TOTAL_GRAAL_COMPILE_TIME, profiler.getTaskTimer(ProfilerType.TASK_COMPILE_GRAAL_TIME, taskMeta.getId())); profiler.start(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId()); - TornadoInstalledCode installedCode = deviceContext.installBinary(result); + TornadoInstalledCode installedCode = deviceContext.installBinary(executionPlanId, result); profiler.stop(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId()); profiler.sum(ProfilerType.TOTAL_DRIVER_COMPILE_TIME, profiler.getTaskTimer(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId())); return installedCode; @@ -193,12 +193,12 @@ private TornadoInstalledCode compileTask(CompilableTask task) { } @Override - public boolean isFullJITMode(SchedulableTask task) { + public boolean isFullJITMode(long executionPlanId, SchedulableTask task) { return false; } @Override - public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { + public TornadoInstalledCode getCodeFromCache(long executionPlanId, SchedulableTask task) { return null; } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVJITCompiler.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVJITCompiler.java index a514991f6c..fdbc5c88bb 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVJITCompiler.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVJITCompiler.java @@ -76,7 +76,7 @@ public static void main(String[] args) { new TestSPIRVJITCompiler().test(); } - public MetaCompilation compileMethod(Class klass, String methodName, Object... parameters) { + public MetaCompilation compileMethod(long executionPlanId, Class klass, String methodName, Object... parameters) { // Get the method object to be compiled Method methodToCompile = CompilerUtil.getMethodForName(klass, methodName); @@ -111,7 +111,7 @@ public MetaCompilation compileMethod(Class klass, String methodName, Object.. // 3. Install the SPIR-V code into the VM SPIRVDevice spirvDevice = (SPIRVDevice) device.getDeviceContext().getDevice(); - SPIRVInstalledCode spirvInstalledCode = (SPIRVInstalledCode) spirvDevice.getDeviceContext().installBinary(spirvCompilationResult); + SPIRVInstalledCode spirvInstalledCode = (SPIRVInstalledCode) spirvDevice.getDeviceContext().installBinary(executionPlanId, spirvCompilationResult); return new MetaCompilation(taskMeta, spirvInstalledCode); } @@ -159,12 +159,13 @@ public void test() { int[] a = new int[N]; int[] b = new int[N]; float[] c = new float[N]; + final long executionPlanId = 0; Arrays.fill(a, -10); Arrays.fill(b, 10); // Obtain the SPIR-V binary from the Java method - MetaCompilation compileMethod = compileMethod(TestSPIRVJITCompiler.class, "methodToCompile", a, b, c); + MetaCompilation compileMethod = compileMethod(executionPlanId, TestSPIRVJITCompiler.class, "methodToCompile", a, b, c); TornadoDevice device = TornadoCoreRuntime.getTornadoRuntime().getBackend(SPIRVBackendImpl.class).getDefaultDevice(); diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/JVMMapping.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/JVMMapping.java index da2fc2b4b7..3e7d948364 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/JVMMapping.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/JVMMapping.java @@ -131,7 +131,7 @@ public XPUBuffer createOrReuseAtomicsBuffer(int[] arr) { } @Override - public TornadoInstalledCode installCode(SchedulableTask task) { + public TornadoInstalledCode installCode(long executionPlanId, SchedulableTask task) { return null; } @@ -201,12 +201,12 @@ public TornadoDeviceType getDeviceType() { } @Override - public boolean isFullJITMode(SchedulableTask task) { + public boolean isFullJITMode(long executionPlanId, SchedulableTask task) { return false; } @Override - public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { + public TornadoInstalledCode getCodeFromCache(long executionPlanId, SchedulableTask task) { return null; } diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoXPUDevice.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoXPUDevice.java index e8514797cc..ef030bd76b 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoXPUDevice.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoXPUDevice.java @@ -65,32 +65,38 @@ public interface TornadoXPUDevice extends TornadoDevice { /** * It installs the Tornado code for the specified schedulable task. * + * @param executionPlanId + * ID number for the execution plan that the task belongs to. * @param task * The {@link SchedulableTask} to install the code for. * @return The {@link TornadoInstalledCode} indicating the installation status. */ - TornadoInstalledCode installCode(SchedulableTask task); + TornadoInstalledCode installCode(long executionPlanId, SchedulableTask task); /** * It checks if the specified schedulable task is in full Just-In-Time (JIT) * mode. * + * @param executionPlanId + * ID for the execution plan to be compiled. * @param task * The {@link SchedulableTask} to check for full JIT mode. * @return True if the task is in full JIT mode, false otherwise. */ - boolean isFullJITMode(SchedulableTask task); + boolean isFullJITMode(long executionPlanId, SchedulableTask task); /** * It retrieves the Tornado installed code from the cache for the specified * schedulable task. * + * @param executionPlanId + * ID for the execution plan to be obtained from the code cache * @param task * The {@link SchedulableTask} to get the installed code from the * cache. * @return The {@link TornadoInstalledCode} from the cache. */ - TornadoInstalledCode getCodeFromCache(SchedulableTask task); + TornadoInstalledCode getCodeFromCache(long executionPlanId, SchedulableTask task); /** * It checks for atomic operations in the specified schedulable task and returns diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java index 6e44a14fad..3fd24fce84 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java @@ -661,7 +661,7 @@ private XPUExecutionFrame compileTaskFromBytecodeToBinary(final int callWrapperI task.forceCompilation(); } - installedCodes[globalToLocalTaskIndex(taskIndex)] = interpreterDevice.installCode(task); + installedCodes[globalToLocalTaskIndex(taskIndex)] = interpreterDevice.installCode(graphExecutionContext.getExecutionPlanId(), task); profilerUpdateForPreCompiledTask(task); // After the compilation has been completed, increment // the batch number of the task and update it. @@ -697,7 +697,7 @@ private int executeLaunch(StringBuilder tornadoVMBytecodeList, final int numArgs if (installedCodes[globalToLocalTaskIndex(taskIndex)] == null) { // After warming-up, it is possible to get a null pointer in the task-cache due // to lazy compilation for FPGAs. In tha case, we check again the code cache. - installedCodes[globalToLocalTaskIndex(taskIndex)] = interpreterDevice.getCodeFromCache(task); + installedCodes[globalToLocalTaskIndex(taskIndex)] = interpreterDevice.getCodeFromCache(graphExecutionContext.getExecutionPlanId(), task); } final TornadoInstalledCode installedCode = installedCodes[globalToLocalTaskIndex(taskIndex)]; diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java index e1db368b49..89165f43c6 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java @@ -541,7 +541,7 @@ public void setDevice(TornadoDevice device) { task.meta().setDevice(device); if (task instanceof CompilableTask compilableTask) { ResolvedJavaMethod method = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(compilableTask.getMethod()); - if (!meta().getXPUDevice().getDeviceContext().isCached(method.getName(), compilableTask)) { + if (!meta().getXPUDevice().getDeviceContext().isCached(executionPlanId, method.getName(), compilableTask)) { updateInner(i, executionContext.getTask(i)); } } @@ -584,7 +584,7 @@ public void setDevice(String taskName, TornadoDevice device) { task.meta().setDevice(device); if (task instanceof CompilableTask) { ResolvedJavaMethod method = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(((CompilableTask) task).getMethod()); - if (!task.getDevice().getDeviceContext().isCached(method.getName(), task)) { + if (!task.getDevice().getDeviceContext().isCached(executionPlanId, method.getName(), task)) { updateInner(i, task); } } @@ -817,7 +817,7 @@ private void preCompileForFPGAs() { if (TornadoOptions.FPGA_EMULATION) { compile = true; } else if (executionContext.getDeviceOfFirstTask() instanceof TornadoXPUDevice tornadoAcceleratorDevice) { - if (tornadoAcceleratorDevice.isFullJITMode(executionContext.getTask(0))) { + if (tornadoAcceleratorDevice.isFullJITMode(executionPlanId, executionContext.getTask(0))) { compile = true; } }