Skip to content

Commit

Permalink
Allow users to bind arbitrary memory using raw pointers (#10428)
Browse files Browse the repository at this point in the history
 Add binding external allocation
  Add negative tests
  Add missing return status check
  • Loading branch information
yuslepukhin authored Feb 2, 2022
1 parent 3c96760 commit 91b8ad5
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 49 deletions.
10 changes: 8 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class InferenceSession : IDisposable
/// Dictionary that represents overridableInitializers metadata
/// </summary>
private Dictionary<string, NodeMetadata> _overridableInitializerMetadata;

private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;
private ModelMetadata _modelMetadata = null;
Expand Down Expand Up @@ -998,9 +998,15 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out el_type));
type = (TensorElementType)el_type;
}

Type dotnetType = null;
int width = 0;
TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width);
if (!TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + type.ToString());
}

UIntPtr numDimensions;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ protected virtual void Dispose(bool disposing)
// dispose managed state (managed objects).
if (disposing)
{
if(_disposables != null)
if (_disposables != null)
{
_disposables.Dispose();
_disposables = null;
Expand Down Expand Up @@ -106,10 +106,19 @@ public NativeOnnxTensorMemory(OrtValue ortValue)
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type));
elemType = (TensorElementType)el_type;
}
TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width);

if (!TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + elemType.ToString());
}

if (typeof(T) != type)
throw new NotSupportedException(nameof(NativeOnnxTensorMemory<T>) + " does not support T = " + nameof(T));
{
var message = String.Format("The NativeOnnxTensorMemory<T> type being instantiated for T = : {0} while supplied OrtValue contains T = {1}",
typeof(T), type);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message);
}

ElementType = elemType;
ElementWidth = width;
Expand All @@ -136,7 +145,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue)
Dimensions[i] = (int)shape[i];
}

if (typeof(T) != typeof(string))
if (elemType != TensorElementType.String)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorMutableData(ortValue.Handle, out _dataBufferPointer));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,22 @@ internal static IntPtr[] ConvertNamesToUtf8<T>(IReadOnlyCollection<T> names, Nam

internal static class TensorElementTypeConverter
{
public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
if(result != null)
bool result = true;
TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType);
if(typeInfo != null)
{
type = result.TensorType;
width = result.TypeSize;
type = typeInfo.TensorType;
width = typeInfo.TypeSize;
}
else
{
type = null;
width = 0;
result = false;
}
return result;
}
}
}
79 changes: 76 additions & 3 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Runtime.InteropServices;
using System.Text;
Expand Down Expand Up @@ -61,7 +62,7 @@ internal IntPtr Pointer
}

#region SafeHandle

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
Expand Down Expand Up @@ -257,7 +258,7 @@ public OrtAllocatorType GetAllocatorType()
public override bool Equals(object obj)
{
var other = obj as OrtMemoryInfo;
if(other == null)
if (other == null)
{
return false;
}
Expand All @@ -271,7 +272,7 @@ public override bool Equals(object obj)
/// <returns>true if instances are equal according to OrtCompareMemoryInfo.</returns>
public bool Equals(OrtMemoryInfo other)
{
if(this == other)
if (this == other)
{
return true;
}
Expand Down Expand Up @@ -310,6 +311,78 @@ protected override bool ReleaseHandle()
#endregion
}

/// <summary>
/// This class represents an arbitrary buffer of memory
/// allocated and owned by the user. It can be either a CPU, GPU or other device memory
/// that can be suitably represented by IntPtr.
/// This is just a composite of the buffer related information.
/// The memory is assumed to be pinned if necessary and usable immediately
/// in the native code.
/// </summary>
public class OrtExternalAllocation
{
/// <summary>
/// Constructor
/// </summary>
/// <param name="memInfo">use to accurately describe a piece of memory that this is wrapping</param>
/// <param name="shape">shape of this buffer</param>
/// <param name="elementType">element type</param>
/// <param name="pointer">the actual pointer to memory</param>
/// <param name="sizeInBytes">size of the allocation in bytes</param>
public OrtExternalAllocation(OrtMemoryInfo memInfo, long[] shape, Tensors.TensorElementType elementType, IntPtr pointer, long sizeInBytes)
{
Type type;
int width;
if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + elementType.ToString());
}

if (elementType == TensorElementType.String)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Strings are not supported by this API");
}

var shapeSize = ArrayUtilities.GetSizeForShape(shape);
var requiredBufferSize = shapeSize * width;
if (requiredBufferSize > sizeInBytes)
{
var message = String.Format("Shape of {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes",
shapeSize, requiredBufferSize, sizeInBytes);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message);
}

Info = memInfo;
Shape = shape;
ElementType = elementType;
Pointer = pointer;
Size = sizeInBytes;
}

/// <summary>
/// OrtMemoryInfo
/// </summary>
public OrtMemoryInfo Info { get; private set; }
/// <summary>
/// Shape
/// </summary>
public long[] Shape { get; private set; }
/// <summary>
/// Data type
/// </summary>
public Tensors.TensorElementType ElementType { get; private set; }
/// <summary>
/// Actual memory ptr
/// </summary>
public IntPtr Pointer { get; private set; }
/// <summary>
/// Size of the allocation in bytes
/// </summary>
public long Size { get; private set; }
}

/// <summary>
/// This class represents memory allocation made by a specific onnxruntime
/// allocator. Use OrtAllocator.Allocate() to obtain an instance of this class.
Expand Down
Loading

0 comments on commit 91b8ad5

Please sign in to comment.