diff --git a/bin/compile b/bin/compile index 7f7f9c3aa0..56695d0f1e 100755 --- a/bin/compile +++ b/bin/compile @@ -191,7 +191,7 @@ def build_spirv_toolkit_and_level_zero(rebuild=False): if (rebuild or build): os.chdir(spirv_tool_kit) - subprocess.run(["git", "pull", "origin", "master"]) + subprocess.run(["git", "pull", "origin", "feat/half"]) subprocess.run(["mvn", "clean", "package"]) subprocess.run(["mvn", "install"]) os.chdir(current) diff --git a/tornado-api/src/main/java/module-info.java b/tornado-api/src/main/java/module-info.java index 420554a6fc..da0e5cb170 100644 --- a/tornado-api/src/main/java/module-info.java +++ b/tornado-api/src/main/java/module-info.java @@ -45,4 +45,5 @@ opens uk.ac.manchester.tornado.api.types.volumes; exports uk.ac.manchester.tornado.api.types.vectors; opens uk.ac.manchester.tornado.api.types.vectors; + exports uk.ac.manchester.tornado.api.types; } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/HalfFloat.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/HalfFloat.java new file mode 100644 index 0000000000..79ed55a7b6 --- /dev/null +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/HalfFloat.java @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.api.types; + +/** + * This class represents a float-16 instance (half float). The data is stored in a short field, to be + * compliant with the representation for float-16 used in the {@link Float} class. The class encapsulates + * methods for getting the data in float-16 and float-32 format, and for basic arithmetic operations (i.e. + * addition, subtraction, multiplication and division). + */ +public class HalfFloat { + + private short halfFloatValue; + + /** + * Constructs a new instance of the {@code HalfFloat} out of a float value. + * To convert the float to a float-16, the floatToFloat16 function of the {@link Float} + * class is used. + * + * @param halfFloat + * The float value that will be stored in a half-float format. + */ + public HalfFloat(float halfFloat) { + this.halfFloatValue = Float.floatToFloat16(halfFloat); + } + + /** + * Constructs a new instance of the {@code HalfFloat} with a given short value. + * + * @param halfFloat + * The short value that represents the half float. + */ + public HalfFloat(short halfFloat) { + this.halfFloatValue = halfFloat; + } + + /** + * Gets the half-float stored in the class. + * + * @return The half float value stored in the {@code HalfFloat} object. + */ + public short getHalfFloatValue() { + return this.halfFloatValue; + } + + /** + * Gets the half-float stored in the class in a 32-bit representation. + * + * @return The float-32 equivalent value the half float stored in the {@code HalfFloat} object. + */ + public float getFloat32() { + return Float.float16ToFloat(halfFloatValue); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs an addition. + * + * @param a + * The first float-16 input for the addition. + * @param b + * The second float-16 input for the addition. + * @return The result of the addition. + */ + private static float addHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA + floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the addition. + * + * @param a + * The first {@code HalfFloat} input for the addition. + * @param b + * The second {@code HalfFloat} input for the addition. + * @return A new {@HalfFloat} containing the results of the addition. + */ + public static HalfFloat add(HalfFloat a, HalfFloat b) { + float result = addHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs a subtraction. + * + * @param a + * The first float-16 input for the subtraction. + * @param b + * The second float-16 input for the subtraction. + * @return The result of the subtraction. + */ + private static float subHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA - floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the subtraction. + * + * @param a + * The first {@code HalfFloat} input for the subtraction. + * @param b + * The second {@code HalfFloat} input for the subtraction. + * @return A new {@HalfFloat} containing the results of the subtraction. + */ + public static HalfFloat sub(HalfFloat a, HalfFloat b) { + float result = subHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs a multiplication. + * + * @param a + * The first float-16 input for the multiplication. + * @param b + * The second float-16 input for the multiplication. + * @return The result of the multiplication. + */ + private static float multHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA * floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the multiplication. + * + * @param a + * The first {@code HalfFloat} input for the multiplication. + * @param b + * The second {@code HalfFloat} input for the multiplication. + * @return A new {@HalfFloat} containing the results of the multiplication. + */ + public static HalfFloat mult(HalfFloat a, HalfFloat b) { + float result = multHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + + /** + * Takes two half float values, converts them to a 32-bit representation and performs a division. + * + * @param a + * The first float-16 input for the division. + * @param b + * The second float-16 input for the division. + * @return The result of the division. + */ + private static float divHalfFloat(short a, short b) { + float floatA = Float.float16ToFloat(a); + float floatB = Float.float16ToFloat(b); + return floatA / floatB; + } + + /** + * Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance + * that contains the results of the division. + * + * @param a + * The first {@code HalfFloat} input for the division. + * @param b + * The second {@code HalfFloat} input for the division. + * @return A new {@HalfFloat} containing the results of the division. + */ + public static HalfFloat div(HalfFloat a, HalfFloat b) { + float result = divHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue()); + return new HalfFloat(result); + } + +} diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java new file mode 100644 index 0000000000..d45e200ac1 --- /dev/null +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/HalfFloatArray.java @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.api.types.arrays; + +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_SHORT; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize; +import uk.ac.manchester.tornado.api.types.HalfFloat; + +/** + * This class represents an array of half floats (float16 types) stored in native memory. + * The half float data is stored in a {@link MemorySegment}, which represents a contiguous region of off-heap memory. + * The class also encapsulates methods for setting and getting half float values, + * for initializing the half float array, and for converting the array to and from different representations. + */ +@SegmentElementSize(size = 2) +public final class HalfFloatArray extends TornadoNativeArray { + + private static final int HALF_FLOAT_BYTES = 2; + private MemorySegment segment; + + private int numberOfElements; + + private int arrayHeaderSize; + + private int baseIndex; + + private long segmentByteSize; + + /** + * Constructs a new instance of the {@code HalfFloatArray} that will store a user-specified number of elements. + * + * @param numberOfElements + * The number of elements in the array. + */ + public HalfFloatArray(int numberOfElements) { + this.numberOfElements = numberOfElements; + arrayHeaderSize = (int) TornadoNativeArray.ARRAY_HEADER; + baseIndex = arrayHeaderSize / HALF_FLOAT_BYTES; + segmentByteSize = numberOfElements * HALF_FLOAT_BYTES + arrayHeaderSize; + + segment = Arena.ofAuto().allocate(segmentByteSize, 1); + segment.setAtIndex(JAVA_INT, 0, numberOfElements); + } + + /** + * Internal method used to create a new instance of the {@code HalfFloatArray} from on-heap data. + * + * @param values + * The on-heap {@link HalfFloat} to create the instance from. + * @return A new {@code HalfFloatArray} instance, initialized with values of the on-heap {@link HalfFloat} array. + */ + private static HalfFloatArray createSegment(HalfFloat[] values) { + HalfFloatArray array = new HalfFloatArray(values.length); + for (int i = 0; i < values.length; i++) { + array.set(i, values[i]); + } + return array; + } + + /** + * Creates a new instance of the {@code HalfFloatArray} class from an on-heap {@link HalfFloat}. + * + * @param values + * The on-heap {@link HalfFloat} array to create the instance from. + * @return A new {@code HalfFloatArray} instance, initialized with values of the on-heap {@link HalfFloat} array. + */ + public static HalfFloatArray fromArray(HalfFloat[] values) { + return createSegment(values); + } + + /** + * Creates a new instance of the {@code HalfFloatArray} class from a set of {@link HalfFloat} values. + * + * @param values + * The {@link HalfFloat} values to initialize the array with. + * @return A new {@code FloatArray} instance, initialized with the given values. + */ + public static HalfFloatArray fromElements(HalfFloat... values) { + return createSegment(values); + } + + /** + * Creates a new instance of the {@code HalfFloatArray} class from a {@link MemorySegment}. + * + * @param segment + * The {@link MemorySegment} containing the off-heap half float data. + * @return A new {@code HalfFloatArray} instance, initialized with the segment data. + */ + public static HalfFloatArray fromSegment(MemorySegment segment) { + long byteSize = segment.byteSize(); + int numElements = (int) (byteSize / HALF_FLOAT_BYTES); + HalfFloatArray halfFloatArray = new HalfFloatArray(numElements); + MemorySegment.copy(segment, 0, halfFloatArray.segment, halfFloatArray.baseIndex * HALF_FLOAT_BYTES, byteSize); + return halfFloatArray; + } + + /** + * Converts the {@link HalfFloat} data from off-heap to on-heap, by copying the values of a {@code HalfFloatArray} + * instance into a new on-heap {@link HalfFloat}. + * + * @return A new on-heap {@link HalfFloat} array, initialized with the values stored in the {@code HalfFloatArray} instance. + */ + public HalfFloat[] toHeapArray() { + HalfFloat[] outputArray = new HalfFloat[getSize()]; + for (int i = 0; i < getSize(); i++) { + outputArray[i] = get(i); + } + return outputArray; + } + + /** + * Sets the {@link HalfFloat} value at a specified index of the {@code HalfFloatArray} instance. + * + * @param index + * The index at which to set the {@link HalfFloat} value. + * @param value + * The {@link HalfFloat} value to store at the specified index. + */ + public void set(int index, HalfFloat value) { + segment.setAtIndex(JAVA_SHORT, baseIndex + index, value.getHalfFloatValue()); + } + + /** + * Gets the {@link HalfFloat} value stored at the specified index of the {@code HalfFloatArray} instance. + * + * @param index + * The index of which to retrieve the {@link HalfFloat} value. + * @return + */ + public HalfFloat get(int index) { + short halfFloatValue = segment.getAtIndex(JAVA_SHORT, baseIndex + index); + return new HalfFloat(halfFloatValue); + } + + /** + * Sets all the values of the {@code HalfFloatArray} instance to zero. + */ + @Override + public void clear() { + init(new HalfFloat(0.0f)); + } + + @Override + public int getElementSize() { + return HALF_FLOAT_BYTES; + } + + /** + * Initializes all the elements of the {@code HalfFloatArray} instance with a specified value. + * + * @param value + * The {@link HalfFloat} value to initialize the {@code HalfFloatArray} instance with. + */ + public void init(HalfFloat value) { + for (int i = 0; i < getSize(); i++) { + segment.setAtIndex(JAVA_SHORT, baseIndex + i, value.getHalfFloatValue()); + } + } + + /** + * Returns the number of half float elements stored in the {@code HalfFloatArray} instance. + * + * @return + */ + @Override + public int getSize() { + return numberOfElements; + } + + /** + * Returns the underlying {@link MemorySegment} of the {@code HalfFloatArray} instance. + * + * @return The {@link MemorySegment} associated with the {@code HalfFloatArray} instance. + */ + @Override + public MemorySegment getSegment() { + return segment; + } + + /** + * Returns the total number of bytes that the {@link MemorySegment}, associated with the {@code HalfFloatArray} instance, occupies. + * + * @return The total number of bytes of the {@link MemorySegment}. + */ + @Override + public long getNumBytesOfSegment() { + return segmentByteSize; + } + + /** + * Returns the number of bytes of the {@link MemorySegment} that is associated with the {@code HalfFloatArray} instance, + * excluding the header bytes. + * + * @return The number of bytes of the raw data in the {@link MemorySegment}. + */ + @Override + public long getNumBytesWithoutHeader() { + return segmentByteSize - TornadoNativeArray.ARRAY_HEADER; + } + +} diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java index 31f619a128..e7b424c499 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013-2023, APT Group, Department of Computer Science, + * Copyright (c) 2013-2024, APT Group, Department of Computer Science, * The University of Manchester. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,9 +33,7 @@ * The constant {@code ARRAY_HEADER} represents the size of the header in bytes. *
*/ -public abstract sealed class TornadoNativeArray permits // - IntArray, FloatArray, DoubleArray, LongArray, ShortArray, // - ByteArray, CharArray { +public abstract sealed class TornadoNativeArray permits ByteArray, CharArray, DoubleArray, FloatArray, IntArray, LongArray, ShortArray, HalfFloatArray { /** * The size of the header in bytes. The default value is 24, but it can be configurable through diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java index c1a54cec00..63040d0664 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLTargetDescription.java @@ -54,6 +54,8 @@ public class OCLTargetDescription extends TargetDescription { private final String extensions; private final boolean supportsInt64Atomics; + private final boolean supportsF16; + public OCLTargetDescription(Architecture arch, boolean supportsFP64, String extensions) { this(arch, false, STACK_ALIGNMENT, IMPLICIT_NULL_CHECK_LIMIT, INLINE_OBJECTS, supportsFP64, extensions); } @@ -63,6 +65,7 @@ protected OCLTargetDescription(Architecture arch, boolean isMP, int stackAlignme this.supportsFP64 = supportsFP64; this.extensions = extensions; supportsInt64Atomics = extensions.contains("cl_khr_int64_base_atomics"); + supportsF16 = extensions.contains("cl_khr_fp16"); } //@formatter:on @@ -92,6 +95,10 @@ public boolean supportsFP64() { return supportsFP64; } + public boolean supportsFP16() { + return supportsF16; + } + public boolean supportsInt64Atomics() { return supportsInt64Atomics; } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java index 1d9760f7a7..281e0e06c3 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLAssembler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2022-2023, APT Group, Department of Computer Science, + * Copyright (c) 2018, 2022-2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * Copyright (c) 2009, 2012, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. @@ -74,6 +74,8 @@ public OCLAssembler(TargetDescription target) { emitLine("#pragma OPENCL EXTENSION cl_khr_fp64 : enable "); } + emitLine("#pragma OPENCL EXTENSION cl_khr_fp16 : enable "); + if (((OCLTargetDescription) target).supportsInt64Atomics()) { emitLine("#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable "); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java index 9d0b1f59a1..609fa312de 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/asm/OCLVariablePrefix.java @@ -2,7 +2,7 @@ * This file is part of Tornado: A heterogeneous programming framework: * https://github.com/beehive-lab/tornadovm * - * Copyright (c) 2023, APT Group, Department of Computer Science, + * Copyright (c) 2023, 2024, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -67,7 +67,9 @@ public enum OCLVariablePrefix { BYTE16("byte16", "b16_"), SHORT("short", "sh_"), SHORT2("short2", "sh2_"), - SHORT3("short3", "sh3_"); + SHORT3("short3", "sh3_"), + HALF("half", "half_"); + // @formatter:on diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java index bc129e17e4..8011b10d36 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/OCLHighTier.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, APT Group, Department of Computer Science, + * Copyright (c) 2020, 2024, APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2018, 2020, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. @@ -50,6 +50,7 @@ import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoLocalMemoryAllocation; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoNewArrayDevirtualizationReplacement; import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPrivateArrayPiRemoval; +import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoHalfFloatReplacement; import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoOpenCLIntrinsicsReplacements; import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoParallelScheduler; import uk.ac.manchester.tornado.drivers.opencl.graal.phases.TornadoTaskSpecialisation; @@ -86,6 +87,8 @@ public OCLHighTier(OptionValues options, TornadoDeviceContext deviceContext, Can appendPhase(new TornadoNewArrayDevirtualizationReplacement()); + appendPhase(new TornadoHalfFloatReplacement()); + if (PartialEscapeAnalysis.getValue(options)) { appendPhase(new PartialEscapePhase(true, canonicalizer, options)); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java index 697547444d..ae52ccad44 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLGraphBuilderPlugins.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, APT Group, Department of Computer Science, + * Copyright (c) 2022, 2024 APT Group, Department of Computer Science, * School of Engineering, The University of Manchester. All rights reserved. * Copyright (c) 2018, 2020, APT Group, Department of Computer Science, * The University of Manchester. All rights reserved. @@ -116,6 +116,8 @@ public static void registerInvocationPlugins(final Plugins ps, final InvocationP // Register TornadoAtomicInteger registerTornadoAtomicInteger(ps, plugins); + OCLHalfFloatPlugins.registerPlugins(ps, plugins); + registerMemoryAccessPlugins(plugins); } diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java new file mode 100644 index 0000000000..b9bd8b83b0 --- /dev/null +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/compiler/plugins/OCLHalfFloatPlugins.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024, APT Group, Department of Computer Science, + * The University of Manchester. All rights reserved. + * Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + */ +package uk.ac.manchester.tornado.drivers.opencl.graal.compiler.plugins; + +import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderContext; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugin; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins; +import org.graalvm.compiler.nodes.graphbuilderconf.NodePlugin; + +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.ResolvedJavaMethod; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode; +import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder; +import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance; + +public class OCLHalfFloatPlugins { + + public static void registerPlugins(final GraphBuilderConfiguration.Plugins ps, final InvocationPlugins plugins) { + registerHalfFloatInit(ps, plugins); + } + + private static void registerHalfFloatInit(GraphBuilderConfiguration.Plugins ps, InvocationPlugins plugins) { + + final InvocationPlugins.Registration r = new InvocationPlugins.Registration(plugins, HalfFloat.class); + + ps.appendNodePlugin(new NodePlugin() { + @Override + public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, ValueNode[] args) { + if (method.getName().equals("
- * Map<%returnType, Map>>
+ * Map<%returnType, Map>>
*
*
* If we have the same number of parameters with the same return type, when we
@@ -400,9 +402,9 @@ private SPIRVId createNewFunctionAndUpdateTables(SPIRVId returnType, SPIRVId...
* parameter (stored in the {@link FunctionTable ) class).
*
* @param returnType
- * ID with the return value.
+ * ID with the return value.
* @param operands
- * List of IDs for the operads.
+ * List of IDs for the operads.
* @return A {@link SPIRVId} for the {@link SPIRVOpFunction}
*/
public SPIRVId emitOpTypeFunction(SPIRVId returnType, SPIRVId... operands) {
@@ -445,7 +447,7 @@ public SPIRVId emitOpTypeFunction(SPIRVId returnType, SPIRVId... operands) {
return functionSignature;
}
- public void emitEntryPointMainKernel(StructuredGraph graph, String kernelName, boolean fp64Capability) {
+ public void emitEntryPointMainKernel(StructuredGraph graph, String kernelName, boolean fp64Capability, boolean fp16Capability) {
mainFunctionID = module.getNextId();
SPIRVMultipleOperands operands;
@@ -480,7 +482,7 @@ public void emitEntryPointMainKernel(StructuredGraph graph, String kernelName, b
operands = new SPIRVMultipleOperands(array);
}
- if (fp64Capability) {
+ if (fp64Capability && fp16Capability) {
module.add(new SPIRVOpExecutionMode(mainFunctionID, SPIRVExecutionMode.ContractionOff()));
}
@@ -545,23 +547,13 @@ public SPIRVId emitConstantValue(SPIRVKind type, String valueConstant) {
SPIRVId newConstantId = module.getNextId();
SPIRVId typeID = primitives.getTypePrimitive(type);
switch (type) {
- case OP_TYPE_INT_8:
- case OP_TYPE_INT_16:
- case OP_TYPE_INT_32:
- module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentInt(BigInteger.valueOf(Integer.parseInt(valueConstant)))));
- break;
- case OP_TYPE_INT_64:
- module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentLong(BigInteger.valueOf(Integer.parseInt(valueConstant)))));
- break;
- case OP_TYPE_FLOAT_16:
- case OP_TYPE_FLOAT_32:
- module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentFloat(Float.parseFloat(valueConstant))));
- break;
- case OP_TYPE_FLOAT_64:
- module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentDouble(Double.parseDouble(valueConstant))));
- break;
- default:
- throw new RuntimeException("Data type not supported yet: " + type);
+ case OP_TYPE_INT_8, OP_TYPE_INT_16, OP_TYPE_INT_32 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentInt(BigInteger.valueOf(Integer.parseInt(
+ valueConstant)))));
+ case OP_TYPE_INT_64 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentLong(BigInteger.valueOf(Integer.parseInt(valueConstant)))));
+ case OP_TYPE_FLOAT_16 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentHalfFloat(Float.floatToFloat16(Float.parseFloat(valueConstant)))));
+ case OP_TYPE_FLOAT_32 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentFloat(Float.parseFloat(valueConstant))));
+ case OP_TYPE_FLOAT_64 -> module.add(new SPIRVOpConstant(typeID, newConstantId, new SPIRVContextDependentDouble(Double.parseDouble(valueConstant))));
+ default -> throw new TornadoRuntimeException(STR."Data type not supported yet: \{type}");
}
return newConstantId;
}
@@ -634,8 +626,8 @@ public void emitValue(SPIRVCompilationResultBuilder crb, Value value) {
}
public void emitValueOrOp(SPIRVCompilationResultBuilder crb, Value value) {
- if (value instanceof SPIRVLIROp) {
- ((SPIRVLIROp) value).emit(crb, this);
+ if (value instanceof SPIRVLIROp spirvValue) {
+ spirvValue.emit(crb, this);
} else {
emitValue(crb, value);
}
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java
index 60fbbce6be..e2d8e5c775 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/SPIRVHighTier.java
@@ -2,7 +2,7 @@
* This file is part of Tornado: A heterogeneous programming framework:
* https://github.com/beehive-lab/tornadovm
*
- * Copyright (c) 2021, APT Group, Department of Computer Science,
+ * Copyright (c) 2021, 2024, APT Group, Department of Computer Science,
* School of Engineering, The University of Manchester. All rights reserved.
* Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
@@ -52,6 +52,7 @@
import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoLocalMemoryAllocation;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoNewArrayDevirtualizationReplacement;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.memalloc.TornadoPrivateArrayPiRemoval;
+import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoHalfFloatReplacement;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoParallelScheduler;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoSPIRVIntrinsicsReplacements;
import uk.ac.manchester.tornado.drivers.spirv.graal.phases.TornadoTaskSpecialization;
@@ -88,6 +89,8 @@ public SPIRVHighTier(OptionValues options, TornadoDeviceContext deviceContext, C
appendPhase(new TornadoNewArrayDevirtualizationReplacement());
+ appendPhase(new TornadoHalfFloatReplacement());
+
if (PartialEscapeAnalysis.getValue(options)) {
appendPhase(new PartialEscapePhase(true, canonicalizer, options));
}
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java
index ef77c9a9a1..9d676d14b2 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/lir/SPIRVArithmeticTool.java
@@ -2,7 +2,7 @@
* This file is part of Tornado: A heterogeneous programming framework:
* https://github.com/beehive-lab/tornadovm
*
- * Copyright (c) 2021, APT Group, Department of Computer Science,
+ * Copyright (c) 2021, 2024, APT Group, Department of Computer Science,
* School of Engineering, The University of Manchester. All rights reserved.
* Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
@@ -13,7 +13,7 @@
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
@@ -119,6 +119,7 @@ protected Variable emitAdd(LIRKind resultKind, Value a, Value b, boolean setFlag
case OP_TYPE_INT_32:
binaryOp = SPIRVBinaryOp.ADD_INTEGER;
break;
+ case OP_TYPE_FLOAT_16:
case OP_TYPE_FLOAT_32:
case OP_TYPE_FLOAT_64:
binaryOp = SPIRVBinaryOp.ADD_FLOAT;
@@ -158,6 +159,7 @@ protected Variable emitSub(LIRKind resultKind, Value a, Value b, boolean setFlag
break;
case OP_TYPE_FLOAT_64:
case OP_TYPE_FLOAT_32:
+ case OP_TYPE_FLOAT_16:
binaryOp = SPIRVBinaryOp.SUB_FLOAT;
break;
default:
@@ -204,6 +206,7 @@ public Value emitMul(Value a, Value b, boolean setFlags) {
break;
case OP_TYPE_FLOAT_64:
case OP_TYPE_FLOAT_32:
+ case OP_TYPE_FLOAT_16:
binaryOp = SPIRVBinaryOp.MULT_FLOAT;
break;
default:
@@ -251,6 +254,7 @@ public Value emitDiv(Value a, Value b, LIRFrameState state) {
break;
case OP_TYPE_FLOAT_64:
case OP_TYPE_FLOAT_32:
+ case OP_TYPE_FLOAT_16:
binaryOp = SPIRVBinaryOp.DIV_FLOAT;
break;
default:
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java
index 925fea9791..a90b55415d 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVGraphBuilderPlugins.java
@@ -2,7 +2,7 @@
* This file is part of Tornado: A heterogeneous programming framework:
* https://github.com/beehive-lab/tornadovm
*
- * Copyright (c) 2021-2022, APT Group, Department of Computer Science,
+ * Copyright (c) 2021-2022, 2024, APT Group, Department of Computer Science,
* School of Engineering, The University of Manchester. All rights reserved.
* Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
@@ -109,6 +109,7 @@ public static void registerInvocationPlugins(Plugins plugins, final InvocationPl
SPIRVMathPlugins.registerTornadoMathPlugins(invocationPlugins);
SPIRVVectorPlugins.registerPlugins(plugins, invocationPlugins);
+ SPIRVHalfFloatPlugins.registerPlugins(plugins, invocationPlugins);
// Register plugins for Off-Heap Arrays with Panama
registerMemoryAccessPlugins(invocationPlugins);
}
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVHalfFloatPlugins.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVHalfFloatPlugins.java
new file mode 100644
index 0000000000..96c59e9d7a
--- /dev/null
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/compiler/plugins/SPIRVHalfFloatPlugins.java
@@ -0,0 +1,112 @@
+/*
+ * This file is part of Tornado: A heterogeneous programming framework:
+ * https://github.com/beehive-lab/tornadovm
+ *
+ * Copyright (c) 2024, APT Group, Department of Computer Science,
+ * School of Engineering, The University of Manchester. All rights reserved.
+ * Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ */
+package uk.ac.manchester.tornado.drivers.spirv.graal.compiler.plugins;
+
+import org.graalvm.compiler.nodes.ValueNode;
+import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration;
+import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderContext;
+import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugin;
+import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins;
+import org.graalvm.compiler.nodes.graphbuilderconf.NodePlugin;
+
+import jdk.vm.ci.meta.JavaKind;
+import jdk.vm.ci.meta.ResolvedJavaMethod;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.runtime.graal.nodes.AddHalfFloatNode;
+import uk.ac.manchester.tornado.runtime.graal.nodes.DivHalfFloatNode;
+import uk.ac.manchester.tornado.runtime.graal.nodes.HalfFloatPlaceholder;
+import uk.ac.manchester.tornado.runtime.graal.nodes.MultHalfFloatNode;
+import uk.ac.manchester.tornado.runtime.graal.nodes.NewHalfFloatInstance;
+import uk.ac.manchester.tornado.runtime.graal.nodes.SubHalfFloatNode;
+
+public class SPIRVHalfFloatPlugins {
+
+ public static void registerPlugins(final GraphBuilderConfiguration.Plugins ps, final InvocationPlugins plugins) {
+ registerHalfFloatInit(ps, plugins);
+ }
+
+ private static void registerHalfFloatInit(GraphBuilderConfiguration.Plugins ps, InvocationPlugins plugins) {
+
+ final InvocationPlugins.Registration r = new InvocationPlugins.Registration(plugins, HalfFloat.class);
+
+ ps.appendNodePlugin(new NodePlugin() {
+ @Override
+ public boolean handleInvoke(GraphBuilderContext b, ResolvedJavaMethod method, ValueNode[] args) {
+ if (method.getName().equals("
- * %43 = OpLoad %_ptr_CrossWorkgroup_ulong %frame Aligned 8
- * %44 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_ulong %43 %uint_3
- * %45 = OpLoad %ulong %44 Aligned 8
- * OpStore %spirv_l_0F0 %45 Aligned 8
- * %46 = OpLoad %ulong %spirv_l_0F0 Aligned 8
- * %48 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %46
- * %49 = OpExtInst %v2float %1 vloadn %ulong_0 %48 2
- * OpStore %spirv_v2f_1F0 %49 Aligned 8
+ * %43 = OpLoad %_ptr_CrossWorkgroup_ulong %frame Aligned 8
+ * %44 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_ulong %43 %uint_3
+ * %45 = OpLoad %ulong %44 Aligned 8
+ * OpStore %spirv_l_0F0 %45 Aligned 8
+ * %46 = OpLoad %ulong %spirv_l_0F0 Aligned 8
+ * %48 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %46
+ * %49 = OpExtInst %v2float %1 vloadn %ulong_0 %48 2
+ * OpStore %spirv_v2f_1F0 %49 Aligned 8
*
*
* @param crb
- * {@link SPIRVCompilationResultBuilder}
+ * {@link SPIRVCompilationResultBuilder}
* @param asm
- * {@link SPIRVAssembler}
+ * {@link SPIRVAssembler}
*/
@Override
public void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {
@@ -1040,12 +1040,19 @@ public StoreStmt(SPIRVAddressCast cast, MemoryAccess address, Value rhs) {
@Override
protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {
+ Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emit StoreStmt in address: " + cast + " <- " + rhs);
cast.emit(crb, asm);
+ boolean isFP16Cast = cast.getLIRKind().getPlatformKind() == SPIRVKind.OP_TYPE_FLOAT_16;
+
SPIRVId value;
if (rhs instanceof ConstantValue) {
- value = asm.lookUpConstant(((ConstantValue) this.rhs).getConstant().toValueString(), (SPIRVKind) rhs.getPlatformKind());
+ if (isFP16Cast) {
+ value = asm.lookUpConstant(((ConstantValue) this.rhs).getConstant().toValueString(), SPIRVKind.OP_TYPE_FLOAT_16);
+ } else {
+ value = asm.lookUpConstant(((ConstantValue) this.rhs).getConstant().toValueString(), (SPIRVKind) rhs.getPlatformKind());
+ }
} else {
value = asm.lookUpLIRInstructions(rhs);
if (TornadoOptions.OPTIMIZE_LOAD_STORE_SPIRV) {
@@ -1478,8 +1485,8 @@ public IndexedLoadMemAccess(SPIRVUnary.MemoryIndexedAccess address, AllocatableV
@Override
protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {
- Logger.traceCodeGen(Logger.BACKEND.SPIRV,
- "emit IndexedLoadMemAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion().getMemorySpace().getName());
+ Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emit IndexedLoadMemAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion().getMemorySpace()
+ .getName());
SPIRVKind spirvKind = (SPIRVKind) result.getPlatformKind();
SPIRVId type = asm.primitives.getTypePrimitive(spirvKind);
@@ -1525,22 +1532,22 @@ public IndexedLoadMemCollectionAccess(SPIRVUnary.MemoryIndexedAccess address, Al
* (e.g., VectorFloat2)
*
*
- * %301 = OpInBoundsPtrAccessChain %_ptr_Function_double %spirv_d_1F0 %ulong_0 %ulong_0
- * %302 = OpPtrCastToGeneric %_ptr_Generic_double %301
- * %303 = OpExtInst %v2double %1 vloadn %ulong_0 %302 2
- * OpStore %spirv_v2d_50F0 %303 Aligned 16
+ * %301 = OpInBoundsPtrAccessChain %_ptr_Function_double %spirv_d_1F0 %ulong_0 %ulong_0
+ * %302 = OpPtrCastToGeneric %_ptr_Generic_double %301
+ * %303 = OpExtInst %v2double %1 vloadn %ulong_0 %302 2
+ * OpStore %spirv_v2d_50F0 %303 Aligned 16
*
*
* @param crb
- * {@link SPIRVCompilationResultBuilder}
+ * {@link SPIRVCompilationResultBuilder}
* @param asm
- * {@link SPIRVAssembler}
+ * {@link SPIRVAssembler}
*/
@Override
protected void emitCode(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {
- Logger.traceCodeGen(Logger.BACKEND.SPIRV,
- "emit IndexedLoadMemCollectionAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion().getMemorySpace().getName());
+ Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emit IndexedLoadMemCollectionAccess in address: " + address + "[ " + address.getIndex() + "] -- region: " + address.getMemoryRegion()
+ .getMemorySpace().getName());
SPIRVKind spirvKind = (SPIRVKind) result.getPlatformKind();
diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java
index 383b23180d..2ee45f4984 100644
--- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java
+++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/lir/SPIRVUnary.java
@@ -2,7 +2,7 @@
* This file is part of Tornado: A heterogeneous programming framework:
* https://github.com/beehive-lab/tornadovm
*
- * Copyright (c) 2021-2023, APT Group, Department of Computer Science,
+ * Copyright (c) 2021-2024, APT Group, Department of Computer Science,
* School of Engineering, The University of Manchester. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
@@ -534,13 +534,10 @@ public void emitForLoad(SPIRVAssembler asm, SPIRVKind resultKind) {
public static class SPIRVAddressCast extends UnaryConsumer {
- private final SPIRVMemoryBase base;
-
private final Value address;
public SPIRVAddressCast(Value address, SPIRVMemoryBase base, LIRKind valueKind) {
super(null, null, valueKind, address);
- this.base = base;
this.address = address;
}
@@ -562,6 +559,7 @@ public void emit(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {
final SPIRVId idLoad;
SPIRVId addressId = asm.lookUpLIRInstructions(this.address);
+
if (TornadoOptions.OPTIMIZE_LOAD_STORE_SPIRV) {
idLoad = addressId;
} else {
@@ -604,9 +602,7 @@ public ThreadBuiltinCallForSPIRV(SPIRVThreadBuiltIn builtIn, Variable result, LI
*
*
*
- * %37 = OpLoad %v3ulong %__spirv_BuiltInGlobalInvocationId Aligned 32
- * %call = OpCompositeExtract %ulong %37 0
- * %conv = OpUConvert %uint %call
+ * %37 = OpLoad %v3ulong %__spirv_BuiltInGlobalInvocationId Aligned 32 %call = OpCompositeExtract %ulong %37 0 %conv = OpUConvert %uint %call
*
*/
@Override
@@ -740,15 +736,13 @@ public SignNarrowValue(LIRKind lirKind, Variable result, Value inputVal, int toB
}
/**
- * Following this:
- * {@url https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpSConvert}
+ * Following this: {@url https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpSConvert}
*
*
* Convert signed width. This is either a truncate or a sign extend.
*
*
- * OpSConvert can be used for sign extend as well as truncate. The "S" symbol
- * represents signed format.
+ * OpSConvert can be used for sign extend as well as truncate. The "S" symbol represents signed format.
*
* @param crb
* {@link SPIRVCompilationResultBuilder}
@@ -932,14 +926,11 @@ public void emit(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {
}
/**
- * OpenCL Extended Instruction Set Intrinsics. As specified in the SPIR-V 1.0
- * standard, the following intrinsics in SPIR-V represents builtin functions
- * from the OpenCL standard.
+ * OpenCL Extended Instruction Set Intrinsics. As specified in the SPIR-V 1.0 standard, the following intrinsics in SPIR-V represents builtin functions from the OpenCL standard.
*
* For obtaining the correct Int-Reference of the function:
*
*
- * How to test?:
+ * How to test? This test requires at least two devices.
*