Skip to content

Commit

Permalink
[Serving] Image support in JSONFFIEngine (mlc-ai#2208)
Browse files Browse the repository at this point in the history
Using new Result interface

Co-authored-by: Animesh Bohara <abohara@cs.cmu.edu>
  • Loading branch information
anibohara2000 and Animesh Bohara authored May 6, 2024
1 parent d31941f commit 5ae393a
Show file tree
Hide file tree
Showing 12 changed files with 481 additions and 100 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
[submodule "3rdparty/tvm"]
path = 3rdparty/tvm
url = https://github.com/mlc-ai/relax.git
[submodule "3rdparty/stb"]
path = 3rdparty/stb
url = https://github.com/nothings/stb.git
1 change: 1 addition & 0 deletions 3rdparty/stb
Submodule stb added at ae721c
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ target_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES})
target_compile_definitions(mlc_llm_objs PRIVATE ${MLC_LLM_COMPILE_DEFS})
target_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include)
target_compile_definitions(mlc_llm_objs PRIVATE -DMLC_LLM_EXPORTS)
target_include_directories(mlc_llm_objs PRIVATE 3rdparty/stb)

add_library(mlc_llm SHARED $<TARGET_OBJECTS:mlc_llm_objs>)
add_library(mlc_llm_static STATIC $<TARGET_OBJECTS:mlc_llm_objs>)
Expand Down
144 changes: 143 additions & 1 deletion cpp/json_ffi/conv_template.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,132 @@
#include <tvm/runtime/registry.h>

#include "../support/json_parser.h"
#include "image_utils.h"

namespace mlc {
namespace llm {
namespace json_ffi {

using namespace mlc::llm;

/****************** Model vision config ******************/

ModelVisionConfig ModelVisionConfig::FromJSON(const picojson::object& json_obj) {
ModelVisionConfig config;

Result<int64_t> hidden_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "hidden_size");
if (hidden_size_res.IsOk()) {
config.hidden_size = hidden_size_res.Unwrap();
}

Result<int64_t> image_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "image_size");
if (image_size_res.IsOk()) {
config.image_size = image_size_res.Unwrap();
}

Result<int64_t> intermediate_size_res =
json::LookupWithResultReturn<int64_t>(json_obj, "intermediate_size");
if (intermediate_size_res.IsOk()) {
config.intermediate_size = intermediate_size_res.Unwrap();
}

Result<int64_t> num_attention_heads_res =
json::LookupWithResultReturn<int64_t>(json_obj, "num_attention_heads");
if (num_attention_heads_res.IsOk()) {
config.num_attention_heads = num_attention_heads_res.Unwrap();
}

Result<int64_t> num_hidden_layers_res =
json::LookupWithResultReturn<int64_t>(json_obj, "num_hidden_layers");
if (num_hidden_layers_res.IsOk()) {
config.num_hidden_layers = num_hidden_layers_res.Unwrap();
}

Result<int64_t> patch_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "patch_size");
if (patch_size_res.IsOk()) {
config.patch_size = patch_size_res.Unwrap();
}

Result<int64_t> projection_dim_res =
json::LookupWithResultReturn<int64_t>(json_obj, "projection_dim");
if (projection_dim_res.IsOk()) {
config.projection_dim = projection_dim_res.Unwrap();
}

Result<int64_t> vocab_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "vocab_size");
if (vocab_size_res.IsOk()) {
config.vocab_size = vocab_size_res.Unwrap();
}

Result<std::string> dtype_res = json::LookupWithResultReturn<std::string>(json_obj, "dtype");
if (dtype_res.IsOk()) {
config.dtype = dtype_res.Unwrap();
}

Result<int64_t> num_channels_res =
json::LookupWithResultReturn<int64_t>(json_obj, "num_channels");
if (num_channels_res.IsOk()) {
config.num_channels = num_channels_res.Unwrap();
}

Result<double> layer_norm_eps_res =
json::LookupWithResultReturn<double>(json_obj, "layer_norm_eps");
if (layer_norm_eps_res.IsOk()) {
config.layer_norm_eps = layer_norm_eps_res.Unwrap();
}

return config;
}

/****************** Model config ******************/

ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) {
ModelConfig config;

Result<int64_t> vocab_size_res = json::LookupWithResultReturn<int64_t>(json_obj, "vocab_size");
if (vocab_size_res.IsOk()) {
config.vocab_size = vocab_size_res.Unwrap();
}

Result<int64_t> context_window_size_res =
json::LookupWithResultReturn<int64_t>(json_obj, "context_window_size");
if (context_window_size_res.IsOk()) {
config.context_window_size = context_window_size_res.Unwrap();
}

Result<int64_t> sliding_window_size_res =
json::LookupWithResultReturn<int64_t>(json_obj, "sliding_window_size");
if (sliding_window_size_res.IsOk()) {
config.sliding_window_size = sliding_window_size_res.Unwrap();
}

Result<int64_t> prefill_chunk_size_res =
json::LookupWithResultReturn<int64_t>(json_obj, "prefill_chunk_size");
if (prefill_chunk_size_res.IsOk()) {
config.prefill_chunk_size = prefill_chunk_size_res.Unwrap();
}

Result<int64_t> tensor_parallel_shards_res =
json::LookupWithResultReturn<int64_t>(json_obj, "tensor_parallel_shards");
if (tensor_parallel_shards_res.IsOk()) {
config.tensor_parallel_shards = tensor_parallel_shards_res.Unwrap();
}

Result<int64_t> max_batch_size_res =
json::LookupWithResultReturn<int64_t>(json_obj, "max_batch_size");
if (max_batch_size_res.IsOk()) {
config.max_batch_size = max_batch_size_res.Unwrap();
}

if (json_obj.count("vision_config")) {
const picojson::object& vision_config_obj =
json_obj.at("vision_config").get<picojson::object>();
config.vision_config = ModelVisionConfig::FromJSON(vision_config_obj);
}

return config;
}

/****************** Conversation template ******************/

std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
Expand All @@ -34,7 +153,7 @@ Conversation::Conversation()
{"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]},
{"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {}

Result<std::vector<Data>> Conversation::AsPrompt() {
Result<std::vector<Data>> Conversation::AsPrompt(ModelConfig config, DLDevice device) {
using TResult = Result<std::vector<Data>>;
// Get the system message
std::string system_msg = system_template;
Expand Down Expand Up @@ -116,6 +235,29 @@ Result<std::vector<Data>> Conversation::AsPrompt() {
}
}
message += role_text;
} else if (it_type->second == "image_url") {
if (item.find("image_url") == item.end()) {
return TResult::Error("Content should have an image_url field");
}
std::string image_url =
item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this
// should be a map, with a "url" key containing the URL, but
// we are just assuming this as the URL for now
std::string base64_image = image_url.substr(image_url.find(",") + 1);
Result<NDArray> image_data_res = LoadImageFromBase64(base64_image);
if (image_data_res.IsErr()) {
return TResult::Error(image_data_res.UnwrapErr());
}
if (!config.vision_config.has_value()) {
return TResult::Error("Vision config is required for image input");
}
int image_size = config.vision_config.value().image_size;
int patch_size = config.vision_config.value().patch_size;

int embed_size = (image_size * image_size) / (patch_size * patch_size);

auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
message_list.push_back(ImageData(image_ndarray, embed_size));
} else {
return TResult::Error("Unsupported content type: " + it_type->second);
}
Expand Down
39 changes: 38 additions & 1 deletion cpp/json_ffi/conv_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,43 @@ namespace mlc {
namespace llm {
namespace json_ffi {

/****************** Model vision config ******************/

/*! \brief Defines the Vision config of the model (if present) */
class ModelVisionConfig {
public:
int hidden_size;
int image_size;
int intermediate_size;
int num_attention_heads;
int num_hidden_layers;
int patch_size;
int projection_dim;
int vocab_size;
std::string dtype;
int num_channels;
double layer_norm_eps;

static ModelVisionConfig FromJSON(const picojson::object& json_obj);
};

/****************** Model config ******************/

/*! \brief Defines the config of the model.
Populated from "model_config" field in mlc-chat-config.json */
class ModelConfig {
public:
int vocab_size;
int context_window_size;
int sliding_window_size;
int prefill_chunk_size;
int tensor_parallel_shards;
int max_batch_size;
std::optional<ModelVisionConfig> vision_config = std::nullopt;

static ModelConfig FromJSON(const picojson::object& json_obj);
};

/****************** Conversation template ******************/

enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION };
Expand Down Expand Up @@ -92,7 +129,7 @@ struct Conversation {
Conversation();

/*! \brief Create the list of prompts from the messages based on the conversation template. */
Result<std::vector<Data>> AsPrompt();
Result<std::vector<Data>> AsPrompt(ModelConfig config, DLDevice device);

/*! \brief Create a Conversation instance from the given JSON object. */
static Result<Conversation> FromJSON(const picojson::object& json);
Expand Down
156 changes: 156 additions & 0 deletions cpp/json_ffi/image_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#include "image_utils.h"

#include <dmlc/io.h>

#include "../../3rdparty/tvm/src/support/base64.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

namespace mlc {
namespace llm {
namespace json_ffi {

using namespace tvm::runtime;

class MemoryBufferStream : public dmlc::Stream {
public:
MemoryBufferStream(const char* data, size_t size) : data_(data), size_(size), pos_(0) {}

size_t Read(void* ptr, size_t size) override {
size_t remaining = size_ - pos_;
if (size > remaining) {
size = remaining;
}
if (size == 0) {
return 0;
}
std::memcpy(ptr, data_ + pos_, size);
pos_ += size;
return size;
}

void Write(const void* ptr, size_t size) override {
LOG(FATAL) << "MemoryBufferStream does not support write";
}

private:
const char* data_;
size_t size_;
size_t pos_;
};

size_t Base64DecodedSize(const std::string& base64_str) {
size_t len = base64_str.size();
size_t padding = 0;
if (base64_str[len - 1] == '=') {
padding++;
}
if (base64_str[len - 2] == '=') {
padding++;
}
return 3 * len / 4 - padding;
}

Result<NDArray> LoadImageFromBase64(const std::string& base64_str) {
using TResult = Result<NDArray>;
MemoryBufferStream stream(base64_str.c_str(), base64_str.size());
tvm::support::Base64InStream base64_stream(&stream);
size_t decoded_size = Base64DecodedSize(base64_str);
std::vector<unsigned char> decoded(decoded_size);
base64_stream.InitPosition();
base64_stream.Read((void*)decoded.data(), decoded_size);
int width, height, num_channels;
unsigned char* image_data =
stbi_load_from_memory(decoded.data(), decoded_size, &width, &height, &num_channels, 3);
if (!image_data) {
return TResult::Error(stbi_failure_reason());
}
auto image_ndarray = NDArray::Empty({height, width, 3}, {kDLUInt, 8, 1}, {kDLCPU, 0});
image_ndarray.CopyFromBytes((void*)image_data, width * height * 3);
stbi_image_free(image_data);
return TResult::Ok(image_ndarray);
}

NDArray ClipPreprocessor(NDArray image_data, int target_size, DLDevice device) {
int height = image_data->shape[0];
int width = image_data->shape[1];
// Resize
const int short_side = width < height ? width : height;
const int long_side = width > height ? width : height;
const int new_short_side = target_size;
const int new_long_side = (int)(new_short_side * (long_side / (float)short_side));
const int new_width = width < height ? new_short_side : new_long_side;
const int new_height = width > height ? new_short_side : new_long_side;

std::vector<float> processed_image_data(new_width * new_height * 3);

// Bilinear Interpolation
for (int y = 0; y < new_height; y++) {
for (int x = 0; x < new_width; x++) {
const float x_ratio = float(width - 1) / new_width;
const float y_ratio = float(height - 1) / new_height;
const int x1 = int(x_ratio * x);
const int y1 = int(y_ratio * y);
const int x2 = x1 + 1;
const int y2 = y1 + 1;
const float x_diff = x_ratio * x - x1;
const float y_diff = y_ratio * y - y1;
for (int c = 0; c < 3; c++) {
const uint8_t top_left = ((uint8_t*)image_data->data)[(y1 * width + x1) * 3 + c];
const uint8_t top_right = ((uint8_t*)image_data->data)[(y1 * width + x2) * 3 + c];
const uint8_t bottom_left = ((uint8_t*)image_data->data)[(y2 * width + x1) * 3 + c];
const uint8_t bottom_right = ((uint8_t*)image_data->data)[(y2 * width + x2) * 3 + c];
processed_image_data[(y * new_width + x) * 3 + c] =
(float)(int(top_left * (1 - x_diff) * (1 - y_diff) + top_right * x_diff * (1 - y_diff) +
bottom_left * y_diff * (1 - x_diff) + bottom_right * x_diff * y_diff));
}
}
}

// Center crop
const int crop_x = (new_width - target_size) / 2;
const int crop_y = (new_height - target_size) / 2;
std::vector<float> cropped_image_data(target_size * target_size * 3);
for (int y = 0; y < target_size; y++) {
for (int x = 0; x < target_size; x++) {
for (int c = 0; c < 3; c++) {
cropped_image_data[(y * target_size + x) * 3 + c] =
processed_image_data[((y + crop_y) * new_width + x + crop_x) * 3 + c];
}
}
}

// Rescale
for (int i = 0; i < target_size * target_size * 3; i++) {
cropped_image_data[i] = cropped_image_data[i] / 255.0f;
}

// Normalize
const float IMAGE_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f};
const float IMAGE_STD[] = {0.26862954f, 0.26130258f, 0.27577711f};
for (int i = 0; i < target_size * target_size * 3; i++) {
const int c = i % 3;
cropped_image_data[i] = (cropped_image_data[i] - IMAGE_MEAN[c]) / IMAGE_STD[c];
}

std::vector<float> image_data_channel_first(target_size * target_size * 3);
for (int y = 0; y < target_size; y++) {
for (int x = 0; x < target_size; x++) {
for (int c = 0; c < 3; c++) {
image_data_channel_first[c * target_size * target_size + y * target_size + x] =
cropped_image_data[(y * target_size + x) * 3 + c];
}
}
}

// Create NDArray
auto image_ndarray = NDArray::Empty({1, 3, target_size, target_size}, {kDLFloat, 32, 1}, device);
image_ndarray.CopyFromBytes((void*)image_data_channel_first.data(),
target_size * target_size * 3 * sizeof(float));

return image_ndarray;
}

} // namespace json_ffi
} // namespace llm
} // namespace mlc
Loading

0 comments on commit 5ae393a

Please sign in to comment.