diff --git a/assets/huggingface_models.json b/assets/huggingface_models.json index 184cd88b..bbfa342b 100644 --- a/assets/huggingface_models.json +++ b/assets/huggingface_models.json @@ -1,6 +1,5 @@ [ { - "id": "llama_2_7b_chat", "name": "Llama 2 7B Chat", "family": "llama2", "template": "llama2", diff --git a/lib/classes/huggingface_model.dart b/lib/classes/huggingface_model.dart new file mode 100644 index 00000000..03942978 --- /dev/null +++ b/lib/classes/huggingface_model.dart @@ -0,0 +1,44 @@ +import 'dart:convert'; + +import 'package:flutter/services.dart'; +import 'package:maid/classes/static/logger.dart'; + +class HuggingfaceModel { + final String name; + final String family; + final String template; + final String repo; + final Map tags; + + HuggingfaceModel({ + required this.name, + required this.family, + required this.template, + required this.repo, + required this.tags, + }); + + factory HuggingfaceModel.fromJson(Map json) { + return HuggingfaceModel( + name: json['name'], + family: json['family'], + template: json['template'], + repo: json['repo'], + tags: Map.from(json['tags']), + ); + } + + static Future> getAll() async { + try { + final jsonString = await rootBundle.loadString('assets/huggingface_models.json'); + + final List jsonList = json.decode(jsonString); + + return jsonList.map((e) => HuggingfaceModel.fromJson(e)).toList(); + } + catch (e) { + Logger.log('Error loading models: $e'); + return []; + } + } +} \ No newline at end of file diff --git a/lib/ui/desktop/buttons/huggingface_button.dart b/lib/ui/desktop/buttons/huggingface_button.dart index a09778ad..dff21a3a 100644 --- a/lib/ui/desktop/buttons/huggingface_button.dart +++ b/lib/ui/desktop/buttons/huggingface_button.dart @@ -1,5 +1,6 @@ import 'package:flutter/material.dart'; import 'package:flutter_svg/svg.dart'; +import 'package:maid/ui/desktop/dropdowns/huggingface_model_dropdown.dart'; class HuggingfaceButton extends StatefulWidget { const HuggingfaceButton({super.key}); @@ -33,7 +34,7 @@ class _HuggingfaceButtonState extends State { 'Select HuggingFace Model', textAlign: TextAlign.center ), - content: const Text('This is a HuggingFace button.'), + content: const HuggingfaceModelDropdown(), actions: [ FilledButton( onPressed: () => Navigator.of(context).pop(), diff --git a/lib/ui/desktop/dropdowns/huggingface_model_dropdown.dart b/lib/ui/desktop/dropdowns/huggingface_model_dropdown.dart new file mode 100644 index 00000000..193a9256 --- /dev/null +++ b/lib/ui/desktop/dropdowns/huggingface_model_dropdown.dart @@ -0,0 +1,156 @@ +import 'package:flutter/material.dart'; +import 'package:maid/classes/huggingface_model.dart'; + +class HuggingfaceModelDropdown extends StatefulWidget { + const HuggingfaceModelDropdown({super.key}); + + @override + State createState() => _HuggingfaceModelDropdownState(); +} + +class _HuggingfaceModelDropdownState extends State { + List options = []; + HuggingfaceModel? selectedModel; + String? selectedTag; + bool modelOpen = false; + bool tagOpen = false; + + @override + Widget build(BuildContext context) { + if (options.isEmpty) { + return buildFuture(); + } + + return buildRow(context); + } + + Widget buildFuture() { + return FutureBuilder>( + future: HuggingfaceModel.getAll(), + builder: (context, snapshot) { + if (snapshot.connectionState == ConnectionState.done) { + options = snapshot.data!; + return buildRow(context); + } + else { + return buildLoading(); + } + } + ); + } + + Widget buildRow(BuildContext context) { + return Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + buildModelRow(context), + buildTagRow(context) + ] + ); + } + + Widget buildModelRow(BuildContext context) { + return Row( + mainAxisSize: MainAxisSize.min, + children: [ + Text( + selectedModel?.name ?? 'Select Model', + style: TextStyle( + color: Theme.of(context).colorScheme.onSurface, + fontSize: 16, + fontWeight: FontWeight.bold + ) + ), + PopupMenuButton( + tooltip: 'Select HuggingFace Model', + icon: Icon( + modelOpen ? Icons.arrow_drop_up : Icons.arrow_drop_down, + color: Theme.of(context).colorScheme.onSurface, + size: 24, + ), + offset: const Offset(0, 40), + itemBuilder: modelItemBuilder, + onOpened: () => setState(() => modelOpen = true), + onCanceled: () => setState(() => modelOpen = false), + onSelected: (value) => selectModel(value) + ) + ] + ); + } + + void selectModel(HuggingfaceModel model) { + setState(() { + modelOpen = false; + selectedModel = model; + selectedTag = model.tags.keys.first; + }); + } + + Widget buildTagRow(BuildContext context) { + return Row( + mainAxisSize: MainAxisSize.min, + children: [ + Text( + selectedTag ?? 'Select Tag', + style: TextStyle( + color: Theme.of(context).colorScheme.onSurface, + fontSize: 16, + fontWeight: FontWeight.bold + ) + ), + PopupMenuButton( + tooltip: 'Select HuggingFace Model Tag', + icon: Icon( + tagOpen ? Icons.arrow_drop_up : Icons.arrow_drop_down, + color: Theme.of(context).colorScheme.onSurface, + size: 24, + ), + offset: const Offset(0, 40), + itemBuilder: tagItemBuilder, + onOpened: () => setState(() => tagOpen = true), + onCanceled: () => setState(() => tagOpen = false), + onSelected: (value) => selectTag(value) + ) + ] + ); + } + + void selectTag(String tag) { + setState(() { + tagOpen = false; + selectedTag = tag; + }); + } + + List> modelItemBuilder(BuildContext context) { + return options.map((HuggingfaceModel model) { + return PopupMenuItem( + value: model, + child: Text(model.name), + onTap: () => selectModel(model), + ); + }).toList(); + } + + List> tagItemBuilder(BuildContext context) { + return selectedModel?.tags.keys.toList().map((String tag) { + return PopupMenuItem( + value: tag, + child: Text(tag), + onTap: () => selectTag(tag), + ); + }).toList() ?? []; + } + + Widget buildLoading() { + return const SizedBox( + width: 24, + height: 24, + child: Center( + child: CircularProgressIndicator( + strokeWidth: 3.0 + ), + ), + ); + } +} \ No newline at end of file