Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion OnnxStack.Console/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"InterOpNumThreads": 0,
"IntraOpNumThreads": 0,
"ExecutionMode": "ORT_SEQUENTIAL",
"ExecutionProvider": "Cuda",
"ExecutionProvider": "DirectML",
"ModelConfigurations": [
{
"Type": "Tokenizer",
Expand Down
1 change: 0 additions & 1 deletion OnnxStack.IntegrationTests/StableDiffusionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ public async Task GivenTextToImage_WhenInference_ThenImageGenerated(string model
{
Prompt = "an astronaut riding a horse in space",
NegativePrompt = "blurry,ugly,cartoon",
BatchCount = 1,
DiffuserType = DiffuserType.TextToImage
};

Expand Down
2 changes: 0 additions & 2 deletions OnnxStack.StableDiffusion/Config/PromptOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ public class PromptOptions
[StringLength(512)]
public string NegativePrompt { get; set; }

public int BatchCount { get; set; } = 1;

public InputImage InputImage { get; set; }

public InputImage InputImageMask { get; set; }
Expand Down
40 changes: 14 additions & 26 deletions OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp

return schedulerResult;
}


/// <summary>
/// Runs the stable diffusion batch loop
Expand Down Expand Up @@ -210,35 +210,23 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
// Scale and decode the image latents with vae.
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);

var images = prompt.BatchCount > 1
? latents.Split(prompt.BatchCount)
: new[] { latents };
var imageTensors = new List<DenseTensor<float>>();
foreach (var image in images)
{
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);

var outputDim = new[] { 1, 3, options.Height, options.Width };
var outputBuffer = new DenseTensor<float>(outputDim);
using (var inputTensorValue = image.ToOrtValue())
using (var outputTensorValue = outputBuffer.ToOrtValue())
var outputDim = new[] { 1, 3, options.Height, options.Width };
var outputBuffer = new DenseTensor<float>(outputDim);
using (var inputTensorValue = latents.ToOrtValue())
using (var outputTensorValue = outputBuffer.ToOrtValue())
{
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
using (var imageResult = results.First())
{
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
using (var imageResult = results.First())
{
imageTensors.Add(outputBuffer);
}
_logger?.LogEnd("End", timestamp);
return outputBuffer;
}
}

var result = prompt.BatchCount > 1
? imageTensors.Join()
: imageTensors.FirstOrDefault();
_logger?.LogEnd("End", timestamp);
return result;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
.MultiplyBy(model.ScaleFactor);

var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
if (prompt.BatchCount > 1)
return noisySample.Repeat(prompt.BatchCount);

return noisySample;
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
/// <returns></returns>
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
/// <returns></returns>
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
.MultiplyBy(model.ScaleFactor);

var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
if (prompt.BatchCount > 1)
return noisySample.Repeat(prompt.BatchCount);

return noisySample;
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,6 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
});

imageTensor = imageTensor.MultiplyBy(modelOptions.ScaleFactor);
if (promptOptions.BatchCount > 1)
imageTensor = imageTensor.Repeat(promptOptions.BatchCount);

if (schedulerOptions.GuidanceScale > 1f)
imageTensor = imageTensor.Repeat(2);

Expand Down Expand Up @@ -232,9 +229,6 @@ private DenseTensor<float> PrepareImageMask(IModelOptions modelOptions, PromptOp
{
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
var scaledSample = sample.MultiplyBy(modelOptions.ScaleFactor);
if (promptOptions.BatchCount > 1)
scaledSample = scaledSample.Repeat(promptOptions.BatchCount);

if (schedulerOptions.GuidanceScale > 1f)
scaledSample = scaledSample.Repeat(2);

Expand Down Expand Up @@ -267,7 +261,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
/// <returns></returns>
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
.MultiplyBy(model.ScaleFactor);

if (prompt.BatchCount > 1)
return scaledSample.Repeat(prompt.BatchCount);

return scaledSample;
}
}
Expand Down Expand Up @@ -214,9 +211,6 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
}
});

if (promptOptions.BatchCount > 1)
return maskTensor.Repeat(promptOptions.BatchCount);

return maskTensor;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
/// <returns></returns>
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
}
}
}
8 changes: 0 additions & 8 deletions OnnxStack.StableDiffusion/Services/PromptService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, Pro
var promptEmbeddings = await GenerateEmbedsAsync(model, promptTokens, maxPromptTokenCount);
var negativePromptEmbeddings = await GenerateEmbedsAsync(model, negativePromptTokens, maxPromptTokenCount);

// If we have a batch, repeat the prompt embeddings
if (promptOptions.BatchCount > 1)
{
promptEmbeddings = promptEmbeddings.Repeat(promptOptions.BatchCount);
negativePromptEmbeddings = negativePromptEmbeddings.Repeat(promptOptions.BatchCount);
}

// If we are doing guided diffusion, concatenate the negative prompt embeddings
// If not we ingore the negative prompt embeddings
if (isGuidanceEnabled)
Expand Down Expand Up @@ -166,6 +159,5 @@ private static IReadOnlyCollection<NamedOnnxValue> CreateInputParameters(params
{
return parameters.ToList().AsReadOnly();
}

}
}