Skip to content

Commit

Permalink
Add negative tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Jan 31, 2022
1 parent 14c5c05 commit af53a53
Showing 1 changed file with 53 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// Licensed under the MIT License.

using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.Win32.SafeHandles;
using System;
using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
using static Microsoft.ML.OnnxRuntime.Tests.InferenceTest;

Expand All @@ -23,15 +25,36 @@ private static void PopulateNativeBufferFloat(OrtMemoryAllocation buffer, float[
Assert.True(false);
}

PopulateNativeBuffer(buffer.Pointer, elements);
}

private static void PopulateNativeBuffer(IntPtr buffer, float[] elements)
{
unsafe
{
float* p = (float*)buffer.Pointer;
float* p = (float*)buffer;
for (int i = 0; i < elements.Length; ++i)
{
*p++ = elements[i];
}
}
}
/// <summary>
/// Use to free globally allocated memory
/// </summary>
class OrtSafeMemoryHandle : SafeHandle
{
public OrtSafeMemoryHandle(IntPtr allocPtr) : base(allocPtr, true) { }

public override bool IsInvalid => handle == IntPtr.Zero;

protected override bool ReleaseHandle()
{
Marshal.FreeHGlobal(handle);
handle = IntPtr.Zero;
return true;
}
}

[Fact(DisplayName = "TestIOBindingWithOrtAllocation")]
public void TestIOBindingWithOrtAllocation()
Expand Down Expand Up @@ -65,10 +88,12 @@ public void TestIOBindingWithOrtAllocation()
Assert.Equal(shapeSize, inputData.Length);
PopulateNativeBufferFloat(ortAllocationInput, inputData);

// Re-use ORT allocated CPU buffer to present this as external allocation
// Create an external allocation for testing OrtExternalAllocation
var cpuMemInfo = OrtMemoryInfo.DefaultInstance;
var sizeInBytes = shapeSize * sizeof(float);
var externalInputAllocation = new OrtExternalAllocation(ortAllocationInput.Info, inputShape,
Tensors.TensorElementType.Float, ortAllocationInput.Pointer, sizeInBytes);
IntPtr allocPtr = Marshal.AllocHGlobal((int)sizeInBytes);
dispList.Add(new OrtSafeMemoryHandle(allocPtr));
PopulateNativeBuffer(allocPtr, inputData);

var ortAllocationOutput = allocator.Allocate((uint)outputData.Length * sizeof(float));
dispList.Add(ortAllocationOutput);
Expand Down Expand Up @@ -109,8 +134,11 @@ public void TestIOBindingWithOrtAllocation()
Assert.Equal(outputData, tensor.ToArray<float>(), new FloatComparer());
}
}
// 3. Pretend we are using external allocation which is currently on CPU
// 3. Test external allocation
{
var externalInputAllocation = new OrtExternalAllocation(cpuMemInfo, inputShape,
Tensors.TensorElementType.Float, allocPtr, sizeInBytes);

ioBinding.BindInput(inputName, externalInputAllocation);
ioBinding.BindOutput(outputName, Tensors.TensorElementType.Float, outputShape, ortAllocationOutput);
ioBinding.SynchronizeBoundInputs();
Expand All @@ -125,6 +153,26 @@ public void TestIOBindingWithOrtAllocation()
Assert.Equal(outputData, tensor.ToArray<float>(), new FloatComparer());
}
}
// 4. Some negative tests for external allocation
{
// Small buffer size
Action smallBuffer = delegate ()
{
new OrtExternalAllocation(cpuMemInfo, inputShape,
Tensors.TensorElementType.Float, allocPtr, sizeInBytes - 10);
};

Assert.Throws<OnnxRuntimeException>(smallBuffer);

Action stringType = delegate ()
{
new OrtExternalAllocation(cpuMemInfo, inputShape,
Tensors.TensorElementType.String, allocPtr, sizeInBytes);
};

Assert.Throws<OnnxRuntimeException>(stringType);

}

}
}
Expand Down

0 comments on commit af53a53

Please sign in to comment.