diff --git a/FirebaseAI/Sources/GenerationConfig.swift b/FirebaseAI/Sources/GenerationConfig.swift index 27c4310f12d..fe2b6963e22 100644 --- a/FirebaseAI/Sources/GenerationConfig.swift +++ b/FirebaseAI/Sources/GenerationConfig.swift @@ -48,6 +48,11 @@ public struct GenerationConfig: Sendable { /// Output schema of the generated candidate text. let responseSchema: Schema? + /// Output schema of the generated response in [JSON Schema](https://json-schema.org/) format. + /// + /// If set, `responseSchema` must be omitted and `responseMIMEType` is required. + let responseJSONSchema: JSONObject? + /// Supported modalities of the response. let responseModalities: [ResponseModality]? @@ -175,6 +180,26 @@ public struct GenerationConfig: Sendable { self.stopSequences = stopSequences self.responseMIMEType = responseMIMEType self.responseSchema = responseSchema + responseJSONSchema = nil + self.responseModalities = responseModalities + self.thinkingConfig = thinkingConfig + } + + init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil, + maxOutputTokens: Int? = nil, presencePenalty: Float? = nil, frequencyPenalty: Float? = nil, + stopSequences: [String]? = nil, responseMIMEType: String, responseJSONSchema: JSONObject, + responseModalities: [ResponseModality]? = nil, thinkingConfig: ThinkingConfig? = nil) { + self.temperature = temperature + self.topP = topP + self.topK = topK + self.candidateCount = candidateCount + self.maxOutputTokens = maxOutputTokens + self.presencePenalty = presencePenalty + self.frequencyPenalty = frequencyPenalty + self.stopSequences = stopSequences + self.responseMIMEType = responseMIMEType + responseSchema = nil + self.responseJSONSchema = responseJSONSchema self.responseModalities = responseModalities self.thinkingConfig = thinkingConfig } @@ -195,6 +220,7 @@ extension GenerationConfig: Encodable { case stopSequences case responseMIMEType = "responseMimeType" case responseSchema + case responseJSONSchema = "responseJsonSchema" case responseModalities case thinkingConfig } diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift index 4f4dd1e3dc8..a9e49818364 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift @@ -28,8 +28,6 @@ import Testing /// Test the schema fields. @Suite(.serialized) struct SchemaTests { - // Set temperature, topP and topK to lowest allowed values to make responses more deterministic. - let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1) let safetySettings = [ SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove), SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove), @@ -37,31 +35,32 @@ struct SchemaTests { SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove), SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove), ] - // Candidates and total token counts may differ slightly between runs due to whitespace tokens. - let tokenCountAccuracy = 1 - let storage: Storage - let userID1: String - - init() async throws { - userID1 = try await TestHelpers.getUserID() - storage = Storage.storage() - } - - @Test(arguments: InstanceConfig.allConfigs) - func generateContentSchemaItems(_ config: InstanceConfig) async throws { - let model = FirebaseAI.componentInstance(config).generativeModel( - modelName: ModelNames.gemini2FlashLite, - generationConfig: GenerationConfig( - responseMIMEType: "application/json", - responseSchema: - .array( - items: .string(description: "The name of the city"), - description: "A list of city names", - minItems: 3, - maxItems: 5 - ) + @Test( + arguments: testConfigs( + instanceConfigs: InstanceConfig.allConfigs, + openAPISchema: .array( + items: .string(description: "The name of the city"), + description: "A list of city names", + minItems: 3, + maxItems: 5 ), + jsonSchema: [ + "type": .string("array"), + "description": .string("A list of city names"), + "items": .object([ + "type": .string("string"), + "description": .string("The name of the city"), + ]), + "minItems": .number(3), + "maxItems": .number(5), + ] + ) + ) + func generateContentItemsSchema(_ config: InstanceConfig, _ schema: SchemaType) async throws { + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: SchemaTests.generationConfig(schema: schema), safetySettings: safetySettings ) let prompt = "What are the biggest cities in Canada?" @@ -73,18 +72,25 @@ struct SchemaTests { #expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)") } - @Test(arguments: InstanceConfig.allConfigs) - func generateContentSchemaNumberRange(_ config: InstanceConfig) async throws { + @Test(arguments: testConfigs( + instanceConfigs: InstanceConfig.allConfigs, + openAPISchema: .integer( + description: "A number", + minimum: 110, + maximum: 120 + ), + jsonSchema: [ + "type": .string("integer"), + "description": .string("A number"), + "minimum": .number(110), + "maximum": .number(120), + ] + )) + func generateContentSchemaNumberRange(_ config: InstanceConfig, + _ schema: SchemaType) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( - modelName: ModelNames.gemini2FlashLite, - generationConfig: GenerationConfig( - responseMIMEType: "application/json", - responseSchema: .integer( - description: "A number", - minimum: 110, - maximum: 120 - ) - ), + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: SchemaTests.generationConfig(schema: schema), safetySettings: safetySettings ) let prompt = "Give me a number" @@ -96,41 +102,83 @@ struct SchemaTests { #expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)") } - @Test(arguments: InstanceConfig.allConfigs) - func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws { + @Test(arguments: testConfigs( + instanceConfigs: InstanceConfig.allConfigs, + openAPISchema: .object( + properties: [ + "productName": .string(description: "The name of the product"), + "price": .double( + description: "A price", + minimum: 10.00, + maximum: 120.00 + ), + "salePrice": .float( + description: "A sale price", + minimum: 5.00, + maximum: 90.00 + ), + "rating": .integer( + description: "A rating", + minimum: 1, + maximum: 5 + ), + ], + propertyOrdering: ["salePrice", "rating", "price", "productName"], + title: "ProductInfo" + ), + jsonSchema: [ + "type": .string("object"), + "title": .string("ProductInfo"), + "properties": .object([ + "productName": .object([ + "type": .string("string"), + "description": .string("The name of the product"), + ]), + "price": .object([ + "type": .string("number"), + "description": .string("A price"), + "minimum": .number(10.00), + "maximum": .number(120.00), + ]), + "salePrice": .object([ + "type": .string("number"), + "description": .string("A sale price"), + "minimum": .number(5.00), + "maximum": .number(90.00), + ]), + "rating": .object([ + "type": .string("integer"), + "description": .string("A rating"), + "minimum": .number(1), + "maximum": .number(5), + ]), + ]), + "required": .array([ + .string("productName"), + .string("price"), + .string("salePrice"), + .string("rating"), + ]), + "propertyOrdering": .array([ + .string("salePrice"), + .string("rating"), + .string("price"), + .string("productName"), + ]), + ] + )) + func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig, + _ schema: SchemaType) async throws { struct ProductInfo: Codable { let productName: String - let rating: Int // Will correspond to .integer in schema - let price: Double // Will correspond to .double in schema - let salePrice: Float // Will correspond to .float in schema + let rating: Int + let price: Double + let salePrice: Float } + let model = FirebaseAI.componentInstance(config).generativeModel( - modelName: ModelNames.gemini2FlashLite, - generationConfig: GenerationConfig( - responseMIMEType: "application/json", - responseSchema: .object( - properties: [ - "productName": .string(description: "The name of the product"), - "price": .double( - description: "A price", - minimum: 10.00, - maximum: 120.00 - ), - "salePrice": .float( - description: "A sale price", - minimum: 5.00, - maximum: 90.00 - ), - "rating": .integer( - description: "A rating", - minimum: 1, - maximum: 5 - ), - ], - propertyOrdering: ["salePrice", "rating", "price", "productName"], - title: "ProductInfo" - ), - ), + modelName: ModelNames.gemini2_5_FlashLite, + generationConfig: SchemaTests.generationConfig(schema: schema), safetySettings: safetySettings ) let prompt = "Describe a premium wireless headphone, including a user rating and price." @@ -149,63 +197,127 @@ struct SchemaTests { #expect(rating <= 5, "Expected a rating <= 5, but got \(rating)") } - @Test(arguments: InstanceConfig.allConfigs) - func generateContentAnyOfSchema(_ config: InstanceConfig) async throws { - struct MailingAddress: Decodable { - let streetAddress: String - let city: String - - // Canadian-specific - let province: String? - let postalCode: String? - - // U.S.-specific - let state: String? - let zipCode: String? - - var isCanadian: Bool { - return province != nil && postalCode != nil && state == nil && zipCode == nil + fileprivate struct MailingAddress { + enum PostalInfo { + struct Canada: Decodable { + let province: String + let postalCode: String } - var isAmerican: Bool { - return province == nil && postalCode == nil && state != nil && zipCode != nil + struct UnitedStates: Decodable { + let state: String + let zipCode: String } + + case canada(province: String, postalCode: String) + case unitedStates(state: String, zipCode: String) } + let streetAddress: String + let city: String + let postalInfo: PostalInfo + } + + private static let generateContentAnyOfOpenAPISchema = { let streetSchema = Schema.string(description: "The civic number and street name, for example, '123 Main Street'.") let citySchema = Schema.string(description: "The name of the city.") - let canadianAddressSchema = Schema.object( + let canadaPostalInfoSchema = Schema.object( properties: [ - "streetAddress": streetSchema, - "city": citySchema, "province": .string(description: "The 2-letter province or territory code, for example, 'ON', 'QC', or 'NU'."), "postalCode": .string(description: "The postal code, for example, 'A1A 1A1'."), - ], - description: "A Canadian mailing address" + ] ) - let americanAddressSchema = Schema.object( + let unitedStatesPostalInfoSchema = Schema.object( properties: [ - "streetAddress": streetSchema, - "city": citySchema, "state": .string(description: "The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."), "zipCode": .string(description: "The 5-digit ZIP code, for example, '12345'."), - ], - description: "A U.S. mailing address" + ] ) + let mailingAddressSchema = Schema.object(properties: [ + "streetAddress": streetSchema, + "city": citySchema, + "postalInfo": .anyOf(schemas: [canadaPostalInfoSchema, unitedStatesPostalInfoSchema]), + ]) + return Schema.array(items: mailingAddressSchema) + }() + + private static let generateContentAnyOfJSONSchema = { + let streetSchema: JSONValue = .object([ + "type": .string("string"), + "description": .string("The civic number and street name, for example, '123 Main Street'."), + ]) + let citySchema: JSONValue = .object([ + "type": .string("string"), + "description": .string("The name of the city."), + ]) + let postalInfoSchema: JSONValue = .object([ + "anyOf": .array([ + .object([ + "type": .string("object"), + "properties": .object([ + "province": .object([ + "type": .string("string"), + "description": .string( + "The 2-letter Canadian province or territory code, for example, 'ON', 'QC', or 'NU'." + ), + ]), + "postalCode": .object([ + "type": .string("string"), + "description": .string("The Canadian postal code, for example, 'A1A 1A1'."), + ]), + ]), + "required": .array([.string("province"), .string("postalCode")]), + ]), + .object([ + "type": .string("object"), + "properties": .object([ + "state": .object([ + "type": .string("string"), + "description": .string( + "The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'." + ), + ]), + "zipCode": .object([ + "type": .string("string"), + "description": .string("The 5-digit U.S. ZIP code, for example, '12345'."), + ]), + ]), + "required": .array([.string("state"), .string("zipCode")]), + ]), + ]), + ]) + let mailingAddressSchema: JSONObject = [ + "type": .string("object"), + "description": .string("A mailing address"), + "properties": .object([ + "streetAddress": streetSchema, + "city": citySchema, + "postalInfo": postalInfoSchema, + ]), + "required": .array([ + .string("streetAddress"), + .string("city"), + .string("postalInfo"), + ]), + ] + return [ + "type": .string("array"), + "items": .object(mailingAddressSchema), + ] as JSONObject + }() + + @Test(arguments: testConfigs( + instanceConfigs: InstanceConfig.allConfigs, + openAPISchema: generateContentAnyOfOpenAPISchema, + jsonSchema: generateContentAnyOfJSONSchema + )) + func generateContentAnyOfSchema(_ config: InstanceConfig, _ schema: SchemaType) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( - modelName: ModelNames.gemini2Flash, - generationConfig: GenerationConfig( - temperature: 0.0, - topP: 0.0, - topK: 1, - responseMIMEType: "application/json", - responseSchema: .array(items: .anyOf( - schemas: [canadianAddressSchema, americanAddressSchema] - )) - ), + modelName: ModelNames.gemini2_5_Flash, + generationConfig: SchemaTests.generationConfig(schema: schema), safetySettings: safetySettings ) let prompt = """ @@ -217,19 +329,102 @@ struct SchemaTests { let decodedAddresses = try JSONDecoder().decode([MailingAddress].self, from: jsonData) try #require(decodedAddresses.count == 3, "Expected 3 JSON addresses, got \(text).") let waterlooAddress = decodedAddresses[0] - #expect( - waterlooAddress.isCanadian, - "Expected Canadian University of Waterloo address, got \(waterlooAddress)." - ) + #expect(waterlooAddress.city == "Waterloo") + if case let .canada(province, postalCode) = waterlooAddress.postalInfo { + #expect(province == "ON") + #expect(postalCode == "N2L 3G1") + } else { + Issue.record("Expected Canadian University of Waterloo address, got \(waterlooAddress).") + } let berkeleyAddress = decodedAddresses[1] - #expect( - berkeleyAddress.isAmerican, - "Expected American UC Berkeley address, got \(berkeleyAddress)." - ) + #expect(berkeleyAddress.city == "Berkeley") + if case let .unitedStates(state, zipCode) = berkeleyAddress.postalInfo { + #expect(state == "CA") + #expect(zipCode == "94720") + } else { + Issue.record("Expected American UC Berkeley address, got \(berkeleyAddress).") + } let queensAddress = decodedAddresses[2] - #expect( - queensAddress.isCanadian, - "Expected Canadian Queen's University address, got \(queensAddress)." + #expect(queensAddress.city == "Kingston") + if case let .canada(province, postalCode) = queensAddress.postalInfo { + #expect(province == "ON") + #expect(postalCode == "K7L 3N6") + } else { + Issue.record("Expected Canadian Queen's University address, got \(queensAddress).") + } + } + + enum SchemaType: CustomTestStringConvertible { + case openAPI(Schema) + case json(JSONObject) + + var testDescription: String { + switch self { + case .openAPI: + return "OpenAPI Schema" + case .json: + return "JSON Schema" + } + } + } + + private static func generationConfig(schema: SchemaType) -> GenerationConfig { + let mimeType = "application/json" + switch schema { + case let .openAPI(openAPISchema): + return GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1, responseMIMEType: mimeType, + responseSchema: openAPISchema) + case let .json(jsonSchema): + return GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1, responseMIMEType: mimeType, + responseJSONSchema: jsonSchema) + } + } + + private static func testConfigs(instanceConfigs: [InstanceConfig], openAPISchema: Schema, + jsonSchema: JSONObject) -> [(InstanceConfig, SchemaType)] { + return instanceConfigs.flatMap { [($0, .openAPI(openAPISchema)), ($0, .json(jsonSchema))] } + } +} + +extension SchemaTests.MailingAddress: Decodable { + enum CodingKeys: CodingKey { + case streetAddress + case city + case postalInfo + } + + init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + streetAddress = try container.decode(String.self, forKey: .streetAddress) + city = try container.decode(String.self, forKey: .city) + let canadaPostalInfo = try? container.decode(PostalInfo.Canada.self, forKey: .postalInfo) + let unitedStatesPostalInfo = try? container.decode( + PostalInfo.UnitedStates.self, forKey: .postalInfo ) + + if canadaPostalInfo != nil, unitedStatesPostalInfo != nil { + throw DecodingError.dataCorruptedError( + forKey: .postalInfo, + in: container, + debugDescription: "Ambiguous postal info: matches both Canadian and U.S. formats." + ) + } + + if let canadaPostalInfo { + postalInfo = .canada( + province: canadaPostalInfo.province, postalCode: canadaPostalInfo.postalCode + ) + } else if let unitedStatesPostalInfo { + postalInfo = .unitedStates( + state: unitedStatesPostalInfo.state, zipCode: unitedStatesPostalInfo.zipCode + ) + } else { + throw DecodingError.typeMismatch( + PostalInfo.self, .init( + codingPath: container.codingPath, + debugDescription: "Expected Canadian or U.S. postal info." + ) + ) + } } } diff --git a/FirebaseAI/Tests/Unit/GenerationConfigTests.swift b/FirebaseAI/Tests/Unit/GenerationConfigTests.swift index 2b38d1898d4..edbde87fc7d 100644 --- a/FirebaseAI/Tests/Unit/GenerationConfigTests.swift +++ b/FirebaseAI/Tests/Unit/GenerationConfigTests.swift @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -import FirebaseAILogic +@testable import FirebaseAILogic import Foundation import XCTest @@ -153,4 +153,85 @@ final class GenerationConfigTests: XCTestCase { } """) } + + func testEncodeGenerationConfig_responseJSONSchema() throws { + let mimeType = "application/json" + let responseJSONSchema: JSONObject = [ + "type": .string("object"), + "title": .string("Person"), + "properties": .object([ + "firstName": .object(["type": .string("string")]), + "middleNames": .object([ + "type": .string("array"), + "items": .object(["type": .string("string")]), + "minItems": .number(0), + "maxItems": .number(3), + ]), + "lastName": .object(["type": .string("string")]), + "age": .object(["type": .string("integer")]), + ]), + "required": .array([ + .string("firstName"), + .string("middleNames"), + .string("lastName"), + .string("age"), + ]), + "propertyOrdering": .array([ + .string("firstName"), + .string("middleNames"), + .string("lastName"), + .string("age"), + ]), + "additionalProperties": .bool(false), + ] + let generationConfig = GenerationConfig( + responseMIMEType: mimeType, + responseJSONSchema: responseJSONSchema + ) + + let jsonData = try encoder.encode(generationConfig) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "responseJsonSchema" : { + "additionalProperties" : false, + "properties" : { + "age" : { + "type" : "integer" + }, + "firstName" : { + "type" : "string" + }, + "lastName" : { + "type" : "string" + }, + "middleNames" : { + "items" : { + "type" : "string" + }, + "maxItems" : 3, + "minItems" : 0, + "type" : "array" + } + }, + "propertyOrdering" : [ + "firstName", + "middleNames", + "lastName", + "age" + ], + "required" : [ + "firstName", + "middleNames", + "lastName", + "age" + ], + "title" : "Person", + "type" : "object" + }, + "responseMimeType" : "\(mimeType)" + } + """) + } }