Skip to content

Add sd3 pipeline and models #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions Diffusion-macOS/ControlsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ struct ControlsView: View {
return
}

if !model.supportsNeuralEngine && generation.computeUnits == .cpuAndNeuralEngine {
// Reset compute units to GPU if Neural Engine is not supported
Settings.shared.userSelectedComputeUnits = .cpuAndGPU
resetComputeUnitsState()
print("Neural Engine not supported for model \(model), switching to GPU")
} else {
resetComputeUnitsState()
}

Settings.shared.currentModel = model

pipelineLoader?.cancel()
Expand Down Expand Up @@ -155,9 +164,15 @@ struct ControlsView: View {
VStack {
Spacer()
PromptTextField(text: $generation.positivePrompt, isPositivePrompt: true, model: $model)
.onChange(of: generation.positivePrompt) { prompt in
Settings.shared.prompt = prompt
}
.padding(.top, 5)
Spacer()
PromptTextField(text: $generation.negativePrompt, isPositivePrompt: false, model: $model)
.onChange(of: generation.negativePrompt) { negativePrompt in
Settings.shared.negativePrompt = negativePrompt
}
.padding(.bottom, 5)
Spacer()
}
Expand Down Expand Up @@ -242,7 +257,11 @@ struct ControlsView: View {
Text("Guidance Scale")
Spacer()
Text(guidanceScaleValue)
}.padding(.leading, 10)
}
.onChange(of: generation.guidanceScale) { guidanceScale in
Settings.shared.guidanceScale = guidanceScale
}
.padding(.leading, 10)
} label: {
HStack {
Label("Guidance Scale", systemImage: "scalemass").foregroundColor(.secondary)
Expand All @@ -269,7 +288,11 @@ struct ControlsView: View {
Text("Steps")
Spacer()
Text("\(Int(generation.steps))")
}.padding(.leading, 10)
}
.onChange(of: generation.steps) { steps in
Settings.shared.stepCount = steps
}
.padding(.leading, 10)
} label: {
HStack {
Label("Step count", systemImage: "square.3.layers.3d.down.left").foregroundColor(.secondary)
Expand All @@ -295,7 +318,11 @@ struct ControlsView: View {
Text("Previews")
Spacer()
Text("\(Int(generation.previews))")
}.padding(.leading, 10)
}
.onChange(of: generation.previews) { previews in
Settings.shared.previewCount = previews
}
.padding(.leading, 10)
} label: {
HStack {
Label("Preview count", systemImage: "eye.square").foregroundColor(.secondary)
Expand Down Expand Up @@ -334,25 +361,32 @@ struct ControlsView: View {
seedHelp($showSeedHelp)
}
} else {
Text("\(Int(generation.seed))")
Text(generation.seed.formatted(.number.grouping(.never)))
}
}
.foregroundColor(.secondary)
}

if Capabilities.hasANE {
Divider()
let isNeuralEngineDisabled = !(ModelInfo.from(modelVersion: model)?.supportsNeuralEngine ?? true)
DisclosureGroup(isExpanded: $disclosedAdvanced) {
HStack {
Picker(selection: $generation.computeUnits, label: Text("Use")) {
Text("GPU").tag(ComputeUnits.cpuAndGPU)
Text("Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine)
Text("Neural Engine\(isNeuralEngineDisabled ? " (unavailable)" : "")")
.foregroundColor(isNeuralEngineDisabled ? .secondary : .primary)
.tag(ComputeUnits.cpuAndNeuralEngine)
Text("GPU and Neural Engine").tag(ComputeUnits.all)
}.pickerStyle(.radioGroup).padding(.leading)
Spacer()
}
.onChange(of: generation.computeUnits) { units in
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
if isNeuralEngineDisabled && units == .cpuAndNeuralEngine {
resetComputeUnitsState()
return
}
let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units)
if variantDownloaded {
updateComputeUnitsState()
Expand Down Expand Up @@ -430,8 +464,10 @@ struct ControlsView: View {
set: { newValue in
if let seed = UInt32(newValue) {
generation.seed = seed
Settings.shared.seed = seed
} else {
generation.seed = 0
Settings.shared.seed = 0
}
}
)
Expand All @@ -442,8 +478,10 @@ struct ControlsView: View {
.onChange(of: seedBinding.wrappedValue, perform: { newValue in
if let seed = UInt32(newValue) {
generation.seed = seed
Settings.shared.seed = seed
} else {
generation.seed = 0
Settings.shared.seed = 0
}
})
.onReceive(Just(seedBinding.wrappedValue)) { newValue in
Expand Down
32 changes: 16 additions & 16 deletions Diffusion.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
objects = {

/* Begin PBXBuildFile section */
16AFDD4F2C1B7D6200536A62 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = 16AFDD4E2C1B7D6200536A62 /* StableDiffusion */; };
16AFDD512C1B7D6700536A62 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = 16AFDD502C1B7D6700536A62 /* StableDiffusion */; };
8C4B32042A770C1D0090EF17 /* DiffusionImage+macOS.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8C4B32032A770C1D0090EF17 /* DiffusionImage+macOS.swift */; };
8C4B32062A770C300090EF17 /* DiffusionImage+iOS.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8C4B32052A770C300090EF17 /* DiffusionImage+iOS.swift */; };
8C4B32082A77F90C0090EF17 /* Utils_iOS.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8C4B32072A77F90C0090EF17 /* Utils_iOS.swift */; };
Expand All @@ -16,8 +18,6 @@
8CEEB7D92A54C88C00C23829 /* DiffusionImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8CEEB7D82A54C88C00C23829 /* DiffusionImage.swift */; };
8CEEB7DA2A54C88C00C23829 /* DiffusionImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8CEEB7D82A54C88C00C23829 /* DiffusionImage.swift */; };
EB067F872992E561004D1AD9 /* HelpContent.swift in Sources */ = {isa = PBXBuildFile; fileRef = EB067F862992E561004D1AD9 /* HelpContent.swift */; };
EB25B3D62A3A2DC4000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */; };
EB25B3D82A3A2DD5000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */; };
EB560F0429A3C20800C0F8B8 /* Capabilities.swift in Sources */ = {isa = PBXBuildFile; fileRef = EB560F0329A3C20800C0F8B8 /* Capabilities.swift */; };
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */; };
EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; };
Expand Down Expand Up @@ -116,8 +116,8 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
EB25B3D62A3A2DC4000E25A1 /* StableDiffusion in Frameworks */,
EBB5BA5D294504DE003A2A5B /* ZIPFoundation in Frameworks */,
16AFDD512C1B7D6700536A62 /* StableDiffusion in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand All @@ -140,7 +140,7 @@
buildActionMask = 2147483647;
files = (
F155203C297118E700DC009B /* CompactSlider in Frameworks */,
EB25B3D82A3A2DD5000E25A1 /* StableDiffusion in Frameworks */,
16AFDD4F2C1B7D6200536A62 /* StableDiffusion in Frameworks */,
EBDD7DAF29731FB300C1C4B2 /* ZIPFoundation in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand Down Expand Up @@ -318,7 +318,7 @@
name = Diffusion;
packageProductDependencies = (
EBB5BA5C294504DE003A2A5B /* ZIPFoundation */,
EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */,
16AFDD502C1B7D6700536A62 /* StableDiffusion */,
);
productName = Diffusion;
productReference = EBE755C5293E37DD00806B32 /* Diffusion.app */;
Expand Down Expand Up @@ -378,7 +378,7 @@
packageProductDependencies = (
F155203B297118E700DC009B /* CompactSlider */,
EBDD7DAE29731FB300C1C4B2 /* ZIPFoundation */,
EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */,
16AFDD4E2C1B7D6200536A62 /* StableDiffusion */,
);
productName = "Diffusion-macOS";
productReference = F15520212971093300DC009B /* Diffusers.app */;
Expand Down Expand Up @@ -422,7 +422,7 @@
packageReferences = (
EBB5BA5B294504DE003A2A5B /* XCRemoteSwiftPackageReference "ZIPFoundation" */,
F155203A297118E600DC009B /* XCRemoteSwiftPackageReference "CompactSlider" */,
EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */,
16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */,
);
productRefGroup = EBE755C6293E37DD00806B32 /* Products */;
projectDirPath = "";
Expand Down Expand Up @@ -876,7 +876,7 @@
"$(inherited)",
"@executable_path/../Frameworks",
);
MACOSX_DEPLOYMENT_TARGET = 13.1;
MACOSX_DEPLOYMENT_TARGET = 14.0;
SDKROOT = macosx;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 5.0;
Expand Down Expand Up @@ -904,7 +904,7 @@
"$(inherited)",
"@executable_path/../Frameworks",
);
MACOSX_DEPLOYMENT_TARGET = 13.1;
MACOSX_DEPLOYMENT_TARGET = 14.0;
PRODUCT_BUNDLE_IDENTIFIER = com.huggingface.Diffusers;
SDKROOT = macosx;
SWIFT_EMIT_LOC_STRINGS = YES;
Expand Down Expand Up @@ -963,9 +963,9 @@
/* End XCConfigurationList section */

/* Begin XCRemoteSwiftPackageReference section */
EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = {
16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/apple/ml-stable-diffusion";
repositoryURL = "https://github.com/argmaxinc/ml-stable-diffusion.git";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apple appears to still be reviewing but I'll be happy to update when that is merged - this is just for testing and review purposes here.

requirement = {
branch = main;
kind = branch;
Expand All @@ -990,18 +990,18 @@
/* End XCRemoteSwiftPackageReference section */

/* Begin XCSwiftPackageProductDependency section */
EB0199482A31FEAF00B133E2 /* StableDiffusion */ = {
16AFDD4E2C1B7D6200536A62 /* StableDiffusion */ = {
isa = XCSwiftPackageProductDependency;
package = 16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
productName = StableDiffusion;
};
EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */ = {
16AFDD502C1B7D6700536A62 /* StableDiffusion */ = {
isa = XCSwiftPackageProductDependency;
package = EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
package = 16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
productName = StableDiffusion;
};
EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */ = {
EB0199482A31FEAF00B133E2 /* StableDiffusion */ = {
isa = XCSwiftPackageProductDependency;
package = EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
productName = StableDiffusion;
};
EBB5BA5C294504DE003A2A5B /* ZIPFoundation */ = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"originHash" : "e97aab54879429ea40e58df49ffe4eef5228d95a28a7cf4d5dca9204c33564e1",
"pins" : [
{
"identity" : "compactslider",
Expand All @@ -12,19 +13,19 @@
{
"identity" : "ml-stable-diffusion",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/ml-stable-diffusion",
"location" : "https://github.com/argmaxinc/ml-stable-diffusion.git",
"state" : {
"branch" : "main",
"revision" : "d456a972cd7d84cab2ec353a29896d59b8602248"
"revision" : "d1f0604fab5345011e0b9f5b87ee0c155612565f"
}
},
{
"identity" : "swift-argument-parser",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-argument-parser.git",
"state" : {
"revision" : "fddd1c00396eed152c45a46bea9f47b98e59301d",
"version" : "1.2.0"
"revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b",
"version" : "1.4.0"
}
},
{
Expand All @@ -37,5 +38,5 @@
}
}
],
"version" : 2
"version" : 3
}
Binary file modified Diffusion/Assets.xcassets/placeholder.imageset/labrador.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 9 additions & 3 deletions Diffusion/Common/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ class Downloader: NSObject, ObservableObject {
self.destination = destination
super.init()

var config = URLSessionConfiguration.default
#if !os(macOS)
// .background allows downloads to proceed in the background
let config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
// helpful for devices that may not keep the app in the foreground for the download duration
config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
config.isDiscretionary = false
config.sessionSendsLaunchEvents = true
#endif
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
Expand Down Expand Up @@ -75,8 +81,8 @@ class Downloader: NSObject, ObservableObject {
}

extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten _: Int64, totalBytesExpectedToWrite _: Int64) {
downloadState.value = .downloading(downloadTask.progress.fractionCompleted)
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))
}

func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
Expand Down
39 changes: 35 additions & 4 deletions Diffusion/Common/ModelInfo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ struct ModelInfo {

/// Suffix of the archive containing the SPLIT_EINSUM_V2 attention variant. Usually something like "split_einsum_v2_compiled"
let splitAttentionV2Suffix: String


/// Whether the archive contains ANE optimized models
let supportsNeuralEngine: Bool

/// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use.
let supportsEncoder: Bool

Expand All @@ -46,25 +49,33 @@ struct ModelInfo {
/// Whether this is a Stable Diffusion XL model
// TODO: retrieve from remote config
let isXL: Bool


/// Whether this is a Stable Diffusion 3 model
// TODO: retrieve from remote config
let isSD3: Bool

//TODO: refactor all these properties
init(modelId: String, modelVersion: String,
originalAttentionSuffix: String = "original_compiled",
splitAttentionSuffix: String = "split_einsum_compiled",
splitAttentionV2Suffix: String = "split_einsum_v2_compiled",
supportsNeuralEngine: Bool = true,
supportsEncoder: Bool = false,
supportsAttentionV2: Bool = false,
quantized: Bool = false,
isXL: Bool = false) {
isXL: Bool = false,
isSD3: Bool = false) {
self.modelId = modelId
self.modelVersion = modelVersion
self.originalAttentionSuffix = originalAttentionSuffix
self.splitAttentionSuffix = splitAttentionSuffix
self.splitAttentionV2Suffix = splitAttentionV2Suffix
self.supportsNeuralEngine = supportsNeuralEngine
self.supportsEncoder = supportsEncoder
self.supportsAttentionV2 = supportsAttentionV2
self.quantized = quantized
self.isXL = isXL
self.isSD3 = isSD3
}
}

Expand Down Expand Up @@ -202,6 +213,24 @@ extension ModelInfo {
isXL: true
)

static let sd3 = ModelInfo(
modelId: "argmaxinc/coreml-stable-diffusion-3-medium",
modelVersion: "SD3 medium (512, macOS)",
supportsNeuralEngine: false, // TODO: support SD3 on ANE
supportsEncoder: false,
quantized: false,
isSD3: true
)

static let sd3highres = ModelInfo(
modelId: "argmaxinc/coreml-stable-diffusion-3-medium-1024-t5",
modelVersion: "SD3 medium (1024, T5, macOS)",
supportsNeuralEngine: false, // TODO: support SD3 on ANE
supportsEncoder: false,
quantized: false,
isSD3: true
)

static let MODELS: [ModelInfo] = {
if deviceSupportsQuantization {
var models = [
Expand All @@ -218,7 +247,9 @@ extension ModelInfo {
models.append(contentsOf: [
ModelInfo.xl,
ModelInfo.xlWithRefiner,
ModelInfo.xlmbp
ModelInfo.xlmbp,
ModelInfo.sd3,
ModelInfo.sd3highres,
])
} else {
models.append(ModelInfo.xlmbpChunked)
Expand Down
Loading