Skip to content

Commit

Permalink
[refactor] Replaced KernelContext methods for retrieving work group s…
Browse files Browse the repository at this point in the history
…izes with fields
  • Loading branch information
stratika committed Sep 8, 2021
1 parent 0a15618 commit e1ebd66
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GroupIdNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.LocalArrayNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.LocalThreadIdNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.LocalThreadSizeNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.calc.DivNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.vector.VectorLoadNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.vector.VectorStoreNode;
Expand All @@ -111,6 +112,8 @@
import uk.ac.manchester.tornado.drivers.opencl.graal.snippets.ReduceGPUSnippets;
import uk.ac.manchester.tornado.runtime.TornadoVMConfig;
import uk.ac.manchester.tornado.runtime.graal.nodes.GetGroupIdFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.GlobalGroupSizeFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.LocalGroupSizeFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.NewArrayNonVirtualizableNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.StoreAtomicIndexedNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ThreadIdFixedWithNextNode;
Expand Down Expand Up @@ -194,6 +197,10 @@ public void lower(Node node, LoweringTool tool) {
lowerLocalThreadIdNode((ThreadLocalIdFixedWithNextNode) node);
} else if (node instanceof GetGroupIdFixedWithNextNode) {
lowerGetGroupIdNode((GetGroupIdFixedWithNextNode) node);
} else if (node instanceof GlobalGroupSizeFixedWithNextNode) {
lowerGlobalGroupSizeNode((GlobalGroupSizeFixedWithNextNode) node);
} else if (node instanceof LocalGroupSizeFixedWithNextNode) {
lowerLocalGroupSizeNode((LocalGroupSizeFixedWithNextNode) node);
} else {
super.lower(node, tool);
}
Expand Down Expand Up @@ -310,6 +317,18 @@ private void lowerGetGroupIdNode(GetGroupIdFixedWithNextNode getGroupIdNode) {
graph.replaceFixedWithFloating(getGroupIdNode, groupIdNode);
}

private void lowerGlobalGroupSizeNode(GlobalGroupSizeFixedWithNextNode globalGroupSizeNode) {
StructuredGraph graph = globalGroupSizeNode.graph();
GlobalThreadSizeNode globalThreadSizeNode = graph.addOrUnique(new GlobalThreadSizeNode(ConstantNode.forInt(globalGroupSizeNode.getDimension(), graph)));
graph.replaceFixedWithFloating(globalGroupSizeNode, globalThreadSizeNode);
}

private void lowerLocalGroupSizeNode(LocalGroupSizeFixedWithNextNode localGroupSizeNode) {
StructuredGraph graph = localGroupSizeNode.graph();
LocalThreadSizeNode localThreadSizeNode = graph.addOrUnique(new LocalThreadSizeNode(ConstantNode.forInt(localGroupSizeNode.getDimension(), graph)));
graph.replaceFixedWithFloating(localGroupSizeNode, localThreadSizeNode);
}

@Override
protected void lowerArrayLengthNode(ArrayLengthNode arrayLengthNode, LoweringTool tool) {
StructuredGraph graph = arrayLengthNode.graph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,29 +224,6 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
});
}

private static void registerLocalWorkGroup(Registration r, JavaKind returnedJavaKind) {
r.register2("getLocalGroupSize", Receiver.class, int.class, new InvocationPlugin() {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
LocalThreadSizeNode localThreadSizeNode = new LocalThreadSizeNode((ConstantNode) size);
b.push(returnedJavaKind, localThreadSizeNode);
return true;
}
});
}

private static void registerGlobalWorkGroupSize(Registration r) {
JavaKind returnedJavaKind = JavaKind.Int;
r.register2("getGlobalGroupSize", Receiver.class, int.class, new InvocationPlugin() {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
GlobalThreadSizeNode threadSize = new GlobalThreadSizeNode((ConstantNode) size);
b.push(returnedJavaKind, threadSize);
return true;
}
});
}

private static void registerIntLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
r.register2("allocateIntLocalArray", Receiver.class, int.class, new InvocationPlugin() {
@Override
Expand Down Expand Up @@ -295,11 +272,6 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
});
}

private static void localWorkGroupPlugin(Registration r) {
JavaKind returnedJavaKind = JavaKind.Int;
registerLocalWorkGroup(r, returnedJavaKind);
}

private static void localArraysPlugins(Registration r) {
JavaKind returnedJavaKind = JavaKind.Object;

Expand All @@ -321,8 +293,6 @@ private static void registerKernelContextPlugins(InvocationPlugins plugins) {

registerLocalBarrier(r);
registerGlobalBarrier(r);
localWorkGroupPlugin(r);
registerGlobalWorkGroupSize(r);
localArraysPlugins(r);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,19 @@
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.CastNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.FixedArrayNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.GlobalThreadIdNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.GlobalThreadSizeNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.GroupIdNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.LocalArrayNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.LocalThreadIdNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.LocalThreadSizeNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.calc.DivNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.vector.LoadIndexedVectorNode;
import uk.ac.manchester.tornado.drivers.ptx.graal.phases.TornadoFloatingReadReplacement;
import uk.ac.manchester.tornado.drivers.ptx.graal.snippets.PTXGPUReduceSnippets;
import uk.ac.manchester.tornado.runtime.TornadoVMConfig;
import uk.ac.manchester.tornado.runtime.graal.nodes.GetGroupIdFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.GlobalGroupSizeFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.LocalGroupSizeFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.NewArrayNonVirtualizableNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.StoreAtomicIndexedNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ThreadIdFixedWithNextNode;
Expand Down Expand Up @@ -224,6 +228,10 @@ public void lower(Node node, LoweringTool tool) {
lowerLocalThreadIdNode((ThreadLocalIdFixedWithNextNode) node);
} else if (node instanceof GetGroupIdFixedWithNextNode) {
lowerGetGroupIdNode((GetGroupIdFixedWithNextNode) node);
} else if (node instanceof GlobalGroupSizeFixedWithNextNode) {
lowerGlobalGroupSizeNode((GlobalGroupSizeFixedWithNextNode) node);
} else if (node instanceof LocalGroupSizeFixedWithNextNode) {
lowerLocalGroupSizeNode((LocalGroupSizeFixedWithNextNode) node);
} else {
super.lower(node, tool);
}
Expand Down Expand Up @@ -369,6 +377,18 @@ private void lowerGetGroupIdNode(GetGroupIdFixedWithNextNode getGroupIdNode) {
graph.replaceFixedWithFloating(getGroupIdNode, groupIdNode);
}

private void lowerGlobalGroupSizeNode(GlobalGroupSizeFixedWithNextNode globalGroupSizeNode) {
StructuredGraph graph = globalGroupSizeNode.graph();
GlobalThreadSizeNode globalThreadSizeNode = graph.addOrUnique(new GlobalThreadSizeNode(ConstantNode.forInt(globalGroupSizeNode.getDimension(), graph)));
graph.replaceFixedWithFloating(globalGroupSizeNode, globalThreadSizeNode);
}

private void lowerLocalGroupSizeNode(LocalGroupSizeFixedWithNextNode localGroupSizeNode) {
StructuredGraph graph = localGroupSizeNode.graph();
LocalThreadSizeNode localThreadSizeNode = graph.addOrUnique(new LocalThreadSizeNode(ConstantNode.forInt(localGroupSizeNode.getDimension(), graph)));
graph.replaceFixedWithFloating(localGroupSizeNode, localThreadSizeNode);
}

@Override
protected void lowerArrayLengthNode(ArrayLengthNode arrayLengthNode, LoweringTool tool) {
StructuredGraph graph = arrayLengthNode.graph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,29 +194,6 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
});
}

private static void registerLocalWorkGroup(Registration r, JavaKind returnedJavaKind) {
r.register2("getLocalGroupSize", InvocationPlugin.Receiver.class, int.class, new InvocationPlugin() {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
LocalThreadSizeNode localThreadSizeNode = new LocalThreadSizeNode((ConstantNode) size);
b.push(returnedJavaKind, localThreadSizeNode);
return true;
}
});
}

private static void registerGlobalWorkGroupSize(Registration r) {
JavaKind returnedJavaKind = JavaKind.Int;
r.register2("getGlobalGroupSize", InvocationPlugin.Receiver.class, int.class, new InvocationPlugin() {
@Override
public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode size) {
GlobalThreadSizeNode threadSize = new GlobalThreadSizeNode((ConstantNode) size);
b.push(returnedJavaKind, threadSize);
return true;
}
});
}

private static void registerIntLocalArray(Registration r, JavaKind returnedJavaKind, JavaKind elementType) {
r.register2("allocateIntLocalArray", InvocationPlugin.Receiver.class, int.class, new InvocationPlugin() {
@Override
Expand Down Expand Up @@ -265,11 +242,6 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec
});
}

private static void localWorkGroupPlugin(Registration r) {
JavaKind returnedJavaKind = JavaKind.Int;
registerLocalWorkGroup(r, returnedJavaKind);
}

private static void localArraysPlugins(Registration r) {
JavaKind returnedJavaKind = JavaKind.Object;

Expand All @@ -291,8 +263,6 @@ private static void registerKernelContextPlugins(InvocationPlugins plugins) {

registerLocalBarrier(r);
registerGlobalBarrier(r);
localWorkGroupPlugin(r);
registerGlobalWorkGroupSize(r);
localArraysPlugins(r);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.runtime.graal.nodes.GetGroupIdFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.GlobalGroupSizeFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.LocalGroupSizeFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ThreadIdFixedWithNextNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ThreadLocalIdFixedWithNextNode;

Expand Down Expand Up @@ -114,6 +116,32 @@ private void introduceKernelContext(StructuredGraph graph) {
}

replaceKernelContextNode(graph, nodesToBeRemoved, node, groupIdNode);
} else if (field.contains("KernelContext.globalGroupSize")) {
GlobalGroupSizeFixedWithNextNode globalGroupSizeNode;
if (field.contains("globalGroupSizeX")) {
globalGroupSizeNode = new GlobalGroupSizeFixedWithNextNode(node.getValue(), 0);
} else if (field.contains("globalGroupSizeY")) {
globalGroupSizeNode = new GlobalGroupSizeFixedWithNextNode(node.getValue(), 1);
} else if (field.contains("globalGroupSizeZ")) {
globalGroupSizeNode = new GlobalGroupSizeFixedWithNextNode(node.getValue(), 2);
} else {
throw new TornadoRuntimeException("Unrecognized dimension");
}

replaceKernelContextNode(graph, nodesToBeRemoved, node, globalGroupSizeNode);
} else if (field.contains("KernelContext.localGroupSize")) {
LocalGroupSizeFixedWithNextNode localGroupSizeNode;
if (field.contains("localGroupSizeX")) {
localGroupSizeNode = new LocalGroupSizeFixedWithNextNode(node.getValue(), 0);
} else if (field.contains("localGroupSizeY")) {
localGroupSizeNode = new LocalGroupSizeFixedWithNextNode(node.getValue(), 1);
} else if (field.contains("localGroupSizeZ")) {
localGroupSizeNode = new LocalGroupSizeFixedWithNextNode(node.getValue(), 2);
} else {
throw new TornadoRuntimeException("Unrecognized dimension");
}

replaceKernelContextNode(graph, nodesToBeRemoved, node, localGroupSizeNode);
} else {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,28 @@ public class KernelContext implements ExecutionContext {
public final Integer localIdy = 0;
public final Integer localIdz = 0;

/**
* It returns the global group size of a particular dimension (e.g. X, Y, Z).
* <p>
* OpenCL equivalent: get_global_size();
* <p>
* PTX equivalent: gridDim * blockDim
*/
public final Integer globalGroupSizeX = 0;
public final Integer globalGroupSizeY = 0;
public final Integer globalGroupSizeZ = 0;

/**
* It returns the global group size of a particular dimension (e.g. X, Y, Z).
* <p>
* OpenCL equivalent: get_local_size();
* <p>
* PTX equivalent: blockDim
*/
public final Integer localGroupSizeX = 0;
public final Integer localGroupSizeY = 0;
public final Integer localGroupSizeZ = 0;

/**
* Class constructor specifying a particular {@link WorkerGrid} object.
*/
Expand Down

0 comments on commit e1ebd66

Please sign in to comment.