diff --git a/OnnxStack.Core/Extensions.cs b/OnnxStack.Core/Extensions.cs index 508cf8e0..e6dc1e52 100644 --- a/OnnxStack.Core/Extensions.cs +++ b/OnnxStack.Core/Extensions.cs @@ -2,10 +2,12 @@ using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core.Config; using System; +using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Numerics; +using System.Runtime.InteropServices; namespace OnnxStack.Core { @@ -205,26 +207,166 @@ public static T GetBufferLength(this ReadOnlySpan array) where T : INumber } + /// + /// Converts to long. + /// + /// The array. + /// public static long[] ToLong(this ReadOnlySpan array) { return Array.ConvertAll(array.ToArray(), Convert.ToInt64); } - + + + /// + /// Converts the string representation of a number to an integer. + /// + /// The array. + /// public static int[] ToInt(this long[] array) { return Array.ConvertAll(array, Convert.ToInt32); } + + /// + /// Converts to long. + /// + /// The array. + /// public static long[] ToLong(this int[] array) { return Array.ConvertAll(array, Convert.ToInt64); } - public static OrtValue ToOrtValue(this DenseTensor tensor) where T : unmanaged + /// + /// Creates and OrtValue form the DenseTensor and NodeMetaData provided + /// + /// The tensor. + /// The node metadata. + /// + public static OrtValue ToOrtValue(this DenseTensor tensor, NodeMetadata nodeMetadata) { - return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong()); + var dimensions = tensor.Dimensions.ToLong(); + return nodeMetadata.ElementDataType switch + { + TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions), + TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions), + _ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions) + }; } + + /// + /// Creates and allocates output tensors buffer. + /// + /// The node metadata. + /// The dimensions. + /// + public static OrtValue CreateOutputBuffer(this NodeMetadata nodeMetadata, ReadOnlySpan dimensions) + { + return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, nodeMetadata.ElementDataType, dimensions.ToLong()); + } + + + /// + /// Converts to DenseTensor. + /// + /// The ort value. + /// + public static DenseTensor ToDenseTensor(this OrtValue ortValue) + { + var typeInfo = ortValue.GetTensorTypeAndShape(); + var dimensions = typeInfo.Shape.ToInt(); + return typeInfo.ElementDataType switch + { + TensorElementType.Float16 => new DenseTensor(ortValue.GetTensorDataAsSpan().ToFloat(), dimensions), + TensorElementType.BFloat16 => new DenseTensor(ortValue.GetTensorDataAsSpan().ToFloat(), dimensions), + _ => new DenseTensor(ortValue.GetTensorDataAsSpan().ToArray(), dimensions) + }; + } + + + /// + /// Converts to array. + /// + /// The ort value. + /// + public static float[] ToArray(this OrtValue ortValue) + { + var typeInfo = ortValue.GetTensorTypeAndShape(); + var dimensions = typeInfo.Shape.ToInt(); + return typeInfo.ElementDataType switch + { + TensorElementType.Float16 => ortValue.GetTensorDataAsSpan().ToFloat().ToArray(), + TensorElementType.BFloat16 => ortValue.GetTensorDataAsSpan().ToFloat().ToArray(), + _ => ortValue.GetTensorDataAsSpan().ToArray() + }; + } + + + /// + /// Converts to float16. + /// + /// The input memory. + /// + internal static Memory ToFloat16(this Memory inputMemory) + { + var elementCount = inputMemory.Length; + var floatArray = new Float16[inputMemory.Length]; + for (int i = 0; i < elementCount; i++) + floatArray[i] = (Float16)inputMemory.Span[i]; + + return floatArray.AsMemory(); + } + + + /// + /// Converts to BFloat16. + /// + /// The input memory. + /// + internal static Memory ToBFloat16(this Memory inputMemory) + { + var elementCount = inputMemory.Length; + var floatArray = new BFloat16[inputMemory.Length]; + for (int i = 0; i < elementCount; i++) + floatArray[i] = (BFloat16)inputMemory.Span[i]; + + return floatArray.AsMemory(); + } + + + /// + /// Converts to float. + /// + /// The input memory. + /// + internal static Memory ToFloat(this ReadOnlySpan inputMemory) + { + var elementCount = inputMemory.Length; + var floatArray = new float[elementCount]; + for (int i = 0; i < elementCount; i++) + floatArray[i] = (float)inputMemory[i]; + + return floatArray.AsMemory(); + } + + + /// + /// Converts to float. + /// + /// The input memory. + /// + internal static Memory ToFloat(this ReadOnlySpan inputMemory) + { + var elementCount = inputMemory.Length; + var floatArray = new float[elementCount]; + for (int i = 0; i < elementCount; i++) + floatArray[i] = (float)inputMemory[i]; + + return floatArray.AsMemory(); + } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs index d56e4502..2565aad7 100644 --- a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs +++ b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs @@ -212,11 +212,12 @@ protected virtual async Task> DecodeLatentsAsync(IModelOption var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder); var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder); + var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeDecoder); + var outputTensorMetaData = outputMetaData[outputNames[0]]; var outputDim = new[] { 1, 3, options.Height, options.Width }; - var outputBuffer = new DenseTensor(outputDim); - using (var inputTensorValue = latents.ToOrtValue()) - using (var outputTensorValue = outputBuffer.ToOrtValue()) + using (var inputTensorValue = latents.ToOrtValue(outputTensorMetaData)) + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDim)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue } }; var outputs = new Dictionary { { outputNames[0], outputTensorValue } }; @@ -224,7 +225,7 @@ protected virtual async Task> DecodeLatentsAsync(IModelOption using (var imageResult = results.First()) { _logger?.LogEnd("End", timestamp); - return outputBuffer; + return imageResult.ToDenseTensor(); } } } @@ -237,13 +238,16 @@ protected virtual async Task> DecodeLatentsAsync(IModelOption /// Name of the timestep input. /// The timestep. /// - protected static OrtValue CreateTimestepNamedOrtValue(IReadOnlyDictionary nodeMetadata, string timestepInputName, int timestep) + protected static OrtValue CreateTimestepNamedOrtValue(NodeMetadata timestepMetaData, int timestep) { - // Some models support Long or Float, could be more but fornow just support these 2 - var timestepMetaData = nodeMetadata[timestepInputName]; - return timestepMetaData.ElementDataType == TensorElementType.Int64 - ? OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, new long[] { 1 }) - : OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, new long[] { 1 }); + var dimension = new long[] { 1 }; + return timestepMetaData.ElementDataType switch + { + TensorElementType.Int64 => OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, dimension), + TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(new Float16[] { (Float16)timestep }, dimension), + TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(new BFloat16[] { (BFloat16)timestep }, dimension), + _ => OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, dimension) // TODO: Deafult to Float32 for now + }; } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs index 7a17e797..135b1f3c 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Logging; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; using OnnxStack.Core.Config; using OnnxStack.Core.Services; using OnnxStack.StableDiffusion.Common; @@ -12,7 +13,6 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using OnnxStack.Core; namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency { @@ -61,19 +61,22 @@ protected override async Task> PrepareLatentsAsync(IModelOpti var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder); var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder); + var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeEncoder); + var outputTensorMetaData = outputMetaData[outputNames[0]]; //TODO: Model Config, Channels - var outputBuffer = new DenseTensor(options.GetScaledDimension()); - using (var inputTensorValue = imageTensor.ToOrtValue()) - using (var outputTensorValue = outputBuffer.ToOrtValue()) + var outputDimension = options.GetScaledDimension(); + using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData)) + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue } }; var outputs = new Dictionary { { outputNames[0], outputTensorValue } }; var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs); using (var result = results.First()) { - var scaledSample = outputBuffer - .Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel)) + var outputResult = outputTensorValue.ToDenseTensor(); + var scaledSample = outputResult + .Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel)) .MultiplyBy(model.ScaleFactor); return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps); diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs index 499809b9..99c83f45 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs @@ -107,6 +107,9 @@ protected override async Task> SchedulerStepAsync(IModelOptio var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet); var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet); var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet); + var outputMetaData = _onnxModelService.GetOutputMetadata(modelOptions, OnnxModelType.Unet); + var timestepMetaData = inputMetaData[inputNames[1]]; + var outputTensorMetaData = outputMetaData[outputNames[0]]; // Loop though the timesteps var step = 0; @@ -120,17 +123,17 @@ protected override async Task> SchedulerStepAsync(IModelOptio var inputTensor = scheduler.ScaleInput(latents, timestep); var outputChannels = 1; - var outputBuffer = new DenseTensor(schedulerOptions.GetScaledDimension(outputChannels)); - using (var outputTensorValue = outputBuffer.ToOrtValue()) - using (var inputTensorValue = inputTensor.ToOrtValue()) - using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep)) - using (var promptTensorValue = promptEmbeddings.ToOrtValue()) - using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue()) + var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension)) + using (var inputTensorValue = inputTensor.ToOrtValue(outputTensorMetaData)) + using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputTensorMetaData)) + using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue(outputTensorMetaData)) + using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetaData, timestep)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue }, - { inputNames[1], timestepOrtValue }, + { inputNames[1], timestepTensorValue }, { inputNames[2], promptTensorValue }, { inputNames[3], guidanceTensorValue } }; @@ -139,7 +142,7 @@ protected override async Task> SchedulerStepAsync(IModelOptio var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs); using (var result = results.First()) { - var noisePred = outputBuffer; + var noisePred = outputTensorValue.ToDenseTensor(); // Scheduler Step var schedulerResult = scheduler.Step(noisePred, timestep, latents); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs index 472caae4..02ca867f 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Logging; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; using OnnxStack.Core.Config; using OnnxStack.Core.Services; using OnnxStack.StableDiffusion.Common; @@ -12,7 +13,6 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using OnnxStack.Core; namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion { @@ -63,19 +63,22 @@ protected override async Task> PrepareLatentsAsync(IModelOpti var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder); var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder); + var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeEncoder); + var outputTensorMetaData = outputMetaData[outputNames[0]]; //TODO: Model Config, Channels - var outputBuffer = new DenseTensor(options.GetScaledDimension()); - using (var inputTensorValue = imageTensor.ToOrtValue()) - using (var outputTensorValue = outputBuffer.ToOrtValue()) + var outputDimension = options.GetScaledDimension(); + using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData)) + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue } }; var outputs = new Dictionary { { outputNames[0], outputTensorValue } }; var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs); using (var result = results.First()) { - var scaledSample = outputBuffer - .Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel)) + var outputResult = outputTensorValue.ToDenseTensor(); + var scaledSample = outputResult + .Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel)) .MultiplyBy(model.ScaleFactor); return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps); diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs index 0fb7853c..d67aa4a0 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs @@ -72,6 +72,9 @@ protected override async Task> SchedulerStepAsync(IModelOptio var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet); var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet); var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet); + var outputMetaData = _onnxModelService.GetOutputMetadata(modelOptions, OnnxModelType.Unet); + var outputTensorMetaData = outputMetaData[outputNames[0]]; + var timestepMetaData = inputMetaData[inputNames[1]]; // Loop though the timesteps var step = 0; @@ -87,16 +90,16 @@ protected override async Task> SchedulerStepAsync(IModelOptio inputTensor = ConcatenateLatents(inputTensor, maskedImage, maskImage); var outputChannels = performGuidance ? 2 : 1; - var outputBuffer = new DenseTensor(schedulerOptions.GetScaledDimension(outputChannels)); - using (var outputTensorValue = outputBuffer.ToOrtValue()) - using (var inputTensorValue = inputTensor.ToOrtValue()) - using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep)) - using (var promptTensorValue = promptEmbeddings.ToOrtValue()) + var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension)) + using (var inputTensorValue = inputTensor.ToOrtValue(outputTensorMetaData)) + using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputTensorMetaData)) + using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetaData, timestep)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue }, - { inputNames[1], timestepOrtValue }, + { inputNames[1], timestepTensorValue }, { inputNames[2], promptTensorValue } }; @@ -104,7 +107,7 @@ protected override async Task> SchedulerStepAsync(IModelOptio var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs); using (var result = results.First()) { - var noisePred = outputBuffer; + var noisePred = outputTensorValue.ToDenseTensor(); // Perform guidance if (performGuidance) diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs index dcee3b02..60aad028 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs @@ -73,6 +73,9 @@ protected override async Task> SchedulerStepAsync(IModelOptio var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet); var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet); var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet); + var outputMetaData = _onnxModelService.GetOutputMetadata(modelOptions, OnnxModelType.Unet); + var outputTensorMetaData = outputMetaData[outputNames[0]]; + var timestepMetaData = inputMetaData[inputNames[1]]; // Loop though the timesteps var step = 0; @@ -87,16 +90,16 @@ protected override async Task> SchedulerStepAsync(IModelOptio var inputTensor = scheduler.ScaleInput(inputLatent, timestep); var outputChannels = performGuidance ? 2 : 1; - var outputBuffer = new DenseTensor(schedulerOptions.GetScaledDimension(outputChannels)); - using (var outputTensorValue = outputBuffer.ToOrtValue()) - using (var inputTensorValue = inputTensor.ToOrtValue()) - using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep)) - using (var promptTensorValue = promptEmbeddings.ToOrtValue()) + var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension)) + using (var inputTensorValue = inputTensor.ToOrtValue(outputTensorMetaData)) + using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputTensorMetaData)) + using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetaData, timestep)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue }, - { inputNames[1], timestepOrtValue }, + { inputNames[1], timestepTensorValue }, { inputNames[2], promptTensorValue } }; @@ -104,7 +107,7 @@ protected override async Task> SchedulerStepAsync(IModelOptio var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs); using (var result = results.First()) { - var noisePred = outputBuffer; + var noisePred = outputTensorValue.ToDenseTensor(); // Perform guidance if (performGuidance) @@ -158,19 +161,22 @@ protected override async Task> PrepareLatentsAsync(IModelOpti var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width }); var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder); var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder); + var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeEncoder); + var outputBufferMetaData = outputMetaData[outputNames[0]]; //TODO: Model Config, Channels - var outputBuffer = new DenseTensor(options.GetScaledDimension()); - using (var inputTensorValue = imageTensor.ToOrtValue()) - using (var outputTensorValue = outputBuffer.ToOrtValue()) + var outputDimensions = options.GetScaledDimension(); + using (var inputTensorValue = imageTensor.ToOrtValue(outputBufferMetaData)) + using (var outputTensorValue = outputBufferMetaData.CreateOutputBuffer(outputDimensions)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue } }; var outputs = new Dictionary { { outputNames[0], outputTensorValue } }; var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs); using (var result = results.First()) { - var scaledSample = outputBuffer - .Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel)) + var outputResult = outputTensorValue.ToDenseTensor(); + var scaledSample = outputResult + .Add(scheduler.CreateRandomSample(outputDimensions, options.InitialNoiseLevel)) .MultiplyBy(model.ScaleFactor); return scaledSample; diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs index c4789a27..114369e9 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs @@ -61,6 +61,9 @@ protected override async Task> SchedulerStepAsync(IModelOptio var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet); var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet); var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet); + var outputMetaData = _onnxModelService.GetOutputMetadata(modelOptions, OnnxModelType.Unet); + var outputTensorMetaData = outputMetaData[outputNames[0]]; + var timestepMetaData = inputMetaData[inputNames[1]]; // Loop though the timesteps var step = 0; @@ -75,16 +78,16 @@ protected override async Task> SchedulerStepAsync(IModelOptio var inputTensor = scheduler.ScaleInput(inputLatent, timestep); var outputChannels = performGuidance ? 2 : 1; - var outputBuffer = new DenseTensor(schedulerOptions.GetScaledDimension(outputChannels)); - using (var outputTensorValue = outputBuffer.ToOrtValue()) - using (var inputTensorValue = inputTensor.ToOrtValue()) - using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep)) - using (var promptTensorValue = promptEmbeddings.ToOrtValue()) + var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension)) + using (var inputTensorValue = inputTensor.ToOrtValue(outputTensorMetaData)) + using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputTensorMetaData)) + using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetaData, timestep)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue }, - { inputNames[1], timestepOrtValue }, + { inputNames[1], timestepTensorValue }, { inputNames[2], promptTensorValue } }; @@ -92,7 +95,7 @@ protected override async Task> SchedulerStepAsync(IModelOptio var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs); using (var result = results.First()) { - var noisePred = outputBuffer; + var noisePred = result.ToDenseTensor(); // Perform guidance if (performGuidance) diff --git a/OnnxStack.StableDiffusion/Services/PromptService.cs b/OnnxStack.StableDiffusion/Services/PromptService.cs index 0e460f47..355b7d64 100644 --- a/OnnxStack.StableDiffusion/Services/PromptService.cs +++ b/OnnxStack.StableDiffusion/Services/PromptService.cs @@ -89,22 +89,22 @@ public Task DecodeTextAsync(IModelOptions model, string inputText) /// public async Task EncodeTokensAsync(IModelOptions model, int[] tokenizedInput) { - // Create input tensor. var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.TextEncoder); var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.TextEncoder); + var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.TextEncoder); + var outputTensorMetaData = outputMetaData.Values.First(); var inputDim = new[] { 1L, tokenizedInput.Length }; var outputDim = new[] { 1L, tokenizedInput.Length, model.EmbeddingsLength }; - var outputBuffer = new float[outputDim.GetBufferLength()]; + using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDim.ToInt())) using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(tokenizedInput, inputDim)) - using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(outputBuffer, outputDim)) { var inputs = new Dictionary { { inputNames[0], inputTensorValue } }; var outputs = new Dictionary { { outputNames[0], outputTensorValue } }; var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.TextEncoder, inputs, outputs); using (var result = results.First()) { - return outputBuffer; + return outputTensorValue.ToArray(); } } }