From 800353c1bc1c393f51bfba0d1eba6ce5f636bb65 Mon Sep 17 00:00:00 2001 From: Vinayak Vanjari Date: Mon, 25 Mar 2024 21:38:50 +0400 Subject: [PATCH] PR feedback & minor optimisations --- Sources/Hub/Hub.swift | 20 +++++--------------- Sources/Hub/HubApi.swift | 12 ++++++------ 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 8a668ea..8deeee8 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -179,19 +179,9 @@ public class LanguageModelConfigurationFromHub { ) async throws -> Configurations { let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"] let repo = Hub.Repo(id: modelName) - try await hubApi.snapshot(from: repo, matching: filesToDownload) + let downloadedModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload) - // Note tokenizerConfig may be nil (does not exist in all models) - let modelConfig = try hubApi.configuration(from: "config.json", in: repo) - let tokenizerConfig = try? hubApi.configuration(from: "tokenizer_config.json", in: repo) - let tokenizerVocab = try hubApi.configuration(from: "tokenizer.json", in: repo) - - let configs = Configurations( - modelConfig: modelConfig, - tokenizerConfig: tokenizerConfig, - tokenizerData: tokenizerVocab - ) - return configs + return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi) } func loadConfig( @@ -199,9 +189,9 @@ public class LanguageModelConfigurationFromHub { hubApi: HubApi = .shared ) async throws -> Configurations { // Note tokenizerConfig may be nil (does not exist in all models) - let modelConfig = try hubApi.configuration(url: modelFolder.appending(path: "config.json")) - let tokenizerConfig = try? hubApi.configuration(url: modelFolder.appending(path: "tokenizer_config.json")) - let tokenizerVocab = try hubApi.configuration(url: modelFolder.appending(path: "tokenizer.json")) + let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json")) + let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json")) + let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json")) let configs = Configurations( modelConfig: modelConfig, diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 2516ef0..0d789da 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -93,14 +93,14 @@ public extension HubApi { /// Assumes the file has already been downloaded. /// `filename` is relative to the download base. func configuration(from filename: String, in repo: Repo) throws -> Config { - let url = localRepoLocation(repo).appending(path: filename) - return try configuration(url: url) + let fileURL = localRepoLocation(repo).appending(path: filename) + return try configuration(fileURL: fileURL) } - /// Assumes the file has already present at local url. - /// `url` is complete local file path for given model - func configuration(url: URL) throws -> Config { - let data = try Data(contentsOf: url) + /// Assumes the file is already present at local url. + /// `fileURL` is a complete local file path for the given model + func configuration(fileURL: URL) throws -> Config { + let data = try Data(contentsOf: fileURL) let parsed = try JSONSerialization.jsonObject(with: data, options: []) guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse } return Config(dictionary)