Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/ai/openai.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ constexpr const char* kChatGpt4oLatest = "chatgpt-4o-latest";

/// Default model used when none is specified
constexpr const char* kDefaultModel = kGpt4o;

} // namespace models

/// Create an OpenAI client with default configuration
Expand Down Expand Up @@ -82,4 +83,4 @@ Client create_client(const std::string& api_key,
std::optional<Client> try_create_client();

} // namespace openai
} // namespace ai
} // namespace ai
7 changes: 7 additions & 0 deletions include/ai/types/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "generate_options.h"
#include "stream_options.h"
#include "stream_result.h"
#include "embeddding_options.h"

#include <memory>
#include <string>
Expand Down Expand Up @@ -31,6 +32,12 @@ class Client {
return GenerateResult("Client not initialized");
}

virtual EmbeddingResult embeddings(const EmbeddingOptions& options) {
if (pimpl_)
return pimpl_->embeddings(options);
return EmbeddingResult("Client not initialized");
}

virtual StreamResult stream_text(const StreamOptions& options) {
if (pimpl_)
return pimpl_->stream_text(options);
Expand Down
84 changes: 84 additions & 0 deletions include/ai/types/embeddding_options.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#pragma once

#include "enums.h"
#include "message.h"
#include "model.h"
#include "tool.h"
#include "usage.h"

#include <functional>
#include <optional>
#include <string>
#include <vector>

namespace ai {

struct EmbeddingOptions {
std::string model;
nlohmann::json input;
std::optional<int> dimensions;
std::optional<std::string> encoding_format;
std::optional<int> max_tokens;
std::optional<double> temperature;
std::optional<double> top_p;
std::optional<int> seed;
std::optional<double> frequency_penalty;
std::optional<double> presence_penalty;

EmbeddingOptions(std::string model_name, nlohmann::json input_)
: model(std::move(model_name)),
input(std::move(input_)) {}

EmbeddingOptions(std::string model_name, nlohmann::json input_, int dimensions_)
: model(std::move(model_name)),
input(std::move(input_)),
dimensions(dimensions_) {}

EmbeddingOptions(std::string model_name, nlohmann::json input_, int dimensions_, std::string encoding_format_)
: model(std::move(model_name)),
input(std::move(input_)),
dimensions(dimensions_),
encoding_format(std::move(encoding_format_)) {}

EmbeddingOptions() = default;

bool is_valid() const {
return !model.empty() && !input.empty();
}

bool has_input() const { return !input.empty(); }

};

struct EmbeddingResult {
nlohmann::json data;
Usage usage;

/// Additional metadata (like TypeScript SDK)
std::optional<std::string> model;

/// Error handling
std::optional<std::string> error;
std::optional<bool> is_retryable;

/// Provider-specific metadata
std::optional<std::string> provider_metadata;

EmbeddingResult() = default;

// EmbeddingResult(std::string data_, Usage token_usage)
// : data(std::move(data_)), usage(token_usage) {}

explicit EmbeddingResult(std::optional<std::string> error_message)
: error(std::move(error_message)) {}

bool is_success() const {
return !error.has_value();
}

explicit operator bool() const { return is_success(); }

std::string error_message() const { return error.value_or(""); }
};

} // namespace ai
10 changes: 10 additions & 0 deletions include/ai/types/generate_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct GenerateOptions {
std::string system;
std::string prompt;
Messages messages;
std::optional<nlohmann::json> response_format {};
std::optional<int> max_tokens;
std::optional<double> temperature;
std::optional<double> top_p;
Expand Down Expand Up @@ -46,6 +47,15 @@ struct GenerateOptions {
system(std::move(system_prompt)),
prompt(std::move(user_prompt)) {}

GenerateOptions(std::string model_name,
std::string system_prompt,
std::string user_prompt,
std::optional<nlohmann::json> response_format_)
: model(std::move(model_name)),
system(std::move(system_prompt)),
prompt(std::move(user_prompt)),
response_format(std::move(response_format_)) {}

GenerateOptions(std::string model_name, Messages conversation)
: model(std::move(model_name)), messages(std::move(conversation)) {}

Expand Down
9 changes: 5 additions & 4 deletions src/providers/anthropic/anthropic_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ AnthropicClient::AnthropicClient(const std::string& api_key,
providers::ProviderConfig{
.api_key = api_key,
.base_url = base_url,
.endpoint_path = "/v1/messages",
.completions_endpoint_path = "/v1/messages",
.embeddings_endpoint_path = "/v1/embeddings",
.auth_header_name = "x-api-key",
.auth_header_prefix = "",
.extra_headers = {{"anthropic-version", "2023-06-01"}}},
Expand All @@ -44,8 +45,8 @@ StreamResult AnthropicClient::stream_text(const StreamOptions& options) {

// Create stream implementation
auto impl = std::make_unique<AnthropicStreamImpl>();
impl->start_stream(config_.base_url + config_.endpoint_path, headers,
request_json);
impl->start_stream(config_.base_url + config_.completions_endpoint_path,
headers, request_json);

ai::logger::log_info("Text streaming started - model: {}", options.model);

Expand Down Expand Up @@ -77,4 +78,4 @@ std::string AnthropicClient::default_model() const {
}

} // namespace anthropic
} // namespace ai
} // namespace ai
2 changes: 1 addition & 1 deletion src/providers/anthropic/anthropic_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ class AnthropicClient : public providers::BaseProviderClient {
};

} // namespace anthropic
} // namespace ai
} // namespace ai
8 changes: 7 additions & 1 deletion src/providers/anthropic/anthropic_request_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ nlohmann::json AnthropicRequestBuilder::build_request_json(
return request;
}

nlohmann::json AnthropicRequestBuilder::build_request_json(const EmbeddingOptions& options) {
nlohmann::json request{{"model", options.model},
{"input", options.input}};
return request;
}

httplib::Headers AnthropicRequestBuilder::build_headers(
const providers::ProviderConfig& config) {
httplib::Headers headers = {
Expand All @@ -172,4 +178,4 @@ httplib::Headers AnthropicRequestBuilder::build_headers(
}

} // namespace anthropic
} // namespace ai
} // namespace ai
1 change: 1 addition & 0 deletions src/providers/anthropic/anthropic_request_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace anthropic {
class AnthropicRequestBuilder : public providers::RequestBuilder {
public:
nlohmann::json build_request_json(const GenerateOptions& options) override;
nlohmann::json build_request_json(const EmbeddingOptions& options) override;
httplib::Headers build_headers(
const providers::ProviderConfig& config) override;
};
Expand Down
41 changes: 39 additions & 2 deletions src/providers/anthropic/anthropic_response_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace ai {
namespace anthropic {

GenerateResult AnthropicResponseParser::parse_success_response(
GenerateResult AnthropicResponseParser::parse_success_completion_response(
const nlohmann::json& response) {
ai::logger::log_debug("Parsing Anthropic messages response");

Expand Down Expand Up @@ -86,12 +86,49 @@ GenerateResult AnthropicResponseParser::parse_success_response(
return result;
}

GenerateResult AnthropicResponseParser::parse_error_response(
GenerateResult AnthropicResponseParser::parse_error_completion_response(
int status_code,
const std::string& body) {
return utils::parse_standard_error_response("Anthropic", status_code, body);
}

EmbeddingResult AnthropicResponseParser::parse_success_embedding_response(const nlohmann::json& response) {
ai::logger::log_debug("Parsing Anthropic embeddings response");

EmbeddingResult result;

// Extract basic fields
result.model = response.value("model", "");

// Extract choices
if (response.contains("data") && !response["data"].empty()) {
result.data = std::move(response["data"]);
}

// Extract usage
if (response.contains("usage")) {
auto& usage = response["usage"];
result.usage.prompt_tokens = usage.value("prompt_tokens", 0);
result.usage.completion_tokens = usage.value("completion_tokens", 0);
result.usage.total_tokens = usage.value("total_tokens", 0);
ai::logger::log_debug("Token usage - prompt: {}, completion: {}, total: {}",
result.usage.prompt_tokens,
result.usage.completion_tokens,
result.usage.total_tokens);
}

// Store full metadata
result.provider_metadata = response.dump();

return result;
}

EmbeddingResult AnthropicResponseParser::parse_error_embedding_response(int status_code, const std::string& body) {
auto generate_result = utils::parse_standard_error_response("Anthropic", status_code, body);
return EmbeddingResult(generate_result.error);
}


FinishReason AnthropicResponseParser::parse_stop_reason(
const std::string& reason) {
if (reason == "end_turn") {
Expand Down
8 changes: 6 additions & 2 deletions src/providers/anthropic/anthropic_response_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ namespace anthropic {

class AnthropicResponseParser : public providers::ResponseParser {
public:
GenerateResult parse_success_response(
GenerateResult parse_success_completion_response(
const nlohmann::json& response) override;
GenerateResult parse_error_response(int status_code,
GenerateResult parse_error_completion_response(int status_code,
const std::string& body) override;
EmbeddingResult parse_success_embedding_response(
const nlohmann::json& response) override;
EmbeddingResult parse_error_embedding_response(int status_code,
const std::string& body) override;

private:
Expand Down
62 changes: 57 additions & 5 deletions src/providers/base_provider_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ BaseProviderClient::BaseProviderClient(
http_handler_ = std::make_unique<http::HttpRequestHandler>(http_config);

ai::logger::log_debug(
"BaseProviderClient initialized - base_url: {}, endpoint: {}",
config.base_url, config.endpoint_path);
R"(BaseProviderClient initialized - base_url: {},
completions_endpoint: {}, embeddings_endpoint: {})",
config.base_url, config.completions_endpoint_path,
config.embeddings_endpoint_path);
}

GenerateResult BaseProviderClient::generate_text(
Expand Down Expand Up @@ -65,13 +67,13 @@ GenerateResult BaseProviderClient::generate_text_single_step(

// Make the request
auto result =
http_handler_->post(config_.endpoint_path, headers, json_body);
http_handler_->post(config_.completions_endpoint_path, headers, json_body);

if (!result.is_success()) {
// Parse error response using provider-specific parser
if (result.provider_metadata.has_value()) {
int status_code = std::stoi(result.provider_metadata.value());
return response_parser_->parse_error_response(
return response_parser_->parse_error_completion_response(
status_code, result.error.value_or(""));
}
return result;
Expand All @@ -94,7 +96,7 @@ GenerateResult BaseProviderClient::generate_text_single_step(

// Parse using provider-specific parser
auto parsed_result =
response_parser_->parse_success_response(json_response);
response_parser_->parse_success_completion_response(json_response);

if (parsed_result.has_tool_calls()) {
ai::logger::log_debug("Model made {} tool calls",
Expand Down Expand Up @@ -144,5 +146,55 @@ StreamResult BaseProviderClient::stream_text(const StreamOptions& options) {
return StreamResult();
}

EmbeddingResult BaseProviderClient::embeddings(const EmbeddingOptions& options) {
try {
// Build request JSON using the provider-specific builder
auto request_json = request_builder_->build_request_json(options);
std::string json_body = request_json.dump();
ai::logger::log_debug("Request JSON built: {}", json_body);

// Build headers
auto headers = request_builder_->build_headers(config_);

// Make the requests
auto result =
http_handler_->post(config_.embeddings_endpoint_path, headers, json_body);

if (!result.is_success()) {
// Parse error response using provider-specific parser
if (result.provider_metadata.has_value()) {
int status_code = std::stoi(result.provider_metadata.value());
return response_parser_->parse_error_embedding_response(
status_code, result.error.value_or(""));
}
return EmbeddingResult(result.error);
}

// Parse the response JSON from result.text
nlohmann::json json_response;
try {
json_response = nlohmann::json::parse(result.text);
} catch (const nlohmann::json::exception& e) {
ai::logger::log_error("Failed to parse response JSON: {}", e.what());
ai::logger::log_debug("Raw response text: {}", result.text);
return EmbeddingResult("Failed to parse response: " +
std::string(e.what()));
}

ai::logger::log_info(
"Embeddings successful - model: {}, response_id: {}",
options.model, json_response.value("id", "unknown"));

// Parse using provider-specific parser
auto parsed_result =
response_parser_->parse_success_embedding_response(json_response);
return parsed_result;

} catch (const std::exception& e) {
ai::logger::log_error("Exception during embeddings: {}", e.what());
return EmbeddingResult(std::string("Exception: ") + e.what());
}
}

} // namespace providers
} // namespace ai
Loading
Loading