diff --git a/.gitignore b/.gitignore index 61fc6e474b8ba..c9e217e135e43 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ onnxprofile_profile_test_*.json /csharp/packages /csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.targets /csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.props +/csharp/**/*.vcxproj.user cmake/external/FeaturizersLibrary/ # Java specific ignores java/.gradle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs index b2a4c2ef47cb2..34e71074d9d9d 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs @@ -3,11 +3,14 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; -using System.Buffers; using System.Collections.Generic; namespace Microsoft.ML.OnnxRuntime { + /// + /// Return immutable collection of results + /// + /// public interface IDisposableReadOnlyCollection : IReadOnlyCollection, IDisposable { @@ -52,7 +55,6 @@ public void Dispose() /// /// This class serves as a container for model run output values including /// tensors, sequences of tensors, sequences and maps. - /// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type /// The class must be disposed of. /// It disposes of _ortValueHolder that owns the underlying Ort output value and /// anything else that would need to be disposed by the instance of the class. @@ -70,24 +72,46 @@ public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable /// Managed object created to represent output value, such as DenseTensor /// List or Dictionary /// - /// Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary() - /// or AsEnumerable() /// Tensor element type if value type is a Tensor /// Object that holds native resources. /// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be /// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented. - private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder) - : base(name, value) + private DisposableNamedOnnxValue(string name, Object value, TensorElementType elementType, IOrtValueOwner ortValueHolder) + : base(name, value, OnnxValueType.ONNX_TYPE_TENSOR) { _ortValueHolder = ortValueHolder; - ValueType = onnxValueType; ElementType = elementType; } /// - /// Returns OnnxValueType + /// Ctor for non-tensor values + /// + /// + /// + /// + /// + private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, IOrtValueOwner ortValueHolder) + : base(name, value, onnxValueType) + { + _ortValueHolder = ortValueHolder; + ElementType = TensorElementType.DataTypeMax; + } + + /// + /// Construct an instance that would contain a map in a form of a Dictionary + /// Currently a limited number of primitive types are supported as map keys and values. + /// So this is not a full implementation of the map type. /// - public OnnxValueType ValueType { get; } + /// + /// + /// + /// + private DisposableNamedOnnxValue(string name, Object value, MapHelper mapHelper, IOrtValueOwner ortValueHolder) + : base(name, value, mapHelper) + { + _ortValueHolder = ortValueHolder; + ElementType = TensorElementType.DataTypeMax; + } /// /// Only valid if ValueType is Tensor @@ -101,22 +125,70 @@ private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxVa /// to do, as this class maintains a native buffer via _ortValueHolder and the memory will be /// disposed by it. This is the case when we are dealing with an OrtValue that is backed by native memory /// and not by pinned managed memory. + /// + /// This class is generally used for outputs to be created on top of the output OrtValue, + /// but the interface (derived from NamedOnnxValue) allows it to be passed as input and one of the test + /// cases does it. Unless we deprecate and re-do the interface, we must support it. /// /// always set to null /// An instance of OrtValue that does not own underlying memory - internal override OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) + internal override OrtValue InputToOrtValue(NodeMetadata metadata, out IDisposable memoryHolder) { - if(_ortValueHolder == null) + if (_ortValueHolder == null) { throw new InvalidOperationException("The instance of this class does not own any OrtValues"); } // PinnedMemoryHandle holds the default value as DisposableNamedOnnxValue // doesn't hold any managed buffer (that needs to be pinned) - pinnedMemoryHandle = null; + memoryHolder = null; // Return non-owning instance of OrtValue return _ortValueHolder.Value; } + /// + /// Generally, this class is created on top of the values that are returned by the model run. + /// So, this method is not expected to be called. However, if it is called (an instance fed as output), + /// it will return the OrtValue that was previously created, since the caller must understand what they are doing. + /// + /// + /// + /// + internal override OrtValue OutputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner) + { + return InputToOrtValue(metadata, out memoryOwner); + } + + internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue) + { + return CreateFromOrtValue(name, ortValue, OrtAllocator.DefaultInstance); + } + + internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue, OrtAllocator allocator) + { + DisposableNamedOnnxValue result = null; + + IntPtr valueType; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(ortValue.Handle, out valueType)); + OnnxValueType onnxValueType = (OnnxValueType)valueType; + switch (onnxValueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + result = FromNativeTensor(name, ortValue); + break; + + case OnnxValueType.ONNX_TYPE_SEQUENCE: + result = FromNativeSequence(name, ortValue, allocator); + break; + + case OnnxValueType.ONNX_TYPE_MAP: + result = FromNativeMap(name, ortValue, allocator); + break; + default: + throw new NotSupportedException("OnnxValueType : " + onnxValueType + " is not supported"); + } + return result; + } + /// /// Creates an instance of DisposableNamedOnnxValue and takes ownership of ortValueElement /// on success. @@ -124,7 +196,7 @@ internal override OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) /// name of the value /// underlying OrtValue /// - internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, OrtValue ortValue) + private static DisposableNamedOnnxValue FromNativeTensor(string name, OrtValue ortValue) { DisposableNamedOnnxValue result = null; @@ -146,46 +218,46 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, switch (elemType) { case TensorElementType.Float: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Double: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int32: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt32: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int64: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt64: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt8: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int8: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.String: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Bool: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Float16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.BFloat16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; default: throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported"); @@ -195,37 +267,6 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, return result; } - internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue) - { - return CreateFromOrtValue(name, ortValue, OrtAllocator.DefaultInstance); - } - - internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue, OrtAllocator allocator) - { - DisposableNamedOnnxValue result = null; - - IntPtr valueType; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(ortValue.Handle, out valueType)); - OnnxValueType onnxValueType = (OnnxValueType)valueType; - switch (onnxValueType) - { - case OnnxValueType.ONNX_TYPE_TENSOR: - result = CreateTensorFromOnnxValue(name, ortValue); - break; - - case OnnxValueType.ONNX_TYPE_SEQUENCE: - result = DisposableNamedOnnxValueFromSequence(name, ortValue, allocator); - break; - - case OnnxValueType.ONNX_TYPE_MAP: - result = DisposableNamedOnnxValueFromNativeMap(name, ortValue, allocator); - break; - default: - throw new NotSupportedException("OnnxValueType : " + onnxValueType + " is not supported"); - } - return result; - } - /// /// This method creates an instance of DisposableNamedOnnxValue that has possession of ortValueElement /// native memory Tensor and returns it to the caller. The original ortValueElement argument looses @@ -236,34 +277,26 @@ internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValu /// name of the output /// native tensor /// DisposableNamedOnnxValue instance - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor(string name, OrtValue ortValue) + private static DisposableNamedOnnxValue FromNativeTensor(string name, OrtValue ortValue) { - if (typeof(T) == typeof(string)) + var ortValueTensor = new OrtValueTensor(ortValue); + try { - var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue); - try + if (typeof(T) == typeof(string)) { - var dt = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); - return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); - } catch(Exception) + var dt = new DenseTensor(ortValueTensor.GetBytesAsStringMemory(), ortValueTensor.Dimensions); + return new DisposableNamedOnnxValue(name, dt, ortValueTensor.ElementType, ortValueTensor); + } + else { - nativeTensorWrapper.Dispose(); - throw; + DenseTensor dt = new DenseTensor(ortValueTensor.Memory, ortValueTensor.Dimensions); + return new DisposableNamedOnnxValue(name, dt, ortValueTensor.ElementType, ortValueTensor); } } - else + catch (Exception) { - NativeOnnxTensorMemory nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue); - try - { - DenseTensor dt = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); - return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); - } - catch (Exception) - { - nativeTensorWrapper.Dispose(); - throw; - } + ortValueTensor.Dispose(); + throw; } } @@ -275,7 +308,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor /// ortValueElement that has native sequence /// used allocator /// DisposableNamedOnnxValue - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(string name, OrtValue ortValueSequence, OrtAllocator allocator) + private static DisposableNamedOnnxValue FromNativeSequence(string name, OrtValue ortValueSequence, OrtAllocator allocator) { DisposableNamedOnnxValue result = null; IntPtr count; @@ -295,8 +328,8 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(str } // NativeOrtValueCollectionOwner will take ownership of ortValueSequence and will make sure sequence // is also disposed. - var nativeCollectionManager = new NativeOrtValueCollectionOwner(ortValueSequence, sequence); - result = new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, TensorElementType.DataTypeMax, nativeCollectionManager); + var nativeCollectionManager = new NativeOrtValueCollectionOwner(ortValueSequence, sequence); + result = new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, nativeCollectionManager); } catch (Exception) { @@ -314,7 +347,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(str /// This function does not take ownership of the map as it we copy all keys an values into a dictionary. We let the caller dispose of it /// /// DisposableNamedOnnxValue - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(string name, OrtValue ortValueMap, OrtAllocator allocator) + private static DisposableNamedOnnxValue FromNativeMap(string name, OrtValue ortValueMap, OrtAllocator allocator) { DisposableNamedOnnxValue result = null; // Map processing is currently not recursing. It is assumed to contain @@ -323,44 +356,87 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(st // not mapped for client consumption. using (var cleanUpList = new DisposableList()) { - // Take possession of the map ortValueElement IntPtr nativeOnnxValueMapKeys = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 0, allocator.Pointer, out nativeOnnxValueMapKeys)); var ortValueKeys = new OrtValue(nativeOnnxValueMapKeys); cleanUpList.Add(ortValueKeys); + var typeAndShape = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape)); + TensorElementType keyElemType; + try + { + IntPtr el_type; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type)); + keyElemType = (TensorElementType)el_type; + } + finally + { + NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); + } + IntPtr nativeOnnxValueMapValues = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 1, allocator.Pointer, out nativeOnnxValueMapValues)); var ortValueValues = new OrtValue(nativeOnnxValueMapValues); cleanUpList.Add(ortValueValues); - IntPtr typeAndShape = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape)); - TensorElementType elemType = TensorElementType.DataTypeMax; + typeAndShape = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapValues, out typeAndShape)); + TensorElementType valueElemType; try { IntPtr el_type; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type)); - elemType = (TensorElementType)el_type; + valueElemType = (TensorElementType)el_type; } finally { NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); } - /// XXX: This code always assumes that the value type is float and makes no checks - /// similar to that of the key. Also Map type in general can also be another sequence or map, - /// not just a tensor - switch (elemType) + // The supported combinations of key and value types are taken from the ORT C API. + switch (keyElemType) { case TensorElementType.Int64: - result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues); + switch (valueElemType) + { + case TensorElementType.Float: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Double: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Int64: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.String: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + default: + break; + } break; case TensorElementType.String: - result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues); + switch (valueElemType) + { + case TensorElementType.Float: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Double: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Int64: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.String: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + default: + break; + } break; default: - throw new NotSupportedException("Map of element type: " + elemType + " is not supported"); + throw new NotSupportedException("Map key type: " + keyElemType + " is not supported"); } } return result; @@ -381,40 +457,78 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(st /// tensor with map keys. /// tensor with map values /// instance of DisposableNamedOnnxValue with Dictionary - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMapElements(string name, + private static DisposableNamedOnnxValue FromNativeMapElements(string name, OrtValue ortValueMap, OrtValue ortValueTensorKeys, OrtValue ortValueTensorValues) { - using (var nativeTensorWrapperValues = new NativeOnnxTensorMemory(ortValueTensorValues)) + var listOfKeysValues = new DisposableList(); + var collOwner = new NativeOrtValueCollectionOwner(ortValueMap, listOfKeysValues); + try { - var denseTensorValues = new DenseTensor(nativeTensorWrapperValues.Memory, nativeTensorWrapperValues.Dimensions); + var tensorKeys = new OrtValueTensor(ortValueTensorKeys); + listOfKeysValues.Add(ortValueTensorKeys); + var tensorValues = new OrtValueTensor(ortValueTensorValues); + listOfKeysValues.Add(ortValueTensorValues); + MapHelper mapHelper = null; if (typeof(K) == typeof(string)) { - var map = new Dictionary(); - using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys)) + var denseTensorKeys = new DenseTensor(tensorKeys.GetBytesAsStringMemory(), tensorKeys.Dimensions); + + if (typeof(V) == typeof(string)) + { + var map = new Dictionary(); + var denseTensorValues = new DenseTensor(tensorValues.GetBytesAsStringMemory(), tensorValues.Dimensions); + for (var i = 0; i < denseTensorKeys.Length; i++) + { + map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + } + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); + } + else { - var denseTensorKeys = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); + var map = new Dictionary(); + var denseTensorValues = new DenseTensor(tensorValues.Memory, tensorValues.Dimensions); for (var i = 0; i < denseTensorKeys.Length; i++) { map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); } - return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); } } else { - var map = new Dictionary(); - using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys)) + var denseTensorKeys = new DenseTensor(tensorKeys.Memory, tensorKeys.Dimensions); + if (typeof(V) == typeof(string)) { - var denseTensorKeys = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); + var map = new Dictionary(); + var denseTensorValues = new DenseTensor(tensorValues.GetBytesAsStringMemory(), tensorValues.Dimensions); for (var i = 0; i < denseTensorKeys.Length; i++) { map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); } - return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); + } + else + { + var denseTensorValues = new DenseTensor(tensorValues.Memory, tensorValues.Dimensions); + var map = new Dictionary(); + for (var i = 0; i < denseTensorKeys.Length; i++) + { + map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + } + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); } } } + catch (Exception) + { + collOwner.Dispose(); + throw; + } } #region IDisposable Support @@ -425,7 +539,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMapEle /// true if invoked by Dispose() protected virtual void Dispose(bool disposing) { - if(_disposed) + if (_disposed) { return; } @@ -448,9 +562,7 @@ protected virtual void Dispose(bool disposing) /// public void Dispose() { - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. Dispose(true); - GC.SuppressFinalize(this); } #endregion diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index 98e1833aed8a6..a6be0afdad093 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -8,6 +8,7 @@ using System.Linq; using Microsoft.ML.OnnxRuntime.Tensors; using System.Buffers; +using System.Diagnostics; namespace Microsoft.ML.OnnxRuntime { @@ -31,11 +32,21 @@ public class InferenceSession : IDisposable /// private Dictionary _inputMetadata; + /// + /// Ordered list of input names + /// + private List _inputNames; + /// /// Dictionary that represent output metadata /// private Dictionary _outputMetadata; + /// + /// Ordered list of output names + /// + private List _outputNames; + /// /// Dictionary that represents overridableInitializers metadata /// @@ -163,6 +174,11 @@ public IReadOnlyDictionary InputMetadata } } + /// + /// Ordered list of input names that can be accessed by index; + /// + public IReadOnlyList InputNames { get { return _inputNames; } } + /// /// Metadata regarding the output nodes, keyed by output names /// @@ -174,6 +190,11 @@ public IReadOnlyDictionary OutputMetadata } } + /// + /// Ordered list of output names that can be accessed by index. + /// + public IReadOnlyList OutputNames { get { return _outputNames; } } + /// /// Metadata regarding the overridable initializers, keyed by node names /// @@ -203,7 +224,8 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. - public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames) + public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, + IReadOnlyCollection outputNames) { return Run(inputs, outputNames, _builtInRunOptions); } @@ -215,13 +237,15 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl /// Specify a collection of string that indicates the output names to fetch. /// /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. - public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames, RunOptions options) + public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, + IReadOnlyCollection outputNames, + RunOptions options) { using (var cleanupList = new DisposableList()) { - var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, cleanupList); - var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, LookupInputMetadata, cleanupList); + var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList); return CreateDisposableResult(ortValues, outputNames); @@ -238,9 +262,7 @@ public IDisposableReadOnlyCollection Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues) { - string[] outputNames = new string[_outputMetadata.Count]; - _outputMetadata.Keys.CopyTo(outputNames, 0); - return Run(inputNames, inputValues, outputNames, _builtInRunOptions); + return Run(inputNames, inputValues, _outputNames, _builtInRunOptions); } /// @@ -279,9 +301,9 @@ public IDisposableReadOnlyCollection Run( using (var cleanupList = new DisposableList()) { - var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList); @@ -336,11 +358,11 @@ public void Run( using (var cleanupList = new DisposableList()) { // prepare inputs - var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( @@ -371,7 +393,7 @@ public void Run( } /// - /// + /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. @@ -386,11 +408,11 @@ public void Run( { using (var cleanupList = new DisposableList()) { - var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList); - var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, LookupInputMetadata, cleanupList); + var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList); - var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList); - var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, LookupOutputMetadata, cleanupList); + var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueForOutput, cleanupList); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, @@ -444,11 +466,11 @@ public void Run( using (var cleanupList = new DisposableList()) { // prepare inputs - var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList); - var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, LookupInputMetadata, cleanupList); + var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList); // prepare outputs - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); var outputValuesArray = GetOrtValuesHandles(outputValues, false); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( @@ -482,7 +504,7 @@ public void Run( } /// - /// + /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. @@ -505,12 +527,12 @@ public void Run( using (var cleanupList = new DisposableList()) { // prepare inputs - var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList); var inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs - var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList); - var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, LookupOutputMetadata, cleanupList); + var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueForOutput, cleanupList); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, @@ -619,22 +641,96 @@ public string EndProfiling() // Delegate for string extraction from an arbitrary input/output object private delegate string NameExtractor(TInput input); + // delegate to fetch input/output OrtValue + private delegate OrtValue OrtValueExtractor(NamedOnnxValue value, NodeMetadata metadata, out IDisposable memOwner); + + // Delegate to lookup metadata for input/initializers/output + private delegate NodeMetadata MetadataLookup(string nodeName); + + /// + /// Checks if the name is a known input or overridable initializer name + /// and if so, returns metadata for it. + /// metadata + /// + /// + /// NodeMetadata for the nodeName + /// + private NodeMetadata LookupInputMetadata(string nodeName) + { + NodeMetadata meta; + if (!_inputMetadata.TryGetValue(nodeName, out meta) && + !_overridableInitializerMetadata.TryGetValue(nodeName, out meta)) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input name: '{nodeName}' is not in the metadata"); + } + return meta; + } + + /// + /// Checks if the nodeName is a known output name and if so returns metadata for it. + /// + /// + /// + /// + private NodeMetadata LookupOutputMetadata(string nodeName) + { + NodeMetadata meta; + if (!_outputMetadata.TryGetValue(nodeName, out meta)) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Output name: '{nodeName}' is not in the metadata"); + } + return meta; + } + + /// + /// Fetches/creates OrtValue for the content of the input + /// + /// + /// + /// + /// + private static OrtValue ExtractOrtValueForInput(NamedOnnxValue input, NodeMetadata metadata, out IDisposable memOwner) + { + return input.InputToOrtValue(metadata, out memOwner); + } + + /// + /// Fetches/Creates OrtValue for output + /// + /// + /// + /// + /// May return null if the onnx value type does not support pre-creation of output OrtValues + private static OrtValue ExtractOrtValueForOutput(NamedOnnxValue output, NodeMetadata metadata, out IDisposable memOwner) + { + return output.OutputToOrtValue(metadata, out memOwner); + } + /// /// Run helper /// - /// names to convert to zero terminated utf8 and pin + /// names to convert to zero terminated utf8 and pin + /// extractor functor that helps extracting names from inputs + /// inputs/outputs metadata /// list to add pinned memory to for later disposal /// - private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection inputs, NameExtractor extractor, + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection values, NameExtractor nameExtractor, + MetadataLookup metaLookup, DisposableList cleanupList) { - var result = new IntPtr[inputs.Count]; - for (int i = 0; i < inputs.Count; ++i) + cleanupList.Capacity += values.Count; + var result = new IntPtr[values.Count]; + for (int i = 0; i < values.Count; ++i) { - var name = extractor(inputs.ElementAt(i)); - var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name); - var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned)); - result[i] = pinnedHandle.Pointer; + var name = nameExtractor(values.ElementAt(i)); + NodeMetadata meta = metaLookup(name); + var utf8Name = meta.ZeroTerminatedUtf8Name; + Debug.Assert(utf8Name != null); + var pinnedHandle = new Memory(utf8Name).Pin(); + unsafe + { + result[i] = (IntPtr)pinnedHandle.Pointer; + } cleanupList.Add(pinnedHandle); } return result; @@ -642,28 +738,41 @@ private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection inputs, NameExtrac /// /// This function obtains ortValues for NamedOnnxValue. - /// The problem with NamedOnnxValue is that it does not contain any Onnx (OrtValue) - /// so calling ToOrtValue creates a new instance of OrtValue that needs to be disposed. + /// The problem with NamedOnnxValue is that it is not disposable and can not contain any disposable items. + /// so calling InputToOrtValue creates a new instance of OrtValue that needs to be disposed. /// The deriving object DisposableNamedValue actually contains and owns OrtValue and it returns /// it. /// - /// - /// + /// a collection of NamedOnnxValues + /// Metadata lookup function (input/initializers/output) + /// list to cleanup in an exception safe manner /// - private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, DisposableList cleanupList) + private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, MetadataLookup metaLookup, + OrtValueExtractor ortValueExtractor, + DisposableList cleanupList) { + cleanupList.Capacity += values.Count * 2; IntPtr[] result = new IntPtr[values.Count]; - for (int inputIndex = 0; inputIndex < values.Count; ++inputIndex) + for (int valueIndex = 0; valueIndex < values.Count; ++valueIndex) { - var input = values.ElementAt(inputIndex); - MemoryHandle? memHandle; - var ortValue = input.ToOrtValue(out memHandle); - if (memHandle.HasValue) + var value = values.ElementAt(valueIndex); + var meta = metaLookup(value.Name); + var ortValue = ortValueExtractor(value, meta, out IDisposable memHolder); + if (memHolder != null) + { + cleanupList.Add(memHolder); + } + if (ortValue != null) + { + if (ortValue.IsOwned) + cleanupList.Add(ortValue); + + result[valueIndex] = ortValue.Handle; + } + else { - cleanupList.Add(memHandle); + result[valueIndex] = IntPtr.Zero; } - cleanupList.Add(ortValue); - result[inputIndex] = ortValue.Handle; } return result; } @@ -687,7 +796,8 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v private DisposableList RunImpl(RunOptions options, IntPtr[] inputNames, IntPtr[] inputValues, IntPtr[] outputNames, DisposableList cleanupList) { - var ortValues = new DisposableList(outputNames.Length); + cleanupList.Capacity += 1; + var ortValues = new DisposableList(outputNames.Length + 1); cleanupList.Add(ortValues); IntPtr[] outputValuesArray = new IntPtr[outputNames.Length]; @@ -717,8 +827,7 @@ IDisposableReadOnlyCollection CreateDisposableResult(L { for (int i = 0; i < ortValues.Count; i++) { - var ortValue = ortValues[i]; - result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ortValue)); + result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ortValues[i])); } } catch (OnnxRuntimeException) @@ -769,12 +878,11 @@ private void Init(string modelPath, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer = null) { var envHandle = OrtEnv.Handle; - var session = IntPtr.Zero; - + IntPtr session; if (prepackedWeightsContainer == null) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), - options.Handle, out session)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), + options.Handle, out session)); } else @@ -791,8 +899,7 @@ private void Init(byte[] modelData, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer = null) { var envHandle = OrtEnv.Handle; - var session = IntPtr.Zero; - + IntPtr session; if (prepackedWeightsContainer == null) { @@ -820,48 +927,57 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options) _nativeHandle = session; try { - // Initialize input/output metadata - _inputMetadata = new Dictionary(); - _outputMetadata = new Dictionary(); - _overridableInitializerMetadata = new Dictionary(); // get input count - UIntPtr inputCount = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out UIntPtr inputCount)); // get all the input names and metadata + _inputMetadata = new Dictionary((int)inputCount); + _inputNames = new List((int)inputCount); + for (ulong i = 0; i < (ulong)inputCount; i++) { - var iname = GetInputName(i); - _inputMetadata[iname] = GetInputMetadata(i); + var inputMeta = GetInputMetadata(i); + var iname = GetInputName(i, out byte[] utf8); + _inputNames.Add(iname); + inputMeta.ZeroTerminatedUtf8Name = utf8; + _inputMetadata[iname] = inputMeta; } // get output count - UIntPtr outputCount = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out UIntPtr outputCount)); // get all the output names and metadata + _outputMetadata = new Dictionary((int)outputCount); + _outputNames = new List((int)outputCount); + for (ulong i = 0; i < (ulong)outputCount; i++) { - _outputMetadata[GetOutputName(i)] = GetOutputMetadata(i); + var outputMeta = GetOutputMetadata(i); + var oname = GetOutputName(i, out byte[] utf8); + _outputNames.Add(oname); + outputMeta.ZeroTerminatedUtf8Name = utf8; + _outputMetadata[oname] = outputMeta; } // get overridable initializer count - UIntPtr initilaizerCount = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out initilaizerCount)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out UIntPtr initilaizerCount)); + _overridableInitializerMetadata = new Dictionary((int)initilaizerCount); // get all the overridable initializer names and metadata for (ulong i = 0; i < (ulong)initilaizerCount; i++) { - _overridableInitializerMetadata[GetOverridableInitializerName(i)] = GetOverridableInitializerMetadata(i); + var meta = GetOverridableInitializerMetadata(i); + var iname = GetOverridableInitializerName(i, out byte[] utf8); + meta.ZeroTerminatedUtf8Name = utf8; + _overridableInitializerMetadata[iname] = meta; } // set profiling's start time - UIntPtr startTime = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetProfilingStartTimeNs(_nativeHandle, - out startTime)); + out UIntPtr startTime)); _profilingStartTimeNs = (ulong)startTime; } - catch (OnnxRuntimeException) + catch (Exception) { if (_nativeHandle != IntPtr.Zero) { @@ -875,64 +991,60 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options) } - private string GetOutputName(ulong index) + private string GetOutputName(ulong index, out byte[] utf8) { + string str; var allocator = OrtAllocator.DefaultInstance; - IntPtr nameHandle = IntPtr.Zero; - string str = null; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, - out nameHandle)); + out IntPtr nameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { - str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); + NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8); } return str; } - private string GetInputName(ulong index) + private string GetInputName(ulong index, out byte[] utf8) { - string str = null; + string str; var allocator = OrtAllocator.DefaultInstance; - IntPtr nameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, - out nameHandle)); + out IntPtr nameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { - str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); + NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8); } return str; } - private string GetOverridableInitializerName(ulong index) + private string GetOverridableInitializerName(ulong index, out byte[] utf8) { - string str = null; + string str; var allocator = OrtAllocator.DefaultInstance; - IntPtr nameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerName( _nativeHandle, (UIntPtr)index, allocator.Pointer, - out nameHandle)); + out IntPtr nameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { - str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); + NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8); } return str; } private NodeMetadata GetInputMetadata(ulong index) { - IntPtr typeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); @@ -945,8 +1057,7 @@ private NodeMetadata GetInputMetadata(ulong index) private NodeMetadata GetOutputMetadata(ulong index) { - IntPtr typeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); @@ -959,8 +1070,7 @@ private NodeMetadata GetOutputMetadata(ulong index) private NodeMetadata GetOverridableInitializerMetadata(ulong index) { - IntPtr typeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out IntPtr typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); @@ -975,40 +1085,112 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) { OnnxValueType valueType; { - IntPtr valType; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOnnxTypeFromTypeInfo(typeInfo, out valType)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOnnxTypeFromTypeInfo(typeInfo, out IntPtr valType)); valueType = (OnnxValueType)valType; } - if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR) + + switch (valueType) { - return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue)); + case OnnxValueType.ONNX_TYPE_TENSOR: + case OnnxValueType.ONNX_TYPE_SPARSETENSOR: + return GetTensorNodeMetadata(valueType, typeInfo); + case OnnxValueType.ONNX_TYPE_SEQUENCE: + return GetSequenceMetadataFromTypeInfo(typeInfo); + case OnnxValueType.ONNX_TYPE_MAP: + return GetMapMetadataFromTypeInfo(typeInfo); + case OnnxValueType.ONNX_TYPE_OPTIONAL: + return GetOptionalMetadataFromTypeInfo(typeInfo); } - // This should not be released - IntPtr tensorInfo; - NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo, out tensorInfo)); //(IntPtr)(int)(uint) - // Convert the newly introduced OrtTypeInfo* to the older OrtTypeAndShapeInfo* + throw new OnnxRuntimeException(ErrorCode.NotImplemented, $"Value type: '{valueType}' not supported in this code"); + } - if (tensorInfo == IntPtr.Zero) - return null; + internal static NodeMetadata GetSequenceMetadataFromTypeInfo(IntPtr typeInfo) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToSequenceTypeInfo(typeInfo, out IntPtr sequenceTypeInfo)); + // Casts API are broken. Always return success, but may return null for the result. + if (sequenceTypeInfo == IntPtr.Zero) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to SequenceTypeInfo failed. The object does not represent a sequence"); + } - TensorElementType type; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetSequenceElementType(sequenceTypeInfo, out IntPtr elementType)); + try { - IntPtr el_type; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out el_type)); - type = (TensorElementType)el_type; + var elementMeta = GetMetadataFromTypeInfo(elementType); + var seqMeta = new SequenceMetadata(elementMeta); + return new NodeMetadata(seqMeta); + } + finally + { + NativeMethods.OrtReleaseTypeInfo(elementType); + } + } + + internal static NodeMetadata GetMapMetadataFromTypeInfo(IntPtr typeInfo) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToMapTypeInfo(typeInfo, out IntPtr mapTypeInfo)); + // Casts API are broken. Always return success, but may return null for the result. + if (mapTypeInfo == IntPtr.Zero) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to MapTypeInfo failed. The object does not represent a map"); + } + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapKeyType(mapTypeInfo, out IntPtr keyType)); + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapValueType(mapTypeInfo, out IntPtr valueTypeInfo)); + try + { + var valueMetadata = GetMetadataFromTypeInfo(valueTypeInfo); + var mapMeta = new MapMetadata((TensorElementType)keyType, valueMetadata); + return new NodeMetadata(mapMeta); + } + finally + { + NativeMethods.OrtReleaseTypeInfo(valueTypeInfo); + } + } + + internal static NodeMetadata GetOptionalMetadataFromTypeInfo(IntPtr typeInfo) + { + // This should not be destroyed + NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToOptionalTypeInfo(typeInfo, out IntPtr optTypeInfo)); + // Casts API are broken. Always return success, but may return null for the result. + if (optTypeInfo == IntPtr.Zero) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to OptionalTypeInfo failed. The object does not represent a optional value"); + } + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOptionalContainedTypeInfo(optTypeInfo, out IntPtr elementTypeInfo)); + try + { + var elementMetadata = GetMetadataFromTypeInfo(elementTypeInfo); + var optMetadata = new OptionalMetadata(elementMetadata); + return new NodeMetadata(optMetadata); + } + finally + { + NativeMethods.OrtReleaseTypeInfo(elementTypeInfo); } + } - Type dotnetType = null; - int width = 0; - if (!TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width)) + internal static NodeMetadata GetTensorNodeMetadata(OnnxValueType valueType, IntPtr typeInfo) + { + // Fetch tensor type and shape from the TypeInfo + NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo, out IntPtr tensorInfo)); //(IntPtr)(int)(uint) + // Casts API are broken. Always return success, but may return null for the result. + if (tensorInfo == IntPtr.Zero) { - throw new OnnxRuntimeException(ErrorCode.InvalidArgument, - "Unable to query type information for data type: " + type.ToString()); + throw new OnnxRuntimeException(ErrorCode.Fail, "TypeInfo cast to TensorTypeInfo failed. The object does not represent a tensor"); } - UIntPtr numDimensions; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions)); + TensorElementType type; + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out IntPtr el_type)); + type = (TensorElementType)el_type; + } + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out UIntPtr numDimensions)); long[] dimensions = new long[(int)numDimensions]; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions)); @@ -1028,7 +1210,8 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) symbolicDimensions[i] = NativeOnnxValueHelper.StringFromNativeUtf8(dimensionNamePtrs[i]); } - return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType); + var tensorTypeAndShape = new TensorTypeAndShape(type, intDimensions, symbolicDimensions); + return new NodeMetadata(valueType, tensorTypeAndShape); } /// @@ -1105,23 +1288,27 @@ protected virtual void Dispose(bool disposing) /// - /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes + /// Represents tensor element type and its shapes /// - public class NodeMetadata + public class TensorTypeAndShape { - internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] symbolicDimensions, Type type) + internal TensorTypeAndShape(TensorElementType elementType, int[] dimensions, string[] symbolicDimensions) { - OnnxValueType = onnxValueType; + ElementTypeInfo = TensorBase.GetElementTypeInfo(elementType); + if (ElementTypeInfo == null) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unregistered TensorElementType value of: " + elementType.ToString()); + } + ElementDataType = elementType; Dimensions = dimensions; SymbolicDimensions = symbolicDimensions; - ElementType = type; } /// - /// Type value of the node + /// Tensor Element type /// - /// A value of OnnxValueType enum - public OnnxValueType OnnxValueType { get; } + /// TensorElementType enum + public TensorElementType ElementDataType { get; } /// /// Shape @@ -1136,10 +1323,259 @@ internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] sy public string[] SymbolicDimensions { get; } /// - /// .NET type that corresponds to this Node. + /// Tensor element metadata + /// + public TensorElementTypeInfo ElementTypeInfo { get; } + } + + /// + /// Represents sequnce metdata + /// + public class SequenceMetadata + { + /// + /// __ctor + /// + /// + internal SequenceMetadata(NodeMetadata elementData) + { + ElementMeta = elementData; + } + /// + /// Element Metatada, recursive definition with a Tensor being a base case + /// may contain maps, tensors and other sequences + /// + public NodeMetadata ElementMeta { get; } + } + + /// + /// The class contains metadata for an optional input/output + /// + public class OptionalMetadata + { + /// + /// __ctor + /// + /// + internal OptionalMetadata(NodeMetadata elementData) + { + ElementMeta = elementData; + } + + /// + /// Element Metatada, recursive definition with a Tensor being a base case + /// may contain maps, tensors and sequences + /// + public NodeMetadata ElementMeta { get; } + } + + /// + /// Represents Map MetaData. + /// Key is always a tensor denoted by an element type + /// with value type being a recursive structure that may + /// contain other maps, sequences or tensors. + /// + public class MapMetadata + { + internal MapMetadata(TensorElementType keyDataType, NodeMetadata valueMetadata) + { + KeyDataType = keyDataType; + ValueMetadata = valueMetadata; + } + + /// + /// Key tensor data type + /// + /// A value of TensorElementType enum + public TensorElementType KeyDataType { get; } + + /// + /// Value metadata + /// + /// /// Instance of Nodemetadata for the value of the map + public NodeMetadata ValueMetadata { get; } + } + + /// + /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes + /// + public class NodeMetadata + { + private readonly Object _metadata; + /// + /// Constructs NodeMetadata for tensor + /// + /// either ONNX_TYPE_TENSOR or ONNX_TYPE_SPARSETENSOR + /// Tensor type and shape information + internal NodeMetadata(OnnxValueType onnxValueType, TensorTypeAndShape typeAndShape) + { + OnnxValueType = onnxValueType; + CheckTensor(); + _metadata = typeAndShape; + } + + /// + /// __ctor for map metadata + /// + /// + internal NodeMetadata(MapMetadata mapMetadata) + { + OnnxValueType = OnnxValueType.ONNX_TYPE_MAP; + _metadata = mapMetadata; + } + + /// + /// __ctor for sequence metadata + /// + /// + internal NodeMetadata(SequenceMetadata sequenceMetadata) + { + OnnxValueType = OnnxValueType.ONNX_TYPE_SEQUENCE; + _metadata = sequenceMetadata; + } + + /// + /// __ctor + /// + /// + internal NodeMetadata(OptionalMetadata optMetadata) + { + OnnxValueType = OnnxValueType.ONNX_TYPE_OPTIONAL; + _metadata = optMetadata; + } + + private void CheckTensor() + { + if (!IsTensor) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "OnnxValueType must either be a tensor or sparse tensor"); + } + } + + /// + /// Retrieves MapMetadata, valid only if this node represents a Map. + /// + /// + /// when the instance does not contain map metadata + public MapMetadata AsMapMetadata() + { + if (OnnxValueType != OnnxValueType.ONNX_TYPE_MAP) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Map metadata"); + } + return _metadata as MapMetadata; + } + + /// + /// Retrieves SequenceMetadata, valid only if this node represents a Sequence + /// + /// + /// when the instance does not contain sequence metadata + public SequenceMetadata AsSequenceMetadata() + { + if (OnnxValueType != OnnxValueType.ONNX_TYPE_SEQUENCE) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Sequence metadata"); + } + return _metadata as SequenceMetadata; + } + + /// + /// Retrieves Optional type metadata, valid if this node is optional + /// Optional metadata is nothing more than just a container for all the usual + /// element types. + /// + /// + /// + public OptionalMetadata AsOptionalMetadata() + { + if (OnnxValueType != OnnxValueType.ONNX_TYPE_OPTIONAL) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Optional metadata"); + } + return _metadata as OptionalMetadata; + } + + /// + /// Type value of the node + /// + /// A value of OnnxValueType enum + public OnnxValueType OnnxValueType { get; } + + /// + /// Zero terminated UTF-8 name of the input/output + /// Present only on the top-level instance + /// metadata dictionary entries. + /// + /// Used to avoid utf8 conversions on every run and associated allocations + /// + internal byte[] ZeroTerminatedUtf8Name { get; set; } + + /// + /// Tensor shape valid only if this is a Tensor. + /// Preserved for API compatibility + /// + /// Array of dimensions + public int[] Dimensions + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).Dimensions; + } + } + + /// + /// Symbolic dimensions valid only if this is a Tensor. + /// Preserved for API compatibility + /// + /// Array of symbolic dimensions if present. + public string[] SymbolicDimensions + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).SymbolicDimensions; + } + } + + /// + /// .NET type that corresponds to the primitive Tensor data type. + /// Valid only if this is a Tensor. /// /// System.Type - public System.Type ElementType { get; } + public System.Type ElementType + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).ElementTypeInfo.TensorType; + } + } + + /// + /// Tensor Element Type. Valid if tensor + /// + public TensorElementType ElementDataType + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).ElementDataType; + } + } + + /// + /// Convinience method to check for string + /// + public bool IsString + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).ElementTypeInfo.IsString; + } + } /// /// Whether it is a Tensor @@ -1149,7 +1585,7 @@ public bool IsTensor { get { - return true; // currently only Tensor nodes are supported + return (OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) || (OnnxValueType == OnnxValueType.ONNX_TYPE_SPARSETENSOR); } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs new file mode 100644 index 0000000000000..664ba21cfd1bb --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.ML.OnnxRuntime.Tensors; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Microsoft.ML.OnnxRuntime +{ + /// + /// The class helps to feed the NamedOnnxValue as inference input. + /// It projects managed classes to OrtValues so they can be consumed + /// by the native onnxruntime library. if possible, it will avoid copying data. + /// The NamedOnnxValue can be a tensor, sequence or map. + /// For recursive structures, create nested NamedOnnxValue instances. + /// For example, a sequence instance would contain a list of NamedOnnxValue instances + /// that in turn may represent tensors or other ONNX values. + /// + internal class ManagedTypeProjection : IDisposable + { + readonly DisposableList _disposables; + readonly OrtValue _ortValue; + bool _disposed = false; + + /// + /// Provides access to non-owning instance of OrtValue + /// + /// Provides access to the OrtValue to be used as input + internal OrtValue Value { get { return new OrtValue(_ortValue.Handle, false); } } + + /// + /// Constructor to create an input OrtValue projection from managed data + /// + /// + /// + /// + internal ManagedTypeProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata) + { + int requiredCapacity = 32; + var disposables = new DisposableList(requiredCapacity); + try + { + _ortValue = CreateDispatchProjection(namedOnnxValue, metadata, disposables); + } + catch (Exception) + { + disposables.Dispose(); + throw; + } + _disposables = disposables; + } + + /// + /// Dispatches the creation of the projection + /// + /// + /// + /// + /// + private OrtValue CreateDispatchProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList disposables) + { + OrtValue result; + + NodeMetadata meta = metadata; + // Use element meta to create types + if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL) + { + meta = metadata.AsOptionalMetadata().ElementMeta; + } + + if (namedOnnxValue.ValueType != meta.OnnxValueType) + { + throw new OnnxRuntimeException(ErrorCode.RuntimeException, + $"NamedOnnxValue: {namedOnnxValue.Name} has value type: {namedOnnxValue.ValueType}" + + $" expected: {meta.OnnxValueType} after optional type adjustment"); + } + + switch (namedOnnxValue.ValueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + result = CreateTensorProjection(namedOnnxValue, meta, disposables); + break; + case OnnxValueType.ONNX_TYPE_SEQUENCE: + result = CreateSequenceProjection(namedOnnxValue, meta, disposables); + break; + case OnnxValueType.ONNX_TYPE_MAP: + result = CreateMapProjection(namedOnnxValue, meta, disposables); + break; + default: + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "ManagedTypeProjection can only project tensors, sequences, maps and optional types"); + } + return result; + } + + /// + /// The function creates OrtValue objects for each element of the sequence + /// and then creates an OrtValue for the whole sequence. + /// + /// NamedOnnxValue containing a IEnumeralbe + /// sequence metadata + /// cleanup list + /// + /// + private OrtValue CreateSequenceProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList disposables) + { + OrtValue result = null; + var elementMeta = metadata.AsSequenceMetadata().ElementMeta; + var elementOnnxValue = elementMeta.OnnxValueType; + var seqContainer = namedOnnxValue.AsEnumerable(); + + if (seqContainer is null) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"NamedOnnxValue: {namedOnnxValue.Name} sequence does not contain NamedOnnxValue elements"); + } + + int capacity = 0; + + if (seqContainer is ICollection) + { + capacity = ((ICollection)seqContainer).Count; + } + + // Record all the ortValues belonging to the sequence locally + var sequenceOrtValues = new List(capacity); + foreach (var element in seqContainer) + { + if (elementOnnxValue != element.ValueType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"NamedOnnxValue: {namedOnnxValue.Name} sequence element expected to be {elementOnnxValue}, received {element.ValueType}"); + } + + sequenceOrtValues.Add(CreateDispatchProjection(element, elementMeta, disposables)); + } + + IntPtr[] ortValHandles = new IntPtr[sequenceOrtValues.Count]; + for (int i = 0; i < sequenceOrtValues.Count; i++) + { + ortValHandles[i] = sequenceOrtValues[i].Handle; + } + + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles, + (UIntPtr)sequenceOrtValues.Count, (IntPtr)OnnxValueType.ONNX_TYPE_SEQUENCE, out IntPtr sequenceHandle)); + result = new OrtValue(sequenceHandle); + disposables.Add(result); + + return result; + } + + /// + /// Creates map projection. Since we support only primitive types in maps + /// we map two tensors (keys and values) + /// + /// + /// + /// + /// OrtValue + /// + private OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList disposables) + { + OrtValue result = null; + var mapMeta = elementMeta.AsMapMetadata(); + Debug.Assert(mapMeta != null); + // Maps currently support only primitive types expressed as two parallel tensors and not nested Sequences or Maps + + var mapValuesMeta = mapMeta.ValueMetadata; + if (mapValuesMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Node: {node.Name} onnxruntime only supports maps with primitive types values"); + } + + + var keys = node.GetDictionaryKeys(); + var ortValueKeys = OrtValue.CreateFromTensorObject(keys, + out MemoryHandle? memoryHandleKeys, out TensorElementType elementTypeKeys); + disposables.Add(ortValueKeys); + + if (memoryHandleKeys.HasValue) + { + disposables.Add(memoryHandleKeys); + } + + if (elementTypeKeys != mapMeta.KeyDataType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Map key data type supplied: {elementTypeKeys} metadata expected: {mapMeta.KeyDataType}"); + } + + var values = node.GetDictionaryValues(); + var ortValueValues = OrtValue.CreateFromTensorObject(values, + out MemoryHandle? memoryHandleValues, out TensorElementType elementTypeValues); + + disposables.Add(ortValueValues); + if (memoryHandleValues.HasValue) + { + disposables.Add(memoryHandleValues); + } + + if (elementTypeValues != mapValuesMeta.ElementDataType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}"); + } + + // Create Map OrtValue + IntPtr[] ortValHandles = { ortValueKeys.Handle, ortValueValues.Handle }; + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles, (UIntPtr)2, + (IntPtr)OnnxValueType.ONNX_TYPE_MAP, out IntPtr ortValueMap)); + result = new OrtValue(ortValueMap); + disposables.Add(result); + return result; + } + + + /// + /// This pins memory that is contained within DenseTensor. + /// + /// NodeOnnxValue containing DenseTensor + /// + /// cleanup list + /// + /// + private OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList disposables) + { + var ortValue = OrtValue.CreateFromTensorObject(node.Value, + out MemoryHandle? memoryHandle, out TensorElementType elementType); + disposables.Add(ortValue); + + if (memoryHandle.HasValue) + { + disposables.Add(memoryHandle); + } + + if (elementType != elementMeta.ElementDataType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Tensor element data type discovered: {elementType} metadata expected: {elementMeta.ElementDataType}"); + } + + return ortValue; + } + + #region IDisposable + /// + /// IDisposable implementation + /// + /// true if invoked by Dispose() + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + // dispose managed state (managed objects). + if (disposing) + { + _disposables.Dispose(); + } + _disposed = true; + } + + + public void Dispose() + { + Dispose(true); + } + + #endregion IDisposable + } +} + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs index 233e5fa9af87e..5f08daf73806a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs @@ -5,38 +5,125 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; namespace Microsoft.ML.OnnxRuntime { /// - /// The class associates a name with an Object. Currently it supports Tensor - /// as possible objects. The name of the class is a misnomer, it does not hold any - /// Onnx values. + /// The class holds keys and values for the dictionary + /// in a for of two DenseTensors. The class is used to avoid + /// data copy and make these available to the native code. + /// Strings require special handling. + /// + internal class MapHelper + { + internal MapHelper(object keys, object values) + { + Keys = keys; + Values = values; + } + internal Object Keys { get; } // DenseTensor + internal Object Values { get; } // DenseTensor + } + + /// + /// The class associates a name with an Object. + /// The name of the class is a misnomer, it does not hold any Onnx values, + /// just managed representation of them. + /// + /// The class is currently used as both inputs and outputs. Because it is non- + /// disposable, it can not hold on to any native objects. + /// + /// When used as input, we temporarily create OrtValues that map managed inputs + /// directly. Thus we are able to avoid copying. + /// + /// For outputs, tensor buffers works the same as input, providing it matches + /// the expected output shape. For other types (maps and sequences) we create a copy of the data. + /// This is because, the class is not Disposable and it is a public interface, thus it can not own + /// the underlying OrtValues that must be destroyed before Run() returns. + /// + /// To avoid data copying on output, use DisposableNamedOnnxValue class that is returned from Run() methods. + /// This provides access to the native memory and avoids copying. + /// + /// It is a recursive structure that may contain Tensors (base case) + /// Other sequences and maps. Although the OnnxValueType is exposed, + /// the caller is supposed to know the actual data type contained. + /// + /// The convention is that for tensors, it would contain a DenseTensor instance or + /// anything derived from Tensor. + /// + /// For sequences, it would contain a IList where T is an instance of NamedOnnxValue that + /// would contain a tensor or another type. + /// + /// For Maps, it would contain a IDictionary where K,V are primitive types or strings. + /// /// public class NamedOnnxValue { /// /// Managed Tensor, Dictionary or IList /// - protected Object _value; + private Object _value; /// /// Name of the instance, model input/output /// - protected string _name; + private string _name; + + private MapHelper _mapHelper; // used for maps, otherwise null /// /// Constructs an instance of NamedOnnxValue and represents - /// a model input to an inference session. It also represents a modle output - /// when serves as a base for DisposablenamedOnnxvalue + /// a model input to an inference session. /// /// input/output name /// Object that may be a tensor, Dictionary, IList + [Obsolete("Use constructors with valueType or static factory methods")] protected NamedOnnxValue(string name, Object value) { _name = name; _value = value; + ValueType = OnnxValueType.ONNX_TYPE_UNKNOWN; } + /// + /// Constructs an instance that contains a tensor, sequence or optional type. + /// + /// + /// + /// + internal NamedOnnxValue(string name, Object value, OnnxValueType valueType) + { + _name = name; + _value = value; + ValueType = valueType; + + if (valueType == OnnxValueType.ONNX_TYPE_MAP) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Use another __ctor for maps"); + } + } + + /// + /// Use this to construct maps + /// + /// + /// + /// + internal NamedOnnxValue(string name, Object value, MapHelper helper) + { + _name = name; + _value = value; + ValueType = OnnxValueType.ONNX_TYPE_MAP; + _mapHelper = helper; + } + + /// + /// Onnx Value Type if known. In general, NamedOnnxValue is able to contain + /// arbitrary objects. Please, follow the convention described in the class doc. + /// + public OnnxValueType ValueType { get; internal set; } + /// /// This is a factory method that instantiates NamedOnnxValue /// and associated name with an instance of a Tensor @@ -47,7 +134,40 @@ protected NamedOnnxValue(string name, Object value) /// public static NamedOnnxValue CreateFromTensor(string name, Tensor value) { - return new NamedOnnxValue(name, value); + return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR); + } + + /// + /// This is a factory method that instantiates NamedOnnxValue. + /// It would contain a sequence of elements + /// + /// + /// + /// + public static NamedOnnxValue CreateFromSequence(string name, IEnumerable value) + { + return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_SEQUENCE); + } + + /// + /// Instantiates NamedOnnxValue that contains IDictionary + /// + /// Keys type + /// Values type + /// + /// + /// new instance of NamedOnnxValue + public static NamedOnnxValue CreateFromMap(string name, IDictionary value) + { + // The order in which Keys and Values are unspecified, + // but it is guaranteed to be the same order + // These tensors are 1-D + var keysMemory = new Memory(value.Keys.ToArray()); + var keysTensor = new DenseTensor(keysMemory, new int[1] { keysMemory.Length }); + + var valuesMemory = new Memory(value.Values.ToArray()); + var valuesTensor = new DenseTensor(valuesMemory, new int[1] { valuesMemory.Length }); + return new NamedOnnxValue(name, value, new MapHelper(keysTensor, valuesTensor)); } /// @@ -73,6 +193,8 @@ public Tensor AsTensor() /// /// Try-get value as an Enumerable<T>. + /// T is usually a NamedOnnxValue instance that may contain + /// Tensors, Sequences, Maps or optional types /// /// Type /// Enumerable object if contained value is a Enumerable. Null otherwise @@ -85,8 +207,8 @@ public IEnumerable AsEnumerable() /// /// Try-get value as an Dictionary<K,V>. /// - /// Key type - /// Value type + /// Key type currently primitive type only + /// Value type, currently primitive type only /// Dictionary object if contained value is a Dictionary. Null otherwise public IDictionary AsDictionary() { @@ -94,15 +216,88 @@ public IDictionary AsDictionary() } /// - /// Pin the underlying memory and create an instance of OrtValue + /// Pin the underlying memory and create an instance of OrtValue containing a tensor /// based on the pinned managed memory. The caller is responsible for Disposing /// both OrtValue and pinnedMemoryHandle /// /// dispose after returned OrtValus is disposed /// an instance of OrtValue. The lifespan of OrtValue must overlap pinnedMemoryHandle - internal virtual OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) + internal virtual OrtValue InputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner) + { + var projection = new ManagedTypeProjection(this, metadata); + memoryOwner = projection; + return projection.Value; + } + + /// + /// Produces an output value for outputs. This produces an output value + /// only for tensors or optional types that can contain a tensor. + /// For all other Onnx value types, this method throws. Use Run() overloads + /// that return DisposableNamedOnnxValue to get access to all Onnx value types + /// that may be returned as output. + /// + /// + /// + /// + internal virtual OrtValue OutputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner) { - return OrtValue.CreateFromTensorObject(_value, out pinnedMemoryHandle, out TensorElementType elementType); + // For NamedOnnxValue for output we only allow to produce OrtValue for tensors + // or optional type that may contain a tensor + if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) + { + var projection = new ManagedTypeProjection(this, metadata); + memoryOwner = projection; + return projection.Value; + } + + if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL) + { + var meta = metadata.AsOptionalMetadata().ElementMeta; + if (meta.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) + { + var projection = new ManagedTypeProjection(this, meta); + memoryOwner = projection; + return projection.Value; + } + } + + throw new OnnxRuntimeException(ErrorCode.NotImplemented, + $"Can not create output OrtValue for NamedOnnxValue '{metadata.OnnxValueType}' type." + + $" Only tensors can be pre-allocated for outputs " + + $" Use Run() overloads that return DisposableNamedOnnxValue to get access to all Onnx value types that may be returned as output."); + } + + /// + /// This method is used internally to feed dictionary keys + /// to create an OrtValue for map keys + /// + /// + /// DenseTensor" + internal Object GetDictionaryKeys() + { + if (ValueType != OnnxValueType.ONNX_TYPE_MAP) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "This NamedOnnxValue instance does not contain a dictionary"); + } + + Debug.Assert(_mapHelper != null); + return _mapHelper.Keys; + } + + /// + /// + /// + /// + /// DenseTensor" + internal Object GetDictionaryValues() + { + if (ValueType != OnnxValueType.ONNX_TYPE_MAP) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "This NamedOnnxValue instance does not contain a dictionary"); + } + + Debug.Assert(_mapHelper != null); + return _mapHelper.Values; } // may expose different types of getters in future diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 57ae4c2905316..c92db2afd4845 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -261,6 +261,31 @@ public struct OrtApi public IntPtr ReleaseCANNProviderOptions; public IntPtr MemoryInfoGetDeviceType; public IntPtr UpdateEnvWithCustomLogLevel; + public IntPtr SetGlobalIntraOpThreadAffinity; + public IntPtr RegisterCustomOpsLibrary_V2; + public IntPtr RegisterCustomOpsUsingFunction; + public IntPtr KernelInfo_GetInputCount; + public IntPtr KernelInfo_GetOutputCount; + public IntPtr KernelInfo_GetInputName; + public IntPtr KernelInfo_GetOutputName; + public IntPtr KernelInfo_GetInputTypeInfo; + public IntPtr KernelInfo_GetOutputTypeInfo; + public IntPtr KernelInfoGetAttribute_tensor; + public IntPtr HasSessionConfigEntry; + public IntPtr GetSessionConfigEntry; + public IntPtr SessionOptionsAppendExecutionProvider_Dnnl; + public IntPtr CreateDnnlProviderOptions; + public IntPtr UpdateDnnlProviderOptions; + public IntPtr GetDnnlProviderOptionsAsString; + public IntPtr ReleaseDnnlProviderOptions; + public IntPtr KernelInfo_GetNodeName; + public IntPtr KernelInfo_GetLogger; + public IntPtr KernelContext_GetLogger; + public IntPtr Logger_LogMessage; + public IntPtr Logger_GetLoggingSeverityLevel; + public IntPtr KernelInfoGetConstantInput_tensor; + public IntPtr CastTypeInfoToOptionalTypeInfo; + public IntPtr GetOptionalContainedTypeInfo; } internal static class NativeMethods @@ -387,9 +412,10 @@ static NativeMethods() OrtSetLanguageProjection = (DOrtSetLanguageProjection)Marshal.GetDelegateForFunctionPointer(api_.SetLanguageProjection, typeof(DOrtSetLanguageProjection)); OrtGetValue = (DOrtGetValue)Marshal.GetDelegateForFunctionPointer(api_.GetValue, typeof(DOrtGetValue)); + OrtGetValueCount = (DOrtGetValueCount)Marshal.GetDelegateForFunctionPointer(api_.GetValueCount, typeof(DOrtGetValueCount)); + OrtCreateValue = (DOrtCreateValue)Marshal.GetDelegateForFunctionPointer(api_.CreateValue, typeof(DOrtCreateValue)); OrtGetValueType = (DOrtGetValueType)Marshal.GetDelegateForFunctionPointer(api_.GetValueType, typeof(DOrtGetValueType)); OrtGetOnnxTypeFromTypeInfo = (DOrtGetOnnxTypeFromTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetOnnxTypeFromTypeInfo, typeof(DOrtGetOnnxTypeFromTypeInfo)); - OrtGetValueCount = (DOrtGetValueCount)Marshal.GetDelegateForFunctionPointer(api_.GetValueCount, typeof(DOrtGetValueCount)); OrtGetTypeInfo = (DOrtGetTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetTypeInfo, typeof(DOrtGetTypeInfo)); OrtCreateTensorAsOrtValue = (DOrtCreateTensorAsOrtValue)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorAsOrtValue, typeof(DOrtCreateTensorAsOrtValue)); OrtCreateTensorWithDataAsOrtValue = (DOrtCreateTensorWithDataAsOrtValue)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorWithDataAsOrtValue, typeof(DOrtCreateTensorWithDataAsOrtValue)); @@ -405,6 +431,16 @@ static NativeMethods() OrtGetDimensions = (DOrtGetDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetDimensions, typeof(DOrtGetDimensions)); OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions)); OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount)); + // MapTypeInfo + OrtGetMapKeyType = (DGetMapKeyType)Marshal.GetDelegateForFunctionPointer(api_.GetMapKeyType, typeof(DGetMapKeyType)); + OrtCastTypeInfoToMapTypeInfo = (DCastTypeInfoToMapTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToMapTypeInfo, typeof(DCastTypeInfoToMapTypeInfo)); + OrtGetMapValueType = (DGetMapValueType)Marshal.GetDelegateForFunctionPointer(api_.GetMapValueType, typeof(DGetMapValueType)); + // SequenceTypeInfo + OrtCastTypeInfoToSequenceTypeInfo = (DCastTypeInfoToSequenceTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToSequenceTypeInfo, typeof(DCastTypeInfoToSequenceTypeInfo)); + OrtGetSequenceElementType = (DGetSequenceElementType)Marshal.GetDelegateForFunctionPointer(api_.GetSequenceElementType, typeof(DGetSequenceElementType)); + // Optional Type info + OrtCastTypeInfoToOptionalTypeInfo = (DOrtCastTypeInfoToOptionalTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToOptionalTypeInfo, typeof(DOrtCastTypeInfoToOptionalTypeInfo)); + OrtGetOptionalContainedTypeInfo = (DGetOptionalContainedTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetOptionalContainedTypeInfo, typeof(DGetOptionalContainedTypeInfo)); OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue)); OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata)); @@ -1571,6 +1607,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtGetValueCount OrtGetValueCount; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr/*(OrtStatus*)*/ DOrtCreateValue(IntPtr[] /* const OrtValue* const* in */ values, + UIntPtr /* size_t */ num_values, IntPtr /* (OnnxValueType */ onnxValueType, out IntPtr /* OrtValue** */ ortValue); + + public static DOrtCreateValue OrtCreateValue; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTypeInfo(IntPtr /*(OrtValue*)*/ value, IntPtr /*(OrtValue**)*/ typeInfo); @@ -1698,6 +1740,41 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtGetTensorShapeElementCount OrtGetTensorShapeElementCount; + /// Map Type API + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToMapTypeInfo(IntPtr /*(const struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*const OrtMapTypeInfo** */ mapTypeInfo); + + public static DCastTypeInfoToMapTypeInfo OrtCastTypeInfoToMapTypeInfo; + + public delegate IntPtr /*(OrtStatus*)*/ DGetMapKeyType(IntPtr /*const OrtMapTypeInfo* */ mapTypeInfo, out IntPtr /*(TensorElementType*)*/ tensorElementType); + + public static DGetMapKeyType OrtGetMapKeyType; + + public delegate IntPtr /*(OrtStatus*)*/ DGetMapValueType(IntPtr /* const OrtMapTypeInfo* */ map_type_info, out IntPtr /* OrtTypeInfo** */ type_info); + + public static DGetMapValueType OrtGetMapValueType; + + // Sequence TypeInfo + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToSequenceTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const OrtSequenceTypeInfo** */ sequenceTypeInfo); + + public static DCastTypeInfoToSequenceTypeInfo OrtCastTypeInfoToSequenceTypeInfo; + + public delegate IntPtr /*(OrtStatus*)*/ DGetSequenceElementType(IntPtr /* const OrtSequenceTypeInfo* */ sequenceTypeInfo, out IntPtr /* OrtTypeInfo** */ elementTypeInfo); + + public static DGetSequenceElementType OrtGetSequenceElementType; + + // OptionalTypeInfo + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToOptionalTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const struct OrtOptionalTypeInfo** */ optionalTypeInfo); + + public static DOrtCastTypeInfoToOptionalTypeInfo OrtCastTypeInfoToOptionalTypeInfo; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DGetOptionalContainedTypeInfo(IntPtr /* const struct OrtOptionalTypeInfo*/ optTypeInfo, out IntPtr /* struct OrtTypeInfo** */ containedTypeInfo); + + public static DGetOptionalContainedTypeInfo OrtGetOptionalContainedTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseValue(IntPtr /*(OrtValue*)*/ value); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index feff70782f834..704eb4d14222f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -43,7 +43,7 @@ public void Dispose() // No need for the finalizer // If this is not disposed timely GC can't help us #endregion - } + } /// /// This helper class contains methods to create native OrtValue from a managed value object @@ -58,8 +58,12 @@ internal static class NativeOnnxValueHelper /// UTF-8 encoded equivalent internal static byte[] StringToZeroTerminatedUtf8(string s) { - byte[] utf8Bytes = UTF8Encoding.UTF8.GetBytes(s); - Array.Resize(ref utf8Bytes, utf8Bytes.Length + 1); + int arraySize = UTF8Encoding.UTF8.GetByteCount(s); + byte[] utf8Bytes = new byte[arraySize + 1]; + if (arraySize != UTF8Encoding.UTF8.GetBytes(s, 0, s.Length, utf8Bytes, 0)) + { + throw new OnnxRuntimeException(ErrorCode.RuntimeException, "Failed to convert to UTF8"); + } utf8Bytes[utf8Bytes.Length - 1] = 0; return utf8Bytes; } @@ -72,7 +76,7 @@ internal static byte[] StringToZeroTerminatedUtf8(string s) /// internal static string StringFromNativeUtf8(IntPtr nativeUtf8) { - // .NET 5.0 has Marshal.PtrToStringUTF8 that does the below + // .NET 8.0 has Marshal.PtrToStringUTF8 that does the below int len = 0; while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len; byte[] buffer = new byte[len]; @@ -80,6 +84,33 @@ internal static string StringFromNativeUtf8(IntPtr nativeUtf8) return Encoding.UTF8.GetString(buffer, 0, buffer.Length); } + /// + /// Reads UTF-8 string from native C zero terminated string, + /// converts it to C# UTF-16 string and returns both C# string and utf-8 + /// bytes as a zero terminated array, suitable for use as a C-string + /// + /// input + /// C# UTF-16 string + /// UTF-8 bytes in a managed buffer, zero terminated + internal static void StringAndUtf8FromNative(IntPtr nativeUtf8, out string str, out byte[] utf8) + { + // .NET 8.0 has Marshal.PtrToStringUTF8 that does the below + int len = 0; + while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len; + utf8 = new byte[len + 1]; + Marshal.Copy(nativeUtf8, utf8, 0, len); + utf8[len] = 0; + str = Encoding.UTF8.GetString(utf8, 0, len); + } + + internal static string StringFromUtf8Span(ReadOnlySpan utf8Span) + { + // XXX: For now we have to copy into byte[], this produces a copy + // Converting from span is available in later versions + var utf8Bytes = utf8Span.ToArray(); + return Encoding.UTF8.GetString(utf8Bytes, 0, utf8Bytes.Length); + } + /// /// Run helper /// @@ -126,7 +157,7 @@ public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, ou { bool result = true; TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType); - if(typeInfo != null) + if (typeInfo != null) { type = typeInfo.TensorType; width = typeInfo.TypeSize; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 08609bb4826a6..868cf00ae334e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -20,6 +20,7 @@ public enum OnnxValueType ONNX_TYPE_MAP = 3, // It's a map ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor + ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKOWN) } /// @@ -345,18 +346,20 @@ out valueHandle // fill the native tensor, using GetValue(index) from the Tensor var len = tensor.Length; var nativeStrings = new IntPtr[len]; - using (var pinnedHandles = new DisposableList((int)len)) + using (var pinnedHandles = new DisposableList((int)len)) { for (int i = 0; i < len; i++) { var utf8str = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(tensor.GetValue(i)); - var gcHandle = GCHandle.Alloc(utf8str, GCHandleType.Pinned); - nativeStrings[i] = gcHandle.AddrOfPinnedObject(); - pinnedHandles.Add(new PinnedGCHandle(gcHandle)); + var pinnedUtf8 = new Memory(utf8str).Pin(); + unsafe + { + nativeStrings[i] = (IntPtr)pinnedUtf8.Pointer; + } + pinnedHandles.Add(pinnedUtf8); } - using (var pinnedStrings = new PinnedGCHandle(GCHandle.Alloc(nativeStrings, GCHandleType.Pinned))) - NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len)); } } catch (OnnxRuntimeException) @@ -380,6 +383,7 @@ protected override bool ReleaseHandle() if (IsOwned) { NativeMethods.OrtReleaseValue(handle); + IsOwned = false; } // Prevent use after disposal handle = IntPtr.Zero; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs similarity index 89% rename from csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs rename to csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs index ce339b6a528c1..29ae3ab2be30a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs @@ -23,13 +23,14 @@ internal interface IOrtValueOwner : IDisposable /// This class is used in conjunction with DisposableNamedOnnxValue to /// own native collection OrtValue and dispose of it along with any DisposableNamedOnnxValues /// - internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable + internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable + where T:IDisposable { private OrtValue _ortValue; - private DisposableList _disposables; + private DisposableList _disposables; bool _disposed = false; - internal NativeOrtValueCollectionOwner(OrtValue ortValue, DisposableList disposables) + internal NativeOrtValueCollectionOwner(OrtValue ortValue, DisposableList disposables) { Debug.Assert(ortValue.IsOwned); _ortValue = new OrtValue(ortValue.Disown()); @@ -80,19 +81,24 @@ public void Dispose() /// /// This helper class owns the underlying OrtValue that is assumed to be a Tensor, /// it does not support any other ortValues and caches Tensor properties. + /// + /// It is easy to expose as a Tensor as DenseTensor can take Memory Mapping from + /// this. + /// + /// This class is disposable because of the MemoryManager inheritance /// /// - internal class NativeOnnxTensorMemory : MemoryManager, IOrtValueOwner + internal class OrtValueTensor : MemoryManager, IOrtValueOwner { private OrtValue _ortValue; // Disposable - private IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory - private string[] _dataBufferAsString; // string tensor values copied into managed memory + private readonly IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory + private readonly string[] _dataBufferAsString; // string tensor values copied into managed memory /// /// Constructs an instance and takes ownership of ortValue on success /// /// ortValue that is a Tensor - public NativeOnnxTensorMemory(OrtValue ortValue) + public OrtValueTensor(OrtValue ortValue) { Type type = null; int width = 0; @@ -115,7 +121,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue) if (typeof(T) != type) { - var message = String.Format("The NativeOnnxTensorMemory type being instantiated for T = : {0} while supplied OrtValue contains T = {1}", + var message = String.Format("The OrtValueTensor type being instantiated for T = : {0} while supplied OrtValue contains T = {1}", typeof(T), type); throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message); } @@ -214,7 +220,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue) public override Span GetSpan() { if (IsDisposed) - throw new ObjectDisposedException(nameof(NativeOnnxTensorMemory)); + throw new ObjectDisposedException(nameof(OrtValueTensor)); Span span = null; unsafe { @@ -226,10 +232,10 @@ public override Span GetSpan() public Memory GetBytesAsStringMemory() { if (IsDisposed) - throw new ObjectDisposedException(nameof(NativeOnnxTensorMemory)); + throw new ObjectDisposedException(nameof(OrtValueTensor)); if (typeof(T) != typeof(string)) - throw new NotSupportedException(nameof(NativeOnnxTensorMemory.GetBytesAsStringMemory) + ": T must be byte"); + throw new NotSupportedException(nameof(OrtValueTensor.GetBytesAsStringMemory) + ": T must be byte"); return (_dataBufferAsString == null) ? new Memory() : new Memory(_dataBufferAsString); } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs index bb7eea2ad1883..0bc5ca7240e66 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs @@ -71,7 +71,7 @@ public Float16(ushort v) /// /// instance of Float16 /// value member - public static implicit operator ushort (Float16 f) { return f.value; } + public static implicit operator ushort(Float16 f) { return f.value; } /// /// Converts a 16-bit unsigned integer to a Float16. /// @@ -191,7 +191,7 @@ public bool Equals(BFloat16 other) /// represent the same type and value. /// /// An System.Object. - /// true if obj is BFloat16 its value is equal to this instance; otherwise, false. + /// true if obj is BFloat16 its value is equal to this instance; otherwise, false. public override bool Equals(object obj) { bool result = false; @@ -286,7 +286,8 @@ public class TensorBase private static readonly Dictionary tensorElementTypeInfoMap; - static TensorBase () { + static TensorBase() + { typeInfoMap = new Dictionary() { { typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) }, @@ -306,11 +307,11 @@ static TensorBase () { }; tensorElementTypeInfoMap = new Dictionary(); - foreach(var info in typeInfoMap) + foreach (var info in typeInfoMap) { tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize)); } - } + } private readonly Type _primitiveType; /// @@ -559,7 +560,10 @@ internal static T Zero { return (T)(object)(ushort)(0); } - + else if (typeof(T) == typeof(string)) + { + return (T)(object)("0"); + } throw new NotSupportedException(); } } @@ -619,8 +623,8 @@ internal static T One else if (typeof(T) == typeof(ushort)) { return (T)(object)(ushort)(1); - } - else if(typeof(T) == typeof(Float16)) + } + else if (typeof(T) == typeof(Float16)) { return (T)(object)(ushort)(15360); } @@ -628,6 +632,10 @@ internal static T One { return (T)(object)(ushort)(16256); } + else if (typeof(T) == typeof(string)) + { + return (T)(object)("1"); + } throw new NotSupportedException(); } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj index e15409216aabd..3e4802bf68f6e 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj @@ -48,6 +48,7 @@ + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 357d1ca8621bd..da83b640f2577 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -588,7 +588,7 @@ private void ThrowWrongInputName() var container = new List(); container.Add(NamedOnnxValue.CreateFromTensor("wrong_name", tensor)); var ex = Assert.Throws(() => session.Run(container)); - Assert.Contains("Invalid Feed Input", ex.Message); + Assert.Contains("Input name: 'wrong_name' is not in the metadata", ex.Message); session.Dispose(); } @@ -604,9 +604,8 @@ private void ThrowWrongInputType() var tensor = new DenseTensor(inputDataInt, inputMeta["data_0"].Dimensions); container.Add(NamedOnnxValue.CreateFromTensor("data_0", tensor)); var ex = Assert.Throws(() => session.Run(container)); - var msg = ex.ToString().Substring(0, 101); - // TODO: message is diff in LInux. Use substring match - Assert.Equal("Microsoft.ML.OnnxRuntime.OnnxRuntimeException: [ErrorCode:InvalidArgument] Unexpected input data type", msg); + var msg = ex.ToString(); + Assert.Contains("Tensor element data type discovered", msg); session.Dispose(); } @@ -624,7 +623,7 @@ private void ThrowExtraInputs() container.Add(nov1); container.Add(nov2); var ex = Assert.Throws(() => session.Run(container)); - Assert.StartsWith("[ErrorCode:InvalidArgument] Invalid Feed Input Name", ex.Message); + Assert.Contains("Input name: 'extra' is not in the metadata", ex.Message); session.Dispose(); } @@ -653,9 +652,10 @@ private void ThrowWrongOutputName() var inputTensor = tuple.Item3; var inputs = new List { NamedOnnxValue.CreateFromTensor("data_0", inputTensor) }; var outputTensor = new DenseTensor((ReadOnlySpan)new[] { 1, 2 }); - var outputs = new List { NamedOnnxValue.CreateFromTensor("bad_output_name", outputTensor) }; - var ex = Assert.Throws(() => session.Run(inputs, outputs)); - Assert.Contains("Invalid Output Name", ex.Message); + // var outputs = new List { NamedOnnxValue.CreateFromTensor("bad_output_name", outputTensor) }; + var bad_names = new string[] {"bad_output_name"}; + var ex = Assert.Throws(() => session.Run(inputs, bad_names)); + Assert.Contains("Output name: 'bad_output_name' is not in the metadata", ex.Message); session.Dispose(); } @@ -1322,8 +1322,29 @@ private void TestModelSequenceOfMapIntFloat() { var outMeta = session.OutputMetadata; - Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType); - Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType); + var label_meta = outMeta["label"]; + Assert.True(label_meta.IsTensor); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, label_meta.OnnxValueType); + Assert.Equal(TensorElementType.Int64, label_meta.ElementDataType); + Assert.NotEmpty(label_meta.Dimensions); + + // sequence> + var probabilities_meta = outMeta["probabilities"]; + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, probabilities_meta.OnnxValueType); + var seqElementMetata = probabilities_meta.AsSequenceMetadata().ElementMeta; + Assert.Equal(OnnxValueType.ONNX_TYPE_MAP, seqElementMetata.OnnxValueType); + var mapMetadata = seqElementMetata.AsMapMetadata(); + // Map + Assert.Equal(Tensors.TensorElementType.Int64, mapMetadata.KeyDataType); + var valueTensorMeta = mapMetadata.ValueMetadata; + Assert.True(valueTensorMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Float, valueTensorMeta.ElementDataType); + + // tensor + var inputMeta = session.InputMetadata["input"]; + Assert.True(inputMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Float, inputMeta.ElementDataType); + Assert.Equal(2, inputMeta.Dimensions.Length); var container = new List(); var tensorIn = new DenseTensor(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 }); @@ -1392,8 +1413,31 @@ private void TestModelSequenceOfMapStringFloat() using (var session = new InferenceSession(model)) { var outMeta = session.OutputMetadata; - Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType); - Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType); + var label_meta = outMeta["label"]; + Assert.True(label_meta.IsTensor); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, label_meta.OnnxValueType); + Assert.True(label_meta.IsString); + Assert.Equal(TensorElementType.String, label_meta.ElementDataType); + Assert.NotEmpty(label_meta.Dimensions); + + // sequence> + var probabilities_meta = outMeta["probabilities"]; + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, probabilities_meta.OnnxValueType); + var seqElementMetata = probabilities_meta.AsSequenceMetadata().ElementMeta; + Assert.Equal(OnnxValueType.ONNX_TYPE_MAP, seqElementMetata.OnnxValueType); + var mapMetadata = seqElementMetata.AsMapMetadata(); + Assert.Equal(Tensors.TensorElementType.String, mapMetadata.KeyDataType); + var valueTensorMeta = mapMetadata.ValueMetadata; + Assert.True(valueTensorMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Float, valueTensorMeta.ElementDataType); + + + // tensor + var inputMeta = session.InputMetadata["input"]; + Assert.True(inputMeta.IsTensor); + Assert.False(inputMeta.IsString); + Assert.Equal(Tensors.TensorElementType.Float, inputMeta.ElementDataType); + Assert.Equal(2, inputMeta.Dimensions.Length); var container = new List(); var tensorIn = new DenseTensor(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 }); @@ -1415,7 +1459,7 @@ private void TestModelSequenceOfMapStringFloat() // Label 1 should have highest probability Assert.Equal("1", outLabelTensor[0]); - // second output is a sequence> + // second output is a sequence> // try-cast to an sequence of NOV var outNode1 = outputs.ElementAtOrDefault(1); Assert.Equal("probabilities", outNode1.Name); @@ -1443,7 +1487,18 @@ private void TestModelSequenceOfTensors() using (var session = new InferenceSession(model)) { var outMeta = session.OutputMetadata; - Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["output_sequence"].OnnxValueType); + var output_seq = outMeta["output_sequence"]; + Assert.False(output_seq.IsTensor); + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, output_seq.OnnxValueType); + var elemMeta = output_seq.AsSequenceMetadata().ElementMeta; + Assert.True(elemMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Int64, elemMeta.ElementDataType); + + // Inputs + var tensor1Meta = session.InputMetadata["tensor1"]; + Assert.True(tensor1Meta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Int64, tensor1Meta.ElementDataType); + Assert.Equal(2, tensor1Meta.Dimensions.Length); var container = new List(); var firstInputTensor = new DenseTensor(new Int64[] { 1, 2, 3, 4, 5, 6 }, new int[] { 2, 3 }); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj index a373039436e3b..58c9cbe11dbd9 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj @@ -9,7 +9,7 @@ true true true - $(OnnxRuntimeCsharpRoot)\..\cmake\external\onnx\onnx + $(OnnxRuntimeCsharpRoot)\..\cmake\external\onnx;\..\cmake\external\onnx\onnx 7.2 @@ -76,6 +76,7 @@ + @@ -107,6 +108,10 @@ + + + + @@ -131,4 +136,4 @@ - + \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxData.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxData.cs new file mode 100644 index 0000000000000..0d3f3b4d3eddc --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxData.cs @@ -0,0 +1,1335 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: onnx/onnx-data.proto3 +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Onnx { + + /// Holder for reflection information generated from onnx/onnx-data.proto3 + public static partial class OnnxDataReflection { + + #region Descriptor + /// File descriptor for onnx/onnx-data.proto3 + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static OnnxDataReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "ChVvbm54L29ubngtZGF0YS5wcm90bzMSBG9ubngaE29ubngvb25ueC1tbC5w", + "cm90bzMi8AIKDVNlcXVlbmNlUHJvdG8SDAoEbmFtZRgBIAEoCRIRCgllbGVt", + "X3R5cGUYAiABKAUSKAoNdGVuc29yX3ZhbHVlcxgDIAMoCzIRLm9ubnguVGVu", + "c29yUHJvdG8SNQoUc3BhcnNlX3RlbnNvcl92YWx1ZXMYBCADKAsyFy5vbm54", + "LlNwYXJzZVRlbnNvclByb3RvEiwKD3NlcXVlbmNlX3ZhbHVlcxgFIAMoCzIT", + "Lm9ubnguU2VxdWVuY2VQcm90bxIiCgptYXBfdmFsdWVzGAYgAygLMg4ub25u", + "eC5NYXBQcm90bxIsCg9vcHRpb25hbF92YWx1ZXMYByADKAsyEy5vbm54Lk9w", + "dGlvbmFsUHJvdG8iXQoIRGF0YVR5cGUSDQoJVU5ERUZJTkVEEAASCgoGVEVO", + "U09SEAESEQoNU1BBUlNFX1RFTlNPUhACEgwKCFNFUVVFTkNFEAMSBwoDTUFQ", + "EAQSDAoIT1BUSU9OQUwQBSJyCghNYXBQcm90bxIMCgRuYW1lGAEgASgJEhAK", + "CGtleV90eXBlGAIgASgFEgwKBGtleXMYAyADKAMSEwoLc3RyaW5nX2tleXMY", + "BCADKAwSIwoGdmFsdWVzGAUgASgLMhMub25ueC5TZXF1ZW5jZVByb3RvIusC", + "Cg1PcHRpb25hbFByb3RvEgwKBG5hbWUYASABKAkSEQoJZWxlbV90eXBlGAIg", + "ASgFEicKDHRlbnNvcl92YWx1ZRgDIAEoCzIRLm9ubnguVGVuc29yUHJvdG8S", + "NAoTc3BhcnNlX3RlbnNvcl92YWx1ZRgEIAEoCzIXLm9ubnguU3BhcnNlVGVu", + "c29yUHJvdG8SKwoOc2VxdWVuY2VfdmFsdWUYBSABKAsyEy5vbm54LlNlcXVl", + "bmNlUHJvdG8SIQoJbWFwX3ZhbHVlGAYgASgLMg4ub25ueC5NYXBQcm90bxIr", + "Cg5vcHRpb25hbF92YWx1ZRgHIAEoCzITLm9ubnguT3B0aW9uYWxQcm90byJd", + "CghEYXRhVHlwZRINCglVTkRFRklORUQQABIKCgZURU5TT1IQARIRCg1TUEFS", + "U0VfVEVOU09SEAISDAoIU0VRVUVOQ0UQAxIHCgNNQVAQBBIMCghPUFRJT05B", + "TBAFQgJIA2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Onnx.OnnxMlReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Onnx.SequenceProto), global::Onnx.SequenceProto.Parser, new[]{ "Name", "ElemType", "TensorValues", "SparseTensorValues", "SequenceValues", "MapValues", "OptionalValues" }, null, new[]{ typeof(global::Onnx.SequenceProto.Types.DataType) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Onnx.MapProto), global::Onnx.MapProto.Parser, new[]{ "Name", "KeyType", "Keys", "StringKeys", "Values" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Onnx.OptionalProto), global::Onnx.OptionalProto.Parser, new[]{ "Name", "ElemType", "TensorValue", "SparseTensorValue", "SequenceValue", "MapValue", "OptionalValue" }, null, new[]{ typeof(global::Onnx.OptionalProto.Types.DataType) }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Sequences + /// + /// Defines a dense, ordered, collection of elements that are of homogeneous types. + /// Sequences can be made out of tensors, maps, or sequences. + /// + /// If a sequence is made out of tensors, the tensors must have the same element + /// type (i.e. int32). In some cases, the tensors in a sequence can have different + /// shapes. Whether the tensors can have different shapes or not depends on the + /// type/shape associated with the corresponding "ValueInfo". For example, + /// "Sequence<Tensor<float, [M,N]>" means that all tensors have same shape. However, + /// "Sequence<Tensor<float, [omitted,omitted]>" means they can have different + /// shapes (all of rank 2), where "omitted" means the corresponding dimension has + /// no symbolic/constant value. Finally, "Sequence<Tensor<float, omitted>>" means + /// that the different tensors can have different ranks, when the "shape" itself + /// is omitted from the tensor-type. For a more complete description, refer to + /// https://github.com/onnx/onnx/blob/main/docs/IR.md#static-tensor-shapes. + /// + public sealed partial class SequenceProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SequenceProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Onnx.OnnxDataReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SequenceProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SequenceProto(SequenceProto other) : this() { + name_ = other.name_; + elemType_ = other.elemType_; + tensorValues_ = other.tensorValues_.Clone(); + sparseTensorValues_ = other.sparseTensorValues_.Clone(); + sequenceValues_ = other.sequenceValues_.Clone(); + mapValues_ = other.mapValues_.Clone(); + optionalValues_ = other.optionalValues_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SequenceProto Clone() { + return new SequenceProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "elem_type" field. + public const int ElemTypeFieldNumber = 2; + private int elemType_; + /// + /// The data type of the element. + /// This field MUST have a valid SequenceProto.DataType value + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ElemType { + get { return elemType_; } + set { + elemType_ = value; + } + } + + /// Field number for the "tensor_values" field. + public const int TensorValuesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_tensorValues_codec + = pb::FieldCodec.ForMessage(26, global::Onnx.TensorProto.Parser); + private readonly pbc::RepeatedField tensorValues_ = new pbc::RepeatedField(); + /// + /// For TensorProto values. + /// When this field is present, the elem_type field MUST be TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TensorValues { + get { return tensorValues_; } + } + + /// Field number for the "sparse_tensor_values" field. + public const int SparseTensorValuesFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_sparseTensorValues_codec + = pb::FieldCodec.ForMessage(34, global::Onnx.SparseTensorProto.Parser); + private readonly pbc::RepeatedField sparseTensorValues_ = new pbc::RepeatedField(); + /// + /// For SparseTensorProto values. + /// When this field is present, the elem_type field MUST be SPARSE_TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SparseTensorValues { + get { return sparseTensorValues_; } + } + + /// Field number for the "sequence_values" field. + public const int SequenceValuesFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_sequenceValues_codec + = pb::FieldCodec.ForMessage(42, global::Onnx.SequenceProto.Parser); + private readonly pbc::RepeatedField sequenceValues_ = new pbc::RepeatedField(); + /// + /// For SequenceProto values, allowing sequences to be of themselves. + /// When this field is present, the elem_type field MUST be SEQUENCE. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SequenceValues { + get { return sequenceValues_; } + } + + /// Field number for the "map_values" field. + public const int MapValuesFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_mapValues_codec + = pb::FieldCodec.ForMessage(50, global::Onnx.MapProto.Parser); + private readonly pbc::RepeatedField mapValues_ = new pbc::RepeatedField(); + /// + /// For MapProto values. + /// When this field is present, the elem_type field MUST be MAP. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField MapValues { + get { return mapValues_; } + } + + /// Field number for the "optional_values" field. + public const int OptionalValuesFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_optionalValues_codec + = pb::FieldCodec.ForMessage(58, global::Onnx.OptionalProto.Parser); + private readonly pbc::RepeatedField optionalValues_ = new pbc::RepeatedField(); + /// + /// For OptionalProto values. + /// When this field is present, the elem_type field MUST be Optional. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OptionalValues { + get { return optionalValues_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SequenceProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SequenceProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (ElemType != other.ElemType) return false; + if(!tensorValues_.Equals(other.tensorValues_)) return false; + if(!sparseTensorValues_.Equals(other.sparseTensorValues_)) return false; + if(!sequenceValues_.Equals(other.sequenceValues_)) return false; + if(!mapValues_.Equals(other.mapValues_)) return false; + if(!optionalValues_.Equals(other.optionalValues_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (ElemType != 0) hash ^= ElemType.GetHashCode(); + hash ^= tensorValues_.GetHashCode(); + hash ^= sparseTensorValues_.GetHashCode(); + hash ^= sequenceValues_.GetHashCode(); + hash ^= mapValues_.GetHashCode(); + hash ^= optionalValues_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + tensorValues_.WriteTo(output, _repeated_tensorValues_codec); + sparseTensorValues_.WriteTo(output, _repeated_sparseTensorValues_codec); + sequenceValues_.WriteTo(output, _repeated_sequenceValues_codec); + mapValues_.WriteTo(output, _repeated_mapValues_codec); + optionalValues_.WriteTo(output, _repeated_optionalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + tensorValues_.WriteTo(ref output, _repeated_tensorValues_codec); + sparseTensorValues_.WriteTo(ref output, _repeated_sparseTensorValues_codec); + sequenceValues_.WriteTo(ref output, _repeated_sequenceValues_codec); + mapValues_.WriteTo(ref output, _repeated_mapValues_codec); + optionalValues_.WriteTo(ref output, _repeated_optionalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (ElemType != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ElemType); + } + size += tensorValues_.CalculateSize(_repeated_tensorValues_codec); + size += sparseTensorValues_.CalculateSize(_repeated_sparseTensorValues_codec); + size += sequenceValues_.CalculateSize(_repeated_sequenceValues_codec); + size += mapValues_.CalculateSize(_repeated_mapValues_codec); + size += optionalValues_.CalculateSize(_repeated_optionalValues_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SequenceProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.ElemType != 0) { + ElemType = other.ElemType; + } + tensorValues_.Add(other.tensorValues_); + sparseTensorValues_.Add(other.sparseTensorValues_); + sequenceValues_.Add(other.sequenceValues_); + mapValues_.Add(other.mapValues_); + optionalValues_.Add(other.optionalValues_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + tensorValues_.AddEntriesFrom(input, _repeated_tensorValues_codec); + break; + } + case 34: { + sparseTensorValues_.AddEntriesFrom(input, _repeated_sparseTensorValues_codec); + break; + } + case 42: { + sequenceValues_.AddEntriesFrom(input, _repeated_sequenceValues_codec); + break; + } + case 50: { + mapValues_.AddEntriesFrom(input, _repeated_mapValues_codec); + break; + } + case 58: { + optionalValues_.AddEntriesFrom(input, _repeated_optionalValues_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + tensorValues_.AddEntriesFrom(ref input, _repeated_tensorValues_codec); + break; + } + case 34: { + sparseTensorValues_.AddEntriesFrom(ref input, _repeated_sparseTensorValues_codec); + break; + } + case 42: { + sequenceValues_.AddEntriesFrom(ref input, _repeated_sequenceValues_codec); + break; + } + case 50: { + mapValues_.AddEntriesFrom(ref input, _repeated_mapValues_codec); + break; + } + case 58: { + optionalValues_.AddEntriesFrom(ref input, _repeated_optionalValues_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the SequenceProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum DataType { + [pbr::OriginalName("UNDEFINED")] Undefined = 0, + [pbr::OriginalName("TENSOR")] Tensor = 1, + [pbr::OriginalName("SPARSE_TENSOR")] SparseTensor = 2, + [pbr::OriginalName("SEQUENCE")] Sequence = 3, + [pbr::OriginalName("MAP")] Map = 4, + [pbr::OriginalName("OPTIONAL")] Optional = 5, + } + + } + #endregion + + } + + /// + /// Maps + /// + /// Specifies an associative table, defined by keys and values. + /// MapProto is formed with a repeated field of keys (of type INT8, INT16, INT32, + /// INT64, UINT8, UINT16, UINT32, UINT64, or STRING) and values (of type TENSOR, + /// SPARSE_TENSOR, SEQUENCE, or MAP). Key types and value types have to remain + /// the same throughout the instantiation of the MapProto. + /// + public sealed partial class MapProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MapProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Onnx.OnnxDataReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MapProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MapProto(MapProto other) : this() { + name_ = other.name_; + keyType_ = other.keyType_; + keys_ = other.keys_.Clone(); + stringKeys_ = other.stringKeys_.Clone(); + values_ = other.values_ != null ? other.values_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MapProto Clone() { + return new MapProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "key_type" field. + public const int KeyTypeFieldNumber = 2; + private int keyType_; + /// + /// The data type of the key. + /// This field MUST have a valid TensorProto.DataType value of + /// INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, or STRING + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int KeyType { + get { return keyType_; } + set { + keyType_ = value; + } + } + + /// Field number for the "keys" field. + public const int KeysFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_keys_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField keys_ = new pbc::RepeatedField(); + /// + /// Every element of keys has to be one of the following data types + /// INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, or STRING. + /// The integer cases are represented by the repeated int64 field keys below. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Keys { + get { return keys_; } + } + + /// Field number for the "string_keys" field. + public const int StringKeysFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_stringKeys_codec + = pb::FieldCodec.ForBytes(34); + private readonly pbc::RepeatedField stringKeys_ = new pbc::RepeatedField(); + /// + /// If keys are strings, they are represented by the repeated bytes field + /// string_keys below. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField StringKeys { + get { return stringKeys_; } + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 5; + private global::Onnx.SequenceProto values_; + /// + /// MapProto values are represented in a SequenceProto of the same length as the + /// repeated keys field and have to be one of the following data types + /// TENSOR, SPARSE_TENSOR, MAP, SEQUENCE. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.SequenceProto Values { + get { return values_; } + set { + values_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MapProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MapProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (KeyType != other.KeyType) return false; + if(!keys_.Equals(other.keys_)) return false; + if(!stringKeys_.Equals(other.stringKeys_)) return false; + if (!object.Equals(Values, other.Values)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (KeyType != 0) hash ^= KeyType.GetHashCode(); + hash ^= keys_.GetHashCode(); + hash ^= stringKeys_.GetHashCode(); + if (values_ != null) hash ^= Values.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (KeyType != 0) { + output.WriteRawTag(16); + output.WriteInt32(KeyType); + } + keys_.WriteTo(output, _repeated_keys_codec); + stringKeys_.WriteTo(output, _repeated_stringKeys_codec); + if (values_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Values); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (KeyType != 0) { + output.WriteRawTag(16); + output.WriteInt32(KeyType); + } + keys_.WriteTo(ref output, _repeated_keys_codec); + stringKeys_.WriteTo(ref output, _repeated_stringKeys_codec); + if (values_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Values); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (KeyType != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KeyType); + } + size += keys_.CalculateSize(_repeated_keys_codec); + size += stringKeys_.CalculateSize(_repeated_stringKeys_codec); + if (values_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Values); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MapProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.KeyType != 0) { + KeyType = other.KeyType; + } + keys_.Add(other.keys_); + stringKeys_.Add(other.stringKeys_); + if (other.values_ != null) { + if (values_ == null) { + Values = new global::Onnx.SequenceProto(); + } + Values.MergeFrom(other.Values); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + KeyType = input.ReadInt32(); + break; + } + case 26: + case 24: { + keys_.AddEntriesFrom(input, _repeated_keys_codec); + break; + } + case 34: { + stringKeys_.AddEntriesFrom(input, _repeated_stringKeys_codec); + break; + } + case 42: { + if (values_ == null) { + Values = new global::Onnx.SequenceProto(); + } + input.ReadMessage(Values); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + KeyType = input.ReadInt32(); + break; + } + case 26: + case 24: { + keys_.AddEntriesFrom(ref input, _repeated_keys_codec); + break; + } + case 34: { + stringKeys_.AddEntriesFrom(ref input, _repeated_stringKeys_codec); + break; + } + case 42: { + if (values_ == null) { + Values = new global::Onnx.SequenceProto(); + } + input.ReadMessage(Values); + break; + } + } + } + } + #endif + + } + + /// + /// Optional + /// + public sealed partial class OptionalProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OptionalProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Onnx.OnnxDataReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptionalProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptionalProto(OptionalProto other) : this() { + name_ = other.name_; + elemType_ = other.elemType_; + tensorValue_ = other.tensorValue_ != null ? other.tensorValue_.Clone() : null; + sparseTensorValue_ = other.sparseTensorValue_ != null ? other.sparseTensorValue_.Clone() : null; + sequenceValue_ = other.sequenceValue_ != null ? other.sequenceValue_.Clone() : null; + mapValue_ = other.mapValue_ != null ? other.mapValue_.Clone() : null; + optionalValue_ = other.optionalValue_ != null ? other.optionalValue_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptionalProto Clone() { + return new OptionalProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "elem_type" field. + public const int ElemTypeFieldNumber = 2; + private int elemType_; + /// + /// The data type of the element, identifies if the OptionalProto value + /// is Tensor, Sparse Tensor, Sequence, Map, or Optional. + /// The type of the optional value MUST match the elem_type specified. + /// This field MUST have a valid OptionalProto.DataType value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ElemType { + get { return elemType_; } + set { + elemType_ = value; + } + } + + /// Field number for the "tensor_value" field. + public const int TensorValueFieldNumber = 3; + private global::Onnx.TensorProto tensorValue_; + /// + /// For TensorProto value. + /// When this field is present, the elem_type field MUST be TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.TensorProto TensorValue { + get { return tensorValue_; } + set { + tensorValue_ = value; + } + } + + /// Field number for the "sparse_tensor_value" field. + public const int SparseTensorValueFieldNumber = 4; + private global::Onnx.SparseTensorProto sparseTensorValue_; + /// + /// For SparseTensorProto value. + /// When this field is present, the elem_type field MUST be SPARSE_TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.SparseTensorProto SparseTensorValue { + get { return sparseTensorValue_; } + set { + sparseTensorValue_ = value; + } + } + + /// Field number for the "sequence_value" field. + public const int SequenceValueFieldNumber = 5; + private global::Onnx.SequenceProto sequenceValue_; + /// + /// For SequenceProto value. + /// When this field is present, the elem_type field MUST be SEQUENCE. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.SequenceProto SequenceValue { + get { return sequenceValue_; } + set { + sequenceValue_ = value; + } + } + + /// Field number for the "map_value" field. + public const int MapValueFieldNumber = 6; + private global::Onnx.MapProto mapValue_; + /// + /// For MapProto value. + /// When this field is present, the elem_type field MUST be MAP. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.MapProto MapValue { + get { return mapValue_; } + set { + mapValue_ = value; + } + } + + /// Field number for the "optional_value" field. + public const int OptionalValueFieldNumber = 7; + private global::Onnx.OptionalProto optionalValue_; + /// + /// For OptionalProto value, allowing optional to be of itself (completeness) + /// When this field is present, the elem_type field MUST be OPTIONAL. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.OptionalProto OptionalValue { + get { return optionalValue_; } + set { + optionalValue_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as OptionalProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(OptionalProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (ElemType != other.ElemType) return false; + if (!object.Equals(TensorValue, other.TensorValue)) return false; + if (!object.Equals(SparseTensorValue, other.SparseTensorValue)) return false; + if (!object.Equals(SequenceValue, other.SequenceValue)) return false; + if (!object.Equals(MapValue, other.MapValue)) return false; + if (!object.Equals(OptionalValue, other.OptionalValue)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (ElemType != 0) hash ^= ElemType.GetHashCode(); + if (tensorValue_ != null) hash ^= TensorValue.GetHashCode(); + if (sparseTensorValue_ != null) hash ^= SparseTensorValue.GetHashCode(); + if (sequenceValue_ != null) hash ^= SequenceValue.GetHashCode(); + if (mapValue_ != null) hash ^= MapValue.GetHashCode(); + if (optionalValue_ != null) hash ^= OptionalValue.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + if (tensorValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorValue); + } + if (sparseTensorValue_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SparseTensorValue); + } + if (sequenceValue_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SequenceValue); + } + if (mapValue_ != null) { + output.WriteRawTag(50); + output.WriteMessage(MapValue); + } + if (optionalValue_ != null) { + output.WriteRawTag(58); + output.WriteMessage(OptionalValue); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + if (tensorValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorValue); + } + if (sparseTensorValue_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SparseTensorValue); + } + if (sequenceValue_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SequenceValue); + } + if (mapValue_ != null) { + output.WriteRawTag(50); + output.WriteMessage(MapValue); + } + if (optionalValue_ != null) { + output.WriteRawTag(58); + output.WriteMessage(OptionalValue); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (ElemType != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ElemType); + } + if (tensorValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorValue); + } + if (sparseTensorValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SparseTensorValue); + } + if (sequenceValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SequenceValue); + } + if (mapValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MapValue); + } + if (optionalValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(OptionalValue); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(OptionalProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.ElemType != 0) { + ElemType = other.ElemType; + } + if (other.tensorValue_ != null) { + if (tensorValue_ == null) { + TensorValue = new global::Onnx.TensorProto(); + } + TensorValue.MergeFrom(other.TensorValue); + } + if (other.sparseTensorValue_ != null) { + if (sparseTensorValue_ == null) { + SparseTensorValue = new global::Onnx.SparseTensorProto(); + } + SparseTensorValue.MergeFrom(other.SparseTensorValue); + } + if (other.sequenceValue_ != null) { + if (sequenceValue_ == null) { + SequenceValue = new global::Onnx.SequenceProto(); + } + SequenceValue.MergeFrom(other.SequenceValue); + } + if (other.mapValue_ != null) { + if (mapValue_ == null) { + MapValue = new global::Onnx.MapProto(); + } + MapValue.MergeFrom(other.MapValue); + } + if (other.optionalValue_ != null) { + if (optionalValue_ == null) { + OptionalValue = new global::Onnx.OptionalProto(); + } + OptionalValue.MergeFrom(other.OptionalValue); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + if (tensorValue_ == null) { + TensorValue = new global::Onnx.TensorProto(); + } + input.ReadMessage(TensorValue); + break; + } + case 34: { + if (sparseTensorValue_ == null) { + SparseTensorValue = new global::Onnx.SparseTensorProto(); + } + input.ReadMessage(SparseTensorValue); + break; + } + case 42: { + if (sequenceValue_ == null) { + SequenceValue = new global::Onnx.SequenceProto(); + } + input.ReadMessage(SequenceValue); + break; + } + case 50: { + if (mapValue_ == null) { + MapValue = new global::Onnx.MapProto(); + } + input.ReadMessage(MapValue); + break; + } + case 58: { + if (optionalValue_ == null) { + OptionalValue = new global::Onnx.OptionalProto(); + } + input.ReadMessage(OptionalValue); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + if (tensorValue_ == null) { + TensorValue = new global::Onnx.TensorProto(); + } + input.ReadMessage(TensorValue); + break; + } + case 34: { + if (sparseTensorValue_ == null) { + SparseTensorValue = new global::Onnx.SparseTensorProto(); + } + input.ReadMessage(SparseTensorValue); + break; + } + case 42: { + if (sequenceValue_ == null) { + SequenceValue = new global::Onnx.SequenceProto(); + } + input.ReadMessage(SequenceValue); + break; + } + case 50: { + if (mapValue_ == null) { + MapValue = new global::Onnx.MapProto(); + } + input.ReadMessage(MapValue); + break; + } + case 58: { + if (optionalValue_ == null) { + OptionalValue = new global::Onnx.OptionalProto(); + } + input.ReadMessage(OptionalValue); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the OptionalProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum DataType { + [pbr::OriginalName("UNDEFINED")] Undefined = 0, + [pbr::OriginalName("TENSOR")] Tensor = 1, + [pbr::OriginalName("SPARSE_TENSOR")] SparseTensor = 2, + [pbr::OriginalName("SEQUENCE")] Sequence = 3, + [pbr::OriginalName("MAP")] Map = 4, + [pbr::OriginalName("OPTIONAL")] Optional = 5, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs index 8805b95839f28..72686b3775277 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs @@ -3864,7 +3864,7 @@ public int DataType { /// float16 values must be bit-wise converted to an uint16_t prior /// to writing to the buffer. /// When this field is present, the data_type field MUST be - /// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + /// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16 or BFLOAT16 /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs index 2086fa0ec3164..42a780b3287ef 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs @@ -1,8 +1,12 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; +using System.Net.NetworkInformation; +using Google.Protobuf; using Microsoft.ML.OnnxRuntime.Tensors; +using Xunit; namespace Microsoft.ML.OnnxRuntime.Tests { @@ -46,185 +50,361 @@ internal static float[] LoadTensorFromEmbeddedResource(string path) return tensorData.ToArray(); } - internal static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Type type, out int width) + static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, string nodeName, NodeMetadata nodeMeta) { - TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType); - if (result != null) + if (nodeMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) { - type = result.TensorType; - width = result.TypeSize; + throw new InvalidDataException($"Metadata for: '{nodeName}' has a type: '{nodeMeta.OnnxValueType}'" + + $" but loading as tensor: '{tensor.Name}'"); } - else + + var protoDt = (Tensors.TensorElementType)tensor.DataType; + var metaElementType = nodeMeta.ElementDataType; + if (!((protoDt == metaElementType) || + (protoDt == TensorElementType.UInt16 && + (metaElementType == TensorElementType.BFloat16 || metaElementType == TensorElementType.Float16)))) + throw new InvalidDataException($"For node: '{nodeName}' metadata expects: '{metaElementType}' but loaded loaded tensor type: '{protoDt}'"); + + // Tensors within Sequences may have no dimensions as the standard allows + // different dimensions for each tensor element of the sequence + if (nodeMeta.Dimensions.Length > 0 && nodeMeta.Dimensions.Length != tensor.Dims.Count) { - throw new ArgumentException("Unable to get information for type: " + elemType.ToString()); + throw new InvalidDataException($"node: '{nodeName}' nodeMeta.Dim.Length: {nodeMeta.Dimensions.Length} " + + $"is expected to be equal to tensor.Dims.Count {tensor.Dims.Count}"); } - } - static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary nodeMetaDict) - { - Type tensorElemType = null; - int width = 0; - GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width); var intDims = new int[tensor.Dims.Count]; - for (int i = 0; i < tensor.Dims.Count; i++) { intDims[i] = (int)tensor.Dims[i]; } - NodeMetadata nodeMeta = null; - string nodeName = string.Empty; + for (int i = 0; i < nodeMeta.Dimensions.Length; i++) + { + if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != tensor.Dims[i])) + throw new InvalidDataException($"Node: '{nodeName}' dimension at idx {i} is {nodeMeta.Dimensions}[{i}] " + + $"is expected to either be -1 or {tensor.Dims[i]}"); + } - if (nodeMetaDict.Count == 1) + // element type for Float16 and BFloat16 in the loaded tensor would always be uint16, so + // we want to use element type from metadata + if (protoDt == TensorElementType.String) + return CreateNamedOnnxValueFromStringTensor(tensor.StringData, nodeName, intDims); + + return CreateNamedOnnxValueFromTensorRawData(nodeName, tensor.RawData.ToArray(), metaElementType, intDims); + } + + internal static NamedOnnxValue CreateNamedOnnxValueFromTensorRawData(string nodeName, byte[] rawData, TensorElementType elementType, int[] intDims) + { + switch (elementType) { - nodeMeta = nodeMetaDict.Values.First(); - nodeName = nodeMetaDict.Keys.First(); // valid for single node input + case TensorElementType.Float: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(float), intDims); + case TensorElementType.Double: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(double), intDims); + case TensorElementType.Int32: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(int), intDims); + case TensorElementType.UInt32: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(uint), intDims); + case TensorElementType.Int16: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(short), intDims); + case TensorElementType.UInt16: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims); + case TensorElementType.Int64: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(long), intDims); + case TensorElementType.UInt64: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ulong), intDims); + case TensorElementType.UInt8: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(byte), intDims); + case TensorElementType.Int8: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(sbyte), intDims); + case TensorElementType.Bool: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(bool), intDims); + case TensorElementType.Float16: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims); + case TensorElementType.BFloat16: + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims); + case TensorElementType.String: + throw new ArgumentException("For string tensors of type use: CreateNamedOnnxValueFromStringTensor."); + default: + throw new NotImplementedException($"Tensors of type: {elementType} not currently supported by this function"); } - else if (nodeMetaDict.Count > 1) + } + + internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, string nodeName, NodeMetadata nodeMeta) + { + Onnx.TensorProto tensor = null; + + var assembly = typeof(TestDataLoader).Assembly; + + using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.TestData.{path}")) { - if (tensor.Name.Length > 0) - { - nodeMeta = nodeMetaDict[tensor.Name]; - nodeName = tensor.Name; - } - else + tensor = Onnx.TensorProto.Parser.ParseFrom(stream); + } + + return LoadTensorPb(tensor, nodeName, nodeMeta); + } + + internal static NamedOnnxValue LoadOnnxValueFromFilePb(string fullFilename, string nodeName, NodeMetadata nodeMeta) + { + // No sparse tensor support yet + //Set buffer size to 4MB + int readBufferSize = 4194304; + using (var file = new FileStream(fullFilename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize)) + { + switch (nodeMeta.OnnxValueType) { - bool matchfound = false; - // try to find from matching type and shape - foreach (var key in nodeMetaDict.Keys) - { - var meta = nodeMetaDict[key]; - if (tensorElemType == meta.ElementType && tensor.Dims.Count == meta.Dimensions.Length) + case OnnxValueType.ONNX_TYPE_TENSOR: { - int i = 0; - for (; i < meta.Dimensions.Length; i++) - { - if (meta.Dimensions[i] != -1 && meta.Dimensions[i] != intDims[i]) - { - break; - } - } - if (i >= meta.Dimensions.Length) - { - matchfound = true; - nodeMeta = meta; - nodeName = key; - break; - } + var tensor = Onnx.TensorProto.Parser.ParseFrom(file); + return LoadTensorPb(tensor, nodeName, nodeMeta); } - } - if (!matchfound) - { - // throw error - throw new Exception($"No Matching Tensor found in InputOutputMetadata corresponding to the serialized tensor specified"); - } + case OnnxValueType.ONNX_TYPE_SEQUENCE: + { + var sequence = Onnx.SequenceProto.Parser.ParseFrom(file); + return CreateNamedOnnxValueFromSequence(sequence, nodeName, nodeMeta); + } + case OnnxValueType.ONNX_TYPE_MAP: + { + var map = Onnx.MapProto.Parser.ParseFrom(file); + return CreateNamedOnnxValueFromMap(map, nodeName, nodeMeta); + } + + case OnnxValueType.ONNX_TYPE_OPTIONAL: + { + var opt = Onnx.OptionalProto.Parser.ParseFrom(file); + return CreateNamedOnnxValueFromOptional(opt, nodeName, nodeMeta); + } + default: + throw new NotImplementedException($"Unable to load value type: {nodeMeta.OnnxValueType} not implemented"); } } - else - { - // throw error - throw new Exception($"While reading the serliazed tensor specified, metaDataDict has 0 elements"); - } + } - if (!nodeMeta.IsTensor) - throw new Exception("LoadTensorFromFile can load Tensor types only"); + private static void SequenceCheckMatchOnnxType(string nodeName, SequenceMetadata meta, + OnnxValueType onnxType) + { + if (meta.ElementMeta.OnnxValueType == onnxType) + return; - if (tensorElemType != nodeMeta.ElementType) - throw new Exception($"{nameof(tensorElemType)} is expected to be equal to {nameof(nodeMeta.ElementType)}"); + throw new InvalidDataException($"Sequence node: '{nodeName}' " + + $"has element type: '{onnxType}'" + + $" expected: '{meta.ElementMeta.OnnxValueType}'"); + } - if (nodeMeta.Dimensions.Length != tensor.Dims.Count) - throw new Exception($"{nameof(nodeMeta.Dimensions.Length)} is expected to be equal to {nameof(tensor.Dims.Count)}"); + private static string MakeSequenceElementName(string nodeName, string seqName, int seqNum) + { + if (seqName.Length > 0) + return $"seq.{nodeName}.data.{seqName}.{seqNum}"; + else + return $"seq.{nodeName}.data._.{seqNum}"; + } - for (int i = 0; i < nodeMeta.Dimensions.Length; i++) - { - if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != intDims[i])) - throw new Exception($"{nameof(nodeMeta.Dimensions)}[{i}] is expected to either be -1 or {nameof(intDims)}[{i}]"); - } + internal static NamedOnnxValue CreateNamedOnnxValueFromSequence(Onnx.SequenceProto sequence, string nodeName, NodeMetadata nodeMeta) + { + var sequenceMeta = nodeMeta.AsSequenceMetadata(); + var elemMeta = sequenceMeta.ElementMeta; - if (nodeMeta.ElementType == typeof(float)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(float), intDims); - } - else if (nodeMeta.ElementType == typeof(double)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(double), intDims); - } - else if (nodeMeta.ElementType == typeof(int)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(int), intDims); - } - else if (nodeMeta.ElementType == typeof(uint)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(uint), intDims); - } - else if (nodeMeta.ElementType == typeof(long)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(long), intDims); - } - else if (nodeMeta.ElementType == typeof(ulong)) + int seqNum = 0; + var seqElemType = (Onnx.SequenceProto.Types.DataType)sequence.ElemType; + switch (seqElemType) { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ulong), intDims); - } - else if (nodeMeta.ElementType == typeof(short)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(short), intDims); - } - else if (nodeMeta.ElementType == typeof(ushort)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); - } - else if (nodeMeta.ElementType == typeof(byte)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(byte), intDims); - } - else if (nodeMeta.ElementType == typeof(sbyte)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(sbyte), intDims); - } - else if (nodeMeta.ElementType == typeof(bool)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims); + case Onnx.SequenceProto.Types.DataType.Tensor: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_TENSOR); + var sequenceOfTensors = new List(sequence.TensorValues.Count); + foreach (var tensor in sequence.TensorValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + var namedOnnxValue = LoadTensorPb(tensor, elemName, elemMeta); + sequenceOfTensors.Add(namedOnnxValue); + } + return NamedOnnxValue.CreateFromSequence(nodeName, sequenceOfTensors); + } + case Onnx.SequenceProto.Types.DataType.Sequence: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_SEQUENCE); + var seqOfSequences = new List(sequence.SequenceValues.Count); + foreach (var s in sequence.SequenceValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + seqOfSequences.Add(CreateNamedOnnxValueFromSequence(s, elemName, elemMeta)); + } + return NamedOnnxValue.CreateFromSequence(nodeName, seqOfSequences); + } + case Onnx.SequenceProto.Types.DataType.Map: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_MAP); + var seqOfMaps = new List(sequence.MapValues.Count); + foreach (var m in sequence.MapValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + seqOfMaps.Add(CreateNamedOnnxValueFromMap(m, elemName, elemMeta)); + } + return NamedOnnxValue.CreateFromSequence(nodeName, seqOfMaps); + } + case Onnx.SequenceProto.Types.DataType.Optional: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_OPTIONAL); + var seqOfOpts = new List(sequence.OptionalValues.Count); + foreach (var opt in sequence.OptionalValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + seqOfOpts.Add(CreateNamedOnnxValueFromOptional(opt, elemName, elemMeta)); + } + return NamedOnnxValue.CreateFromSequence(nodeName, seqOfOpts); + } + default: + throw new NotImplementedException($"Sequence test data loading does not support element type: " + + $"'{seqElemType}'"); } - else if (nodeMeta.ElementType == typeof(Float16)) + + } + + internal static NamedOnnxValue CastAndCreateFromMapKeys(string name, TensorElementType elementType, IList keys) + { + switch (elementType) { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + case TensorElementType.Float: + return CastAndCreateTensor(name, keys); + case TensorElementType.Double: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int32: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt32: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int16: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt16: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int64: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt64: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt8: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int8: + return CastAndCreateTensor(name, keys); + case TensorElementType.Bool: + return CastAndCreateTensor(name, keys); + case TensorElementType.Float16: + return CastAndCreateTensor(name, keys); + case TensorElementType.BFloat16: + return CastAndCreateTensor(name, keys); + default: + throw new NotImplementedException($"Tensors of type: " + elementType.ToString() + + " not currently supported here, use: CreateNamedOnnxValueFromStringTensor."); } - else if (nodeMeta.ElementType == typeof(BFloat16)) + } + + /// + /// All the keys in maps are stored as an array of longs, so + /// to create a real tensor we need to cast to create a continuous buffer + /// essentially packing it as a raw data. + /// + /// + /// + /// + /// + /// + /// + internal static NamedOnnxValue CastAndCreateTensor(string name, IList elements) + { + // Create raw data + T[] castKeys = new T[elements.Count]; + if (typeof(T) == typeof(Float16) || typeof(T) == typeof(BFloat16)) { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + for (int i = 0; i < elements.Count; i++) + { + var obj = Convert.ChangeType(elements[i], typeof(ushort)); + if (obj == null) + { + throw new InvalidDataException($"Conversion from long to {typeof(T)} failed"); + } + castKeys[i] = (T)obj; + } } else { - //TODO: Add support for remaining types - throw new Exception($"Tensors of type {nameof(nodeMeta.ElementType)} not currently supported in the LoadTensorFromEmbeddedResource"); + for (int i = 0; i < elements.Count; i++) + { + var obj = (T)Convert.ChangeType(elements[i], typeof(T)); + if (obj == null) + { + throw new InvalidDataException($"Conversion from long to {typeof(T)} failed"); + } + castKeys[i] = (T)obj; + } } + var tensor = new DenseTensor(castKeys, new int[] { elements.Count }); + return NamedOnnxValue.CreateFromTensor(name, tensor); } - internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, IReadOnlyDictionary nodeMetaDict) + internal static NamedOnnxValue CreateNamedOnnxValueFromMap(Onnx.MapProto map, string nodeName, NodeMetadata nodeMetadata) { - Onnx.TensorProto tensor = null; + // See GH issue https://github.com/onnx/onnx/issues/5072 + throw new NotImplementedException($"Loading map node: '{nodeName}' not implemented yet"); - var assembly = typeof(TestDataLoader).Assembly; + //var mapMeta = nodeMetadata.AsMapMetadata(); - using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.TestData.{path}")) - { - tensor = Onnx.TensorProto.Parser.ParseFrom(stream); - } + //if ((TensorElementType)map.KeyType != mapMeta.KeyDataType) + //{ + // throw new InvalidDataException($"Node: '{nodeName}' map key type expected: " + + // $"'{mapMeta.KeyDataType}', loaded from test data: '{(TensorElementType)map.KeyType}'"); + //} + + //// temp non-generic(!) container + //NamedOnnxValue keysTensor; + //if (mapMeta.KeyDataType == TensorElementType.String) + //{ + // keysTensor = CreateNamedOnnxValueFromStringTensor(map.StringKeys, nodeName, new int[] { map.StringKeys.Count }); + //} + //else + //{ + // keysTensor = CastAndCreateFromMapKeys(nodeName, mapMeta.KeyDataType, map.Keys); + //} + + //switch ((Onnx.SequenceProto.Types.DataType)map.Values.ElemType) + //{ + // case Onnx.SequenceProto.Types.DataType.Tensor: + // var tensorCount = map.Values.TensorValues.Count; + // break; + // default: + // throw new NotImplementedException("Does not support map value type other than a tensor"); + //} - return LoadTensorPb(tensor, nodeMetaDict); + //return new NamedOnnxValue(string.Empty, new Object(), OnnxValueType.ONNX_TYPE_UNKNOWN); } - internal static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary nodeMetaDict) + internal static NamedOnnxValue CreateNamedOnnxValueFromOptional(Onnx.OptionalProto optional, string nodeName, NodeMetadata nodeMetadata) { - //Set buffer size to 4MB - int readBufferSize = 4194304; - Onnx.TensorProto tensor = null; - using (var file = new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize)) + var meta = nodeMetadata.AsOptionalMetadata().ElementMeta; + switch((Onnx.OptionalProto.Types.DataType)optional.ElemType) { - tensor = Onnx.TensorProto.Parser.ParseFrom(file); + case Onnx.OptionalProto.Types.DataType.Tensor: + { + var tensor = optional.TensorValue; + return LoadTensorPb(tensor, nodeName, meta); + } + case Onnx.OptionalProto.Types.DataType.Sequence: + { + var sequence = optional.SequenceValue; + return CreateNamedOnnxValueFromSequence(sequence, nodeName, meta); + } + case Onnx.OptionalProto.Types.DataType.Map: + { + var map = optional.MapValue; + return CreateNamedOnnxValueFromMap(map, nodeName, meta); + } + case Onnx.OptionalProto.Types.DataType.Optional: + throw new NotImplementedException($"Unable to load '{nodeName}' optional contained within optional"); + default: + // Test data contains OptionalProto with the contained element type undefined. + // the premise is, if the element is not fed as an input, we should not care + // what Onnx type it is. However, we do not need to support AFAIK such inputs + // since the value for them could never be supplied. + throw new NotImplementedException($"Unable to load '{nodeName}' optional element type of: {(Onnx.OptionalProto.Types.DataType)optional.ElemType} type"); } - - return LoadTensorPb(tensor, nodeMetaDict); } internal static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, byte[] rawData, int elemWidth, int[] dimensions) @@ -250,6 +430,19 @@ internal static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, b return NamedOnnxValue.CreateFromTensor(name, dt); } + internal static NamedOnnxValue CreateNamedOnnxValueFromStringTensor(IList strings, + string nodeName, int[] dimensions) + { + string[] strArray = new string[strings.Count]; + for (int i = 0; i < strings.Count; ++i) + { + strArray[i] = System.Text.Encoding.UTF8.GetString(strings[i].ToByteArray()); + } + + var dt = new DenseTensor(strArray, dimensions); + return NamedOnnxValue.CreateFromTensor(nodeName, dt); + } + internal static float[] LoadTensorFromFile(string filename, bool skipheader = true) { var tensorData = new List(); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 61ff2a43ffd40..76518e341f935 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -3,32 +3,33 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; +using System.Text.RegularExpressions; using Microsoft.ML.OnnxRuntime.Tensors; using Xunit; namespace Microsoft.ML.OnnxRuntime.Tests { - /// - /// This is compensate for the absence of string.Contains() in .NET Standard 2.0 - /// Contains(String, StringComparison) - /// - public static class StringExtensions - { - public static bool Contains(this String str, String substring, - StringComparison comp) + /// + /// This is compensate for the absence of string.Contains() in .NET Standard 2.0 + /// Contains(String, StringComparison) + /// + public static class StringExtensions { - if (substring == null) - throw new ArgumentNullException("substring", - "substring cannot be null."); - else if (!Enum.IsDefined(typeof(StringComparison), comp)) - throw new ArgumentException("comp is not a member of StringComparison", - "comp"); - - return str.IndexOf(substring, comp) >= 0; + public static bool Contains(this String str, String substring, + StringComparison comp) + { + if (substring == null) + throw new ArgumentNullException("substring", + "substring cannot be null."); + else if (!Enum.IsDefined(typeof(StringComparison), comp)) + throw new ArgumentException("comp is not a member of StringComparison", + "comp"); + + return str.IndexOf(substring, comp) >= 0; + } } - } - public partial class InferenceTest - { + public partial class InferenceTest + { private const string module = "onnxruntime.dll"; private const string propertiesFile = "Properties.txt"; @@ -40,8 +41,8 @@ public void CanCreateAndDisposeSessionWithModelPath() { Assert.NotNull(session); Assert.NotNull(session.InputMetadata); - Assert.Equal(1, session.InputMetadata.Count); // 1 input node - Assert.True(session.InputMetadata.ContainsKey("data_0")); // input node name + Assert.Equal(1, session.InputMetadata.Count); // 1 input nodeMeta + Assert.True(session.InputMetadata.ContainsKey("data_0")); // input nodeMeta name Assert.Equal(typeof(float), session.InputMetadata["data_0"].ElementType); Assert.True(session.InputMetadata["data_0"].IsTensor); var expectedInputDimensions = new int[] { 1, 3, 224, 224 }; @@ -52,8 +53,8 @@ public void CanCreateAndDisposeSessionWithModelPath() } Assert.NotNull(session.OutputMetadata); - Assert.Equal(1, session.OutputMetadata.Count); // 1 output node - Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output node name + Assert.Equal(1, session.OutputMetadata.Count); // 1 output nodeMeta + Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output nodeMeta name Assert.Equal(typeof(float), session.OutputMetadata["softmaxout_1"].ElementType); Assert.True(session.OutputMetadata["softmaxout_1"].IsTensor); var expectedOutputDimensions = new int[] { 1, 1000, 1, 1 }; @@ -246,9 +247,9 @@ private static Dictionary GetSkippedModels(DirectoryInfo modelsD { "fp16_test_tiny_yolov2", "ImageScaler is not a registered function/op"}, { "fp16_coreml_FNS-Candy", "ImageScaler is not a registered function/op" }, { "fp16_coreml_LinearRegression_NYCTaxi", "Error in Node:featureVectorizer : No Op registered for FeatureVectorizer with domain_version of 1"}, - { "test_bidaf", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile." }, { "test_mnist", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile" }, - { "BERT_Squad", "Could not find an implementation for the node bert / embeddings / one_hot:OneHot(9)" }, + { "BERT_Squad", "Could not find an implementation for the nodeMeta bert / embeddings / one_hot:OneHot(9)" }, + { "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" }, { "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" }, { "tf_resnet_v1_101", "result mismatch when Conv BN Fusion is applied" }, @@ -256,108 +257,85 @@ private static Dictionary GetSkippedModels(DirectoryInfo modelsD { "cntk_simple_seg", "Bad onnx test output caused by wrong SAME_UPPER/SAME_LOWER for ConvTranspose" }, { "coreml_Imputer-LogisticRegression_sklearn_load_breast_cancer", "Can't determine model file name" }, { "mask_rcnn_keras", "Model should be edited to remove the extra outputs" }, - { "test_strnormalizer_export_monday_casesensintive_lower", "ElementType not currently supported"}, - { "test_max_float64", "node test error"}, - { "test_min_uint8", "node test error"}, - { "test_mod_mixed_sign_float64", "node test error"}, - { "test_momentum", "node test error"}, - { "test_max_uint16", "node test error"}, - { "test_resize_downsample_scales_linear_align_corners", "node test error"}, - { "test_strnormalizer_nostopwords_nochangecase", "node test error"}, - { "test_adagrad_multiple", "node test error"}, - { "test_einsum_inner_prod", "node test error"}, - { "test_sequence_insert_at_back", "node test error"}, - { "test_mod_mixed_sign_int8", "node test error"}, - { "test_maxunpool_export_with_output_shape", "node test error"}, - { "test_strnormalizer_export_monday_empty_output", "node test error"}, - { "test_strnormalizer_export_monday_insensintive_upper_twodim", "ElementType not currently supported"}, - { "test_min_int16", "node test error"}, - { "test_adagrad", "node test error"}, - { "test_min_float64", "node test error"}, - { "test_max_int16", "node test error"}, - { "test_sequence_insert_at_front", "node test error"}, - { "test_training_dropout_default", "node test error"}, - { "test_training_dropout", "node test error"}, - { "test_adam", "node test error"}, - { "test_training_dropout_mask", "node test error"}, - { "test_clip_default_int8_inbounds", "node test error"}, - { "test_eyelike_with_dtype", "node test error"}, - { "test_cast_STRING_to_FLOAT", "node test error"}, - { "test_cast_FLOAT16_to_DOUBLE", "node test error"}, - { "test_cast_FLOAT_to_DOUBLE", "node test error"}, - { "test_cast_BFLOAT16_to_FLOAT", "node test error"}, - { "test_cast_FLOAT_to_BFLOAT16", "node test error"}, - { "test_cast_FLOAT_to_STRING", "node test error"}, - { "test_castlike_STRING_to_FLOAT", "node test error"}, - { "test_castlike_STRING_to_FLOAT_expanded", "node test error"}, - { "test_castlike_FLOAT16_to_DOUBLE", "node test error"}, - { "test_castlike_FLOAT16_to_DOUBLE_expanded", "node test error"}, - { "test_castlike_FLOAT_to_DOUBLE", "node test error"}, - { "test_castlike_FLOAT_to_DOUBLE_expanded", "node test error"}, - { "test_castlike_BFLOAT16_to_FLOAT", "node test error"}, - { "test_castlike_BFLOAT16_to_FLOAT_expanded", "node test error"}, - { "test_castlike_FLOAT_to_BFLOAT16", "node test error"}, - { "test_castlike_FLOAT_to_BFLOAT16_expanded", "node test error"}, - { "test_castlike_FLOAT_to_STRING", "node test error"}, - { "test_castlike_FLOAT_to_STRING_expanded", "node test error"}, - { "test_bitshift_right_uint16", "node test error"}, - { "test_bitshift_left_uint16", "node test error"}, - { "test_pow_types_float32_uint64", "node test error"}, - { "test_max_uint8", "node test error"}, - { "test_strnormalizer_export_monday_casesensintive_nochangecase", "ElementType not currently supported"}, - { "test_momentum_multiple", "node test error"}, - { "test_pow_types_float32_uint32", "node test error"}, - { "test_if_seq", "sequence type is not supported in test infra."}, - { "test_resize_downsample_scales_cubic_align_corners", "node test error"}, - { "test_einsum_batch_matmul", "node test error"}, - { "test_nesterov_momentum", "node test error"}, - { "test_strnormalizer_export_monday_casesensintive_upper", "node test error"}, - { "test_min_uint16", "node test error"}, - { "test_adam_multiple", "node test error"}, - { "test_loop13_seq", "sequence type is not supported in test infra." }, - { "test_training_dropout_default_mask", "node test error"}, - { "test_min_int8", "node test error"}, - { "test_identity_sequence", "data type not supported"}, + + { "test_maxunpool_export_with_output_shape", "results mismatch"}, + + { "test_min_int8", "Could not find an implementation for Min(13) node with name"}, + { "test_min_uint8", "Could not find an implementation for Min(13) node with name"}, + { "test_min_int16", "Could not find an implementation for Min(13) node with name"}, + { "test_min_uint16", "Could not find an implementation for Min(13) node with name"}, + + { "test_max_int8", "Could not find an implementation for Max(13) node with name"}, + { "test_max_uint8", "Could not find an implementation for Max(13) node with name"}, + { "test_max_int16", "Could not find an implementation for Max(13) node with name"}, + { "test_max_uint16", "Could not find an implementation for Max(13) nodeMeta with name '"}, + + { "test_mul_uint8", "Could not find an implementation for Mul(14) node with name" }, + + { "test_cast_STRING_to_FLOAT", "Output mismatch"}, + { "test_cast_BFLOAT16_to_FLOAT", "Output mismatch"}, + { "test_cast_FLOAT_to_STRING", "Output strings can not be compared exactly"}, + { "test_castlike_STRING_to_FLOAT", "Output mismatch"}, + { "test_castlike_STRING_to_FLOAT_expanded", "Output mismatch"}, + { "test_castlike_BFLOAT16_to_FLOAT", "Length is expected to be equal to Count (metadata and expected data mismatch) "}, + { "test_castlike_BFLOAT16_to_FLOAT_expanded", "Length is expected to be equal to Count metadata and expected data mismatch"}, + { "test_castlike_FLOAT_to_BFLOAT16", "Length is expected to be equal to Count. Testdata dims length do not match that of model metadata"}, + { "test_castlike_FLOAT_to_BFLOAT16_expanded", "Length is expected to be equal to Count"}, + { "test_castlike_FLOAT_to_STRING", "string comparison does not match due to float rounding"}, + { "test_castlike_FLOAT_to_STRING_expanded", "string comparison does not match due to float rounding"}, + + { "test_bitshift_right_uint16", "Could not find an implementation for BitShift(11) nodeMeta with name ''"}, + { "test_bitshift_left_uint16", "Could not find an implementation for BitShift(11)"}, + + { "test_pow_types_float32_uint64", "Could not find an implementation for Pow(15) node with name ''"}, + { "test_pow_types_float32_uint32", "Could not find an implementation for Pow(15) node with name ''"}, + + { "test_resize_downsample_scales_cubic_align_corners", "Results mismatch"}, + { "test_resize_downsample_scales_linear_align_corners", "Results mismatch"}, + { "test_gru_batchwise", "batchwise operations not supported"}, - { "test_lstm_batchwise", "batchwise operations not supported"}, + { "test_lstm_batchwise", "Batchwise recurrent operations(layout == 1) are not supported.If you need support create a github issue with justification."}, { "test_simple_rnn_batchwise", "batchwise operations not supported"}, { "test_batchnorm_example_training_mode", "opset14 version not implemented yet"}, - { "test_bernoulli", "random generator"}, - { "test_bernoulli_seed", "random generator"}, - { "test_bernoulli_double", "random generator"}, - { "test_bernoulli_expanded", "random generator"}, - { "test_bernoulli_seed_expanded", "random generator"}, - { "test_bernoulli_double_expanded", "random generator"}, - { "test_shape", "opset15 version not implemented yet"}, - { "test_optional_get_element", "optional type is not supported in test infra."}, - { "test_optional_get_element_sequence", "optional type is not supported in test infra."}, - { "test_identity_opt", "optional type is not supported in test infra." }, - { "test_if_opt", "optional type is not supported in test infra." }, - { "test_loop16_seq_none", "sequence type is not supported in test infra." }, - { "test_sequence_map_extract_shapes", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence_1_tensor", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence_1_tensor_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_1_sequence_1_tensor", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_2_sequences", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_2_sequences_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_2_sequences_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_extract_shapes_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_1_sequence_1_tensor_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_2_sequences", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence", "sequence type is not supported in test infra." }, - { "BERT-Squad-int8", "training domain"}, - { "YOLOv3-12-int8", "training_domain"}, + + { "test_bernoulli", "random generator, results mismatch"}, + { "test_bernoulli_seed", "random generator, results mismatch"}, + { "test_bernoulli_double", "random generator, results mismatch"}, + { "test_bernoulli_expanded", "random generator, results mismatch"}, + { "test_bernoulli_seed_expanded", "random generator, results mismatch"}, + { "test_bernoulli_double_expanded", "random generator, results mismatch"}, + // the expansion of Softplus uses Exp(1). ORT has a Softplus kernel, so testing the expansion is // unnecessary and fails as ORT support for Exp started at opset 6 (as ORT didn't exist until opset 7). - { "test_softplus_example_expanded", "Not applicable"}, - { "test_softplus_expanded", "Not applicable"}, - { "test_col2im_pads", "due to a typo in test data"}, - { "test_optional_has_element_empty_optional_input", "C# API doesn't support optional input"}, - { "test_optional_get_element_optional_tensor", "C# API doesn't support optional input"}, - { "test_optional_get_element_optional_sequence", "C# API doesn't support optional input"}, - { "test_optional_has_element_tensor_input", "C# API doesn't support optional input"}, - { "test_optional_has_element_optional_input", "C# API doesn't support optional input"}, + + { "test_clip_default_int8_max_expanded", "Could not find an implementation for Less(13) nodeMeta with name ''" }, + { "test_softplus_expanded", "Could not find an implementation for Exp(1) node with name ''"}, + { "test_softplus_example_expanded", "Could not find an implementation for Exp(1) node with name ''"}, + { "test_div_uint8", "Could not find an implementation for Div(14) nodeMeta with name ''"}, + { "test_add_uint8", "Opset18 Could not find an implementation for Add(14) nodeMeta with name ''"}, + { "test_col2im_pads", "Results mismatch due to a typo in test data"}, + + { "test_optional_has_element_empty_optional_input", "OptionalProto test metadata. Unable to load 'optional_input' optional element type of: Undefined type"}, + { "test_loop13_seq", "3rd input is an empty sequence. Ort API does not tolerate empty seq: Number of values should be at least 1" }, + + // Training tests + { "BERT-Squad-int8", "training domain"}, + { "YOLOv3-12-int8", "training_domain"}, + + { "test_training_dropout_default", "results mismatch"}, + { "test_training_dropout_default_mask", "Results mismatch"}, + { "test_training_dropout", "results mismatch"}, + { "test_training_dropout_mask", "results mismatch."}, + + { "test_momentum", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"}, + { "test_momentum_multiple", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"}, + { "test_nesterov_momentum", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"}, + + { "test_adam", "ai.onnx.preview.training:Adam(-1) is not a registered function/op"}, + { "test_adam_multiple", "ai.onnx.preview.training:Adam(-1) is not a registered function/op"}, + + { "test_adagrad", "ai.onnx.preview.training:Adagrad(-1) is not a registered function/op"}, + { "test_adagrad_multiple", "ai.onnx.preview.training:Adagrad(-1) is not a registered function/op"}, }; // The following models fails on nocontribops win CI @@ -460,6 +438,110 @@ public static IEnumerable GetSkippedModelForTest() } } + private string MatchInputOutputWithFile(string fileName, InferenceSession session, bool input, out NodeMetadata result) + { + string nodeName = string.Empty; + result = null; + var names = (input) ? session.InputNames : session.OutputNames; + var metadata = (input) ? session.InputMetadata : session.OutputMetadata; + string regEx = (input) ? @"input_(\d{1,}).pb" : @"output_(\d{1,}).pb"; + var inpOut = (input) ? "input" : "output"; + + // Extract the number from the file name, if not try to match the input/output name with the name of the file. + try + { + // captures start at index 1 + var group = Regex.Matches(fileName, regEx).Single().Groups[1]; + var num = int.Parse(group.Value); + if (num >= 0 && num < names.Count) + { + nodeName = names[num]; + result = metadata[nodeName]; + } + else + { + throw new InvalidDataException($"Filename '{fileName}' {inpOut} number '{num}' is out of range for '{names.Count}' {inpOut}(s)"); + } + } + catch (Exception) + { + // Either does not match or can not parse the number + } + + if (result is null) + { + throw new InvalidDataException($"Unable to match file: {fileName} to input/output metadata"); + } + return nodeName; + } + + // The numbering of the input files does not match the order of outputs + // listed in the metadata of test_BERT_Squad. Model metadata order: + // "unique_ids_raw_output___9:0", "segment_ids:0", "input_mask:0", "input_ids:0" + // The corr input files are: input_0.pb, input_3.pb, input_2.pb, input_1.pb + // Everything in reverse, but the 0. + + // Previously, it worked because our test data has matching + // tensor names that we could match to metadata after we load the tensor. + // But now, we need to know ahead of time what Onnx type we load, and thus match + // metadata with the test data file before loading. Protobuf can happily load whatever + // and give you garbage. + + private string MatchBertSquadInputs(string fileName) + { + string nodeName = string.Empty; + switch (fileName) + { + case "input_0.pb": + nodeName = "unique_ids_raw_output___9:0"; + break; + case "input_1.pb": + nodeName = "input_ids:0"; + break; + case "input_2.pb": + nodeName = "input_mask:0"; + break; + case "input_3.pb": + nodeName = "segment_ids:0"; + break; + default: + throw new InvalidDataException($"Unhandled input file name: '{fileName}' for test_BERT_Squad"); + } + return nodeName; + } + + // The model actually has only 3 outputs, but the Zoo version has 4 files are supplied. + // The numbering of the output files does not match the order of outputs + // listed in the metadata. + + // Previously, it worked because our CI test data version has matching + // tensor names that we could match to metadata after we load the tensor. + // But now, we need to know ahead of time what Onnx type we load, and thus match + // metadata with the test data file before loading. Protobuf can happily load whatever + // and give you garbage. + + // Order in the metadata: unstack:1, unstack:0, unique_ids:0 + // The files are in reverse order + private string MatchBertSquadOutputs(string fileName) + { + string nodeName = string.Empty; + switch (fileName) + { + case "output_0.pb": // Int64 + nodeName = "unique_ids:0"; + break; + case "output_1.pb": + nodeName = "unstack:0"; + break; + case "output_2.pb": + nodeName = "unstack:1"; + break; + default: + throw new InvalidDataException($"Unhandled output file name: '{fileName}' for test_BERT_Squad"); + } + return nodeName; + } + [Theory(DisplayName = "TestPreTrainedModels")] [MemberData(nameof(GetModelsForTest))] [MemberData(nameof(GetSkippedModelForTest), Skip = "Skipped due to Error, please fix the error and enable the test")] @@ -501,6 +583,7 @@ private void TestPreTrainedModels(string opsetDir, string modelName) using (var session = new InferenceSession(onnxModelFileName)) { var inMeta = session.InputMetadata; + var outMeta = session.OutputMetadata; string testDataDirNamePattern = "test_data*"; if (opset == "opset9" && modelName == "LSTM_Seq_lens_unpacked") { @@ -508,15 +591,52 @@ private void TestPreTrainedModels(string opsetDir, string modelName) } foreach (var testDataDir in modelDir.EnumerateDirectories(testDataDirNamePattern)) { - var inputContainer = new List(); - var outputContainer = new List(); + var inputContainer = new List(inMeta.Count); + var outputContainer = new List(outMeta.Count); foreach (var f in testDataDir.EnumerateFiles("input_*.pb")) { - inputContainer.Add(TestDataLoader.LoadTensorFromFilePb(f.FullName, inMeta)); + if (modelName == "keras_prelu_ImageNet_small" && opset == "opset9") + { + // The model has 1 input, match all file names (they are different in each data set) + // to the same input + var nodeName = "p_re_lu_3_input"; + var nodeMeta = inMeta[nodeName]; + inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); + } + else if (modelName == "test_BERT_Squad" && opset == "opset8") + { + string nodeName = MatchBertSquadInputs(f.Name); + var nodeMeta = inMeta[nodeName]; + inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); + } + else + { + var nodeName = MatchInputOutputWithFile(f.Name, session, true, out NodeMetadata nodeMeta); + inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); + } } foreach (var f in testDataDir.EnumerateFiles("output_*.pb")) { - outputContainer.Add(TestDataLoader.LoadTensorFromFilePb(f.FullName, session.OutputMetadata)); + if (modelName == "keras_prelu_ImageNet_small" && opset == "opset9") + { + // The model has 1 output, match all file names (they are different in each data set) + // to the same output + var nodeName = "p_re_lu_3/add:0"; + var nodeMeta = outMeta[nodeName]; + outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); + } + else if (modelName == "test_BERT_Squad" && opset == "opset8") + { + string nodeName = MatchBertSquadOutputs(f.Name); + var nodeMeta = outMeta[nodeName]; + outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); + } + else + { + // Otherwise, just match trailing filename number to the input name -> metadata + var nodeName = MatchInputOutputWithFile(f.Name, session, false, out NodeMetadata nodeMeta); + outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); + } } using (var resultCollection = session.Run(inputContainer)) @@ -524,7 +644,6 @@ private void TestPreTrainedModels(string opsetDir, string modelName) foreach (var result in resultCollection) { Assert.True(session.OutputMetadata.ContainsKey(result.Name)); - var outputMeta = session.OutputMetadata[result.Name]; NamedOnnxValue outputValue = null; foreach (var o in outputContainer) { @@ -534,72 +653,32 @@ private void TestPreTrainedModels(string opsetDir, string modelName) break; } } - if (outputValue == null) - { - outputValue = outputContainer.First(); // in case the output data file does not contain the name - } - if (outputMeta.IsTensor) + + Assert.NotNull(outputValue); + + var outputMeta = session.OutputMetadata[result.Name]; + if (outputMeta.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL) { - if (outputMeta.ElementType == typeof(float)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new FloatComparer()); - } - else if (outputMeta.ElementType == typeof(double)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new DoubleComparer()); - } - else if (outputMeta.ElementType == typeof(int)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(uint)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(short)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(ushort)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(long)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(ulong)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(byte)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(sbyte)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(bool)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(Float16)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new Float16Comparer { tolerance = 2 }); - } - else if (outputMeta.ElementType == typeof(BFloat16)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new BFloat16Comparer { tolerance = 2 }); - } - else - { - Assert.True(false, $"{nameof(TestPreTrainedModels)} does not yet support output of type {outputMeta.ElementType}"); - } + outputMeta = outputMeta.AsOptionalMetadata().ElementMeta; } - else + + Assert.Equal(outputValue.ValueType, outputMeta.OnnxValueType); + + switch (outputValue.ValueType) { - Assert.True(false, $"{nameof(TestPreTrainedModels)} cannot handle non-tensor outputs yet"); + case OnnxValueType.ONNX_TYPE_TENSOR: // Only Dense tensors now + { + VerifyTensorResults(outputMeta.ElementDataType, result, outputValue); + } + break; + case OnnxValueType.ONNX_TYPE_SEQUENCE: + { + VerifySequenceResults(result, outputValue, outputMeta); + } + break; + default: + Assert.True(false, $"TestPreTrainedModels cannot handle Onnxtype: {outputValue.ValueType}"); + break; } } } @@ -624,6 +703,87 @@ private void TestPreTrainedModels(string opsetDir, string modelName) } } + private void VerifySequenceResults(NamedOnnxValue result, NamedOnnxValue expectedValue, NodeMetadata metaData) + { + var meta = metaData.AsSequenceMetadata(); + var resultSequence = result.AsEnumerable(); + var expectedSequence = expectedValue.AsEnumerable(); + Assert.Equal(resultSequence.Count(), expectedSequence.Count()); + + foreach (var (resultItem, expectedItem) in resultSequence.Zip(expectedSequence, (r, e) => (r, e))) + { + Assert.Equal(resultItem.ValueType, expectedItem.ValueType); + Assert.Equal(resultItem.ValueType, meta.ElementMeta.OnnxValueType); + switch (resultItem.ValueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + VerifyTensorResults(meta.ElementMeta.ElementDataType, resultItem, expectedItem); + break; + case OnnxValueType.ONNX_TYPE_SEQUENCE: + { + VerifySequenceResults(resultItem, expectedItem, meta.ElementMeta); + } + break; + default: + Assert.True(false, "VerifySequenceResults cannot handle Onnxtype: " + resultItem.ValueType.ToString()); + break; + } + Assert.Equal(resultItem.AsTensor(), expectedItem.AsTensor(), new FloatComparer()); + } + } + + private void VerifyTensorResults(TensorElementType elementType, NamedOnnxValue result, NamedOnnxValue outputValue) + { + switch (elementType) + { + case TensorElementType.Float: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new FloatComparer()); + break; + case TensorElementType.Double: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new DoubleComparer()); + break; + case TensorElementType.Int32: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt32: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Int16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Int64: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt64: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt8: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Int8: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Bool: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Float16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new Float16Comparer { tolerance = 2 }); + break; + case TensorElementType.BFloat16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new BFloat16Comparer { tolerance = 2 }); + break; + case TensorElementType.String: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + default: + Assert.True(false, "TestPreTrainedModels does not yet support output of type: " + elementType.ToString()); + break; + } + } + // Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle private void UnloadLibrary(IntPtr libraryHandle) { @@ -669,7 +829,8 @@ private void TestRegisterCustomOpLibrary() var ortEnvInstance = OrtEnv.Instance(); string[] providers = ortEnvInstance.GetAvailableProviders(); - if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) { + if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) + { option.AppendExecutionProvider_CUDA(0); } diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs index 4f29d72b0b146..9370a03f7fbeb 100644 --- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs +++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using CommandLine; -using Google.Protobuf; using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0b41ee9ffad4f..b6e313bd20cb4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -263,10 +263,11 @@ ORT_RUNTIME_CLASS(Value); ORT_RUNTIME_CLASS(RunOptions); ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); -ORT_RUNTIME_CLASS(SessionOptions); -ORT_RUNTIME_CLASS(CustomOpDomain); ORT_RUNTIME_CLASS(MapTypeInfo); ORT_RUNTIME_CLASS(SequenceTypeInfo); +ORT_RUNTIME_CLASS(OptionalTypeInfo); +ORT_RUNTIME_CLASS(SessionOptions); +ORT_RUNTIME_CLASS(CustomOpDomain); ORT_RUNTIME_CLASS(ModelMetadata); ORT_RUNTIME_CLASS(ThreadPoolParams); ORT_RUNTIME_CLASS(ThreadingOptions); @@ -578,9 +579,14 @@ typedef struct OrtMIGraphXProviderOptions { */ typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus - OrtOpenVINOProviderOptions() : device_type{}, enable_vpu_fast_compile{}, device_id{}, - num_of_threads{}, cache_dir{}, - context{}, enable_opencl_throttling{}, enable_dynamic_shapes{} {} + OrtOpenVINOProviderOptions() : device_type{}, + enable_vpu_fast_compile{}, + device_id{}, + num_of_threads{}, + cache_dir{}, + context{}, + enable_opencl_throttling{}, + enable_dynamic_shapes{} {} #endif /** \brief Device type string * @@ -589,8 +595,8 @@ typedef struct OrtOpenVINOProviderOptions { const char* device_type; unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; - size_t num_of_threads; ///< 0 = Use default number of threads - const char* cache_dir; // path is set to empty by default + size_t num_of_threads; ///< 0 = Use default number of threads + const char* cache_dir; // path is set to empty by default void* context; unsigned char enable_opencl_throttling; ///< 0 = disabled, nonzero = enabled unsigned char enable_dynamic_shapes; ///< 0 = disabled, nonzero = enabled @@ -1340,8 +1346,9 @@ struct OrtApi { * * \param[in] type_info * \param[out] out Do not free this value, it will be valid until type_info is freed. + * If type_info does not represent tensor, this value will be set to nullptr. * - * \snippet{doc} snippets.dox OrtStatus Return Value + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. */ ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); @@ -1850,9 +1857,10 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[out] type_info - * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value + * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value. If type_info + * does not contain a map, this value will be set to nullptr. * - * \snippet{doc} snippets.dox OrtStatus Return Value + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. */ ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtMapTypeInfo** out); @@ -1865,9 +1873,10 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[in] type_info - * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value + * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value. If type_info + * doesn not contain a sequence, this value will be set to nullptr. * - * \snippet{doc} snippets.dox OrtStatus Return Value + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. */ ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); @@ -3932,7 +3941,7 @@ struct OrtApi { ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options); - /** \brief Create an OrtDnnlProviderOptions + /** \brief Create an OrtDnnlProviderOptions * * \param[out] out Newly created ::OrtDnnlProviderOptions. Must be released with OrtApi::ReleaseDnnlProviderOptions * @@ -4101,9 +4110,47 @@ struct OrtApi { */ ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); -#ifdef __cplusplus - OrtApi(const OrtApi&) = delete; // Prevent users from accidentally copying the API structure, it should always be passed as a pointer -#endif + /** \brief Get Optional Type information from an ::OrtTypeInfo + * + * This augments ::OrtTypeInfo to return an ::OrtOptionalTypeInfo when the type is optional. + * The OrtOptionalTypeInfo also has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by ::OrtOptionalTypeInfo. + * + * So the picture: ::OrtTypeInfo -> ::OrtOptionalTypeInfo -> ::OrtTypeInfo (describes the type that can be supplied + * in place of the optional type when creating the actual ::OrtValue). + * + * \param[in] type_info + * \param[out] out A pointer to the ::OrtOptionalTypeInfo. Do not free this value, + * it is owned by OrtTypeInfo instance. When the type_info does not represent + * optional type, nullptr is returned in out. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. + * + * \since Version 1.15. + */ + ORT_API2_STATUS(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); + + /** \brief Get OrtTypeInfo for the allowed contained type from an ::OrtOptionalTypeInfo. + * + * This augments ::OrtOptionalTypeInfo to return an ::OrtTypeInfo for the contained type. + * The OrtOptionalTypeInfo has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by the returned ::OrtTypeInfo. + * + * \param[in] optional_type_info + * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. + * it is owned by OrtOptionalTypeInfo instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. + * + * \since Version 1.15. + */ + ORT_API2_STATUS(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5af45b8ff38e2..2086193f0c39a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -890,6 +890,19 @@ struct SequenceTypeInfo : detail::SequenceTypeInfoImpl { ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } }; +namespace detail { +template +struct OptionalTypeInfoImpl : Base { + using B = Base; + using B::B; + TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo +}; + +} // namespace detail + +// This is always owned by the TypeInfo and can only be obtained from it. +using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl>; + namespace detail { template struct MapTypeInfoImpl : detail::Base { @@ -921,6 +934,7 @@ struct TypeInfoImpl : detail::Base { ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo + ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo ONNXType GetONNXType() const; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 8e2d845d2f62c..899c6c331a2cd 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1074,9 +1074,6 @@ inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { return out; } -} // namespace detail - -namespace detail { template inline ConstTensorTypeAndShapeInfo TypeInfoImpl::GetTensorTypeAndShapeInfo() const { const OrtTensorTypeAndShapeInfo* out; @@ -1105,9 +1102,6 @@ inline ONNXType TypeInfoImpl::GetONNXType() const { return out; } -} // namespace detail - -namespace detail { template inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { OrtTypeInfo* output; @@ -1115,9 +1109,13 @@ inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { return TypeInfo{output}; } -} // namespace detail +template +inline TypeInfo OptionalTypeInfoImpl::GetOptionalElementType() const { + OrtTypeInfo* info; + ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info)); + return TypeInfo{info}; +} -namespace detail { template inline ONNXTensorElementDataType MapTypeInfoImpl::GetMapKeyType() const { ONNXTensorElementDataType out; @@ -1131,6 +1129,14 @@ inline TypeInfo MapTypeInfoImpl::GetMapValueType() const { ThrowOnError(GetApi().GetMapValueType(this->p_, &output)); return TypeInfo{output}; } + +template +inline ConstOptionalTypeInfo TypeInfoImpl::GetOptionalTypeInfo() const { + const OrtOptionalTypeInfo* info; + ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info)); + return ConstOptionalTypeInfo{info}; +} + } // namespace detail namespace detail { diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 9b18ba6703693..b87ea179c0700 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -1,70 +1,97 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #include "core/framework/onnxruntime_map_type_info.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type) noexcept : map_key_type_(map_key_type), map_value_type_(map_value_type, &OrtApis::ReleaseTypeInfo) { +OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, + std::unique_ptr map_value_type) noexcept + : map_key_type_(map_key_type), map_value_type_(std::move(map_value_type)) { } +OrtMapTypeInfo::~OrtMapTypeInfo() = default; + static ONNXTensorElementDataType ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { using TensorType = ONNX_NAMESPACE::TensorProto_DataType; switch (data_type) { - case TensorType::TensorProto_DataType_BOOL: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; } - case TensorType::TensorProto_DataType_STRING: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; } // maps to c++ type std::string - case TensorType::TensorProto_DataType_FLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; } // maps to c type float16 - case TensorType::TensorProto_DataType_FLOAT: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } // maps to c type float - case TensorType::TensorProto_DataType_DOUBLE: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; } // maps to c type double - case TensorType::TensorProto_DataType_INT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; } // maps to c type int8_t - case TensorType::TensorProto_DataType_INT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; } // maps to c type int16_t - case TensorType::TensorProto_DataType_INT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } // maps to c type int32_t - case TensorType::TensorProto_DataType_INT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } // maps to c type int64_t - case TensorType::TensorProto_DataType_UINT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; } // maps to c type uint8_t - case TensorType::TensorProto_DataType_UINT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; } // maps to c type uint16_t - case TensorType::TensorProto_DataType_UINT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; } // maps to c type uint32_t - case TensorType::TensorProto_DataType_UINT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; } // maps to c type uint64_t - case TensorType::TensorProto_DataType_COMPLEX64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; } // complex with float32 real and imaginary components - case TensorType::TensorProto_DataType_COMPLEX128: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; } // complex with float64 real and imaginary components - case TensorType::TensorProto_DataType_BFLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; } // Non-IEEE floating-point format based on IEEE754 single-precision - default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } + case TensorType::TensorProto_DataType_BOOL: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; + } + case TensorType::TensorProto_DataType_STRING: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + } // maps to c++ type std::string + case TensorType::TensorProto_DataType_FLOAT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + } // maps to c type float16 + case TensorType::TensorProto_DataType_FLOAT: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } // maps to c type float + case TensorType::TensorProto_DataType_DOUBLE: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + } // maps to c type double + case TensorType::TensorProto_DataType_INT8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + } // maps to c type int8_t + case TensorType::TensorProto_DataType_INT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; + } // maps to c type int16_t + case TensorType::TensorProto_DataType_INT32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + } // maps to c type int32_t + case TensorType::TensorProto_DataType_INT64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } // maps to c type int64_t + case TensorType::TensorProto_DataType_UINT8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + } // maps to c type uint8_t + case TensorType::TensorProto_DataType_UINT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; + } // maps to c type uint16_t + case TensorType::TensorProto_DataType_UINT32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; + } // maps to c type uint32_t + case TensorType::TensorProto_DataType_UINT64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; + } // maps to c type uint64_t + case TensorType::TensorProto_DataType_COMPLEX64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; + } // complex with float32 real and imaginary components + case TensorType::TensorProto_DataType_COMPLEX128: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; + } // complex with float64 real and imaginary components + case TensorType::TensorProto_DataType_BFLOAT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; + } // Non-IEEE floating-point format based on IEEE754 single-precision + default: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + } } } -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif -OrtStatus* OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtMapTypeInfo** out) { - auto value_case = type_proto->value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kMapType) - { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type map!");; - } + +std::unique_ptr OrtMapTypeInfo::FromTypeProto( + const ONNX_NAMESPACE::TypeProto& type_proto) { + auto value_case = type_proto.value_case(); + + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kMapType, "type_proto is not of type map!"); // Get the key type of the map - auto type_proto_map = type_proto->map_type(); - auto map_key_type = ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); + const auto& type_proto_map = type_proto.map_type(); + const auto map_key_type = ToONNXTensorElementDataType( + ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); // Get the value type of the map - OrtTypeInfo* map_value_type_info = nullptr; - if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_map.value_type(), &map_value_type_info)) - { - return status; - } + auto map_value_type_info = OrtTypeInfo::FromTypeProto(type_proto_map.value_type()); - *out = new OrtMapTypeInfo(map_key_type, map_value_type_info); - return nullptr; + return std::make_unique(map_key_type, std::move(map_value_type_info)); } -OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) { - OrtTypeInfo* map_value_type_copy = nullptr; - if (auto status = map_value_type_->Clone(&map_value_type_copy)) - { - return status; - } - *out = new OrtMapTypeInfo(map_key_type_, map_value_type_copy); - return nullptr; +std::unique_ptr OrtMapTypeInfo::Clone() const { + auto map_value_type_copy = map_value_type_->Clone(); + return std::make_unique(map_key_type_, std::move(map_value_type_copy)); } // OrtMapTypeInfo Accessors @@ -76,12 +103,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_ API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, + _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN - return map_type_info->map_value_type_->Clone(out); + auto clone = map_type_info->map_value_type_->Clone(); + *out = clone.release(); + return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo* ptr) { - delete ptr; + std::unique_ptr p(ptr); } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.h b/onnxruntime/core/framework/onnxruntime_map_type_info.h index 46477d8f04fa7..6b20a94b30a56 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.h @@ -10,17 +10,21 @@ namespace ONNX_NAMESPACE { class TypeProto; } +struct OrtTypeInfo; + struct OrtMapTypeInfo { public: + ONNXTensorElementDataType map_key_type_ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - std::unique_ptr map_value_type_; + std::unique_ptr map_value_type_; + + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); - static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtMapTypeInfo** out); + std::unique_ptr Clone() const; - OrtStatus* Clone(OrtMapTypeInfo** out); + OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, std::unique_ptr map_value_type) noexcept; + ~OrtMapTypeInfo(); - private: - OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type)noexcept; OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; OrtMapTypeInfo& operator=(const OrtMapTypeInfo& other) = delete; diff --git a/onnxruntime/core/framework/onnxruntime_optional_type_info.cc b/onnxruntime/core/framework/onnxruntime_optional_type_info.cc new file mode 100644 index 0000000000000..d0eda08819873 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_optional_type_info.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/framework/onnxruntime_optional_type_info.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +OrtOptionalTypeInfo::OrtOptionalTypeInfo(std::unique_ptr contained_type) noexcept + : contained_type_(std::move(contained_type)) { +} + +OrtOptionalTypeInfo::~OrtOptionalTypeInfo() = default; + +std::unique_ptr OrtOptionalTypeInfo::FromTypeProto( + const ONNX_NAMESPACE::TypeProto& type_proto) { + const auto value_case = type_proto.value_case(); + + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kOptionalType, "type_proto is not of optional type"); + + const auto& type_proto_optional = type_proto.optional_type(); + auto contained_type_info = OrtTypeInfo::FromTypeProto(type_proto_optional.elem_type()); + + return std::make_unique(std::move(contained_type_info)); +} + +std::unique_ptr OrtOptionalTypeInfo::Clone() const { + auto contained_type_copy = contained_type_->Clone(); + return std::make_unique(std::move(contained_type_copy)); +} + +ORT_API_STATUS_IMPL(OrtApis::GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out) { + API_IMPL_BEGIN + auto type_info = optional_type_info->contained_type_->Clone(); + *out = type_info.release(); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/framework/onnxruntime_optional_type_info.h b/onnxruntime/core/framework/onnxruntime_optional_type_info.h new file mode 100644 index 0000000000000..561d055689b53 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_optional_type_info.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include + +struct OrtTypeInfo; + +namespace ONNX_NAMESPACE { +class TypeProto; +} + +struct OrtOptionalTypeInfo { + + explicit OrtOptionalTypeInfo(std::unique_ptr contained_type) noexcept; + ~OrtOptionalTypeInfo(); + + std::unique_ptr contained_type_; + + std::unique_ptr Clone() const; + + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); + + OrtOptionalTypeInfo(const OrtOptionalTypeInfo& other) = delete; + OrtOptionalTypeInfo& operator=(const OrtOptionalTypeInfo& other) = delete; +}; diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc index acae583c7a21f..3f1d852610912 100644 --- a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc @@ -6,43 +6,39 @@ #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -OrtSequenceTypeInfo::OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept : sequence_key_type_(sequence_key_type, &OrtApis::ReleaseTypeInfo) { +OrtSequenceTypeInfo::OrtSequenceTypeInfo(std::unique_ptr sequence_key_type) noexcept + : sequence_key_type_(std::move(sequence_key_type)) { } -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif -OrtStatus* OrtSequenceTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtSequenceTypeInfo** out) { - auto value_case = type_proto->value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kSequenceType) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type sequence!"); - } - - auto type_proto_sequence = type_proto->sequence_type(); - OrtTypeInfo* sequence_key_type_info = nullptr; - if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_sequence.elem_type(), &sequence_key_type_info)) { - return status; - } - - *out = new OrtSequenceTypeInfo(sequence_key_type_info); - return nullptr; + +OrtSequenceTypeInfo::~OrtSequenceTypeInfo() = default; + +std::unique_ptr OrtSequenceTypeInfo::FromTypeProto( + const ONNX_NAMESPACE::TypeProto& type_proto) { + const auto value_case = type_proto.value_case(); + + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kSequenceType, "type_proto is not of type sequence!"); + + const auto& type_proto_sequence = type_proto.sequence_type(); + auto key_type_info = OrtTypeInfo::FromTypeProto(type_proto_sequence.elem_type()); + + return std::make_unique(std::move(key_type_info)); } -OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) { - OrtTypeInfo* sequence_key_type_copy = nullptr; - if (auto status = sequence_key_type_->Clone(&sequence_key_type_copy)) { - return status; - } - *out = new OrtSequenceTypeInfo(sequence_key_type_copy); - return nullptr; +std::unique_ptr OrtSequenceTypeInfo::Clone() const { + auto key_type_copy = sequence_key_type_->Clone(); + return std::make_unique(std::move(key_type_copy)); } -ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, +ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, + _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN - return sequence_type_info->sequence_key_type_->Clone(out); + auto key_type_copy = sequence_type_info->sequence_key_type_->Clone(); + *out = key_type_copy.release(); + return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo* ptr) { - delete ptr; + std::unique_ptr p(ptr); } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h index 5378dd578abe5..d1d1412b92ce7 100644 --- a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h @@ -2,25 +2,26 @@ // Licensed under the MIT License. #pragma once -#include "onnxruntime_c_api.h" - #include +#include "core/framework/onnxruntime_typeinfo.h" + namespace ONNX_NAMESPACE { class TypeProto; } struct OrtSequenceTypeInfo { public: - explicit OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept; - std::unique_ptr sequence_key_type_; + explicit OrtSequenceTypeInfo(std::unique_ptr sequence_key_type) noexcept; + ~OrtSequenceTypeInfo(); + + std::unique_ptr sequence_key_type_; - OrtStatus* Clone(OrtSequenceTypeInfo** out); + std::unique_ptr Clone() const; - static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtSequenceTypeInfo** out); + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); - private: OrtSequenceTypeInfo(const OrtSequenceTypeInfo& other) = delete; OrtSequenceTypeInfo& operator=(const OrtSequenceTypeInfo& other) = delete; }; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index e3a07f84ef32a..678e7e6e78237 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -//this file contains implementations of the C API +// this file contains implementations of the C API #include #include "onnxruntime_typeinfo.h" @@ -15,6 +15,7 @@ #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_map_type_info.h" #include "core/framework/onnxruntime_sequence_type_info.h" +#include "core/framework/onnxruntime_optional_type_info.h" #include "core/framework/TensorSeq.h" using onnxruntime::BFloat16; @@ -27,47 +28,44 @@ using onnxruntime::Tensor; using onnxruntime::TensorShape; namespace on = ONNX_NAMESPACE; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif -OrtTypeInfo::OrtTypeInfo(ONNXType type1) noexcept : type(type1) { -} -OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) { +OrtTypeInfo::OrtTypeInfo(ONNXType type) noexcept : type(type) { } -OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtMapTypeInfo* map_type_info1) noexcept : type(type1), map_type_info(map_type_info1) { -} +OrtTypeInfo::OrtTypeInfo(std::unique_ptr map_type_info) noexcept + : type(ONNX_TYPE_MAP), map_type_info(std::move(map_type_info)) {} -OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtSequenceTypeInfo* sequence_type_info1) noexcept : type(type1), sequence_type_info(sequence_type_info1) { -} +OrtTypeInfo::OrtTypeInfo(std::unique_ptr sequence_type_info) noexcept + : type(ONNX_TYPE_SEQUENCE), sequence_type_info(std::move(sequence_type_info)) {} -OrtTypeInfo::~OrtTypeInfo() { - OrtApis::ReleaseTensorTypeAndShapeInfo(data); +OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info) noexcept + : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info)) {} - if (map_type_info) { - OrtApis::ReleaseMapTypeInfo(map_type_info); - } - if (sequence_type_info) { - OrtApis::ReleaseSequenceTypeInfo(sequence_type_info); - } +OrtTypeInfo::OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept + : type(type), data(std::move(data)) { } +OrtTypeInfo::~OrtTypeInfo() = default; + ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input, _Out_ ONNXType* out) { + API_IMPL_BEGIN *out = input->type; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { - *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data : nullptr; + API_IMPL_BEGIN + *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtMapTypeInfo** out) { API_IMPL_BEGIN - *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info : nullptr; + *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info.get() : nullptr; return nullptr; API_IMPL_END } @@ -75,7 +73,15 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out) { API_IMPL_BEGIN - *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info : nullptr; + *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info.get() : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out) { + API_IMPL_BEGIN + *out = (type_info->type == ONNX_TYPE_OPTIONAL) ? type_info->optional_type_info.get() : nullptr; return nullptr; API_IMPL_END } @@ -90,19 +96,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* } ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { - delete ptr; + std::unique_ptr p(ptr); } -OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const onnxruntime::DataTypeImpl& tensor_data_type, - OrtTensorTypeAndShapeInfo** out); -OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out); +std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { + auto result = MakePtr(ONNX_TYPE_UNKNOWN); -OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { onnxruntime::MLDataType type = value.Type(); if (type == nullptr) { - *out = new OrtTypeInfo(ONNX_TYPE_UNKNOWN); - return nullptr; + return result; } // GetType and GetType do not have TypeProto populated because they return a static @@ -110,51 +112,38 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { // unless they are primitive data types, in which case we as before return them not implemented // however, this way we can support Opaque and we can avoid excessive calls to GetType() if (type->IsTensorType()) { - OrtTensorTypeAndShapeInfo* info = nullptr; const Tensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info); - if (st != nullptr) - return st; + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type); + return MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); } - *out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info); - return nullptr; + return MakePtr(ONNX_TYPE_TENSOR); } if (type->IsSparseTensorType()) { #if !defined(DISABLE_SPARSE_TENSORS) - OrtTensorTypeAndShapeInfo* info = nullptr; const SparseTensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - OrtStatus* st = GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type, &info); - if (st != nullptr) return st; + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type); + return MakePtr(ONNX_TYPE_SPARSETENSOR, std::move(type_shape)); } - *out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, info); - return nullptr; + return MakePtr(ONNX_TYPE_SPARSETENSOR); #else - return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build."); + ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); #endif } if (type->IsTensorSequenceType()) { - OrtTensorTypeAndShapeInfo* info = nullptr; const auto* tensor_data_type = value.Get().DataType(); - if (tensor_data_type != nullptr) { - TensorShape void_shape = {}; - OrtStatus* st = GetTensorShapeAndType(void_shape, *tensor_data_type, &info); - if (st != nullptr) { - return st; - } + ORT_ENFORCE(tensor_data_type != nullptr, "OrtValue is TensorSequence type but has no element Tensor DataType."); - auto element_type_info = new OrtTypeInfo(ONNX_TYPE_TENSOR, info); - auto sequence_type_info = new OrtSequenceTypeInfo(element_type_info); - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, sequence_type_info); - return nullptr; - } else { - return OrtApis::CreateStatus(ORT_FAIL, "OrtValue is TensorSequence type but has no element Tensor DataType."); - } + TensorShape void_shape = {}; + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type); + auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); + auto sequence_type_info = std::make_unique(std::move(type_info)); + return MakePtr(std::move(sequence_type_info)); } const auto* type_proto = type->GetTypeProto(); @@ -162,74 +151,57 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { // Place Opaque first as tensors will be mostly handled above and maps and sequences are not common switch (type_proto->value_case()) { case on::TypeProto::kOpaqueType: { - *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE); - return nullptr; - } -#if !defined(DISABLE_ML_OPS) + result = MakePtr(ONNX_TYPE_OPAQUE); + } break; case on::TypeProto::kMapType: { - return OrtTypeInfo::FromTypeProto(type_proto, out); - } +#if !defined(DISABLE_ML_OPS) + auto map_type_info = OrtMapTypeInfo::FromTypeProto(*type_proto); + result = MakePtr(std::move(map_type_info)); +#else + ORT_NOT_IMPLEMENTED("Map types are not supported in this build"); #endif + } break; case on::TypeProto::kSequenceType: { - return OrtTypeInfo::FromTypeProto(type_proto, out); - } + auto seq_info = OrtSequenceTypeInfo::FromTypeProto(*type_proto); + result = MakePtr(std::move(seq_info)); + } break; // Real Tensor support - case on::TypeProto::kTensorType: + case on::TypeProto::kSparseTensorType: #if !defined(DISABLE_SPARSE_TENSORS) - case on::TypeProto::kSparseTensorType: { - return OrtApis::CreateStatus(ORT_FAIL, "Tensor types should have been handled already"); - } + [[fallthrough]]; +#else + ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); + break; #endif + case on::TypeProto::kTensorType: + ORT_THROW("Tensor types should have been handled already"); + break; default: - // NOT_IMPLEMENTED + ORT_NOT_IMPLEMENTED("This OrtValue is neither Tensor, SparseTensor, Map or Sequence type"); break; } } - - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); + return result; } const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) { - switch (type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_STRING: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT16: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT16: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: - return DataTypeImpl::GetType(); - - default: - ORT_NOT_IMPLEMENTED(__FUNCTION__, ":tensor type ", type, " is not supported"); - } + auto tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(type); + return tensor_type->GetElementType(); } -OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) { - auto value_case = input->value_case(); +std::unique_ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& input) { + std::unique_ptr result; + + auto value_case = input.value_case(); switch (value_case) { - case on::TypeProto::kTensorType: - case on::TypeProto::kSparseTensorType: { + case on::TypeProto::kSparseTensorType: +#if !defined(DISABLE_SPARSE_TENSORS) + [[fallthrough]]; +#else + ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); + break; +#endif + case on::TypeProto::kTensorType: { ONNXType ten_type = ONNX_TYPE_UNKNOWN; const on::TypeProto_Tensor* tensor_type = nullptr; #if !defined(DISABLE_SPARSE_TENSORS) @@ -237,29 +209,30 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or #endif const on::TensorShapeProto* sp = nullptr; if (value_case == on::TypeProto::kTensorType) { - tensor_type = &input->tensor_type(); + tensor_type = &input.tensor_type(); ten_type = ONNX_TYPE_TENSOR; if (onnxruntime::utils::HasShape(*tensor_type)) { sp = &tensor_type->shape(); } } else if (value_case == on::TypeProto::kSparseTensorType) { #if !defined(DISABLE_SPARSE_TENSORS) - sparse_type = &input->sparse_tensor_type(); + sparse_type = &input.sparse_tensor_type(); ten_type = ONNX_TYPE_SPARSETENSOR; if (onnxruntime::utils::HasShape(*sparse_type)) { sp = &sparse_type->shape(); } +#else + ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); #endif } - OrtStatus* st = nullptr; - OrtTensorTypeAndShapeInfo* info = nullptr; + std::unique_ptr type_shape; if (sp != nullptr) { const on::TensorShapeProto& s = *sp; std::vector dims(s.dim_size()); std::vector dim_params(s.dim_size()); TensorShape shape_data(std::move(dims)); - for (int i = 0; i < s.dim_size(); ++i) { + for (int i = 0, dim_size = s.dim_size(); i < dim_size; ++i) { auto& t = s.dim(i); switch (t.value_case()) { case on::TensorShapeProto::Dimension::kDimValue: @@ -275,97 +248,88 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or assert(false); } } - st = GetTensorShapeAndType(shape_data, &dim_params, *input, &info); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input); } else { - st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input); } - if (st != nullptr) return st; - auto type_info = new OrtTypeInfo(ten_type, info); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; + + result = MakePtr(ten_type, std::move(type_shape)); + result->denotation = input.denotation(); } break; case on::TypeProto::kSequenceType: { - OrtSequenceTypeInfo* sequence_type_info = nullptr; - - if (auto status = OrtSequenceTypeInfo::FromTypeProto(input, &sequence_type_info)) { - return status; - } - - auto type_info = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, sequence_type_info); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; + auto sequence_type_info = OrtSequenceTypeInfo::FromTypeProto(input); + result = MakePtr(std::move(sequence_type_info)); + result->denotation = input.denotation(); } break; case on::TypeProto::kMapType: { - OrtMapTypeInfo* map_type_info = nullptr; - - if (auto status = OrtMapTypeInfo::FromTypeProto(input, &map_type_info)) { - return status; - } - - auto type_info = new OrtTypeInfo(ONNX_TYPE_MAP, map_type_info); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; +#if !defined(DISABLE_ML_OPS) + auto map_type_info = OrtMapTypeInfo::FromTypeProto(input); + result = MakePtr(std::move(map_type_info)); + result->denotation = input.denotation(); +#else + ORT_NOT_IMPLEMENTED("Map types are not supported in this build"); +#endif + } break; + case on::TypeProto::kOptionalType: { + auto optional_type_info = OrtOptionalTypeInfo::FromTypeProto(input); + result = MakePtr(std::move(optional_type_info)); + result->denotation = input.denotation(); } break; case on::TypeProto::kOpaqueType: { - auto type_info = new OrtTypeInfo(ONNX_TYPE_OPAQUE); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; + result = MakePtr(ONNX_TYPE_OPAQUE); + result->denotation = input.denotation(); } break; case on::TypeProto::VALUE_NOT_SET: + ORT_THROW("This TypeProto does not have ValueCase set"); break; default: - // Not implemented + ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); break; } - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); + return result; } -OrtStatus* OrtTypeInfo::Clone(OrtTypeInfo** out) { +std::unique_ptr OrtTypeInfo::Clone() const { + std::unique_ptr result; switch (type) { - case ONNX_TYPE_TENSOR: - case ONNX_TYPE_SPARSETENSOR: { + case ONNX_TYPE_SPARSETENSOR: #if !defined(DISABLE_SPARSE_TENSORS) - OrtTensorTypeAndShapeInfo* clone; - if (auto status = data->Clone(&clone)) { - return status; - } - *out = new OrtTypeInfo(type, clone); - (*out)->denotation = denotation; - return nullptr; + [[fallthrough]]; #else - return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build."); + ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); #endif - } - case ONNX_TYPE_SEQUENCE: { - OrtSequenceTypeInfo* clone; - if (auto status = sequence_type_info->Clone(&clone)) { - return status; + case ONNX_TYPE_TENSOR: { + std::unique_ptr info; + if (data) { + info = data->Clone(); } - *out = new OrtTypeInfo(type, clone); - (*out)->denotation = denotation; - return nullptr; - } + result = MakePtr(type, std::move(info)); + result->denotation = denotation; + } break; + + case ONNX_TYPE_SEQUENCE: { + auto seq_clone = sequence_type_info->Clone(); + result = MakePtr(std::move(seq_clone)); + result->denotation = denotation; + } break; case ONNX_TYPE_MAP: { - OrtMapTypeInfo* clone; - if (auto status = map_type_info->Clone(&clone)) { - return status; - } - *out = new OrtTypeInfo(type, clone); - (*out)->denotation = denotation; - return nullptr; - } + auto map_clone = map_type_info->Clone(); + result = MakePtr(std::move(map_clone)); + result->denotation = denotation; + } break; + case ONNX_TYPE_OPTIONAL: { + auto opt_clone = optional_type_info->Clone(); + result = MakePtr(std::move(opt_clone)); + result->denotation = denotation; + } break; case ONNX_TYPE_OPAQUE: { - *out = new OrtTypeInfo(type); - (*out)->denotation = denotation; - return nullptr; - } + result = MakePtr(type); + result->denotation = denotation; + } break; default: - // Not implemented + ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); break; } - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); + + return result; } diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 5b9145d32e28c..06b0b8d989f55 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -2,13 +2,15 @@ // Licensed under the MIT License. #pragma once -#include + +#include #include +#include + #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { class DataTypeImpl; -class TensorShape; } // namespace onnxruntime namespace ONNX_NAMESPACE { @@ -18,33 +20,48 @@ class TypeProto; // These types are only present in the winml adapter c api, so they are forward declared. struct OrtMapTypeInfo; struct OrtSequenceTypeInfo; +struct OrtOptionalTypeInfo; +struct OrtTensorTypeAndShapeInfo; /** * the equivalent of ONNX_NAMESPACE::TypeProto * This class is mainly for the C API */ struct OrtTypeInfo { - public: - ONNXType type = ONNX_TYPE_UNKNOWN; - std::string denotation; - ~OrtTypeInfo(); + ONNXType type; + std::string denotation; - //owned by this - OrtTensorTypeAndShapeInfo* data = nullptr; - OrtMapTypeInfo* map_type_info = nullptr; - OrtSequenceTypeInfo* sequence_type_info = nullptr; - OrtTypeInfo(const OrtTypeInfo& other) = delete; - OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete; + std::unique_ptr data; + std::unique_ptr map_type_info; + std::unique_ptr sequence_type_info; + std::unique_ptr optional_type_info; - OrtStatus* Clone(OrtTypeInfo** out); + std::unique_ptr Clone() const; - static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out); - static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out); + static std::unique_ptr FromOrtValue(const OrtValue& value); + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); - OrtTypeInfo(ONNXType type) noexcept; - OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept; - OrtTypeInfo(ONNXType type, OrtMapTypeInfo* map_type_info) noexcept; - OrtTypeInfo(ONNXType type, OrtSequenceTypeInfo* sequence_type_info) noexcept; + explicit OrtTypeInfo(ONNXType type) noexcept; + + explicit OrtTypeInfo(std::unique_ptr map_type_info) noexcept; + + OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept; + + explicit OrtTypeInfo(std::unique_ptr sequence_type_info) noexcept; + + explicit OrtTypeInfo(std::unique_ptr optional_type_info) noexcept; + + + OrtTypeInfo(const OrtTypeInfo&) = delete; + OrtTypeInfo& operator=(const OrtTypeInfo&) = delete; + + ~OrtTypeInfo(); + + template + static std::unique_ptr MakePtr(Args... args) { + return std::make_unique(std::forward(args)...); + } + }; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index bebfa72c546a5..1b73ed1d837b2 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -24,20 +24,23 @@ using onnxruntime::MLFloat16; #if !defined(DISABLE_SPARSE_TENSORS) using onnxruntime::SparseTensor; #endif -using onnxruntime::Tensor; using onnxruntime::narrow; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif +using onnxruntime::Tensor; + +OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo() = default; +OrtTensorTypeAndShapeInfo::~OrtTensorTypeAndShapeInfo() = default; +OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = default; +OrtTensorTypeAndShapeInfo& OrtTensorTypeAndShapeInfo::operator=(const OrtTensorTypeAndShapeInfo& other) = default; + ORT_API_STATUS_IMPL(OrtApis::CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN - *out = new OrtTensorTypeAndShapeInfo(); + *out = std::make_unique().release(); return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseTensorTypeAndShapeInfo, _Frees_ptr_opt_ OrtTensorTypeAndShapeInfo* ptr) { - delete ptr; + std::unique_ptr p(ptr); } ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* this_ptr, @@ -48,29 +51,34 @@ ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShape API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, _In_ const int64_t* dim_values, size_t dim_count) { +ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, + _In_ const int64_t* dim_values, size_t dim_count) { API_IMPL_BEGIN this_ptr->shape = onnxruntime::TensorShape(dim_values, dim_count); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetTensorElementType, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ ONNXTensorElementDataType* out) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorElementType, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ ONNXTensorElementDataType* out) { *out = info->type; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetDimensionsCount, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtApis::GetDimensionsCount, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ size_t* out) { *out = info->shape.NumDimensions(); return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { +ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ int64_t* dim_values, size_t dim_values_length) { info->shape.CopyDims(dim_values, dim_values_length); return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, +ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, + _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_writes_all_(dim_params_length) const char** names, size_t dim_params_length) { for (size_t idx = 0, end = std::min(info->dim_params.size(), dim_params_length); idx < end; ++idx) { names[idx] = info->dim_params[idx].c_str(); @@ -79,7 +87,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorT return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, + _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { API_IMPL_BEGIN *out = SafeInt{this_ptr->shape.Size()}; return nullptr; @@ -151,45 +160,39 @@ ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( return TensorDataTypeToOnnxRuntimeTensorElementDataType(prim_type->GetDataType()); } -OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape shape, - const std::vector* dim_params, OrtTensorTypeAndShapeInfo** out) { - OrtTensorTypeAndShapeInfo* ret; - if (auto* status = OrtApis::CreateTensorTypeAndShapeInfo(&ret)) - return status; - if (auto* status = OrtApis::SetTensorElementType(ret, type)) { - OrtApis::ReleaseTensorTypeAndShapeInfo(ret); - return status; - } - - auto* status = OrtApis::SetDimensions(ret, shape.GetDims().data(), shape.GetDims().size()); - if (status != nullptr) { - OrtApis::ReleaseTensorTypeAndShapeInfo(ret); - return status; - } +std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper( + ONNXTensorElementDataType type, + onnxruntime::TensorShape shape, + const std::vector* dim_params) { + auto type_and_shape = std::make_unique(); + type_and_shape->type = type; + type_and_shape->shape = std::move(shape); if (dim_params != nullptr) { - ret->dim_params = *dim_params; + type_and_shape->dim_params = *dim_params; } else { // we expect to be being called with a concrete shape so validate that - assert(shape.Size() >= 0); - ret->dim_params.resize(shape.NumDimensions(), ""); + assert(type_and_shape->shape.Size() >= 0); + type_and_shape->dim_params.resize(type_and_shape->shape.NumDimensions(), ""); } - *out = ret; - return nullptr; + return type_and_shape; } -OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, - const onnxruntime::DataTypeImpl& tensor_data_type, OrtTensorTypeAndShapeInfo** out) { +std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const onnxruntime::DataTypeImpl& tensor_data_type) { ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented"); + ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, shape, nullptr, out); + return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr); } -OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out) { +std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto& type_proto) { auto value_case = type_proto.value_case(); assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType); @@ -198,19 +201,17 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const st : type_proto.sparse_tensor_type().elem_type(); ONNXTensorElementDataType type = TensorDataTypeToOnnxRuntimeTensorElementDataType(dtype); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented"); + ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, shape, dim_params, out); + return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params); } -OrtStatus* OrtTensorTypeAndShapeInfo::Clone(OrtTensorTypeAndShapeInfo** out) { - return GetTensorShapeAndTypeHelper(type, shape, &dim_params, out); -} - -ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, + _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN if (!v->IsAllocated()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "the ort_value must contain a constructed tensor or sparse tensor"); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "the ort_value must contain a constructed tensor or sparse tensor"); } if (v->IsTensor() || v->IsSparseTensor()) { const onnxruntime::TensorShape* shape = nullptr; @@ -219,17 +220,23 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out const Tensor& tensor = v->Get(); shape = &tensor.Shape(); data_type = tensor.DataType(); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + *out = ptr.release(); } else { #if !defined(DISABLE_SPARSE_TENSORS) const SparseTensor& tensor = v->Get(); shape = &tensor.DenseShape(); data_type = tensor.DataType(); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + *out = ptr.release(); +#else + ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); #endif } - return GetTensorShapeAndType(*shape, *data_type, out); } else { ORT_THROW("Argument is not a tensor"); } + return nullptr; API_IMPL_END } @@ -239,7 +246,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValuesTypeAndShape, _In_ const OrtVa #if !defined(DISABLE_SPARSE_TENSORS) const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v); const auto& values = sparse_tensor.Values(); - return GetTensorShapeAndType(values.Shape(), *values.DataType(), out); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType()); + *out = ptr.release(); + return nullptr; #else ORT_UNUSED_PARAMETER(v); ORT_UNUSED_PARAMETER(out); @@ -279,7 +288,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndicesTypeShape, _In_ const OrtValu API_IMPL_BEGIN #if !defined(DISABLE_SPARSE_TENSORS) const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format); - return GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType(), out); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType()); + *out = ptr.release(); + return nullptr; #else ORT_UNUSED_PARAMETER(v); ORT_UNUSED_PARAMETER(indices_format); @@ -309,13 +320,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndices, _In_ const OrtValue* v, ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXType* out) { API_IMPL_BEGIN - OrtTypeInfo* type_info; - auto status = OrtTypeInfo::FromOrtValue(*v, &type_info); - if (status != nullptr) - return status; - + auto type_info = OrtTypeInfo::FromOrtValue(*v); *out = type_info->type; - OrtApis::ReleaseTypeInfo(type_info); return nullptr; API_IMPL_END } @@ -325,7 +331,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXTyp * \param value * \return The returned value should be freed by OrtReleaseTypeInfo after use */ -ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, _Outptr_result_maybenull_ struct OrtTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, + _In_ const OrtValue* v, _Outptr_result_maybenull_ struct OrtTypeInfo** out) { API_IMPL_BEGIN // TODO: This is consistent with the previous implementation but inconsistent with GetValueType which returns // ONNX_TYPE_UNKNOWN if v->Type() is null. Should we instead just call OrtTypeInfo::FromOrtValue and @@ -334,8 +341,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, _Outptr_result *out = nullptr; return nullptr; } - - auto status = OrtTypeInfo::FromOrtValue(*v, out); - return status; + auto ptr = OrtTypeInfo::FromOrtValue(*v); + *out = ptr.release(); + return nullptr; API_IMPL_END } \ No newline at end of file diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index affc3c98a506c..9da1d8cd64145 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -2,25 +2,59 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include "core/framework/tensor_shape.h" #include "core/session/onnxruntime_c_api.h" +namespace ONNX_NAMESPACE { +class TypeProto; +} + +namespace onnxruntime { +class DataTypeImpl; +} + struct OrtTensorTypeAndShapeInfo { public: + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; onnxruntime::TensorShape shape; // dim_param values. empty string if dim_value or no dim_param was specified. // one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs std::vector dim_params; - OrtTensorTypeAndShapeInfo() = default; - OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; - OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete; + OrtTensorTypeAndShapeInfo(); + ~OrtTensorTypeAndShapeInfo(); + + // Utils + static std::unique_ptr GetTensorShapeAndTypeHelper( + ONNXTensorElementDataType type, + onnxruntime::TensorShape shape, + const std::vector* dim_params); + + static std::unique_ptr GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const onnxruntime::DataTypeImpl& tensor_data_type); + + static std::unique_ptr GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto&); + + // We provide Clone() here to satisfy the existing coding pattern + // as we need copies made on the heap even though we achieve that + // via a copy __ctor which can not be made private due to make_unique + // which is a requirement. + std::unique_ptr Clone() const { + return std::make_unique(*this); + } - OrtStatus* Clone(OrtTensorTypeAndShapeInfo** out); + // Copy ops are public because std::make_unique above requires them to be accessible + OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other); + OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other); }; constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementDataType(int32_t dtype); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 924bc34b2b9ea..154e5302382b6 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -256,7 +256,9 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputTypeInfo, _In_ const OrtKernelIn return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo input does not have a type"); } - return OrtTypeInfo::FromTypeProto(type_proto, type_info); + auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); + *type_info = type_info_ret.release(); + return nullptr; API_IMPL_END } @@ -277,7 +279,9 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelI return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo output does not have a type"); } - return OrtTypeInfo::FromTypeProto(type_proto, type_info); + auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); + *type_info = type_info_ret.release(); + return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 2283a0dbfe965..ae201eebb2fee 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1324,7 +1324,9 @@ static ORT_STATUS_PTR GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefLis if (p.second->size() <= index) return OrtApis::CreateStatus(ORT_FAIL, "out of index"); const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto(); - return OrtTypeInfo::FromTypeProto(type_proto, out); + auto type_info = OrtTypeInfo::FromTypeProto(*type_proto); + *out = type_info.release(); + return nullptr; API_IMPL_END } @@ -2703,6 +2705,8 @@ static constexpr OrtApi ort_api_1_to_15 = { &OrtApis::Logger_LogMessage, &OrtApis::Logger_GetLoggingSeverityLevel, &OrtApis::KernelInfoGetConstantInput_tensor, + &OrtApis::CastTypeInfoToOptionalTypeInfo, + &OrtApis::GetOptionalContainedTypeInfo }; // Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 67b425a5e2c91..7cdacbbaf9e72 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -453,4 +453,10 @@ ORT_API_STATUS_IMPL(Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger ORT_API_STATUS_IMPL(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, _In_ size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); + +ORT_API_STATUS_IMPL(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); + +ORT_API_STATUS_IMPL(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out); } // namespace OrtApis diff --git a/onnxruntime/test/framework/model_builder_utils.h b/onnxruntime/test/framework/model_builder_utils.h index b959d5cb2b571..6b8fdf03112c0 100644 --- a/onnxruntime/test/framework/model_builder_utils.h +++ b/onnxruntime/test/framework/model_builder_utils.h @@ -57,6 +57,29 @@ struct Type { dim->set_dim_param(d); } } + + static Type MakeSequence(const ONNX_NAMESPACE::TypeProto& element_proto) { + ONNX_NAMESPACE::TypeProto proto; + proto.mutable_sequence_type()->mutable_elem_type()->CopyFrom(element_proto); + return Type(std::move(proto)); + } + + static Type MakeMap(ONNX_NAMESPACE::TensorProto_DataType dtype, const ONNX_NAMESPACE::TypeProto& value_proto) { + ONNX_NAMESPACE::TypeProto proto; + auto& mut_map = *proto.mutable_map_type(); + mut_map.set_key_type(static_cast(dtype)); + mut_map.mutable_value_type()->CopyFrom(value_proto); + return Type(std::move(proto)); + } + + static Type MakeOptional(const ONNX_NAMESPACE::TypeProto& contained_proto) { + ONNX_NAMESPACE::TypeProto proto; + proto.mutable_optional_type()->mutable_elem_type()->CopyFrom(contained_proto); + return Type(std::move(proto)); + } + +private: + explicit Type(ONNX_NAMESPACE::TypeProto type_proto) : value(std::move(type_proto)) {} }; } // namespace modelbuilder diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc new file mode 100644 index 0000000000000..ee787fb071d97 --- /dev/null +++ b/onnxruntime/test/framework/type_info_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "model_builder_utils.h" + +#include "core/framework/onnxruntime_optional_type_info.h" +#include "core/framework/onnxruntime_map_type_info.h" +#include "core/framework/onnxruntime_sequence_type_info.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/onnxruntime_typeinfo.h" + +namespace onnxruntime { +namespace test { + +namespace mb = modelbuilder; + +TEST(TypeInfoTests, TensorProto) { + mb::Type tensor_type = {1, 2, 3, 4}; + + auto tensor_type_info = OrtTypeInfo::FromTypeProto(tensor_type.value); + ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info->type); + ASSERT_NE(nullptr, tensor_type_info->data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->data->shape.GetDims())); +} + +TEST(TypeInfoTests, SequenceWithTensorElement) { + mb::Type tensor_type = {1, 2, 3, 4}; + auto sequence_proto = mb::Type::MakeSequence(tensor_type.value); + auto seq_type_info = OrtTypeInfo::FromTypeProto(sequence_proto.value); + + ASSERT_EQ(ONNX_TYPE_SEQUENCE, seq_type_info->type); + ASSERT_NE(nullptr, seq_type_info->sequence_type_info); + const auto& tensor_type_info = *seq_type_info->sequence_type_info->sequence_key_type_; + + ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); + ASSERT_NE(nullptr, tensor_type_info.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); +} + +TEST(TypeInfoTests, OptionalWithTensorProto) { + mb::Type tensor_type = {1, 2, 3, 4}; + auto optional_proto = mb::Type::MakeOptional(tensor_type.value); + + auto optional_type_info = OrtTypeInfo::FromTypeProto(optional_proto.value); + + ASSERT_EQ(ONNX_TYPE_OPTIONAL, optional_type_info->type); + ASSERT_NE(nullptr, optional_type_info->optional_type_info); + ASSERT_NE(nullptr, optional_type_info->optional_type_info->contained_type_); + + const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; + ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); + ASSERT_NE(nullptr, contained_type.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); +} + +#if !defined(DISABLE_ML_OPS) +TEST(TypeInfoTests, MapWithTensorValue) { + mb::Type value_type = {1, 2, 3, 4}; + auto map_proto = mb::Type::MakeMap(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, value_type.value); + auto map_type_info = OrtTypeInfo::FromTypeProto(map_proto.value); + + ASSERT_EQ(ONNX_TYPE_MAP, map_type_info->type); + ASSERT_NE(nullptr, map_type_info->map_type_info); + const auto& map_info = *map_type_info->map_type_info; + + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, map_info.map_key_type_); + ASSERT_NE(nullptr, map_info.map_value_type_); + const auto& tensor_type_info = *map_info.map_value_type_; + + ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); + ASSERT_NE(nullptr, tensor_type_info.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); +} +#endif + +} // namespace test +} // namespace onnxruntime \ No newline at end of file diff --git a/tools/ci_build/github/Doxyfile_csharp.cfg b/tools/ci_build/github/Doxyfile_csharp.cfg index 78fb4d5e9af54..dccc15ed1137d 100644 --- a/tools/ci_build/github/Doxyfile_csharp.cfg +++ b/tools/ci_build/github/Doxyfile_csharp.cfg @@ -1,20 +1,146 @@ -## Onnxruntime C# API Doxygen configuration file -# Doxyfile 1.8.20 +# Doxyfile 1.9.4 + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables: +# doxygen -x_noenv [configFile] #--------------------------------------------------------------------------- # Project related configuration options #--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. +# The default value is: UTF-8. + +## Onnxruntime C# API Doxygen configuration file + DOXYFILE_ENCODING = UTF-8 -PROJECT_NAME = "Onnxruntime" + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = Onnxruntime + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + PROJECT_NUMBER = + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + PROJECT_BRIEF = + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + PROJECT_LOGO = + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + OUTPUT_DIRECTORY = $(ORT_DOXY_OUT)\csharp_dox + +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. +# The default value is: NO. + CREATE_SUBDIRS = NO + +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# numer of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. +# The default value is: English. + OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + BRIEF_MEMBER_DESC = YES + +# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + ABBREVIATE_BRIEF = "The $name class" \ "The $name widget" \ "The $name file" \ @@ -26,313 +152,2539 @@ ABBREVIATE_BRIEF = "The $name class" \ a \ an \ the + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + JAVADOC_AUTOBRIEF = NO + +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + JAVADOC_BANNER = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + MULTILINE_CPP_IS_BRIEF = NO + +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + PYTHON_DOCSTRING = YES + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:^^" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) + ALIASES = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + OPTIMIZE_OUTPUT_VHDL = NO + +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + OPTIMIZE_OUTPUT_SLICE = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. + EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See https://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + MARKDOWN_SUPPORT = YES + +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + TOC_INCLUDE_HEADINGS = 5 + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + BUILTIN_STL_SUPPORT = NO + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + DISTRIBUTE_GROUP_DOC = NO + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + LOOKUP_CACHE_SIZE = 0 + +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + NUM_PROC_THREADS = 1 + #--------------------------------------------------------------------------- # Build related configuration options #--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + EXTRACT_ALL = NO + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + EXTRACT_PRIVATE = NO + +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + EXTRACT_PRIV_VIRTUAL = NO + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + EXTRACT_ANON_NSPACES = NO + +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# declarations. If set to NO, these declarations will be included in the +# documentation. +# The default value is: NO. + HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + INTERNAL_DOCS = NO + +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# The default value is: system dependent. + CASE_SENSE_NAMES = NO + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + HIDE_SCOPE_NAMES = NO + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + SHOW_USED_FILES = NO + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + FILE_VERSION_FILTER = "git -C $(ORT_DOXY_SRC) log -n 1 --format=%h -- afile" + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + CITE_BIB_FILES = + #--------------------------------------------------------------------------- # Configuration options related to warning and progress messages #--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. +# The default value is: YES. + WARN_IF_DOC_ERROR = YES + +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC +# The default value is: NO. + WARN_NO_PARAMDOC = YES + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# Possible values are: NO, YES and FAIL_ON_WARNINGS. +# The default value is: NO. + WARN_AS_ERROR = YES + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT +# The default value is: $file:$line: $text. + WARN_FORMAT = "$file:$line: $text" -WARN_LOGFILE = + +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). + +WARN_LOGFILE = + #--------------------------------------------------------------------------- # Configuration options related to the input files #--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + INPUT = $(ORT_DOXY_SRC)\csharp\src\Microsoft.ML.OnnxRuntime \ - $(ORT_DOXY_SRC)\csharp\src\Microsoft.ML.OnnxRuntime\Tensors + $(ORT_DOXY_SRC)\csharp\src\Microsoft.ML.OnnxRuntime\Tensors + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# The default value is: UTF-8. + INPUT_ENCODING = UTF-8 + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by doxygen. +# +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, +# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. + FILE_PATTERNS = *.cs + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + RECURSIVE = NO + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + EXCLUDE = + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + EXCLUDE_PATTERNS = Native*.cs + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# ANamespace::AClass, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + EXCLUDE_SYMBOLS = -EXAMPLE_PATH = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + EXAMPLE_PATTERNS = * + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + USE_MDFILE_AS_MAINPAGE = + #--------------------------------------------------------------------------- # Configuration options related to source browsing #--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# entity all documented functions referencing it will be listed. +# The default value is: NO. + REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see https://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + VERBATIM_HEADERS = YES + +# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the +# clang parser (see: +# http://clang.llvm.org/) for more accurate parsing at the cost of reduced +# performance. This can be particularly helpful with template rich C++ code for +# which doxygen's built-in parser lacks the necessary type information. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. +# The default value is: NO. + CLANG_ASSISTED_PARSING = NO + +# If the CLANG_ASSISTED_PARSING tag is set to YES and the CLANG_ADD_INC_PATHS +# tag is set to YES then doxygen will add the directory of each input to the +# include path. +# The default value is: YES. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + +CLANG_ADD_INC_PATHS = YES + +# If clang assisted parsing is enabled you can provide the compiler with command +# line options that you would normally use when invoking the compiler. Note that +# the include paths will already be set by doxygen for the files and directories +# specified with INPUT and INCLUDE_PATH. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + CLANG_OPTIONS = + +# If clang assisted parsing is enabled you can provide the clang parser with the +# path to the directory containing a file called compile_commands.json. This +# file is the compilation database (see: +# http://clang.llvm.org/docs/HowToSetupToolingForLLVM.html) containing the +# options used when the source files were built. This is equivalent to +# specifying the -p option to a clang tool, such as clang-check. These options +# will then be passed to the parser. Any options specified with CLANG_OPTIONS +# will be added as well. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. + CLANG_DATABASE_PATH = + #--------------------------------------------------------------------------- # Configuration options related to the alphabetical class index #--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + ALPHABETICAL_INDEX = YES + +# In case all classes in a project start with a common prefix, all classes will +# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag +# can be used to specify a prefix (or a list of prefixes) that should be ignored +# while generating the index headers. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + IGNORE_PREFIX = + #--------------------------------------------------------------------------- # Configuration options related to the HTML output #--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output +# The default value is: YES. + GENERATE_HTML = YES + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). For an example see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use gray-scales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to YES can help to show when doxygen was last run and thus if the +# documentation is up to date. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_TIMESTAMP = NO + +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_DYNAMIC_MENUS = YES + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the main .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_TREEVIEW = NO + +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + ENUM_VALUES_PER_LINE = 4 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + EXT_LINKS_IN_WINDOW = NO + +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_FORMULA_FORMAT = png + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + FORMULA_FONTSIZE = 10 + +# Use the FORMULA_TRANSPARENT tag to determine whether or not the images +# generated for formulas are transparent PNGs. Transparent PNGs are not +# supported properly for IE 6.0, but are supported on all modern browsers. +# +# Note that when changing this option you need to delete any form_*.png files in +# the HTML output directory before the changes have effect. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + FORMULA_TRANSPARENT = YES + +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. + FORMULA_MACROFILE = + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# https://www.mathjax.org) which uses client side JavaScript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + USE_MATHJAX = NO + +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_RELPATH = https://cdn.jsdelivr.net/npm/mathjax@2 + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /