Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Initial Float16 and BFloat16 onnx type support #31

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 145 additions & 3 deletions OnnxStack.Core/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core.Config;
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Runtime.InteropServices;

namespace OnnxStack.Core
{
Expand Down Expand Up @@ -205,26 +207,166 @@ public static T GetBufferLength<T>(this ReadOnlySpan<T> array) where T : INumber
}


/// <summary>
/// Converts to long.
/// </summary>
/// <param name="array">The array.</param>
/// <returns></returns>
public static long[] ToLong(this ReadOnlySpan<int> array)
{
return Array.ConvertAll(array.ToArray(), Convert.ToInt64);
}



/// <summary>
/// Converts the string representation of a number to an integer.
/// </summary>
/// <param name="array">The array.</param>
/// <returns></returns>
public static int[] ToInt(this long[] array)
{
return Array.ConvertAll(array, Convert.ToInt32);
}


/// <summary>
/// Converts to long.
/// </summary>
/// <param name="array">The array.</param>
/// <returns></returns>
public static long[] ToLong(this int[] array)
{
return Array.ConvertAll(array, Convert.ToInt64);
}


public static OrtValue ToOrtValue<T>(this DenseTensor<T> tensor) where T : unmanaged
/// <summary>
/// Creates and OrtValue form the DenseTensor and NodeMetaData provided
/// </summary>
/// <param name="tensor">The tensor.</param>
/// <param name="nodeMetadata">The node metadata.</param>
/// <returns></returns>
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, NodeMetadata nodeMetadata)
{
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
var dimensions = tensor.Dimensions.ToLong();
return nodeMetadata.ElementDataType switch
{
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
_ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions)
};
}


/// <summary>
/// Creates and allocates output tensors buffer.
/// </summary>
/// <param name="nodeMetadata">The node metadata.</param>
/// <param name="dimensions">The dimensions.</param>
/// <returns></returns>
public static OrtValue CreateOutputBuffer(this NodeMetadata nodeMetadata, ReadOnlySpan<int> dimensions)
{
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, nodeMetadata.ElementDataType, dimensions.ToLong());
}


/// <summary>
/// Converts to DenseTensor<float>.
/// </summary>
/// <param name="ortValue">The ort value.</param>
/// <returns></returns>
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue)
{
var typeInfo = ortValue.GetTensorTypeAndShape();
var dimensions = typeInfo.Shape.ToInt();
return typeInfo.ElementDataType switch
{
TensorElementType.Float16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<Float16>().ToFloat(), dimensions),
TensorElementType.BFloat16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat(), dimensions),
_ => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<float>().ToArray(), dimensions)
};
}


/// <summary>
/// Converts to array.
/// </summary>
/// <param name="ortValue">The ort value.</param>
/// <returns></returns>
public static float[] ToArray(this OrtValue ortValue)
{
var typeInfo = ortValue.GetTensorTypeAndShape();
var dimensions = typeInfo.Shape.ToInt();
return typeInfo.ElementDataType switch
{
TensorElementType.Float16 => ortValue.GetTensorDataAsSpan<Float16>().ToFloat().ToArray(),
TensorElementType.BFloat16 => ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat().ToArray(),
_ => ortValue.GetTensorDataAsSpan<float>().ToArray()
};
}


/// <summary>
/// Converts to float16.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<Float16> ToFloat16(this Memory<float> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new Float16[inputMemory.Length];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (Float16)inputMemory.Span[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to BFloat16.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<BFloat16> ToBFloat16(this Memory<float> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new BFloat16[inputMemory.Length];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (BFloat16)inputMemory.Span[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to float.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<float> ToFloat(this ReadOnlySpan<Float16> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new float[elementCount];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (float)inputMemory[i];

return floatArray.AsMemory();
}


/// <summary>
/// Converts to float.
/// </summary>
/// <param name="inputMemory">The input memory.</param>
/// <returns></returns>
internal static Memory<float> ToFloat(this ReadOnlySpan<BFloat16> inputMemory)
{
var elementCount = inputMemory.Length;
var floatArray = new float[elementCount];
for (int i = 0; i < elementCount; i++)
floatArray[i] = (float)inputMemory[i];

return floatArray.AsMemory();
}
}
}
24 changes: 14 additions & 10 deletions OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,20 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption

var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);
var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeDecoder);
var outputTensorMetaData = outputMetaData[outputNames[0]];

var outputDim = new[] { 1, 3, options.Height, options.Width };
var outputBuffer = new DenseTensor<float>(outputDim);
using (var inputTensorValue = latents.ToOrtValue())
using (var outputTensorValue = outputBuffer.ToOrtValue())
using (var inputTensorValue = latents.ToOrtValue(outputTensorMetaData))
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDim))
{
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
using (var imageResult = results.First())
{
_logger?.LogEnd("End", timestamp);
return outputBuffer;
return imageResult.ToDenseTensor();
}
}
}
Expand All @@ -237,13 +238,16 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption
/// <param name="timestepInputName">Name of the timestep input.</param>
/// <param name="timestep">The timestep.</param>
/// <returns></returns>
protected static OrtValue CreateTimestepNamedOrtValue(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata, string timestepInputName, int timestep)
protected static OrtValue CreateTimestepNamedOrtValue(NodeMetadata timestepMetaData, int timestep)
{
// Some models support Long or Float, could be more but fornow just support these 2
var timestepMetaData = nodeMetadata[timestepInputName];
return timestepMetaData.ElementDataType == TensorElementType.Int64
? OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, new long[] { 1 })
: OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, new long[] { 1 });
var dimension = new long[] { 1 };
return timestepMetaData.ElementDataType switch
{
TensorElementType.Int64 => OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, dimension),
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(new Float16[] { (Float16)timestep }, dimension),
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(new BFloat16[] { (BFloat16)timestep }, dimension),
_ => OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, dimension) // TODO: Deafult to Float32 for now
};
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Config;
using OnnxStack.Core.Services;
using OnnxStack.StableDiffusion.Common;
Expand All @@ -12,7 +13,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using OnnxStack.Core;

namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
{
Expand Down Expand Up @@ -61,19 +61,22 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeEncoder);
var outputTensorMetaData = outputMetaData[outputNames[0]];

//TODO: Model Config, Channels
var outputBuffer = new DenseTensor<float>(options.GetScaledDimension());
using (var inputTensorValue = imageTensor.ToOrtValue())
using (var outputTensorValue = outputBuffer.ToOrtValue())
var outputDimension = options.GetScaledDimension();
using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData))
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
{
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
using (var result = results.First())
{
var scaledSample = outputBuffer
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
var outputResult = outputTensorValue.ToDenseTensor();
var scaledSample = outputResult
.Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel))
.MultiplyBy(model.ScaleFactor);

return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet);
var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet);
var outputMetaData = _onnxModelService.GetOutputMetadata(modelOptions, OnnxModelType.Unet);
var timestepMetaData = inputMetaData[inputNames[1]];
var outputTensorMetaData = outputMetaData[outputNames[0]];

// Loop though the timesteps
var step = 0;
Expand All @@ -120,17 +123,17 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
var inputTensor = scheduler.ScaleInput(latents, timestep);

var outputChannels = 1;
var outputBuffer = new DenseTensor<float>(schedulerOptions.GetScaledDimension(outputChannels));
using (var outputTensorValue = outputBuffer.ToOrtValue())
using (var inputTensorValue = inputTensor.ToOrtValue())
using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep))
using (var promptTensorValue = promptEmbeddings.ToOrtValue())
using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue())
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
using (var inputTensorValue = inputTensor.ToOrtValue(outputTensorMetaData))
using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputTensorMetaData))
using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue(outputTensorMetaData))
using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetaData, timestep))
{
var inputs = new Dictionary<string, OrtValue>
{
{ inputNames[0], inputTensorValue },
{ inputNames[1], timestepOrtValue },
{ inputNames[1], timestepTensorValue },
{ inputNames[2], promptTensorValue },
{ inputNames[3], guidanceTensorValue }
};
Expand All @@ -139,7 +142,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs);
using (var result = results.First())
{
var noisePred = outputBuffer;
var noisePred = outputTensorValue.ToDenseTensor();

// Scheduler Step
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Config;
using OnnxStack.Core.Services;
using OnnxStack.StableDiffusion.Common;
Expand All @@ -12,7 +13,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using OnnxStack.Core;

namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion
{
Expand Down Expand Up @@ -63,19 +63,22 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeEncoder);
var outputTensorMetaData = outputMetaData[outputNames[0]];

//TODO: Model Config, Channels
var outputBuffer = new DenseTensor<float>(options.GetScaledDimension());
using (var inputTensorValue = imageTensor.ToOrtValue())
using (var outputTensorValue = outputBuffer.ToOrtValue())
var outputDimension = options.GetScaledDimension();
using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData))
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
{
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
using (var result = results.First())
{
var scaledSample = outputBuffer
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
var outputResult = outputTensorValue.ToDenseTensor();
var scaledSample = outputResult
.Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel))
.MultiplyBy(model.ScaleFactor);

return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
Expand Down
Loading