Skip to content

Commit

Permalink
huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Jun 25, 2024
1 parent e4d4865 commit 96bb712
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 2 deletions.
1 change: 0 additions & 1 deletion assets/huggingface_models.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[
{
"id": "llama_2_7b_chat",
"name": "Llama 2 7B Chat",
"family": "llama2",
"template": "llama2",
Expand Down
44 changes: 44 additions & 0 deletions lib/classes/huggingface_model.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import 'dart:convert';

import 'package:flutter/services.dart';
import 'package:maid/classes/static/logger.dart';

class HuggingfaceModel {
final String name;
final String family;
final String template;
final String repo;
final Map<String, String> tags;

HuggingfaceModel({
required this.name,
required this.family,
required this.template,
required this.repo,
required this.tags,
});

factory HuggingfaceModel.fromJson(Map<String, dynamic> json) {
return HuggingfaceModel(
name: json['name'],
family: json['family'],
template: json['template'],
repo: json['repo'],
tags: Map<String, String>.from(json['tags']),
);
}

static Future<List<HuggingfaceModel>> getAll() async {
try {
final jsonString = await rootBundle.loadString('assets/huggingface_models.json');

final List<dynamic> jsonList = json.decode(jsonString);

return jsonList.map((e) => HuggingfaceModel.fromJson(e)).toList();
}
catch (e) {
Logger.log('Error loading models: $e');
return [];
}
}
}
3 changes: 2 additions & 1 deletion lib/ui/desktop/buttons/huggingface_button.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import 'package:flutter/material.dart';
import 'package:flutter_svg/svg.dart';
import 'package:maid/ui/desktop/dropdowns/huggingface_model_dropdown.dart';

class HuggingfaceButton extends StatefulWidget {
const HuggingfaceButton({super.key});
Expand Down Expand Up @@ -33,7 +34,7 @@ class _HuggingfaceButtonState extends State<HuggingfaceButton> {
'Select HuggingFace Model',
textAlign: TextAlign.center
),
content: const Text('This is a HuggingFace button.'),
content: const HuggingfaceModelDropdown(),
actions: [
FilledButton(
onPressed: () => Navigator.of(context).pop(),
Expand Down
156 changes: 156 additions & 0 deletions lib/ui/desktop/dropdowns/huggingface_model_dropdown.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import 'package:flutter/material.dart';
import 'package:maid/classes/huggingface_model.dart';

class HuggingfaceModelDropdown extends StatefulWidget {
const HuggingfaceModelDropdown({super.key});

@override
State<HuggingfaceModelDropdown> createState() => _HuggingfaceModelDropdownState();
}

class _HuggingfaceModelDropdownState extends State<HuggingfaceModelDropdown> {
List<HuggingfaceModel> options = [];
HuggingfaceModel? selectedModel;
String? selectedTag;
bool modelOpen = false;
bool tagOpen = false;

@override
Widget build(BuildContext context) {
if (options.isEmpty) {
return buildFuture();
}

return buildRow(context);
}

Widget buildFuture() {
return FutureBuilder<List<HuggingfaceModel>>(
future: HuggingfaceModel.getAll(),
builder: (context, snapshot) {
if (snapshot.connectionState == ConnectionState.done) {
options = snapshot.data!;
return buildRow(context);
}
else {
return buildLoading();
}
}
);
}

Widget buildRow(BuildContext context) {
return Row(
mainAxisAlignment: MainAxisAlignment.spaceBetween,
children: [
buildModelRow(context),
buildTagRow(context)
]
);
}

Widget buildModelRow(BuildContext context) {
return Row(
mainAxisSize: MainAxisSize.min,
children: [
Text(
selectedModel?.name ?? 'Select Model',
style: TextStyle(
color: Theme.of(context).colorScheme.onSurface,
fontSize: 16,
fontWeight: FontWeight.bold
)
),
PopupMenuButton(
tooltip: 'Select HuggingFace Model',
icon: Icon(
modelOpen ? Icons.arrow_drop_up : Icons.arrow_drop_down,
color: Theme.of(context).colorScheme.onSurface,
size: 24,
),
offset: const Offset(0, 40),
itemBuilder: modelItemBuilder,
onOpened: () => setState(() => modelOpen = true),
onCanceled: () => setState(() => modelOpen = false),
onSelected: (value) => selectModel(value)
)
]
);
}

void selectModel(HuggingfaceModel model) {
setState(() {
modelOpen = false;
selectedModel = model;
selectedTag = model.tags.keys.first;
});
}

Widget buildTagRow(BuildContext context) {
return Row(
mainAxisSize: MainAxisSize.min,
children: [
Text(
selectedTag ?? 'Select Tag',
style: TextStyle(
color: Theme.of(context).colorScheme.onSurface,
fontSize: 16,
fontWeight: FontWeight.bold
)
),
PopupMenuButton(
tooltip: 'Select HuggingFace Model Tag',
icon: Icon(
tagOpen ? Icons.arrow_drop_up : Icons.arrow_drop_down,
color: Theme.of(context).colorScheme.onSurface,
size: 24,
),
offset: const Offset(0, 40),
itemBuilder: tagItemBuilder,
onOpened: () => setState(() => tagOpen = true),
onCanceled: () => setState(() => tagOpen = false),
onSelected: (value) => selectTag(value)
)
]
);
}

void selectTag(String tag) {
setState(() {
tagOpen = false;
selectedTag = tag;
});
}

List<PopupMenuEntry<HuggingfaceModel>> modelItemBuilder(BuildContext context) {
return options.map((HuggingfaceModel model) {
return PopupMenuItem<HuggingfaceModel>(
value: model,
child: Text(model.name),
onTap: () => selectModel(model),
);
}).toList();
}

List<PopupMenuEntry<String>> tagItemBuilder(BuildContext context) {
return selectedModel?.tags.keys.toList().map((String tag) {
return PopupMenuItem<String>(
value: tag,
child: Text(tag),
onTap: () => selectTag(tag),
);
}).toList() ?? [];
}

Widget buildLoading() {
return const SizedBox(
width: 24,
height: 24,
child: Center(
child: CircularProgressIndicator(
strokeWidth: 3.0
),
),
);
}
}

0 comments on commit 96bb712

Please sign in to comment.