From 37e234e1ee4f2ef9c21c3d5cd592b5ec63d879af Mon Sep 17 00:00:00 2001 From: Arda Atahan Ibis Date: Thu, 12 Dec 2024 21:59:52 +0300 Subject: [PATCH] Add get_hf_file_metadata Functionality (#142) * add getHfFileMetadata function to HubApi * only allow huggingface endpoints in getHfFileMetadata * add test case for getHfFileMetadata * remove hardcoded string from location check in test case * rename getHfFileMetadata to getFileMetadata and refactor * add blob search for file metadata * Update Tests/HubTests/HubApiTests.swift Co-authored-by: Pedro Cuenca --------- Co-authored-by: Pedro Cuenca --- Sources/Hub/HubApi.swift | 98 ++++++++++++++++++++++++++++++++ Tests/HubTests/HubApiTests.swift | 57 +++++++++++++++++++ 2 files changed, 155 insertions(+) diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 355f5e2..46da89a 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -60,6 +60,25 @@ public extension HubApi { return (data, response) } + func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { + var request = URLRequest(url: url) + request.httpMethod = "HEAD" + if let hfToken = hfToken { + request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") + } + request.setValue("identity", forHTTPHeaderField: "Accept-Encoding") + let (data, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } + + switch response.statusCode { + case 200..<300: break + case 400..<500: throw Hub.HubClientError.authorizationRequired + default: throw Hub.HubClientError.httpStatusCode(response.statusCode) + } + + return (data, response) + } + func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] { // Read repo info and only parse "siblings" let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)")! @@ -222,6 +241,65 @@ public extension HubApi { } } +/// Metadata +public extension HubApi { + /// A structure representing metadata for a remote file + struct FileMetadata { + /// The file's Git commit hash + public let commitHash: String? + + /// Server-provided ETag for caching + public let etag: String? + + /// Stringified URL location of the file + public let location: String + + /// The file's size in bytes + public let size: Int? + } + + private func normalizeEtag(_ etag: String?) -> String? { + guard let etag = etag else { return nil } + return etag.trimmingPrefix("W/").trimmingCharacters(in: CharacterSet(charactersIn: "\"")) + } + + func getFileMetadata(url: URL) async throws -> FileMetadata { + let (_, response) = try await httpHead(for: url) + + return FileMetadata( + commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"), + etag: normalizeEtag( + (response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag")) + ), + location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString, + size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "") + ) + } + + func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [FileMetadata] { + let files = try await getFilenames(from: repo, matching: globs) + let url = URL(string: "\(endpoint)/\(repo.id)/resolve/main")! // TODO: revisions + var selectedMetadata: Array = [] + for file in files { + let fileURL = url.appending(path: file) + selectedMetadata.append(try await getFileMetadata(url: fileURL)) + } + return selectedMetadata + } + + func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [FileMetadata] { + return try await getFileMetadata(from: Repo(id: repoId), matching: globs) + } + + func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [FileMetadata] { + return try await getFileMetadata(from: repo, matching: [glob]) + } + + func getFileMetadata(from repoId: String, matching glob: String) async throws -> [FileMetadata] { + return try await getFileMetadata(from: Repo(id: repoId), matching: [glob]) + } +} + /// Stateless wrappers that use `HubApi` instances public extension Hub { static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { @@ -259,6 +337,26 @@ public extension Hub { static func whoami(token: String) async throws -> Config { return try await HubApi(hfToken: token).whoami() } + + static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata { + return try await HubApi.shared.getFileMetadata(url: fileURL) + } + + static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: repo, matching: globs) + } + + static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs) + } + + static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob]) + } + + static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] { + return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob]) + } } public extension [String] { diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 4ab5d04..49eb930 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -87,6 +87,63 @@ class HubApiTests: XCTestCase { XCTFail("\(error)") } } + + func testGetFileMetadata() async throws { + do { + let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json") + let metadata = try await Hub.getFileMetadata(fileURL: url!) + + XCTAssertNotNil(metadata.commitHash) + XCTAssertNotNil(metadata.etag) + XCTAssertEqual(metadata.location, url?.absoluteString) + XCTAssertEqual(metadata.size, 163) + } catch { + XCTFail("\(error)") + } + } + + func testGetFileMetadataBlobPath() async throws { + do { + let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/blob/main/config.json") + let metadata = try await Hub.getFileMetadata(fileURL: url!) + + XCTAssertEqual(metadata.commitHash, nil) + XCTAssertTrue(metadata.etag != nil && metadata.etag!.hasPrefix("10841-")) + XCTAssertEqual(metadata.location, url?.absoluteString) + XCTAssertEqual(metadata.size, 67649) + } catch { + XCTFail("\(error)") + } + } + + func testGetFileMetadataWithRevision() async throws { + do { + let revision = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2" + let url = URL(string: "https://huggingface.co/julien-c/dummy-unknown/resolve/\(revision)/config.json") + let metadata = try await Hub.getFileMetadata(fileURL: url!) + + XCTAssertEqual(metadata.commitHash, revision) + XCTAssertNotNil(metadata.etag) + XCTAssertGreaterThan(metadata.etag!.count, 0) + XCTAssertEqual(metadata.location, url?.absoluteString) + XCTAssertEqual(metadata.size, 851) + } catch { + XCTFail("\(error)") + } + } + + func testGetFileMetadataWithBlobSearch() async throws { + let repo = "coreml-projects/Llama-2-7b-chat-coreml" + let metadataFromBlob = try await Hub.getFileMetadata(from: repo, matching: "*.json").sorted { $0.location < $1.location } + let files = try await Hub.getFilenames(from: repo, matching: "*.json").sorted() + for (metadata, file) in zip(metadataFromBlob, files) { + XCTAssertNotNil(metadata.commitHash) + XCTAssertNotNil(metadata.etag) + XCTAssertGreaterThan(metadata.etag!.count, 0) + XCTAssertTrue(metadata.location.contains(file)) + XCTAssertGreaterThan(metadata.size!, 0) + } + } } class SnapshotDownloadTests: XCTestCase {