Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_llava.sh
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ run_and_verify() {

# verify result.txt
RESULT=$(cat result.txt)
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with"
EXPECTED_PREFIX="ASSISTANT: The image captures a basketball game in progress, with"

if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then
echo "Expected result prefix: ${EXPECTED_PREFIX}"
Expand Down
60 changes: 53 additions & 7 deletions examples/models/llava/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/examples/models/llava/runner/llava_runner.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/multimodal_input.h>
#include <executorch/extension/llm/runner/multimodal_runner.h>
#include <gflags/gflags.h>
#include <pytorch/tokenizers/llama2c_tokenizer.h>
#define STB_IMAGE_IMPLEMENTATION
#include <stb_image.h>
#define STB_IMAGE_RESIZE_IMPLEMENTATION
Expand Down Expand Up @@ -44,7 +47,10 @@ DEFINE_int32(
-1,
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");

using executorch::extension::llm::Image;
using ::executorch::extension::llm::Image;
using ::executorch::extension::llm::make_image_input;
using ::executorch::extension::llm::make_text_input;
using ::executorch::extension::llm::MultimodalInput;

void load_image(const std::string& image_path, Image& image) {
int width, height, channels;
Expand Down Expand Up @@ -127,14 +133,54 @@ int32_t main(int32_t argc, char** argv) {
->_unsafe_reset_threadpool(num_performant_cores);
}
#endif
// create llama runner
example::LlavaRunner runner(model_path, tokenizer_path, temperature);
// Load tokenizer
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
std::make_unique<tokenizers::Llama2cTokenizer>();
tokenizer->load(tokenizer_path);
if (tokenizer == nullptr) {
ET_LOG(Error, "Failed to load tokenizer from: %s", tokenizer_path);
return 1;
}

// Create multimodal runner
std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner =
::executorch::extension::llm::create_multimodal_runner(
model_path, std::move(tokenizer));
if (runner == nullptr) {
ET_LOG(Error, "Failed to create multimodal runner");
return 1;
}

// Load runner
auto load_error = runner->load();
if (load_error != ::executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to load multimodal runner");
return 1;
}

// Prepare inputs
static const char* kPresetPrompt =
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
Image image;
load_image(image_path, image);
std::vector<Image> images = {image};
std::vector<MultimodalInput> inputs = {
make_text_input(std::string(kPresetPrompt)),
make_image_input(image),
make_text_input(std::string(prompt)),
};

::executorch::extension::llm::GenerationConfig config;
config.temperature = temperature;
config.echo = true;

// Generate
ET_LOG(Info, "Starting generation...");
auto error = runner->generate(inputs, config);
if (error != ::executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to generate with multimodal runner");
return 1;
}

// generate
runner.generate(std::move(images), prompt, seq_len);
printf("\n");
return 0;
}
6 changes: 5 additions & 1 deletion extension/llm/runner/multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ Error MultimodalRunner::generate(

uint64_t prefill_next_token = 0;
// Process multimodal inputs in order
for (const MultimodalInput& input : inputs) {
for (size_t i = 0; i < inputs.size(); ++i) {
const MultimodalInput& input = inputs[i];
if (config.echo && i == inputs.size() - 1 && input.is_text()) {
wrapped_callback(input.get_text());
}
prefill_next_token = ET_UNWRAP(multimodal_prefiller_->prefill(input, pos_));
}

Expand Down
Loading