diff --git a/packages/firebase_vertexai/firebase_vertexai/example/.gitignore b/packages/firebase_vertexai/firebase_vertexai/example/.gitignore index 0498b592dfa0..53bed76d8faa 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/.gitignore +++ b/packages/firebase_vertexai/firebase_vertexai/example/.gitignore @@ -48,3 +48,4 @@ app.*.map.json firebase_options.dart google-services.json GoogleService-Info.plist +firebase.json diff --git a/packages/firebase_vertexai/firebase_vertexai/example/android/settings.gradle b/packages/firebase_vertexai/firebase_vertexai/example/android/settings.gradle index 9151bc043341..40cbd22bb13b 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/android/settings.gradle +++ b/packages/firebase_vertexai/firebase_vertexai/example/android/settings.gradle @@ -19,6 +19,9 @@ pluginManagement { plugins { id "dev.flutter.flutter-plugin-loader" version "1.0.0" id "com.android.application" version "7.3.0" apply false + // START: FlutterFire Configuration + id "com.google.gms.google-services" version "4.3.15" apply false + // END: FlutterFire Configuration id "org.jetbrains.kotlin.android" version "1.9.22" apply false } diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart index e6d880e15427..d4246d4211e1 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart @@ -38,23 +38,78 @@ void main() async { // await Firebase.initializeApp(options: DefaultFirebaseOptions.currentPlatform); await Firebase.initializeApp(); await FirebaseAuth.instance.signInAnonymously(); + runApp(const GenerativeAISample()); +} - var vertexInstance = - FirebaseVertexAI.instanceFor(auth: FirebaseAuth.instance); - final model = vertexInstance.generativeModel(model: 'gemini-1.5-flash'); +class GenerativeAISample extends StatefulWidget { + const GenerativeAISample({super.key}); - runApp(GenerativeAISample(model: model)); + @override + State createState() => _GenerativeAISampleState(); } -class GenerativeAISample extends StatelessWidget { - final GenerativeModel model; +class _GenerativeAISampleState extends State { + bool _useVertexBackend = false; + late GenerativeModel _currentModel; + late ImagenModel _currentImagenModel; + int _currentBottomNavIndex = 0; + + @override + void initState() { + super.initState(); + + _initializeModel(_useVertexBackend); + } - const GenerativeAISample({super.key, required this.model}); + void _initializeModel(bool useVertexBackend) { + if (useVertexBackend) { + final vertexInstance = + FirebaseVertexAI.instanceFor(auth: FirebaseAuth.instance); + _currentModel = vertexInstance.generativeModel(model: 'gemini-1.5-flash'); + _currentImagenModel = _initializeImagenModel(vertexInstance); + } else { + final googleAI = FirebaseVertexAI.googleAI(auth: FirebaseAuth.instance); + _currentModel = googleAI.generativeModel(model: 'gemini-2.0-flash'); + _currentImagenModel = _initializeImagenModel(googleAI); + } + } + + ImagenModel _initializeImagenModel(FirebaseVertexAI instance) { + var generationConfig = ImagenGenerationConfig( + negativePrompt: 'frog', + numberOfImages: 1, + aspectRatio: ImagenAspectRatio.square1x1, + imageFormat: ImagenFormat.jpeg(compressionQuality: 75), + ); + return instance.imagenModel( + model: 'imagen-3.0-generate-001', + generationConfig: generationConfig, + safetySettings: ImagenSafetySettings( + ImagenSafetyFilterLevel.blockLowAndAbove, + ImagenPersonFilterLevel.allowAdult, + ), + ); + } + + void _toggleBackend(bool value) { + setState(() { + _useVertexBackend = value; + }); + _initializeModel(_useVertexBackend); + } + + void _onBottomNavTapped(int index) { + setState(() { + _currentBottomNavIndex = index; + }); + } @override Widget build(BuildContext context) { return MaterialApp( - title: 'Flutter + Vertex AI', + title: 'Flutter + ${_useVertexBackend ? 'Vertex AI' : 'Google AI'}', + debugShowCheckedModeBanner: false, + themeMode: ThemeMode.dark, theme: ThemeData( colorScheme: ColorScheme.fromSeed( brightness: Brightness.dark, @@ -62,137 +117,204 @@ class GenerativeAISample extends StatelessWidget { ), useMaterial3: true, ), - home: HomeScreen(model: model), + home: HomeScreen( + key: ValueKey( + '${_useVertexBackend}_${_currentModel.hashCode}', + ), + model: _currentModel, + imagenModel: _currentImagenModel, + useVertexBackend: _useVertexBackend, + onBackendChanged: _toggleBackend, + selectedIndex: _currentBottomNavIndex, + onSelectedIndexChanged: _onBottomNavTapped, + ), ); } } class HomeScreen extends StatefulWidget { final GenerativeModel model; - const HomeScreen({super.key, required this.model}); + final ImagenModel imagenModel; + final bool useVertexBackend; + final ValueChanged onBackendChanged; + final int selectedIndex; + final ValueChanged onSelectedIndexChanged; + + const HomeScreen({ + super.key, + required this.model, + required this.imagenModel, + required this.useVertexBackend, + required this.onBackendChanged, + required this.selectedIndex, + required this.onSelectedIndexChanged, + }); @override State createState() => _HomeScreenState(); } class _HomeScreenState extends State { - int _selectedIndex = 0; - - List get _pages => [ - // Build _pages dynamically - ChatPage(title: 'Chat', model: widget.model), - AudioPage(title: 'Audio', model: widget.model), - TokenCountPage(title: 'Token Count', model: widget.model), - const FunctionCallingPage( - title: 'Function Calling', - ), // function calling will initial its own model - ImagePromptPage(title: 'Image Prompt', model: widget.model), - ImagenPage(title: 'Imagen Model', model: widget.model), - SchemaPromptPage(title: 'Schema Prompt', model: widget.model), - DocumentPage(title: 'Document Prompt', model: widget.model), - VideoPage(title: 'Video Prompt', model: widget.model), - BidiPage(title: 'Bidi Stream', model: widget.model), - ]; - void _onItemTapped(int index) { - setState(() { - _selectedIndex = index; - }); + widget.onSelectedIndexChanged(index); + } + +// Method to build the selected page on demand + Widget _buildSelectedPage( + int index, + GenerativeModel currentModel, + ImagenModel currentImagenModel, + bool useVertexBackend, + ) { + switch (index) { + case 0: + return ChatPage(title: 'Chat', model: currentModel); + case 1: + return AudioPage(title: 'Audio', model: currentModel); + case 2: + return TokenCountPage(title: 'Token Count', model: currentModel); + case 3: + // FunctionCallingPage initializes its own model as per original design + return FunctionCallingPage( + title: 'Function Calling', + useVertexBackend: useVertexBackend, + ); + case 4: + return ImagePromptPage(title: 'Image Prompt', model: currentModel); + case 5: + return ImagenPage(title: 'Imagen Model', model: currentImagenModel); + case 6: + return SchemaPromptPage(title: 'Schema Prompt', model: currentModel); + case 7: + return DocumentPage(title: 'Document Prompt', model: currentModel); + case 8: + return VideoPage(title: 'Video Prompt', model: currentModel); + case 9: + return BidiPage(title: 'Bidi Stream', model: currentModel); + default: + // Fallback to the first page in case of an unexpected index + return ChatPage(title: 'Chat', model: currentModel); + } } @override Widget build(BuildContext context) { return Scaffold( appBar: AppBar( - title: const Text('Flutter + Vertex AI'), + title: Text( + 'Flutter + ${widget.useVertexBackend ? 'Vertex AI' : 'Google AI'}', + ), + actions: [ + Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Text( + 'Google AI', + style: TextStyle( + fontSize: 12, + color: widget.useVertexBackend + ? Theme.of(context) + .colorScheme + .onSurface + .withValues(alpha: 0.7) + : Theme.of(context).colorScheme.primary, + ), + ), + Switch( + value: widget.useVertexBackend, + onChanged: widget.onBackendChanged, + activeTrackColor: Colors.green.withValues(alpha: 0.5), + inactiveTrackColor: Colors.blueGrey.withValues(alpha: 0.5), + activeColor: Colors.green, + inactiveThumbColor: Colors.blueGrey, + ), + Text( + 'Vertex AI', + style: TextStyle( + fontSize: 12, + color: widget.useVertexBackend + ? Theme.of(context).colorScheme.primary + : Theme.of(context) + .colorScheme + .onSurface + .withValues(alpha: 0.7), + ), + ), + ], + ), + ), + ], ), body: Center( - child: _pages.elementAt(_selectedIndex), + child: _buildSelectedPage( + widget.selectedIndex, + widget.model, + widget.imagenModel, + widget.useVertexBackend, + ), ), bottomNavigationBar: BottomNavigationBar( - items: [ + type: BottomNavigationBarType.fixed, + selectedFontSize: 10, + unselectedFontSize: 9, + selectedItemColor: Theme.of(context).colorScheme.primary, + unselectedItemColor: + Theme.of(context).colorScheme.onSurface.withValues(alpha: 0.7), + items: const [ BottomNavigationBarItem( - icon: Icon( - Icons.chat, - color: Theme.of(context).colorScheme.primary, - ), + icon: Icon(Icons.chat), label: 'Chat', tooltip: 'Chat', ), BottomNavigationBarItem( - icon: Icon( - Icons.mic, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Audio Prompt', + icon: Icon(Icons.mic), + label: 'Audio', tooltip: 'Audio Prompt', ), BottomNavigationBarItem( - icon: Icon( - Icons.numbers, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Token Count', + icon: Icon(Icons.numbers), + label: 'Tokens', tooltip: 'Token Count', ), BottomNavigationBarItem( - icon: Icon( - Icons.functions, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Function Calling', + icon: Icon(Icons.functions), + label: 'Functions', tooltip: 'Function Calling', ), BottomNavigationBarItem( - icon: Icon( - Icons.image, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Image Prompt', + icon: Icon(Icons.image), + label: 'Image', tooltip: 'Image Prompt', ), BottomNavigationBarItem( - icon: Icon( - Icons.image_search, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Imagen Model', + icon: Icon(Icons.image_search), + label: 'Imagen', tooltip: 'Imagen Model', ), BottomNavigationBarItem( - icon: Icon( - Icons.schema, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Schema Prompt', + icon: Icon(Icons.schema), + label: 'Schema', tooltip: 'Schema Prompt', ), BottomNavigationBarItem( - icon: Icon( - Icons.edit_document, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Document Prompt', + icon: Icon(Icons.edit_document), + label: 'Document', tooltip: 'Document Prompt', ), BottomNavigationBarItem( - icon: Icon( - Icons.video_collection, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Video Prompt', + icon: Icon(Icons.video_collection), + label: 'Video', tooltip: 'Video Prompt', ), BottomNavigationBarItem( - icon: Icon( - Icons.stream, - color: Theme.of(context).colorScheme.primary, - ), - label: 'Bidi Stream', + icon: Icon(Icons.stream), + label: 'Bidi', tooltip: 'Bidi Stream', ), ], - currentIndex: _selectedIndex, + currentIndex: widget.selectedIndex, onTap: _onItemTapped, ), ); diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/function_calling_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/function_calling_page.dart index 130afff5ce92..dce322c1822d 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/function_calling_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/function_calling_page.dart @@ -18,9 +18,14 @@ import 'package:firebase_auth/firebase_auth.dart'; import '../widgets/message_widget.dart'; class FunctionCallingPage extends StatefulWidget { - const FunctionCallingPage({super.key, required this.title}); + const FunctionCallingPage({ + super.key, + required this.title, + required this.useVertexBackend, + }); final String title; + final bool useVertexBackend; @override State createState() => _FunctionCallingPageState(); @@ -41,14 +46,23 @@ class _FunctionCallingPageState extends State { @override void initState() { super.initState(); - var vertex_instance = - FirebaseVertexAI.instanceFor(auth: FirebaseAuth.instance); - _functionCallModel = vertex_instance.generativeModel( - model: 'gemini-1.5-flash', - tools: [ - Tool.functionDeclarations([fetchWeatherTool]), - ], - ); + if (widget.useVertexBackend) { + var vertexAI = FirebaseVertexAI.instanceFor(auth: FirebaseAuth.instance); + _functionCallModel = vertexAI.generativeModel( + model: 'gemini-2.0-flash', + tools: [ + Tool.functionDeclarations([fetchWeatherTool]), + ], + ); + } else { + var googleAI = FirebaseVertexAI.googleAI(auth: FirebaseAuth.instance); + _functionCallModel = googleAI.generativeModel( + model: 'gemini-2.0-flash', + tools: [ + Tool.functionDeclarations([fetchWeatherTool]), + ], + ); + } } // This is a hypothetical API to return a fake weather data collection for @@ -146,7 +160,7 @@ class _FunctionCallingPageState extends State { _loading = true; }); final functionCallChat = _functionCallModel.startChat(); - const prompt = 'What is the weather like in Boston on 10/02 this year?'; + const prompt = 'What is the weather like in Boston on 10/02 in year 2024?'; // Send the message to the generative model. var response = await functionCallChat.sendMessage( diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart index bb08a4b5533a..b9ac03669667 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart @@ -25,7 +25,7 @@ class ImagenPage extends StatefulWidget { }); final String title; - final GenerativeModel model; + final ImagenModel model; @override State createState() => _ImagenPageState(); @@ -37,26 +37,6 @@ class _ImagenPageState extends State { final FocusNode _textFieldFocus = FocusNode(); final List _generatedContent = []; bool _loading = false; - late final ImagenModel _imagenModel; - - @override - void initState() { - super.initState(); - var generationConfig = ImagenGenerationConfig( - negativePrompt: 'frog', - numberOfImages: 1, - aspectRatio: ImagenAspectRatio.square1x1, - imageFormat: ImagenFormat.jpeg(compressionQuality: 75), - ); - _imagenModel = FirebaseVertexAI.instance.imagenModel( - model: 'imagen-3.0-generate-001', - generationConfig: generationConfig, - safetySettings: ImagenSafetySettings( - ImagenSafetyFilterLevel.blockLowAndAbove, - ImagenPersonFilterLevel.allowAdult, - ), - ); - } void _scrollDown() { WidgetsBinding.instance.addPostFrameCallback( @@ -153,7 +133,7 @@ class _ImagenPageState extends State { _loading = true; }); - var response = await _imagenModel.generateImages(prompt); + var response = await widget.model.generateImages(prompt); if (response.images.isNotEmpty) { var imagenImage = response.images[0]; @@ -181,7 +161,7 @@ class _ImagenPageState extends State { // }); // var gcsUrl = 'gs://vertex-ai-example-ef5a2.appspot.com/imagen'; - // var response = await _imagenModel.generateImagesGCS(prompt, gcsUrl); + // var response = await widget.model.generateImagesGCS(prompt, gcsUrl); // if (response.images.isNotEmpty) { // var imagenImage = response.images[0]; diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart index fb5cb3b18e1d..1c3f1e21c713 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart @@ -82,9 +82,10 @@ class FirebaseVertexAI extends FirebasePluginPlatform { String? location, }) { app ??= Firebase.app(); + var instanceKey = '${app.name}::vertexai'; - if (_cachedInstances.containsKey(app.name)) { - return _cachedInstances[app.name]!; + if (_cachedInstances.containsKey(instanceKey)) { + return _cachedInstances[instanceKey]!; } location ??= _defaultLocation; @@ -96,7 +97,7 @@ class FirebaseVertexAI extends FirebasePluginPlatform { auth: auth, useVertexBackend: true, ); - _cachedInstances[app.name] = newInstance; + _cachedInstances[instanceKey] = newInstance; return newInstance; } @@ -111,9 +112,10 @@ class FirebaseVertexAI extends FirebasePluginPlatform { FirebaseAuth? auth, }) { app ??= Firebase.app(); + var instanceKey = '${app.name}::googleai'; - if (_cachedInstances.containsKey(app.name)) { - return _cachedInstances[app.name]!; + if (_cachedInstances.containsKey(instanceKey)) { + return _cachedInstances[instanceKey]!; } FirebaseVertexAI newInstance = FirebaseVertexAI._( @@ -123,7 +125,7 @@ class FirebaseVertexAI extends FirebasePluginPlatform { auth: auth, useVertexBackend: false, ); - _cachedInstances[app.name] = newInstance; + _cachedInstances[instanceKey] = newInstance; return newInstance; }