diff --git a/.gitignore b/.gitignore
index 61fc6e474b8ba..c9e217e135e43 100644
--- a/.gitignore
+++ b/.gitignore
@@ -39,6 +39,7 @@ onnxprofile_profile_test_*.json
/csharp/packages
/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.targets
/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.props
+/csharp/**/*.vcxproj.user
cmake/external/FeaturizersLibrary/
# Java specific ignores
java/.gradle
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs
index b2a4c2ef47cb2..34e71074d9d9d 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs
@@ -3,11 +3,14 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
-using System.Buffers;
using System.Collections.Generic;
namespace Microsoft.ML.OnnxRuntime
{
+ ///
+ /// Return immutable collection of results
+ ///
+ ///
public interface IDisposableReadOnlyCollection : IReadOnlyCollection, IDisposable
{
@@ -52,7 +55,6 @@ public void Dispose()
///
/// This class serves as a container for model run output values including
/// tensors, sequences of tensors, sequences and maps.
- /// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type
/// The class must be disposed of.
/// It disposes of _ortValueHolder that owns the underlying Ort output value and
/// anything else that would need to be disposed by the instance of the class.
@@ -70,24 +72,46 @@ public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable
/// Managed object created to represent output value, such as DenseTensor
/// List or Dictionary
///
- /// Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary()
- /// or AsEnumerable()
/// Tensor element type if value type is a Tensor
/// Object that holds native resources.
/// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be
/// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented.
- private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder)
- : base(name, value)
+ private DisposableNamedOnnxValue(string name, Object value, TensorElementType elementType, IOrtValueOwner ortValueHolder)
+ : base(name, value, OnnxValueType.ONNX_TYPE_TENSOR)
{
_ortValueHolder = ortValueHolder;
- ValueType = onnxValueType;
ElementType = elementType;
}
///
- /// Returns OnnxValueType
+ /// Ctor for non-tensor values
+ ///
+ ///
+ ///
+ ///
+ ///
+ private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, IOrtValueOwner ortValueHolder)
+ : base(name, value, onnxValueType)
+ {
+ _ortValueHolder = ortValueHolder;
+ ElementType = TensorElementType.DataTypeMax;
+ }
+
+ ///
+ /// Construct an instance that would contain a map in a form of a Dictionary
+ /// Currently a limited number of primitive types are supported as map keys and values.
+ /// So this is not a full implementation of the map type.
///
- public OnnxValueType ValueType { get; }
+ ///
+ ///
+ ///
+ ///
+ private DisposableNamedOnnxValue(string name, Object value, MapHelper mapHelper, IOrtValueOwner ortValueHolder)
+ : base(name, value, mapHelper)
+ {
+ _ortValueHolder = ortValueHolder;
+ ElementType = TensorElementType.DataTypeMax;
+ }
///
/// Only valid if ValueType is Tensor
@@ -101,22 +125,70 @@ private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxVa
/// to do, as this class maintains a native buffer via _ortValueHolder and the memory will be
/// disposed by it. This is the case when we are dealing with an OrtValue that is backed by native memory
/// and not by pinned managed memory.
+ ///
+ /// This class is generally used for outputs to be created on top of the output OrtValue,
+ /// but the interface (derived from NamedOnnxValue) allows it to be passed as input and one of the test
+ /// cases does it. Unless we deprecate and re-do the interface, we must support it.
///
/// always set to null
/// An instance of OrtValue that does not own underlying memory
- internal override OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle)
+ internal override OrtValue InputToOrtValue(NodeMetadata metadata, out IDisposable memoryHolder)
{
- if(_ortValueHolder == null)
+ if (_ortValueHolder == null)
{
throw new InvalidOperationException("The instance of this class does not own any OrtValues");
}
// PinnedMemoryHandle holds the default value as DisposableNamedOnnxValue
// doesn't hold any managed buffer (that needs to be pinned)
- pinnedMemoryHandle = null;
+ memoryHolder = null;
// Return non-owning instance of OrtValue
return _ortValueHolder.Value;
}
+ ///
+ /// Generally, this class is created on top of the values that are returned by the model run.
+ /// So, this method is not expected to be called. However, if it is called (an instance fed as output),
+ /// it will return the OrtValue that was previously created, since the caller must understand what they are doing.
+ ///
+ ///
+ ///
+ ///
+ internal override OrtValue OutputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner)
+ {
+ return InputToOrtValue(metadata, out memoryOwner);
+ }
+
+ internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue)
+ {
+ return CreateFromOrtValue(name, ortValue, OrtAllocator.DefaultInstance);
+ }
+
+ internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue, OrtAllocator allocator)
+ {
+ DisposableNamedOnnxValue result = null;
+
+ IntPtr valueType;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(ortValue.Handle, out valueType));
+ OnnxValueType onnxValueType = (OnnxValueType)valueType;
+ switch (onnxValueType)
+ {
+ case OnnxValueType.ONNX_TYPE_TENSOR:
+ result = FromNativeTensor(name, ortValue);
+ break;
+
+ case OnnxValueType.ONNX_TYPE_SEQUENCE:
+ result = FromNativeSequence(name, ortValue, allocator);
+ break;
+
+ case OnnxValueType.ONNX_TYPE_MAP:
+ result = FromNativeMap(name, ortValue, allocator);
+ break;
+ default:
+ throw new NotSupportedException("OnnxValueType : " + onnxValueType + " is not supported");
+ }
+ return result;
+ }
+
///
/// Creates an instance of DisposableNamedOnnxValue and takes ownership of ortValueElement
/// on success.
@@ -124,7 +196,7 @@ internal override OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle)
/// name of the value
/// underlying OrtValue
///
- internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, OrtValue ortValue)
+ private static DisposableNamedOnnxValue FromNativeTensor(string name, OrtValue ortValue)
{
DisposableNamedOnnxValue result = null;
@@ -146,46 +218,46 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name,
switch (elemType)
{
case TensorElementType.Float:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.Double:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.Int16:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.UInt16:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.Int32:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.UInt32:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.Int64:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.UInt64:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.UInt8:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.Int8:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.String:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.Bool:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.Float16:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
case TensorElementType.BFloat16:
- result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue);
+ result = FromNativeTensor(name, ortValue);
break;
default:
throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported");
@@ -195,37 +267,6 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name,
return result;
}
- internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue)
- {
- return CreateFromOrtValue(name, ortValue, OrtAllocator.DefaultInstance);
- }
-
- internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue, OrtAllocator allocator)
- {
- DisposableNamedOnnxValue result = null;
-
- IntPtr valueType;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(ortValue.Handle, out valueType));
- OnnxValueType onnxValueType = (OnnxValueType)valueType;
- switch (onnxValueType)
- {
- case OnnxValueType.ONNX_TYPE_TENSOR:
- result = CreateTensorFromOnnxValue(name, ortValue);
- break;
-
- case OnnxValueType.ONNX_TYPE_SEQUENCE:
- result = DisposableNamedOnnxValueFromSequence(name, ortValue, allocator);
- break;
-
- case OnnxValueType.ONNX_TYPE_MAP:
- result = DisposableNamedOnnxValueFromNativeMap(name, ortValue, allocator);
- break;
- default:
- throw new NotSupportedException("OnnxValueType : " + onnxValueType + " is not supported");
- }
- return result;
- }
-
///
/// This method creates an instance of DisposableNamedOnnxValue that has possession of ortValueElement
/// native memory Tensor and returns it to the caller. The original ortValueElement argument looses
@@ -236,34 +277,26 @@ internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValu
/// name of the output
/// native tensor
/// DisposableNamedOnnxValue instance
- private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor(string name, OrtValue ortValue)
+ private static DisposableNamedOnnxValue FromNativeTensor(string name, OrtValue ortValue)
{
- if (typeof(T) == typeof(string))
+ var ortValueTensor = new OrtValueTensor(ortValue);
+ try
{
- var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue);
- try
+ if (typeof(T) == typeof(string))
{
- var dt = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions);
- return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper);
- } catch(Exception)
+ var dt = new DenseTensor(ortValueTensor.GetBytesAsStringMemory(), ortValueTensor.Dimensions);
+ return new DisposableNamedOnnxValue(name, dt, ortValueTensor.ElementType, ortValueTensor);
+ }
+ else
{
- nativeTensorWrapper.Dispose();
- throw;
+ DenseTensor dt = new DenseTensor(ortValueTensor.Memory, ortValueTensor.Dimensions);
+ return new DisposableNamedOnnxValue(name, dt, ortValueTensor.ElementType, ortValueTensor);
}
}
- else
+ catch (Exception)
{
- NativeOnnxTensorMemory nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue);
- try
- {
- DenseTensor dt = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions);
- return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper);
- }
- catch (Exception)
- {
- nativeTensorWrapper.Dispose();
- throw;
- }
+ ortValueTensor.Dispose();
+ throw;
}
}
@@ -275,7 +308,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor
/// ortValueElement that has native sequence
/// used allocator
/// DisposableNamedOnnxValue
- private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(string name, OrtValue ortValueSequence, OrtAllocator allocator)
+ private static DisposableNamedOnnxValue FromNativeSequence(string name, OrtValue ortValueSequence, OrtAllocator allocator)
{
DisposableNamedOnnxValue result = null;
IntPtr count;
@@ -295,8 +328,8 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(str
}
// NativeOrtValueCollectionOwner will take ownership of ortValueSequence and will make sure sequence
// is also disposed.
- var nativeCollectionManager = new NativeOrtValueCollectionOwner(ortValueSequence, sequence);
- result = new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, TensorElementType.DataTypeMax, nativeCollectionManager);
+ var nativeCollectionManager = new NativeOrtValueCollectionOwner(ortValueSequence, sequence);
+ result = new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, nativeCollectionManager);
}
catch (Exception)
{
@@ -314,7 +347,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(str
/// This function does not take ownership of the map as it we copy all keys an values into a dictionary. We let the caller dispose of it
///
/// DisposableNamedOnnxValue
- private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(string name, OrtValue ortValueMap, OrtAllocator allocator)
+ private static DisposableNamedOnnxValue FromNativeMap(string name, OrtValue ortValueMap, OrtAllocator allocator)
{
DisposableNamedOnnxValue result = null;
// Map processing is currently not recursing. It is assumed to contain
@@ -323,44 +356,87 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(st
// not mapped for client consumption.
using (var cleanUpList = new DisposableList())
{
- // Take possession of the map ortValueElement
IntPtr nativeOnnxValueMapKeys = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 0, allocator.Pointer, out nativeOnnxValueMapKeys));
var ortValueKeys = new OrtValue(nativeOnnxValueMapKeys);
cleanUpList.Add(ortValueKeys);
+ var typeAndShape = IntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape));
+ TensorElementType keyElemType;
+ try
+ {
+ IntPtr el_type;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type));
+ keyElemType = (TensorElementType)el_type;
+ }
+ finally
+ {
+ NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape);
+ }
+
IntPtr nativeOnnxValueMapValues = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 1, allocator.Pointer, out nativeOnnxValueMapValues));
var ortValueValues = new OrtValue(nativeOnnxValueMapValues);
cleanUpList.Add(ortValueValues);
- IntPtr typeAndShape = IntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape));
- TensorElementType elemType = TensorElementType.DataTypeMax;
+ typeAndShape = IntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapValues, out typeAndShape));
+ TensorElementType valueElemType;
try
{
IntPtr el_type;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type));
- elemType = (TensorElementType)el_type;
+ valueElemType = (TensorElementType)el_type;
}
finally
{
NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape);
}
- /// XXX: This code always assumes that the value type is float and makes no checks
- /// similar to that of the key. Also Map type in general can also be another sequence or map,
- /// not just a tensor
- switch (elemType)
+ // The supported combinations of key and value types are taken from the ORT C API.
+ switch (keyElemType)
{
case TensorElementType.Int64:
- result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues);
+ switch (valueElemType)
+ {
+ case TensorElementType.Float:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ case TensorElementType.Double:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ case TensorElementType.Int64:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ case TensorElementType.String:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ default:
+ break;
+ }
break;
case TensorElementType.String:
- result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues);
+ switch (valueElemType)
+ {
+ case TensorElementType.Float:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ case TensorElementType.Double:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ case TensorElementType.Int64:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ case TensorElementType.String:
+ result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues);
+ break;
+ default:
+ break;
+ }
break;
default:
- throw new NotSupportedException("Map of element type: " + elemType + " is not supported");
+ throw new NotSupportedException("Map key type: " + keyElemType + " is not supported");
}
}
return result;
@@ -381,40 +457,78 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(st
/// tensor with map keys.
/// tensor with map values
/// instance of DisposableNamedOnnxValue with Dictionary
- private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMapElements(string name,
+ private static DisposableNamedOnnxValue FromNativeMapElements(string name, OrtValue ortValueMap,
OrtValue ortValueTensorKeys, OrtValue ortValueTensorValues)
{
- using (var nativeTensorWrapperValues = new NativeOnnxTensorMemory(ortValueTensorValues))
+ var listOfKeysValues = new DisposableList();
+ var collOwner = new NativeOrtValueCollectionOwner(ortValueMap, listOfKeysValues);
+ try
{
- var denseTensorValues = new DenseTensor(nativeTensorWrapperValues.Memory, nativeTensorWrapperValues.Dimensions);
+ var tensorKeys = new OrtValueTensor(ortValueTensorKeys);
+ listOfKeysValues.Add(ortValueTensorKeys);
+ var tensorValues = new OrtValueTensor(ortValueTensorValues);
+ listOfKeysValues.Add(ortValueTensorValues);
+ MapHelper mapHelper = null;
if (typeof(K) == typeof(string))
{
- var map = new Dictionary();
- using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys))
+ var denseTensorKeys = new DenseTensor(tensorKeys.GetBytesAsStringMemory(), tensorKeys.Dimensions);
+
+ if (typeof(V) == typeof(string))
+ {
+ var map = new Dictionary();
+ var denseTensorValues = new DenseTensor(tensorValues.GetBytesAsStringMemory(), tensorValues.Dimensions);
+ for (var i = 0; i < denseTensorKeys.Length; i++)
+ {
+ map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i));
+ }
+ mapHelper = new MapHelper(denseTensorKeys, denseTensorValues);
+ return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner);
+ }
+ else
{
- var denseTensorKeys = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions);
+ var map = new Dictionary();
+ var denseTensorValues = new DenseTensor(tensorValues.Memory, tensorValues.Dimensions);
for (var i = 0; i < denseTensorKeys.Length; i++)
{
map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i));
}
- return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null);
+ mapHelper = new MapHelper(denseTensorKeys, denseTensorValues);
+ return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner);
}
}
else
{
- var map = new Dictionary();
- using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys))
+ var denseTensorKeys = new DenseTensor(tensorKeys.Memory, tensorKeys.Dimensions);
+ if (typeof(V) == typeof(string))
{
- var denseTensorKeys = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions);
+ var map = new Dictionary();
+ var denseTensorValues = new DenseTensor(tensorValues.GetBytesAsStringMemory(), tensorValues.Dimensions);
for (var i = 0; i < denseTensorKeys.Length; i++)
{
map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i));
}
- return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null);
+ mapHelper = new MapHelper(denseTensorKeys, denseTensorValues);
+ return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner);
+ }
+ else
+ {
+ var denseTensorValues = new DenseTensor(tensorValues.Memory, tensorValues.Dimensions);
+ var map = new Dictionary();
+ for (var i = 0; i < denseTensorKeys.Length; i++)
+ {
+ map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i));
+ }
+ mapHelper = new MapHelper(denseTensorKeys, denseTensorValues);
+ return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner);
}
}
}
+ catch (Exception)
+ {
+ collOwner.Dispose();
+ throw;
+ }
}
#region IDisposable Support
@@ -425,7 +539,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMapEle
/// true if invoked by Dispose()
protected virtual void Dispose(bool disposing)
{
- if(_disposed)
+ if (_disposed)
{
return;
}
@@ -448,9 +562,7 @@ protected virtual void Dispose(bool disposing)
///
public void Dispose()
{
- // Do not change this code. Put cleanup code in Dispose(bool disposing) above.
Dispose(true);
- GC.SuppressFinalize(this);
}
#endregion
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs
index 98e1833aed8a6..a6be0afdad093 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs
@@ -8,6 +8,7 @@
using System.Linq;
using Microsoft.ML.OnnxRuntime.Tensors;
using System.Buffers;
+using System.Diagnostics;
namespace Microsoft.ML.OnnxRuntime
{
@@ -31,11 +32,21 @@ public class InferenceSession : IDisposable
///
private Dictionary _inputMetadata;
+ ///
+ /// Ordered list of input names
+ ///
+ private List _inputNames;
+
///
/// Dictionary that represent output metadata
///
private Dictionary _outputMetadata;
+ ///
+ /// Ordered list of output names
+ ///
+ private List _outputNames;
+
///
/// Dictionary that represents overridableInitializers metadata
///
@@ -163,6 +174,11 @@ public IReadOnlyDictionary InputMetadata
}
}
+ ///
+ /// Ordered list of input names that can be accessed by index;
+ ///
+ public IReadOnlyList InputNames { get { return _inputNames; } }
+
///
/// Metadata regarding the output nodes, keyed by output names
///
@@ -174,6 +190,11 @@ public IReadOnlyDictionary OutputMetadata
}
}
+ ///
+ /// Ordered list of output names that can be accessed by index.
+ ///
+ public IReadOnlyList OutputNames { get { return _outputNames; } }
+
///
/// Metadata regarding the overridable initializers, keyed by node names
///
@@ -203,7 +224,8 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl
/// Specify a collection of that indicates the input values.
/// Specify a collection of string that indicates the output names to fetch.
/// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.
- public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames)
+ public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs,
+ IReadOnlyCollection outputNames)
{
return Run(inputs, outputNames, _builtInRunOptions);
}
@@ -215,13 +237,15 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl
/// Specify a collection of string that indicates the output names to fetch.
///
/// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.
- public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames, RunOptions options)
+ public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs,
+ IReadOnlyCollection outputNames,
+ RunOptions options)
{
using (var cleanupList = new DisposableList())
{
- var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, cleanupList);
- var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList);
- var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
+ var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, LookupInputMetadata, cleanupList);
+ var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList);
+ var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList);
var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList);
return CreateDisposableResult(ortValues, outputNames);
@@ -238,9 +262,7 @@ public IDisposableReadOnlyCollection Run(
IReadOnlyCollection inputNames,
IReadOnlyCollection inputValues)
{
- string[] outputNames = new string[_outputMetadata.Count];
- _outputMetadata.Keys.CopyTo(outputNames, 0);
- return Run(inputNames, inputValues, outputNames, _builtInRunOptions);
+ return Run(inputNames, inputValues, _outputNames, _builtInRunOptions);
}
///
@@ -279,9 +301,9 @@ public IDisposableReadOnlyCollection Run(
using (var cleanupList = new DisposableList())
{
- var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList);
+ var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList);
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
- var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
+ var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList);
var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList);
@@ -336,11 +358,11 @@ public void Run(
using (var cleanupList = new DisposableList())
{
// prepare inputs
- var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList);
+ var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList);
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
// prepare outputs
- var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
+ var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList);
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false);
NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
@@ -371,7 +393,7 @@ public void Run(
}
///
- ///
+ ///
/// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run.
///
/// Outputs need to be created with correct type and dimension to receive the fetched data.
@@ -386,11 +408,11 @@ public void Run(
{
using (var cleanupList = new DisposableList())
{
- var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList);
- var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList);
+ var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, LookupInputMetadata, cleanupList);
+ var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList);
- var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList);
- var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList);
+ var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, LookupOutputMetadata, cleanupList);
+ var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueForOutput, cleanupList);
NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
_nativeHandle,
@@ -444,11 +466,11 @@ public void Run(
using (var cleanupList = new DisposableList())
{
// prepare inputs
- var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList);
- var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList);
+ var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, LookupInputMetadata, cleanupList);
+ var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList);
// prepare outputs
- var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
+ var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList);
var outputValuesArray = GetOrtValuesHandles(outputValues, false);
NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
@@ -482,7 +504,7 @@ public void Run(
}
///
- ///
+ ///
/// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run.
///
/// Outputs need to be created with correct type and dimension to receive the fetched data.
@@ -505,12 +527,12 @@ public void Run(
using (var cleanupList = new DisposableList())
{
// prepare inputs
- var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList);
+ var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList);
var inputValuesArray = GetOrtValuesHandles(inputValues, true);
// prepare outputs
- var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList);
- var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList);
+ var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, LookupOutputMetadata, cleanupList);
+ var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueForOutput, cleanupList);
NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
_nativeHandle,
@@ -619,22 +641,96 @@ public string EndProfiling()
// Delegate for string extraction from an arbitrary input/output object
private delegate string NameExtractor(TInput input);
+ // delegate to fetch input/output OrtValue
+ private delegate OrtValue OrtValueExtractor(NamedOnnxValue value, NodeMetadata metadata, out IDisposable memOwner);
+
+ // Delegate to lookup metadata for input/initializers/output
+ private delegate NodeMetadata MetadataLookup(string nodeName);
+
+ ///
+ /// Checks if the name is a known input or overridable initializer name
+ /// and if so, returns metadata for it.
+ /// metadata
+ ///
+ ///
+ /// NodeMetadata for the nodeName
+ ///
+ private NodeMetadata LookupInputMetadata(string nodeName)
+ {
+ NodeMetadata meta;
+ if (!_inputMetadata.TryGetValue(nodeName, out meta) &&
+ !_overridableInitializerMetadata.TryGetValue(nodeName, out meta))
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input name: '{nodeName}' is not in the metadata");
+ }
+ return meta;
+ }
+
+ ///
+ /// Checks if the nodeName is a known output name and if so returns metadata for it.
+ ///
+ ///
+ ///
+ ///
+ private NodeMetadata LookupOutputMetadata(string nodeName)
+ {
+ NodeMetadata meta;
+ if (!_outputMetadata.TryGetValue(nodeName, out meta))
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Output name: '{nodeName}' is not in the metadata");
+ }
+ return meta;
+ }
+
+ ///
+ /// Fetches/creates OrtValue for the content of the input
+ ///
+ ///
+ ///
+ ///
+ ///
+ private static OrtValue ExtractOrtValueForInput(NamedOnnxValue input, NodeMetadata metadata, out IDisposable memOwner)
+ {
+ return input.InputToOrtValue(metadata, out memOwner);
+ }
+
+ ///
+ /// Fetches/Creates OrtValue for output
+ ///
+ ///
+ ///
+ ///
+ /// May return null if the onnx value type does not support pre-creation of output OrtValues
+ private static OrtValue ExtractOrtValueForOutput(NamedOnnxValue output, NodeMetadata metadata, out IDisposable memOwner)
+ {
+ return output.OutputToOrtValue(metadata, out memOwner);
+ }
+
///
/// Run helper
///
- /// names to convert to zero terminated utf8 and pin
+ /// names to convert to zero terminated utf8 and pin
+ /// extractor functor that helps extracting names from inputs
+ /// inputs/outputs metadata
/// list to add pinned memory to for later disposal
///
- private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection inputs, NameExtractor extractor,
+ private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection values, NameExtractor nameExtractor,
+ MetadataLookup metaLookup,
DisposableList cleanupList)
{
- var result = new IntPtr[inputs.Count];
- for (int i = 0; i < inputs.Count; ++i)
+ cleanupList.Capacity += values.Count;
+ var result = new IntPtr[values.Count];
+ for (int i = 0; i < values.Count; ++i)
{
- var name = extractor(inputs.ElementAt(i));
- var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
- var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned));
- result[i] = pinnedHandle.Pointer;
+ var name = nameExtractor(values.ElementAt(i));
+ NodeMetadata meta = metaLookup(name);
+ var utf8Name = meta.ZeroTerminatedUtf8Name;
+ Debug.Assert(utf8Name != null);
+ var pinnedHandle = new Memory(utf8Name).Pin();
+ unsafe
+ {
+ result[i] = (IntPtr)pinnedHandle.Pointer;
+ }
cleanupList.Add(pinnedHandle);
}
return result;
@@ -642,28 +738,41 @@ private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection inputs, NameExtrac
///
/// This function obtains ortValues for NamedOnnxValue.
- /// The problem with NamedOnnxValue is that it does not contain any Onnx (OrtValue)
- /// so calling ToOrtValue creates a new instance of OrtValue that needs to be disposed.
+ /// The problem with NamedOnnxValue is that it is not disposable and can not contain any disposable items.
+ /// so calling InputToOrtValue creates a new instance of OrtValue that needs to be disposed.
/// The deriving object DisposableNamedValue actually contains and owns OrtValue and it returns
/// it.
///
- ///
- ///
+ /// a collection of NamedOnnxValues
+ /// Metadata lookup function (input/initializers/output)
+ /// list to cleanup in an exception safe manner
///
- private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, DisposableList cleanupList)
+ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, MetadataLookup metaLookup,
+ OrtValueExtractor ortValueExtractor,
+ DisposableList cleanupList)
{
+ cleanupList.Capacity += values.Count * 2;
IntPtr[] result = new IntPtr[values.Count];
- for (int inputIndex = 0; inputIndex < values.Count; ++inputIndex)
+ for (int valueIndex = 0; valueIndex < values.Count; ++valueIndex)
{
- var input = values.ElementAt(inputIndex);
- MemoryHandle? memHandle;
- var ortValue = input.ToOrtValue(out memHandle);
- if (memHandle.HasValue)
+ var value = values.ElementAt(valueIndex);
+ var meta = metaLookup(value.Name);
+ var ortValue = ortValueExtractor(value, meta, out IDisposable memHolder);
+ if (memHolder != null)
+ {
+ cleanupList.Add(memHolder);
+ }
+ if (ortValue != null)
+ {
+ if (ortValue.IsOwned)
+ cleanupList.Add(ortValue);
+
+ result[valueIndex] = ortValue.Handle;
+ }
+ else
{
- cleanupList.Add(memHandle);
+ result[valueIndex] = IntPtr.Zero;
}
- cleanupList.Add(ortValue);
- result[inputIndex] = ortValue.Handle;
}
return result;
}
@@ -687,7 +796,8 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v
private DisposableList RunImpl(RunOptions options, IntPtr[] inputNames, IntPtr[] inputValues, IntPtr[] outputNames,
DisposableList cleanupList)
{
- var ortValues = new DisposableList(outputNames.Length);
+ cleanupList.Capacity += 1;
+ var ortValues = new DisposableList(outputNames.Length + 1);
cleanupList.Add(ortValues);
IntPtr[] outputValuesArray = new IntPtr[outputNames.Length];
@@ -717,8 +827,7 @@ IDisposableReadOnlyCollection CreateDisposableResult(L
{
for (int i = 0; i < ortValues.Count; i++)
{
- var ortValue = ortValues[i];
- result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ortValue));
+ result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ortValues[i]));
}
}
catch (OnnxRuntimeException)
@@ -769,12 +878,11 @@ private void Init(string modelPath, SessionOptions options,
PrePackedWeightsContainer prepackedWeightsContainer = null)
{
var envHandle = OrtEnv.Handle;
- var session = IntPtr.Zero;
-
+ IntPtr session;
if (prepackedWeightsContainer == null)
{
- NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath),
- options.Handle, out session));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath),
+ options.Handle, out session));
}
else
@@ -791,8 +899,7 @@ private void Init(byte[] modelData, SessionOptions options,
PrePackedWeightsContainer prepackedWeightsContainer = null)
{
var envHandle = OrtEnv.Handle;
- var session = IntPtr.Zero;
-
+ IntPtr session;
if (prepackedWeightsContainer == null)
{
@@ -820,48 +927,57 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options)
_nativeHandle = session;
try
{
-
// Initialize input/output metadata
- _inputMetadata = new Dictionary();
- _outputMetadata = new Dictionary();
- _overridableInitializerMetadata = new Dictionary();
// get input count
- UIntPtr inputCount = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out UIntPtr inputCount));
// get all the input names and metadata
+ _inputMetadata = new Dictionary((int)inputCount);
+ _inputNames = new List((int)inputCount);
+
for (ulong i = 0; i < (ulong)inputCount; i++)
{
- var iname = GetInputName(i);
- _inputMetadata[iname] = GetInputMetadata(i);
+ var inputMeta = GetInputMetadata(i);
+ var iname = GetInputName(i, out byte[] utf8);
+ _inputNames.Add(iname);
+ inputMeta.ZeroTerminatedUtf8Name = utf8;
+ _inputMetadata[iname] = inputMeta;
}
// get output count
- UIntPtr outputCount = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out UIntPtr outputCount));
// get all the output names and metadata
+ _outputMetadata = new Dictionary((int)outputCount);
+ _outputNames = new List((int)outputCount);
+
for (ulong i = 0; i < (ulong)outputCount; i++)
{
- _outputMetadata[GetOutputName(i)] = GetOutputMetadata(i);
+ var outputMeta = GetOutputMetadata(i);
+ var oname = GetOutputName(i, out byte[] utf8);
+ _outputNames.Add(oname);
+ outputMeta.ZeroTerminatedUtf8Name = utf8;
+ _outputMetadata[oname] = outputMeta;
}
// get overridable initializer count
- UIntPtr initilaizerCount = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out initilaizerCount));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out UIntPtr initilaizerCount));
+ _overridableInitializerMetadata = new Dictionary((int)initilaizerCount);
// get all the overridable initializer names and metadata
for (ulong i = 0; i < (ulong)initilaizerCount; i++)
{
- _overridableInitializerMetadata[GetOverridableInitializerName(i)] = GetOverridableInitializerMetadata(i);
+ var meta = GetOverridableInitializerMetadata(i);
+ var iname = GetOverridableInitializerName(i, out byte[] utf8);
+ meta.ZeroTerminatedUtf8Name = utf8;
+ _overridableInitializerMetadata[iname] = meta;
}
// set profiling's start time
- UIntPtr startTime = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetProfilingStartTimeNs(_nativeHandle,
- out startTime));
+ out UIntPtr startTime));
_profilingStartTimeNs = (ulong)startTime;
}
- catch (OnnxRuntimeException)
+ catch (Exception)
{
if (_nativeHandle != IntPtr.Zero)
{
@@ -875,64 +991,60 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options)
}
- private string GetOutputName(ulong index)
+ private string GetOutputName(ulong index, out byte[] utf8)
{
+ string str;
var allocator = OrtAllocator.DefaultInstance;
- IntPtr nameHandle = IntPtr.Zero;
- string str = null;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputName(
_nativeHandle,
(UIntPtr)index,
allocator.Pointer,
- out nameHandle));
+ out IntPtr nameHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
{
- str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle);
+ NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8);
}
return str;
}
- private string GetInputName(ulong index)
+ private string GetInputName(ulong index, out byte[] utf8)
{
- string str = null;
+ string str;
var allocator = OrtAllocator.DefaultInstance;
- IntPtr nameHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputName(
_nativeHandle,
(UIntPtr)index,
allocator.Pointer,
- out nameHandle));
+ out IntPtr nameHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
{
- str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle);
+ NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8);
}
return str;
}
- private string GetOverridableInitializerName(ulong index)
+ private string GetOverridableInitializerName(ulong index, out byte[] utf8)
{
- string str = null;
+ string str;
var allocator = OrtAllocator.DefaultInstance;
- IntPtr nameHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerName(
_nativeHandle,
(UIntPtr)index,
allocator.Pointer,
- out nameHandle));
+ out IntPtr nameHandle));
using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
{
- str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle);
+ NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8);
}
return str;
}
private NodeMetadata GetInputMetadata(ulong index)
{
- IntPtr typeInfo = IntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo));
try
{
return GetMetadataFromTypeInfo(typeInfo);
@@ -945,8 +1057,7 @@ private NodeMetadata GetInputMetadata(ulong index)
private NodeMetadata GetOutputMetadata(ulong index)
{
- IntPtr typeInfo = IntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo));
try
{
return GetMetadataFromTypeInfo(typeInfo);
@@ -959,8 +1070,7 @@ private NodeMetadata GetOutputMetadata(ulong index)
private NodeMetadata GetOverridableInitializerMetadata(ulong index)
{
- IntPtr typeInfo = IntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo));
try
{
return GetMetadataFromTypeInfo(typeInfo);
@@ -975,40 +1085,112 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
{
OnnxValueType valueType;
{
- IntPtr valType;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOnnxTypeFromTypeInfo(typeInfo, out valType));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOnnxTypeFromTypeInfo(typeInfo, out IntPtr valType));
valueType = (OnnxValueType)valType;
}
- if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR)
+
+ switch (valueType)
{
- return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue));
+ case OnnxValueType.ONNX_TYPE_TENSOR:
+ case OnnxValueType.ONNX_TYPE_SPARSETENSOR:
+ return GetTensorNodeMetadata(valueType, typeInfo);
+ case OnnxValueType.ONNX_TYPE_SEQUENCE:
+ return GetSequenceMetadataFromTypeInfo(typeInfo);
+ case OnnxValueType.ONNX_TYPE_MAP:
+ return GetMapMetadataFromTypeInfo(typeInfo);
+ case OnnxValueType.ONNX_TYPE_OPTIONAL:
+ return GetOptionalMetadataFromTypeInfo(typeInfo);
}
- // This should not be released
- IntPtr tensorInfo;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo, out tensorInfo)); //(IntPtr)(int)(uint)
- // Convert the newly introduced OrtTypeInfo* to the older OrtTypeAndShapeInfo*
+ throw new OnnxRuntimeException(ErrorCode.NotImplemented, $"Value type: '{valueType}' not supported in this code");
+ }
- if (tensorInfo == IntPtr.Zero)
- return null;
+ internal static NodeMetadata GetSequenceMetadataFromTypeInfo(IntPtr typeInfo)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToSequenceTypeInfo(typeInfo, out IntPtr sequenceTypeInfo));
+ // Casts API are broken. Always return success, but may return null for the result.
+ if (sequenceTypeInfo == IntPtr.Zero)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to SequenceTypeInfo failed. The object does not represent a sequence");
+ }
- TensorElementType type;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetSequenceElementType(sequenceTypeInfo, out IntPtr elementType));
+ try
{
- IntPtr el_type;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out el_type));
- type = (TensorElementType)el_type;
+ var elementMeta = GetMetadataFromTypeInfo(elementType);
+ var seqMeta = new SequenceMetadata(elementMeta);
+ return new NodeMetadata(seqMeta);
+ }
+ finally
+ {
+ NativeMethods.OrtReleaseTypeInfo(elementType);
+ }
+ }
+
+ internal static NodeMetadata GetMapMetadataFromTypeInfo(IntPtr typeInfo)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToMapTypeInfo(typeInfo, out IntPtr mapTypeInfo));
+ // Casts API are broken. Always return success, but may return null for the result.
+ if (mapTypeInfo == IntPtr.Zero)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to MapTypeInfo failed. The object does not represent a map");
+ }
+
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapKeyType(mapTypeInfo, out IntPtr keyType));
+
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapValueType(mapTypeInfo, out IntPtr valueTypeInfo));
+ try
+ {
+ var valueMetadata = GetMetadataFromTypeInfo(valueTypeInfo);
+ var mapMeta = new MapMetadata((TensorElementType)keyType, valueMetadata);
+ return new NodeMetadata(mapMeta);
+ }
+ finally
+ {
+ NativeMethods.OrtReleaseTypeInfo(valueTypeInfo);
+ }
+ }
+
+ internal static NodeMetadata GetOptionalMetadataFromTypeInfo(IntPtr typeInfo)
+ {
+ // This should not be destroyed
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToOptionalTypeInfo(typeInfo, out IntPtr optTypeInfo));
+ // Casts API are broken. Always return success, but may return null for the result.
+ if (optTypeInfo == IntPtr.Zero)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to OptionalTypeInfo failed. The object does not represent a optional value");
+ }
+
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOptionalContainedTypeInfo(optTypeInfo, out IntPtr elementTypeInfo));
+ try
+ {
+ var elementMetadata = GetMetadataFromTypeInfo(elementTypeInfo);
+ var optMetadata = new OptionalMetadata(elementMetadata);
+ return new NodeMetadata(optMetadata);
+ }
+ finally
+ {
+ NativeMethods.OrtReleaseTypeInfo(elementTypeInfo);
}
+ }
- Type dotnetType = null;
- int width = 0;
- if (!TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width))
+ internal static NodeMetadata GetTensorNodeMetadata(OnnxValueType valueType, IntPtr typeInfo)
+ {
+ // Fetch tensor type and shape from the TypeInfo
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo, out IntPtr tensorInfo)); //(IntPtr)(int)(uint)
+ // Casts API are broken. Always return success, but may return null for the result.
+ if (tensorInfo == IntPtr.Zero)
{
- throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
- "Unable to query type information for data type: " + type.ToString());
+ throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to TensorTypeInfo failed. The object does not represent a tensor");
}
- UIntPtr numDimensions;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions));
+ TensorElementType type;
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out IntPtr el_type));
+ type = (TensorElementType)el_type;
+ }
+
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out UIntPtr numDimensions));
long[] dimensions = new long[(int)numDimensions];
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions));
@@ -1028,7 +1210,8 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
symbolicDimensions[i] = NativeOnnxValueHelper.StringFromNativeUtf8(dimensionNamePtrs[i]);
}
- return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType);
+ var tensorTypeAndShape = new TensorTypeAndShape(type, intDimensions, symbolicDimensions);
+ return new NodeMetadata(valueType, tensorTypeAndShape);
}
///
@@ -1105,23 +1288,27 @@ protected virtual void Dispose(bool disposing)
///
- /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes
+ /// Represents tensor element type and its shapes
///
- public class NodeMetadata
+ public class TensorTypeAndShape
{
- internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] symbolicDimensions, Type type)
+ internal TensorTypeAndShape(TensorElementType elementType, int[] dimensions, string[] symbolicDimensions)
{
- OnnxValueType = onnxValueType;
+ ElementTypeInfo = TensorBase.GetElementTypeInfo(elementType);
+ if (ElementTypeInfo == null)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unregistered TensorElementType value of: " + elementType.ToString());
+ }
+ ElementDataType = elementType;
Dimensions = dimensions;
SymbolicDimensions = symbolicDimensions;
- ElementType = type;
}
///
- /// Type value of the node
+ /// Tensor Element type
///
- /// A value of OnnxValueType enum
- public OnnxValueType OnnxValueType { get; }
+ /// TensorElementType enum
+ public TensorElementType ElementDataType { get; }
///
/// Shape
@@ -1136,10 +1323,259 @@ internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] sy
public string[] SymbolicDimensions { get; }
///
- /// .NET type that corresponds to this Node.
+ /// Tensor element metadata
+ ///
+ public TensorElementTypeInfo ElementTypeInfo { get; }
+ }
+
+ ///
+ /// Represents sequnce metdata
+ ///
+ public class SequenceMetadata
+ {
+ ///
+ /// __ctor
+ ///
+ ///
+ internal SequenceMetadata(NodeMetadata elementData)
+ {
+ ElementMeta = elementData;
+ }
+ ///
+ /// Element Metatada, recursive definition with a Tensor being a base case
+ /// may contain maps, tensors and other sequences
+ ///
+ public NodeMetadata ElementMeta { get; }
+ }
+
+ ///
+ /// The class contains metadata for an optional input/output
+ ///
+ public class OptionalMetadata
+ {
+ ///
+ /// __ctor
+ ///
+ ///
+ internal OptionalMetadata(NodeMetadata elementData)
+ {
+ ElementMeta = elementData;
+ }
+
+ ///
+ /// Element Metatada, recursive definition with a Tensor being a base case
+ /// may contain maps, tensors and sequences
+ ///
+ public NodeMetadata ElementMeta { get; }
+ }
+
+ ///
+ /// Represents Map MetaData.
+ /// Key is always a tensor denoted by an element type
+ /// with value type being a recursive structure that may
+ /// contain other maps, sequences or tensors.
+ ///
+ public class MapMetadata
+ {
+ internal MapMetadata(TensorElementType keyDataType, NodeMetadata valueMetadata)
+ {
+ KeyDataType = keyDataType;
+ ValueMetadata = valueMetadata;
+ }
+
+ ///
+ /// Key tensor data type
+ ///
+ /// A value of TensorElementType enum
+ public TensorElementType KeyDataType { get; }
+
+ ///
+ /// Value metadata
+ ///
+ /// /// Instance of Nodemetadata for the value of the map
+ public NodeMetadata ValueMetadata { get; }
+ }
+
+ ///
+ /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes
+ ///
+ public class NodeMetadata
+ {
+ private readonly Object _metadata;
+ ///
+ /// Constructs NodeMetadata for tensor
+ ///
+ /// either ONNX_TYPE_TENSOR or ONNX_TYPE_SPARSETENSOR
+ /// Tensor type and shape information
+ internal NodeMetadata(OnnxValueType onnxValueType, TensorTypeAndShape typeAndShape)
+ {
+ OnnxValueType = onnxValueType;
+ CheckTensor();
+ _metadata = typeAndShape;
+ }
+
+ ///
+ /// __ctor for map metadata
+ ///
+ ///
+ internal NodeMetadata(MapMetadata mapMetadata)
+ {
+ OnnxValueType = OnnxValueType.ONNX_TYPE_MAP;
+ _metadata = mapMetadata;
+ }
+
+ ///
+ /// __ctor for sequence metadata
+ ///
+ ///
+ internal NodeMetadata(SequenceMetadata sequenceMetadata)
+ {
+ OnnxValueType = OnnxValueType.ONNX_TYPE_SEQUENCE;
+ _metadata = sequenceMetadata;
+ }
+
+ ///
+ /// __ctor
+ ///
+ ///
+ internal NodeMetadata(OptionalMetadata optMetadata)
+ {
+ OnnxValueType = OnnxValueType.ONNX_TYPE_OPTIONAL;
+ _metadata = optMetadata;
+ }
+
+ private void CheckTensor()
+ {
+ if (!IsTensor)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "OnnxValueType must either be a tensor or sparse tensor");
+ }
+ }
+
+ ///
+ /// Retrieves MapMetadata, valid only if this node represents a Map.
+ ///
+ ///
+ /// when the instance does not contain map metadata
+ public MapMetadata AsMapMetadata()
+ {
+ if (OnnxValueType != OnnxValueType.ONNX_TYPE_MAP)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Map metadata");
+ }
+ return _metadata as MapMetadata;
+ }
+
+ ///
+ /// Retrieves SequenceMetadata, valid only if this node represents a Sequence
+ ///
+ ///
+ /// when the instance does not contain sequence metadata
+ public SequenceMetadata AsSequenceMetadata()
+ {
+ if (OnnxValueType != OnnxValueType.ONNX_TYPE_SEQUENCE)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Sequence metadata");
+ }
+ return _metadata as SequenceMetadata;
+ }
+
+ ///
+ /// Retrieves Optional type metadata, valid if this node is optional
+ /// Optional metadata is nothing more than just a container for all the usual
+ /// element types.
+ ///
+ ///
+ ///
+ public OptionalMetadata AsOptionalMetadata()
+ {
+ if (OnnxValueType != OnnxValueType.ONNX_TYPE_OPTIONAL)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Optional metadata");
+ }
+ return _metadata as OptionalMetadata;
+ }
+
+ ///
+ /// Type value of the node
+ ///
+ /// A value of OnnxValueType enum
+ public OnnxValueType OnnxValueType { get; }
+
+ ///
+ /// Zero terminated UTF-8 name of the input/output
+ /// Present only on the top-level instance
+ /// metadata dictionary entries.
+ ///
+ /// Used to avoid utf8 conversions on every run and associated allocations
+ ///
+ internal byte[] ZeroTerminatedUtf8Name { get; set; }
+
+ ///
+ /// Tensor shape valid only if this is a Tensor.
+ /// Preserved for API compatibility
+ ///
+ /// Array of dimensions
+ public int[] Dimensions
+ {
+ get
+ {
+ CheckTensor();
+ return (_metadata as TensorTypeAndShape).Dimensions;
+ }
+ }
+
+ ///
+ /// Symbolic dimensions valid only if this is a Tensor.
+ /// Preserved for API compatibility
+ ///
+ /// Array of symbolic dimensions if present.
+ public string[] SymbolicDimensions
+ {
+ get
+ {
+ CheckTensor();
+ return (_metadata as TensorTypeAndShape).SymbolicDimensions;
+ }
+ }
+
+ ///
+ /// .NET type that corresponds to the primitive Tensor data type.
+ /// Valid only if this is a Tensor.
///
/// System.Type
- public System.Type ElementType { get; }
+ public System.Type ElementType
+ {
+ get
+ {
+ CheckTensor();
+ return (_metadata as TensorTypeAndShape).ElementTypeInfo.TensorType;
+ }
+ }
+
+ ///
+ /// Tensor Element Type. Valid if tensor
+ ///
+ public TensorElementType ElementDataType
+ {
+ get
+ {
+ CheckTensor();
+ return (_metadata as TensorTypeAndShape).ElementDataType;
+ }
+ }
+
+ ///
+ /// Convinience method to check for string
+ ///
+ public bool IsString
+ {
+ get
+ {
+ CheckTensor();
+ return (_metadata as TensorTypeAndShape).ElementTypeInfo.IsString;
+ }
+ }
///
/// Whether it is a Tensor
@@ -1149,7 +1585,7 @@ public bool IsTensor
{
get
{
- return true; // currently only Tensor nodes are supported
+ return (OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) || (OnnxValueType == OnnxValueType.ONNX_TYPE_SPARSETENSOR);
}
}
}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs
new file mode 100644
index 0000000000000..664ba21cfd1bb
--- /dev/null
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs
@@ -0,0 +1,277 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using Microsoft.ML.OnnxRuntime.Tensors;
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+
+namespace Microsoft.ML.OnnxRuntime
+{
+ ///
+ /// The class helps to feed the NamedOnnxValue as inference input.
+ /// It projects managed classes to OrtValues so they can be consumed
+ /// by the native onnxruntime library. if possible, it will avoid copying data.
+ /// The NamedOnnxValue can be a tensor, sequence or map.
+ /// For recursive structures, create nested NamedOnnxValue instances.
+ /// For example, a sequence instance would contain a list of NamedOnnxValue instances
+ /// that in turn may represent tensors or other ONNX values.
+ ///
+ internal class ManagedTypeProjection : IDisposable
+ {
+ readonly DisposableList _disposables;
+ readonly OrtValue _ortValue;
+ bool _disposed = false;
+
+ ///
+ /// Provides access to non-owning instance of OrtValue
+ ///
+ /// Provides access to the OrtValue to be used as input
+ internal OrtValue Value { get { return new OrtValue(_ortValue.Handle, false); } }
+
+ ///
+ /// Constructor to create an input OrtValue projection from managed data
+ ///
+ ///
+ ///
+ ///
+ internal ManagedTypeProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
+ {
+ int requiredCapacity = 32;
+ var disposables = new DisposableList(requiredCapacity);
+ try
+ {
+ _ortValue = CreateDispatchProjection(namedOnnxValue, metadata, disposables);
+ }
+ catch (Exception)
+ {
+ disposables.Dispose();
+ throw;
+ }
+ _disposables = disposables;
+ }
+
+ ///
+ /// Dispatches the creation of the projection
+ ///
+ ///
+ ///
+ ///
+ ///
+ private OrtValue CreateDispatchProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList disposables)
+ {
+ OrtValue result;
+
+ NodeMetadata meta = metadata;
+ // Use element meta to create types
+ if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL)
+ {
+ meta = metadata.AsOptionalMetadata().ElementMeta;
+ }
+
+ if (namedOnnxValue.ValueType != meta.OnnxValueType)
+ {
+ throw new OnnxRuntimeException(ErrorCode.RuntimeException,
+ $"NamedOnnxValue: {namedOnnxValue.Name} has value type: {namedOnnxValue.ValueType}" +
+ $" expected: {meta.OnnxValueType} after optional type adjustment");
+ }
+
+ switch (namedOnnxValue.ValueType)
+ {
+ case OnnxValueType.ONNX_TYPE_TENSOR:
+ result = CreateTensorProjection(namedOnnxValue, meta, disposables);
+ break;
+ case OnnxValueType.ONNX_TYPE_SEQUENCE:
+ result = CreateSequenceProjection(namedOnnxValue, meta, disposables);
+ break;
+ case OnnxValueType.ONNX_TYPE_MAP:
+ result = CreateMapProjection(namedOnnxValue, meta, disposables);
+ break;
+ default:
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "ManagedTypeProjection can only project tensors, sequences, maps and optional types");
+ }
+ return result;
+ }
+
+ ///
+ /// The function creates OrtValue objects for each element of the sequence
+ /// and then creates an OrtValue for the whole sequence.
+ ///
+ /// NamedOnnxValue containing a IEnumeralbe
+ /// sequence metadata
+ /// cleanup list
+ ///
+ ///
+ private OrtValue CreateSequenceProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList disposables)
+ {
+ OrtValue result = null;
+ var elementMeta = metadata.AsSequenceMetadata().ElementMeta;
+ var elementOnnxValue = elementMeta.OnnxValueType;
+ var seqContainer = namedOnnxValue.AsEnumerable();
+
+ if (seqContainer is null)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
+ $"NamedOnnxValue: {namedOnnxValue.Name} sequence does not contain NamedOnnxValue elements");
+ }
+
+ int capacity = 0;
+
+ if (seqContainer is ICollection)
+ {
+ capacity = ((ICollection)seqContainer).Count;
+ }
+
+ // Record all the ortValues belonging to the sequence locally
+ var sequenceOrtValues = new List(capacity);
+ foreach (var element in seqContainer)
+ {
+ if (elementOnnxValue != element.ValueType)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
+ $"NamedOnnxValue: {namedOnnxValue.Name} sequence element expected to be {elementOnnxValue}, received {element.ValueType}");
+ }
+
+ sequenceOrtValues.Add(CreateDispatchProjection(element, elementMeta, disposables));
+ }
+
+ IntPtr[] ortValHandles = new IntPtr[sequenceOrtValues.Count];
+ for (int i = 0; i < sequenceOrtValues.Count; i++)
+ {
+ ortValHandles[i] = sequenceOrtValues[i].Handle;
+ }
+
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles,
+ (UIntPtr)sequenceOrtValues.Count, (IntPtr)OnnxValueType.ONNX_TYPE_SEQUENCE, out IntPtr sequenceHandle));
+ result = new OrtValue(sequenceHandle);
+ disposables.Add(result);
+
+ return result;
+ }
+
+ ///
+ /// Creates map projection. Since we support only primitive types in maps
+ /// we map two tensors (keys and values)
+ ///
+ ///
+ ///
+ ///
+ /// OrtValue
+ ///
+ private OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList disposables)
+ {
+ OrtValue result = null;
+ var mapMeta = elementMeta.AsMapMetadata();
+ Debug.Assert(mapMeta != null);
+ // Maps currently support only primitive types expressed as two parallel tensors and not nested Sequences or Maps
+
+ var mapValuesMeta = mapMeta.ValueMetadata;
+ if (mapValuesMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
+ $"Node: {node.Name} onnxruntime only supports maps with primitive types values");
+ }
+
+
+ var keys = node.GetDictionaryKeys();
+ var ortValueKeys = OrtValue.CreateFromTensorObject(keys,
+ out MemoryHandle? memoryHandleKeys, out TensorElementType elementTypeKeys);
+ disposables.Add(ortValueKeys);
+
+ if (memoryHandleKeys.HasValue)
+ {
+ disposables.Add(memoryHandleKeys);
+ }
+
+ if (elementTypeKeys != mapMeta.KeyDataType)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
+ $"Map key data type supplied: {elementTypeKeys} metadata expected: {mapMeta.KeyDataType}");
+ }
+
+ var values = node.GetDictionaryValues();
+ var ortValueValues = OrtValue.CreateFromTensorObject(values,
+ out MemoryHandle? memoryHandleValues, out TensorElementType elementTypeValues);
+
+ disposables.Add(ortValueValues);
+ if (memoryHandleValues.HasValue)
+ {
+ disposables.Add(memoryHandleValues);
+ }
+
+ if (elementTypeValues != mapValuesMeta.ElementDataType)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
+ $"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}");
+ }
+
+ // Create Map OrtValue
+ IntPtr[] ortValHandles = { ortValueKeys.Handle, ortValueValues.Handle };
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles, (UIntPtr)2,
+ (IntPtr)OnnxValueType.ONNX_TYPE_MAP, out IntPtr ortValueMap));
+ result = new OrtValue(ortValueMap);
+ disposables.Add(result);
+ return result;
+ }
+
+
+ ///
+ /// This pins memory that is contained within DenseTensor.
+ ///
+ /// NodeOnnxValue containing DenseTensor
+ ///
+ /// cleanup list
+ ///
+ ///
+ private OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList disposables)
+ {
+ var ortValue = OrtValue.CreateFromTensorObject(node.Value,
+ out MemoryHandle? memoryHandle, out TensorElementType elementType);
+ disposables.Add(ortValue);
+
+ if (memoryHandle.HasValue)
+ {
+ disposables.Add(memoryHandle);
+ }
+
+ if (elementType != elementMeta.ElementDataType)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
+ $"Tensor element data type discovered: {elementType} metadata expected: {elementMeta.ElementDataType}");
+ }
+
+ return ortValue;
+ }
+
+ #region IDisposable
+ ///
+ /// IDisposable implementation
+ ///
+ /// true if invoked by Dispose()
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ // dispose managed state (managed objects).
+ if (disposing)
+ {
+ _disposables.Dispose();
+ }
+ _disposed = true;
+ }
+
+
+ public void Dispose()
+ {
+ Dispose(true);
+ }
+
+ #endregion IDisposable
+ }
+}
+
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs
index 233e5fa9af87e..5f08daf73806a 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs
@@ -5,38 +5,125 @@
using System;
using System.Buffers;
using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
namespace Microsoft.ML.OnnxRuntime
{
///
- /// The class associates a name with an Object. Currently it supports Tensor
- /// as possible objects. The name of the class is a misnomer, it does not hold any
- /// Onnx values.
+ /// The class holds keys and values for the dictionary
+ /// in a for of two DenseTensors. The class is used to avoid
+ /// data copy and make these available to the native code.
+ /// Strings require special handling.
+ ///
+ internal class MapHelper
+ {
+ internal MapHelper(object keys, object values)
+ {
+ Keys = keys;
+ Values = values;
+ }
+ internal Object Keys { get; } // DenseTensor
+ internal Object Values { get; } // DenseTensor
+ }
+
+ ///
+ /// The class associates a name with an Object.
+ /// The name of the class is a misnomer, it does not hold any Onnx values,
+ /// just managed representation of them.
+ ///
+ /// The class is currently used as both inputs and outputs. Because it is non-
+ /// disposable, it can not hold on to any native objects.
+ ///
+ /// When used as input, we temporarily create OrtValues that map managed inputs
+ /// directly. Thus we are able to avoid copying.
+ ///
+ /// For outputs, tensor buffers works the same as input, providing it matches
+ /// the expected output shape. For other types (maps and sequences) we create a copy of the data.
+ /// This is because, the class is not Disposable and it is a public interface, thus it can not own
+ /// the underlying OrtValues that must be destroyed before Run() returns.
+ ///
+ /// To avoid data copying on output, use DisposableNamedOnnxValue class that is returned from Run() methods.
+ /// This provides access to the native memory and avoids copying.
+ ///
+ /// It is a recursive structure that may contain Tensors (base case)
+ /// Other sequences and maps. Although the OnnxValueType is exposed,
+ /// the caller is supposed to know the actual data type contained.
+ ///
+ /// The convention is that for tensors, it would contain a DenseTensor instance or
+ /// anything derived from Tensor.
+ ///
+ /// For sequences, it would contain a IList where T is an instance of NamedOnnxValue that
+ /// would contain a tensor or another type.
+ ///
+ /// For Maps, it would contain a IDictionary where K,V are primitive types or strings.
+ ///
///
public class NamedOnnxValue
{
///
/// Managed Tensor, Dictionary or IList
///
- protected Object _value;
+ private Object _value;
///
/// Name of the instance, model input/output
///
- protected string _name;
+ private string _name;
+
+ private MapHelper _mapHelper; // used for maps, otherwise null
///
/// Constructs an instance of NamedOnnxValue and represents
- /// a model input to an inference session. It also represents a modle output
- /// when serves as a base for DisposablenamedOnnxvalue
+ /// a model input to an inference session.
///
/// input/output name
/// Object that may be a tensor, Dictionary, IList
+ [Obsolete("Use constructors with valueType or static factory methods")]
protected NamedOnnxValue(string name, Object value)
{
_name = name;
_value = value;
+ ValueType = OnnxValueType.ONNX_TYPE_UNKNOWN;
}
+ ///
+ /// Constructs an instance that contains a tensor, sequence or optional type.
+ ///
+ ///
+ ///
+ ///
+ internal NamedOnnxValue(string name, Object value, OnnxValueType valueType)
+ {
+ _name = name;
+ _value = value;
+ ValueType = valueType;
+
+ if (valueType == OnnxValueType.ONNX_TYPE_MAP)
+ {
+ throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Use another __ctor for maps");
+ }
+ }
+
+ ///
+ /// Use this to construct maps
+ ///
+ ///
+ ///
+ ///
+ internal NamedOnnxValue(string name, Object value, MapHelper helper)
+ {
+ _name = name;
+ _value = value;
+ ValueType = OnnxValueType.ONNX_TYPE_MAP;
+ _mapHelper = helper;
+ }
+
+ ///
+ /// Onnx Value Type if known. In general, NamedOnnxValue is able to contain
+ /// arbitrary objects. Please, follow the convention described in the class doc.
+ ///
+ public OnnxValueType ValueType { get; internal set; }
+
///
/// This is a factory method that instantiates NamedOnnxValue
/// and associated name with an instance of a Tensor
@@ -47,7 +134,40 @@ protected NamedOnnxValue(string name, Object value)
///
public static NamedOnnxValue CreateFromTensor(string name, Tensor value)
{
- return new NamedOnnxValue(name, value);
+ return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR);
+ }
+
+ ///
+ /// This is a factory method that instantiates NamedOnnxValue.
+ /// It would contain a sequence of elements
+ ///
+ ///
+ ///
+ ///
+ public static NamedOnnxValue CreateFromSequence(string name, IEnumerable value)
+ {
+ return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_SEQUENCE);
+ }
+
+ ///
+ /// Instantiates NamedOnnxValue that contains IDictionary
+ ///
+ /// Keys type
+ /// Values type
+ ///
+ ///
+ /// new instance of NamedOnnxValue
+ public static NamedOnnxValue CreateFromMap(string name, IDictionary value)
+ {
+ // The order in which Keys and Values are unspecified,
+ // but it is guaranteed to be the same order
+ // These tensors are 1-D
+ var keysMemory = new Memory(value.Keys.ToArray());
+ var keysTensor = new DenseTensor(keysMemory, new int[1] { keysMemory.Length });
+
+ var valuesMemory = new Memory(value.Values.ToArray());
+ var valuesTensor = new DenseTensor(valuesMemory, new int[1] { valuesMemory.Length });
+ return new NamedOnnxValue(name, value, new MapHelper(keysTensor, valuesTensor));
}
///
@@ -73,6 +193,8 @@ public Tensor AsTensor()
///
/// Try-get value as an Enumerable<T>.
+ /// T is usually a NamedOnnxValue instance that may contain
+ /// Tensors, Sequences, Maps or optional types
///
/// Type
/// Enumerable object if contained value is a Enumerable. Null otherwise
@@ -85,8 +207,8 @@ public IEnumerable AsEnumerable()
///
/// Try-get value as an Dictionary<K,V>.
///
- /// Key type
- /// Value type
+ /// Key type currently primitive type only
+ /// Value type, currently primitive type only
/// Dictionary object if contained value is a Dictionary. Null otherwise
public IDictionary AsDictionary()
{
@@ -94,15 +216,88 @@ public IDictionary AsDictionary()
}
///
- /// Pin the underlying memory and create an instance of OrtValue
+ /// Pin the underlying memory and create an instance of OrtValue containing a tensor
/// based on the pinned managed memory. The caller is responsible for Disposing
/// both OrtValue and pinnedMemoryHandle
///
/// dispose after returned OrtValus is disposed
/// an instance of OrtValue. The lifespan of OrtValue must overlap pinnedMemoryHandle
- internal virtual OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle)
+ internal virtual OrtValue InputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner)
+ {
+ var projection = new ManagedTypeProjection(this, metadata);
+ memoryOwner = projection;
+ return projection.Value;
+ }
+
+ ///
+ /// Produces an output value for outputs. This produces an output value
+ /// only for tensors or optional types that can contain a tensor.
+ /// For all other Onnx value types, this method throws. Use Run() overloads
+ /// that return DisposableNamedOnnxValue to get access to all Onnx value types
+ /// that may be returned as output.
+ ///
+ ///
+ ///
+ ///
+ internal virtual OrtValue OutputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner)
{
- return OrtValue.CreateFromTensorObject(_value, out pinnedMemoryHandle, out TensorElementType elementType);
+ // For NamedOnnxValue for output we only allow to produce OrtValue for tensors
+ // or optional type that may contain a tensor
+ if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR)
+ {
+ var projection = new ManagedTypeProjection(this, metadata);
+ memoryOwner = projection;
+ return projection.Value;
+ }
+
+ if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL)
+ {
+ var meta = metadata.AsOptionalMetadata().ElementMeta;
+ if (meta.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR)
+ {
+ var projection = new ManagedTypeProjection(this, meta);
+ memoryOwner = projection;
+ return projection.Value;
+ }
+ }
+
+ throw new OnnxRuntimeException(ErrorCode.NotImplemented,
+ $"Can not create output OrtValue for NamedOnnxValue '{metadata.OnnxValueType}' type." +
+ $" Only tensors can be pre-allocated for outputs " +
+ $" Use Run() overloads that return DisposableNamedOnnxValue to get access to all Onnx value types that may be returned as output.");
+ }
+
+ ///
+ /// This method is used internally to feed dictionary keys
+ /// to create an OrtValue for map keys
+ ///
+ ///
+ /// DenseTensor"
+ internal Object GetDictionaryKeys()
+ {
+ if (ValueType != OnnxValueType.ONNX_TYPE_MAP)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "This NamedOnnxValue instance does not contain a dictionary");
+ }
+
+ Debug.Assert(_mapHelper != null);
+ return _mapHelper.Keys;
+ }
+
+ ///
+ ///
+ ///
+ ///
+ /// DenseTensor"
+ internal Object GetDictionaryValues()
+ {
+ if (ValueType != OnnxValueType.ONNX_TYPE_MAP)
+ {
+ throw new OnnxRuntimeException(ErrorCode.Fail, "This NamedOnnxValue instance does not contain a dictionary");
+ }
+
+ Debug.Assert(_mapHelper != null);
+ return _mapHelper.Values;
}
// may expose different types of getters in future
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 57ae4c2905316..c92db2afd4845 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -261,6 +261,31 @@ public struct OrtApi
public IntPtr ReleaseCANNProviderOptions;
public IntPtr MemoryInfoGetDeviceType;
public IntPtr UpdateEnvWithCustomLogLevel;
+ public IntPtr SetGlobalIntraOpThreadAffinity;
+ public IntPtr RegisterCustomOpsLibrary_V2;
+ public IntPtr RegisterCustomOpsUsingFunction;
+ public IntPtr KernelInfo_GetInputCount;
+ public IntPtr KernelInfo_GetOutputCount;
+ public IntPtr KernelInfo_GetInputName;
+ public IntPtr KernelInfo_GetOutputName;
+ public IntPtr KernelInfo_GetInputTypeInfo;
+ public IntPtr KernelInfo_GetOutputTypeInfo;
+ public IntPtr KernelInfoGetAttribute_tensor;
+ public IntPtr HasSessionConfigEntry;
+ public IntPtr GetSessionConfigEntry;
+ public IntPtr SessionOptionsAppendExecutionProvider_Dnnl;
+ public IntPtr CreateDnnlProviderOptions;
+ public IntPtr UpdateDnnlProviderOptions;
+ public IntPtr GetDnnlProviderOptionsAsString;
+ public IntPtr ReleaseDnnlProviderOptions;
+ public IntPtr KernelInfo_GetNodeName;
+ public IntPtr KernelInfo_GetLogger;
+ public IntPtr KernelContext_GetLogger;
+ public IntPtr Logger_LogMessage;
+ public IntPtr Logger_GetLoggingSeverityLevel;
+ public IntPtr KernelInfoGetConstantInput_tensor;
+ public IntPtr CastTypeInfoToOptionalTypeInfo;
+ public IntPtr GetOptionalContainedTypeInfo;
}
internal static class NativeMethods
@@ -387,9 +412,10 @@ static NativeMethods()
OrtSetLanguageProjection = (DOrtSetLanguageProjection)Marshal.GetDelegateForFunctionPointer(api_.SetLanguageProjection, typeof(DOrtSetLanguageProjection));
OrtGetValue = (DOrtGetValue)Marshal.GetDelegateForFunctionPointer(api_.GetValue, typeof(DOrtGetValue));
+ OrtGetValueCount = (DOrtGetValueCount)Marshal.GetDelegateForFunctionPointer(api_.GetValueCount, typeof(DOrtGetValueCount));
+ OrtCreateValue = (DOrtCreateValue)Marshal.GetDelegateForFunctionPointer(api_.CreateValue, typeof(DOrtCreateValue));
OrtGetValueType = (DOrtGetValueType)Marshal.GetDelegateForFunctionPointer(api_.GetValueType, typeof(DOrtGetValueType));
OrtGetOnnxTypeFromTypeInfo = (DOrtGetOnnxTypeFromTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetOnnxTypeFromTypeInfo, typeof(DOrtGetOnnxTypeFromTypeInfo));
- OrtGetValueCount = (DOrtGetValueCount)Marshal.GetDelegateForFunctionPointer(api_.GetValueCount, typeof(DOrtGetValueCount));
OrtGetTypeInfo = (DOrtGetTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetTypeInfo, typeof(DOrtGetTypeInfo));
OrtCreateTensorAsOrtValue = (DOrtCreateTensorAsOrtValue)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorAsOrtValue, typeof(DOrtCreateTensorAsOrtValue));
OrtCreateTensorWithDataAsOrtValue = (DOrtCreateTensorWithDataAsOrtValue)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorWithDataAsOrtValue, typeof(DOrtCreateTensorWithDataAsOrtValue));
@@ -405,6 +431,16 @@ static NativeMethods()
OrtGetDimensions = (DOrtGetDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetDimensions, typeof(DOrtGetDimensions));
OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions));
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
+ // MapTypeInfo
+ OrtGetMapKeyType = (DGetMapKeyType)Marshal.GetDelegateForFunctionPointer(api_.GetMapKeyType, typeof(DGetMapKeyType));
+ OrtCastTypeInfoToMapTypeInfo = (DCastTypeInfoToMapTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToMapTypeInfo, typeof(DCastTypeInfoToMapTypeInfo));
+ OrtGetMapValueType = (DGetMapValueType)Marshal.GetDelegateForFunctionPointer(api_.GetMapValueType, typeof(DGetMapValueType));
+ // SequenceTypeInfo
+ OrtCastTypeInfoToSequenceTypeInfo = (DCastTypeInfoToSequenceTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToSequenceTypeInfo, typeof(DCastTypeInfoToSequenceTypeInfo));
+ OrtGetSequenceElementType = (DGetSequenceElementType)Marshal.GetDelegateForFunctionPointer(api_.GetSequenceElementType, typeof(DGetSequenceElementType));
+ // Optional Type info
+ OrtCastTypeInfoToOptionalTypeInfo = (DOrtCastTypeInfoToOptionalTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToOptionalTypeInfo, typeof(DOrtCastTypeInfoToOptionalTypeInfo));
+ OrtGetOptionalContainedTypeInfo = (DGetOptionalContainedTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetOptionalContainedTypeInfo, typeof(DGetOptionalContainedTypeInfo));
OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue));
OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata));
@@ -1571,6 +1607,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public static DOrtGetValueCount OrtGetValueCount;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr/*(OrtStatus*)*/ DOrtCreateValue(IntPtr[] /* const OrtValue* const* in */ values,
+ UIntPtr /* size_t */ num_values, IntPtr /* (OnnxValueType */ onnxValueType, out IntPtr /* OrtValue** */ ortValue);
+
+ public static DOrtCreateValue OrtCreateValue;
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTypeInfo(IntPtr /*(OrtValue*)*/ value, IntPtr /*(OrtValue**)*/ typeInfo);
@@ -1698,6 +1740,41 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public static DOrtGetTensorShapeElementCount OrtGetTensorShapeElementCount;
+ /// Map Type API
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToMapTypeInfo(IntPtr /*(const struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*const OrtMapTypeInfo** */ mapTypeInfo);
+
+ public static DCastTypeInfoToMapTypeInfo OrtCastTypeInfoToMapTypeInfo;
+
+ public delegate IntPtr /*(OrtStatus*)*/ DGetMapKeyType(IntPtr /*const OrtMapTypeInfo* */ mapTypeInfo, out IntPtr /*(TensorElementType*)*/ tensorElementType);
+
+ public static DGetMapKeyType OrtGetMapKeyType;
+
+ public delegate IntPtr /*(OrtStatus*)*/ DGetMapValueType(IntPtr /* const OrtMapTypeInfo* */ map_type_info, out IntPtr /* OrtTypeInfo** */ type_info);
+
+ public static DGetMapValueType OrtGetMapValueType;
+
+ // Sequence TypeInfo
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToSequenceTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const OrtSequenceTypeInfo** */ sequenceTypeInfo);
+
+ public static DCastTypeInfoToSequenceTypeInfo OrtCastTypeInfoToSequenceTypeInfo;
+
+ public delegate IntPtr /*(OrtStatus*)*/ DGetSequenceElementType(IntPtr /* const OrtSequenceTypeInfo* */ sequenceTypeInfo, out IntPtr /* OrtTypeInfo** */ elementTypeInfo);
+
+ public static DGetSequenceElementType OrtGetSequenceElementType;
+
+ // OptionalTypeInfo
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToOptionalTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const struct OrtOptionalTypeInfo** */ optionalTypeInfo);
+
+ public static DOrtCastTypeInfoToOptionalTypeInfo OrtCastTypeInfoToOptionalTypeInfo;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DGetOptionalContainedTypeInfo(IntPtr /* const struct OrtOptionalTypeInfo*/ optTypeInfo, out IntPtr /* struct OrtTypeInfo** */ containedTypeInfo);
+
+ public static DGetOptionalContainedTypeInfo OrtGetOptionalContainedTypeInfo;
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseValue(IntPtr /*(OrtValue*)*/ value);
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
index feff70782f834..704eb4d14222f 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
@@ -43,7 +43,7 @@ public void Dispose()
// No need for the finalizer
// If this is not disposed timely GC can't help us
#endregion
- }
+ }
///
/// This helper class contains methods to create native OrtValue from a managed value object
@@ -58,8 +58,12 @@ internal static class NativeOnnxValueHelper
/// UTF-8 encoded equivalent
internal static byte[] StringToZeroTerminatedUtf8(string s)
{
- byte[] utf8Bytes = UTF8Encoding.UTF8.GetBytes(s);
- Array.Resize(ref utf8Bytes, utf8Bytes.Length + 1);
+ int arraySize = UTF8Encoding.UTF8.GetByteCount(s);
+ byte[] utf8Bytes = new byte[arraySize + 1];
+ if (arraySize != UTF8Encoding.UTF8.GetBytes(s, 0, s.Length, utf8Bytes, 0))
+ {
+ throw new OnnxRuntimeException(ErrorCode.RuntimeException, "Failed to convert to UTF8");
+ }
utf8Bytes[utf8Bytes.Length - 1] = 0;
return utf8Bytes;
}
@@ -72,7 +76,7 @@ internal static byte[] StringToZeroTerminatedUtf8(string s)
///
internal static string StringFromNativeUtf8(IntPtr nativeUtf8)
{
- // .NET 5.0 has Marshal.PtrToStringUTF8 that does the below
+ // .NET 8.0 has Marshal.PtrToStringUTF8 that does the below
int len = 0;
while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len;
byte[] buffer = new byte[len];
@@ -80,6 +84,33 @@ internal static string StringFromNativeUtf8(IntPtr nativeUtf8)
return Encoding.UTF8.GetString(buffer, 0, buffer.Length);
}
+ ///
+ /// Reads UTF-8 string from native C zero terminated string,
+ /// converts it to C# UTF-16 string and returns both C# string and utf-8
+ /// bytes as a zero terminated array, suitable for use as a C-string
+ ///
+ /// input
+ /// C# UTF-16 string
+ /// UTF-8 bytes in a managed buffer, zero terminated
+ internal static void StringAndUtf8FromNative(IntPtr nativeUtf8, out string str, out byte[] utf8)
+ {
+ // .NET 8.0 has Marshal.PtrToStringUTF8 that does the below
+ int len = 0;
+ while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len;
+ utf8 = new byte[len + 1];
+ Marshal.Copy(nativeUtf8, utf8, 0, len);
+ utf8[len] = 0;
+ str = Encoding.UTF8.GetString(utf8, 0, len);
+ }
+
+ internal static string StringFromUtf8Span(ReadOnlySpan utf8Span)
+ {
+ // XXX: For now we have to copy into byte[], this produces a copy
+ // Converting from span is available in later versions
+ var utf8Bytes = utf8Span.ToArray();
+ return Encoding.UTF8.GetString(utf8Bytes, 0, utf8Bytes.Length);
+ }
+
///
/// Run helper
///
@@ -126,7 +157,7 @@ public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, ou
{
bool result = true;
TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType);
- if(typeInfo != null)
+ if (typeInfo != null)
{
type = typeInfo.TensorType;
width = typeInfo.TypeSize;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
index 08609bb4826a6..868cf00ae334e 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
@@ -20,6 +20,7 @@ public enum OnnxValueType
ONNX_TYPE_MAP = 3, // It's a map
ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object
ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor
+ ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKOWN)
}
///
@@ -345,18 +346,20 @@ out valueHandle
// fill the native tensor, using GetValue(index) from the Tensor
var len = tensor.Length;
var nativeStrings = new IntPtr[len];
- using (var pinnedHandles = new DisposableList((int)len))
+ using (var pinnedHandles = new DisposableList((int)len))
{
for (int i = 0; i < len; i++)
{
var utf8str = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(tensor.GetValue(i));
- var gcHandle = GCHandle.Alloc(utf8str, GCHandleType.Pinned);
- nativeStrings[i] = gcHandle.AddrOfPinnedObject();
- pinnedHandles.Add(new PinnedGCHandle(gcHandle));
+ var pinnedUtf8 = new Memory(utf8str).Pin();
+ unsafe
+ {
+ nativeStrings[i] = (IntPtr)pinnedUtf8.Pointer;
+ }
+ pinnedHandles.Add(pinnedUtf8);
}
- using (var pinnedStrings = new PinnedGCHandle(GCHandle.Alloc(nativeStrings, GCHandleType.Pinned)))
- NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len));
}
}
catch (OnnxRuntimeException)
@@ -380,6 +383,7 @@ protected override bool ReleaseHandle()
if (IsOwned)
{
NativeMethods.OrtReleaseValue(handle);
+ IsOwned = false;
}
// Prevent use after disposal
handle = IntPtr.Zero;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs
similarity index 89%
rename from csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs
rename to csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs
index ce339b6a528c1..29ae3ab2be30a 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs
@@ -23,13 +23,14 @@ internal interface IOrtValueOwner : IDisposable
/// This class is used in conjunction with DisposableNamedOnnxValue to
/// own native collection OrtValue and dispose of it along with any DisposableNamedOnnxValues
///
- internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable
+ internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable
+ where T:IDisposable
{
private OrtValue _ortValue;
- private DisposableList _disposables;
+ private DisposableList _disposables;
bool _disposed = false;
- internal NativeOrtValueCollectionOwner(OrtValue ortValue, DisposableList disposables)
+ internal NativeOrtValueCollectionOwner(OrtValue ortValue, DisposableList disposables)
{
Debug.Assert(ortValue.IsOwned);
_ortValue = new OrtValue(ortValue.Disown());
@@ -80,19 +81,24 @@ public void Dispose()
///
/// This helper class owns the underlying OrtValue that is assumed to be a Tensor,
/// it does not support any other ortValues and caches Tensor properties.
+ ///
+ /// It is easy to expose as a Tensor as DenseTensor can take Memory Mapping from
+ /// this.
+ ///
+ /// This class is disposable because of the MemoryManager inheritance
///
///
- internal class NativeOnnxTensorMemory : MemoryManager, IOrtValueOwner
+ internal class OrtValueTensor : MemoryManager, IOrtValueOwner
{
private OrtValue _ortValue; // Disposable
- private IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory
- private string[] _dataBufferAsString; // string tensor values copied into managed memory
+ private readonly IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory
+ private readonly string[] _dataBufferAsString; // string tensor values copied into managed memory
///
/// Constructs an instance and takes ownership of ortValue on success
///
/// ortValue that is a Tensor
- public NativeOnnxTensorMemory(OrtValue ortValue)
+ public OrtValueTensor(OrtValue ortValue)
{
Type type = null;
int width = 0;
@@ -115,7 +121,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue)
if (typeof(T) != type)
{
- var message = String.Format("The NativeOnnxTensorMemory type being instantiated for T = : {0} while supplied OrtValue contains T = {1}",
+ var message = String.Format("The OrtValueTensor type being instantiated for T = : {0} while supplied OrtValue contains T = {1}",
typeof(T), type);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message);
}
@@ -214,7 +220,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue)
public override Span GetSpan()
{
if (IsDisposed)
- throw new ObjectDisposedException(nameof(NativeOnnxTensorMemory));
+ throw new ObjectDisposedException(nameof(OrtValueTensor));
Span span = null;
unsafe
{
@@ -226,10 +232,10 @@ public override Span GetSpan()
public Memory GetBytesAsStringMemory()
{
if (IsDisposed)
- throw new ObjectDisposedException(nameof(NativeOnnxTensorMemory));
+ throw new ObjectDisposedException(nameof(OrtValueTensor));
if (typeof(T) != typeof(string))
- throw new NotSupportedException(nameof(NativeOnnxTensorMemory.GetBytesAsStringMemory) + ": T must be byte");
+ throw new NotSupportedException(nameof(OrtValueTensor.GetBytesAsStringMemory) + ": T must be byte");
return (_dataBufferAsString == null) ? new Memory() : new Memory(_dataBufferAsString);
}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs
index bb7eea2ad1883..0bc5ca7240e66 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs
@@ -71,7 +71,7 @@ public Float16(ushort v)
///
/// instance of Float16
/// value member
- public static implicit operator ushort (Float16 f) { return f.value; }
+ public static implicit operator ushort(Float16 f) { return f.value; }
///
/// Converts a 16-bit unsigned integer to a Float16.
///
@@ -191,7 +191,7 @@ public bool Equals(BFloat16 other)
/// represent the same type and value.
///
/// An System.Object.
- /// true if obj is BFloat16 its value is equal to this instance; otherwise, false.
+ /// true if obj is BFloat16 its value is equal to this instance; otherwise, false.
public override bool Equals(object obj)
{
bool result = false;
@@ -286,7 +286,8 @@ public class TensorBase
private static readonly Dictionary tensorElementTypeInfoMap;
- static TensorBase () {
+ static TensorBase()
+ {
typeInfoMap = new Dictionary()
{
{ typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) },
@@ -306,11 +307,11 @@ static TensorBase () {
};
tensorElementTypeInfoMap = new Dictionary();
- foreach(var info in typeInfoMap)
+ foreach (var info in typeInfoMap)
{
tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize));
}
- }
+ }
private readonly Type _primitiveType;
///
@@ -559,7 +560,10 @@ internal static T Zero
{
return (T)(object)(ushort)(0);
}
-
+ else if (typeof(T) == typeof(string))
+ {
+ return (T)(object)("0");
+ }
throw new NotSupportedException();
}
}
@@ -619,8 +623,8 @@ internal static T One
else if (typeof(T) == typeof(ushort))
{
return (T)(object)(ushort)(1);
- }
- else if(typeof(T) == typeof(Float16))
+ }
+ else if (typeof(T) == typeof(Float16))
{
return (T)(object)(ushort)(15360);
}
@@ -628,6 +632,10 @@ internal static T One
{
return (T)(object)(ushort)(16256);
}
+ else if (typeof(T) == typeof(string))
+ {
+ return (T)(object)("1");
+ }
throw new NotSupportedException();
}
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj
index e15409216aabd..3e4802bf68f6e 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj
@@ -48,6 +48,7 @@
+
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
index 357d1ca8621bd..da83b640f2577 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
@@ -588,7 +588,7 @@ private void ThrowWrongInputName()
var container = new List();
container.Add(NamedOnnxValue.CreateFromTensor("wrong_name", tensor));
var ex = Assert.Throws(() => session.Run(container));
- Assert.Contains("Invalid Feed Input", ex.Message);
+ Assert.Contains("Input name: 'wrong_name' is not in the metadata", ex.Message);
session.Dispose();
}
@@ -604,9 +604,8 @@ private void ThrowWrongInputType()
var tensor = new DenseTensor(inputDataInt, inputMeta["data_0"].Dimensions);
container.Add(NamedOnnxValue.CreateFromTensor("data_0", tensor));
var ex = Assert.Throws(() => session.Run(container));
- var msg = ex.ToString().Substring(0, 101);
- // TODO: message is diff in LInux. Use substring match
- Assert.Equal("Microsoft.ML.OnnxRuntime.OnnxRuntimeException: [ErrorCode:InvalidArgument] Unexpected input data type", msg);
+ var msg = ex.ToString();
+ Assert.Contains("Tensor element data type discovered", msg);
session.Dispose();
}
@@ -624,7 +623,7 @@ private void ThrowExtraInputs()
container.Add(nov1);
container.Add(nov2);
var ex = Assert.Throws(() => session.Run(container));
- Assert.StartsWith("[ErrorCode:InvalidArgument] Invalid Feed Input Name", ex.Message);
+ Assert.Contains("Input name: 'extra' is not in the metadata", ex.Message);
session.Dispose();
}
@@ -653,9 +652,10 @@ private void ThrowWrongOutputName()
var inputTensor = tuple.Item3;
var inputs = new List { NamedOnnxValue.CreateFromTensor("data_0", inputTensor) };
var outputTensor = new DenseTensor((ReadOnlySpan)new[] { 1, 2 });
- var outputs = new List { NamedOnnxValue.CreateFromTensor("bad_output_name", outputTensor) };
- var ex = Assert.Throws(() => session.Run(inputs, outputs));
- Assert.Contains("Invalid Output Name", ex.Message);
+ // var outputs = new List { NamedOnnxValue.CreateFromTensor("bad_output_name", outputTensor) };
+ var bad_names = new string[] {"bad_output_name"};
+ var ex = Assert.Throws(() => session.Run(inputs, bad_names));
+ Assert.Contains("Output name: 'bad_output_name' is not in the metadata", ex.Message);
session.Dispose();
}
@@ -1322,8 +1322,29 @@ private void TestModelSequenceOfMapIntFloat()
{
var outMeta = session.OutputMetadata;
- Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType);
- Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType);
+ var label_meta = outMeta["label"];
+ Assert.True(label_meta.IsTensor);
+ Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, label_meta.OnnxValueType);
+ Assert.Equal(TensorElementType.Int64, label_meta.ElementDataType);
+ Assert.NotEmpty(label_meta.Dimensions);
+
+ // sequence
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs
index 2086fa0ec3164..42a780b3287ef 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs
@@ -1,8 +1,12 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.IO;
using System.Linq;
+using System.Net.NetworkInformation;
+using Google.Protobuf;
using Microsoft.ML.OnnxRuntime.Tensors;
+using Xunit;
namespace Microsoft.ML.OnnxRuntime.Tests
{
@@ -46,185 +50,361 @@ internal static float[] LoadTensorFromEmbeddedResource(string path)
return tensorData.ToArray();
}
- internal static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Type type, out int width)
+ static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, string nodeName, NodeMetadata nodeMeta)
{
- TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
- if (result != null)
+ if (nodeMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
{
- type = result.TensorType;
- width = result.TypeSize;
+ throw new InvalidDataException($"Metadata for: '{nodeName}' has a type: '{nodeMeta.OnnxValueType}'" +
+ $" but loading as tensor: '{tensor.Name}'");
}
- else
+
+ var protoDt = (Tensors.TensorElementType)tensor.DataType;
+ var metaElementType = nodeMeta.ElementDataType;
+ if (!((protoDt == metaElementType) ||
+ (protoDt == TensorElementType.UInt16 &&
+ (metaElementType == TensorElementType.BFloat16 || metaElementType == TensorElementType.Float16))))
+ throw new InvalidDataException($"For node: '{nodeName}' metadata expects: '{metaElementType}' but loaded loaded tensor type: '{protoDt}'");
+
+ // Tensors within Sequences may have no dimensions as the standard allows
+ // different dimensions for each tensor element of the sequence
+ if (nodeMeta.Dimensions.Length > 0 && nodeMeta.Dimensions.Length != tensor.Dims.Count)
{
- throw new ArgumentException("Unable to get information for type: " + elemType.ToString());
+ throw new InvalidDataException($"node: '{nodeName}' nodeMeta.Dim.Length: {nodeMeta.Dimensions.Length} " +
+ $"is expected to be equal to tensor.Dims.Count {tensor.Dims.Count}");
}
- }
- static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary nodeMetaDict)
- {
- Type tensorElemType = null;
- int width = 0;
- GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width);
var intDims = new int[tensor.Dims.Count];
-
for (int i = 0; i < tensor.Dims.Count; i++)
{
intDims[i] = (int)tensor.Dims[i];
}
- NodeMetadata nodeMeta = null;
- string nodeName = string.Empty;
+ for (int i = 0; i < nodeMeta.Dimensions.Length; i++)
+ {
+ if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != tensor.Dims[i]))
+ throw new InvalidDataException($"Node: '{nodeName}' dimension at idx {i} is {nodeMeta.Dimensions}[{i}] " +
+ $"is expected to either be -1 or {tensor.Dims[i]}");
+ }
- if (nodeMetaDict.Count == 1)
+ // element type for Float16 and BFloat16 in the loaded tensor would always be uint16, so
+ // we want to use element type from metadata
+ if (protoDt == TensorElementType.String)
+ return CreateNamedOnnxValueFromStringTensor(tensor.StringData, nodeName, intDims);
+
+ return CreateNamedOnnxValueFromTensorRawData(nodeName, tensor.RawData.ToArray(), metaElementType, intDims);
+ }
+
+ internal static NamedOnnxValue CreateNamedOnnxValueFromTensorRawData(string nodeName, byte[] rawData, TensorElementType elementType, int[] intDims)
+ {
+ switch (elementType)
{
- nodeMeta = nodeMetaDict.Values.First();
- nodeName = nodeMetaDict.Keys.First(); // valid for single node input
+ case TensorElementType.Float:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(float), intDims);
+ case TensorElementType.Double:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(double), intDims);
+ case TensorElementType.Int32:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(int), intDims);
+ case TensorElementType.UInt32:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(uint), intDims);
+ case TensorElementType.Int16:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(short), intDims);
+ case TensorElementType.UInt16:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims);
+ case TensorElementType.Int64:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(long), intDims);
+ case TensorElementType.UInt64:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ulong), intDims);
+ case TensorElementType.UInt8:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(byte), intDims);
+ case TensorElementType.Int8:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(sbyte), intDims);
+ case TensorElementType.Bool:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(bool), intDims);
+ case TensorElementType.Float16:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims);
+ case TensorElementType.BFloat16:
+ return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims);
+ case TensorElementType.String:
+ throw new ArgumentException("For string tensors of type use: CreateNamedOnnxValueFromStringTensor.");
+ default:
+ throw new NotImplementedException($"Tensors of type: {elementType} not currently supported by this function");
}
- else if (nodeMetaDict.Count > 1)
+ }
+
+ internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, string nodeName, NodeMetadata nodeMeta)
+ {
+ Onnx.TensorProto tensor = null;
+
+ var assembly = typeof(TestDataLoader).Assembly;
+
+ using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.TestData.{path}"))
{
- if (tensor.Name.Length > 0)
- {
- nodeMeta = nodeMetaDict[tensor.Name];
- nodeName = tensor.Name;
- }
- else
+ tensor = Onnx.TensorProto.Parser.ParseFrom(stream);
+ }
+
+ return LoadTensorPb(tensor, nodeName, nodeMeta);
+ }
+
+ internal static NamedOnnxValue LoadOnnxValueFromFilePb(string fullFilename, string nodeName, NodeMetadata nodeMeta)
+ {
+ // No sparse tensor support yet
+ //Set buffer size to 4MB
+ int readBufferSize = 4194304;
+ using (var file = new FileStream(fullFilename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize))
+ {
+ switch (nodeMeta.OnnxValueType)
{
- bool matchfound = false;
- // try to find from matching type and shape
- foreach (var key in nodeMetaDict.Keys)
- {
- var meta = nodeMetaDict[key];
- if (tensorElemType == meta.ElementType && tensor.Dims.Count == meta.Dimensions.Length)
+ case OnnxValueType.ONNX_TYPE_TENSOR:
{
- int i = 0;
- for (; i < meta.Dimensions.Length; i++)
- {
- if (meta.Dimensions[i] != -1 && meta.Dimensions[i] != intDims[i])
- {
- break;
- }
- }
- if (i >= meta.Dimensions.Length)
- {
- matchfound = true;
- nodeMeta = meta;
- nodeName = key;
- break;
- }
+ var tensor = Onnx.TensorProto.Parser.ParseFrom(file);
+ return LoadTensorPb(tensor, nodeName, nodeMeta);
}
- }
- if (!matchfound)
- {
- // throw error
- throw new Exception($"No Matching Tensor found in InputOutputMetadata corresponding to the serialized tensor specified");
- }
+ case OnnxValueType.ONNX_TYPE_SEQUENCE:
+ {
+ var sequence = Onnx.SequenceProto.Parser.ParseFrom(file);
+ return CreateNamedOnnxValueFromSequence(sequence, nodeName, nodeMeta);
+ }
+ case OnnxValueType.ONNX_TYPE_MAP:
+ {
+ var map = Onnx.MapProto.Parser.ParseFrom(file);
+ return CreateNamedOnnxValueFromMap(map, nodeName, nodeMeta);
+ }
+
+ case OnnxValueType.ONNX_TYPE_OPTIONAL:
+ {
+ var opt = Onnx.OptionalProto.Parser.ParseFrom(file);
+ return CreateNamedOnnxValueFromOptional(opt, nodeName, nodeMeta);
+ }
+ default:
+ throw new NotImplementedException($"Unable to load value type: {nodeMeta.OnnxValueType} not implemented");
}
}
- else
- {
- // throw error
- throw new Exception($"While reading the serliazed tensor specified, metaDataDict has 0 elements");
- }
+ }
- if (!nodeMeta.IsTensor)
- throw new Exception("LoadTensorFromFile can load Tensor types only");
+ private static void SequenceCheckMatchOnnxType(string nodeName, SequenceMetadata meta,
+ OnnxValueType onnxType)
+ {
+ if (meta.ElementMeta.OnnxValueType == onnxType)
+ return;
- if (tensorElemType != nodeMeta.ElementType)
- throw new Exception($"{nameof(tensorElemType)} is expected to be equal to {nameof(nodeMeta.ElementType)}");
+ throw new InvalidDataException($"Sequence node: '{nodeName}' " +
+ $"has element type: '{onnxType}'" +
+ $" expected: '{meta.ElementMeta.OnnxValueType}'");
+ }
- if (nodeMeta.Dimensions.Length != tensor.Dims.Count)
- throw new Exception($"{nameof(nodeMeta.Dimensions.Length)} is expected to be equal to {nameof(tensor.Dims.Count)}");
+ private static string MakeSequenceElementName(string nodeName, string seqName, int seqNum)
+ {
+ if (seqName.Length > 0)
+ return $"seq.{nodeName}.data.{seqName}.{seqNum}";
+ else
+ return $"seq.{nodeName}.data._.{seqNum}";
+ }
- for (int i = 0; i < nodeMeta.Dimensions.Length; i++)
- {
- if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != intDims[i]))
- throw new Exception($"{nameof(nodeMeta.Dimensions)}[{i}] is expected to either be -1 or {nameof(intDims)}[{i}]");
- }
+ internal static NamedOnnxValue CreateNamedOnnxValueFromSequence(Onnx.SequenceProto sequence, string nodeName, NodeMetadata nodeMeta)
+ {
+ var sequenceMeta = nodeMeta.AsSequenceMetadata();
+ var elemMeta = sequenceMeta.ElementMeta;
- if (nodeMeta.ElementType == typeof(float))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(float), intDims);
- }
- else if (nodeMeta.ElementType == typeof(double))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(double), intDims);
- }
- else if (nodeMeta.ElementType == typeof(int))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(int), intDims);
- }
- else if (nodeMeta.ElementType == typeof(uint))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(uint), intDims);
- }
- else if (nodeMeta.ElementType == typeof(long))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(long), intDims);
- }
- else if (nodeMeta.ElementType == typeof(ulong))
+ int seqNum = 0;
+ var seqElemType = (Onnx.SequenceProto.Types.DataType)sequence.ElemType;
+ switch (seqElemType)
{
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ulong), intDims);
- }
- else if (nodeMeta.ElementType == typeof(short))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(short), intDims);
- }
- else if (nodeMeta.ElementType == typeof(ushort))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
- }
- else if (nodeMeta.ElementType == typeof(byte))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(byte), intDims);
- }
- else if (nodeMeta.ElementType == typeof(sbyte))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(sbyte), intDims);
- }
- else if (nodeMeta.ElementType == typeof(bool))
- {
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims);
+ case Onnx.SequenceProto.Types.DataType.Tensor:
+ {
+ SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_TENSOR);
+ var sequenceOfTensors = new List(sequence.TensorValues.Count);
+ foreach (var tensor in sequence.TensorValues)
+ {
+ var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
+ var namedOnnxValue = LoadTensorPb(tensor, elemName, elemMeta);
+ sequenceOfTensors.Add(namedOnnxValue);
+ }
+ return NamedOnnxValue.CreateFromSequence(nodeName, sequenceOfTensors);
+ }
+ case Onnx.SequenceProto.Types.DataType.Sequence:
+ {
+ SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_SEQUENCE);
+ var seqOfSequences = new List(sequence.SequenceValues.Count);
+ foreach (var s in sequence.SequenceValues)
+ {
+ var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
+ seqOfSequences.Add(CreateNamedOnnxValueFromSequence(s, elemName, elemMeta));
+ }
+ return NamedOnnxValue.CreateFromSequence(nodeName, seqOfSequences);
+ }
+ case Onnx.SequenceProto.Types.DataType.Map:
+ {
+ SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_MAP);
+ var seqOfMaps = new List(sequence.MapValues.Count);
+ foreach (var m in sequence.MapValues)
+ {
+ var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
+ seqOfMaps.Add(CreateNamedOnnxValueFromMap(m, elemName, elemMeta));
+ }
+ return NamedOnnxValue.CreateFromSequence(nodeName, seqOfMaps);
+ }
+ case Onnx.SequenceProto.Types.DataType.Optional:
+ {
+ SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_OPTIONAL);
+ var seqOfOpts = new List(sequence.OptionalValues.Count);
+ foreach (var opt in sequence.OptionalValues)
+ {
+ var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++);
+ seqOfOpts.Add(CreateNamedOnnxValueFromOptional(opt, elemName, elemMeta));
+ }
+ return NamedOnnxValue.CreateFromSequence(nodeName, seqOfOpts);
+ }
+ default:
+ throw new NotImplementedException($"Sequence test data loading does not support element type: " +
+ $"'{seqElemType}'");
}
- else if (nodeMeta.ElementType == typeof(Float16))
+
+ }
+
+ internal static NamedOnnxValue CastAndCreateFromMapKeys(string name, TensorElementType elementType, IList keys)
+ {
+ switch (elementType)
{
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
+ case TensorElementType.Float:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.Double:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.Int32:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.UInt32:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.Int16:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.UInt16:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.Int64:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.UInt64:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.UInt8:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.Int8:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.Bool:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.Float16:
+ return CastAndCreateTensor(name, keys);
+ case TensorElementType.BFloat16:
+ return CastAndCreateTensor(name, keys);
+ default:
+ throw new NotImplementedException($"Tensors of type: " + elementType.ToString() +
+ " not currently supported here, use: CreateNamedOnnxValueFromStringTensor.");
}
- else if (nodeMeta.ElementType == typeof(BFloat16))
+ }
+
+ ///
+ /// All the keys in maps are stored as an array of longs, so
+ /// to create a real tensor we need to cast to create a continuous buffer
+ /// essentially packing it as a raw data.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal static NamedOnnxValue CastAndCreateTensor(string name, IList elements)
+ {
+ // Create raw data
+ T[] castKeys = new T[elements.Count];
+ if (typeof(T) == typeof(Float16) || typeof(T) == typeof(BFloat16))
{
- return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims);
+ for (int i = 0; i < elements.Count; i++)
+ {
+ var obj = Convert.ChangeType(elements[i], typeof(ushort));
+ if (obj == null)
+ {
+ throw new InvalidDataException($"Conversion from long to {typeof(T)} failed");
+ }
+ castKeys[i] = (T)obj;
+ }
}
else
{
- //TODO: Add support for remaining types
- throw new Exception($"Tensors of type {nameof(nodeMeta.ElementType)} not currently supported in the LoadTensorFromEmbeddedResource");
+ for (int i = 0; i < elements.Count; i++)
+ {
+ var obj = (T)Convert.ChangeType(elements[i], typeof(T));
+ if (obj == null)
+ {
+ throw new InvalidDataException($"Conversion from long to {typeof(T)} failed");
+ }
+ castKeys[i] = (T)obj;
+ }
}
+ var tensor = new DenseTensor(castKeys, new int[] { elements.Count });
+ return NamedOnnxValue.CreateFromTensor(name, tensor);
}
- internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, IReadOnlyDictionary nodeMetaDict)
+ internal static NamedOnnxValue CreateNamedOnnxValueFromMap(Onnx.MapProto map, string nodeName, NodeMetadata nodeMetadata)
{
- Onnx.TensorProto tensor = null;
+ // See GH issue https://github.com/onnx/onnx/issues/5072
+ throw new NotImplementedException($"Loading map node: '{nodeName}' not implemented yet");
- var assembly = typeof(TestDataLoader).Assembly;
+ //var mapMeta = nodeMetadata.AsMapMetadata();
- using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.TestData.{path}"))
- {
- tensor = Onnx.TensorProto.Parser.ParseFrom(stream);
- }
+ //if ((TensorElementType)map.KeyType != mapMeta.KeyDataType)
+ //{
+ // throw new InvalidDataException($"Node: '{nodeName}' map key type expected: " +
+ // $"'{mapMeta.KeyDataType}', loaded from test data: '{(TensorElementType)map.KeyType}'");
+ //}
+
+ //// temp non-generic(!) container
+ //NamedOnnxValue keysTensor;
+ //if (mapMeta.KeyDataType == TensorElementType.String)
+ //{
+ // keysTensor = CreateNamedOnnxValueFromStringTensor(map.StringKeys, nodeName, new int[] { map.StringKeys.Count });
+ //}
+ //else
+ //{
+ // keysTensor = CastAndCreateFromMapKeys(nodeName, mapMeta.KeyDataType, map.Keys);
+ //}
+
+ //switch ((Onnx.SequenceProto.Types.DataType)map.Values.ElemType)
+ //{
+ // case Onnx.SequenceProto.Types.DataType.Tensor:
+ // var tensorCount = map.Values.TensorValues.Count;
+ // break;
+ // default:
+ // throw new NotImplementedException("Does not support map value type other than a tensor");
+ //}
- return LoadTensorPb(tensor, nodeMetaDict);
+ //return new NamedOnnxValue(string.Empty, new Object(), OnnxValueType.ONNX_TYPE_UNKNOWN);
}
- internal static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary nodeMetaDict)
+ internal static NamedOnnxValue CreateNamedOnnxValueFromOptional(Onnx.OptionalProto optional, string nodeName, NodeMetadata nodeMetadata)
{
- //Set buffer size to 4MB
- int readBufferSize = 4194304;
- Onnx.TensorProto tensor = null;
- using (var file = new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize))
+ var meta = nodeMetadata.AsOptionalMetadata().ElementMeta;
+ switch((Onnx.OptionalProto.Types.DataType)optional.ElemType)
{
- tensor = Onnx.TensorProto.Parser.ParseFrom(file);
+ case Onnx.OptionalProto.Types.DataType.Tensor:
+ {
+ var tensor = optional.TensorValue;
+ return LoadTensorPb(tensor, nodeName, meta);
+ }
+ case Onnx.OptionalProto.Types.DataType.Sequence:
+ {
+ var sequence = optional.SequenceValue;
+ return CreateNamedOnnxValueFromSequence(sequence, nodeName, meta);
+ }
+ case Onnx.OptionalProto.Types.DataType.Map:
+ {
+ var map = optional.MapValue;
+ return CreateNamedOnnxValueFromMap(map, nodeName, meta);
+ }
+ case Onnx.OptionalProto.Types.DataType.Optional:
+ throw new NotImplementedException($"Unable to load '{nodeName}' optional contained within optional");
+ default:
+ // Test data contains OptionalProto with the contained element type undefined.
+ // the premise is, if the element is not fed as an input, we should not care
+ // what Onnx type it is. However, we do not need to support AFAIK such inputs
+ // since the value for them could never be supplied.
+ throw new NotImplementedException($"Unable to load '{nodeName}' optional element type of: {(Onnx.OptionalProto.Types.DataType)optional.ElemType} type");
}
-
- return LoadTensorPb(tensor, nodeMetaDict);
}
internal static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, byte[] rawData, int elemWidth, int[] dimensions)
@@ -250,6 +430,19 @@ internal static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, b
return NamedOnnxValue.CreateFromTensor(name, dt);
}
+ internal static NamedOnnxValue CreateNamedOnnxValueFromStringTensor(IList strings,
+ string nodeName, int[] dimensions)
+ {
+ string[] strArray = new string[strings.Count];
+ for (int i = 0; i < strings.Count; ++i)
+ {
+ strArray[i] = System.Text.Encoding.UTF8.GetString(strings[i].ToByteArray());
+ }
+
+ var dt = new DenseTensor(strArray, dimensions);
+ return NamedOnnxValue.CreateFromTensor(nodeName, dt);
+ }
+
internal static float[] LoadTensorFromFile(string filename, bool skipheader = true)
{
var tensorData = new List();
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
index 61ff2a43ffd40..76518e341f935 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
@@ -3,32 +3,33 @@
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
+using System.Text.RegularExpressions;
using Microsoft.ML.OnnxRuntime.Tensors;
using Xunit;
namespace Microsoft.ML.OnnxRuntime.Tests
{
- ///
- /// This is compensate for the absence of string.Contains() in .NET Standard 2.0
- /// Contains(String, StringComparison)
- ///
- public static class StringExtensions
- {
- public static bool Contains(this String str, String substring,
- StringComparison comp)
+ ///
+ /// This is compensate for the absence of string.Contains() in .NET Standard 2.0
+ /// Contains(String, StringComparison)
+ ///
+ public static class StringExtensions
{
- if (substring == null)
- throw new ArgumentNullException("substring",
- "substring cannot be null.");
- else if (!Enum.IsDefined(typeof(StringComparison), comp))
- throw new ArgumentException("comp is not a member of StringComparison",
- "comp");
-
- return str.IndexOf(substring, comp) >= 0;
+ public static bool Contains(this String str, String substring,
+ StringComparison comp)
+ {
+ if (substring == null)
+ throw new ArgumentNullException("substring",
+ "substring cannot be null.");
+ else if (!Enum.IsDefined(typeof(StringComparison), comp))
+ throw new ArgumentException("comp is not a member of StringComparison",
+ "comp");
+
+ return str.IndexOf(substring, comp) >= 0;
+ }
}
- }
- public partial class InferenceTest
- {
+ public partial class InferenceTest
+ {
private const string module = "onnxruntime.dll";
private const string propertiesFile = "Properties.txt";
@@ -40,8 +41,8 @@ public void CanCreateAndDisposeSessionWithModelPath()
{
Assert.NotNull(session);
Assert.NotNull(session.InputMetadata);
- Assert.Equal(1, session.InputMetadata.Count); // 1 input node
- Assert.True(session.InputMetadata.ContainsKey("data_0")); // input node name
+ Assert.Equal(1, session.InputMetadata.Count); // 1 input nodeMeta
+ Assert.True(session.InputMetadata.ContainsKey("data_0")); // input nodeMeta name
Assert.Equal(typeof(float), session.InputMetadata["data_0"].ElementType);
Assert.True(session.InputMetadata["data_0"].IsTensor);
var expectedInputDimensions = new int[] { 1, 3, 224, 224 };
@@ -52,8 +53,8 @@ public void CanCreateAndDisposeSessionWithModelPath()
}
Assert.NotNull(session.OutputMetadata);
- Assert.Equal(1, session.OutputMetadata.Count); // 1 output node
- Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output node name
+ Assert.Equal(1, session.OutputMetadata.Count); // 1 output nodeMeta
+ Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output nodeMeta name
Assert.Equal(typeof(float), session.OutputMetadata["softmaxout_1"].ElementType);
Assert.True(session.OutputMetadata["softmaxout_1"].IsTensor);
var expectedOutputDimensions = new int[] { 1, 1000, 1, 1 };
@@ -246,9 +247,9 @@ private static Dictionary GetSkippedModels(DirectoryInfo modelsD
{ "fp16_test_tiny_yolov2", "ImageScaler is not a registered function/op"},
{ "fp16_coreml_FNS-Candy", "ImageScaler is not a registered function/op" },
{ "fp16_coreml_LinearRegression_NYCTaxi", "Error in Node:featureVectorizer : No Op registered for FeatureVectorizer with domain_version of 1"},
- { "test_bidaf", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile." },
{ "test_mnist", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile" },
- { "BERT_Squad", "Could not find an implementation for the node bert / embeddings / one_hot:OneHot(9)" },
+ { "BERT_Squad", "Could not find an implementation for the nodeMeta bert / embeddings / one_hot:OneHot(9)" },
+
{ "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" },
{ "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" },
{ "tf_resnet_v1_101", "result mismatch when Conv BN Fusion is applied" },
@@ -256,108 +257,85 @@ private static Dictionary GetSkippedModels(DirectoryInfo modelsD
{ "cntk_simple_seg", "Bad onnx test output caused by wrong SAME_UPPER/SAME_LOWER for ConvTranspose" },
{ "coreml_Imputer-LogisticRegression_sklearn_load_breast_cancer", "Can't determine model file name" },
{ "mask_rcnn_keras", "Model should be edited to remove the extra outputs" },
- { "test_strnormalizer_export_monday_casesensintive_lower", "ElementType not currently supported"},
- { "test_max_float64", "node test error"},
- { "test_min_uint8", "node test error"},
- { "test_mod_mixed_sign_float64", "node test error"},
- { "test_momentum", "node test error"},
- { "test_max_uint16", "node test error"},
- { "test_resize_downsample_scales_linear_align_corners", "node test error"},
- { "test_strnormalizer_nostopwords_nochangecase", "node test error"},
- { "test_adagrad_multiple", "node test error"},
- { "test_einsum_inner_prod", "node test error"},
- { "test_sequence_insert_at_back", "node test error"},
- { "test_mod_mixed_sign_int8", "node test error"},
- { "test_maxunpool_export_with_output_shape", "node test error"},
- { "test_strnormalizer_export_monday_empty_output", "node test error"},
- { "test_strnormalizer_export_monday_insensintive_upper_twodim", "ElementType not currently supported"},
- { "test_min_int16", "node test error"},
- { "test_adagrad", "node test error"},
- { "test_min_float64", "node test error"},
- { "test_max_int16", "node test error"},
- { "test_sequence_insert_at_front", "node test error"},
- { "test_training_dropout_default", "node test error"},
- { "test_training_dropout", "node test error"},
- { "test_adam", "node test error"},
- { "test_training_dropout_mask", "node test error"},
- { "test_clip_default_int8_inbounds", "node test error"},
- { "test_eyelike_with_dtype", "node test error"},
- { "test_cast_STRING_to_FLOAT", "node test error"},
- { "test_cast_FLOAT16_to_DOUBLE", "node test error"},
- { "test_cast_FLOAT_to_DOUBLE", "node test error"},
- { "test_cast_BFLOAT16_to_FLOAT", "node test error"},
- { "test_cast_FLOAT_to_BFLOAT16", "node test error"},
- { "test_cast_FLOAT_to_STRING", "node test error"},
- { "test_castlike_STRING_to_FLOAT", "node test error"},
- { "test_castlike_STRING_to_FLOAT_expanded", "node test error"},
- { "test_castlike_FLOAT16_to_DOUBLE", "node test error"},
- { "test_castlike_FLOAT16_to_DOUBLE_expanded", "node test error"},
- { "test_castlike_FLOAT_to_DOUBLE", "node test error"},
- { "test_castlike_FLOAT_to_DOUBLE_expanded", "node test error"},
- { "test_castlike_BFLOAT16_to_FLOAT", "node test error"},
- { "test_castlike_BFLOAT16_to_FLOAT_expanded", "node test error"},
- { "test_castlike_FLOAT_to_BFLOAT16", "node test error"},
- { "test_castlike_FLOAT_to_BFLOAT16_expanded", "node test error"},
- { "test_castlike_FLOAT_to_STRING", "node test error"},
- { "test_castlike_FLOAT_to_STRING_expanded", "node test error"},
- { "test_bitshift_right_uint16", "node test error"},
- { "test_bitshift_left_uint16", "node test error"},
- { "test_pow_types_float32_uint64", "node test error"},
- { "test_max_uint8", "node test error"},
- { "test_strnormalizer_export_monday_casesensintive_nochangecase", "ElementType not currently supported"},
- { "test_momentum_multiple", "node test error"},
- { "test_pow_types_float32_uint32", "node test error"},
- { "test_if_seq", "sequence type is not supported in test infra."},
- { "test_resize_downsample_scales_cubic_align_corners", "node test error"},
- { "test_einsum_batch_matmul", "node test error"},
- { "test_nesterov_momentum", "node test error"},
- { "test_strnormalizer_export_monday_casesensintive_upper", "node test error"},
- { "test_min_uint16", "node test error"},
- { "test_adam_multiple", "node test error"},
- { "test_loop13_seq", "sequence type is not supported in test infra." },
- { "test_training_dropout_default_mask", "node test error"},
- { "test_min_int8", "node test error"},
- { "test_identity_sequence", "data type not supported"},
+
+ { "test_maxunpool_export_with_output_shape", "results mismatch"},
+
+ { "test_min_int8", "Could not find an implementation for Min(13) node with name"},
+ { "test_min_uint8", "Could not find an implementation for Min(13) node with name"},
+ { "test_min_int16", "Could not find an implementation for Min(13) node with name"},
+ { "test_min_uint16", "Could not find an implementation for Min(13) node with name"},
+
+ { "test_max_int8", "Could not find an implementation for Max(13) node with name"},
+ { "test_max_uint8", "Could not find an implementation for Max(13) node with name"},
+ { "test_max_int16", "Could not find an implementation for Max(13) node with name"},
+ { "test_max_uint16", "Could not find an implementation for Max(13) nodeMeta with name '"},
+
+ { "test_mul_uint8", "Could not find an implementation for Mul(14) node with name" },
+
+ { "test_cast_STRING_to_FLOAT", "Output mismatch"},
+ { "test_cast_BFLOAT16_to_FLOAT", "Output mismatch"},
+ { "test_cast_FLOAT_to_STRING", "Output strings can not be compared exactly"},
+ { "test_castlike_STRING_to_FLOAT", "Output mismatch"},
+ { "test_castlike_STRING_to_FLOAT_expanded", "Output mismatch"},
+ { "test_castlike_BFLOAT16_to_FLOAT", "Length is expected to be equal to Count (metadata and expected data mismatch) "},
+ { "test_castlike_BFLOAT16_to_FLOAT_expanded", "Length is expected to be equal to Count metadata and expected data mismatch"},
+ { "test_castlike_FLOAT_to_BFLOAT16", "Length is expected to be equal to Count. Testdata dims length do not match that of model metadata"},
+ { "test_castlike_FLOAT_to_BFLOAT16_expanded", "Length is expected to be equal to Count"},
+ { "test_castlike_FLOAT_to_STRING", "string comparison does not match due to float rounding"},
+ { "test_castlike_FLOAT_to_STRING_expanded", "string comparison does not match due to float rounding"},
+
+ { "test_bitshift_right_uint16", "Could not find an implementation for BitShift(11) nodeMeta with name ''"},
+ { "test_bitshift_left_uint16", "Could not find an implementation for BitShift(11)"},
+
+ { "test_pow_types_float32_uint64", "Could not find an implementation for Pow(15) node with name ''"},
+ { "test_pow_types_float32_uint32", "Could not find an implementation for Pow(15) node with name ''"},
+
+ { "test_resize_downsample_scales_cubic_align_corners", "Results mismatch"},
+ { "test_resize_downsample_scales_linear_align_corners", "Results mismatch"},
+
{ "test_gru_batchwise", "batchwise operations not supported"},
- { "test_lstm_batchwise", "batchwise operations not supported"},
+ { "test_lstm_batchwise", "Batchwise recurrent operations(layout == 1) are not supported.If you need support create a github issue with justification."},
{ "test_simple_rnn_batchwise", "batchwise operations not supported"},
{ "test_batchnorm_example_training_mode", "opset14 version not implemented yet"},
- { "test_bernoulli", "random generator"},
- { "test_bernoulli_seed", "random generator"},
- { "test_bernoulli_double", "random generator"},
- { "test_bernoulli_expanded", "random generator"},
- { "test_bernoulli_seed_expanded", "random generator"},
- { "test_bernoulli_double_expanded", "random generator"},
- { "test_shape", "opset15 version not implemented yet"},
- { "test_optional_get_element", "optional type is not supported in test infra."},
- { "test_optional_get_element_sequence", "optional type is not supported in test infra."},
- { "test_identity_opt", "optional type is not supported in test infra." },
- { "test_if_opt", "optional type is not supported in test infra." },
- { "test_loop16_seq_none", "sequence type is not supported in test infra." },
- { "test_sequence_map_extract_shapes", "sequence type is not supported in test infra." },
- { "test_sequence_map_identity_1_sequence_1_tensor", "sequence type is not supported in test infra." },
- { "test_sequence_map_identity_1_sequence_1_tensor_expanded", "sequence type is not supported in test infra." },
- { "test_sequence_map_add_1_sequence_1_tensor", "sequence type is not supported in test infra." },
- { "test_sequence_map_identity_1_sequence_expanded", "sequence type is not supported in test infra." },
- { "test_sequence_map_identity_2_sequences", "sequence type is not supported in test infra." },
- { "test_sequence_map_add_2_sequences_expanded", "sequence type is not supported in test infra." },
- { "test_sequence_map_identity_2_sequences_expanded", "sequence type is not supported in test infra." },
- { "test_sequence_map_extract_shapes_expanded", "sequence type is not supported in test infra." },
- { "test_sequence_map_add_1_sequence_1_tensor_expanded", "sequence type is not supported in test infra." },
- { "test_sequence_map_add_2_sequences", "sequence type is not supported in test infra." },
- { "test_sequence_map_identity_1_sequence", "sequence type is not supported in test infra." },
- { "BERT-Squad-int8", "training domain"},
- { "YOLOv3-12-int8", "training_domain"},
+
+ { "test_bernoulli", "random generator, results mismatch"},
+ { "test_bernoulli_seed", "random generator, results mismatch"},
+ { "test_bernoulli_double", "random generator, results mismatch"},
+ { "test_bernoulli_expanded", "random generator, results mismatch"},
+ { "test_bernoulli_seed_expanded", "random generator, results mismatch"},
+ { "test_bernoulli_double_expanded", "random generator, results mismatch"},
+
// the expansion of Softplus uses Exp(1). ORT has a Softplus kernel, so testing the expansion is
// unnecessary and fails as ORT support for Exp started at opset 6 (as ORT didn't exist until opset 7).
- { "test_softplus_example_expanded", "Not applicable"},
- { "test_softplus_expanded", "Not applicable"},
- { "test_col2im_pads", "due to a typo in test data"},
- { "test_optional_has_element_empty_optional_input", "C# API doesn't support optional input"},
- { "test_optional_get_element_optional_tensor", "C# API doesn't support optional input"},
- { "test_optional_get_element_optional_sequence", "C# API doesn't support optional input"},
- { "test_optional_has_element_tensor_input", "C# API doesn't support optional input"},
- { "test_optional_has_element_optional_input", "C# API doesn't support optional input"},
+
+ { "test_clip_default_int8_max_expanded", "Could not find an implementation for Less(13) nodeMeta with name ''" },
+ { "test_softplus_expanded", "Could not find an implementation for Exp(1) node with name ''"},
+ { "test_softplus_example_expanded", "Could not find an implementation for Exp(1) node with name ''"},
+ { "test_div_uint8", "Could not find an implementation for Div(14) nodeMeta with name ''"},
+ { "test_add_uint8", "Opset18 Could not find an implementation for Add(14) nodeMeta with name ''"},
+ { "test_col2im_pads", "Results mismatch due to a typo in test data"},
+
+ { "test_optional_has_element_empty_optional_input", "OptionalProto test metadata. Unable to load 'optional_input' optional element type of: Undefined type"},
+ { "test_loop13_seq", "3rd input is an empty sequence. Ort API does not tolerate empty seq: Number of values should be at least 1" },
+
+ // Training tests
+ { "BERT-Squad-int8", "training domain"},
+ { "YOLOv3-12-int8", "training_domain"},
+
+ { "test_training_dropout_default", "results mismatch"},
+ { "test_training_dropout_default_mask", "Results mismatch"},
+ { "test_training_dropout", "results mismatch"},
+ { "test_training_dropout_mask", "results mismatch."},
+
+ { "test_momentum", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"},
+ { "test_momentum_multiple", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"},
+ { "test_nesterov_momentum", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"},
+
+ { "test_adam", "ai.onnx.preview.training:Adam(-1) is not a registered function/op"},
+ { "test_adam_multiple", "ai.onnx.preview.training:Adam(-1) is not a registered function/op"},
+
+ { "test_adagrad", "ai.onnx.preview.training:Adagrad(-1) is not a registered function/op"},
+ { "test_adagrad_multiple", "ai.onnx.preview.training:Adagrad(-1) is not a registered function/op"},
};
// The following models fails on nocontribops win CI
@@ -460,6 +438,110 @@ public static IEnumerable