Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ let package = Package(
.library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"])
],
dependencies: [
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.3.0"))
.package(url: "https://github.com/huggingface/swift-jinja.git", from: "2.0.0")
],
targets: [
.target(name: "Generation", dependencies: ["Tokenizers"]),
.target(name: "Hub", resources: [.process("Resources")], swiftSettings: swiftSettings),
.target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings),
.target(name: "Models", dependencies: ["Tokenizers", "Generation"]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]),
.testTarget(name: "GenerationTests", dependencies: ["Generation"]),
.testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")], swiftSettings: swiftSettings),
.testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings),
.testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]),
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]),
]
Expand Down
23 changes: 12 additions & 11 deletions Sources/Hub/Config.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// Created by Piotr Kowalczuk on 06.03.25.

import Foundation
import Jinja

// MARK: - Configuration files with dynamic lookup

Expand Down Expand Up @@ -433,28 +434,28 @@ public struct Config: Hashable, Sendable,
self.dictionary(or: or)
}

public func toJinjaCompatible() -> Any? {
public func jinjaValue() -> Jinja.Value {
switch self.value {
case let .array(val):
return val.map { $0.toJinjaCompatible() }
return .array(val.map { $0.jinjaValue() })
case let .dictionary(val):
var result: [String: Any?] = [:]
var result: [String: Jinja.Value] = [:]
for (key, config) in val {
result[key.string] = config.toJinjaCompatible()
result[key.string] = config.jinjaValue()
}
return result
return .object(.init(uniqueKeysWithValues: result))
case let .boolean(val):
return val
return .boolean(val)
case let .floating(val):
return val
return .double(Double(String(val)) ?? Double(val))
case let .integer(val):
return val
return .int(val)
case let .string(val):
return val.string
return .string(val.string)
case let .token(val):
return [String(val.0): val.1.string] as [String: String]
return [String(val.0): .string(val.1.string)]
case .null:
return nil
return .null
}
}

Expand Down
20 changes: 11 additions & 9 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -769,32 +769,34 @@ public class PreTrainedTokenizer: Tokenizer {
}

let template = try compiledTemplate(for: selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt,
var context: [String: Jinja.Value] = try [
"messages": .array(messages.map { try Value(any: $0) }),
"add_generation_prompt": .boolean(addGenerationPrompt),
]
if let tools {
context["tools"] = tools
context["tools"] = try .array(tools.map { try Value(any: $0) })
}
if let additionalContext {
// Additional keys and values to be added to the context provided to the prompt templating engine.
// For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
// The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
for (key, value) in additionalContext {
context[key] = value
context[key] = try Value(any: value)
}
}

for (key, value) in tokenizerConfig.dictionary(or: [:]) {
if specialTokenAttributes.contains(key.string), !value.isNull() {
if let stringValue = value.string() {
context[key.string] = stringValue
context[key.string] = .string(stringValue)
} else if let dictionary = value.dictionary() {
context[key.string] = addedTokenAsString(Config(dictionary))
if let addedTokenString = addedTokenAsString(Config(dictionary)) {
context[key.string] = .string(addedTokenString)
}
} else if let array: [String] = value.get() {
context[key.string] = array
context[key.string] = .array(array.map { .string($0) })
} else {
context[key.string] = value
context[key.string] = try Value(any: value)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/HubTests/ConfigTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ struct ConfigTests {
"""

let got = try Template(template).render([
"config": cfg.toJinjaCompatible()
"config": cfg.jinjaValue()
])

#expect(got == exp)
Expand Down
1 change: 1 addition & 0 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ struct ChatTemplateTests {
What is the weather in Paris today?<|im_end|>
<|im_start|>assistant


"""

#expect(
Expand Down