Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Dynamic ModelType #123

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ iOSInjectionProject/
# OS
.DS_Store

.idea
89 changes: 48 additions & 41 deletions Libraries/LLM/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,57 +26,64 @@ public enum StringOrNumber: Codable, Equatable, Sendable {
}
}

public enum ModelType: String, Codable, Sendable {
case mistral
case llama
case phi
case phi3
case gemma
case gemma2
case qwen2
case starcoder2
case cohere
case openelm
public struct ModelType: RawRepresentable, Codable, Sendable {
public let rawValue: String

public func createModel(configuration: URL) throws -> LLMModel {
switch self {
case .mistral, .llama:
let configuration = try JSONDecoder().decode(
LlamaConfiguration.self, from: Data(contentsOf: configuration))
return LlamaModel(configuration)
case .phi:
let configuration = try JSONDecoder().decode(
PhiConfiguration.self, from: Data(contentsOf: configuration))
public init(rawValue: String) {
self.rawValue = rawValue
}

private static func createLlamaModel(url: URL) throws -> LLMModel {
let configuration = try JSONDecoder().decode(LlamaConfiguration.self, from: Data(contentsOf: url))
return LlamaModel(configuration)
}

private static var creators: [String: (URL) throws -> LLMModel] = [
johnmai-dev marked this conversation as resolved.
Show resolved Hide resolved
"mistral": createLlamaModel,
"llama": createLlamaModel,
"phi": { url in
let configuration = try JSONDecoder().decode(PhiConfiguration.self, from: Data(contentsOf: url))
return PhiModel(configuration)
case .phi3:
let configuration = try JSONDecoder().decode(
Phi3Configuration.self, from: Data(contentsOf: configuration))
},
"phi3": { url in
let configuration = try JSONDecoder().decode(Phi3Configuration.self, from: Data(contentsOf: url))
return Phi3Model(configuration)
case .gemma:
let configuration = try JSONDecoder().decode(
GemmaConfiguration.self, from: Data(contentsOf: configuration))
},
"gemma": { url in
let configuration = try JSONDecoder().decode(GemmaConfiguration.self, from: Data(contentsOf: url))
return GemmaModel(configuration)
case .gemma2:
let configuration = try JSONDecoder().decode(
Gemma2Configuration.self, from: Data(contentsOf: configuration))
},
"gemma2": { url in
let configuration = try JSONDecoder().decode(Gemma2Configuration.self, from: Data(contentsOf: url))
return Gemma2Model(configuration)
case .qwen2:
let configuration = try JSONDecoder().decode(
Qwen2Configuration.self, from: Data(contentsOf: configuration))
},
"qwen2": { url in
let configuration = try JSONDecoder().decode(Qwen2Configuration.self, from: Data(contentsOf: url))
return Qwen2Model(configuration)
case .starcoder2:
let configuration = try JSONDecoder().decode(
Starcoder2Configuration.self, from: Data(contentsOf: configuration))
},
"starcoder2": { url in
let configuration = try JSONDecoder().decode(Starcoder2Configuration.self, from: Data(contentsOf: url))
return Starcoder2Model(configuration)
case .cohere:
let configuration = try JSONDecoder().decode(
CohereConfiguration.self, from: Data(contentsOf: configuration))
},
"cohere": { url in
let configuration = try JSONDecoder().decode(CohereConfiguration.self, from: Data(contentsOf: url))
return CohereModel(configuration)
case .openelm:
let configuration = try JSONDecoder().decode(
OpenElmConfiguration.self, from: Data(contentsOf: configuration))
},
"openelm": { url in
let configuration = try JSONDecoder().decode(OpenElmConfiguration.self, from: Data(contentsOf: url))
return OpenELMModel(configuration)
}
]

public static func registerModelType(_ type: String, creator: @escaping (URL) throws -> LLMModel) {
creators[type] = creator
}

public func createModel(configuration: URL) throws -> LLMModel {
guard let creator = ModelType.creators[rawValue] else {
throw LLMError(message: "Unsupported model type.")
}
return try creator(configuration)
}
}

Expand Down
2 changes: 1 addition & 1 deletion Libraries/LLM/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config {
}

/// overrides for TokenizerModel/knownTokenizers
let replacementTokenizers = [
public var replacementTokenizers = [
johnmai-dev marked this conversation as resolved.
Show resolved Hide resolved
"Qwen2Tokenizer": "PreTrainedTokenizer",
"CohereTokenizer": "PreTrainedTokenizer",
]
Expand Down