Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support half precision floating point types #313

Merged
merged 22 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cd86683
Create new half-float type and array
mairooni Dec 14, 2023
7ccb7bf
Include unittests to check the current functionality of half floats
mairooni Dec 14, 2023
bf4fb52
Merge branch 'develop' into feat/float16
mairooni Dec 14, 2023
81c043b
Prototype for half float addition
mairooni Dec 21, 2023
b90cb07
Merge branch 'develop' into feat/float16
mairooni Dec 21, 2023
1c9889d
Support the creation of a new HalfFloat instance in the kernel
mairooni Jan 8, 2024
eec336e
A node that represents the creation of a new halfnode instance
mairooni Jan 8, 2024
68bdd10
Add support for subtraction, multiplication and division between half…
mairooni Jan 8, 2024
be02782
[WIP] Support float-16 types for SPIRV
mairooni Jan 9, 2024
d0d6f18
Remove unused method
mairooni Jan 9, 2024
f784ce1
Support half-float type for ptx backend
mairooni Jan 18, 2024
31e38a7
Merge branch 'develop' into feat/float16
mairooni Jan 18, 2024
8d7eec4
Merge branch 'develop' into feat/float16
mairooni Jan 22, 2024
fccdbec
Merge branch 'develop' into feat/float16
mairooni Jan 22, 2024
45ae694
[half][spirv] Initial support for FP16 (Half) for SPIR-V
jjfumero Jan 22, 2024
53481f4
Merge remote-tracking branch 'refs/remotes/origin/feat/float16' into …
jjfumero Jan 22, 2024
64fdd76
Update license headers with the latest date
mairooni Jan 23, 2024
95a7557
[spirv][half] Initialization supported
jjfumero Jan 23, 2024
7a2f41d
[spirv] Toolkit dependency updated
jjfumero Jan 23, 2024
7b80f92
Merge remote-tracking branch 'refs/remotes/origin/feat/float16' into …
jjfumero Jan 23, 2024
805669e
Query the opencl extensions to check if fp16 is supported
mairooni Jan 24, 2024
f2579d1
Merge branch 'develop' into feat/float16
mairooni Jan 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/compile
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def build_spirv_toolkit_and_level_zero(rebuild=False):

if (rebuild or build):
os.chdir(spirv_tool_kit)
subprocess.run(["git", "pull", "origin", "master"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be reverted

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We keep this until we merge the code from the Beehive SPIR-V Toolkit

subprocess.run(["git", "pull", "origin", "feat/half"])
subprocess.run(["mvn", "clean", "package"])
subprocess.run(["mvn", "install"])
os.chdir(current)
Expand Down
1 change: 1 addition & 0 deletions tornado-api/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@
opens uk.ac.manchester.tornado.api.types.volumes;
exports uk.ac.manchester.tornado.api.types.vectors;
opens uk.ac.manchester.tornado.api.types.vectors;
exports uk.ac.manchester.tornado.api.types;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.api.types;

/**
* This class represents a float-16 instance (half float). The data is stored in a short field, to be
* compliant with the representation for float-16 used in the {@link Float} class. The class encapsulates
* methods for getting the data in float-16 and float-32 format, and for basic arithmetic operations (i.e.
* addition, subtraction, multiplication and division).
*/
public class HalfFloat {

private short halfFloatValue;

/**
* Constructs a new instance of the {@code HalfFloat} out of a float value.
* To convert the float to a float-16, the floatToFloat16 function of the {@link Float}
* class is used.
*
* @param halfFloat
* The float value that will be stored in a half-float format.
*/
public HalfFloat(float halfFloat) {
this.halfFloatValue = Float.floatToFloat16(halfFloat);
}

/**
* Constructs a new instance of the {@code HalfFloat} with a given short value.
*
* @param halfFloat
* The short value that represents the half float.
*/
public HalfFloat(short halfFloat) {
this.halfFloatValue = halfFloat;
}

/**
* Gets the half-float stored in the class.
*
* @return The half float value stored in the {@code HalfFloat} object.
*/
public short getHalfFloatValue() {
return this.halfFloatValue;
}

/**
* Gets the half-float stored in the class in a 32-bit representation.
*
* @return The float-32 equivalent value the half float stored in the {@code HalfFloat} object.
*/
public float getFloat32() {
return Float.float16ToFloat(halfFloatValue);
}

/**
* Takes two half float values, converts them to a 32-bit representation and performs an addition.
*
* @param a
* The first float-16 input for the addition.
* @param b
* The second float-16 input for the addition.
* @return The result of the addition.
*/
private static float addHalfFloat(short a, short b) {
float floatA = Float.float16ToFloat(a);
float floatB = Float.float16ToFloat(b);
return floatA + floatB;
}

/**
* Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance
* that contains the results of the addition.
*
* @param a
* The first {@code HalfFloat} input for the addition.
* @param b
* The second {@code HalfFloat} input for the addition.
* @return A new {@HalfFloat} containing the results of the addition.
*/
public static HalfFloat add(HalfFloat a, HalfFloat b) {
float result = addHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue());
return new HalfFloat(result);
}

/**
* Takes two half float values, converts them to a 32-bit representation and performs a subtraction.
*
* @param a
* The first float-16 input for the subtraction.
* @param b
* The second float-16 input for the subtraction.
* @return The result of the subtraction.
*/
private static float subHalfFloat(short a, short b) {
float floatA = Float.float16ToFloat(a);
float floatB = Float.float16ToFloat(b);
return floatA - floatB;
}

/**
* Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance
* that contains the results of the subtraction.
*
* @param a
* The first {@code HalfFloat} input for the subtraction.
* @param b
* The second {@code HalfFloat} input for the subtraction.
* @return A new {@HalfFloat} containing the results of the subtraction.
*/
public static HalfFloat sub(HalfFloat a, HalfFloat b) {
float result = subHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue());
return new HalfFloat(result);
}

/**
* Takes two half float values, converts them to a 32-bit representation and performs a multiplication.
*
* @param a
* The first float-16 input for the multiplication.
* @param b
* The second float-16 input for the multiplication.
* @return The result of the multiplication.
*/
private static float multHalfFloat(short a, short b) {
float floatA = Float.float16ToFloat(a);
float floatB = Float.float16ToFloat(b);
return floatA * floatB;
}

/**
* Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance
* that contains the results of the multiplication.
*
* @param a
* The first {@code HalfFloat} input for the multiplication.
* @param b
* The second {@code HalfFloat} input for the multiplication.
* @return A new {@HalfFloat} containing the results of the multiplication.
*/
public static HalfFloat mult(HalfFloat a, HalfFloat b) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any reason we called it mult and mul

float result = multHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue());
return new HalfFloat(result);
}

/**
* Takes two half float values, converts them to a 32-bit representation and performs a division.
*
* @param a
* The first float-16 input for the division.
* @param b
* The second float-16 input for the division.
* @return The result of the division.
*/
private static float divHalfFloat(short a, short b) {
float floatA = Float.float16ToFloat(a);
float floatB = Float.float16ToFloat(b);
return floatA / floatB;
}

/**
* Takes two {@code HalfFloat} objects and returns a new {@HalfFloat} instance
* that contains the results of the division.
*
* @param a
* The first {@code HalfFloat} input for the division.
* @param b
* The second {@code HalfFloat} input for the division.
* @return A new {@HalfFloat} containing the results of the division.
*/
public static HalfFloat div(HalfFloat a, HalfFloat b) {
float result = divHalfFloat(a.getHalfFloatValue(), b.getHalfFloatValue());
return new HalfFloat(result);
}

}
Loading