Skip to content

Commit

Permalink
refactor: download LoraWeights in FluxConfiguration
Browse files Browse the repository at this point in the history
  • Loading branch information
mzbac committed Oct 14, 2024
1 parent e37ecae commit e78b164
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
9 changes: 4 additions & 5 deletions Sources/FLUX.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ import Tokenizers

open class FLUX {

internal func loadLoraWeights(hub: HubApi, loraPath: String, dType: DType) async throws
internal func loadLoraWeights(hub: HubApi, loraPath: String, dType: DType) throws
-> [String: MLXArray]
{
let loraDirectory: URL
if FileManager.default.fileExists(atPath: loraPath) {
loraDirectory = URL(fileURLWithPath: loraPath)
} else {
let repo = Hub.Repo(id: loraPath)
try await hub.snapshot(from: repo, matching: ["*.safetensors"])
loraDirectory = hub.localRepoLocation(repo)
}

Expand Down Expand Up @@ -144,7 +143,7 @@ open class FLUX {
}
}

public class Flux1Schnell: FLUX, TextToImageGenerator {
public class Flux1Schnell: FLUX, TextToImageGenerator, @unchecked Sendable {
let clipTokenizer: CLIPTokenizer
let t5Tokenizer: any Tokenizer
let transformer: MultiModalDiffusionTransformer
Expand Down Expand Up @@ -240,7 +239,7 @@ public class Flux1Schnell: FLUX, TextToImageGenerator {
}
}

public class Flux1Dev: FLUX, TextToImageGenerator {
public class Flux1Dev: FLUX, TextToImageGenerator, @unchecked Sendable {
let clipTokenizer: CLIPTokenizer
let t5Tokenizer: any Tokenizer
let transformer: MultiModalDiffusionTransformer
Expand Down Expand Up @@ -349,7 +348,7 @@ public protocol ImageGenerator {
func decode(xt: MLXArray) -> MLXArray
}

public protocol TextToImageGenerator: ImageGenerator {
public protocol TextToImageGenerator: ImageGenerator, Sendable {
func generateLatents(parameters: EvaluateParameters) -> DenoiseIterator
}

Expand Down
38 changes: 32 additions & 6 deletions Sources/FluxConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func fuseLoraWeights(
if let loraA = loraWeight[loraAKey], let loraB = loraWeight[loraBKey],
let transformerWeight = fusedWeights[weightKey]
{
let loraScale: Float = 1.0
let loraScale: Float = 1.0
let loraFused = matmul(loraB, loraA)
fusedWeights[weightKey] = transformerWeight + loraScale * loraFused
}
Expand All @@ -116,7 +116,7 @@ public struct FluxConfiguration: Sendable {
let files: [FileKey: String]
public let defaultParameters: @Sendable () -> EvaluateParameters
let factory:
@Sendable (HubApi, FluxConfiguration, LoadConfiguration) async throws ->
@Sendable (HubApi, FluxConfiguration, LoadConfiguration) throws ->
FLUX

public func download(
Expand All @@ -127,10 +127,23 @@ public struct FluxConfiguration: Sendable {
from: repo, matching: Array(files.values), progressHandler: progressHandler)
}

public func downloadLoraWeights(
hub: HubApi = HubApi(), loadConfiguration: LoadConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws {
guard let loraPath = loadConfiguration.loraPath else {
throw FluxConfigurationError.missingLoraPath
}

let repo = Hub.Repo(id: loraPath)
try await hub.snapshot(
from: repo, matching: ["*.safetensors"], progressHandler: progressHandler)
}

public func textToImageGenerator(hub: HubApi = HubApi(), configuration: LoadConfiguration)
async throws -> TextToImageGenerator?
throws -> TextToImageGenerator?
{
try await factory(hub, self, configuration) as? TextToImageGenerator
try factory(hub, self, configuration) as? TextToImageGenerator
}

public static let flux1Schnell = FluxConfiguration(
Expand All @@ -149,7 +162,7 @@ public struct FluxConfiguration: Sendable {
hub: hub, configuration: fluxConfiguration, dType: loadConfiguration.dType)

if let loraPath = loadConfiguration.loraPath {
let loraWeight = try await flux.loadLoraWeights(
let loraWeight = try flux.loadLoraWeights(
hub: hub, loraPath: loraPath, dType: loadConfiguration.dType)

let weights = fuseLoraWeights(
Expand Down Expand Up @@ -191,7 +204,7 @@ public struct FluxConfiguration: Sendable {
hub: hub, configuration: fluxConfiguration, dType: loadConfiguration.dType)

if let loraPath = loadConfiguration.loraPath {
let loraWeight = try await flux.loadLoraWeights(
let loraWeight = try flux.loadLoraWeights(
hub: hub, loraPath: loraPath, dType: loadConfiguration.dType)

let weights = fuseLoraWeights(
Expand All @@ -217,3 +230,16 @@ public struct FluxConfiguration: Sendable {
}
)
}

enum FluxConfigurationError: Error {
case missingLoraPath
}

extension FluxConfigurationError: LocalizedError {
var errorDescription: String? {
switch self {
case .missingLoraPath:
return "LoRA path is missing. Please provide a valid LoRA path in the LoadConfiguration."
}
}
}

0 comments on commit e78b164

Please sign in to comment.