Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mediapipe_task_genai] On Android llm engine failed to validate graph for gemma-2b-it-gpu-int8.bin model #56

Open
fengwang opened this issue May 24, 2024 · 24 comments

Comments

@fengwang
Copy link

fengwang commented May 24, 2024

Bug report

Describe the bug
LLM Engine failed in ValidatedGraphConfig Initialization step.

Steps to reproduce

Steps to reproduce the behavior:

  1. Download gemma-2b-it-gpu-int8.bin weight from kaggle
  2. Upload weight to Android: adb push gemma-2b-it-gpu-int8.bin /data/local/tmp
  3. Clone the repo, go to folder flutter-mediapipe/packages/mediapipe-task-genai/example and run flutter run -d 9TAUH6MRNZJ7KN6H --dart-define=GEMMA_8B_GPU_URI=/data/local/tmp/gemma-2b-it-gpu-int8.bin, in which 9TAUH6MRNZJ7KN6H is my Android ID.
  4. On Android screen, click Gemma 8b GPU and then input Hello, world! and click send button;
  5. Crash, and the log is uploaded crash_log.txt

Expected behavior

Just run one or two conversation successfully.

Additional context

Device Info:

  • Device Name: Redmi K60 Ultra
  • Hardware: Mediatek mt6985
  • API Level 33
  • Kernel: 5.15.78
  • GPU Mali-G715
  • Memory: 24 GB
  • Storage: 1 TB

I downloaded the model twice and tested twice, so the model file should be fine.


Flutter doctor

Run flutter doctor and paste the output below:

Click To Expand
Doctor summary (to see all details, run flutter doctor -v):
[✓] Flutter (Channel master, 3.23.0-7.0.pre.3, on Arch Linux 6.6.31-2-lts, locale en_US.UTF-8)
[✓] Android toolchain - develop for Android devices (Android SDK version 34.0.0)
[✓] Chrome - develop for the web
[✓] Linux toolchain - develop for Linux desktop
[!] Android Studio (not installed)
[✓] Connected device (3 available)
[✓] Network resources

! Doctor found issues in 1 category.


Flutter dependencies

Run flutter pub deps -- --style=compact and paste the output below:

Click To Expand
Dart SDK 3.5.0-189.0.dev
Flutter SDK 3.23.0-7.0.pre.3
example 1.0.0+1

dependencies:
- adaptive_dialog 2.1.0 [animations collection dynamic_color flutter intersperse macos_ui meta]
- chat_bubbles 1.6.0 [flutter intl]
- cupertino_icons 1.0.8
- flutter 0.0.0 [characters collection material_color_utilities meta vector_math sky_engine]
- flutter_bloc 8.1.5 [bloc flutter provider]
- flutter_localizations 0.0.0 [flutter intl characters clock collection material_color_utilities meta path vector_math]
- freezed_annotation 2.4.1 [collection json_annotation meta]
- getwidget 4.0.0 [flutter]
- http 1.2.1 [async http_parser meta web]
- intl 0.19.0 [clock meta path]
- logging 1.2.0
- mediapipe_core 0.0.1 [equatable ffi logging meta]
- mediapipe_genai 0.0.1 [async equatable ffi http logging mediapipe_core native_assets_cli native_toolchain_c path]
- path 1.9.0
- path_provider 2.1.3 [flutter path_provider_android path_provider_foundation path_provider_linux path_provider_platform_interface path_provider_windows]
- provider 6.1.2 [collection flutter nested]
- shimmer 3.0.0 [flutter]
- uuid 4.4.0 [crypto sprintf meta fixnum]

dev dependencies:
- build_runner 2.4.10 [analyzer args async build build_config build_daemon build_resolvers build_runner_core code_builder collection crypto dart_style frontend_server_client glob graphs http_multi_server io js logging meta mime package_config path pool pub_semver pubspec_parse shelf shelf_web_socket stack_trace stream_transform timing watcher web_socket_channel yaml]
- flutter_lints 3.0.2 [lints]
- flutter_test 0.0.0 [flutter test_api matcher path fake_async clock stack_trace vector_math leak_tracker_flutter_testing async boolean_selector characters collection leak_tracker leak_tracker_testing material_color_utilities meta source_span stream_channel string_scanner term_glyph vm_service]
- freezed 2.5.2 [analyzer build build_config collection meta source_gen freezed_annotation json_annotation]

transitive dependencies:
- _fe_analyzer_shared 67.0.0 [meta]
- analyzer 6.4.1 [_fe_analyzer_shared collection convert crypto glob meta package_config path pub_semver source_span watcher yaml]
- animations 2.0.11 [flutter]
- appkit_ui_element_colors 1.0.0 [equatable flutter plugin_platform_interface]
- args 2.5.0
- async 2.11.0 [collection meta]
- bloc 8.1.4 [meta]
- boolean_selector 2.1.1 [source_span string_scanner]
- build 2.4.1 [analyzer async convert crypto glob logging meta package_config path]
- build_config 1.1.1 [checked_yaml json_annotation path pubspec_parse yaml]
- build_daemon 4.0.2 [built_collection built_value crypto http_multi_server logging path pool shelf shelf_web_socket stream_transform watcher web_socket_channel]
- build_resolvers 2.4.2 [analyzer async build collection convert crypto graphs logging package_config path pool pub_semver stream_transform yaml]
- build_runner_core 7.3.0 [async build build_config build_resolvers collection convert crypto glob graphs json_annotation logging meta package_config path pool timing watcher yaml]
- built_collection 5.1.1
- built_value 8.9.2 [built_collection collection fixnum meta]
- characters 1.3.0
- checked_yaml 2.0.3 [json_annotation source_span yaml]
- cli_config 0.1.2 [args yaml]
- clock 1.1.1
- code_builder 4.10.0 [built_collection built_value collection matcher meta]
- collection 1.18.0
- convert 3.1.1 [typed_data]
- crypto 3.0.3 [typed_data]
- dart_style 2.3.6 [analyzer args collection path pub_semver source_span]
- dynamic_color 1.7.0 [flutter flutter_test material_color_utilities]
- equatable 2.0.5 [collection meta]
- fake_async 1.3.1 [clock collection]
- ffi 2.1.2
- file 7.0.0 [meta path]
- fixnum 1.1.0
- frontend_server_client 4.0.0 [async path]
- glob 2.1.2 [async collection file path string_scanner]
- gradient_borders 1.0.0 [flutter]
- graphs 2.3.1 [collection]
- http_multi_server 3.2.1 [async]
- http_parser 4.0.2 [collection source_span string_scanner typed_data]
- intersperse 2.0.0
- io 1.0.4 [meta path string_scanner]
- js 0.7.1
- json_annotation 4.9.0 [meta]
- leak_tracker 10.0.5 [clock collection meta path vm_service]
- leak_tracker_flutter_testing 3.0.5 [flutter leak_tracker leak_tracker_testing matcher meta]
- leak_tracker_testing 3.0.1 [leak_tracker matcher meta]
- lints 3.0.0
- macos_ui 2.0.7 [flutter macos_window_utils gradient_borders appkit_ui_element_colors equatable]
- macos_window_utils 1.5.0 [flutter]
- matcher 0.12.16+1 [async meta stack_trace term_glyph test_api]
- material_color_utilities 0.11.1 [collection]
- meta 1.14.0
- mime 1.0.5
- native_assets_cli 0.3.2 [cli_config collection crypto pub_semver yaml yaml_edit]
- native_toolchain_c 0.3.3 [cli_config glob logging meta native_assets_cli pub_semver]
- nested 1.0.0 [flutter]
- package_config 2.1.0 [path]
- path_provider_android 2.2.4 [flutter path_provider_platform_interface]
- path_provider_foundation 2.4.0 [flutter path_provider_platform_interface]
- path_provider_linux 2.2.1 [ffi flutter path path_provider_platform_interface xdg_directories]
- path_provider_platform_interface 2.1.2 [flutter platform plugin_platform_interface]
- path_provider_windows 2.2.1 [ffi flutter path path_provider_platform_interface win32]
- platform 3.1.4
- plugin_platform_interface 2.1.8 [meta]
- pool 1.5.1 [async stack_trace]
- pub_semver 2.1.4 [collection meta]
- pubspec_parse 1.2.3 [checked_yaml collection json_annotation pub_semver yaml]
- shelf 1.4.1 [async collection http_parser path stack_trace stream_channel]
- shelf_web_socket 2.0.0 [shelf stream_channel web_socket_channel]
- sky_engine 0.0.99
- source_gen 1.5.0 [analyzer async build dart_style glob path source_span yaml]
- source_span 1.10.0 [collection path term_glyph]
- sprintf 7.0.0
- stack_trace 1.11.1 [path]
- stream_channel 2.1.2 [async]
- stream_transform 2.1.0
- string_scanner 1.2.0 [source_span]
- term_glyph 1.2.1
- test_api 0.7.1 [async boolean_selector collection meta source_span stack_trace stream_channel string_scanner term_glyph]
- timing 1.0.1 [json_annotation]
- typed_data 1.3.2 [collection]
- vector_math 2.1.4
- vm_service 14.2.2
- watcher 1.1.0 [async path]
- web 0.5.1
- web_socket 0.1.4 [web]
- web_socket_channel 3.0.0 [async crypto stream_channel web web_socket]
- win32 5.5.1 [ffi]
- xdg_directories 1.0.4 [meta path]
- yaml 3.1.2 [collection source_span string_scanner]
- yaml_edit 2.2.1 [collection meta source_span yaml]


@WingCH
Copy link

WingCH commented May 25, 2024

same issue here

@Gaurav-822
Copy link

facing the same issue

@yashpapa6969
Copy link

same issue

@Gaurav-822
Copy link

idk why but i guess this issue is with the flutter version only, i tried the one in native android and it worked fine, using Method Channel to implement it for now.

@craiglabenz
Copy link
Collaborator

Does the CPU version work? Currently, it is known that the CPU version increases the number of devices on which Gemma can successfully run.

@yashpapa6969
Copy link

cpu ,gpu ,both the quantised model tried same issue

@craiglabenz
Copy link
Collaborator

Thanks for the report. I've shared this with the MediaPipe team, but both engineers who work on that half of the stack are currently OOO, so we'll have to all be patient for a little bit until they're able to unravel why the MediaPipe SDK is crashing.

@craiglabenz
Copy link
Collaborator

This is potentially similar to google-ai-edge/mediapipe-samples#335.

@aaronrau
Copy link

aaronrau commented Jun 7, 2024

Same Issue testing on 3 different android devices (Samsung S8, S10, A25 5G). All crashes on the example app, doesn't matter which models gets loaded.

@yashpapa6969
Copy link

yashpapa6969 commented Jun 7, 2024

the model works in native android though check this out
https://github.com/google-ai-edge/mediapipe-samples/tree/main/examples/llm_inference/android
i tried running on s24 ultra it worked well.

@craiglabenz
Copy link
Collaborator

craiglabenz commented Jun 7, 2024

That very likely means there are issues with the version of the SDK compiled specifically for Flutter. I have filed this as an internal bug on the MediaPipe project (now called google-ai-edge), but I anticipate we may have to be patient to see a fix; as several relevant engineers are currently on leave.

For Googlers: b/349870091

@tempo-riz
Copy link

hey any update on this @craiglabenz ? Do you have any time estimation ?

@craiglabenz
Copy link
Collaborator

Sadly, I don't have any updates yet. The relevant MediaPipe engineers are still on leave.

@jawad111
Copy link

jawad111 commented Jul 25, 2024

Got the following error Log on Samsung A31, when launching the mediapipe_task_genai example with GEMMA_4B_CPU_URI.

I/native (32156): I0000 00:00:1721895466.174139 1280 llm_inference_engine.cc:186] Session config: backend: 1, threads: 4, num_output_candidates: 1, topk: 1, temperature: 1, max_tokens: 512, disable_kv_cache: 0, activation_data_type: 0, fake_weights_mode: 0, benchmark_info.input_token_limit: 0, benchmark_info.wait_for_input_processing: 0, Model type: 6
E/native (32156): E0000 00:00:1721895466.207241 1280 llm_engine.cc:68] NOT_FOUND: ValidatedGraphConfig Initialization failed.
E/native (32156): No registered object with name: TokenizerCalculator; Unable to find Calculator "TokenizerCalculator"
E/native (32156): No registered object with name: DetokenizerCalculator; Unable to find Calculator "DetokenizerCalculator"
E/native (32156): No registered object with name: LlmXnnCalculator; Unable to find Calculator "LlmXnnCalculator"
E/native (32156): No registered object with name: TokenCostCalculator; Unable to find Calculator "TokenCostCalculator"
E/native (32156): No registered object with name: ModelDataCalculator; Unable to find Calculator "ModelDataCalculator"
F/native (32156): F0000 00:00:1721895466.207621 1280 llm_engine.cc:98] Check failed: graph_->ObserveOutputStream( "result", [&, init_start_time, first_token_observed = false]( const mediapipe::Packet& p) mutable { const auto& tokens = p.Get<std::vectorstd::string>(); if (!tokens.empty() && !first_token_observed) { first_token_observed = true; benchmark_result_.time_to_first_token = absl::Now() - init_start_time; } timestamp_++; for (int i = 0; i < tokens.size(); ++i) { outputs_[i].response_tokens.push_back(tokens[i]); } if (async_callback_) { auto step_outputs = std::vector(tokens.size()); for (int i = 0; i < tokens.size(); ++i) { step_outputs[i].response_tokens.push_back(tokens[i]); } async_callback_(step_outputs); } return absl::OkStatus(); }) is OK (INTERNAL: )
F/native (32156): terminating.
F/native (32156): F0000 00:00:1721895466.207621 1280 llm_engine.cc:98] Check failed: graph_->ObserveOutputStream( "result", [&, init_start_time, first_token_observed = false]( const mediapipe::Packet& p) mutable { const auto& tokens = p.Get<std::vectorstd::string>(); if (!tokens.empty() && !first_token_observed) { first_token_observed = true; benchmark_result_.time_to_first_token = absl::Now() - init_start_time; } timestamp_++; for (int i = 0; i < tokens.size(); ++i) { outputs_[i].response_tokens.push_back(tokens[i]); } if (async_callback_) { auto step_outputs = std::vector(tokens.size()); for (int i = 0; i < tokens.size(); ++i) { step_outputs[i].response_tokens.push_back(tokens[i]); } async_callback_(step_outputs); } return absl::OkStatus(); }) is OK (INTERNAL: )
F/native (32156): terminating.
F/libc (32156): Fatal signal 6 (SIGABRT), code -1 (SI_QUEUE) in tid 1280 (DartWorker), pid 32156 (com.example)


Build fingerprint: 'samsung/a31xx/a31:12/SP1A.210812.016/A315FXXS5DXB1:user/release-keys'
Revision: '3'
ABI: 'arm64'
Processor: '6'
Timestamp: 2024-07-25 12:17:46.496010026+0400
Process uptime: 85s
Cmdline: com.example
pid: 32156, tid: 1280, name: DartWorker >>> com.example <<<
uid: 10996
signal 6 (SIGABRT), code -1 (SI_QUEUE), fault addr --------
Abort message: 'F0000 00:00:1721895466.207621 1280 llm_engine.cc:98] Check failed: graph_->ObserveOutputStream( "result", [&, init_start_time, first_token_observed = false]( const mediapipe::Packet& p) mutable { const auto& tokens = p.Get<std::vectorstd::string>(); if (!tokens.empty() && !first_token_observed) { first_token_observed = true; benchmark_result_.time_to_first_token = absl::Now() - init_start_time; } timestamp_++; for (int i = 0; i < tokens.size(); ++i) { outputs_[i].response_tokens.push_back(tokens[i]); } if (async_callback_) { auto step_outputs = std::vector(tokens.size()); for (int i = 0; i < tokens.size(); ++i) { step_outputs[i].response_tokens.push_back(tokens[i]); } async_callback_(step_outputs); } return absl::OkStatus(); }) is OK (INTERNAL: )
'
x0 0000000000000000 x1 0000000000000500 x2 0000000000000006 x3 0000007411503600
x4 0000000100000006 x5 0000000100000006 x6 0000000100000006 x7 0000000100000006
x8 00000000000000f0 x9 00000075320b90d8 x10 ffffff00fffffbdf x11 0000000000000001
x12 0000000000000001 x13 00000000ffffffff x14 0000000000000001 x15 0000000000000001
x16 000000753218ed20 x17 0000007532169400 x18 000000740a6b6000 x19 00000000000000ac
x20 0000000000007d9c x21 00000000000000b2 x22 0000000000000500 x23 00000000ffffffff
x24 0000007300008081 x25 0000007411423000 x26 b40000742192bc00 x27 0000007308f43440
x28 0000000800000073 x29 0000007411503680
lr 000000753211a9d4 sp 00000074115035e0 pc 000000753211aa04 pst 0000000000000000
backtrace:
#00 pc 0000000000089a04 /apex/com.android.runtime/lib64/bionic/libc.so (abort+180) (BuildId: 9f4160992bb8de5917fe05c216139614)
#1 pc 0000000000f636c8 /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (absl::log_internal::LogMessage::FailWithoutStackTrace()+100)
#2 pc 0000000000f64028 /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (absl::log_internal::LogMessage::Die()+60)
#3 pc 0000000000f639a4 /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (absl::log_internal::LogMessage::SendToLog()+176)
#4 pc 0000000000f62edc /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (absl::log_internal::LogMessage::Flush()+540)
#5 pc 0000000000f64174 /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (absl::log_internal::LogMessageFatal::~LogMessageFatal()+28)
#6 pc 000000000080823c /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (odml::infra::LlmInferenceEngine::Session::Session(odml::infra::proto::SessionConfig, mediapipe::CalculatorGraphConfig*)+1112)
#7 pc 0000000000814438 /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so
#8 pc 000000000080ad98 /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (odml::infra::LlmInferenceEngine::CreateSession(odml::infra::proto::SessionConfig const&, std::__ndk1::optionalodml::infra::proto::LlmModelType)+348)
#9 pc 00000000007f51ec /data/app/~~j6hqN7g2qKf_33ambb9y7g==/com.example-2xzzAbTeOukFvMDe_Cg7bQ==/lib/arm64/liblibllm_inference_engine.so (LlmInferenceEngine_CreateSession+3824)
#10 pc 00000000000081d4 [anon:dart-code]
Lost connection to device.

Exited.

Note: The same model works fine when used with com.google.mediapipe.tasks.genai.llminference.LlmInference using Android Method Channel

@tempo-riz
Copy link

Hey @jawad111 would you be okay to share a code sample/repo of your implementation using platform channel and android native ?

I tried on my side but couldn't get it to work :/

@jawad111
Copy link

jawad111 commented Aug 1, 2024

Hey @jawad111 would you be okay to share a code sample/repo of your implementation using platform channel and android native ?

I tried on my side but couldn't get it to work :/

Hi @tempo-riz. First, I would like to say that this repo is a better and an optimized solution. While the issues are resolved, you can follow this official video by flutter team for a detailed reference. Observable Flutter #43: On-device LLMs with Gemma.

@tempo-riz
Copy link

hey @craiglabenz any update on this ?

@craiglabenz
Copy link
Collaborator

We're attempting to finalize contracts with a great developer to finish work on this library, which should ultimately get and keep it in great shape. Unfortunately, my employer can be slow with paperwork. Still, I'm cautiously optimistic that renewed engineering work on this library will begin in Q1 2025 🤞

@yashpapa6969
Copy link

@craiglabenz is it possible to get internship for this project ? I am quite interested in the project

@SumeshBharathi
Copy link

Faced the same issue. :(

@keyfayaz
Copy link

keyfayaz commented Jan 5, 2025

Same here.

@ndnam198
Copy link

ndnam198 commented Jan 5, 2025

Same here too

@ASahu16
Copy link

ASahu16 commented Jan 6, 2025

+1
Does anyone have a workaround?

@jawad111
Copy link

jawad111 commented Jan 21, 2025

Sorry for the late response. Please find below a working example of using com.google.mediapipe.tasks.genai.llminference.LlmInference with Android Method Channel. This should work as a temporary workaround.

Example Project Link: https://github.com/jawad111/mediapipe_llminference_example.git

Tested on Devices:

*Samsung A31
*Samsung A15 

Considerations for project setup

In AndroidManifest.xml added  android:largeHeap="true" in application Tag.
In AndroidManifest.xml added <uses-native-library android:name="libOpenCL.so" android:required="false"/>

Model Setup Instructions

  • Download the Model from: https://www.kaggle.com/models/google/gemma/tfLite/

  • Rename the Model to model.bin for consistency with the app's configuration.

  • Upload the Model to the Device Path

    /data/local/tmp/llm/model.bin

  • bash

    adb push model.bin /data/local/tmp/llm/model.bin


Project Overview

Widget Layer

 1. Initialize Model Button     
    onPressed: calls -> LlmService.initializeModel()
                               
 2. Generate Response Button    
    onPressed: calls -> LlmService.generateResponse()

Service Layer: LlmService

 1. initializeModel()-> Calls MethodChannel.invokeMethod('initialize', {...}) 
                                |
 2. generateResponse() -> Calls MethodChannel.invokeMethod('generateResponse', {...}) 

Native Code: Kotlin

MethodChannel Handler (MainActivity.kt) 

 1. 'initialize' -> Executes Kotlin logic to load the model 
 2. 'generateResponse' -> Executes logic to generate a response 

Code Snippets for Each Step

1. Widget Layer

Initialize Model Button:
ElevatedButton(
  onPressed: LlmService().initializeModel(),
  child: const Text('Initialize Model'),
);


Generate Response Button:
ElevatedButton(
  onPressed: LlmService().generateResponse(_promptController.text),
  child: const Text('Sync Response'),
);

2. Service Layer (LlmService)

static const MethodChannel _channel =
      MethodChannel('com.example.mediapipe_llminference_example/inference');
Future<String> initializeModel() async {
  try {
    final result = await _channel.invokeMethod('initialize', {
      'modelPath': '/data/local/tmp/llm/model.bin',
      'maxTokens': 50,
      'temperature': 0.7,
      'randomSeed': 42,
      'topK': 40,
    });
    return 'Initialization result: $result';
  } catch (e) {
    return 'Error initializing model: $e';
  }
}


Future<String> generateResponse(String prompt) async {
  try {
    final result = await _channel.invokeMethod('generateResponse', {
      'prompt': prompt,
    });
    return result ?? 'No response received';
  } catch (e) {
    return 'Error generating response: $e';
  }
}

3. Native Code (Kotlin)

class MainActivity : FlutterActivity() {

    private val CHANNEL = "com.example.mediapipe_llminference_example/inference"
    private var llmInference: LlmInference? = null
    private val partialResultsFlow = MutableSharedFlow<Pair<String, Boolean>>(
        extraBufferCapacity = 1,
        onBufferOverflow = BufferOverflow.DROP_OLDEST
    )

    override fun configureFlutterEngine(flutterEngine: FlutterEngine) {
        super.configureFlutterEngine(flutterEngine)

        MethodChannel(flutterEngine.dartExecutor.binaryMessenger, CHANNEL).setMethodCallHandler { call, result ->
            when (call.method) {
                "initialize" -> {
                    val modelPath = call.argument<String>("modelPath") ?: ""
                    val maxTokens = call.argument<Int>("maxTokens") ?: 50
                    val temperature: Float = (call.argument<Double>("temperature")?.toFloat() ?: 0.7f)
                    val randomSeed = call.argument<Int>("randomSeed") ?: 42
                    val topK = call.argument<Int>("topK") ?: 40

                    if (!File(modelPath).exists()) {
                        result.error("INIT_ERROR", "Model not found at path: $modelPath", null)
                        return@setMethodCallHandler
                    }

                    try {
                        val options = LlmInference.LlmInferenceOptions.builder()
                            .setModelPath(modelPath)
                            .setMaxTokens(maxTokens)
                            .setTemperature(temperature)
                            .setRandomSeed(randomSeed)
                            .setTopK(topK)
                            .setResultListener { partialResult, done ->
                                partialResultsFlow.tryEmit(partialResult to done)
                            }
                            .build()

                        llmInference = LlmInference.createFromOptions(this, options)
                        result.success("Model initialized successfully")
                    } catch (e: Exception) {
                        result.error("INIT_ERROR", "Failed to initialize: ${e.message}", null)
                    }
                }
                "generateResponse" -> {
                    val prompt = call.argument<String>("prompt") ?: ""
                    try {
                        val response = llmInference?.generateResponse(prompt)
                        if (response != null) {
                            result.success(response)
                        } else {
                            result.error("GEN_ERROR", "Failed to generate response", null)
                        }
                    } catch (e: Exception) {
                        result.error("GEN_ERROR", "Error generating response: ${e.message}", null)
                    }
                }
            }
        }
    }

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)

        // Listen for partial results if necessary
        CoroutineScope(Dispatchers.Main).launch {
            partialResultsFlow.collectLatest { (partialResult, done) ->
                // Optional: Use EventChannel or logs to send partial results back to Flutter.
                println("Partial result: $partialResult, Done: $done")
            }
        }
    }
}


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests