Skip to content

Commit

Permalink
Support loading tokenizer from local folder #76 (#81)
Browse files Browse the repository at this point in the history
* Support loading tokenizer from local folder #76

* PR feedback & minor optimisations
  • Loading branch information
vinu-vanjari authored Mar 25, 2024
1 parent 508c540 commit 18b62e5
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 6 deletions.
24 changes: 20 additions & 4 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ public class LanguageModelConfigurationFromHub {
return try await self.loadConfig(modelName: modelName, hubApi: hubApi)
}
}

public init(
modelFolder: URL,
hubApi: HubApi = .shared
) {
self.configPromise = Task {
return try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
}
}

public var modelConfig: Config {
get async throws {
Expand Down Expand Up @@ -170,12 +179,19 @@ public class LanguageModelConfigurationFromHub {
) async throws -> Configurations {
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
let repo = Hub.Repo(id: modelName)
try await hubApi.snapshot(from: repo, matching: filesToDownload)
let downloadedModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)

return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi)
}

func loadConfig(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Configurations {
// Note tokenizerConfig may be nil (does not exist in all models)
let modelConfig = try hubApi.configuration(from: "config.json", in: repo)
let tokenizerConfig = try? hubApi.configuration(from: "tokenizer_config.json", in: repo)
let tokenizerVocab = try hubApi.configuration(from: "tokenizer.json", in: repo)
let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json"))
let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))

let configs = Configurations(
modelConfig: modelConfig,
Expand Down
10 changes: 8 additions & 2 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,14 @@ public extension HubApi {
/// Assumes the file has already been downloaded.
/// `filename` is relative to the download base.
func configuration(from filename: String, in repo: Repo) throws -> Config {
let url = localRepoLocation(repo).appending(path: filename)
let data = try Data(contentsOf: url)
let fileURL = localRepoLocation(repo).appending(path: filename)
return try configuration(fileURL: fileURL)
}

/// Assumes the file is already present at local url.
/// `fileURL` is a complete local file path for the given model
func configuration(fileURL: URL) throws -> Config {
let data = try Data(contentsOf: fileURL)
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse }
return Config(dictionary)
Expand Down
11 changes: 11 additions & 0 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,17 @@ extension AutoTokenizer {

return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}
}

// MARK: - Tokenizer model classes
Expand Down
20 changes: 20 additions & 0 deletions Tests/TokenizersTests/FactoryTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,24 @@ class FactoryTests: TestWithCustomHubDownloadLocation {
let inputIds = tokenizer("Today she took a train to the West")
XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257])
}

func testFromModelFolder() async throws {
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
let repo = Hub.Repo(id: "coreml-projects/Llama-2-7b-chat-coreml")
let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)

let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi)
let inputIds = tokenizer("Today she took a train to the West")
XCTAssertEqual(inputIds, [1, 20628, 1183, 3614, 263, 7945, 304, 278, 3122])
}

func testWhisperFromModelFolder() async throws {
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
let repo = Hub.Repo(id: "openai/whisper-large-v2")
let localModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)

let tokenizer = try await AutoTokenizer.from(modelFolder: localModelFolder, hubApi: hubApi)
let inputIds = tokenizer("Today she took a train to the West")
XCTAssertEqual(inputIds, [50258, 50363, 27676, 750, 1890, 257, 3847, 281, 264, 4055, 50257])
}
}

0 comments on commit 18b62e5

Please sign in to comment.