diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart index 6697c972dc05..4c8990d39844 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart @@ -14,6 +14,7 @@ import 'content.dart'; import 'error.dart'; +import 'function_calling.dart' show Tool, ToolConfig; import 'schema.dart'; /// Response for Count Tokens @@ -155,6 +156,23 @@ final class UsageMetadata { final List? candidatesTokensDetails; } +/// Constructe a UsageMetadata with all it's fields. +/// +/// Expose access to the private constructor for use within the package.. +UsageMetadata createUsageMetadata({ + required int? promptTokenCount, + required int? candidatesTokenCount, + required int? totalTokenCount, + required List? promptTokensDetails, + required List? candidatesTokensDetails, +}) => + UsageMetadata._( + promptTokenCount: promptTokenCount, + candidatesTokenCount: candidatesTokenCount, + totalTokenCount: totalTokenCount, + promptTokensDetails: promptTokensDetails, + candidatesTokensDetails: candidatesTokensDetails); + /// Response candidate generated from a [GenerativeModel]. final class Candidate { // TODO: token count? @@ -842,53 +860,124 @@ enum TaskType { Object toJson() => _jsonString; } -/// Parse the json to [GenerateContentResponse] -GenerateContentResponse parseGenerateContentResponse(Object jsonObject) { - if (jsonObject case {'error': final Object error}) throw parseError(error); - final candidates = switch (jsonObject) { - {'candidates': final List candidates} => - candidates.map(_parseCandidate).toList(), - _ => [] - }; - final promptFeedback = switch (jsonObject) { - {'promptFeedback': final promptFeedback?} => - _parsePromptFeedback(promptFeedback), - _ => null, - }; - final usageMedata = switch (jsonObject) { - {'usageMetadata': final usageMetadata?} => - _parseUsageMetadata(usageMetadata), - _ => null, - }; - return GenerateContentResponse(candidates, promptFeedback, - usageMetadata: usageMedata); +// ignore: public_member_api_docs +abstract interface class SerializationStrategy { + // ignore: public_member_api_docs + GenerateContentResponse parseGenerateContentResponse(Object jsonObject); + // ignore: public_member_api_docs + CountTokensResponse parseCountTokensResponse(Object jsonObject); + // ignore: public_member_api_docs + Map generateContentRequest( + Iterable contents, + ({String prefix, String name}) model, + List safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + ); + + // ignore: public_member_api_docs + Map countTokensRequest( + Iterable contents, + ({String prefix, String name}) model, + List safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + ); } -/// Parse the json to [CountTokensResponse] -CountTokensResponse parseCountTokensResponse(Object jsonObject) { - if (jsonObject case {'error': final Object error}) throw parseError(error); +// ignore: public_member_api_docs +final class VertexSerialization implements SerializationStrategy { + /// Parse the json to [GenerateContentResponse] + @override + GenerateContentResponse parseGenerateContentResponse(Object jsonObject) { + if (jsonObject case {'error': final Object error}) throw parseError(error); + final candidates = switch (jsonObject) { + {'candidates': final List candidates} => + candidates.map(_parseCandidate).toList(), + _ => [] + }; + final promptFeedback = switch (jsonObject) { + {'promptFeedback': final promptFeedback?} => + _parsePromptFeedback(promptFeedback), + _ => null, + }; + final usageMedata = switch (jsonObject) { + {'usageMetadata': final usageMetadata?} => + _parseUsageMetadata(usageMetadata), + {'totalTokens': final int totalTokens} => + UsageMetadata._(totalTokenCount: totalTokens), + _ => null, + }; + return GenerateContentResponse(candidates, promptFeedback, + usageMetadata: usageMedata); + } - if (jsonObject is! Map) { - throw unhandledFormat('CountTokensResponse', jsonObject); + /// Parse the json to [CountTokensResponse] + @override + CountTokensResponse parseCountTokensResponse(Object jsonObject) { + if (jsonObject case {'error': final Object error}) throw parseError(error); + + if (jsonObject is! Map) { + throw unhandledFormat('CountTokensResponse', jsonObject); + } + + final totalTokens = jsonObject['totalTokens'] as int; + final totalBillableCharacters = switch (jsonObject) { + {'totalBillableCharacters': final int totalBillableCharacters} => + totalBillableCharacters, + _ => null, + }; + final promptTokensDetails = switch (jsonObject) { + {'promptTokensDetails': final List promptTokensDetails} => + promptTokensDetails.map(_parseModalityTokenCount).toList(), + _ => null, + }; + + return CountTokensResponse( + totalTokens, + totalBillableCharacters: totalBillableCharacters, + promptTokensDetails: promptTokensDetails, + ); } - final totalTokens = jsonObject['totalTokens'] as int; - final totalBillableCharacters = switch (jsonObject) { - {'totalBillableCharacters': final int totalBillableCharacters} => - totalBillableCharacters, - _ => null, - }; - final promptTokensDetails = switch (jsonObject) { - {'promptTokensDetails': final List promptTokensDetails} => - promptTokensDetails.map(_parseModalityTokenCount).toList(), - _ => null, - }; + @override + Map generateContentRequest( + Iterable contents, + ({String prefix, String name}) model, + List safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + ) { + return { + 'model': '${model.prefix}/${model.name}', + 'contents': contents.map((c) => c.toJson()).toList(), + if (safetySettings.isNotEmpty) + 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), + if (generationConfig != null) + 'generationConfig': generationConfig.toJson(), + if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), + if (toolConfig != null) 'toolConfig': toolConfig.toJson(), + if (systemInstruction != null) + 'systemInstruction': systemInstruction.toJson(), + }; + } - return CountTokensResponse( - totalTokens, - totalBillableCharacters: totalBillableCharacters, - promptTokensDetails: promptTokensDetails, - ); + @override + Map countTokensRequest( + Iterable contents, + ({String prefix, String name}) model, + List safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + ) => + // Everything except contents is ignored. + {'contents': contents.map((c) => c.toJson()).toList()}; } Candidate _parseCandidate(Object? jsonObject) { diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart index 4edcdb73da12..fea23c3f93ad 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart @@ -26,6 +26,7 @@ import 'package:web_socket_channel/io.dart'; import 'api.dart'; import 'client.dart'; import 'content.dart'; +import 'developer/api.dart'; import 'function_calling.dart'; import 'imagen_api.dart'; import 'imagen_content.dart'; @@ -52,33 +53,28 @@ enum Task { predict, } -/// Base class for models. -/// -/// Do not instantiate directly. -abstract class BaseModel { - // ignore: public_member_api_docs - BaseModel( +abstract interface class _ModelUri { + String get baseAuthority; + Uri taskUri(Task task); + ({String prefix, String name}) get model; +} + +final class _VertexUri implements _ModelUri { + _VertexUri( {required String model, required String location, required FirebaseApp app}) - : _model = normalizeModelName(model), + : model = _normalizeModelName(model), _projectUri = _vertexUri(app, location); - static const _baseUrl = 'firebasevertexai.googleapis.com'; + static const _baseAuthority = 'firebasevertexai.googleapis.com'; static const _apiVersion = 'v1beta'; - final ({String prefix, String name}) _model; - - final Uri _projectUri; - - /// The normalized model name. - ({String prefix, String name}) get model => _model; - /// Returns the model code for a user friendly model name. /// /// If the model name is already a model code (contains a `/`), use the parts /// directly. Otherwise, return a `models/` model code. - static ({String prefix, String name}) normalizeModelName(String modelName) { + static ({String prefix, String name}) _normalizeModelName(String modelName) { if (!modelName.contains('/')) return (prefix: 'models', name: modelName); final parts = modelName.split('/'); return (prefix: parts.first, name: parts.skip(1).join('/')); @@ -87,11 +83,79 @@ abstract class BaseModel { static Uri _vertexUri(FirebaseApp app, String location) { var projectId = app.options.projectId; return Uri.https( - _baseUrl, + _baseAuthority, '/$_apiVersion/projects/$projectId/locations/$location/publishers/google', ); } + final Uri _projectUri; + @override + final ({String prefix, String name}) model; + + @override + String get baseAuthority => _baseAuthority; + + @override + Uri taskUri(Task task) { + return _projectUri.replace( + pathSegments: _projectUri.pathSegments + .followedBy([model.prefix, '${model.name}:${task.name}'])); + } +} + +final class _GoogleAIUri implements _ModelUri { + _GoogleAIUri({ + required String model, + required FirebaseApp app, + }) : model = _normalizeModelName(model), + _baseUri = _googleAIBaseUri(app: app); + + /// Returns the model code for a user friendly model name. + /// + /// If the model name is already a model code (contains a `/`), use the parts + /// directly. Otherwise, return a `models/` model code. + static ({String prefix, String name}) _normalizeModelName(String modelName) { + if (!modelName.contains('/')) return (prefix: 'models', name: modelName); + final parts = modelName.split('/'); + return (prefix: parts.first, name: parts.skip(1).join('/')); + } + + static const _apiVersion = 'v1beta'; + static const _baseAuthority = 'firebasevertexai.googleapis.com'; + static Uri _googleAIBaseUri( + {String apiVersion = _apiVersion, required FirebaseApp app}) => + Uri.https( + _baseAuthority, '$apiVersion/projects/${app.options.projectId}'); + final Uri _baseUri; + + @override + final ({String prefix, String name}) model; + + @override + String get baseAuthority => _baseAuthority; + + @override + Uri taskUri(Task task) => _baseUri.replace( + pathSegments: _baseUri.pathSegments + .followedBy([model.prefix, '${model.name}:${task.name}'])); +} + +/// Base class for models. +/// +/// Do not instantiate directly. +abstract class BaseModel { + BaseModel._( + {required SerializationStrategy serializationStrategy, + required _ModelUri modelUri}) + : _serializationStrategy = serializationStrategy, + _modelUri = modelUri; + + final SerializationStrategy _serializationStrategy; + final _ModelUri _modelUri; + + /// The normalized model name. + ({String prefix, String name}) get model => _modelUri.model; + /// Returns a function that generates Firebase auth tokens. static FutureOr> Function() firebaseTokens( FirebaseAppCheck? appCheck, FirebaseAuth? auth, FirebaseApp? app) { @@ -120,9 +184,7 @@ abstract class BaseModel { } /// Returns a URI for the given [task]. - Uri taskUri(Task task) => _projectUri.replace( - pathSegments: _projectUri.pathSegments - .followedBy([_model.prefix, '${_model.name}:${task.name}'])); + Uri taskUri(Task task) => _modelUri.taskUri(task); } /// An abstract base class for models that interact with an API using an [ApiClient]. @@ -136,11 +198,11 @@ abstract class BaseModel { abstract class BaseApiClientModel extends BaseModel { // ignore: public_member_api_docs BaseApiClientModel({ - required super.model, - required super.location, - required super.app, + required super.serializationStrategy, + required super.modelUri, required ApiClient client, - }) : _client = client; + }) : _client = client, + super._(); final ApiClient _client; diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/developer/api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/developer/api.dart new file mode 100644 index 000000000000..8bc1117fdbfa --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/developer/api.dart @@ -0,0 +1,323 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import '../api.dart' + show + BlockReason, + Candidate, + Citation, + CitationMetadata, + CountTokensResponse, + FinishReason, + GenerateContentResponse, + GenerationConfig, + HarmBlockThreshold, + HarmCategory, + HarmProbability, + PromptFeedback, + SafetyRating, + SafetySetting, + SerializationStrategy, + UsageMetadata, + createUsageMetadata; +import '../content.dart' show Content, FunctionCall, Part, TextPart; +import '../error.dart'; +import '../function_calling.dart' show Tool, ToolConfig; + +HarmProbability _parseHarmProbability(Object jsonObject) => + switch (jsonObject) { + 'UNSPECIFIED' => HarmProbability.unknown, + 'NEGLIGIBLE' => HarmProbability.negligible, + 'LOW' => HarmProbability.low, + 'MEDIUM' => HarmProbability.medium, + 'HIGH' => HarmProbability.high, + _ => throw unhandledFormat('HarmProbability', jsonObject), + }; +HarmCategory _parseHarmCategory(Object jsonObject) => switch (jsonObject) { + 'HARM_CATEGORY_UNSPECIFIED' => HarmCategory.unknown, + 'HARM_CATEGORY_HARASSMENT' => HarmCategory.harassment, + 'HARM_CATEGORY_HATE_SPEECH' => HarmCategory.hateSpeech, + 'HARM_CATEGORY_SEXUALLY_EXPLICIT' => HarmCategory.sexuallyExplicit, + 'HARM_CATEGORY_DANGEROUS_CONTENT' => HarmCategory.dangerousContent, + _ => throw unhandledFormat('HarmCategory', jsonObject), + }; + +FinishReason _parseFinishReason(Object jsonObject) => switch (jsonObject) { + 'UNSPECIFIED' => FinishReason.unknown, + 'STOP' => FinishReason.stop, + 'MAX_TOKENS' => FinishReason.maxTokens, + 'SAFETY' => FinishReason.safety, + 'RECITATION' => FinishReason.recitation, + 'OTHER' => FinishReason.other, + _ => throw unhandledFormat('FinishReason', jsonObject), + }; +BlockReason _parseBlockReason(String jsonObject) => switch (jsonObject) { + 'BLOCK_REASON_UNSPECIFIED' => BlockReason.unknown, + 'SAFETY' => BlockReason.safety, + 'OTHER' => BlockReason.other, + _ => throw unhandledFormat('BlockReason', jsonObject), + }; +String _harmBlockThresholdtoJson(HarmBlockThreshold? threshold) => + switch (threshold) { + null => 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + HarmBlockThreshold.low => 'BLOCK_LOW_AND_ABOVE', + HarmBlockThreshold.medium => 'BLOCK_MEDIUM_AND_ABOVE', + HarmBlockThreshold.high => 'BLOCK_ONLY_HIGH', + HarmBlockThreshold.none => 'BLOCK_NONE' + }; +String _harmCategoryToJson(HarmCategory harmCategory) => switch (harmCategory) { + HarmCategory.unknown => 'HARM_CATEGORY_UNSPECIFIED', + HarmCategory.harassment => 'HARM_CATEGORY_HARASSMENT', + HarmCategory.hateSpeech => 'HARM_CATEGORY_HATE_SPEECH', + HarmCategory.sexuallyExplicit => 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + HarmCategory.dangerousContent => 'HARM_CATEGORY_DANGEROUS_CONTENT' + }; + +Object _safetySettingToJson(SafetySetting safetySetting) { + if (safetySetting.method != null) { + throw ArgumentError( + 'HarmBlockMethod is not supported by google AI and must be left null.'); + } + return { + 'category': _harmCategoryToJson(safetySetting.category), + 'threshold': _harmBlockThresholdtoJson(safetySetting.threshold) + }; +} + +// ignore: public_member_api_docs +final class DeveloperSerialization implements SerializationStrategy { + @override + GenerateContentResponse parseGenerateContentResponse(Object jsonObject) { + if (jsonObject case {'error': final Object error}) throw parseError(error); + final candidates = switch (jsonObject) { + {'candidates': final List candidates} => + candidates.map(_parseCandidate).toList(), + _ => [] + }; + final promptFeedback = switch (jsonObject) { + {'promptFeedback': final promptFeedback?} => + _parsePromptFeedback(promptFeedback), + _ => null, + }; + final usageMedata = switch (jsonObject) { + {'usageMetadata': final usageMetadata?} => + _parseUsageMetadata(usageMetadata), + _ => null, + }; + return GenerateContentResponse(candidates, promptFeedback, + usageMetadata: usageMedata); + } + + @override + CountTokensResponse parseCountTokensResponse(Object jsonObject) { + if (jsonObject case {'error': final Object error}) throw parseError(error); + if (jsonObject case {'totalTokens': final int totalTokens}) { + return CountTokensResponse(totalTokens); + } + throw unhandledFormat('CountTokensResponse', jsonObject); + } + + @override + Map generateContentRequest( + Iterable contents, + ({String prefix, String name}) model, + List safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + ) { + return { + 'model': '${model.prefix}/${model.name}', + 'contents': contents.map((c) => c.toJson()).toList(), + if (safetySettings.isNotEmpty) + 'safetySettings': safetySettings.map(_safetySettingToJson).toList(), + if (generationConfig != null) + 'generationConfig': generationConfig.toJson(), + if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), + if (toolConfig != null) 'toolConfig': toolConfig.toJson(), + if (systemInstruction != null) + 'systemInstruction': systemInstruction.toJson(), + }; + } + + @override + Map countTokensRequest( + Iterable contents, + ({String prefix, String name}) model, + List safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + ) => + { + 'generateContentRequest': generateContentRequest( + contents, + model, + safetySettings, + generationConfig, + tools, + toolConfig, + null, + ) + }; +} + +Candidate _parseCandidate(Object? jsonObject) { + if (jsonObject is! Map) { + throw unhandledFormat('Candidate', jsonObject); + } + + return Candidate( + jsonObject.containsKey('content') + ? _parseGoogleAIContent(jsonObject['content'] as Object) + : Content(null, []), + switch (jsonObject) { + {'safetyRatings': final List safetyRatings} => + safetyRatings.map(_parseSafetyRating).toList(), + _ => null + }, + switch (jsonObject) { + {'citationMetadata': final Object citationMetadata} => + _parseCitationMetadata(citationMetadata), + _ => null + }, + switch (jsonObject) { + {'finishReason': final Object finishReason} => + _parseFinishReason(finishReason), + _ => null + }, + switch (jsonObject) { + {'finishMessage': final String finishMessage} => finishMessage, + _ => null + }, + ); +} + +PromptFeedback _parsePromptFeedback(Object jsonObject) { + return switch (jsonObject) { + { + 'safetyRatings': final List safetyRatings, + } => + PromptFeedback( + switch (jsonObject) { + {'blockReason': final String blockReason} => + _parseBlockReason(blockReason), + _ => null, + }, + switch (jsonObject) { + {'blockReasonMessage': final String blockReasonMessage} => + blockReasonMessage, + _ => null, + }, + safetyRatings.map(_parseSafetyRating).toList()), + _ => throw unhandledFormat('PromptFeedback', jsonObject), + }; +} + +UsageMetadata _parseUsageMetadata(Object jsonObject) { + if (jsonObject is! Map) { + throw unhandledFormat('UsageMetadata', jsonObject); + } + final promptTokenCount = switch (jsonObject) { + {'promptTokenCount': final int promptTokenCount} => promptTokenCount, + _ => null, + }; + final candidatesTokenCount = switch (jsonObject) { + {'candidatesTokenCount': final int candidatesTokenCount} => + candidatesTokenCount, + _ => null, + }; + final totalTokenCount = switch (jsonObject) { + {'totalTokenCount': final int totalTokenCount} => totalTokenCount, + _ => null, + }; + return createUsageMetadata( + promptTokenCount: promptTokenCount, + candidatesTokenCount: candidatesTokenCount, + totalTokenCount: totalTokenCount, + promptTokensDetails: null, + candidatesTokensDetails: null, + ); +} + +SafetyRating _parseSafetyRating(Object? jsonObject) { + return switch (jsonObject) { + { + 'category': final Object category, + 'probability': final Object probability + } => + SafetyRating( + _parseHarmCategory(category), _parseHarmProbability(probability)), + _ => throw unhandledFormat('SafetyRating', jsonObject), + }; +} + +CitationMetadata _parseCitationMetadata(Object? jsonObject) { + return switch (jsonObject) { + {'citationSources': final List citationSources} => + CitationMetadata(citationSources.map(_parseCitationSource).toList()), + // Vertex SDK format uses `citations` + {'citations': final List citationSources} => + CitationMetadata(citationSources.map(_parseCitationSource).toList()), + _ => throw unhandledFormat('CitationMetadata', jsonObject), + }; +} + +Citation _parseCitationSource(Object? jsonObject) { + if (jsonObject is! Map) { + throw unhandledFormat('CitationSource', jsonObject); + } + + final uriString = jsonObject['uri'] as String?; + + return Citation( + jsonObject['startIndex'] as int?, + jsonObject['endIndex'] as int?, + uriString != null ? Uri.parse(uriString) : null, + jsonObject['license'] as String?, + ); +} + +Content _parseGoogleAIContent(Object jsonObject) { + return switch (jsonObject) { + {'parts': final List parts} => Content( + switch (jsonObject) { + {'role': final String role} => role, + _ => null, + }, + parts.map(_parsePart).toList()), + _ => throw unhandledFormat('Content', jsonObject), + }; +} + +Part _parsePart(Object? jsonObject) { + return switch (jsonObject) { + {'text': final String text} => TextPart(text), + { + 'functionCall': { + 'name': final String name, + 'args': final Map args + } + } => + FunctionCall(name, args), + { + 'functionResponse': {'name': String _, 'response': Map _} + } => + throw UnimplementedError('FunctionResponse part not yet supported'), + {'inlineData': {'mimeType': String _, 'data': String _}} => + throw UnimplementedError('inlineData content part not yet supported'), + _ => throw unhandledFormat('Part', jsonObject), + }; +} diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart index 80f2c68a026c..fb5cb3b18e1d 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart @@ -27,8 +27,13 @@ const _defaultLocation = 'us-central1'; /// The entrypoint for [FirebaseVertexAI]. class FirebaseVertexAI extends FirebasePluginPlatform { FirebaseVertexAI._( - {required this.app, required this.location, this.appCheck, this.auth}) - : super(app.name, 'plugins.flutter.io/firebase_vertexai'); + {required this.app, + required this.location, + required bool useVertexBackend, + this.appCheck, + this.auth}) + : _useVertexBackend = useVertexBackend, + super(app.name, 'plugins.flutter.io/firebase_vertexai'); /// The [FirebaseApp] for this current [FirebaseVertexAI] instance. FirebaseApp app; @@ -43,6 +48,8 @@ class FirebaseVertexAI extends FirebasePluginPlatform { /// The service location for this [FirebaseVertexAI] instance. String location; + final bool _useVertexBackend; + static final Map _cachedInstances = {}; /// Returns an instance using the default [FirebaseApp]. @@ -61,6 +68,18 @@ class FirebaseVertexAI extends FirebasePluginPlatform { FirebaseAppCheck? appCheck, FirebaseAuth? auth, String? location, + }) => + vertexAI(app: app, appCheck: appCheck, auth: auth, location: location); + + /// Returns an instance using a specified [FirebaseApp]. + /// + /// If [app] is not provided, the default Firebase app will be used. + /// If pass in [appCheck], request session will get protected from abusing. + static FirebaseVertexAI vertexAI({ + FirebaseApp? app, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + String? location, }) { app ??= Firebase.app(); @@ -71,7 +90,39 @@ class FirebaseVertexAI extends FirebasePluginPlatform { location ??= _defaultLocation; FirebaseVertexAI newInstance = FirebaseVertexAI._( - app: app, location: location, appCheck: appCheck, auth: auth); + app: app, + location: location, + appCheck: appCheck, + auth: auth, + useVertexBackend: true, + ); + _cachedInstances[app.name] = newInstance; + + return newInstance; + } + + /// Returns an instance using a specified [FirebaseApp]. + /// + /// If [app] is not provided, the default Firebase app will be used. + /// If pass in [appCheck], request session will get protected from abusing. + static FirebaseVertexAI googleAI({ + FirebaseApp? app, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + }) { + app ??= Firebase.app(); + + if (_cachedInstances.containsKey(app.name)) { + return _cachedInstances[app.name]!; + } + + FirebaseVertexAI newInstance = FirebaseVertexAI._( + app: app, + location: _defaultLocation, + appCheck: appCheck, + auth: auth, + useVertexBackend: false, + ); _cachedInstances[app.name] = newInstance; return newInstance; @@ -100,6 +151,7 @@ class FirebaseVertexAI extends FirebasePluginPlatform { model: model, app: app, appCheck: appCheck, + useVertexBackend: _useVertexBackend, auth: auth, location: location, safetySettings: safetySettings, diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/generative_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/generative_model.dart index 31841f8834c1..8533e3519287 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/generative_model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/generative_model.dart @@ -17,7 +17,7 @@ part of vertexai_model; /// A multimodel generative model (like Gemini). /// -/// Allows generating content, creating embeddings, and counting the number of +/// Allows generating content and counting the number of /// tokens in a piece of content. final class GenerativeModel extends BaseApiClientModel { /// Create a [GenerativeModel] backed by the generative model named [model]. @@ -36,6 +36,7 @@ final class GenerativeModel extends BaseApiClientModel { required String model, required String location, required FirebaseApp app, + required bool useVertexBackend, FirebaseAppCheck? appCheck, FirebaseAuth? auth, List? safetySettings, @@ -50,9 +51,12 @@ final class GenerativeModel extends BaseApiClientModel { _toolConfig = toolConfig, _systemInstruction = systemInstruction, super( - model: model, - app: app, - location: location, + serializationStrategy: useVertexBackend + ? VertexSerialization() + : DeveloperSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: model, location: location) + : _GoogleAIUri(app: app, model: model), client: HttpApiClient( apiKey: app.options.apiKey, httpClient: httpClient, @@ -62,6 +66,7 @@ final class GenerativeModel extends BaseApiClientModel { required String model, required String location, required FirebaseApp app, + required useVertexBackend, FirebaseAppCheck? appCheck, FirebaseAuth? auth, List? safetySettings, @@ -76,9 +81,12 @@ final class GenerativeModel extends BaseApiClientModel { _toolConfig = toolConfig, _systemInstruction = systemInstruction, super( - model: model, - app: app, - location: location, + serializationStrategy: useVertexBackend + ? VertexSerialization() + : DeveloperSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: model, location: location) + : _GoogleAIUri(app: app, model: model), client: apiClient ?? HttpApiClient( apiKey: app.options.apiKey, @@ -92,31 +100,6 @@ final class GenerativeModel extends BaseApiClientModel { final ToolConfig? _toolConfig; final Content? _systemInstruction; - Map _generateContentRequest( - Iterable contents, { - List? safetySettings, - GenerationConfig? generationConfig, - List? tools, - ToolConfig? toolConfig, - }) { - safetySettings ??= _safetySettings; - generationConfig ??= _generationConfig; - tools ??= _tools; - toolConfig ??= _toolConfig; - return { - 'model': '${model.prefix}/${model.name}', - 'contents': contents.map((c) => c.toJson()).toList(), - if (safetySettings.isNotEmpty) - 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), - if (generationConfig != null) - 'generationConfig': generationConfig.toJson(), - if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), - if (toolConfig != null) 'toolConfig': toolConfig.toJson(), - if (_systemInstruction case final systemInstruction?) - 'systemInstruction': systemInstruction.toJson(), - }; - } - /// Generates content responding to [prompt]. /// /// Sends a "generateContent" API request for the configured model, @@ -134,14 +117,16 @@ final class GenerativeModel extends BaseApiClientModel { ToolConfig? toolConfig}) => makeRequest( Task.generateContent, - _generateContentRequest( + _serializationStrategy.generateContentRequest( prompt, - safetySettings: safetySettings, - generationConfig: generationConfig, - tools: tools, - toolConfig: toolConfig, + model, + safetySettings ?? _safetySettings, + generationConfig ?? _generationConfig, + tools ?? _tools, + toolConfig ?? _toolConfig, + _systemInstruction, ), - parseGenerateContentResponse); + _serializationStrategy.parseGenerateContentResponse); /// Generates a stream of content responding to [prompt]. /// @@ -163,14 +148,16 @@ final class GenerativeModel extends BaseApiClientModel { ToolConfig? toolConfig}) { final response = client.streamRequest( taskUri(Task.streamGenerateContent), - _generateContentRequest( + _serializationStrategy.generateContentRequest( prompt, - safetySettings: safetySettings, - generationConfig: generationConfig, - tools: tools, - toolConfig: toolConfig, + model, + safetySettings ?? _safetySettings, + generationConfig ?? _generationConfig, + tools ?? _tools, + toolConfig ?? _toolConfig, + _systemInstruction, )); - return response.map(parseGenerateContentResponse); + return response.map(_serializationStrategy.parseGenerateContentResponse); } /// Counts the total number of tokens in [contents]. @@ -193,10 +180,16 @@ final class GenerativeModel extends BaseApiClientModel { Future countTokens( Iterable contents, ) async { - final parameters = { - 'contents': contents.map((c) => c.toJson()).toList() - }; - return makeRequest(Task.countTokens, parameters, parseCountTokensResponse); + final parameters = _serializationStrategy.countTokensRequest( + contents, + model, + _safetySettings, + _generationConfig, + _tools, + _toolConfig, + ); + return makeRequest(Task.countTokens, parameters, + _serializationStrategy.parseCountTokensResponse); } } @@ -205,6 +198,7 @@ GenerativeModel createGenerativeModel({ required FirebaseApp app, required String location, required String model, + required bool useVertexBackend, FirebaseAppCheck? appCheck, FirebaseAuth? auth, GenerationConfig? generationConfig, @@ -217,6 +211,7 @@ GenerativeModel createGenerativeModel({ model: model, app: app, appCheck: appCheck, + useVertexBackend: useVertexBackend, auth: auth, location: location, safetySettings: safetySettings, @@ -234,6 +229,7 @@ GenerativeModel createModelWithClient({ required String location, required String model, required ApiClient client, + required bool useVertexBackend, Content? systemInstruction, FirebaseAppCheck? appCheck, FirebaseAuth? auth, @@ -246,6 +242,7 @@ GenerativeModel createModelWithClient({ model: model, app: app, appCheck: appCheck, + useVertexBackend: useVertexBackend, auth: auth, location: location, safetySettings: safetySettings, diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart index 68adf1f67e7e..5be9b3ea1538 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart @@ -36,9 +36,12 @@ final class ImagenModel extends BaseApiClientModel { : _generationConfig = generationConfig, _safetySettings = safetySettings, super( - model: model, - app: app, - location: location, + serializationStrategy: VertexSerialization(), + modelUri: _VertexUri( + model: model, + app: app, + location: location, + ), client: HttpApiClient( apiKey: app.options.apiKey, requestHeaders: BaseModel.firebaseTokens(appCheck, auth, app))); diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/live_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/live_model.dart index 57d844d566ae..06e443341e98 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/live_model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/live_model.dart @@ -43,11 +43,15 @@ final class LiveGenerativeModel extends BaseModel { _liveGenerationConfig = liveGenerationConfig, _tools = tools, _systemInstruction = systemInstruction, - super( - model: model, - app: app, - location: location, + super._( + serializationStrategy: VertexSerialization(), + modelUri: _VertexUri( + model: model, + app: app, + location: location, + ), ); + static const _apiVersion = 'v1beta'; final FirebaseApp _app; final String _location; @@ -65,10 +69,11 @@ final class LiveGenerativeModel extends BaseModel { /// Returns a [Future] that resolves to an [LiveSession] object upon successful /// connection. Future connect() async { - final uri = - 'wss://${BaseModel._baseUrl}/$_apiUrl.${BaseModel._apiVersion}.$_apiUrlSuffix/$_location?key=${_app.options.apiKey}'; - final modelString = - 'projects/${_app.options.projectId}/locations/$_location/publishers/google/models/${model.name}'; + final uri = 'wss://${_modelUri.baseAuthority}/' + '$_apiUrl.$_apiVersion.$_apiUrlSuffix/' + '$_location?key=${_app.options.apiKey}'; + final modelString = 'projects/${_app.options.projectId}/' + 'locations/$_location/publishers/google/models/${model.name}'; final setupJson = { 'setup': { diff --git a/packages/firebase_vertexai/firebase_vertexai/test/base_model_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/base_model_test.dart index 6ca5b0e251ff..0a886841ed83 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/base_model_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/base_model_test.dart @@ -69,59 +69,8 @@ class MockApiClient extends Mock implements ApiClient { } } -// A concrete subclass of BaseModel for testing purposes -class TestBaseModel extends BaseModel { - TestBaseModel({ - required String model, - required String location, - required FirebaseApp app, - }) : super(model: model, location: location, app: app); -} - -class TestApiClientModel extends BaseApiClientModel { - TestApiClientModel({ - required super.model, - required super.location, - required super.app, - required ApiClient client, - }) : super(client: client); -} - void main() { group('BaseModel', () { - test('normalizeModelName returns correct prefix and name for model code', - () { - final result = BaseModel.normalizeModelName('models/my-model'); - expect(result.prefix, 'models'); - expect(result.name, 'my-model'); - }); - - test( - 'normalizeModelName returns correct prefix and name for user-friendly name', - () { - final result = BaseModel.normalizeModelName('my-model'); - expect(result.prefix, 'models'); - expect(result.name, 'my-model'); - }); - - test('taskUri constructs the correct URI for a task', () { - final mockApp = MockFirebaseApp(); - final model = TestBaseModel( - model: 'my-model', location: 'us-central1', app: mockApp); - final taskUri = model.taskUri(Task.generateContent); - expect(taskUri.toString(), - 'https://firebasevertexai.googleapis.com/v1beta/projects/test-project/locations/us-central1/publishers/google/models/my-model:generateContent'); - }); - - test('taskUri constructs the correct URI for a task with model code', () { - final mockApp = MockFirebaseApp(); - final model = TestBaseModel( - model: 'models/my-model', location: 'us-central1', app: mockApp); - final taskUri = model.taskUri(Task.countTokens); - expect(taskUri.toString(), - 'https://firebasevertexai.googleapis.com/v1beta/projects/test-project/locations/us-central1/publishers/google/models/my-model:countTokens'); - }); - test('firebaseTokens returns a function that generates headers', () async { final tokenFunction = BaseModel.firebaseTokens(null, null, null); final headers = await tokenFunction(); @@ -189,33 +138,4 @@ void main() { expect(headers.length, 4); }); }); - - group('BaseApiClientModel', () { - test('makeRequest returns the parsed response', () async { - final mockApp = MockFirebaseApp(); - final mockClient = MockApiClient(); - final model = TestApiClientModel( - model: 'test-model', - location: 'us-central1', - app: mockApp, - client: mockClient); - final params = {'input': 'test'}; - const task = Task.generateContent; - - final response = await model.makeRequest( - task, params, (data) => data['mockResponse']! as String); - expect(response, 'success'); - }); - - test('client getter returns the injected ApiClient', () { - final mockApp = MockFirebaseApp(); - final mockClient = MockApiClient(); - final model = TestApiClientModel( - model: 'test-model', - location: 'us-central1', - app: mockApp, - client: mockClient); - expect(model.client, mockClient); - }); - }); } diff --git a/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart index 51afe2a9b495..e0655978e898 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart @@ -38,6 +38,7 @@ void main() { final client = ClientController(); final model = createModelWithClient( app: app, + useVertexBackend: true, model: modelName, client: client.client, location: 'us-central1'); diff --git a/packages/firebase_vertexai/firebase_vertexai/test/google_ai_generative_model_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/google_ai_generative_model_test.dart new file mode 100644 index 000000000000..5440175503f1 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/test/google_ai_generative_model_test.dart @@ -0,0 +1,731 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:firebase_core/firebase_core.dart'; +import 'package:firebase_vertexai/firebase_vertexai.dart'; +import 'package:firebase_vertexai/src/base_model.dart'; +import 'package:flutter_test/flutter_test.dart'; + +import 'mock.dart'; +import 'utils/matchers.dart'; +import 'utils/stub_client.dart'; + +void main() { + setupFirebaseVertexAIMocks(); + late FirebaseApp app; + setUpAll(() async { + // Initialize Firebase + app = await Firebase.initializeApp(); + }); + group('GenerativeModel', () { + const defaultModelName = 'some-model'; + + (ClientController, GenerativeModel) createModel({ + String modelName = defaultModelName, + List? tools, + ToolConfig? toolConfig, + Content? systemInstruction, + }) { + final client = ClientController(); + final model = createModelWithClient( + useVertexBackend: false, + app: app, + model: modelName, + client: client.client, + tools: tools, + toolConfig: toolConfig, + systemInstruction: systemInstruction, + location: 'us-central1'); + return (client, model); + } + + test('strips leading "models/" from model name', () async { + final (client, model) = createModel( + modelName: 'models/$defaultModelName', + ); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + response: arbitraryGenerateContentResponse, + verifyRequest: (uri, _) { + expect(uri.path, endsWith('/models/some-model:generateContent')); + }, + ); + }); + + test('allows specifying a tuned model', () async { + final (client, model) = createModel( + modelName: 'tunedModels/$defaultModelName', + ); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + response: arbitraryGenerateContentResponse, + verifyRequest: (uri, _) { + expect(uri.path, endsWith('/tunedModels/some-model:generateContent')); + }, + ); + }); + + test('allows specifying an API version', () async { + final (client, model) = createModel( + // requestOptions: RequestOptions(apiVersion: 'override_version'), + ); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + response: arbitraryGenerateContentResponse, + verifyRequest: (uri, _) { + expect(uri.path, startsWith('/override_version/')); + }, + ); + }, skip: 'No support for overriding API version'); + + group('generate unary content', () { + test('can make successful request', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + const result = 'Some response'; + final response = await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + verifyRequest: (uri, request) { + expect( + uri, + Uri.parse( + 'https://firebasevertexai.googleapis.com/v1beta/' + 'projects/123/' + 'models/some-model:generateContent', + ), + ); + expect(request, { + 'model': 'models/$defaultModelName', + 'contents': [ + { + 'role': 'user', + 'parts': [ + {'text': prompt}, + ], + }, + ], + }); + }, + response: { + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': result}, + ], + }, + }, + ], + }, + ); + expect( + response, + matchesGenerateContentResponse( + GenerateContentResponse([ + Candidate( + Content('model', [TextPart(result)]), + null, + null, + null, + null, + ), + ], null), + ), + ); + }); + + test('can override safety settings', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent( + [Content.text(prompt)], + safetySettings: [ + SafetySetting( + HarmCategory.dangerousContent, + HarmBlockThreshold.high, + null, + ), + ], + ), + response: arbitraryGenerateContentResponse, + verifyRequest: (_, request) { + expect(request['safetySettings'], [ + { + 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', + 'threshold': 'BLOCK_ONLY_HIGH', + }, + ]); + }, + ); + }); + + test('can override generation config', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([ + Content.text(prompt), + ], generationConfig: GenerationConfig(stopSequences: ['a'])), + verifyRequest: (_, request) { + expect(request['generationConfig'], { + 'stopSequences': ['a'], + }); + }, + response: arbitraryGenerateContentResponse, + ); + }); + + test('can pass system instructions', () async { + const instructions = 'Do a good job'; + final (client, model) = createModel( + systemInstruction: Content.system(instructions), + ); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + verifyRequest: (_, request) { + expect(request['systemInstruction'], { + 'role': 'system', + 'parts': [ + {'text': instructions}, + ], + }); + }, + response: arbitraryGenerateContentResponse, + ); + }); + + test('can pass tools and function calling config', () async { + final (client, model) = createModel( + tools: [ + Tool.functionDeclarations([ + FunctionDeclaration( + 'someFunction', + 'Some cool function.', + parameters: { + 'schema1': Schema.string(description: 'Some parameter.'), + }, + ), + ]), + ], + toolConfig: ToolConfig( + functionCallingConfig: FunctionCallingConfig.any( + {'someFunction'}, + ), + ), + ); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + verifyRequest: (_, request) { + expect(request['tools'], [ + { + 'functionDeclarations': [ + { + 'name': 'someFunction', + 'description': 'Some cool function.', + 'parameters': { + 'type': 'OBJECT', + 'properties': { + 'schema1': { + 'type': 'STRING', + 'description': 'Some parameter.' + } + }, + 'required': ['schema1'] + } + }, + ], + }, + ]); + expect(request['toolConfig'], { + 'functionCallingConfig': { + 'mode': 'ANY', + 'allowedFunctionNames': ['someFunction'], + }, + }); + }, + response: arbitraryGenerateContentResponse, + ); + }); + + test('can override tools and function calling config', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent( + [Content.text(prompt)], + tools: [ + Tool.functionDeclarations([ + FunctionDeclaration( + 'someFunction', + 'Some cool function.', + parameters: { + 'schema1': Schema.string(description: 'Some parameter.'), + }, + ), + ]), + ], + toolConfig: ToolConfig( + functionCallingConfig: FunctionCallingConfig.any( + {'someFunction'}, + ), + ), + ), + verifyRequest: (_, request) { + expect(request['tools'], [ + { + 'functionDeclarations': [ + { + 'name': 'someFunction', + 'description': 'Some cool function.', + 'parameters': { + 'type': 'OBJECT', + 'properties': { + 'schema1': { + 'type': 'STRING', + 'description': 'Some parameter.' + } + }, + 'required': ['schema1'] + } + }, + ], + }, + ]); + expect(request['toolConfig'], { + 'functionCallingConfig': { + 'mode': 'ANY', + 'allowedFunctionNames': ['someFunction'], + }, + }); + }, + response: arbitraryGenerateContentResponse, + ); + }); + + test('can enable code execution', () async { + final (client, model) = createModel(tools: [ + // Tool(codeExecution: CodeExecution()), + ]); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([Content.text(prompt)]), + verifyRequest: (_, request) { + expect(request['tools'], [ + {'codeExecution': {}} + ]); + }, + response: arbitraryGenerateContentResponse, + ); + }, skip: 'No support for code executation'); + + test('can override code execution', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + await client.checkRequest( + () => model.generateContent([ + Content.text(prompt) + ], tools: [ + // Tool(codeExecution: CodeExecution()), + ]), + verifyRequest: (_, request) { + expect(request['tools'], [ + {'codeExecution': {}} + ]); + }, + response: arbitraryGenerateContentResponse, + ); + }, skip: 'No support for code execution'); + }); + + group('generate content stream', () { + test('can make successful request', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + final results = {'First response', 'Second Response'}; + final response = await client.checkStreamRequest( + () async => model.generateContentStream([Content.text(prompt)]), + verifyRequest: (uri, request) { + expect( + uri, + Uri.parse( + 'https://firebasevertexai.googleapis.com/v1beta/' + 'projects/123/' + 'models/some-model:streamGenerateContent', + ), + ); + expect(request, { + 'model': 'models/$defaultModelName', + 'contents': [ + { + 'role': 'user', + 'parts': [ + {'text': prompt}, + ], + }, + ], + }); + }, + responses: [ + for (final result in results) + { + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': result}, + ], + }, + }, + ], + }, + ], + ); + expect( + response, + emitsInOrder([ + for (final result in results) + matchesGenerateContentResponse( + GenerateContentResponse([ + Candidate( + Content('model', [TextPart(result)]), + null, + null, + null, + null, + ), + ], null), + ), + ]), + ); + }); + + test('can override safety settings', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + final responses = await client.checkStreamRequest( + () async => model.generateContentStream( + [Content.text(prompt)], + safetySettings: [ + SafetySetting( + HarmCategory.dangerousContent, + HarmBlockThreshold.high, + null, + ), + ], + ), + verifyRequest: (_, request) { + expect(request['safetySettings'], [ + { + 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', + 'threshold': 'BLOCK_ONLY_HIGH', + }, + ]); + }, + responses: [arbitraryGenerateContentResponse], + ); + await responses.drain(); + }); + + test('can override generation config', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + final responses = await client.checkStreamRequest( + () async => model.generateContentStream([ + Content.text(prompt), + ], generationConfig: GenerationConfig(stopSequences: ['a'])), + verifyRequest: (_, request) { + expect(request['generationConfig'], { + 'stopSequences': ['a'], + }); + }, + responses: [arbitraryGenerateContentResponse], + ); + await responses.drain(); + }); + }); + + group('count tokens', () { + test('can make successful request', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + final response = await client.checkRequest( + () => model.countTokens([Content.text(prompt)]), + verifyRequest: (uri, request) { + expect( + uri, + Uri.parse( + 'https://firebasevertexai.googleapis.com/v1beta/' + 'projects/123/' + 'models/some-model:countTokens', + ), + ); + expect(request, { + 'generateContentRequest': { + 'model': 'models/$defaultModelName', + 'contents': [ + { + 'role': 'user', + 'parts': [ + {'text': prompt}, + ], + }, + ], + } + }); + }, + response: {'totalTokens': 2}, + ); + expect(response, matchesCountTokensResponse(CountTokensResponse(2))); + }); + + test('can override GenerateContentRequest fields', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + await client.checkRequest( + response: {'totalTokens': 100}, + () => model.countTokens( + [Content.text(prompt)], + // safetySettings: [ + // SafetySetting( + // HarmCategory.dangerousContent, + // HarmBlockThreshold.high, + // null, + // ), + // ], + // generationConfig: GenerationConfig(stopSequences: ['a']), + // tools: [ + // Tool(functionDeclarations: [ + // FunctionDeclaration( + // 'someFunction', + // 'Some cool function.', + // Schema(SchemaType.string, description: 'Some parameter.'), + // ), + // ]), + // ], + // toolConfig: ToolConfig( + // functionCallingConfig: FunctionCallingConfig( + // mode: FunctionCallingMode.any, + // allowedFunctionNames: {'someFunction'}, + // ), + // ), + ), + verifyRequest: (_, countTokensRequest) { + expect(countTokensRequest, isNotNull); + final request = countTokensRequest['generateContentRequest']! + as Map; + expect(request['safetySettings'], [ + { + 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', + 'threshold': 'BLOCK_ONLY_HIGH', + }, + ]); + expect(request['generationConfig'], { + 'stopSequences': ['a'], + }); + expect(request['tools'], [ + { + 'functionDeclarations': [ + { + 'name': 'someFunction', + 'description': 'Some cool function.', + 'parameters': { + 'type': 'STRING', + 'description': 'Some parameter.', + }, + }, + ], + }, + ]); + expect(request['toolConfig'], { + 'functionCallingConfig': { + 'mode': 'ANY', + 'allowedFunctionNames': ['someFunction'], + }, + }); + }, + ); + }, skip: 'Only content argument supported for countTokens'); + }); + + group('embed content', () { + test('can make successful request', () async { + final (client, model) = createModel(); + const prompt = 'Some prompt'; + final response = await client.checkRequest( + () async { + // await model.embedContent(Content.text(prompt)); + }, + verifyRequest: (uri, request) { + expect( + uri, + Uri.parse( + 'https://firebasevertexai.googleapis.com/v1beta/' + 'projects/123/' + 'models/some-model:embedContent', + ), + ); + expect(request, { + 'content': { + 'role': 'user', + 'parts': [ + {'text': prompt}, + ], + }, + }); + }, + response: { + 'embedding': { + 'values': [0.1, 0.2, 0.3], + }, + }, + ); + expect( + response, + // matchesEmbedContentResponse( + // EmbedContentResponse(ContentEmbedding([0.1, 0.2, 0.3])), + // ), + isNotNull, + ); + }); + + test('embed content with reduced output dimensionality', () async { + final (client, model) = createModel(); + const content = 'Some content'; + const outputDimensionality = 1; + final embeddingValues = [0.1]; + + await client.checkRequest(() async { + Content.text(content); + // await model.embedContent( + // Content.text(content), + // outputDimensionality: outputDimensionality, + // ); + }, verifyRequest: (_, request) { + expect(request, + containsPair('outputDimensionality', outputDimensionality)); + }, response: { + 'embedding': {'values': embeddingValues}, + }); + }); + }, skip: 'No support for embedding content'); + + group('batch embed contents', () { + test('can make successful request', () async { + final (client, model) = createModel(); + const prompt1 = 'Some prompt'; + const prompt2 = 'Another prompt'; + final embedding1 = [0.1, 0.2, 0.3]; + final embedding2 = [0.4, 0.5, 1.6]; + final response = await client.checkRequest( + () async { + // await model.batchEmbedContents([ + // EmbedContentRequest(Content.text(prompt1)), + // EmbedContentRequest(Content.text(prompt2)), + // ]); + }, + verifyRequest: (uri, request) { + expect( + uri, + Uri.parse( + 'https://firebasevertexai.googleapis.com/v1beta/' + 'projects/123/' + 'models/some-model:batchEmbedContents', + ), + ); + expect(request, { + 'requests': [ + { + 'content': { + 'role': 'user', + 'parts': [ + {'text': prompt1}, + ], + }, + 'model': 'models/$defaultModelName', + }, + { + 'content': { + 'role': 'user', + 'parts': [ + {'text': prompt2}, + ], + }, + 'model': 'models/$defaultModelName', + }, + ], + }); + }, + response: { + 'embeddings': [ + {'values': embedding1}, + {'values': embedding2}, + ], + }, + ); + expect( + response, + isNotNull, + // matchesBatchEmbedContentsResponse( + // BatchEmbedContentsResponse([ + // ContentEmbedding(embedding1), + // ContentEmbedding(embedding2), + // ]), + // ), + ); + }); + + test('batch embed contents with reduced output dimensionality', () async { + final (client, model) = createModel(); + const content1 = 'Some content 1'; + const content2 = 'Some content 2'; + const outputDimensionality = 1; + final embeddingValues1 = [0.1]; + final embeddingValues2 = [0.4]; + + await client.checkRequest(() async { + Content.text(content1); + Content.text(content2); + // await model.batchEmbedContents([ + // EmbedContentRequest( + // Content.text(content1), + // outputDimensionality: outputDimensionality, + // ), + // EmbedContentRequest( + // Content.text(content2), + // outputDimensionality: outputDimensionality, + // ), + // ]); + }, verifyRequest: (_, request) { + expect(request['requests'], [ + containsPair('outputDimensionality', outputDimensionality), + containsPair('outputDimensionality', outputDimensionality), + ]); + }, response: { + 'embeddings': [ + {'values': embeddingValues1}, + {'values': embeddingValues2}, + ], + }); + }); + }, skip: 'No support for embed content'); + }); +} diff --git a/packages/firebase_vertexai/firebase_vertexai/test/google_ai_response_parsing_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/google_ai_response_parsing_test.dart new file mode 100644 index 000000000000..1005e6dd0555 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/test/google_ai_response_parsing_test.dart @@ -0,0 +1,770 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; + +import 'package:firebase_vertexai/firebase_vertexai.dart'; +import 'package:firebase_vertexai/src/developer/api.dart'; +import 'package:flutter_test/flutter_test.dart'; + +import 'utils/matchers.dart'; + +void main() { + group('throws errors for invalid GenerateContentResponse', () { + test('with empty content', () { + const response = ''' +{ + "candidates": [ + { + "content": {}, + "index": 0 + } + ], + "promptFeedback": { + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + expect( + () => DeveloperSerialization().parseGenerateContentResponse(decoded), + throwsA( + isA().having( + (e) => e.message, + 'message', + startsWith('Unhandled format for Content:'), + ), + ), + ); + }); + + test('with a blocked prompt', () { + const response = ''' +{ + "promptFeedback": { + "blockReason": "SAFETY", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "HIGH" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = + DeveloperSerialization().parseGenerateContentResponse(decoded); + expect( + generateContentResponse, + matchesGenerateContentResponse( + GenerateContentResponse( + [], + PromptFeedback(BlockReason.safety, null, [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating(HarmCategory.hateSpeech, HarmProbability.high), + SafetyRating(HarmCategory.harassment, HarmProbability.negligible), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ]), + ), + ), + ); + expect( + () => generateContentResponse.text, + throwsA( + isA().having( + (e) => e.message, + 'message', + startsWith('Response was blocked due to safety'), + ), + ), + ); + }); + }); + + group('parses successful GenerateContentResponse', () { + test('with a basic reply', () async { + const response = ''' +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Mountain View, California, United States" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "promptFeedback": { + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = + DeveloperSerialization().parseGenerateContentResponse(decoded); + expect( + generateContentResponse, + matchesGenerateContentResponse( + GenerateContentResponse( + [ + Candidate( + Content.model([ + TextPart('Mountain View, California, United States'), + ]), + [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.hateSpeech, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.harassment, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ], + null, + FinishReason.stop, + null, + ), + ], + PromptFeedback(null, null, [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating(HarmCategory.hateSpeech, HarmProbability.negligible), + SafetyRating(HarmCategory.harassment, HarmProbability.negligible), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ]), + ), + ), + ); + }); + + test('with a citation', () async { + const response = ''' +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "placeholder" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ], + "citationMetadata": { + "citationSources": [ + { + "startIndex": 574, + "endIndex": 705, + "uri": "https://example.com/", + "license": "" + }, + { + "startIndex": 899, + "endIndex": 1026, + "uri": "https://example.com/", + "license": "" + }, + { + "startIndex": 899, + "endIndex": 1026 + }, + { + "uri": "https://example.com/", + "license": "" + }, + {} + ] + } + } + ], + "promptFeedback": { + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = + DeveloperSerialization().parseGenerateContentResponse(decoded); + expect( + generateContentResponse, + matchesGenerateContentResponse( + GenerateContentResponse( + [ + Candidate( + Content.model([TextPart('placeholder')]), + [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.hateSpeech, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.harassment, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ], + CitationMetadata([ + Citation(574, 705, Uri.https('example.com'), ''), + Citation(899, 1026, Uri.https('example.com'), ''), + ]), + FinishReason.stop, + null, + ), + ], + PromptFeedback(null, null, [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating(HarmCategory.hateSpeech, HarmProbability.negligible), + SafetyRating(HarmCategory.harassment, HarmProbability.negligible), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ]), + ), + ), + ); + }); + + test('with a vertex formatted citation', () async { + const response = ''' +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "placeholder" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ], + "citationMetadata": { + "citations": [ + { + "startIndex": 574, + "endIndex": 705, + "uri": "https://example.com/", + "license": "" + }, + { + "startIndex": 899, + "endIndex": 1026, + "uri": "https://example.com/", + "license": "" + }, + { + "startIndex": 899, + "endIndex": 1026 + }, + { + "uri": "https://example.com/", + "license": "" + }, + {} + ] + } + } + ], + "promptFeedback": { + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = + DeveloperSerialization().parseGenerateContentResponse(decoded); + expect( + generateContentResponse, + matchesGenerateContentResponse( + GenerateContentResponse( + [ + Candidate( + Content.model([TextPart('placeholder')]), + [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.hateSpeech, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.harassment, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ], + CitationMetadata([ + Citation(574, 705, Uri.https('example.com'), ''), + Citation(899, 1026, Uri.https('example.com'), ''), + ]), + FinishReason.stop, + null, + ), + ], + PromptFeedback(null, null, [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating(HarmCategory.hateSpeech, HarmProbability.negligible), + SafetyRating(HarmCategory.harassment, HarmProbability.negligible), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ]), + ), + ), + ); + }); + + test('with code execution', () async { + const response = ''' +{ + "candidates": [ + { + "content": { + "parts": [ + { + "executableCode": { + "language": "PYTHON", + "code": "print('hello world')" + } + }, + { + "codeExecutionResult": { + "outcome": "OUTCOME_OK", + "output": "hello world" + } + }, + { + "text": "hello world" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ] +} +'''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = + DeveloperSerialization().parseGenerateContentResponse(decoded); + expect( + generateContentResponse, + matchesGenerateContentResponse( + GenerateContentResponse( + [ + Candidate( + Content.model([ + // ExecutableCode(Language.python, 'print(\'hello world\')'), + // CodeExecutionResult(Outcome.ok, 'hello world'), + TextPart('hello world') + ]), + [], + null, + FinishReason.stop, + null, + ), + ], + null, + ), + ), + ); + }, skip: 'Code Execution Unsupported'); + + test('allows missing content', () async { + const response = ''' +{ + "candidates": [ + { + "finishReason": "SAFETY", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "LOW" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "MEDIUM" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ] +} +'''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = + DeveloperSerialization().parseGenerateContentResponse(decoded); + expect( + generateContentResponse, + matchesGenerateContentResponse( + GenerateContentResponse([ + Candidate( + Content(null, []), + [ + SafetyRating( + HarmCategory.sexuallyExplicit, + HarmProbability.negligible, + ), + SafetyRating( + HarmCategory.hateSpeech, HarmProbability.negligible), + SafetyRating( + HarmCategory.harassment, HarmProbability.negligible), + SafetyRating( + HarmCategory.dangerousContent, + HarmProbability.negligible, + ), + ], + CitationMetadata([]), + FinishReason.safety, + null), + ], null), + ), + ); + }); + + test('text getter joins content', () async { + const response = ''' +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Initial text" + }, + { + "functionCall": {"name": "someFunction", "args": {}} + }, + { + "text": " And more text" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ] +} +'''; + final decoded = jsonDecode(response) as Object; + final generateContentResponse = + DeveloperSerialization().parseGenerateContentResponse(decoded); + expect(generateContentResponse.text, 'Initial text And more text'); + expect(generateContentResponse.candidates.single.text, + 'Initial text And more text'); + }); + }); + + group('parses and throws error responses', () { + test('for invalid API key', () async { + const response = ''' +{ + "error": { + "code": 400, + "message": "API key not valid. Please pass a valid API key.", + "status": "INVALID_ARGUMENT", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "API_KEY_INVALID", + "domain": "googleapis.com", + "metadata": { + "service": "generativelanguage.googleapis.com" + } + }, + { + "@type": "type.googleapis.com/google.rpc.DebugInfo", + "detail": "Invalid API key: AIzv00G7VmUCUeC-5OglO3hcXM" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + final expectedThrow = throwsA( + isA().having( + (e) => e.message, + 'message', + 'API key not valid. Please pass a valid API key.', + ), + ); + expect( + () => DeveloperSerialization().parseGenerateContentResponse(decoded), + expectedThrow); + expect(() => DeveloperSerialization().parseCountTokensResponse(decoded), + expectedThrow); + // expect(() => parseEmbedContentResponse(decoded), expectedThrow); + }); + + test('for unsupported user location', () async { + const response = r''' +{ + "error": { + "code": 400, + "message": "User location is not supported for the API use.", + "status": "FAILED_PRECONDITION", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.DebugInfo", + "detail": "[ORIGINAL ERROR] generic::failed_precondition: User location is not supported for the API use. [google.rpc.error_details_ext] { message: \"User location is not supported for the API use.\" }" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + final expectedThrow = throwsA( + isA().having( + (e) => e.message, + 'message', + 'User location is not supported for the API use.', + ), + ); + expect( + () => DeveloperSerialization().parseGenerateContentResponse(decoded), + expectedThrow); + expect(() => DeveloperSerialization().parseCountTokensResponse(decoded), + expectedThrow); + // expect(() => parseEmbedContentResponse(decoded), expectedThrow); + }); + + test('for general server errors', () async { + const response = r''' +{ + "error": { + "code": 404, + "message": "models/unknown is not found for API version v1, or is not supported for GenerateContent. Call ListModels to see the list of available models and their supported methods.", + "status": "NOT_FOUND", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.DebugInfo", + "detail": "[ORIGINAL ERROR] generic::not_found: models/unknown is not found for API version v1, or is not supported for GenerateContent. Call ListModels to see the list of available models and their supported methods. [google.rpc.error_details_ext] { message: \"models/unknown is not found for API version v1, or is not supported for GenerateContent. Call ListModels to see the list of available models and their supported methods.\" }" + } + ] + } +} +'''; + final decoded = jsonDecode(response) as Object; + final expectedThrow = throwsA( + isA().having( + (e) => e.message, + 'message', + startsWith( + 'models/unknown is not found for API version v1, ' + 'or is not supported for GenerateContent.', + ), + ), + ); + expect( + () => DeveloperSerialization().parseGenerateContentResponse(decoded), + expectedThrow); + expect(() => DeveloperSerialization().parseCountTokensResponse(decoded), + expectedThrow); + // expect(() => parseEmbedContentResponse(decoded), expectedThrow); + }); + }); +} diff --git a/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart index 57c3fbaaa5ed..86e20ff99421 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart @@ -39,6 +39,7 @@ void main() { }) { final client = ClientController(); final model = createModelWithClient( + useVertexBackend: true, app: app, model: modelName, client: client.client, diff --git a/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart index d5c8d64f16d5..505031344b00 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/response_parsing_test.dart @@ -56,7 +56,7 @@ void main() { '''; final decoded = jsonDecode(response) as Object; expect( - () => parseGenerateContentResponse(decoded), + () => VertexSerialization().parseGenerateContentResponse(decoded), throwsA( isA().having( (e) => e.message, @@ -94,7 +94,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final generateContentResponse = parseGenerateContentResponse(decoded); + final generateContentResponse = + VertexSerialization().parseGenerateContentResponse(decoded); expect( generateContentResponse, matchesGenerateContentResponse( @@ -158,7 +159,7 @@ void main() { '''; final decoded = jsonDecode(response) as Object; expect( - () => parseGenerateContentResponse(decoded), + () => VertexSerialization().parseGenerateContentResponse(decoded), throwsA( isA().having( (e) => e.message, @@ -206,7 +207,7 @@ void main() { '''; final decoded = jsonDecode(response) as Object; expect( - () => parseGenerateContentResponse(decoded), + () => VertexSerialization().parseGenerateContentResponse(decoded), throwsA( isA().having( (e) => e.message, @@ -277,7 +278,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final generateContentResponse = parseGenerateContentResponse(decoded); + final generateContentResponse = + VertexSerialization().parseGenerateContentResponse(decoded); expect( generateContentResponse, matchesGenerateContentResponse( @@ -410,7 +412,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final generateContentResponse = parseGenerateContentResponse(decoded); + final generateContentResponse = + VertexSerialization().parseGenerateContentResponse(decoded); expect( generateContentResponse, matchesGenerateContentResponse( @@ -544,7 +547,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final generateContentResponse = parseGenerateContentResponse(decoded); + final generateContentResponse = + VertexSerialization().parseGenerateContentResponse(decoded); expect( generateContentResponse, matchesGenerateContentResponse( @@ -625,7 +629,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final generateContentResponse = parseGenerateContentResponse(decoded); + final generateContentResponse = + VertexSerialization().parseGenerateContentResponse(decoded); expect( generateContentResponse, matchesGenerateContentResponse( @@ -685,7 +690,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final generateContentResponse = parseGenerateContentResponse(decoded); + final generateContentResponse = + VertexSerialization().parseGenerateContentResponse(decoded); expect( generateContentResponse.text, 'Here is a description of the image:'); expect(generateContentResponse.usageMetadata?.totalTokenCount, 1913); @@ -722,7 +728,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final countTokensResponse = parseCountTokensResponse(decoded); + final countTokensResponse = + VertexSerialization().parseCountTokensResponse(decoded); expect(countTokensResponse.totalTokens, 1837); expect(countTokensResponse.promptTokensDetails?.first.modality, ContentModality.image); @@ -755,7 +762,8 @@ void main() { } '''; final decoded = jsonDecode(response) as Object; - final generateContentResponse = parseGenerateContentResponse(decoded); + final generateContentResponse = + VertexSerialization().parseGenerateContentResponse(decoded); expect(generateContentResponse.text, 'Initial text And more text'); expect(generateContentResponse.candidates.single.text, 'Initial text And more text'); @@ -795,8 +803,10 @@ void main() { 'API key not valid. Please pass a valid API key.', ), ); - expect(() => parseGenerateContentResponse(decoded), expectedThrow); - expect(() => parseCountTokensResponse(decoded), expectedThrow); + expect(() => VertexSerialization().parseGenerateContentResponse(decoded), + expectedThrow); + expect(() => VertexSerialization().parseCountTokensResponse(decoded), + expectedThrow); }); test('for unsupported user location', () async { @@ -823,8 +833,10 @@ void main() { 'User location is not supported for the API use.', ), ); - expect(() => parseGenerateContentResponse(decoded), expectedThrow); - expect(() => parseCountTokensResponse(decoded), expectedThrow); + expect(() => VertexSerialization().parseGenerateContentResponse(decoded), + expectedThrow); + expect(() => VertexSerialization().parseCountTokensResponse(decoded), + expectedThrow); }); test('for general server errors', () async { @@ -854,8 +866,10 @@ void main() { ), ), ); - expect(() => parseGenerateContentResponse(decoded), expectedThrow); - expect(() => parseCountTokensResponse(decoded), expectedThrow); + expect(() => VertexSerialization().parseGenerateContentResponse(decoded), + expectedThrow); + expect(() => VertexSerialization().parseCountTokensResponse(decoded), + expectedThrow); }); }); }