Skip to content

Commit

Permalink
feat: add support for embeddings and reranking models
Browse files Browse the repository at this point in the history
  • Loading branch information
bernard-ng committed Nov 21, 2024
1 parent 1cf0ba5 commit 01e2fb4
Show file tree
Hide file tree
Showing 32 changed files with 561 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat\Completion;
namespace Devscast\Lugha\Model\Completion\Chat;

/**
* Class History.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat\Completion;
namespace Devscast\Lugha\Model\Completion\Chat;

/**
* Class Message.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat\Completion;
namespace Devscast\Lugha\Model\Completion\Chat;

/**
* Class Role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat;
namespace Devscast\Lugha\Model\Completion;

use Devscast\Lugha\Model\Chat\Completion\History;
use Devscast\Lugha\Model\Chat\Completion\Message;
use Devscast\Lugha\Model\Chat\Prompt\PromptTemplate;
use Devscast\Lugha\Model\Completion\Chat\History;
use Devscast\Lugha\Model\Completion\Chat\Message;
use Devscast\Lugha\Model\Completion\Prompt\PromptTemplate;

/**
* Interface ChatInterface.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,25 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat;
namespace Devscast\Lugha\Model\Completion;

use Webmozart\Assert\Assert;

/**
* Class ChatConfig.
* Class CompletionConfig.
*
* @see https://platform.openai.com/docs/api-reference/chat/object
* @see https://ai.google.dev/gemini-api/docs/text-generation?lang=rest#configure
* @see https://docs.mistral.ai/api/#tag/chat
* @see https://docs.anthropic.com/en/api/messages
* @see https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
*
* @author bernard-ng <bernard@devscast.tech>
*/
final readonly class ChatConfig
final readonly class CompletionConfig
{
/**
* @param string $model The model to use for generating the text.
* @param float|null $temperature The value used to control the randomness of the generated text.
* @param int|null $maxTokens The maximum number of tokens to generate.
* @param float|null $topP The cumulative probability of the top tokens to keep.
Expand All @@ -37,6 +39,7 @@
* @param array|null $stopSequences A list of sequences where the model should stop generating the text.
*/
public function __construct(
public string $model,
public ?float $temperature = null,
public ?int $maxTokens = null,
public ?float $topP = null,
Expand All @@ -45,6 +48,7 @@ public function __construct(
public ?float $presencePenalty = null,
public ?array $stopSequences = null,
) {
Assert::notEmpty($this->model);
Assert::nullOrRange($this->temperature, 0, 2);
Assert::nullOrPositiveInteger($this->maxTokens);
Assert::nullOrRange($this->topP, 0, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat\Parser;
namespace Devscast\Lugha\Model\Completion\Parser;

use League\CommonMark\Environment\Environment;
use League\CommonMark\Exception\CommonMarkException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat\Parser;
namespace Devscast\Lugha\Model\Completion\Parser;

/**
* Interface OutputParserInterface.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

declare(strict_types=1);

namespace Devscast\Lugha\Model\Chat\Prompt;
namespace Devscast\Lugha\Model\Completion\Prompt;

use Webmozart\Assert\Assert;

Expand Down
3 changes: 3 additions & 0 deletions src/Model/Embedding/EmbeddingConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
final readonly class EmbeddingConfig
{
/**
* @param string $model The model to use for generating the embeddings.
* @param int|null $dimensions The dimensionality of the embeddings to generate.
* @param string $encodingFormat The encoding format to use for the embeddings.
*/
public function __construct(
public string $model,
public ?int $dimensions = null,
public string $encodingFormat = 'float',
) {
Assert::notEmpty($this->model);
Assert::nullOrPositiveInteger($this->dimensions);
Assert::oneOf($this->encodingFormat, ['float', 'base64']);
}
Expand Down
19 changes: 19 additions & 0 deletions src/Model/Reranking/RankedDocument.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<?php

declare(strict_types=1);

namespace Devscast\Lugha\Model\Reranking;

/**
* Class RankedDocument.
*
* @author bernard-ng <bernard@devscast.tech>
*/
final class RankedDocument
{
public function __construct(
public string $content,
public float $score
) {
}
}
36 changes: 36 additions & 0 deletions src/Model/Reranking/RerankingConfig.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<?php

/*
* This file is part of the Lugha package.
*
* (c) Bernard Ngandu <bernard@devscast.tech>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

declare(strict_types=1);

namespace Devscast\Lugha\Model\Reranking;

use Webmozart\Assert\Assert;

/**
* Class RerankingConfig.
*
* @author bernard-ng <bernard@devscast.tech>
*/
final readonly class RerankingConfig
{
/**
* @param string $model The model to use for generating the embeddings.
* @param int $topK The number of top results to return.
*/
public function __construct(
public string $model,
public int $topK
) {
Assert::notEmpty($this->model);
Assert::positiveInteger($this->topK);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

declare(strict_types=1);

namespace Devscast\Lugha\Provider\Config;
namespace Devscast\Lugha\Provider;

use Webmozart\Assert\Assert;

Expand All @@ -31,6 +31,7 @@
* @param bool $providerResponse Whether to return the full provider's response along with the result.
*/
public function __construct(
#[\SensitiveParameter]
public string $apiKey,
public ?string $baseUri = null,
public ?int $maxRetries = 2,
Expand Down
20 changes: 20 additions & 0 deletions src/Provider/Response/CompletionResponse.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?php

declare(strict_types=1);

namespace Devscast\Lugha\Provider\Response;

/**
* Class CompletionResponse.
*
* @author bernard-ng <bernard@devscast.tech>
*/
final readonly class CompletionResponse
{
public function __construct(
public string $model,
public string $completion,
public array $providerResponse = []
) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@

declare(strict_types=1);

namespace Devscast\Lugha\Provider\Service\Ollama;

use Devscast\Lugha\Provider\Service\AbstractClient;
namespace Devscast\Lugha\Provider\Response;

/**
* Class Client.
* Class EmbeddingResponse.
*
* @author bernard-ng <bernard@devscast.tech>
*/
final class Client extends AbstractClient
final readonly class EmbeddingResponse
{
protected const string BASE_URI = 'http://localhost:11434/api/';
public function __construct(
public string $model,
public array $embedding,
public array $providerResponse = [],
) {
}
}
25 changes: 25 additions & 0 deletions src/Provider/Response/RerankingResponse.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<?php

declare(strict_types=1);

namespace Devscast\Lugha\Provider\Response;

use Devscast\Lugha\Model\Reranking\RankedDocument;

/**
* Class RerankingResponse.
*
* @author bernard-ng <bernard@devscast.tech>
*/
final readonly class RerankingResponse
{
/**
* @param array<RankedDocument> $documents
*/
public function __construct(
public string $model,
public array $documents,
public array $providerResponse = []
) {
}
}
2 changes: 1 addition & 1 deletion src/Provider/Service/AbstractClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace Devscast\Lugha\Provider\Service;

use Devscast\Lugha\Provider\Config\ProviderConfig;
use Devscast\Lugha\Provider\ProviderConfig;
use Symfony\Component\HttpClient\HttpClient;
use Symfony\Component\HttpClient\Retry\GenericRetryStrategy;
use Symfony\Component\HttpClient\RetryableHttpClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

declare(strict_types=1);

namespace Devscast\Lugha\Provider\Service\Anthropic;
namespace Devscast\Lugha\Provider\Service\Client;

use Devscast\Lugha\Provider\Service\AbstractClient;

/**
* Class Client.
* Class OllamaClient.
*
* @see https://docs.anthropic.com/en/api/getting-started
*
* @author bernard-ng <bernard@devscast.tech>
*/
final class Client extends AbstractClient
final class AnthropicClient extends AbstractClient
{
protected const string BASE_URI = 'https://api.anthropic.com/v1/';
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

declare(strict_types=1);

namespace Devscast\Lugha\Provider\Service\Github;
namespace Devscast\Lugha\Provider\Service\Client;

use Devscast\Lugha\Provider\Service\AbstractClient;

/**
* Class Client.
* Class OllamaClient.
*
* @see https://github.com/marketplace/models/
*
* @author bernard-ng <bernard@devscast.tech>
*/
final class Client extends AbstractClient
final class GithubClient extends AbstractClient
{
protected const string BASE_URI = 'https://models.inference.ai.azure.com/';
}
63 changes: 63 additions & 0 deletions src/Provider/Service/Client/GoogleClient.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
<?php

/*
* This file is part of the Lugha package.
*
* (c) Bernard Ngandu <bernard@devscast.tech>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

declare(strict_types=1);

namespace Devscast\Lugha\Provider\Service\Client;

use Devscast\Lugha\Model\Embedding\EmbeddingConfig;
use Devscast\Lugha\Provider\Response\EmbeddingResponse;
use Devscast\Lugha\Provider\Service\AbstractClient;
use Devscast\Lugha\Provider\Service\HasEmbeddingSupport;
use Devscast\Lugha\Provider\Service\IntegrationException;
use Webmozart\Assert\Assert;

/**
* Class OllamaClient.
*
* @see https://ai.google.dev/api
* @see https://ai.google.dev/gemini-api/docs/embeddings#curl
*
* @author bernard-ng <bernard@devscast.tech>
*/
final class GoogleClient extends AbstractClient implements HasEmbeddingSupport
{
protected const string BASE_URI = 'https://generativelanguage.googleapis.com/v1beta/';

#[\Override]
public function embeddings(string $prompt, EmbeddingConfig $config): EmbeddingResponse
{
Assert::notEmpty($prompt);

try {
$response = $this->http->request('POST', "models/{$config->model}:embedContent?key={$this->config->apiKey}", [
'json' => [
'model' => "models/{$config->model}",
'content' => [
'parts' => [
[
'text' => $prompt,
],
],
],
],
])->toArray();

return new EmbeddingResponse(
model: $config->model,
embedding: $response['embedding']['values'],
providerResponse: [] // no special information to pass
);
} catch (\Throwable $e) {
throw new IntegrationException('Unable to generate embeddings.', previous: $e);
}
}
}
Loading

0 comments on commit 01e2fb4

Please sign in to comment.