Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: download llm files #5723

Merged
merged 15 commits into from
Jul 15, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ import 'util.dart';

extension AppFlowyAuthTest on WidgetTester {
Future<void> tapGoogleLoginInButton() async {
await tapButton(find.byKey(const Key('signInWithGoogleButton')));
await tapButton(
find.byKey(const Key('signInWithGoogleButton')),
milliseconds: 3000,
);
}

/// Requires being on the SettingsPage.account of the SettingsDialog
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-chat/entities.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-error/code.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-error/errors.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-folder/protobuf.dart';
import 'package:appflowy_backend/protobuf/flowy-folder/view.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-user/user_profile.pb.dart';
Expand Down Expand Up @@ -480,7 +481,7 @@ class ChatState with _$ChatState {
@freezed
class LoadingState with _$LoadingState {
const factory LoadingState.loading() = _Loading;
const factory LoadingState.finish() = _Finish;
const factory LoadingState.finish({FlowyError? error}) = _Finish;
}

enum OnetimeShotType {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/protobuf/flowy-chat/entities.pb.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart';

part 'chat_file_bloc.freezed.dart';

class ChatFileBloc extends Bloc<ChatFileEvent, ChatFileState> {
ChatFileBloc({
required String chatId,
dynamic message,
}) : super(ChatFileState.initial(message)) {
on<ChatFileEvent>(
(event, emit) async {
await event.when(
initial: () async {},
newFile: (String filePath) {
final payload = ChatFilePB(filePath: filePath, chatId: chatId);
ChatEventChatWithFile(payload).send();
},
);
},
);
}
}

@freezed
class ChatFileEvent with _$ChatFileEvent {
const factory ChatFileEvent.initial() = Initial;
const factory ChatFileEvent.newFile(String filePath) = _NewFile;
}

@freezed
class ChatFileState with _$ChatFileState {
const factory ChatFileState({
required String text,
}) = _ChatFileState;

factory ChatFileState.initial(dynamic text) {
return ChatFileState(
text: text is String ? text : "",
);
}
}
51 changes: 48 additions & 3 deletions frontend/appflowy_flutter/lib/plugins/ai_chat/chat_page.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import 'package:appflowy/plugins/ai_chat/application/chat_file_bloc.dart';
import 'package:desktop_drop/desktop_drop.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';

Expand Down Expand Up @@ -50,7 +52,7 @@ class AIChatUILayout {
}
}

class AIChatPage extends StatefulWidget {
class AIChatPage extends StatelessWidget {
const AIChatPage({
super.key,
required this.view,
Expand All @@ -63,10 +65,53 @@ class AIChatPage extends StatefulWidget {
final UserProfilePB userProfile;

@override
State<AIChatPage> createState() => _AIChatPageState();
Widget build(BuildContext context) {
if (userProfile.authenticator == AuthenticatorPB.AppFlowyCloud) {
return BlocProvider(
create: (context) => ChatFileBloc(chatId: view.id.toString()),
child: BlocBuilder<ChatFileBloc, ChatFileState>(
builder: (context, state) {
return DropTarget(
onDragDone: (DropDoneDetails detail) async {
for (final file in detail.files) {
context
.read<ChatFileBloc>()
.add(ChatFileEvent.newFile(file.path));
}
},
child: _ChatContentPage(
view: view,
userProfile: userProfile,
),
);
},
),
);
}

return Center(
child: FlowyText(
LocaleKeys.chat_unsupportedCloudPrompt.tr(),
fontSize: 20,
),
);
}
}

class _ChatContentPage extends StatefulWidget {
const _ChatContentPage({
required this.view,
required this.userProfile,
});

final UserProfilePB userProfile;
final ViewPB view;

@override
State<_ChatContentPage> createState() => _ChatContentPageState();
}

class _AIChatPageState extends State<AIChatPage> {
class _ChatContentPageState extends State<_ChatContentPage> {
late types.User _user;

@override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import 'dart:async';
import 'dart:ffi';
import 'dart:isolate';

import 'package:appflowy/plugins/ai_chat/application/chat_bloc.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/protobuf/flowy-chat/entities.pb.dart';
import 'package:bloc/bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart';
import 'package:fixnum/fixnum.dart';
part 'download_model_bloc.freezed.dart';

class DownloadModelBloc extends Bloc<DownloadModelEvent, DownloadModelState> {
DownloadModelBloc(LLMModelPB model)
: super(DownloadModelState(model: model)) {
on<DownloadModelEvent>(_handleEvent);
}

Future<void> _handleEvent(
DownloadModelEvent event,
Emitter<DownloadModelState> emit,
) async {
await event.when(
started: () async {
final downloadStream = DownloadingStream();
downloadStream.listen(
onModelPercentage: (name, percent) {
if (!isClosed) {
add(
DownloadModelEvent.updatePercent(name, percent),
);
}
},
onPluginPercentage: (percent) {
if (!isClosed) {
add(DownloadModelEvent.updatePercent("AppFlowy Plugin", percent));
}
},
onFinish: () {
add(const DownloadModelEvent.downloadFinish());
},
onError: (err) {
// emit(state.copyWith(downloadError: err));
},
);

final payload =
DownloadLLMPB(progressStream: Int64(downloadStream.nativePort));
final result = await ChatEventDownloadLLMResource(payload).send();
result.fold((_) {
emit(
state.copyWith(
downloadStream: downloadStream,
loadingState: const LoadingState.finish(),
downloadError: null,
),
);
}, (err) {
emit(state.copyWith(loadingState: LoadingState.finish(error: err)));
});
},
updatePercent: (String object, double percent) {
emit(state.copyWith(object: object, percent: percent));
},
downloadFinish: () {
emit(state.copyWith(isFinish: true));
},
);
}
}

@freezed
class DownloadModelEvent with _$DownloadModelEvent {
const factory DownloadModelEvent.started() = _Started;
const factory DownloadModelEvent.updatePercent(
String object,
double percent,
) = _UpdatePercent;
const factory DownloadModelEvent.downloadFinish() = _DownloadFinish;
}

@freezed
class DownloadModelState with _$DownloadModelState {
const factory DownloadModelState({
required LLMModelPB model,
DownloadingStream? downloadStream,
String? downloadError,
@Default("") String object,
@Default(0) double percent,
@Default(false) bool isFinish,
@Default(LoadingState.loading()) LoadingState loadingState,
}) = _DownloadModelState;
}

class DownloadingStream {
DownloadingStream() {
_port.handler = _controller.add;
}

final RawReceivePort _port = RawReceivePort();
StreamSubscription<String>? _sub;
final StreamController<String> _controller = StreamController.broadcast();
int get nativePort => _port.sendPort.nativePort;

Future<void> dispose() async {
await _sub?.cancel();
await _controller.close();
_port.close();
}

void listen({
void Function(String modelName, double percent)? onModelPercentage,
void Function(double percent)? onPluginPercentage,
void Function(String data)? onError,
void Function()? onFinish,
}) {
_sub = _controller.stream.listen((text) {
if (text.contains(':progress:')) {
final progressIndex = text.indexOf(':progress:');
final modelName = text.substring(0, progressIndex);
final progressValue = text
.substring(progressIndex + 10); // 10 is the length of ":progress:"
final percent = double.tryParse(progressValue);
if (percent != null) {
onModelPercentage?.call(modelName, percent);
}
} else if (text.startsWith('plugin:progress:')) {
final percent = double.tryParse(text.substring(16));
if (percent != null) {
onPluginPercentage?.call(percent);
}
} else if (text.startsWith('finish')) {
onFinish?.call();
} else if (text.startsWith('error:')) {
// substring 6 to remove "error:"
onError?.call(text.substring(6));
}
});
}
}
Loading
Loading