Skip to content

Commit

Permalink
Fix #508 (#518)
Browse files Browse the repository at this point in the history
* add requirements dialog

* make storage op dialog futurebuilder
  • Loading branch information
danemadsen authored Apr 26, 2024
1 parent 823ff52 commit b5a06fc
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 42 deletions.
15 changes: 15 additions & 0 deletions lib/classes/google_gemini_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ class GoogleGeminiModel extends LargeLanguageModel {
fromMap(json);
}

@override
List<String> get missingRequirements {
List<String> missing = [];

if (name.isEmpty) {
missing.add('- A model option is required for prompting.\n');
}

if (token.isEmpty) {
missing.add('- An authentication token is required for prompting.\n');
}

return missing;
}

@override
Stream<String> prompt(List<ChatNode> messages) async* {
try {
Expand Down
18 changes: 18 additions & 0 deletions lib/classes/large_language_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,24 @@ class LargeLanguageModel extends ChangeNotifier {
};
}

List<String> get missingRequirements {
List<String> missing = [];

if (_name.isEmpty) {
missing.add('- A model option is required for prompting.\n');
}

if (_uri.isEmpty) {
missing.add('- A compatible URL is required for prompting.\n');
}

if (_token.isEmpty) {
missing.add('- An authentication token is required for prompting.\n');
}

return missing;
}

Stream<String> prompt(List<ChatNode> messages) {
throw UnimplementedError();
}
Expand Down
17 changes: 16 additions & 1 deletion lib/classes/llama_cpp_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ class LlamaCppModel extends LargeLanguageModel {
_promptFormat = value;
notifyListeners();
}

@override
List<String> get missingRequirements {
List<String> missing = [];

if (uri.isEmpty) {
missing.add('- A path to the model file is required.\n');
}

if (!File(uri).existsSync()) {
missing.add('- The file provided does not exist.\n');
}

return missing;
}

LlamaCppModel({
super.listener,
Expand Down Expand Up @@ -145,7 +160,7 @@ class LlamaCppModel extends LargeLanguageModel {
uri = file.path;
notifyListeners();
} catch (e) {
return "Error: $e";
return e.toString();
}

return "Model Successfully Loaded";
Expand Down
15 changes: 15 additions & 0 deletions lib/classes/ollama_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ class OllamaModel extends LargeLanguageModel {

String _ip = '';

@override
List<String> get missingRequirements {
List<String> missing = [];

if (name.isEmpty) {
missing.add('- A model option is required for prompting.\n');
}

if (uri.isEmpty) {
missing.add('- A compatible URL is required for prompting.\n');
}

return missing;
}

OllamaModel({
super.listener,
super.name,
Expand Down
8 changes: 5 additions & 3 deletions lib/ui/mobile/pages/platforms/llama_cpp_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ class _LlamaCppPageState extends State<LlamaCppPage> {
mainAxisAlignment: MainAxisAlignment.center,
children: [
FilledButton(
onPressed: () async {
await storageOperationDialog(
context, (session.model as LlamaCppModel).loadModel);
onPressed: () {
storageOperationDialog(
context,
(session.model as LlamaCppModel).loadModel
);
session.notify();
},
child: Text(
Expand Down
6 changes: 5 additions & 1 deletion lib/ui/mobile/widgets/chat_widgets/chat_field.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import 'dart:io';
import 'dart:async';

import 'package:flutter/material.dart';
import 'package:maid/ui/mobile/widgets/dialogs.dart';
import 'package:maid_llm/src/chat_node.dart';
import 'package:maid/classes/large_language_model.dart';
import 'package:maid/providers/session.dart';
Expand Down Expand Up @@ -109,7 +110,10 @@ class _ChatFieldState extends State<ChatField> {
),
IconButton(
onPressed: () {
if (session.chat.tail.finalised ) {
if (session.model.missingRequirements.isNotEmpty) {
showMissingRequirementsDialog(context);
}
else if (session.chat.tail.finalised) {
send();
}
},
Expand Down
36 changes: 27 additions & 9 deletions lib/ui/mobile/widgets/chat_widgets/chat_message.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'package:flutter/material.dart';
import 'package:maid/ui/mobile/widgets/dialogs.dart';
import 'package:maid_llm/maid_llm.dart';
import 'package:maid/providers/character.dart';
import 'package:maid/providers/session.dart';
Expand Down Expand Up @@ -145,9 +146,15 @@ class _ChatMessageState extends State<ChatMessage> with SingleTickerProviderStat
IconButton(
onPressed: () {
if (!context.read<Session>().chat.tail.finalised) return;
setState(() {
editing = true;
});

if (context.read<Session>().model.missingRequirements.isNotEmpty) {
showMissingRequirementsDialog(context);
}
else {
setState(() {
editing = true;
});
}
},
icon: const Icon(Icons.edit),
),
Expand All @@ -159,8 +166,13 @@ class _ChatMessageState extends State<ChatMessage> with SingleTickerProviderStat
IconButton(
onPressed: () {
if (!context.read<Session>().chat.tail.finalised) return;
context.read<Session>().regenerate(node.key, context);
setState(() {});

if (context.read<Session>().model.missingRequirements.isNotEmpty) {
showMissingRequirementsDialog(context);
} else {
context.read<Session>().regenerate(node.key, context);
setState(() {});
}
},
icon: const Icon(Icons.refresh),
),
Expand Down Expand Up @@ -189,10 +201,16 @@ class _ChatMessageState extends State<ChatMessage> with SingleTickerProviderStat
padding: const EdgeInsets.all(0),
onPressed: () {
if (!context.read<Session>().chat.tail.finalised) return;
setState(() {
editing = false;
});
context.read<Session>().edit(node.key, messageController.text, context);

if (context.read<Session>().model.missingRequirements.isNotEmpty) {
showMissingRequirementsDialog(context);
}
else {
setState(() {
editing = false;
});
context.read<Session>().edit(node.key, messageController.text, context);
}
},
icon: const Icon(Icons.done)),
IconButton(
Expand Down
99 changes: 71 additions & 28 deletions lib/ui/mobile/widgets/dialogs.dart
Original file line number Diff line number Diff line change
@@ -1,34 +1,77 @@

import 'package:flutter/material.dart';
import 'package:maid/providers/session.dart';
import 'package:provider/provider.dart';

Future<void> storageOperationDialog(BuildContext context, Future<String> Function(BuildContext context) storageFunction) async {
String ret = await storageFunction(context);
// Ensure that the context is still valid before attempting to show the dialog.
if (context.mounted) {
showDialog(
context: context,
builder: (BuildContext context) {
return AlertDialog(
title: Text(ret),
alignment: Alignment.center,
actionsAlignment: MainAxisAlignment.center,
backgroundColor: Theme.of(context).colorScheme.background,
shape: const RoundedRectangleBorder(
borderRadius: BorderRadius.all(Radius.circular(20.0)),
),
actions: [
FilledButton(
onPressed: () {
Navigator.of(context).pop();
},
child: Text(
"Close",
style: Theme.of(context).textTheme.labelLarge,
void storageOperationDialog(BuildContext context, Future<String> Function(BuildContext context) storageFunction) {
showDialog(
context: context,
builder: (BuildContext context) {
return FutureBuilder<String>(
future: storageFunction(context),
builder: (BuildContext context, AsyncSnapshot<String> snapshot) {
if (snapshot.connectionState == ConnectionState.done) {
return AlertDialog(
title: Text(
snapshot.data!,
textAlign: TextAlign.center
),
actionsAlignment: MainAxisAlignment.center,
actions: [
FilledButton(
onPressed: () {
Navigator.of(context).pop();
},
child: Text(
"Close",
style: Theme.of(context).textTheme.labelLarge,
),
),
],
);
} else {
return const AlertDialog(
title: Text(
"Storage Operation Pending",
textAlign: TextAlign.center
),
content: Center(
heightFactor: 1.0,
child: CircularProgressIndicator(),
)
);
}
},
);
},
);
}

void showMissingRequirementsDialog(BuildContext context) {
showDialog(
context: context,
builder: (context) {
final requirement = context.read<Session>().model.missingRequirements;

return AlertDialog(
title: const Text(
"Missing Requirements",
textAlign: TextAlign.center
),
actionsAlignment: MainAxisAlignment.center,
content: Text(
requirement.join()
),
actions: [
FilledButton(
onPressed: () => Navigator.of(context).pop(),
child: Text(
"OK",
style: Theme.of(context).textTheme.labelLarge
),
],
);
},
);
}
),
],
);
},
);
}

0 comments on commit b5a06fc

Please sign in to comment.