diff --git a/bin/compile b/bin/compile index 7f7f9c3aa0..56695d0f1e 100755 --- a/bin/compile +++ b/bin/compile @@ -191,7 +191,7 @@ def build_spirv_toolkit_and_level_zero(rebuild=False): if (rebuild or build): os.chdir(spirv_tool_kit) - subprocess.run(["git", "pull", "origin", "master"]) + subprocess.run(["git", "pull", "origin", "feat/half"]) subprocess.run(["mvn", "clean", "package"]) subprocess.run(["mvn", "install"]) os.chdir(current) diff --git a/tornado-api/src/main/java/module-info.java b/tornado-api/src/main/java/module-info.java index 420554a6fc..da0e5cb170 100644 --- a/tornado-api/src/main/java/module-info.java +++ b/tornado-api/src/main/java/module-info.java @@ -45,4 +45,5 @@ opens uk.ac.manchester.tornado.api.types.volumes; exports uk.ac.manchester.tornado.api.types.vectors; opens uk.ac.manchester.tornado.api.types.vectors; + exports uk.ac.manchester.tornado.api.types; } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/HalfFloat.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/HalfFloat.java new file mode 100644 index 0000000000..79ed55a7b6 --- /dev/null +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/HalfFloat.java @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.api.types; + +/** + * This class represents a float-16 instance (half float). The data is stored in a short field, to be + * compliant with the representation for float-16 used in the {@link Float} class. The class encapsulates + * methods for getting the data in float-16 and float-32 format, and for basic arithmetic operations (i.e. + * addition, subtraction, multiplication and division). + */ +public class HalfFloat { + + private short halfFloatValue; + + /** + * Constructs a new instance of the {@code HalfFloat} out of a float value. + * To convert the float to a float-16, the floatToFloat16 function of the {@link Float} + * class is used. + * + * @param halfFloat + * The float value that will be stored in a half-float format. + */ + public HalfFloat(float halfFloat) { + this.halfFloatValue = Float.floatToFloat16(halfFloat); + } + + /** + * Constructs a new instance of the {@code HalfFloat} with a given short value. + * + * @param halfFloat + * The short value that represents the half float. + */ + public HalfFloat(short halfFloat) { + this.halfFloatValue = halfFloat; + } + + /** + * Gets the half-float stored in the class. + * + * @return The half float value stored in the {@code HalfFloat} object. + */ + public short getHalfFloatValue() { + return this.halfFloatValue; + } + + /** + * Gets the half-float stored in the class in a 32-bit representation. + * + * @return The float-32 equivalent value the half float stored in the {@code HalfFloat} object. + */ + public float getFloat32() { + return Float.float16ToFloat(halfFloatValue); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs an addition. + * + * @param a + * The first float-16 input for the addition. + * @param b + * The second float-16 input for the addition. + * @return The result of the addition. + */ + private static float addHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA + floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the addition. + * + * @param a + * The first {@code HalfFloat} input for the addition. + * @param b + * The second {@code HalfFloat} input for the addition. + * @return A new {@HalfFloat} containing the results of the addition. + */ + public static HalfFloat add(HalfFloat a, HalfFloat b) { + float result = addHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs a subtraction. + * + * @param a + * The first float-16 input for the subtraction. + * @param b + * The second float-16 input for the subtraction. + * @return The result of the subtraction. + */ + private static float subHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA - floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the subtraction. + * + * @param a + * The first {@code HalfFloat} input for the subtraction. + * @param b + * The second {@code HalfFloat} input for the subtraction. + * @return A new {@HalfFloat} containing the results of the subtraction. + */ + public static HalfFloat sub(HalfFloat a, HalfFloat b) { + float result = subHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs a multiplication. + * + * @param a + * The first float-16 input for the multiplication. + * @param b + * The second float-16 input for the multiplication. + * @return The result of the multiplication. + */ + private static float multHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA * floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the multiplication. + * + * @param a + * The first {@code HalfFloat} input for the multiplication. + * @param b + * The second {@code HalfFloat} input for the multiplication. + * @return A new {@HalfFloat} containing the results of the multiplication. + */ + public static HalfFloat mult(HalfFloat a, HalfFloat b) { + float result = multHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs a division. + * + * @param a + * The first float-16 input for the division. + * @param b + * The second float-16 input for the division. + * @return The result of the division. + */ + private static float divHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA / floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the division. + * + * @param a + * The first {@code HalfFloat} input for the division. + * @param b + * The second {@code HalfFloat} input for the division. + * @return A new {@HalfFloat} containing the results of the division. + */ + public static HalfFloat div(HalfFloat a, HalfFloat b) { + float result = divHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + +} diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java new file mode 100644 index 0000000000..d45e200ac1 --- /dev/null +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.api.types.arrays; + +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_SHORT; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize; +import uk.ac.manchester.tornado.api.types.HalfFloat; + +/** + * This class represents an array of half floats (float16 types) stored in native memory. + * The half float data is stored in a {@link MemorySegment}, which represents a contiguous region of off-heap memory. + * The class also encapsulates methods for setting and getting half float values, + * for initializing the half float array, and for converting the array to and from different representations. + */ +@SegmentElementSize(size = 2) +public final class HalfFloatArray extends TornadoNativeArray { + + private static final int HALF_FLOAT_BYTES = 2; + private MemorySegment segment; + + private int numberOfElements; + + private int arrayHeaderSize; + + private int baseIndex; + + private long segmentByteSize; + + /** + * Constructs a new instance of the {@code HalfFloatArray} that will store a user-specified number of elements. + * + * @param numberOfElements + * The number of elements in the array. + */ + public HalfFloatArray(int numberOfElements) { + this.numberOfElements = numberOfElements; + arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER; + baseIndex = arrayHeaderSize / HALF_FLOAT_BYTES; + segmentByteSize = numberOfElements * HALF_FLOAT_BYTES + arrayHeaderSize; + + segment = Arena.ofAuto().allocate(segmentByteSize, 1); + segment.setAtIndex(JAVA_INT, 0, numberOfElements); + } + + /** + * Internal method used to create a new instance of the {@code HalfFloatArray} from on-heap data. + * + * @param values + * The on-heap {@link HalfFloat} to create the instance from. + * @return A new {@code HalfFloatArray} instance, initialized with values of the on-heap {@link HalfFloat} array. + */ + private static HalfFloatArray createSegment(HalfFloat[] values) { + HalfFloatArray array = new HalfFloatArray(values.length); + for (int i = 0; i < values.length; i++) { + array.set(i, values[i]); + } + return array; + } + + /** + * Creates a new instance of the {@code HalfFloatArray} class from an on-heap {@link HalfFloat}. + * + * @param values + * The on-heap {@link HalfFloat} array to create the instance from. + * @return A new {@code HalfFloatArray} instance, initialized with values of the on-heap {@link HalfFloat} array. + */ + public static HalfFloatArray fromArray(HalfFloat[] values) { + return createSegment(values); + } + + /** + * Creates a new instance of the {@code HalfFloatArray} class from a set of {@link HalfFloat} values. + * + * @param values + * The {@link HalfFloat} values to initialize the array with. + * @return A new {@code FloatArray} instance, initialized with the given values. + */ + public static HalfFloatArray fromElements(HalfFloat... values) { + return createSegment(values); + } + + /** + * Creates a new instance of the {@code HalfFloatArray} class from a {@link MemorySegment}. + * + * @param segment + * The {@link MemorySegment} containing the off-heap half float data. + * @return A new {@code HalfFloatArray} instance, initialized with the segment data. + */ + public static HalfFloatArray fromSegment(MemorySegment segment) { + long byteSize = segment.byteSize(); + int numElements = (int) (byteSize / HALF_FLOAT_BYTES); + HalfFloatArray halfFloatArray = new HalfFloatArray(numElements); + MemorySegment.copy(segment, 0, halfFloatArray.segment, halfFloatArray.baseIndex * HALF_FLOAT_BYTES, byteSize); + return halfFloatArray; + } + + /** + * Converts the {@link HalfFloat} data from off-heap to on-heap, by copying the values of a {@code HalfFloatArray} + * instance into a new on-heap {@link HalfFloat}. + * + * @return A new on-heap {@link HalfFloat} array, initialized with the values stored in the {@code HalfFloatArray} instance. + */ + public HalfFloat[] toHeapArray() { + HalfFloat[] outputArray = new HalfFloat[getSize()]; + for (int i = 0; i < getSize(); i++) { + outputArray[i] = get(i); + } + return outputArray; + } + + /** + * Sets the {@link HalfFloat} value at a specified index of the {@code HalfFloatArray} instance. + * + * @param index + * The index at which to set the {@link HalfFloat} value. + * @param value + * The {@link HalfFloat} value to store at the specified index. + */ + public void set(int index, HalfFloat value) { + segment.setAtIndex(JAVA_SHORT, baseIndex + index, value.getHalfFloatValue()); + } + + /** + * Gets the {@link HalfFloat} value stored at the specified index of the {@code HalfFloatArray} instance. + * + * @param index + * The index of which to retrieve the {@link HalfFloat} value. + * @return + */ + public HalfFloat get(int index) { + short halfFloatValue = segment.getAtIndex(JAVA_SHORT, baseIndex + index); + return new HalfFloat(halfFloatValue); + } + + /** + * Sets all the values of the {@code HalfFloatArray} instance to zero. + */ + @Override + public void clear() { + init(new HalfFloat(0.0f)); + } + + @Override + public int getElementSize() { + return HALF_FLOAT_BYTES; + } + + /** + * Initializes all the elements of the {@code HalfFloatArray} instance with a specified value. + * + * @param value + * The {@link HalfFloat} value to initialize the {@code HalfFloatArray} instance with. + */ + public void init(HalfFloat value) { + for (int i = 0; i < getSize(); i++) { + segment.setAtIndex(JAVA_SHORT, baseIndex + i, value.getHalfFloatValue()); + } + } + + /** + * Returns the number of half float elements stored in the {@code HalfFloatArray} instance. + * + * @return + */ + @Override + public int getSize() { + return numberOfElements; + } + + /** + * Returns the underlying {@link MemorySegment} of the {@code HalfFloatArray} instance. + * + * @return The {@link MemorySegment} associated with the {@code HalfFloatArray} instance. + */ + @Override + public MemorySegment getSegment() { + return segment; + } + + /** + * Returns the total number of bytes that the {@link MemorySegment}, associated with the {@code HalfFloatArray} instance, occupies. + * + * @return The total number of bytes of the {@link MemorySegment}. + */ + @Override + public long getNumBytesOfSegment() { + return segmentByteSize; + } + + /** + * Returns the number of bytes of the {@link MemorySegment} that is associated with the {@code HalfFloatArray} instance, + * excluding the header bytes. + * + * @return The number of bytes of the raw data in the {@link MemorySegment}. + */ + @Override + public long getNumBytesWithoutHeader() { + return segmentByteSize - TornadoNativeArray.ARRAY_HEADER; + } + +} diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java index 31f619a128..e7b424c499 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013-2023, APT Group, Department of Computer Science, + * Copyright (c) 2013-2024, APT Group, Department of Computer Science, * The University of Manchester. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,9 +33,7 @@ * The constant {@code ARRAY_HEADER} represents the size of the header in bytes. *

*/ -public abstract sealed class TornadoNativeArray permits // - IntArray, FloatArray, DoubleArray, LongArray, ShortArray, // - ByteArray, CharArray { +public abstract sealed class TornadoNativeArray permits ByteArray, CharArray, DoubleArray, FloatArray, IntArray, LongArray, ShortArray, HalfFloatArray { /** * The size of the header in bytes. The default value is 24, but it can be configurable through diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java index c1a54cec00..63040d0664 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java @@ -54,6 +54,8 @@ public class OCLTargetDescription extends TargetDescription { private final String extensions; private final boolean supportsInt64Atomics; + private final boolean supportsF16; + public OCLTargetDescription(Architecture arch, boolean supportsFP64, String extensions) { this(arch, false, STACK_ALIGNMENT, IMPLICIT_NULL_CHECK_LIMIT, INLINE_OBJECTS, supportsFP64, extensions); } @@ -63,6 +65,7 @@ protected OCLTargetDescription(Architecture arch, boolean isMP, int stackAlignme this.supportsFP64 = supportsFP64; this.extensions = extensions; supportsInt64Atomics = extensions.contains("cl_khr_int64_base_atomics"); + supportsF16 = extensions.contains("cl_khr_fp16"); } //@formatter:on @@ -92,6 +95,10 @@ public boolean supportsFP64() { return supportsFP64; } + public boolean supportsFP16() { + return supportsF16; + } + public boolean supportsInt64Atomics() { return supportsInt64Atomics; } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java index 1d9760f7a7..281e0e06c3 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2022-2023, APT Group, Department of Computer Science, + * Copyright (c) 2018, 2022-2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * Copyright (c) 2009, 2012, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -74,6 +74,8 @@ public OCLAssembler(TargetDescription target) { emitLine("#pragma OPENCL EXTENSION cl_khr_fp64 : enable "); } + emitLine("#pragma OPENCL EXTENSION cl_khr_fp16 : enable "); + if (((OCLTargetDescription) target).supportsInt64Atomics()) { emitLine("#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable "); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java index 9d0b1f59a1..609fa312de 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2023, APT Group, Department of Computer Science, + * Copyright (c) 2023, 2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -67,7 +67,9 @@ public enum OCLVariablePrefix { BYTE16("byte16", "b16_"), SHORT("short", "sh_"), SHORT2("short2", "sh2_"), - SHORT3("short3", "sh3_"); + SHORT3("short3", "sh3_"), + HALF("half", "half_"); + // @formatter:on diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java index bc129e17e4..8011b10d36 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2018, 2020, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. @@ -50,6 +50,7 @@ import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoLocalMemoryAllocation; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoNewArrayDevirtualizationReplacement; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPrivateArrayPiRemoval; +import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoHalfFloatReplacement; import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoOpenCLIntrinsicsReplacements; import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoParallelScheduler; import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoTaskSpecialisation; @@ -86,6 +87,8 @@ public OCLHighTier(OptionValues options, TornadoDeviceContext deviceContext, Can appendPhase(new TornadoNewArrayDevirtualizationReplacement()); + appendPhase(new TornadoHalfFloatReplacement()); + if (PartialEscapeAnalysis.getValue(options)) { appendPhase(new PartialEscapePhase(true, canonicalizer, options)); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java index 697547444d..ae52ccad44 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, APT Group, Department of Computer Science, + * Copyright (c) 2022, 2024 APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2018, 2020, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. @@ -116,6 +116,8 @@ public static void registerInvocationPlugins(final Plugins ps, final InvocationP // Register TornadoAtomicInteger registerTornadoAtomicInteger(ps, plugins); + OCLHalfFloatPlugins.registerPlugins(ps, plugins); + registerMemoryAccessPlugins(plugins); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java new file mode 100644 index 0000000000..b9bd8b83b0 --- /dev/null +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.opencl.graal.compiler.plugins; + +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderContext; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugin; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins; +import org.graalvm.compiler.nodes.graphbuilderconf.NodePlugin; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.ResolvedJavaMethod; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; +import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance; + +public class OCLHalfFloatPlugins { + + public static void registerPlugins(final GraphBuilderConfiguration.Plugins ps, final InvocationPlugins plugins) { + registerHalfFloatInit(ps, plugins); + } + + private static void registerHalfFloatInit(GraphBuilderConfiguration.Plugins ps, InvocationPlugins plugins) { + + final InvocationPlugins.Registration r = new InvocationPlugins.Registration(plugins, HalfFloat.class); + + ps.appendNodePlugin(new NodePlugin() { + @Override + public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, ValueNode[] args) { + if (method.getName().equals("") && (method.toString().contains("HalfFloat."))) { + NewHalfFloatInstance newHalfFloatInstance = b.append(new NewHalfFloatInstance(args[1])); + b.add(newHalfFloatInstance); + return true; + } + return false; + } + }); + + r.register(new InvocationPlugin("add", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + AddHalfFloatNode addNode = b.append(new AddHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, addNode); + return true; + } + }); + + r.register(new InvocationPlugin("sub", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + SubHalfFloatNode subNode = b.append(new SubHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, subNode); + return true; + } + }); + + r.register(new InvocationPlugin("mult", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + MultHalfFloatNode multNode = b.append(new MultHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, multNode); + return true; + } + }); + + r.register(new InvocationPlugin("div", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + DivHalfFloatNode divNode = b.append(new DivHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, divNode); + return true; + } + }); + + r.register(new InvocationPlugin("getHalfFloatValue", InvocationPlugin.Receiver.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) { + b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get()))); + return true; + } + }); + + } + +} diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLKind.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLKind.java index 9602feec5b..d626c1d9fe 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLKind.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLKind.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2018, 2020, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. @@ -94,7 +94,7 @@ public enum OCLKind implements PlatformKind { UINT(4, null), LONG(8, java.lang.Long.TYPE), ULONG(8, null), - HALF(2, null), + HALF(2, java.lang.Short.TYPE), FLOAT(4, java.lang.Float.TYPE), DOUBLE(8, java.lang.Double.TYPE), CHAR2(2, null, CHAR), diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLUnary.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLUnary.java index da8c5cb4e4..450eb57cc0 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLUnary.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLUnary.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2013-2022, APT Group, Department of Computer Science, + * Copyright (c) 2013-2022, 2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -287,7 +287,7 @@ public static class OCLAddressCast extends UnaryConsumer { private final OCLMemoryBase base; - OCLAddressCast(OCLMemoryBase base, LIRKind lirKind) { + public OCLAddressCast(OCLMemoryBase base, LIRKind lirKind) { super(OCLUnaryTemplate.CAST_TO_POINTER, lirKind, null); this.base = base; } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/ReadHalfFloatNode.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/ReadHalfFloatNode.java new file mode 100644 index 0000000000..e83d4c3773 --- /dev/null +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/ReadHalfFloatNode.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.opencl.graal.nodes; + +import org.graalvm.compiler.core.common.LIRKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.lir.Variable; +import org.graalvm.compiler.lir.gen.LIRGeneratorTool; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.FixedWithNextNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.nodes.spi.LIRLowerable; +import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.Value; +import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture; +import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind; +import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLLIRStmt; +import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLUnary; + +@NodeInfo +public class ReadHalfFloatNode extends FixedWithNextNode implements LIRLowerable { + + public static final NodeClass TYPE = NodeClass.create(ReadHalfFloatNode.class); + + @Input + private AddressNode addressNode; + + public ReadHalfFloatNode(AddressNode addressNode) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.addressNode = addressNode; + } + + public void generate(NodeLIRBuilderTool generator) { + LIRGeneratorTool tool = generator.getLIRGeneratorTool(); + Variable result = tool.newVariable(LIRKind.value(OCLKind.HALF)); + Value addressValue = generator.operand(addressNode); + OCLArchitecture.OCLMemoryBase base = ((OCLUnary.MemoryAccess) addressValue).getBase(); + OCLUnary.OCLAddressCast cast = new OCLUnary.OCLAddressCast(base, LIRKind.value(OCLKind.HALF)); + tool.append(new OCLLIRStmt.LoadStmt(result, cast, (OCLUnary.MemoryAccess) addressValue)); + generator.setResult(this, result); + } +} diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/WriteHalfFloatNode.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/WriteHalfFloatNode.java new file mode 100644 index 0000000000..5669aa1cb1 --- /dev/null +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/nodes/WriteHalfFloatNode.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.opencl.graal.nodes; + +import org.graalvm.compiler.core.common.LIRKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.lir.gen.LIRGeneratorTool; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.FixedWithNextNode; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.nodes.spi.LIRLowerable; +import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.Value; +import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture; +import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind; +import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLLIRStmt; +import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLUnary; + +@NodeInfo +public class WriteHalfFloatNode extends FixedWithNextNode implements LIRLowerable { + + public static final NodeClass TYPE = NodeClass.create(WriteHalfFloatNode.class); + + @Input + private AddressNode addressNode; + + @Input + private ValueNode valueNode; + + public WriteHalfFloatNode(AddressNode addressNode, ValueNode valueNode) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.addressNode = addressNode; + this.valueNode = valueNode; + } + + public void generate(NodeLIRBuilderTool generator) { + LIRGeneratorTool tool = generator.getLIRGeneratorTool(); + Value addressValue = generator.operand(addressNode); + OCLArchitecture.OCLMemoryBase base = ((OCLUnary.MemoryAccess) addressValue).getBase(); + OCLUnary.OCLAddressCast cast = new OCLUnary.OCLAddressCast(base, LIRKind.value(OCLKind.HALF)); + Value input = generator.operand(valueNode); + tool.append(new OCLLIRStmt.StoreStmt(cast, (OCLUnary.MemoryAccess) addressValue, input)); + } +} diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java new file mode 100644 index 0000000000..c2d531b897 --- /dev/null +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.opencl.graal.phases; + +import java.util.Optional; + +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.nodes.GraphState; +import org.graalvm.compiler.nodes.StructuredGraph; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.calc.AddNode; +import org.graalvm.compiler.nodes.calc.FloatDivNode; +import org.graalvm.compiler.nodes.calc.MulNode; +import org.graalvm.compiler.nodes.calc.SubNode; +import org.graalvm.compiler.nodes.extended.JavaReadNode; +import org.graalvm.compiler.nodes.extended.JavaWriteNode; +import org.graalvm.compiler.nodes.java.NewInstanceNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.phases.BasePhase; + +import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.ReadHalfFloatNode; +import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.WriteHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; +import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance; +import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext; + +public class TornadoHalfFloatReplacement extends BasePhase { + + @Override + public Optional notApplicableTo(GraphState graphState) { + return ALWAYS_APPLICABLE; + } + + protected void run(StructuredGraph graph, TornadoHighTierContext context) { + + // replace reads with halfFloat reads + for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) { + if (javaRead.successors().first() instanceof NewInstanceNode) { + NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first(); + if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) { + if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) { + NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first(); + deleteFixed(newHalfFloatInstance); + } + AddressNode readingAddress = javaRead.getAddress(); + ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress); + graph.addWithoutUnique(readHalfFloatNode); + replaceFixed(javaRead, readHalfFloatNode); + newInstanceNode.replaceAtUsages(readHalfFloatNode); + deleteFixed(newInstanceNode); + } + } + } + + // replace writes with halfFloat writes + for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) { + if (isWriteHalfFloat(javaWrite)) { + // This casting is safe to do as it is already checked by the isWriteHalfFloat function + HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value(); + ValueNode writingValue; + if (javaWrite.predecessor() instanceof NewHalfFloatInstance) { + // if a new HalfFloat instance is written + NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor(); + writingValue = newHalfFloatInstance.getValue(); + if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) { + NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor(); + if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) { + deleteFixed(newInstanceNode); + deleteFixed(newHalfFloatInstance); + } + } + } else { + // if the result of an operation or a stored value is written + writingValue = placeholder.getInput(); + } + placeholder.replaceAtUsages(writingValue); + placeholder.safeDelete(); + AddressNode writingAddress = javaWrite.getAddress(); + WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue); + graph.addWithoutUnique(writeHalfFloatNode); + replaceFixed(javaWrite, writeHalfFloatNode); + deleteFixed(javaWrite); + } + } + + // replace the half float operator nodes with the corresponding regular operators + replaceAddHalfFloatNodes(graph); + replaceSubHalfFloatNodes(graph); + replaceMultHalfFloatNodes(graph); + replaceDivHalfFloatNodes(graph); + + } + + private static void replaceAddHalfFloatNodes(StructuredGraph graph) { + for (AddHalfFloatNode addHalfFloatNode : graph.getNodes().filter(AddHalfFloatNode.class)) { + AddNode addNode = new AddNode(addHalfFloatNode.getX(), addHalfFloatNode.getY()); + graph.addWithoutUnique(addNode); + addHalfFloatNode.replaceAtUsages(addNode); + addHalfFloatNode.safeDelete(); + } + } + + private static void replaceSubHalfFloatNodes(StructuredGraph graph) { + for (SubHalfFloatNode subHalfFloatNode : graph.getNodes().filter(SubHalfFloatNode.class)) { + SubNode subNode = new SubNode(subHalfFloatNode.getX(), subHalfFloatNode.getY()); + graph.addWithoutUnique(subNode); + subHalfFloatNode.replaceAtUsages(subNode); + subHalfFloatNode.safeDelete(); + } + } + + private static void replaceMultHalfFloatNodes(StructuredGraph graph) { + for (MultHalfFloatNode multHalfFloatNode : graph.getNodes().filter(MultHalfFloatNode.class)) { + MulNode mulNode = new MulNode(multHalfFloatNode.getX(), multHalfFloatNode.getY()); + graph.addWithoutUnique(mulNode); + multHalfFloatNode.replaceAtUsages(mulNode); + multHalfFloatNode.safeDelete(); + } + } + + private static void replaceDivHalfFloatNodes(StructuredGraph graph) { + for (DivHalfFloatNode divHalfFloatNode : graph.getNodes().filter(DivHalfFloatNode.class)) { + FloatDivNode divNode = new FloatDivNode(divHalfFloatNode.getX(), divHalfFloatNode.getY()); + graph.addWithoutUnique(divNode); + divHalfFloatNode.replaceAtUsages(divNode); + divHalfFloatNode.safeDelete(); + } + } + + private static boolean isWriteHalfFloat(JavaWriteNode javaWrite) { + if (javaWrite.value() instanceof HalfFloatPlaceholder) { + return true; + } + return false; + } + + private static void replaceFixed(Node n, Node other) { + Node pred = n.predecessor(); + Node suc = n.successors().first(); + + n.replaceFirstSuccessor(suc, null); + n.replaceAtPredecessor(other); + pred.replaceFirstSuccessor(n, other); + other.replaceFirstSuccessor(null, suc); + + for (Node us : n.usages()) { + n.removeUsage(us); + } + n.clearInputs(); + n.safeDelete(); + + } + + private static void deleteFixed(Node node) { + if (!node.isDeleted()) { + Node predecessor = node.predecessor(); + Node successor = node.successors().first(); + + node.replaceFirstSuccessor(successor, null); + node.replaceAtPredecessor(successor); + predecessor.replaceFirstSuccessor(node, successor); + + for (Node us : node.usages()) { + node.removeUsage(us); + } + node.clearInputs(); + node.safeDelete(); + } + } + +} diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java index 4d39fe7c60..5357adf2c5 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/mm/OCLMemorySegmentWrapper.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2023, APT Group, Department of Computer Science, + * Copyright (c) 2023, 2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -36,6 +36,7 @@ import uk.ac.manchester.tornado.api.types.arrays.CharArray; import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.api.types.arrays.LongArray; import uk.ac.manchester.tornado.api.types.arrays.ShortArray; @@ -118,6 +119,7 @@ private MemorySegment getSegment(final Object reference) { case ShortArray shortArray -> shortArray.getSegment(); case ByteArray byteArray -> byteArray.getSegment(); case CharArray charArray -> charArray.getSegment(); + case HalfFloatArray halfFloatArray -> halfFloatArray.getSegment(); case VectorFloat2 vectorFloat2 -> vectorFloat2.getArray().getSegment(); case VectorFloat3 vectorFloat3 -> vectorFloat3.getArray().getSegment(); case VectorFloat4 vectorFloat4 -> vectorFloat4.getArray().getSegment(); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java index 4b092ee1f4..a503d64bba 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2013-2020, 2023, APT Group, Department of Computer Science, + * Copyright (c) 2013-2020, 2023, 2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -55,6 +55,7 @@ import uk.ac.manchester.tornado.api.types.arrays.CharArray; import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.api.types.arrays.LongArray; import uk.ac.manchester.tornado.api.types.arrays.ShortArray; @@ -540,6 +541,8 @@ private ObjectBuffer createDeviceBuffer(Class type, Object object, OCLDeviceC result = new OCLMemorySegmentWrapper(deviceContext, batchSize); } else if (object instanceof CharArray) { result = new OCLMemorySegmentWrapper(deviceContext, batchSize); + } else if (object instanceof HalfFloatArray) { + result = new OCLMemorySegmentWrapper(deviceContext, batchSize); } else { result = new OCLObjectWrapper(deviceContext, object); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXVariablePrefix.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXVariablePrefix.java index f02561da9b..eaf2ec6ee1 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXVariablePrefix.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/PTXVariablePrefix.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2023, APT Group, Department of Computer Science, The University + * Copyright (c) 2023, 2024, APT Group, Department of Computer Science, The University * of Manchester. All rights reserved. DO NOT ALTER OR REMOVE COPYRIGHT NOTICES * OR THIS FILE HEADER. * @@ -29,6 +29,7 @@ public enum PTXVariablePrefix { */ // @formatter:off B8("b8", "rub"), + B16("b16", "rufh"), B32("b32", "rui"), B64("b64", "rbd"), S8("s8", "rsb"), @@ -37,6 +38,7 @@ public enum PTXVariablePrefix { S64("s64", "rsd"), U32("u32", "rui"), U64("u64", "rud"), + F16("f16", "rfh"), F32("f32", "rfi"), F64("f64", "rfd"), PRED("pred", "rpb"); diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssembler.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssembler.java index 3d3e9328d4..e2ee3fa3ba 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssembler.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssembler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2022-2023, APT Group, Department of Computer Science, + * Copyright (c) 2021, 2022-2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -10,7 +10,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -160,7 +160,7 @@ public static String toString(Value value) { * generated format strings. * * @param input - * The {@link Value} input to be converted. + * The {@link Value} input to be converted. * @return The converted format string in the PTX backend format. */ public static String convertValueFromGraalFormat(Value input) { @@ -184,8 +184,8 @@ public static String convertValueFromGraalFormat(Value input) { } // Find the PTXVariablePrefix corresponding to the input's platform type. - PTXVariablePrefix typePrefix = Arrays.stream(PTXVariablePrefix.values()).filter(tp -> tp.getType().equals(input.getPlatformKind().name().toLowerCase())).findFirst() - .orElseThrow(AssertionError::new); + PTXVariablePrefix typePrefix = Arrays.stream(PTXVariablePrefix.values()).filter(tp -> tp.getType().equals(input.getPlatformKind().name().toLowerCase())).findFirst().orElseThrow( + AssertionError::new); // Create the formatted index value. String indexValue = isArray ? arraylocalIndexes.get(ptxKind).toString() : String.valueOf(localIndexes.get(ptxKind)); @@ -621,6 +621,7 @@ public static class PTXBinaryOp extends PTXOp { public static final PTXBinaryOp MUL_WIDE = new PTXBinaryOp("mul.wide"); public static final PTXBinaryOp DIV = new PTXBinaryOp("div"); public static final PTXBinaryOp DIV_FULL = new PTXBinaryOp("div.full", false); + public static final PTXBinaryOp DIV_APPROX = new PTXBinaryOp("div.approx", false); public static final PTXBinaryOp REM = new PTXBinaryOp("rem", false); public static final PTXBinaryOp RELATIONAL_EQ = new PTXBinaryOp("=="); diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssemblerConstants.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssemblerConstants.java index 28daede900..8a0e44793f 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssemblerConstants.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/asm/PTXAssemblerConstants.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2020, 2022-2023, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2022-2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -12,7 +12,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -32,6 +32,7 @@ public class PTXAssemblerConstants { public static final String VECTOR = "v"; public static final String CONVERT = "cvt"; + public static final String CONVERT_RN = "cvt.rn"; public static final String CONVERT_ADDRESS = "cvta"; public static final String MOVE = "mov"; public static final String TEST_NUMBER = "testp.number"; diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java index d498a0ca40..75bc518b93 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/PTXHighTier.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2020, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -51,6 +51,7 @@ import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoLocalMemoryAllocation; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoNewArrayDevirtualizationReplacement; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPrivateArrayPiRemoval; +import uk.ac.manchester.tornado.drivers.ptx.graal.phases.TornadoHalfFloatReplacement; import uk.ac.manchester.tornado.drivers.ptx.graal.phases.TornadoPTXIntrinsicsReplacements; import uk.ac.manchester.tornado.drivers.ptx.graal.phases.TornadoParallelScheduler; import uk.ac.manchester.tornado.drivers.ptx.graal.phases.TornadoTaskSpecialisation; @@ -86,6 +87,8 @@ public PTXHighTier(OptionValues options, CanonicalizerPhase.CustomSimplification appendPhase(new TornadoNewArrayDevirtualizationReplacement()); + appendPhase(new TornadoHalfFloatReplacement()); + if (PartialEscapeAnalysis.getValue(options)) { appendPhase(new PartialEscapePhase(true, canonicalizer, options)); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java index 7554255bf3..b2c9b4b385 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXGraphBuilderPlugins.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2020, 2022, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2022, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -94,6 +94,7 @@ public static void registerInvocationPlugins(final Plugins ps, final InvocationP registerPTXBuiltinPlugins(plugins); PTXMathPlugins.registerTornadoMathPlugins(plugins); PTXVectorPlugins.registerPlugins(ps, plugins); + PTXHalfFloatPlugin.registerPlugins(ps, plugins); registerMemoryAccessPlugins(plugins); registerKernelContextPlugins(plugins); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java new file mode 100644 index 0000000000..19c1be979a --- /dev/null +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/compiler/plugins/PTXHalfFloatPlugin.java @@ -0,0 +1,111 @@ +/* + * This file is part of Tornado: A heterogeneous programming framework: + * https://github.com/beehive-lab/tornadovm + * + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.ptx.graal.compiler.plugins; + +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderContext; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugin; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins; +import org.graalvm.compiler.nodes.graphbuilderconf.NodePlugin; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.ResolvedJavaMethod; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; +import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance; +import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode; + +public class PTXHalfFloatPlugin { + + public static void registerPlugins(final GraphBuilderConfiguration.Plugins ps, final InvocationPlugins plugins) { + registerHalfFloatInit(ps, plugins); + } + + private static void registerHalfFloatInit(GraphBuilderConfiguration.Plugins ps, InvocationPlugins plugins) { + + final InvocationPlugins.Registration r = new InvocationPlugins.Registration(plugins, HalfFloat.class); + + ps.appendNodePlugin(new NodePlugin() { + @Override + public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, ValueNode[] args) { + if (method.getName().equals("") && (method.toString().contains("HalfFloat."))) { + NewHalfFloatInstance newHalfFloatInstance = b.append(new NewHalfFloatInstance(args[1])); + b.add(newHalfFloatInstance); + return true; + } + return false; + } + }); + + r.register(new InvocationPlugin("add", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + AddHalfFloatNode addNode = b.append(new AddHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, addNode); + return true; + } + }); + + r.register(new InvocationPlugin("sub", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + SubHalfFloatNode subNode = b.append(new SubHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, subNode); + return true; + } + }); + + r.register(new InvocationPlugin("mult", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + MultHalfFloatNode multNode = b.append(new MultHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, multNode); + return true; + } + }); + + r.register(new InvocationPlugin("div", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + DivHalfFloatNode divNode = b.append(new DivHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, divNode); + return true; + } + }); + + r.register(new InvocationPlugin("getHalfFloatValue", InvocationPlugin.Receiver.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) { + b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get()))); + return true; + } + }); + + } + +} diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXKind.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXKind.java index ade375fe2a..27f4e9001f 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXKind.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXKind.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -84,7 +84,7 @@ public enum PTXKind implements PlatformKind { B8(1, null), S16(2, Short.TYPE), - F16(2, null), + F16(2, Short.TYPE), U16(2, Character.TYPE), B16(2, null), @@ -151,6 +151,7 @@ public enum PTXKind implements PlatformKind { private final PTXKind kind; private final PTXKind elementKind; private final Class javaClass; + PTXKind(int size, Class javaClass) { this(size, javaClass, null); } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXLIRStmt.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXLIRStmt.java index 89306899bd..7224e547cc 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXLIRStmt.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXLIRStmt.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2020, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -12,7 +12,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -25,9 +25,11 @@ package uk.ac.manchester.tornado.drivers.ptx.graal.lir; import static uk.ac.manchester.tornado.drivers.ptx.graal.PTXCodeUtil.getFPURoundingMode; +import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler.PTXBinaryOp.DIV_APPROX; import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.ASSIGN; import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.COMMA; import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.CONVERT; +import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.CONVERT_RN; import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.CURLY_BRACKETS_CLOSE; import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.CURLY_BRACKETS_OPEN; import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssemblerConstants.DOT; @@ -73,6 +75,122 @@ public final void emitCode(CompilationResultBuilder crb) { public abstract void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm); } + @Opcode("DIVHALF") + public static class DivideHalfFloatStmt extends AbstractInstruction { + + public static final LIRInstructionClass TYPE = LIRInstructionClass.create(DivideHalfFloatStmt.class); + + @Use + protected Value dividend; + @Use + protected Value divisor; + @Def + protected Value dividendFloat; + @Def + protected Value divisorFloat; + @Def + protected Value floatResult; + @Def + protected Value halfFloatResult; + + public DivideHalfFloatStmt(Value dividend, Value divisor, Value dividendFloat, Value divisorFloat, Value floatResult, Value halfFloatResult) { + super(TYPE); + this.dividend = dividend; + this.divisor = divisor; + this.dividendFloat = dividendFloat; + this.divisorFloat = divisorFloat; + this.floatResult = floatResult; + this.halfFloatResult = halfFloatResult; + } + + @Override + public void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm) { + // convert divident from half-float to float + asm.emitSymbol(TAB); + asm.emit(CONVERT + DOT + dividendFloat.getPlatformKind() + DOT + dividend.getPlatformKind()); + asm.emitSymbol(SPACE); + asm.emitValue(dividendFloat); + asm.emitSymbol(COMMA + SPACE); + asm.emitValue(dividend); + asm.delimiter(); + asm.eol(); + //convert divisor from half-float to float + asm.emitSymbol(TAB); + asm.emit(CONVERT + DOT + divisorFloat.getPlatformKind() + DOT + divisor.getPlatformKind()); + asm.emitSymbol(SPACE); + asm.emitValue(divisorFloat); + asm.emitSymbol(COMMA + SPACE); + asm.emitValue(divisor); + asm.delimiter(); + asm.eol(); + // divide the two float values + asm.emitSymbol(TAB); + asm.emit(DIV_APPROX + DOT + floatResult.getPlatformKind()); + asm.emitSymbol(SPACE); + asm.emitValue(floatResult); + asm.emitSymbol(COMMA + SPACE); + asm.emitValue(dividendFloat); + asm.emitSymbol(COMMA + SPACE); + asm.emitValue(divisorFloat); + asm.delimiter(); + asm.eol(); + //convert the result from float to half-float + asm.emitSymbol(TAB); + asm.emit(CONVERT_RN + DOT + halfFloatResult.getPlatformKind() + DOT + floatResult.getPlatformKind()); + asm.emitSymbol(SPACE); + asm.emitValue(halfFloatResult); + asm.emitSymbol(COMMA + SPACE); + asm.emitValue(floatResult); + asm.delimiter(); + asm.eol(); + } + + } + + @Opcode("CONVERTHALF") + public static class ConvertHalfFloatStmt extends AbstractInstruction { + + public static final LIRInstructionClass TYPE = LIRInstructionClass.create(ConvertHalfFloatStmt.class); + + @Def + protected Value halfFloatVariable; + @Use + protected Value input; + @Def + protected Value intermediate; + + public ConvertHalfFloatStmt(Value halfFloatVariable, Value input, Value intermediate) { + super(TYPE); + this.halfFloatVariable = halfFloatVariable; + this.input = input; + this.intermediate = intermediate; + } + + @Override + public void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm) { + // move value to a float variable + asm.emitSymbol(TAB); + asm.emit(MOVE + DOT + intermediate.getPlatformKind()); + asm.emitSymbol(SPACE); + asm.emitValue(intermediate); + asm.emitSymbol(COMMA + SPACE); + asm.emitValue(input); + asm.delimiter(); + asm.eol(); + + // convert float to half float + asm.emitSymbol(TAB); + asm.emit(CONVERT_RN + DOT + halfFloatVariable.getPlatformKind() + DOT + intermediate.getPlatformKind()); + asm.emitSymbol(SPACE); + asm.emitValue(halfFloatVariable); + asm.emitSymbol(COMMA + SPACE); + asm.emitValue(intermediate); + asm.delimiter(); + asm.eol(); + } + + } + @Opcode("ASSIGN") public static class AssignStmt extends AbstractInstruction { @@ -265,6 +383,43 @@ public void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm) { } } + @Opcode("HALFLOAD") + public static class HalfFloatLoadStmt extends AbstractInstruction { + public static final LIRInstructionClass TYPE = LIRInstructionClass.create(HalfFloatLoadStmt.class); + + @Use + protected Variable dest; + + @Use + PTXUnary.MemoryAccess address; + + @Use + PTXNullaryOp loadOp; + + public HalfFloatLoadStmt(PTXUnary.MemoryAccess address, Variable dest, PTXNullaryOp op) { + super(TYPE); + this.dest = dest; + this.loadOp = op; + this.address = address; + } + + @Override + public void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm) { + loadOp.emit(crb, null); + asm.emitSymbol(DOT); + asm.emit(address.getBase().memorySpace.getName()); + asm.emit(DOT + PTXKind.B16); + asm.emitSymbol(TAB); + + asm.emitValue(dest); + asm.emitSymbol(COMMA); + asm.space(); + address.emit(crb, asm, null); + asm.delimiter(); + asm.eol(); + } + } + @Opcode("VLOAD") public static class VectorLoadStmt extends AbstractInstruction { @@ -363,6 +518,44 @@ public PTXUnary.MemoryAccess getAddress() { } } + @Opcode("STOREHALF") + public static class HalfFloatStoreStmt extends AbstractInstruction { + + public static final LIRInstructionClass TYPE = LIRInstructionClass.create(HalfFloatStoreStmt.class); + + @Use + protected Value rhs; + @Use + protected PTXUnary.MemoryAccess address; + + public HalfFloatStoreStmt(PTXUnary.MemoryAccess address, Value rhs) { + super(TYPE); + this.rhs = rhs; + this.address = address; + } + + public void emitNormalCode(PTXCompilationResultBuilder crb, PTXAssembler asm) { + PTXNullaryOp.ST.emit(crb, null); + asm.emitSymbol(DOT); + asm.emit(address.getBase().memorySpace.getName()); + asm.emit(DOT + PTXKind.B16); + asm.emitSymbol(TAB); + + address.emit(crb, asm, null); + asm.emitSymbol(COMMA); + asm.space(); + + asm.emitValueOrOp(crb, rhs, null); + asm.delimiter(); + asm.eol(); + } + + @Override + public void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm) { + emitNormalCode(crb, asm); + } + } + @Opcode("VSTORE") public static class VectorStoreStmt extends AbstractInstruction { diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/PTXHalfFloatDivisionNode.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/PTXHalfFloatDivisionNode.java new file mode 100644 index 0000000000..953ae41073 --- /dev/null +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/PTXHalfFloatDivisionNode.java @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.ptx.graal.nodes; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.Value; +import org.graalvm.compiler.core.common.LIRKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.lir.Variable; +import org.graalvm.compiler.lir.gen.LIRGeneratorTool; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.calc.FloatingNode; +import org.graalvm.compiler.nodes.spi.LIRLowerable; +import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXLIRStmt; + +@NodeInfo +public class PTXHalfFloatDivisionNode extends FloatingNode implements LIRLowerable { + + public static final NodeClass TYPE = NodeClass.create(PTXHalfFloatDivisionNode.class); + + @Input + private ValueNode dividend; + @Input + private ValueNode divisor; + + public PTXHalfFloatDivisionNode(ValueNode dividend, ValueNode divisor) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.dividend = dividend; + this.divisor = divisor; + } + + @Override + public void generate(NodeLIRBuilderTool generator) { + LIRGeneratorTool tool = generator.getLIRGeneratorTool(); + Value dividendValue = generator.operand(dividend); + Value divisorValue = generator.operand(divisor); + + Variable dividendConvertedToFloat = tool.newVariable(LIRKind.value(PTXKind.F32)); + Variable divisorConvertedToFloat = tool.newVariable(LIRKind.value(PTXKind.F32)); + Variable floatResult = tool.newVariable(LIRKind.value(PTXKind.F32)); + + Variable halfFloatResult = tool.newVariable(LIRKind.value(PTXKind.F16)); + + tool.append(new PTXLIRStmt.DivideHalfFloatStmt(dividendValue, divisorValue, dividendConvertedToFloat, divisorConvertedToFloat, floatResult, halfFloatResult)); + generator.setResult(this, halfFloatResult); + } +} diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/ReadHalfFloatNode.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/ReadHalfFloatNode.java new file mode 100644 index 0000000000..a4cb9f560c --- /dev/null +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/ReadHalfFloatNode.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.ptx.graal.nodes; + +import org.graalvm.compiler.core.common.LIRKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.lir.Variable; +import org.graalvm.compiler.lir.gen.LIRGeneratorTool; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.FixedWithNextNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.nodes.spi.LIRLowerable; +import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.Value; +import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXLIRStmt; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXUnary; + +@NodeInfo +public class ReadHalfFloatNode extends FixedWithNextNode implements LIRLowerable { + + public static final NodeClass TYPE = NodeClass.create(ReadHalfFloatNode.class); + + @Input + private AddressNode addressNode; + + public ReadHalfFloatNode(AddressNode addressNode) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.addressNode = addressNode; + } + + public void generate(NodeLIRBuilderTool generator) { + LIRGeneratorTool tool = generator.getLIRGeneratorTool(); + Variable result = tool.newVariable(LIRKind.value(PTXKind.F16)); + Value addressValue = generator.operand(addressNode); + tool.append(new PTXLIRStmt.HalfFloatLoadStmt((PTXUnary.MemoryAccess) addressValue, result, PTXAssembler.PTXNullaryOp.LD)); + generator.setResult(this, result); + } +} diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/WriteHalfFloatNode.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/WriteHalfFloatNode.java new file mode 100644 index 0000000000..52a313146a --- /dev/null +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/nodes/WriteHalfFloatNode.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.ptx.graal.nodes; + +import org.graalvm.compiler.core.common.LIRKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.lir.Variable; +import org.graalvm.compiler.lir.gen.LIRGeneratorTool; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.FixedWithNextNode; +import org.graalvm.compiler.nodes.NodeView; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.nodes.spi.LIRLowerable; +import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.Value; +import uk.ac.manchester.tornado.drivers.ptx.graal.compiler.PTXLIRGenerationResult; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXLIRStmt; +import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXUnary; +import uk.ac.manchester.tornado.runtime.graph.nodes.ConstantNode; + +@NodeInfo +public class WriteHalfFloatNode extends FixedWithNextNode implements LIRLowerable { + + public static final NodeClass TYPE = NodeClass.create(WriteHalfFloatNode.class); + + @Input + private AddressNode addressNode; + + @Input + private ValueNode valueNode; + + public WriteHalfFloatNode(AddressNode addressNode, ValueNode valueNode) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.addressNode = addressNode; + this.valueNode = valueNode; + } + + public void generate(NodeLIRBuilderTool generator) { + LIRGeneratorTool tool = generator.getLIRGeneratorTool(); + Value valueToStore; + if (valueNode.stamp(NodeView.DEFAULT).isFloatStamp()) { + // the value to be written is in float format, so the bytecodes to convert + // to half float need to be generated + Value value = generator.operand(valueNode); + Variable intermediate = tool.newVariable(LIRKind.value(PTXKind.F32)); + Variable result = tool.newVariable(LIRKind.value(PTXKind.F16)); + tool.append(new PTXLIRStmt.ConvertHalfFloatStmt(result, value, intermediate)); + valueToStore = result; + } else { + valueToStore = generator.operand(valueNode); + } + Value addressValue = generator.operand(addressNode); + PTXUnary.MemoryAccess access = (PTXUnary.MemoryAccess) addressValue; + tool.append(new PTXLIRStmt.HalfFloatStoreStmt(access, valueToStore)); + } +} diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java new file mode 100644 index 0000000000..d34e127416 --- /dev/null +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.ptx.graal.phases; + +import java.util.Optional; + +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.nodes.GraphState; +import org.graalvm.compiler.nodes.StructuredGraph; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.calc.AddNode; +import org.graalvm.compiler.nodes.calc.FloatDivNode; +import org.graalvm.compiler.nodes.calc.MulNode; +import org.graalvm.compiler.nodes.calc.SubNode; +import org.graalvm.compiler.nodes.extended.JavaReadNode; +import org.graalvm.compiler.nodes.extended.JavaWriteNode; +import org.graalvm.compiler.nodes.java.NewInstanceNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.phases.BasePhase; + +import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXHalfFloatDivisionNode; +import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.ReadHalfFloatNode; +import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.WriteHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; +import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance; +import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext; + +public class TornadoHalfFloatReplacement extends BasePhase { + + @Override + public Optional notApplicableTo(GraphState graphState) { + return ALWAYS_APPLICABLE; + } + + protected void run(StructuredGraph graph, TornadoHighTierContext context) { + + // replace reads with halfFloat reads + for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) { + if (javaRead.successors().first() instanceof NewInstanceNode) { + NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first(); + if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) { + if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) { + NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first(); + deleteFixed(newHalfFloatInstance); + } + AddressNode readingAddress = javaRead.getAddress(); + ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress); + graph.addWithoutUnique(readHalfFloatNode); + replaceFixed(javaRead, readHalfFloatNode); + newInstanceNode.replaceAtUsages(readHalfFloatNode); + deleteFixed(newInstanceNode); + } + } + } + + // replace writes with halfFloat writes + for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) { + if (isWriteHalfFloat(javaWrite)) { + // This casting is safe to do as it is already checked by the isWriteHalfFloat function + HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value(); + ValueNode writingValue; + if (javaWrite.predecessor() instanceof NewHalfFloatInstance) { + // if a new HalfFloat instance is written + NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor(); + writingValue = newHalfFloatInstance.getValue(); + if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) { + NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor(); + if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) { + deleteFixed(newInstanceNode); + deleteFixed(newHalfFloatInstance); + } + } + } else { + // if the result of an operation or a stored value is written + writingValue = placeholder.getInput(); + } + placeholder.replaceAtUsages(writingValue); + placeholder.safeDelete(); + AddressNode writingAddress = javaWrite.getAddress(); + WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue); + graph.addWithoutUnique(writeHalfFloatNode); + replaceFixed(javaWrite, writeHalfFloatNode); + deleteFixed(javaWrite); + } + } + + // replace the half float operator nodes with the corresponding regular operators + replaceAddHalfFloatNodes(graph); + replaceSubHalfFloatNodes(graph); + replaceMultHalfFloatNodes(graph); + replaceDivHalfFloatNodes(graph); + + } + + private static void replaceAddHalfFloatNodes(StructuredGraph graph) { + for (AddHalfFloatNode addHalfFloatNode : graph.getNodes().filter(AddHalfFloatNode.class)) { + AddNode addNode = new AddNode(addHalfFloatNode.getX(), addHalfFloatNode.getY()); + graph.addWithoutUnique(addNode); + addHalfFloatNode.replaceAtUsages(addNode); + addHalfFloatNode.safeDelete(); + } + } + + private static void replaceSubHalfFloatNodes(StructuredGraph graph) { + for (SubHalfFloatNode subHalfFloatNode : graph.getNodes().filter(SubHalfFloatNode.class)) { + SubNode subNode = new SubNode(subHalfFloatNode.getX(), subHalfFloatNode.getY()); + graph.addWithoutUnique(subNode); + subHalfFloatNode.replaceAtUsages(subNode); + subHalfFloatNode.safeDelete(); + } + } + + private static void replaceMultHalfFloatNodes(StructuredGraph graph) { + for (MultHalfFloatNode multHalfFloatNode : graph.getNodes().filter(MultHalfFloatNode.class)) { + MulNode mulNode = new MulNode(multHalfFloatNode.getX(), multHalfFloatNode.getY()); + graph.addWithoutUnique(mulNode); + multHalfFloatNode.replaceAtUsages(mulNode); + multHalfFloatNode.safeDelete(); + } + } + + private static void replaceDivHalfFloatNodes(StructuredGraph graph) { + for (DivHalfFloatNode divHalfFloatNode : graph.getNodes().filter(DivHalfFloatNode.class)) { + PTXHalfFloatDivisionNode divNode = new PTXHalfFloatDivisionNode(divHalfFloatNode.getX(), divHalfFloatNode.getY()); + graph.addWithoutUnique(divNode); + divHalfFloatNode.replaceAtUsages(divNode); + divHalfFloatNode.safeDelete(); + } + } + + private static boolean isWriteHalfFloat(JavaWriteNode javaWrite) { + if (javaWrite.value() instanceof HalfFloatPlaceholder) { + return true; + } + return false; + } + + private static void replaceFixed(Node n, Node other) { + Node pred = n.predecessor(); + Node suc = n.successors().first(); + + n.replaceFirstSuccessor(suc, null); + n.replaceAtPredecessor(other); + pred.replaceFirstSuccessor(n, other); + other.replaceFirstSuccessor(null, suc); + + for (Node us : n.usages()) { + n.removeUsage(us); + } + n.clearInputs(); + n.safeDelete(); + + } + + private static void deleteFixed(Node node) { + if (!node.isDeleted()) { + Node predecessor = node.predecessor(); + Node successor = node.successors().first(); + + node.replaceFirstSuccessor(successor, null); + node.replaceAtPredecessor(successor); + predecessor.replaceFirstSuccessor(node, successor); + + for (Node us : node.usages()) { + node.removeUsage(us); + } + node.clearInputs(); + node.safeDelete(); + } + } +} diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java index 0eedf358a6..ba278eb8e1 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/mm/PTXMemorySegmentWrapper.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2023, APT Group, Department of Computer Science, + * Copyright (c) 2023, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -27,18 +27,6 @@ import java.util.ArrayList; import java.util.List; -import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError; -import uk.ac.manchester.tornado.api.exceptions.TornadoMemoryException; -import uk.ac.manchester.tornado.api.exceptions.TornadoOutOfMemoryException; -import uk.ac.manchester.tornado.api.memory.ObjectBuffer; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.CharArray; -import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.IntArray; -import uk.ac.manchester.tornado.api.types.arrays.LongArray; -import uk.ac.manchester.tornado.api.types.arrays.ShortArray; -import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; import uk.ac.manchester.tornado.api.types.collections.VectorDouble2; import uk.ac.manchester.tornado.api.types.collections.VectorDouble3; import uk.ac.manchester.tornado.api.types.collections.VectorDouble4; @@ -51,6 +39,19 @@ import uk.ac.manchester.tornado.api.types.collections.VectorInt3; import uk.ac.manchester.tornado.api.types.collections.VectorInt4; import uk.ac.manchester.tornado.api.types.collections.VectorInt8; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.CharArray; +import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; +import uk.ac.manchester.tornado.api.types.arrays.LongArray; +import uk.ac.manchester.tornado.api.types.arrays.ShortArray; +import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; +import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError; +import uk.ac.manchester.tornado.api.exceptions.TornadoMemoryException; +import uk.ac.manchester.tornado.api.exceptions.TornadoOutOfMemoryException; +import uk.ac.manchester.tornado.api.memory.ObjectBuffer; import uk.ac.manchester.tornado.drivers.ptx.PTXDeviceContext; import uk.ac.manchester.tornado.runtime.common.Tornado; import uk.ac.manchester.tornado.runtime.common.TornadoLogger; @@ -114,6 +115,7 @@ private MemorySegment getSegment(final Object reference) { case ShortArray shortArray -> shortArray.getSegment(); case ByteArray byteArray -> byteArray.getSegment(); case CharArray charArray -> charArray.getSegment(); + case HalfFloatArray halfFloatArray -> halfFloatArray.getSegment(); case VectorFloat2 vectorFloat2 -> vectorFloat2.getArray().getSegment(); case VectorFloat3 vectorFloat3 -> vectorFloat3.getArray().getSegment(); case VectorFloat4 vectorFloat4 -> vectorFloat4.getArray().getSegment(); diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java index 54a72bb1a9..16d8509e22 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2020, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -52,6 +52,7 @@ import uk.ac.manchester.tornado.api.types.arrays.CharArray; import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.api.types.arrays.LongArray; import uk.ac.manchester.tornado.api.types.arrays.ShortArray; @@ -272,6 +273,8 @@ private ObjectBuffer createDeviceBuffer(Class type, Object object, long batch result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); } else if (object instanceof CharArray) { result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); + } else if (object instanceof HalfFloatArray) { + result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); } else { result = new PTXObjectWrapper(getDeviceContext(), object); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVBackend.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVBackend.java index 29b980e5c2..6572018d84 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVBackend.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVBackend.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021-2022, APT Group, Department of Computer Science, + * Copyright (c) 2021-2022, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -104,13 +104,16 @@ import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVSourceLanguage; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVStorageClass; import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException; import uk.ac.manchester.tornado.api.exceptions.TornadoDeviceFP64NotSupported; import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError; +import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.api.profiler.ProfilerType; import uk.ac.manchester.tornado.api.profiler.TornadoProfiler; import uk.ac.manchester.tornado.drivers.common.logging.Logger; import uk.ac.manchester.tornado.drivers.common.utils.BackendDeopt; import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.FPGAWorkGroupSizeNode; +import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.WriteHalfFloatNode; import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVArchitecture; import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVCodeProvider; import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVFrameContext; @@ -148,6 +151,8 @@ public class SPIRVBackend extends TornadoBackend implements Fram private SPIRVId pointerToULongFunction; private SPIRVInstScope blockScope; private boolean fp64CapabilityEnabled; + + private boolean fp16CapabilityEnabled; private boolean supportsFP64; private AtomicInteger methodIndex; @@ -348,6 +353,11 @@ private void emitFP64Capability(SPIRVModule module) { fp64CapabilityEnabled = true; } + private void emitFP16Capability(SPIRVModule module) { + module.add(new SPIRVOpCapability(SPIRVCapability.Float16Buffer())); // To use FP16 + fp16CapabilityEnabled = true; + } + private void emitSPIRVCapabilities(SPIRVModule module) { // Emit Capabilities module.add(new SPIRVOpCapability(SPIRVCapability.Addresses())); // Uses physical addressing, non-logical addressing modes. @@ -376,7 +386,9 @@ private void emitOpSourceForOpenCL(SPIRVModule module, int version) { } private SPIRVLiteralContextDependentNumber buildLiteralContextNumber(SPIRVKind kind, Constant value) { - if (kind == SPIRVKind.OP_TYPE_INT_32) { + if (kind == SPIRVKind.OP_TYPE_FLOAT_16) { + return new SPIRVContextDependentFloat(Float.parseFloat(value.toValueString())); + } else if (kind == SPIRVKind.OP_TYPE_INT_32) { return new SPIRVContextDependentInt(BigInteger.valueOf(Integer.parseInt(value.toValueString()))); } else if (kind == SPIRVKind.OP_TYPE_INT_64) { return new SPIRVContextDependentLong(BigInteger.valueOf(Long.parseLong(value.toValueString()))); @@ -385,7 +397,7 @@ private SPIRVLiteralContextDependentNumber buildLiteralContextNumber(SPIRVKind k } else if (kind == SPIRVKind.OP_TYPE_FLOAT_64) { return new SPIRVContextDependentDouble(Double.parseDouble(value.toValueString())); } else { - throw new RuntimeException("SPIRV - SPIRVLiteralContextDependentNumber Type not supported"); + throw new TornadoRuntimeException("SPIRV - SPIRVLiteralContextDependentNumber Type not supported"); } } @@ -421,14 +433,12 @@ private IDTable emitVariableDefs(SPIRVCompilationResultBuilder crb, SPIRVAssembl for (LIRInstruction lirInstruction : lir.getLIRforBlock(lir.getBlockById(block))) { lirInstruction.forEachOutput((instruction, value, mode, flags) -> { - if (value instanceof ArrayVariable) { + if (value instanceof ArrayVariable variable) { // All function variables, including arrays, must be defined a consecutive block // of instructions from the block 0. We detect array declaration and define // these as array for the SPIR-V Function StorageClass. - ArrayVariable variable = (ArrayVariable) value; resultArrays.add(variable); - } else if (value instanceof Variable) { - Variable variable = (Variable) value; + } else if (value instanceof Variable variable) { if (variable.toString() != null) { addVariableDef(kindToVariable, variable); variableCount.incrementAndGet(); @@ -606,6 +616,9 @@ private void emitPrologueForMainKernel(SPIRVCompilationResultBuilder crb, SPIRVA if (idTable.kindToVariable.containsKey(SPIRVKind.OP_TYPE_FLOAT_64)) { emitFP64Capability(asm.module); } + if (idTable.kindToVariable.containsKey(SPIRVKind.OP_TYPE_FLOAT_16)) { + emitFP16Capability(asm.module); + } // Emit the Store between from the parameter value and the local variable // assigned @@ -766,7 +779,7 @@ private void emitPrologueForMainKernelEntry(SPIRVCompilationResultBuilder crb, S final ControlFlowGraph cfg = (ControlFlowGraph) lir.getControlFlowGraph(); if (cfg.getStartBlock().getEndNode().predecessor() instanceof FPGAWorkGroupSizeNode) { - throw new RuntimeException("FPGA Thread Attributes not supported yet."); + throw new TornadoBailoutRuntimeException("FPGA Thread Attributes not supported yet."); } emitSPIRVCapabilities(module); @@ -783,6 +796,9 @@ private void emitPrologueForMainKernelEntry(SPIRVCompilationResultBuilder crb, S if (idTable.kindToVariable.containsKey(SPIRVKind.OP_TYPE_FLOAT_64)) { emitFP64Capability(module); } + if (idTable.kindToVariable.containsKey(SPIRVKind.OP_TYPE_FLOAT_16)) { + emitFP16Capability(module); + } // ---------------------------------- // Emit Entry Kernel @@ -790,7 +806,7 @@ private void emitPrologueForMainKernelEntry(SPIRVCompilationResultBuilder crb, S if (fp64CapabilityEnabled && !supportsFP64) { throw new TornadoDeviceFP64NotSupported("Error - The current SPIR-V device does not support FP64"); } - asm.emitEntryPointMainKernel(cfg.graph, method.getName(), fp64CapabilityEnabled); + asm.emitEntryPointMainKernel(cfg.graph, method.getName(), fp64CapabilityEnabled, fp16CapabilityEnabled); // Add all KINDS we generate the corresponding declaration for (SPIRVKind kind : idTable.kindToVariable.keySet()) { @@ -808,12 +824,18 @@ private void emitPrologueForMainKernelEntry(SPIRVCompilationResultBuilder crb, S JavaKind stackKind = constantNode.getStackKind(); Constant value = constantNode.getValue(); - SPIRVKind kind = SPIRVKind.fromJavaKind(stackKind); + SPIRVKind kind; + if (constantNode.usages().filter(WriteHalfFloatNode.class).isNotEmpty()) { + kind = SPIRVKind.OP_TYPE_FLOAT_16; + } else { + kind = SPIRVKind.fromJavaKind(stackKind); + } + SPIRVId typeId; if (kind.isPrimitive()) { typeId = asm.primitives.getTypePrimitive(kind); } else { - throw new RuntimeException("Type not supported"); + throw new TornadoRuntimeException("Type not supported"); } SPIRVLiteralContextDependentNumber literalNumber = buildLiteralContextNumber(kind, value); @@ -828,7 +850,7 @@ private void emitPrologueForMainKernelEntry(SPIRVCompilationResultBuilder crb, S while (!stack.isEmpty()) { TypeConstant t = stack.pop(); SPIRVId idConstant = module.getNextId(); - module.add(new SPIRVOpConstant(t.typeID, idConstant, t.n)); + module.add(new SPIRVOpConstant(t.typeID, idConstant, t.literalContextNumber)); asm.getConstants().put(new SPIRVAssembler.ConstantKeyPair(t.valueString, t.kind), idConstant); } @@ -869,13 +891,13 @@ private static class SPIRV_HEADER_VALUES { private static class TypeConstant { public SPIRVId typeID; - public SPIRVLiteralContextDependentNumber n; + public SPIRVLiteralContextDependentNumber literalContextNumber; public String valueString; public SPIRVKind kind; - public TypeConstant(SPIRVId typeID, SPIRVLiteralContextDependentNumber n, String valueString, SPIRVKind kind) { + public TypeConstant(SPIRVId typeID, SPIRVLiteralContextDependentNumber literal, String valueString, SPIRVKind kind) { this.typeID = typeID; - this.n = n; + this.literalContextNumber = literal; this.valueString = valueString; this.kind = kind; } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVPrimitiveTypes.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVPrimitiveTypes.java index 658333d14e..acc958ca8e 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVPrimitiveTypes.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVPrimitiveTypes.java @@ -40,6 +40,7 @@ import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVId; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVLiteralInteger; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVStorageClass; +import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVKind; public class SPIRVPrimitiveTypes { @@ -118,6 +119,7 @@ public SPIRVId getTypePrimitive(SPIRVKind primitive) { break; case OP_TYPE_FLOAT_16: if (!capabilities.contains(primitive)) { + module.add(new SPIRVOpCapability(SPIRVCapability.Float16Buffer())); module.add(new SPIRVOpCapability(SPIRVCapability.Float16())); } module.add(new SPIRVOpTypeFloat(typeID, new SPIRVLiteralInteger(sizeInBytes))); @@ -132,7 +134,7 @@ public SPIRVId getTypePrimitive(SPIRVKind primitive) { module.add(new SPIRVOpTypeFloat(typeID, new SPIRVLiteralInteger(sizeInBytes))); break; default: - throw new RuntimeException("DataType Not supported yet"); + throw new TornadoRuntimeException("DataType Not supported yet"); } primitives.put(primitive, typeID); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/asm/SPIRVAssembler.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/asm/SPIRVAssembler.java index f496c9eba0..2be6c8252f 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/asm/SPIRVAssembler.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/asm/SPIRVAssembler.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021-2022, APT Group, Department of Computer Science, + * Copyright (c) 2021-2022, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -13,7 +13,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -81,6 +81,7 @@ import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.SPIRVOpTypePointer; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVContextDependentDouble; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVContextDependentFloat; +import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVContextDependentHalfFloat; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVContextDependentInt; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVContextDependentLong; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVDecoration; @@ -95,6 +96,7 @@ import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVMultipleOperands; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVOptionalOperand; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVStorageClass; +import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.drivers.spirv.SPIRVPrimitiveTypes; import uk.ac.manchester.tornado.drivers.spirv.SPIRVThreadBuiltIn; import uk.ac.manchester.tornado.drivers.spirv.graal.compiler.SPIRVCompilationResultBuilder; @@ -320,7 +322,7 @@ public SPIRVId getFunctionPtrToLocalArray(SPIRVId resultArrayId) { * If we want to return the same names per module, just return the labelName. * * @param labelName - * String + * String * @return a new label name. */ public String composeUniqueLabelName(String labelName) { @@ -392,7 +394,7 @@ private SPIRVId createNewFunctionAndUpdateTables(SPIRVId returnType, SPIRVId... * follows: * * - * Map<%returnType, Map>> + * Map<%returnType, Map>> * * * If we have the same number of parameters with the same return type, when we @@ -400,9 +402,9 @@ private SPIRVId createNewFunctionAndUpdateTables(SPIRVId returnType, SPIRVId... * parameter (stored in the {@link FunctionTable ) class). * * @param returnType - * ID with the return value. + * ID with the return value. * @param operands - * List of IDs for the operads. + * List of IDs for the operads. * @return A {@link SPIRVId} for the {@link SPIRVOpFunction} */ public SPIRVId emitOpTypeFunction(SPIRVId returnType, SPIRVId... operands) { @@ -445,7 +447,7 @@ public SPIRVId emitOpTypeFunction(SPIRVId returnType, SPIRVId... operands) { return functionSignature; } - public void emitEntryPointMainKernel(StructuredGraph graph, String kernelName, boolean fp64Capability) { + public void emitEntryPointMainKernel(StructuredGraph graph, String kernelName, boolean fp64Capability, boolean fp16Capability) { mainFunctionID = module.getNextId(); SPIRVMultipleOperands operands; @@ -480,7 +482,7 @@ public void emitEntryPointMainKernel(StructuredGraph graph, String kernelName, b operands = new SPIRVMultipleOperands(array); } - if (fp64Capability) { + if (fp64Capability && fp16Capability) { module.add(new SPIRVOpExecutionMode(mainFunctionID, SPIRVExecutionMode.ContractionOff())); } @@ -545,23 +547,13 @@ public SPIRVId emitConstantValue(SPIRVKind type, String valueConstant) { SPIRVId newConstantId = module.getNextId(); SPIRVId typeID = primitives.getTypePrimitive(type); switch (type) { - case OP_TYPE_INT_8: - case OP_TYPE_INT_16: - case OP_TYPE_INT_32: - module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentInt(BigInteger.valueOf(Integer.parseInt(valueConstant))))); - break; - case OP_TYPE_INT_64: - module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentLong(BigInteger.valueOf(Integer.parseInt(valueConstant))))); - break; - case OP_TYPE_FLOAT_16: - case OP_TYPE_FLOAT_32: - module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentFloat(Float.parseFloat(valueConstant)))); - break; - case OP_TYPE_FLOAT_64: - module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentDouble(Double.parseDouble(valueConstant)))); - break; - default: - throw new RuntimeException("Data type not supported yet: " + type); + case OP_TYPE_INT_8, OP_TYPE_INT_16, OP_TYPE_INT_32 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentInt(BigInteger.valueOf(Integer.parseInt( + valueConstant))))); + case OP_TYPE_INT_64 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentLong(BigInteger.valueOf(Integer.parseInt(valueConstant))))); + case OP_TYPE_FLOAT_16 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentHalfFloat(Float.floatToFloat16(Float.parseFloat(valueConstant))))); + case OP_TYPE_FLOAT_32 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentFloat(Float.parseFloat(valueConstant)))); + case OP_TYPE_FLOAT_64 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentDouble(Double.parseDouble(valueConstant)))); + default -> throw new TornadoRuntimeException(STR."Data type not supported yet: \{type}"); } return newConstantId; } @@ -634,8 +626,8 @@ public void emitValue(SPIRVCompilationResultBuilder crb, Value value) { } public void emitValueOrOp(SPIRVCompilationResultBuilder crb, Value value) { - if (value instanceof SPIRVLIROp) { - ((SPIRVLIROp) value).emit(crb, this); + if (value instanceof SPIRVLIROp spirvValue) { + spirvValue.emit(crb, this); } else { emitValue(crb, value); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java index 60fbbce6be..e2d8e5c775 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021, APT Group, Department of Computer Science, + * Copyright (c) 2021, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -52,6 +52,7 @@ import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoLocalMemoryAllocation; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoNewArrayDevirtualizationReplacement; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPrivateArrayPiRemoval; +import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoHalfFloatReplacement; import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoParallelScheduler; import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoSPIRVIntrinsicsReplacements; import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoTaskSpecialization; @@ -88,6 +89,8 @@ public SPIRVHighTier(OptionValues options, TornadoDeviceContext deviceContext, C appendPhase(new TornadoNewArrayDevirtualizationReplacement()); + appendPhase(new TornadoHalfFloatReplacement()); + if (PartialEscapeAnalysis.getValue(options)) { appendPhase(new PartialEscapePhase(true, canonicalizer, options)); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java index ef77c9a9a1..9d676d14b2 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021, APT Group, Department of Computer Science, + * Copyright (c) 2021, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -13,7 +13,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -119,6 +119,7 @@ protected Variable emitAdd(LIRKind resultKind, Value a, Value b, boolean setFlag case OP_TYPE_INT_32: binaryOp = SPIRVBinaryOp.ADD_INTEGER; break; + case OP_TYPE_FLOAT_16: case OP_TYPE_FLOAT_32: case OP_TYPE_FLOAT_64: binaryOp = SPIRVBinaryOp.ADD_FLOAT; @@ -158,6 +159,7 @@ protected Variable emitSub(LIRKind resultKind, Value a, Value b, boolean setFlag break; case OP_TYPE_FLOAT_64: case OP_TYPE_FLOAT_32: + case OP_TYPE_FLOAT_16: binaryOp = SPIRVBinaryOp.SUB_FLOAT; break; default: @@ -204,6 +206,7 @@ public Value emitMul(Value a, Value b, boolean setFlags) { break; case OP_TYPE_FLOAT_64: case OP_TYPE_FLOAT_32: + case OP_TYPE_FLOAT_16: binaryOp = SPIRVBinaryOp.MULT_FLOAT; break; default: @@ -251,6 +254,7 @@ public Value emitDiv(Value a, Value b, LIRFrameState state) { break; case OP_TYPE_FLOAT_64: case OP_TYPE_FLOAT_32: + case OP_TYPE_FLOAT_16: binaryOp = SPIRVBinaryOp.DIV_FLOAT; break; default: diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java index 925fea9791..a90b55415d 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021-2022, APT Group, Department of Computer Science, + * Copyright (c) 2021-2022, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -109,6 +109,7 @@ public static void registerInvocationPlugins(Plugins plugins, final InvocationPl SPIRVMathPlugins.registerTornadoMathPlugins(invocationPlugins); SPIRVVectorPlugins.registerPlugins(plugins, invocationPlugins); + SPIRVHalfFloatPlugins.registerPlugins(plugins, invocationPlugins); // Register plugins for Off-Heap Arrays with Panama registerMemoryAccessPlugins(invocationPlugins); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVHalfFloatPlugins.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVHalfFloatPlugins.java new file mode 100644 index 0000000000..96c59e9d7a --- /dev/null +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVHalfFloatPlugins.java @@ -0,0 +1,112 @@ +/* + * This file is part of Tornado: A heterogeneous programming framework: + * https://github.com/beehive-lab/tornadovm + * + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.spirv.graal.compiler.plugins; + +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderContext; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugin; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins; +import org.graalvm.compiler.nodes.graphbuilderconf.NodePlugin; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.ResolvedJavaMethod; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; +import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance; +import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode; + +public class SPIRVHalfFloatPlugins { + + public static void registerPlugins(final GraphBuilderConfiguration.Plugins ps, final InvocationPlugins plugins) { + registerHalfFloatInit(ps, plugins); + } + + private static void registerHalfFloatInit(GraphBuilderConfiguration.Plugins ps, InvocationPlugins plugins) { + + final InvocationPlugins.Registration r = new InvocationPlugins.Registration(plugins, HalfFloat.class); + + ps.appendNodePlugin(new NodePlugin() { + @Override + public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, ValueNode[] args) { + if (method.getName().equals("") && (method.toString().contains("HalfFloat."))) { + NewHalfFloatInstance newHalfFloatInstance = b.append(new NewHalfFloatInstance(args[1])); + b.add(newHalfFloatInstance); + return true; + } + return false; + } + }); + + r.register(new InvocationPlugin("add", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + AddHalfFloatNode addNode = b.append(new AddHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, addNode); + return true; + } + }); + + r.register(new InvocationPlugin("sub", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + SubHalfFloatNode subNode = b.append(new SubHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, subNode); + return true; + } + }); + + r.register(new InvocationPlugin("mult", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + MultHalfFloatNode multNode = b.append(new MultHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, multNode); + return true; + } + }); + + r.register(new InvocationPlugin("div", HalfFloat.class, HalfFloat.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode halfFloat1, ValueNode halfFloat2) { + DivHalfFloatNode divNode = b.append(new DivHalfFloatNode(halfFloat1, halfFloat2)); + b.push(JavaKind.Object, divNode); + return true; + } + }); + + r.register(new InvocationPlugin("getHalfFloatValue", InvocationPlugin.Receiver.class) { + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver) { + b.push(JavaKind.Short, b.append(new HalfFloatPlaceholder(receiver.get()))); + return true; + } + }); + + } + +} diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVBinary.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVBinary.java index c589ddc9fb..59d9fc5364 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVBinary.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVBinary.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021, APT Group, Department of Computer Science, + * Copyright (c) 2021, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -12,7 +12,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -227,8 +227,8 @@ public void emit(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { SPIRVId typeResultOperationId = asm.primitives.getTypePrimitive(resultKind); - Logger.traceCodeGen(Logger.BACKEND.SPIRV, - "emitBinaryOperation " + binaryOperation.getInstruction() + ": " + x + " " + binaryOperation.getOpcode() + " " + y + " Result Kind: " + resultKind); + Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emitBinaryOperation " + binaryOperation.getInstruction() + ": " + x + " " + binaryOperation + .getOpcode() + " " + y + " Result Kind: " + resultKind); SPIRVId operationId = obtainPhiValueIdIfNeeded(asm); diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVKind.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVKind.java index 55edd856ab..5caa7c7180 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVKind.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVKind.java @@ -86,7 +86,7 @@ public enum SPIRVKind implements PlatformKind { OP_TYPE_VECTOR4_INT_64(4, uk.ac.manchester.tornado.api.types.vectors.Int4.TYPE, OP_TYPE_INT_64), OP_TYPE_VECTORINT4_INT_32(4, uk.ac.manchester.tornado.api.types.collections.VectorInt4.TYPE, OP_TYPE_INT_32), - // OP_TYPE_VECTOR 8 + // OP_TYPE_VECTOR 8 OP_TYPE_VECTOR8_INT_32(8, uk.ac.manchester.tornado.api.types.vectors.Int8.TYPE, OP_TYPE_INT_32), OP_TYPE_VECTOR16_INT_32(16, uk.ac.manchester.tornado.api.types.vectors.Int16.TYPE, OP_TYPE_INT_32), OP_TYPE_VECTOR8_INT_64(8, uk.ac.manchester.tornado.api.types.vectors.Int8.TYPE, OP_TYPE_INT_64), @@ -200,54 +200,35 @@ public enum SPIRVKind implements PlatformKind { } public static SPIRVKind fromJavaKind(JavaKind stackKind) { - switch (stackKind) { - case Void: - return SPIRVKind.OP_TYPE_VOID; - case Boolean: - return SPIRVKind.OP_TYPE_BOOL; - case Byte: - return SPIRVKind.OP_TYPE_INT_8; - case Short: - return SPIRVKind.OP_TYPE_INT_16; - case Int: - return SPIRVKind.OP_TYPE_INT_32; - case Long: - return SPIRVKind.OP_TYPE_INT_64; - case Float: - return SPIRVKind.OP_TYPE_FLOAT_32; - case Double: - return SPIRVKind.OP_TYPE_FLOAT_64; - default: - throw new TornadoRuntimeException("Java type not supported: " + stackKind); - } + return switch (stackKind) { + case Void -> SPIRVKind.OP_TYPE_VOID; + case Boolean -> SPIRVKind.OP_TYPE_BOOL; + case Byte -> SPIRVKind.OP_TYPE_INT_8; + case Short -> SPIRVKind.OP_TYPE_INT_16; + case Int -> SPIRVKind.OP_TYPE_INT_32; + case Long -> SPIRVKind.OP_TYPE_INT_64; + case Float -> SPIRVKind.OP_TYPE_FLOAT_32; + case Double -> SPIRVKind.OP_TYPE_FLOAT_64; + default -> throw new TornadoRuntimeException("Java type not supported: " + stackKind); + }; } public static SPIRVKind fromJavaKindForMethodCalls(JavaKind stackKind) { - switch (stackKind) { - case Void: - return SPIRVKind.OP_TYPE_VOID; - case Boolean: - return SPIRVKind.OP_TYPE_BOOL; - case Char: - return SPIRVKind.OP_TYPE_INT_8; - case Byte: - return SPIRVKind.OP_TYPE_INT_8; - case Short: - return SPIRVKind.OP_TYPE_INT_16; - case Int: - return SPIRVKind.OP_TYPE_INT_32; - case Long: - return SPIRVKind.OP_TYPE_INT_64; - case Float: - return SPIRVKind.OP_TYPE_FLOAT_32; - case Double: - return SPIRVKind.OP_TYPE_FLOAT_64; - case Object: + return switch (stackKind) { + case Void -> SPIRVKind.OP_TYPE_VOID; + case Boolean -> SPIRVKind.OP_TYPE_BOOL; + case Char -> SPIRVKind.OP_TYPE_INT_8; + case Byte -> SPIRVKind.OP_TYPE_INT_8; + case Short -> SPIRVKind.OP_TYPE_INT_16; + case Int -> SPIRVKind.OP_TYPE_INT_32; + case Long -> SPIRVKind.OP_TYPE_INT_64; + case Float -> SPIRVKind.OP_TYPE_FLOAT_32; + case Double -> SPIRVKind.OP_TYPE_FLOAT_64; + case Object -> // we return a 64-bit long value - return SPIRVKind.OP_TYPE_INT_64; - default: - throw new TornadoRuntimeException("Java type not supported: " + stackKind); - } + SPIRVKind.OP_TYPE_INT_64; + default -> throw new TornadoRuntimeException("Java type not supported: " + stackKind); + }; } public static SPIRVKind fromResolvedJavaTypeToVectorKind(ResolvedJavaType type) { @@ -307,55 +288,20 @@ public SPIRVKind getElementKind() { @Override public char getTypeChar() { - switch (kind) { - case OP_TYPE_BOOL: - return 'z'; - case OP_TYPE_INT_8: - return 'c'; - case OP_TYPE_INT_16: - return 's'; - case OP_TYPE_INT_32: - return 'i'; - case OP_TYPE_INT_64: - return 'l'; - case OP_TYPE_FLOAT_32: - return 'f'; - case OP_TYPE_FLOAT_64: - return 'd'; - case OP_TYPE_VECTOR2_INT_16: - case OP_TYPE_VECTOR2_INT_32: - case OP_TYPE_VECTOR2_INT_64: - - case OP_TYPE_VECTOR3_INT_8: - case OP_TYPE_VECTOR3_INT_16: - case OP_TYPE_VECTOR3_INT_32: - case OP_TYPE_VECTOR3_INT_64: - - case OP_TYPE_VECTOR4_INT_8: - case OP_TYPE_VECTOR4_INT_32: - case OP_TYPE_VECTOR4_INT_64: - - case OP_TYPE_VECTOR8_INT_32: - case OP_TYPE_VECTOR8_INT_64: - - case OP_TYPE_VECTOR2_FLOAT_16: - case OP_TYPE_VECTOR2_FLOAT_32: - case OP_TYPE_VECTOR2_FLOAT_64: - - case OP_TYPE_VECTOR4_FLOAT_16: - case OP_TYPE_VECTOR4_FLOAT_32: - case OP_TYPE_VECTOR4_FLOAT_64: - - case OP_TYPE_VECTOR8_FLOAT_16: - case OP_TYPE_VECTOR8_FLOAT_32: - case OP_TYPE_VECTOR8_FLOAT_64: - case OP_TYPE_VECTOR16_FLOAT_32: - case OP_TYPE_VECTOR16_INT_32: - case OP_TYPE_VECTOR16_FLOAT_64: - return 'v'; - default: - return '-'; - } + return switch (kind) { + case OP_TYPE_BOOL -> 'z'; + case OP_TYPE_INT_8 -> 'c'; + case OP_TYPE_INT_16 -> 's'; + case OP_TYPE_INT_32 -> 'i'; + case OP_TYPE_INT_64 -> 'l'; + case OP_TYPE_FLOAT_32 -> 'f'; + case OP_TYPE_FLOAT_64 -> 'd'; + case OP_TYPE_VECTOR2_INT_16, OP_TYPE_VECTOR2_INT_32, OP_TYPE_VECTOR2_INT_64, OP_TYPE_VECTOR3_INT_8, OP_TYPE_VECTOR3_INT_16, OP_TYPE_VECTOR3_INT_32, OP_TYPE_VECTOR3_INT_64, + OP_TYPE_VECTOR4_INT_8, OP_TYPE_VECTOR4_INT_32, OP_TYPE_VECTOR4_INT_64, OP_TYPE_VECTOR8_INT_32, OP_TYPE_VECTOR8_INT_64, OP_TYPE_VECTOR2_FLOAT_16, OP_TYPE_VECTOR2_FLOAT_32, + OP_TYPE_VECTOR2_FLOAT_64, OP_TYPE_VECTOR4_FLOAT_16, OP_TYPE_VECTOR4_FLOAT_32, OP_TYPE_VECTOR4_FLOAT_64, OP_TYPE_VECTOR8_FLOAT_16, OP_TYPE_VECTOR8_FLOAT_32, + OP_TYPE_VECTOR8_FLOAT_64, OP_TYPE_VECTOR16_FLOAT_32, OP_TYPE_VECTOR16_INT_32, OP_TYPE_VECTOR16_FLOAT_64 -> 'v'; + default -> '-'; + }; } @Override diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIROp.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIROp.java index a0db6d9312..6c6bc6b117 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIROp.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIROp.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021, APT Group, Department of Computer Science, + * Copyright (c) 2021, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -13,7 +13,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -30,7 +30,6 @@ import jdk.vm.ci.meta.AllocatableValue; import jdk.vm.ci.meta.PlatformKind; import jdk.vm.ci.meta.Value; -import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVModule; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.SPIRVOpLoad; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVId; import uk.ac.manchester.beehivespirvtoolkit.lib.instructions.operands.SPIRVLiteralInteger; @@ -43,17 +42,10 @@ public abstract class SPIRVLIROp extends Value { - // protected SPIRVModule module; - protected SPIRVLIROp(LIRKind valueKind) { super(valueKind); } - protected SPIRVLIROp(LIRKind valueKind, SPIRVModule module) { - super(valueKind); - // this.module = module; - } - public final void emit(SPIRVCompilationResultBuilder crb) { emit(crb, crb.getAssembler()); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIRStmt.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIRStmt.java index 88eb43bff1..4c71fc8609 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIRStmt.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVLIRStmt.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021-2022 APT Group, Department of Computer Science, + * Copyright (c) 2021-2022, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -13,7 +13,7 @@ * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * @@ -484,15 +484,15 @@ private SPIRVId assignIdToPhiResult(SPIRVAssembler asm) { * in the ASM lookup tables. * * @param asm - * {@link SPIRVAssembler} Assembler + * {@link SPIRVAssembler} Assembler * @param previousID - * {@link SPIRVId} Previous ID in SPIRVId format + * {@link SPIRVId} Previous ID in SPIRVId format * @param previousBranch - * {@link SPIRVId} of the previous branch + * {@link SPIRVId} of the previous branch * @param newID - * {@link SPIRVId} of the new ID + * {@link SPIRVId} of the new ID * @param currentBranch - * {@link SPIRVId} basic block of the new ID + * {@link SPIRVId} basic block of the new ID * @return {@link SPIRVMultipleOperands} */ private SPIRVMultipleOperands composeOperands(SPIRVAssembler asm, SPIRVId previousID, SPIRVId previousBranch, SPIRVId newID, SPIRVId currentBranch) { @@ -601,9 +601,9 @@ public ASSIGNParameter(AllocatableValue lhs, Value rhs, int alignment, int param * * * @param crb - * {@link SPIRVCompilationResultBuilder} + * {@link SPIRVCompilationResultBuilder} * @param asm - * {@link SPIRVAssembler} + * {@link SPIRVAssembler} */ @Override protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { @@ -657,9 +657,9 @@ public ASSIGNParameterWithNoStore(AllocatableValue lhs, Value rhs) { * Loads the stack frame. This version optimizes Loads/Stores. * * @param crb - * {@link SPIRVCompilationResultBuilder} + * {@link SPIRVCompilationResultBuilder} * @param asm - * {@link SPIRVAssembler} + * {@link SPIRVAssembler} */ @Override protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { @@ -706,9 +706,9 @@ public ASSIGNParameterWithNoStoreNewMemoryModel(AllocatableValue lhs, Value rhs) * Loads the stack frame. This version optimizes Loads/Stores. * * @param crb - * {@link SPIRVCompilationResultBuilder} + * {@link SPIRVCompilationResultBuilder} * @param asm - * {@link SPIRVAssembler} + * {@link SPIRVAssembler} */ @Override protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { @@ -758,9 +758,9 @@ public ASSIGNIndexedParameter(AllocatableValue lhs, Value rhs) { * * * @param crb - * {@link SPIRVCompilationResultBuilder} + * {@link SPIRVCompilationResultBuilder} * @param asm - * {@link SPIRVAssembler} + * {@link SPIRVAssembler} */ @Override protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { @@ -940,20 +940,20 @@ public LoadVectorStmt(AllocatableValue result, SPIRVAddressCast cast, MemoryAcce * Then the SPIR-V Optimizer is not enabled: * * - * %43 = OpLoad %_ptr_CrossWorkgroup_ulong %frame Aligned 8 - * %44 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_ulong %43 %uint_3 - * %45 = OpLoad %ulong %44 Aligned 8 - * OpStore %spirv_l_0F0 %45 Aligned 8 - * %46 = OpLoad %ulong %spirv_l_0F0 Aligned 8 - * %48 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %46 - * %49 = OpExtInst %v2float %1 vloadn %ulong_0 %48 2 - * OpStore %spirv_v2f_1F0 %49 Aligned 8 + * %43 = OpLoad %_ptr_CrossWorkgroup_ulong %frame Aligned 8 + * %44 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_ulong %43 %uint_3 + * %45 = OpLoad %ulong %44 Aligned 8 + * OpStore %spirv_l_0F0 %45 Aligned 8 + * %46 = OpLoad %ulong %spirv_l_0F0 Aligned 8 + * %48 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %46 + * %49 = OpExtInst %v2float %1 vloadn %ulong_0 %48 2 + * OpStore %spirv_v2f_1F0 %49 Aligned 8 * * * @param crb - * {@link SPIRVCompilationResultBuilder} + * {@link SPIRVCompilationResultBuilder} * @param asm - * {@link SPIRVAssembler} + * {@link SPIRVAssembler} */ @Override public void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { @@ -1040,12 +1040,19 @@ public StoreStmt(SPIRVAddressCast cast, MemoryAccess address, Value rhs) { @Override protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { + Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emit StoreStmt in address: " + cast + " <- " + rhs); cast.emit(crb, asm); + boolean isFP16Cast = cast.getLIRKind().getPlatformKind() == SPIRVKind.OP_TYPE_FLOAT_16; + SPIRVId value; if (rhs instanceof ConstantValue) { - value = asm.lookUpConstant(((ConstantValue) this.rhs).getConstant().toValueString(), (SPIRVKind) rhs.getPlatformKind()); + if (isFP16Cast) { + value = asm.lookUpConstant(((ConstantValue) this.rhs).getConstant().toValueString(), SPIRVKind.OP_TYPE_FLOAT_16); + } else { + value = asm.lookUpConstant(((ConstantValue) this.rhs).getConstant().toValueString(), (SPIRVKind) rhs.getPlatformKind()); + } } else { value = asm.lookUpLIRInstructions(rhs); if (TornadoOptions.OPTIMIZE_LOAD_STORE_SPIRV) { @@ -1478,8 +1485,8 @@ public IndexedLoadMemAccess(SPIRVUnary.MemoryIndexedAccess address, AllocatableV @Override protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { - Logger.traceCodeGen(Logger.BACKEND.SPIRV, - "emit IndexedLoadMemAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion().getMemorySpace().getName()); + Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emit IndexedLoadMemAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion().getMemorySpace() + .getName()); SPIRVKind spirvKind = (SPIRVKind) result.getPlatformKind(); SPIRVId type = asm.primitives.getTypePrimitive(spirvKind); @@ -1525,22 +1532,22 @@ public IndexedLoadMemCollectionAccess(SPIRVUnary.MemoryIndexedAccess address, Al * (e.g., VectorFloat2) * * - * %301 = OpInBoundsPtrAccessChain %_ptr_Function_double %spirv_d_1F0 %ulong_0 %ulong_0 - * %302 = OpPtrCastToGeneric %_ptr_Generic_double %301 - * %303 = OpExtInst %v2double %1 vloadn %ulong_0 %302 2 - * OpStore %spirv_v2d_50F0 %303 Aligned 16 + * %301 = OpInBoundsPtrAccessChain %_ptr_Function_double %spirv_d_1F0 %ulong_0 %ulong_0 + * %302 = OpPtrCastToGeneric %_ptr_Generic_double %301 + * %303 = OpExtInst %v2double %1 vloadn %ulong_0 %302 2 + * OpStore %spirv_v2d_50F0 %303 Aligned 16 * * * @param crb - * {@link SPIRVCompilationResultBuilder} + * {@link SPIRVCompilationResultBuilder} * @param asm - * {@link SPIRVAssembler} + * {@link SPIRVAssembler} */ @Override protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { - Logger.traceCodeGen(Logger.BACKEND.SPIRV, - "emit IndexedLoadMemCollectionAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion().getMemorySpace().getName()); + Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emit IndexedLoadMemCollectionAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion() + .getMemorySpace().getName()); SPIRVKind spirvKind = (SPIRVKind) result.getPlatformKind(); diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java index 383b23180d..2ee45f4984 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021-2023, APT Group, Department of Computer Science, + * Copyright (c) 2021-2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -534,13 +534,10 @@ public void emitForLoad(SPIRVAssembler asm, SPIRVKind resultKind) { public static class SPIRVAddressCast extends UnaryConsumer { - private final SPIRVMemoryBase base; - private final Value address; public SPIRVAddressCast(Value address, SPIRVMemoryBase base, LIRKind valueKind) { super(null, null, valueKind, address); - this.base = base; this.address = address; } @@ -562,6 +559,7 @@ public void emit(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { final SPIRVId idLoad; SPIRVId addressId = asm.lookUpLIRInstructions(this.address); + if (TornadoOptions.OPTIMIZE_LOAD_STORE_SPIRV) { idLoad = addressId; } else { @@ -604,9 +602,7 @@ public ThreadBuiltinCallForSPIRV(SPIRVThreadBuiltIn builtIn, Variable result, LI * * * - * %37 = OpLoad %v3ulong %__spirv_BuiltInGlobalInvocationId Aligned 32 - * %call = OpCompositeExtract %ulong %37 0 - * %conv = OpUConvert %uint %call + * %37 = OpLoad %v3ulong %__spirv_BuiltInGlobalInvocationId Aligned 32 %call = OpCompositeExtract %ulong %37 0 %conv = OpUConvert %uint %call * */ @Override @@ -740,15 +736,13 @@ public SignNarrowValue(LIRKind lirKind, Variable result, Value inputVal, int toB } /** - * Following this: - * {@url https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpSConvert} + * Following this: {@url https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpSConvert} * * * Convert signed width. This is either a truncate or a sign extend. * *

- * OpSConvert can be used for sign extend as well as truncate. The "S" symbol - * represents signed format. + * OpSConvert can be used for sign extend as well as truncate. The "S" symbol represents signed format. * * @param crb * {@link SPIRVCompilationResultBuilder} @@ -932,14 +926,11 @@ public void emit(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) { } /** - * OpenCL Extended Instruction Set Intrinsics. As specified in the SPIR-V 1.0 - * standard, the following intrinsics in SPIR-V represents builtin functions - * from the OpenCL standard. + * OpenCL Extended Instruction Set Intrinsics. As specified in the SPIR-V 1.0 standard, the following intrinsics in SPIR-V represents builtin functions from the OpenCL standard. * * For obtaining the correct Int-Reference of the function: * * https://www.khronos.org/registry/spir-v/specs/1.0/OpenCL.ExtendedInstructionSet.100.html - * */ public static class Intrinsic extends UnaryConsumer { diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/CastNode.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/CastNode.java index 7f271fefec..0c19a6b3fa 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/CastNode.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/CastNode.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2021, APT Group, Department of Computer Science, + * Copyright (c) 2021, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -58,27 +58,16 @@ public CastNode(Stamp stamp, FloatConvert op, ValueNode value) { } private SPIRVUnary.CastOperations resolveOp(Variable result, LIRKind lirKind, Value value) { - switch (op) { - case I2F: - return new SPIRVUnary.CastIToFloat(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_32); - case I2D: - return new SPIRVUnary.CastIToFloat(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_64); - case D2F: - return new SPIRVUnary.CastFloatDouble(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_32); - case F2D: - return new SPIRVUnary.CastFloatDouble(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_64); - case L2D: - return new SPIRVUnary.CastFloatDouble(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_64); - case L2F: - return new SPIRVUnary.CastFloatToLong(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_32); - case F2I: - return new SPIRVUnary.CastFloatToInt(lirKind, result, value, SPIRVKind.OP_TYPE_INT_32); - case D2L: - case F2L: - case D2I: - default: - throw new RuntimeException("Conversion Cast Operation unimplemented: " + op); - } + return switch (op) { + case I2F -> new SPIRVUnary.CastIToFloat(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_32); + case I2D -> new SPIRVUnary.CastIToFloat(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_64); + case D2F -> new SPIRVUnary.CastFloatDouble(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_32); + case F2D -> new SPIRVUnary.CastFloatDouble(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_64); + case L2D -> new SPIRVUnary.CastFloatDouble(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_64); + case L2F -> new SPIRVUnary.CastFloatToLong(lirKind, result, value, SPIRVKind.OP_TYPE_FLOAT_32); + case F2I -> new SPIRVUnary.CastFloatToInt(lirKind, result, value, SPIRVKind.OP_TYPE_INT_32); + default -> throw new RuntimeException("Conversion Cast Operation unimplemented: " + op); + }; } @Override diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/ReadHalfFloatNode.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/ReadHalfFloatNode.java new file mode 100644 index 0000000000..95139ee480 --- /dev/null +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/ReadHalfFloatNode.java @@ -0,0 +1,68 @@ +/* + * This file is part of Tornado: A heterogeneous programming framework: + * https://github.com/beehive-lab/tornadovm + * + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.spirv.graal.nodes; + +import org.graalvm.compiler.core.common.LIRKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.lir.Variable; +import org.graalvm.compiler.lir.gen.LIRGeneratorTool; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.FixedWithNextNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.nodes.spi.LIRLowerable; +import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.Value; +import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVArchitecture; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVKind; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVLIRStmt; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVUnary; + +@NodeInfo +public class ReadHalfFloatNode extends FixedWithNextNode implements LIRLowerable { + + public static final NodeClass TYPE = NodeClass.create(ReadHalfFloatNode.class); + + @Input + private AddressNode addressNode; + + public ReadHalfFloatNode(AddressNode addressNode) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.addressNode = addressNode; + } + + public void generate(NodeLIRBuilderTool generator) { + LIRGeneratorTool tool = generator.getLIRGeneratorTool(); + Variable result = tool.newVariable(LIRKind.value(SPIRVKind.OP_TYPE_FLOAT_16)); + Value addressValue = generator.operand(addressNode); + SPIRVArchitecture.SPIRVMemoryBase base = ((SPIRVUnary.MemoryAccess) (addressValue)).getMemoryRegion(); + SPIRVUnary.SPIRVAddressCast cast = new SPIRVUnary.SPIRVAddressCast(addressValue, base, LIRKind.value(SPIRVKind.OP_TYPE_FLOAT_16)); + tool.append(new SPIRVLIRStmt.LoadStmt(result, cast, (SPIRVUnary.MemoryAccess) addressValue)); + generator.setResult(this, result); + } + +} diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/WriteHalfFloatNode.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/WriteHalfFloatNode.java new file mode 100644 index 0000000000..487f694f40 --- /dev/null +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/nodes/WriteHalfFloatNode.java @@ -0,0 +1,79 @@ +/* + * This file is part of Tornado: A heterogeneous programming framework: + * https://github.com/beehive-lab/tornadovm + * + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.spirv.graal.nodes; + +import org.graalvm.compiler.core.common.LIRKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.lir.gen.LIRGeneratorTool; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.FixedWithNextNode; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.nodes.spi.LIRLowerable; +import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.Value; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVKind; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVLIRStmt; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVUnary; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVUnary.MemoryAccess; +import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVUnary.MemoryIndexedAccess; + +@NodeInfo +public class WriteHalfFloatNode extends FixedWithNextNode implements LIRLowerable { + + public static final NodeClass TYPE = NodeClass.create(WriteHalfFloatNode.class); + + @Input + private AddressNode addressNode; + + @Input + private ValueNode valueNode; + + public WriteHalfFloatNode(AddressNode addressNode, ValueNode valueNode) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.addressNode = addressNode; + this.valueNode = valueNode; + } + + public void generate(NodeLIRBuilderTool generator) { + LIRGeneratorTool tool = generator.getLIRGeneratorTool(); + + Value addressValue = generator.operand(addressNode); + Value valueToStore = generator.operand(valueNode); + + if (addressValue instanceof MemoryAccess memoryAccess) { + SPIRVUnary.SPIRVAddressCast cast = new SPIRVUnary.SPIRVAddressCast(memoryAccess.getValue(), memoryAccess.getMemoryRegion(), LIRKind.value(SPIRVKind.OP_TYPE_FLOAT_16)); + if (memoryAccess.getIndex() == null) { + tool.append(new SPIRVLIRStmt.StoreStmt(cast, memoryAccess, valueToStore)); + } + } else if (addressValue instanceof MemoryIndexedAccess indexedAccess) { + tool.append(new SPIRVLIRStmt.StoreIndexedMemAccess(indexedAccess, valueToStore)); + } + + } +} diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java new file mode 100644 index 0000000000..559dff8a68 --- /dev/null +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java @@ -0,0 +1,197 @@ +/* + * This file is part of Tornado: A heterogeneous programming framework: + * https://github.com/beehive-lab/tornadovm + * + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.spirv.graal.phases; + +import java.util.Optional; + +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.nodes.GraphState; +import org.graalvm.compiler.nodes.StructuredGraph; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.calc.AddNode; +import org.graalvm.compiler.nodes.calc.FloatDivNode; +import org.graalvm.compiler.nodes.calc.MulNode; +import org.graalvm.compiler.nodes.calc.SubNode; +import org.graalvm.compiler.nodes.extended.JavaReadNode; +import org.graalvm.compiler.nodes.extended.JavaWriteNode; +import org.graalvm.compiler.nodes.java.NewInstanceNode; +import org.graalvm.compiler.nodes.memory.address.AddressNode; +import org.graalvm.compiler.phases.BasePhase; + +import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.ReadHalfFloatNode; +import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.WriteHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; +import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance; +import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext; + +public class TornadoHalfFloatReplacement extends BasePhase { + + @Override + public Optional notApplicableTo(GraphState graphState) { + return ALWAYS_APPLICABLE; + } + + protected void run(StructuredGraph graph, TornadoHighTierContext context) { + + // replace reads with halfFloat reads + for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) { + if (javaRead.successors().first() instanceof NewInstanceNode) { + NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first(); + if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) { + if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) { + NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) newInstanceNode.successors().first(); + deleteFixed(newHalfFloatInstance); + } + AddressNode readingAddress = javaRead.getAddress(); + ReadHalfFloatNode readHalfFloatNode = new ReadHalfFloatNode(readingAddress); + graph.addWithoutUnique(readHalfFloatNode); + replaceFixed(javaRead, readHalfFloatNode); + newInstanceNode.replaceAtUsages(readHalfFloatNode); + deleteFixed(newInstanceNode); + } + } + } + + // replace writes with halfFloat writes + for (JavaWriteNode javaWrite : graph.getNodes().filter(JavaWriteNode.class)) { + if (isWriteHalfFloat(javaWrite)) { + // This casting is safe to do as it is already checked by the isWriteHalfFloat function + HalfFloatPlaceholder placeholder = (HalfFloatPlaceholder) javaWrite.value(); + ValueNode writingValue; + if (javaWrite.predecessor() instanceof NewHalfFloatInstance) { + // if a new HalfFloat instance is written + NewHalfFloatInstance newHalfFloatInstance = (NewHalfFloatInstance) javaWrite.predecessor(); + writingValue = newHalfFloatInstance.getValue(); + if (newHalfFloatInstance.predecessor() instanceof NewInstanceNode) { + NewInstanceNode newInstanceNode = (NewInstanceNode) newHalfFloatInstance.predecessor(); + if (newInstanceNode.instanceClass().toString().contains("HalfFloat")) { + deleteFixed(newInstanceNode); + deleteFixed(newHalfFloatInstance); + } + } + } else { + // if the result of an operation or a stored value is written + writingValue = placeholder.getInput(); + } + placeholder.replaceAtUsages(writingValue); + placeholder.safeDelete(); + AddressNode writingAddress = javaWrite.getAddress(); + WriteHalfFloatNode writeHalfFloatNode = new WriteHalfFloatNode(writingAddress, writingValue); + graph.addWithoutUnique(writeHalfFloatNode); + replaceFixed(javaWrite, writeHalfFloatNode); + deleteFixed(javaWrite); + } + } + + // replace the half float operator nodes with the corresponding regular operators + replaceAddHalfFloatNodes(graph); + replaceSubHalfFloatNodes(graph); + replaceMultHalfFloatNodes(graph); + replaceDivHalfFloatNodes(graph); + + } + + private static void replaceAddHalfFloatNodes(StructuredGraph graph) { + for (AddHalfFloatNode addHalfFloatNode : graph.getNodes().filter(AddHalfFloatNode.class)) { + AddNode addNode = new AddNode(addHalfFloatNode.getX(), addHalfFloatNode.getY()); + graph.addWithoutUnique(addNode); + addHalfFloatNode.replaceAtUsages(addNode); + addHalfFloatNode.safeDelete(); + } + } + + private static void replaceSubHalfFloatNodes(StructuredGraph graph) { + for (SubHalfFloatNode subHalfFloatNode : graph.getNodes().filter(SubHalfFloatNode.class)) { + SubNode subNode = new SubNode(subHalfFloatNode.getX(), subHalfFloatNode.getY()); + graph.addWithoutUnique(subNode); + subHalfFloatNode.replaceAtUsages(subNode); + subHalfFloatNode.safeDelete(); + } + } + + private static void replaceMultHalfFloatNodes(StructuredGraph graph) { + for (MultHalfFloatNode multHalfFloatNode : graph.getNodes().filter(MultHalfFloatNode.class)) { + MulNode mulNode = new MulNode(multHalfFloatNode.getX(), multHalfFloatNode.getY()); + graph.addWithoutUnique(mulNode); + multHalfFloatNode.replaceAtUsages(mulNode); + multHalfFloatNode.safeDelete(); + } + } + + private static void replaceDivHalfFloatNodes(StructuredGraph graph) { + for (DivHalfFloatNode divHalfFloatNode : graph.getNodes().filter(DivHalfFloatNode.class)) { + FloatDivNode divNode = new FloatDivNode(divHalfFloatNode.getX(), divHalfFloatNode.getY()); + graph.addWithoutUnique(divNode); + divHalfFloatNode.replaceAtUsages(divNode); + divHalfFloatNode.safeDelete(); + } + } + + private static boolean isWriteHalfFloat(JavaWriteNode javaWrite) { + if (javaWrite.value() instanceof HalfFloatPlaceholder) { + return true; + } + return false; + } + + private static void replaceFixed(Node n, Node other) { + Node pred = n.predecessor(); + Node suc = n.successors().first(); + + n.replaceFirstSuccessor(suc, null); + n.replaceAtPredecessor(other); + pred.replaceFirstSuccessor(n, other); + other.replaceFirstSuccessor(null, suc); + + for (Node us : n.usages()) { + n.removeUsage(us); + } + n.clearInputs(); + n.safeDelete(); + + } + + private static void deleteFixed(Node node) { + if (!node.isDeleted()) { + Node predecessor = node.predecessor(); + Node successor = node.successors().first(); + + node.replaceFirstSuccessor(successor, null); + node.replaceAtPredecessor(successor); + predecessor.replaceFirstSuccessor(node, successor); + + for (Node us : node.usages()) { + node.removeUsage(us); + } + node.clearInputs(); + node.safeDelete(); + } + } + +} diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemorySegmentWrapper.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemorySegmentWrapper.java index 066416f8d8..fa5b75f470 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemorySegmentWrapper.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/mm/SPIRVMemorySegmentWrapper.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2023, APT Group, Department of Computer Science, + * Copyright (c) 2023, 2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -36,6 +36,7 @@ import uk.ac.manchester.tornado.api.types.arrays.CharArray; import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.api.types.arrays.LongArray; import uk.ac.manchester.tornado.api.types.arrays.ShortArray; @@ -122,6 +123,7 @@ private MemorySegment getSegment(final Object reference) { case ShortArray shortArray -> shortArray.getSegment(); case ByteArray byteArray -> byteArray.getSegment(); case CharArray charArray -> charArray.getSegment(); + case HalfFloatArray halfFloatArray -> halfFloatArray.getSegment(); case VectorFloat2 vectorFloat2 -> vectorFloat2.getArray().getSegment(); case VectorFloat3 vectorFloat3 -> vectorFloat3.getArray().getSegment(); case VectorFloat4 vectorFloat4 -> vectorFloat4.getArray().getSegment(); diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/compiler/TornadoSketchTier.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/compiler/TornadoSketchTier.java index 4a2eeb12cb..0466c77e82 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/compiler/TornadoSketchTier.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/compiler/TornadoSketchTier.java @@ -41,6 +41,7 @@ import uk.ac.manchester.tornado.runtime.graal.phases.sketcher.TornadoAutoParalleliser; import uk.ac.manchester.tornado.runtime.graal.phases.sketcher.TornadoDataflowAnalysis; import uk.ac.manchester.tornado.runtime.graal.phases.sketcher.TornadoFullInliningPolicy; +import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHalfFloatFixedGuardElimination; import uk.ac.manchester.tornado.runtime.graal.phases.sketcher.TornadoKernelContextReplacement; import uk.ac.manchester.tornado.runtime.graal.phases.sketcher.TornadoNativeTypeElimination; import uk.ac.manchester.tornado.runtime.graal.phases.sketcher.TornadoNumericPromotionPhase; @@ -74,6 +75,7 @@ public TornadoSketchTier(OptionValues options, CanonicalizerPhase.CustomSimplifi } appendPhase(new TornadoStampResolver()); + appendPhase(new TornadoHalfFloatFixedGuardElimination()); appendPhase(new TornadoNativeTypeElimination()); appendPhase(new TornadoReduceReplacement()); appendPhase(new TornadoApiReplacement()); diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/AddHalfFloatNode.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/AddHalfFloatNode.java new file mode 100644 index 0000000000..ae182fa02b --- /dev/null +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/AddHalfFloatNode.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.runtime.graal.nodes; + +import jdk.vm.ci.meta.JavaKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.ValueNode; + +@NodeInfo(shortName = "FLOAT16(+)") +public class AddHalfFloatNode extends ValueNode { + + public static final NodeClass TYPE = NodeClass.create(AddHalfFloatNode.class); + + @Node.Input + ValueNode input1; + + @Node.Input + ValueNode input2; + + public AddHalfFloatNode(ValueNode input1, ValueNode input2) { + super(TYPE, StampFactory.forKind(JavaKind.Object)); + this.input1 = input1; + this.input2 = input2; + } + + public ValueNode getX() { + return input1; + } + + public ValueNode getY() { + return input2; + } +} diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/DivHalfFloatNode.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/DivHalfFloatNode.java new file mode 100644 index 0000000000..86229c4b52 --- /dev/null +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/DivHalfFloatNode.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.runtime.graal.nodes; + +import jdk.vm.ci.meta.JavaKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.ValueNode; + +@NodeInfo(shortName = "FLOAT16(/)") +public class DivHalfFloatNode extends ValueNode { + public static final NodeClass TYPE = NodeClass.create(DivHalfFloatNode.class); + + @Node.Input + ValueNode input1; + + @Node.Input + ValueNode input2; + + public DivHalfFloatNode(ValueNode input1, ValueNode input2) { + super(TYPE, StampFactory.forKind(JavaKind.Object)); + this.input1 = input1; + this.input2 = input2; + } + + public ValueNode getX() { + return input1; + } + + public ValueNode getY() { + return input2; + } +} diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/HalfFloatPlaceholder.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/HalfFloatPlaceholder.java new file mode 100644 index 0000000000..9b232f03b5 --- /dev/null +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/HalfFloatPlaceholder.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.runtime.graal.nodes; + +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.ValueNode; + +import jdk.vm.ci.meta.JavaKind; + +@NodeInfo +public class HalfFloatPlaceholder extends ValueNode { + + public static final NodeClass TYPE = NodeClass.create(HalfFloatPlaceholder.class); + + @Input + private ValueNode input; + + public HalfFloatPlaceholder(ValueNode input) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.input = input; + } + + public ValueNode getInput() { + return this.input; + } + + public void setInput(ValueNode input) { + this.input = input; + } + +} diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/MultHalfFloatNode.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/MultHalfFloatNode.java new file mode 100644 index 0000000000..81ea3f97cd --- /dev/null +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/MultHalfFloatNode.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.runtime.graal.nodes; + +import jdk.vm.ci.meta.JavaKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.ValueNode; + +@NodeInfo(shortName = "FLOAT16(*)") +public class MultHalfFloatNode extends ValueNode { + + public static final NodeClass TYPE = NodeClass.create(MultHalfFloatNode.class); + + @Node.Input + ValueNode input1; + + @Node.Input + ValueNode input2; + + public MultHalfFloatNode(ValueNode input1, ValueNode input2) { + super(TYPE, StampFactory.forKind(JavaKind.Object)); + this.input1 = input1; + this.input2 = input2; + } + + public ValueNode getX() { + return input1; + } + + public ValueNode getY() { + return input2; + } + +} diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/NewHalfFloatInstance.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/NewHalfFloatInstance.java new file mode 100644 index 0000000000..4dc1496ed1 --- /dev/null +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/NewHalfFloatInstance.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.runtime.graal.nodes; + +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.FixedWithNextNode; +import jdk.vm.ci.meta.JavaKind; +import org.graalvm.compiler.nodes.ValueNode; + +@NodeInfo +public class NewHalfFloatInstance extends FixedWithNextNode { + + public static final NodeClass TYPE = NodeClass.create(NewHalfFloatInstance.class); + + @Input + private ValueNode value; + + public NewHalfFloatInstance(ValueNode value) { + super(TYPE, StampFactory.forKind(JavaKind.Short)); + this.value = value; + } + + public ValueNode getValue() { + return this.value; + } +} diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/SubHalfFloatNode.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/SubHalfFloatNode.java new file mode 100644 index 0000000000..a0576f1e28 --- /dev/null +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/nodes/SubHalfFloatNode.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.runtime.graal.nodes; + +import jdk.vm.ci.meta.JavaKind; +import org.graalvm.compiler.core.common.type.StampFactory; +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.graph.NodeClass; +import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.ValueNode; + +@NodeInfo(shortName = "FLOAT16(-)") +public class SubHalfFloatNode extends ValueNode { + public static final NodeClass TYPE = NodeClass.create(SubHalfFloatNode.class); + + @Node.Input + ValueNode input1; + + @Node.Input + ValueNode input2; + + public SubHalfFloatNode(ValueNode input1, ValueNode input2) { + super(TYPE, StampFactory.forKind(JavaKind.Object)); + this.input1 = input1; + this.input2 = input2; + } + + public ValueNode getX() { + return input1; + } + + public ValueNode getY() { + return input2; + } +} diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java new file mode 100644 index 0000000000..e005addc49 --- /dev/null +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHalfFloatFixedGuardElimination.java @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * School of Engineering, The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.runtime.graal.phases; + +import java.util.ArrayList; +import java.util.Optional; + +import org.graalvm.compiler.graph.Node; +import org.graalvm.compiler.nodes.FixedGuardNode; +import org.graalvm.compiler.nodes.GraphState; +import org.graalvm.compiler.nodes.PiNode; +import org.graalvm.compiler.nodes.StructuredGraph; +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.phases.BasePhase; + +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; + +public class TornadoHalfFloatFixedGuardElimination extends BasePhase { + + @Override + public Optional notApplicableTo(GraphState graphState) { + return ALWAYS_APPLICABLE; + } + + protected void run(StructuredGraph graph, TornadoSketchTierContext context) { + ArrayList nodesToBeDeleted = new ArrayList(); + for (HalfFloatPlaceholder placeholderNode : graph.getNodes().filter(HalfFloatPlaceholder.class)) { + if (placeholderNode.getInput() instanceof PiNode) { + PiNode placeholderInput = (PiNode) placeholderNode.getInput(); + ValueNode halfFloatValue = placeholderInput.object(); + FixedGuardNode placeholderGuard = (FixedGuardNode) placeholderInput.getGuard(); + deleteFixed(placeholderGuard); + placeholderNode.setInput(halfFloatValue); + nodesToBeDeleted.add(placeholderInput); + } + } + + for (ValueNode node : nodesToBeDeleted) { + node.safeDelete(); + } + } + + private static void deleteFixed(Node node) { + if (!node.isDeleted()) { + Node predecessor = node.predecessor(); + Node successor = node.successors().first(); + + node.replaceFirstSuccessor(successor, null); + node.replaceAtPredecessor(successor); + predecessor.replaceFirstSuccessor(node, successor); + + for (Node us : node.usages()) { + node.removeUsage(us); + } + node.clearInputs(); + node.safeDelete(); + } + } +} diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java index f0e2e7a77e..1701a9fa14 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/sketcher/TornadoSketcher.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2020, 2023 APT Group, Department of Computer Science, + * Copyright (c) 2020, 2023, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2013-2020, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestAPI.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestAPI.java index c7521cb059..1e14d2fbfe 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestAPI.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/api/TestAPI.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013-2020, 2022 APT Group, Department of Computer Science, + * Copyright (c) 2013-2020, 2022, 2024, APT Group, Department of Computer Science, * The University of Manchester. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,10 +34,12 @@ import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.TornadoExecutionResult; import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.CharArray; import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.api.types.arrays.LongArray; import uk.ac.manchester.tornado.api.types.arrays.ShortArray; @@ -192,6 +194,25 @@ public void testSegmentsDouble() { /** * Perform the copy out under demand. It performs a copy from the device to the host of the entire array via the execution result. */ + @Test + public void testSegmentsHalfFloats() { + HalfFloatArray dataA = HalfFloatArray.fromElements(new HalfFloat(0), new HalfFloat(1), new HalfFloat(2), new HalfFloat(3)); + HalfFloatArray dataB = HalfFloatArray.fromArray(new HalfFloat[] { new HalfFloat(0), new HalfFloat(1), new HalfFloat(2), new HalfFloat(3) }); + + for (int i = 0; i < dataA.getSize(); i++) { + assertEquals(dataA.get(i).getFloat32(), dataB.get(i).getFloat32(), 0.01f); + } + HalfFloat[] fArray = dataA.toHeapArray(); + for (int i = 0; i < dataA.getSize(); i++) { + assertEquals(fArray[i].getFloat32(), dataA.get(i).getFloat32(), 0.01f); + } + + HalfFloat[] fArrayB = dataB.toHeapArray(); + for (int i = 0; i < dataA.getSize(); i++) { + assertEquals(fArrayB[i].getFloat32(), dataB.get(i).getFloat32(), 0.01f); + } + } + @Test public void testLazyCopyOut() { final int N = 1024; @@ -459,5 +480,26 @@ public void testBuildWithSegmentsChar() { assertEquals((char) 10 + i, charArray.get(i)); } } + + @Test + public void testBuildWithSegmentsHalfFloat() { + + final int n = 10; + // Allocate 10 elements + MemorySegment m = Arena.ofAuto().allocate(ValueLayout.JAVA_SHORT.byteSize() * n); + + // Set 10 elements + for (int i = 0; i < n; i++) { + m.setAtIndex(ValueLayout.JAVA_SHORT, i, Float.floatToFloat16(10 + i)); + } + + // Factory method to build a float array from a segment + HalfFloatArray halfFloatArray = HalfFloatArray.fromSegment(m); + + for (int i = 0; i < n; i++) { + assertEquals(10 + i, halfFloatArray.get(i).getFloat32(), 0.001f); + } + } + // CHECKSTYLE:ON } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/arrays/TestArrays.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/arrays/TestArrays.java index 7ce5367b09..4753928d9f 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/arrays/TestArrays.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/arrays/TestArrays.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013-2022, APT Group, Department of Computer Science, + * Copyright (c) 2013-2022, 2024, APT Group, Department of Computer Science, * The University of Manchester. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,10 +35,12 @@ import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.CharArray; import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import uk.ac.manchester.tornado.api.types.arrays.LongArray; import uk.ac.manchester.tornado.api.types.arrays.ShortArray; @@ -115,6 +117,36 @@ public static void addChars(CharArray a, IntArray b) { } } + public static void initHalfFloatVector(HalfFloatArray c) { + for (@Parallel int i = 0; i < c.getSize(); i++) { + c.set(i, new HalfFloat(100.0f)); + } + } + + public static void vectorAddHalfFloat(HalfFloatArray a, HalfFloatArray b, HalfFloatArray c) { + for (@Parallel int i = 0; i < c.getSize(); i++) { + c.set(i, HalfFloat.add(a.get(i), b.get(i))); + } + } + + public static void vectorSubHalfFloat(HalfFloatArray a, HalfFloatArray b, HalfFloatArray c) { + for (@Parallel int i = 0; i < c.getSize(); i++) { + c.set(i, HalfFloat.sub(a.get(i), b.get(i))); + } + } + + public static void vectorMultHalfFloat(HalfFloatArray a, HalfFloatArray b, HalfFloatArray c) { + for (@Parallel int i = 0; i < c.getSize(); i++) { + c.set(i, HalfFloat.mult(a.get(i), b.get(i))); + } + } + + public static void vectorDivHalfFloat(HalfFloatArray a, HalfFloatArray b, HalfFloatArray c) { + for (@Parallel int i = 0; i < c.getSize(); i++) { + c.set(i, HalfFloat.div(a.get(i), b.get(i))); + } + } + public static void initializeSequentialByte(ByteArray a) { for (int i = 0; i < a.getSize(); i++) { a.set(i, (byte) 21); @@ -484,6 +516,115 @@ public void testVectorBytes() { } } + @Test + public void testHalfFloatInitialization() { + final int numElements = 4096; + HalfFloatArray c = new HalfFloatArray(numElements); + + TaskGraph taskGraph = new TaskGraph("s0") // + .task("t0", TestArrays::initHalfFloatVector, c) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, c); + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph); + executionPlan.execute(); + + for (int i = 0; i < c.getSize(); i++) { + assertEquals(100.0f, c.get(i).getFloat32(), 0.01f); + } + } + + @Test + public void testVectorAdditionHalfFloat() { + final int numElements = 4096; + HalfFloatArray a = new HalfFloatArray(numElements); + HalfFloatArray b = new HalfFloatArray(numElements); + HalfFloatArray c = new HalfFloatArray(numElements); + a.init(new HalfFloat(6.0f)); + b.init(new HalfFloat(2.0f)); + + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) // + .task("t0", TestArrays::vectorAddHalfFloat, a, b, c) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, c); + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph); + executionPlan.execute(); + + for (int i = 0; i < c.getSize(); i++) { + assertEquals(8.0f, c.get(i).getFloat32(), 0.01f); + } + } + + @Test + public void testVectorSubtractionHalfFloat() { + final int numElements = 4096; + HalfFloatArray a = new HalfFloatArray(numElements); + HalfFloatArray b = new HalfFloatArray(numElements); + HalfFloatArray c = new HalfFloatArray(numElements); + a.init(new HalfFloat(6.0f)); + b.init(new HalfFloat(2.0f)); + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) // + .task("t0", TestArrays::vectorSubHalfFloat, a, b, c) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, c); + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph); + executionPlan.execute(); + + for (int i = 0; i < c.getSize(); i++) { + assertEquals(4.0f, c.get(i).getFloat32(), 0.01f); + } + } + + @Test + public void testVectorMultiplicationHalfFloat() { + final int numElements = 4096; + HalfFloatArray a = new HalfFloatArray(numElements); + HalfFloatArray b = new HalfFloatArray(numElements); + HalfFloatArray c = new HalfFloatArray(numElements); + a.init(new HalfFloat(6.0f)); + b.init(new HalfFloat(2.0f)); + + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) // + .task("t0", TestArrays::vectorMultHalfFloat, a, b, c) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, c); + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph); + executionPlan.execute(); + + for (int i = 0; i < c.getSize(); i++) { + assertEquals(12.0f, c.get(i).getFloat32(), 0.01f); + } + } + + @Test + public void testVectorDivisionHalfFloat() { + final int numElements = 4096; + HalfFloatArray a = new HalfFloatArray(numElements); + HalfFloatArray b = new HalfFloatArray(numElements); + HalfFloatArray c = new HalfFloatArray(numElements); + a.init(new HalfFloat(6.0f)); + b.init(new HalfFloat(2.0f)); + + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) // + .task("t0", TestArrays::vectorDivHalfFloat, a, b, c) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, c); + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph); + executionPlan.execute(); + + for (int i = 0; i < c.getSize(); i++) { + assertEquals(3.0f, c.get(i).getFloat32(), 0.01f); + } + } + /** * Inspired by the CUDA Hello World from Computer Graphics: * diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vm/concurrency/TestParallelTaskGraph.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vm/concurrency/TestParallelTaskGraph.java index f27bc0d60d..92f86d99d2 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vm/concurrency/TestParallelTaskGraph.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/vm/concurrency/TestParallelTaskGraph.java @@ -38,7 +38,7 @@ /** * *

- * How to test?: + * How to test? This test requires at least two devices. *

* *