Skip to content

Commit

Permalink
feat: ai writer block
Browse files Browse the repository at this point in the history
  • Loading branch information
richardshiue committed Feb 5, 2025
1 parent 43e64d8 commit 148643a
Show file tree
Hide file tree
Showing 55 changed files with 2,532 additions and 2,746 deletions.

This file was deleted.

4 changes: 4 additions & 0 deletions frontend/appflowy_flutter/lib/ai/ai.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
export 'service/ai_client.dart';
export 'service/ai_entities.dart';
export 'service/ai_prompt_input_bloc.dart';
export 'service/appflowy_ai_service.dart';
export 'widgets/loading_indicator.dart';
export 'widgets/prompt_input/action_buttons.dart';
export 'widgets/prompt_input/desktop_input_text_field.dart';
Expand Down
21 changes: 2 additions & 19 deletions frontend/appflowy_flutter/lib/ai/service/ai_client.dart
Original file line number Diff line number Diff line change
@@ -1,34 +1,17 @@
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:appflowy_result/appflowy_result.dart';

import 'ai_entities.dart';
import 'error.dart';
import 'text_completion.dart';

abstract class AIRepository {
Future<void> getStreamedCompletions({
required String prompt,
required Future<void> Function() onStart,
required Future<void> Function(TextCompletionResponse response) onProcess,
required Future<void> Function() onEnd,
required void Function(AIError error) onError,
String? suffix,
int maxTokens = 2048,
double temperature = 0.3,
bool useAction = false,
});

Future<void> streamCompletion({
String? objectId,
required String text,
PredefinedFormat? format,
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
required Future<void> Function() onEnd,
required void Function(AIError error) onError,
});

Future<FlowyResult<List<String>, AIError>> generateImage({
required String prompt,
int n = 1,
});
}
89 changes: 89 additions & 0 deletions frontend/appflowy_flutter/lib/ai/service/ai_entities.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import 'package:appflowy/generated/flowy_svgs.g.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:equatable/equatable.dart';

class PredefinedFormat extends Equatable {
const PredefinedFormat({
required this.imageFormat,
required this.textFormat,
});

const PredefinedFormat.auto()
: imageFormat = ImageFormat.text,
textFormat = TextFormat.auto;

final ImageFormat imageFormat;
final TextFormat? textFormat;

PredefinedFormatPB toPB() {
return PredefinedFormatPB(
imageFormat: switch (imageFormat) {
ImageFormat.text => ResponseImageFormatPB.TextOnly,
ImageFormat.image => ResponseImageFormatPB.ImageOnly,
ImageFormat.textAndImage => ResponseImageFormatPB.TextAndImage,
},
textFormat: switch (textFormat) {
TextFormat.auto => ResponseTextFormatPB.Paragraph,
TextFormat.bulletList => ResponseTextFormatPB.BulletedList,
TextFormat.numberedList => ResponseTextFormatPB.NumberedList,
TextFormat.table => ResponseTextFormatPB.Table,
_ => null,
},
);
}

@override
List<Object?> get props => [imageFormat, textFormat];
}

enum ImageFormat {
text,
image,
textAndImage;

bool get hasText => this == text || this == textAndImage;

FlowySvgData get icon {
return switch (this) {
ImageFormat.text => FlowySvgs.ai_text_s,
ImageFormat.image => FlowySvgs.ai_image_s,
ImageFormat.textAndImage => FlowySvgs.ai_text_image_s,
};
}

String get i18n {
return switch (this) {
ImageFormat.text => LocaleKeys.chat_changeFormat_textOnly.tr(),
ImageFormat.image => LocaleKeys.chat_changeFormat_imageOnly.tr(),
ImageFormat.textAndImage =>
LocaleKeys.chat_changeFormat_textAndImage.tr(),
};
}
}

enum TextFormat {
auto,
bulletList,
numberedList,
table;

FlowySvgData get icon {
return switch (this) {
TextFormat.auto => FlowySvgs.ai_paragraph_s,
TextFormat.bulletList => FlowySvgs.ai_list_s,
TextFormat.numberedList => FlowySvgs.ai_number_list_s,
TextFormat.table => FlowySvgs.ai_table_s,
};
}

String get i18n {
return switch (this) {
TextFormat.auto => LocaleKeys.chat_changeFormat_text.tr(),
TextFormat.bulletList => LocaleKeys.chat_changeFormat_bullet.tr(),
TextFormat.numberedList => LocaleKeys.chat_changeFormat_number.tr(),
TextFormat.table => LocaleKeys.chat_changeFormat_table.tr(),
};
}
}
108 changes: 19 additions & 89 deletions frontend/appflowy_flutter/lib/ai/service/appflowy_ai_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,119 +4,47 @@ import 'dart:isolate';

import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-ai/entities.pb.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:fixnum/fixnum.dart' as fixnum;

import 'ai_client.dart';
import 'ai_entities.dart';
import 'error.dart';
import 'text_completion.dart';

enum AskAIAction {
summarize,
fixSpelling,
improveWriting,
makeItLonger;

String get toInstruction => switch (this) {
summarize => 'Tl;dr',
fixSpelling => 'Correct this to standard English:',
improveWriting => 'Rewrite this in your own words:',
makeItLonger => 'Make this text longer:',
};

String prompt(String input) => switch (this) {
summarize => '$input\n\n$toInstruction',
_ => "$toInstruction\n\n$input",
};

static AskAIAction from(int index) => switch (index) {
0 => summarize,
1 => fixSpelling,
2 => improveWriting,
3 => makeItLonger,
_ => fixSpelling
};

String get name => switch (this) {
summarize => LocaleKeys.document_plugins_smartEditSummarize.tr(),
fixSpelling => LocaleKeys.document_plugins_smartEditFixSpelling.tr(),
improveWriting =>
LocaleKeys.document_plugins_smartEditImproveWriting.tr(),
makeItLonger => LocaleKeys.document_plugins_smartEditMakeLonger.tr(),
};
}

class AppFlowyAIService implements AIRepository {
@override
Future<FlowyResult<List<String>, AIError>> generateImage({
required String prompt,
int n = 1,
}) {
throw UnimplementedError();
}

@override
Future<void> getStreamedCompletions({
required String prompt,
required Future<void> Function() onStart,
required Future<void> Function(TextCompletionResponse response) onProcess,
required Future<void> Function() onEnd,
required void Function(AIError error) onError,
String? suffix,
int maxTokens = 2048,
double temperature = 0.3,
bool useAction = false,
}) {
throw UnimplementedError();
}

@override
Future<CompletionStream> streamCompletion({
Future<(String, CompletionStream)?> streamCompletion({
String? objectId,
required String text,
PredefinedFormat? format,
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
required Future<void> Function() onEnd,
required void Function(AIError error) onError,
}) async {
final stream = CompletionStream(
onStart,
onProcess,
onEnd,
onError,
);
final List<String> ragIds = [];
if (objectId != null) {
ragIds.add(objectId);
}
final stream = CompletionStream(onStart, onProcess, onEnd, onError);

final payload = CompleteTextPB(
text: text,
completionType: completionType,
streamPort: fixnum.Int64(stream.nativePort),
objectId: objectId ?? "",
ragIds: ragIds,
objectId: objectId ?? '',
ragIds: [
if (objectId != null) objectId,
],
);

// ignore: unawaited_futures
AIEventCompleteText(payload).send();
return stream;
}
}

CompletionTypePB completionTypeFromInt(AskAIAction action) {
switch (action) {
case AskAIAction.summarize:
return CompletionTypePB.MakeShorter;
case AskAIAction.fixSpelling:
return CompletionTypePB.SpellingAndGrammar;
case AskAIAction.improveWriting:
return CompletionTypePB.ImproveWriting;
case AskAIAction.makeItLonger:
return CompletionTypePB.MakeLonger;
return AIEventCompleteText(payload).send().fold(
(completeTextTask) => (completeTextTask.taskId, stream),
(error) {
Log.error(error);
return null;
},
);
}
}

Expand Down Expand Up @@ -170,7 +98,9 @@ class CompletionStream {
}

if (event.startsWith("error:")) {
onError(AIError(message: event.substring(6)));
onError(
AIError(message: event.substring(6), code: AIErrorCode.other),
);
}
},
);
Expand Down
2 changes: 1 addition & 1 deletion frontend/appflowy_flutter/lib/ai/service/error.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ part 'error.g.dart';
class AIError with _$AIError {
const factory AIError({
required String message,
@Default(AIErrorCode.other) AIErrorCode code,
required AIErrorCode code,
}) = _AIError;

factory AIError.fromJson(Map<String, Object?> json) =>
Expand Down
Loading

0 comments on commit 148643a

Please sign in to comment.