Skip to content

Commit

Permalink
Merge pull request jmont-dev#16 from jmont-dev/tool-support
Browse files Browse the repository at this point in the history
Update embeddings to new endpoint; change tests to not rely on exact comparisons.
  • Loading branch information
jmont-dev authored Aug 12, 2024
2 parents f4909c6 + 0aaa3bd commit fdd114f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 50 deletions.
19 changes: 10 additions & 9 deletions include/ollama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,14 @@ namespace ollama
request(): json() {}
~request(){};

static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m")
{
ollama::request request(message_type::embedding);

request["model"] = name;
request["prompt"] = prompt;
request["model"] = model;
request["input"] = input;
if (options!=nullptr) request["options"] = options["options"];
request["truncate"] = truncate;
request["keep_alive"] = keep_alive_duration;

return request;
Expand All @@ -295,7 +296,7 @@ namespace ollama

if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get<std::string>();
else
if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get<std::string>();
if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get<std::string>();
else
if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get<std::string>();

Expand Down Expand Up @@ -715,15 +716,15 @@ class Ollama
return false;
}

ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
{
ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration);
ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration);
ollama::response response;

std::string request_string = request.dump();
if (ollama::log_requests) std::cout << request_string << std::endl;

if (auto res = cli->Post("/api/embeddings", request_string, "application/json"))
if (auto res = cli->Post("/api/embed", request_string, "application/json"))
{
if (ollama::log_replies) std::cout << res->body << std::endl;

Expand Down Expand Up @@ -885,9 +886,9 @@ namespace ollama
return ollama.push_model(model, allow_insecure);
}

inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
{
return ollama.generate_embeddings(model, prompt, options, keep_alive_duration);
return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration);
}

inline void setReadTimeout(const int& seconds)
Expand Down
19 changes: 10 additions & 9 deletions singleheader/ollama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35053,13 +35053,14 @@ namespace ollama
request(): json() {}
~request(){};

static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m")
{
ollama::request request(message_type::embedding);

request["model"] = name;
request["prompt"] = prompt;
request["model"] = model;
request["input"] = input;
if (options!=nullptr) request["options"] = options["options"];
request["truncate"] = truncate;
request["keep_alive"] = keep_alive_duration;

return request;
Expand All @@ -35085,7 +35086,7 @@ namespace ollama

if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get<std::string>();
else
if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get<std::string>();
if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get<std::string>();
else
if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get<std::string>();

Expand Down Expand Up @@ -35505,15 +35506,15 @@ class Ollama
return false;
}

ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
{
ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration);
ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration);
ollama::response response;

std::string request_string = request.dump();
if (ollama::log_requests) std::cout << request_string << std::endl;

if (auto res = cli->Post("/api/embeddings", request_string, "application/json"))
if (auto res = cli->Post("/api/embed", request_string, "application/json"))
{
if (ollama::log_replies) std::cout << res->body << std::endl;

Expand Down Expand Up @@ -35675,9 +35676,9 @@ namespace ollama
return ollama.push_model(model, allow_insecure);
}

inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
{
return ollama.generate_embeddings(model, prompt, options, keep_alive_duration);
return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration);
}

inline void setReadTimeout(const int& seconds)
Expand Down
58 changes: 26 additions & 32 deletions test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// Note that this is static. We will use these options for other generations.
static ollama::options options;

static std::string test_model = "llama3:8b", image_test_model = "llava";

TEST_SUITE("Ollama Tests") {

TEST_CASE("Initialize Options") {
Expand Down Expand Up @@ -52,19 +54,19 @@ TEST_SUITE("Ollama Tests") {

TEST_CASE("Load Model") {

CHECK( ollama::load_model("llama3:8b") );
CHECK( ollama::load_model(test_model) );
}

TEST_CASE("Pull, Copy, and Delete Models") {

// Pull a model by specifying a model name.
CHECK( ollama::pull_model("llama3:8b") == true );
CHECK( ollama::pull_model(test_model) == true );

// Copy a model by specifying a source model and destination model name.
CHECK( ollama::copy_model("llama3:8b", "llama3_copy") ==true );
CHECK( ollama::copy_model(test_model, test_model+"_copy") ==true );

// Delete a model by specifying a model name.
CHECK( ollama::delete_model("llama3_copy") == true );
CHECK( ollama::delete_model(test_model+"_copy") == true );
}

TEST_CASE("Model Info") {
Expand All @@ -81,7 +83,7 @@ TEST_SUITE("Ollama Tests") {
// List the models available locally in the ollama server
std::vector<std::string> models = ollama::list_models();

bool contains_model = (std::find(models.begin(), models.end(), "llama3:8b") != models.end() );
bool contains_model = (std::find(models.begin(), models.end(), test_model) != models.end() );

CHECK( contains_model );
}
Expand All @@ -101,12 +103,9 @@ TEST_SUITE("Ollama Tests") {

TEST_CASE("Basic Generation") {

ollama::response response = ollama::generate("llama3:8b", "Why is the sky blue?", options);
//std::cout << response << std::endl;

std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";
ollama::response response = ollama::generate(test_model, "Why is the sky blue?", options);

CHECK(response.as_simple_string() == expected_response);
CHECK( response.as_json().contains("response") == true );
}


Expand All @@ -124,35 +123,34 @@ TEST_SUITE("Ollama Tests") {
TEST_CASE("Streaming Generation") {

std::function<void(const ollama::response&)> response_callback = on_receive_response;
ollama::generate("llama3:8b", "Why is the sky blue?", response_callback, options);
ollama::generate(test_model, "Why is the sky blue?", response_callback, options);

std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";

CHECK( streamed_response == expected_response );
CHECK( streamed_response != "" );
}

TEST_CASE("Non-Singleton Generation") {

Ollama my_ollama_server("http://localhost:11434");

// You can use all of the same functions from this instanced version of the class.
ollama::response response = my_ollama_server.generate("llama3:8b", "Why is the sky blue?", options);
//std::cout << response << std::endl;
ollama::response response = my_ollama_server.generate(test_model, "Why is the sky blue?", options);

std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";

CHECK(response.as_simple_string() == expected_response);
CHECK(response.as_json().contains("response") == true);
}

TEST_CASE("Single-Message Chat") {

ollama::message message("user", "Why is the sky blue?");

ollama::response response = ollama::chat("llama3:8b", message, options);
ollama::response response = ollama::chat(test_model, message, options);

std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";

CHECK(response.as_simple_string()!="");
CHECK(response.as_json().contains("message") == true);
}

TEST_CASE("Multi-Message Chat") {
Expand All @@ -163,11 +161,11 @@ TEST_SUITE("Ollama Tests") {

ollama::messages messages = {message1, message2, message3};

ollama::response response = ollama::chat("llama3:8b", messages, options);
ollama::response response = ollama::chat(test_model, messages, options);

std::string expected_response = "";

CHECK(response.as_simple_string()!="");
CHECK(response.as_json().contains("message") == true);
}

TEST_CASE("Chat with Streaming Response") {
Expand All @@ -182,7 +180,7 @@ TEST_SUITE("Ollama Tests") {

ollama::message message("user", "Why is the sky blue?");

ollama::chat("llama3:8b", message, response_callback, options);
ollama::chat(test_model, message, response_callback, options);

CHECK(streamed_response!="");
}
Expand All @@ -195,12 +193,9 @@ TEST_SUITE("Ollama Tests") {

ollama::image image = ollama::image::from_file("llama.jpg");

//ollama::images images={image};

ollama::response response = ollama::generate("llava", "What do you see in this image?", options, image);
std::string expected_response = " The image features a large, fluffy white llama";
ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, image);

CHECK(response.as_simple_string() == expected_response);
CHECK( response.as_json().contains("response") == true );
}

TEST_CASE("Generation with Multiple Images") {
Expand All @@ -214,10 +209,10 @@ TEST_SUITE("Ollama Tests") {

ollama::images images={image, base64_image};

ollama::response response = ollama::generate("llava", "What do you see in this image?", options, images);
ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, images);
std::string expected_response = " The image features a large, fluffy white and gray llama";

CHECK(response.as_simple_string() == expected_response);
CHECK(response.as_json().contains("response") == true);
}

TEST_CASE("Chat with Image") {
Expand All @@ -230,21 +225,20 @@ TEST_SUITE("Ollama Tests") {

// We can optionally include images with each message. Vision-enabled models will be able to utilize these.
ollama::message message_with_image("user", "What do you see in this image?", image);
ollama::response response = ollama::chat("llava", message_with_image, options);
ollama::response response = ollama::chat(image_test_model, message_with_image, options);

std::string expected_response = " The image features a large, fluffy white llama";

CHECK(response.as_simple_string()!="");
CHECK(response.as_json().contains("message") == true);
}

TEST_CASE("Embedding Generation") {

options["num_predict"] = 18;

ollama::response response = ollama::generate_embeddings("llama3:8b", "Why is the sky blue?");
//std::cout << response << std::endl;
ollama::response response = ollama::generate_embeddings(test_model, "Why is the sky blue?");

CHECK(response.as_json().contains("embedding") == true);
CHECK(response.as_json().contains("embeddings") == true);
}

TEST_CASE("Enable Debug Logging") {
Expand Down

0 comments on commit fdd114f

Please sign in to comment.