Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Tensor API, Tensor Utilities and compatiblity with ONNX RT #369

Merged
merged 77 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
effe843
Add Tensor and DType classes, and tensor testing
mikepapadim Feb 27, 2024
87e90ad
Refactor DataType enum into several class files
mikepapadim Feb 29, 2024
e704149
wip
mikepapadim Feb 29, 2024
4950843
add tensor
mikepapadim Feb 29, 2024
546c15b
Remove Tensor classes and update DType subclasses
mikepapadim Mar 5, 2024
0b393f9
Refactor TensorArray class to Tensor and update its methods
mikepapadim Mar 5, 2024
c2593e3
Rename HalfFloat class to HF and update references
mikepapadim Mar 5, 2024
367a6dd
Update Tensor to support HalfFloat type
mikepapadim Mar 5, 2024
3f05f59
Update Tensor to support HalfFloat type
mikepapadim Mar 5, 2024
96fc95c
Update Tensor to support HalfFloat type
mikepapadim Mar 5, 2024
6a6f6fa
Refactored HalfFloat handling and enabled Tensor support
mikepapadim Mar 5, 2024
650a51e
Update HalfFloat and Tensor handling across components
mikepapadim Mar 6, 2024
e73a7cd
WIP
mikepapadim Mar 6, 2024
00948c9
Refactor and removed verbose logs from OCLTornadoDevice and OCLObject…
mikepapadim Mar 6, 2024
41a01f6
Refactor TornadoDataflowAnalysis code for clarity and efficiency
mikepapadim Mar 6, 2024
baa9acc
Refactor guard elimination in TornadoHalfFloatFixedGuardElimination
mikepapadim Mar 6, 2024
dc3dabd
Remove array creation methods in DType
mikepapadim Mar 6, 2024
d1efa26
Add tensor addition test and remove redundancy in TestTensorTypes
mikepapadim Mar 6, 2024
9620c05
Add tensor addition test and remove redundancy in TestTensorTypes
mikepapadim Mar 6, 2024
91c1d87
Refactor code block in OCLHotSpotBackendFactory.java
mikepapadim Mar 6, 2024
5f18365
Remove calculateSize method from DType.java
mikepapadim Mar 6, 2024
f36373d
Add Tensor creation from array and improve unit tests
mikepapadim Mar 11, 2024
086de25
Merge branch 'mikepapadim/segment_slice' into mikepapadim/tensors_v0.5
mikepapadim Mar 11, 2024
5b6aabc
Add detailed comments and improve documentation in Tensor modules
mikepapadim Mar 13, 2024
6ba50f8
Refactor tensor unit tests and add random data population
mikepapadim Mar 13, 2024
4520e84
Refactor DType.java and enhance documentation
mikepapadim Mar 13, 2024
9e2af7d
Add tensor constructor and float buffer conversion.
mikepapadim Mar 17, 2024
4655bd5
Fix merge conflicts
mikepapadim Mar 25, 2024
58d8216
Add tensor types and adjust base class permissions
mikepapadim Mar 25, 2024
f32c116
Update and extend tensor types for data handling.
mikepapadim Mar 25, 2024
a6eccba
Added AbstractTensor interface to Tornado API.
mikepapadim Mar 25, 2024
799973f
Refactored Tornado API with AbstractTensor interface and updated clas…
mikepapadim Mar 26, 2024
ebe51cf
Add Apache license header to tensor API classes
mikepapadim Mar 26, 2024
3ab369c
Update tensor classes and remove debug prints.
mikepapadim Mar 26, 2024
b3fe88b
Expand tensor classes with functions and attributes.
mikepapadim Mar 26, 2024
5f10c5a
Refactor TornadoNativeArray class for readability.
mikepapadim Mar 26, 2024
ce2ffcd
Add onnxruntime dependency to unittests module
mikepapadim Mar 27, 2024
a2f091b
Add TestTensorAPIWithOnnx class in tornado-unittests
mikepapadim Mar 27, 2024
117790d
Modify Shape class to use long instead of int
mikepapadim Mar 28, 2024
7fa2853
Add toHeapArray and getFloatBuffer methods to TensorFloat32
mikepapadim Mar 28, 2024
a0d8ec4
Refactored TestTensorAPIWithOnnx unit test
mikepapadim Mar 28, 2024
49102c4
Refactored the TestTensorAPIWithOnnx unit
mikepapadim Mar 28, 2024
1601180
Remove unused code in TornadoNativeTypeElimination
mikepapadim Mar 29, 2024
81994ff
Update model path in TestTensorAPIWithOnnx.
mikepapadim Mar 29, 2024
bd24d3e
Merge branch 'develop' of github.com:beehive-lab/TornadoVM into feat/…
mikepapadim Mar 29, 2024
0b4d1c1
Apply PR comments, first iteration
mikepapadim Apr 3, 2024
cf5842c
Refactor tensor API test to assert string outputs
mikepapadim Apr 3, 2024
f8da4ee
Add tensor byte addition test to TensorTypes unit tests
mikepapadim Apr 3, 2024
ab5f0fd
Apply formatter
mikepapadim Apr 3, 2024
dfd35f9
Refactor tensor class names and remove unused imports
mikepapadim Apr 4, 2024
5dbe798
Refactored tensor classes to extend from the Tensor class instead of …
mikepapadim Apr 4, 2024
23fc46e
Refactor TornadoNativeArray class to add permits on Tensor
mikepapadim Apr 4, 2024
de3b57d
Update TensorFloat variable names to TensorFP
mikepapadim Apr 4, 2024
4af314c
Add new class Tensor as part of Tornado API
mikepapadim Apr 4, 2024
45ede41
Improve Onnx compatibility test by adding model downloading.
mikepapadim Apr 4, 2024
ae5aeba
Update Tensor API test and add new ones in TornadoVM
mikepapadim Apr 4, 2024
09a8c01
Merge remote-tracking branch 'origin' into feat/tensors_api_v0.1
mikepapadim Apr 4, 2024
c717299
Add concat function to various Tensor classes in TornadoVM
mikepapadim Apr 4, 2024
c92746c
Add initialize method to TensorFP32 and TensorFP64 classes
mikepapadim Apr 4, 2024
337ddbe
Add buffer conversion methods to tensor classes
mikepapadim Apr 4, 2024
ab580f4
Refactor Shape class and add documentation.
mikepapadim Apr 4, 2024
5a29282
Remove binary of ONNX model
mikepapadim Apr 4, 2024
1192798
Update Tensor classes and remove model cleanup function
mikepapadim Apr 4, 2024
f53e037
Update Tensor constructors to call superclass constructor
mikepapadim Apr 4, 2024
6d132dd
Update Tensor class to abstract
mikepapadim Apr 5, 2024
fb6416a
Remove isNull node from FixedGuardNode
mairooni Apr 5, 2024
efa781d
Merge branch 'feat/tensors_api_v0.1' of github.com:mikepapadim/Tornad…
mairooni Apr 5, 2024
ff3189c
add fix for tensorfp16
mikepapadim Apr 5, 2024
7033276
Remove TensorFP16 handling in OCLTornadoDevice
mikepapadim Apr 8, 2024
2cd3e6a
Update tornado-unittests/src/main/java/uk/ac/manchester/tornado/unitt…
mikepapadim Apr 8, 2024
43ce7d1
Update tornado-unittests/src/main/java/uk/ac/manchester/tornado/unitt…
mikepapadim Apr 8, 2024
f2e295a
Update tornado-unittests/src/main/java/uk/ac/manchester/tornado/unitt…
mikepapadim Apr 8, 2024
cec52f9
Update tornado-unittests/src/main/java/uk/ac/manchester/tornado/unitt…
mikepapadim Apr 8, 2024
42688dd
Remove TensorFP16 handling in OCLTornadoDevice
mikepapadim Apr 8, 2024
68c10c5
Update Tensor import
mikepapadim Apr 8, 2024
ed4c6d9
Update Tensor import
mikepapadim Apr 9, 2024
86ecf73
Update Tensor class
mikepapadim Apr 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,11 @@
</build>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.1</version>
</dependency>
<dependency>
<groupId>com.puppycrawl.tools</groupId>
<artifactId>checkstyle</artifactId>
Expand Down
2 changes: 2 additions & 0 deletions tornado-api/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
exports uk.ac.manchester.tornado.api.types.vectors;
opens uk.ac.manchester.tornado.api.types.vectors;
exports uk.ac.manchester.tornado.api.types;
exports uk.ac.manchester.tornado.api.types.tensors;
opens uk.ac.manchester.tornado.api.types.tensors;
opens uk.ac.manchester.tornado.api.types;
opens uk.ac.manchester.tornado.api.runtime;
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,6 @@ public HalfFloat(short halfFloat) {
this.halfFloatValue = halfFloat;
}

/**
* Gets the half-float stored in the class.
*
* @return The half float value stored in the {@code HalfFloat} object.
*/
public short getHalfFloatValue() {
return this.halfFloatValue;
}

/**
* Gets the half-float stored in the class in a 32-bit representation.
*
* @return The float-32 equivalent value the half float stored in the {@code HalfFloat} object.
*/
public float getFloat32() {
return Float.float16ToFloat(halfFloatValue);
}

/**
* Takes two half float values, converts them to a 32-bit representation and performs an addition.
*
Expand Down Expand Up @@ -187,4 +169,27 @@ public static HalfFloat div(HalfFloat a, HalfFloat b) {
return new HalfFloat(result);
}

/**
* Gets the half-float stored in the class.
*
* @return The half float value stored in the {@code HalfFloat} object.
*/
public short getHalfFloatValue() {
return this.halfFloatValue;
}

/**
* Gets the half-float stored in the class in a 32-bit representation.
*
* @return The float-32 equivalent value the half float stored in the {@code HalfFloat} object.
*/
public float getFloat32() {
return Float.float16ToFloat(halfFloatValue);
}

@Override
public String toString() {
return String.format("HalfFloat: %.4f", getFloat32());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public static HalfFloatArray fromSegment(MemorySegment segment) {
long byteSize = segment.byteSize();
int numElements = (int) (byteSize / HALF_FLOAT_BYTES);
HalfFloatArray halfFloatArray = new HalfFloatArray(numElements);
MemorySegment.copy(segment, 0, halfFloatArray.segment, halfFloatArray.baseIndex * HALF_FLOAT_BYTES, byteSize);
MemorySegment.copy(segment, 0, halfFloatArray.segment, (long) halfFloatArray.baseIndex * HALF_FLOAT_BYTES, byteSize);
return halfFloatArray;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.lang.foreign.MemorySegment;

import uk.ac.manchester.tornado.api.types.tensors.Tensor;

/**
* This abstract sealed class represents the common functionality of the TornadoVM custom native arrays,
* (e.g., {@link ByteArray}, {@link IntArray}, etc.)
Expand All @@ -33,7 +35,10 @@
* The constant {@link ARRAY_HEADER} represents the size of the header in bytes.
* </p>
*/
public abstract sealed class TornadoNativeArray permits ByteArray, CharArray, DoubleArray, FloatArray, IntArray, LongArray, ShortArray, HalfFloatArray {
public abstract sealed class TornadoNativeArray //
permits ByteArray, CharArray, DoubleArray, //
FloatArray, HalfFloatArray, IntArray, //
LongArray, ShortArray, Tensor {

/**
* The size of the header in bytes. The default value is 24, but it can be configurable through
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.api.types.tensors;

import java.lang.foreign.ValueLayout;

/**
* The {@code DType} enum represents the various data types can be found in models.
*/
public enum DType {
mikepapadim marked this conversation as resolved.
Show resolved Hide resolved
// @formatter:off
/**
* Represents a half-precision floating-point data type using 2 bytes.
*/
HALF_FLOAT(2, ValueLayout.JAVA_SHORT),
/**
* Represents a single-precision 32-bit IEEE floating-point data type using 4 bytes.
*/
FLOAT(4, ValueLayout.JAVA_FLOAT),
/**
* Represents a double-precision 64-bit IEEE floating-point data type using 8 bytes.
*/
DOUBLE(8, ValueLayout.JAVA_DOUBLE),
/**
* Represents an 8-bit signed integer data type using 1 byte.
*/
INT8(1, ValueLayout.JAVA_BYTE),
/**
* Represents a 16-bit signed integer data type using 2 bytes.
*/
INT16(2, ValueLayout.JAVA_SHORT),
/**
* Represents a 32-bit signed integer data type using 4 bytes.
*/
INT32(4, ValueLayout.JAVA_INT),
/**
* Represents a 64-bit signed integer data type using 8 bytes.
*/
INT64(8, ValueLayout.JAVA_LONG),
/**
* Represents an 8-bit unsigned integer data type using 1 byte.
*/
UINT8(1, ValueLayout.JAVA_BYTE),
/**
* Represents a boolean data type using 1 byte for true/false values.
*/
BOOL(1, ValueLayout.JAVA_BYTE),
/**
* Represents a quantized 8-bit signed integer used in specialized applications like machine learning, using 1 byte.
*/
QINT8(1, ValueLayout.JAVA_BYTE),
/**
* Represents a quantized 8-bit unsigned integer used in specialized applications like machine learning, using 1 byte.
*/
QUINT8(1, ValueLayout.JAVA_BYTE);
// @formatter:on

/**
* The size of the data type in bytes.
*/
private final int size;

/**
* The layout of the data type in memory.
*/
private final ValueLayout layout;

/**
* Constructs an instance of the enum constant with the specified size and memory layout.
*
* @param size
* The size of the data type in bytes.
* @param layout
* The {@link ValueLayout} specifying how the data is laid out in memory.
*/
DType(int size, ValueLayout layout) {
this.size = size;
this.layout = layout;
}

/**
* Returns the size of the data type in bytes.
*
* @return The size of the data type.
*/
public int getByteSize() {
return size;
}

/**
* Returns the {@link ValueLayout} of the data type, which describes how the data is laid out in memory.
*
* @return The memory layout of the data type.
*/
public ValueLayout getLayout() {
return layout;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.api.types.tensors;

import java.util.Arrays;

public record Shape(long... dimensions) {

/**
* Returns the rank of the shape, which is the number of dimensions.
*
* @return the number of dimensions of the shape
*/
public int getRank() {
return dimensions.length;
}

/**
* Returns of the dimensions of the shape.
*
* @return an array of long values representing the dimensions of the shape
*/
public long[] getDimensions() {
return dimensions;
}

/**
* Calculates and returns the size of the shape, which is the product of all its dimensions.
*
* @return the total size of the shape as an int
*/
public int getSize() {
return (int) Arrays.stream(dimensions).reduce(1, (a, b) -> a * b);
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
Shape shape = (Shape) o;
return Arrays.equals(dimensions, shape.dimensions);
}

@Override
public int hashCode() {
return Arrays.hashCode(dimensions);
}

@Override
public String toString() {
return STR."Shape{dimensions=\{Arrays.toString(dimensions)}}";
}

/**
* Generates a string representation of the shape compatible with TensorFlow's shape format.
*
* @return a string representing the shape in TensorFlow's format
*/

public String toTensorFlowShapeString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < dimensions.length; i++) {
sb.append(dimensions[i]);
if (i < dimensions.length - 1) {
sb.append(",");
}
}
sb.append("]");
return sb.toString();
}

/**
* Generates a string representation of the shape compatible with ONNX's shape format.
*
* @return a string representing the shape in ONNX's format
*/
public String toONNXShapeString() {
StringBuilder sb = new StringBuilder();
sb.append("{");
for (int i = 0; i < dimensions.length; i++) {
sb.append("dim_").append(i).append(": ").append(dimensions[i]);
if (i < dimensions.length - 1) {
sb.append(", ");
}
}
sb.append("}");
return sb.toString();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.api.types.tensors;

import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;

public abstract non-sealed class Tensor extends TornadoNativeArray {
private final DType dtype;
private final Shape shape;

protected Tensor(DType dtype, Shape shape) {
this.dtype = dtype;
this.shape = shape;
}

public abstract Shape getShape();

public abstract String getDTypeAsString();

public abstract DType getDType();

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking at this class. This only works for FP32 and FP16, but then there are the rest of the classes, such as Int, byte, etc., that are built using the specialized classes.

Loading