diff --git a/drush.services.yml b/drush.services.yml index 93be7ad..21e70be 100644 --- a/drush.services.yml +++ b/drush.services.yml @@ -19,3 +19,10 @@ services: - '@plugin.manager.llm_services' tags: - { name: console.command } + + llm.service.char.command: + class: Drupal\llm_services\Commands\ModelChatCommand + arguments: + - '@plugin.manager.llm_services' + tags: + - { name: console.command } diff --git a/src/Client/Ollama.php b/src/Client/Ollama.php index caa2d58..504ad22 100644 --- a/src/Client/Ollama.php +++ b/src/Client/Ollama.php @@ -28,7 +28,7 @@ class Ollama { * @param string $url * The URL of the Ollama server. * @param int $port - * The port that Ollama is listening on. + * The port that Ollama is listening at. */ public function __construct( private readonly string $url, @@ -42,15 +42,13 @@ public function __construct( * @return array> * Basic information about the models. * - * @throws \GuzzleHttp\Exception\GuzzleException - * @throws \JsonException + * @throws \Drupal\llm_services\Exceptions\CommunicationException */ public function listLocalModels(): array { $response = $this->call(method: 'get', uri: '/api/tags'); $data = $response->getBody()->getContents(); $data = json_decode($data, TRUE); - // @todo Change to value objects. $models = []; foreach ($data['models'] as $item) { $models[$item['model']] = [ @@ -135,6 +133,66 @@ public function completion(Payload $payload): \Generator { } } + /** + * Chat with a model. + * + * @param \Drupal\llm_services\Model\Payload $payload + * The question to ask the module and the chat history. + * + * @return \Generator + * The response from the model as it completes it. + * + * @throws \Drupal\llm_services\Exceptions\CommunicationException + * @throws \JsonException + */ + public function chat(Payload $payload): \Generator { + $response = $this->call(method: 'post', uri: '/api/chat', options: [ + 'json' => [ + 'model' => $payload->model, + 'messages' => $this->chatMessagesAsArray($payload), + 'stream' => TRUE, + ], + 'headers' => [ + 'Content-Type' => 'application/json', + ], + RequestOptions::CONNECT_TIMEOUT => 10, + RequestOptions::TIMEOUT => 300, + RequestOptions::STREAM => TRUE, + ]); + + $body = $response->getBody(); + while (!$body->eof()) { + $data = $body->read(1024); + yield from $this->parse($data); + } + } + + /** + * Take all payload messages and change them into an array. + * + * This array of messages is used to give the model some chat context to make + * the interaction appear more like real char with a person. + * + * @param \Drupal\llm_services\Model\Payload $payload + * The payload sent to the chat function. + * + * @return array + * Array of messages to send to Ollama. + * + * @see https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-history + */ + private function chatMessagesAsArray(Payload $payload): array { + $messages = []; + foreach ($payload->messages as $message) { + $messages[] = [ + 'content' => $message->content, + 'role' => $message->role->value, + ]; + } + + return $messages; + } + /** * Parse LLM stream. * diff --git a/src/Client/OllamaChatResponse.php b/src/Client/OllamaChatResponse.php new file mode 100644 index 0000000..d4a688d --- /dev/null +++ b/src/Client/OllamaChatResponse.php @@ -0,0 +1,71 @@ +model; + } + + /** + * {@inheritdoc} + */ + public function getStatus(): bool { + return $this->done; + } + + /** + * {@inheritdoc} + */ + public function getContent(): string { + return $this->content; + } + + /** + * {@inheritdoc} + */ + public function getRole(): MessageRoles { + return $this->role; + } + + /** + * {@inheritdoc} + */ + public function getImages(): array { + return $this->images; + } + +} diff --git a/src/Commands/ModelChatCommand.php b/src/Commands/ModelChatCommand.php new file mode 100644 index 0000000..90a7528 --- /dev/null +++ b/src/Commands/ModelChatCommand.php @@ -0,0 +1,134 @@ +setName('llm:model:chat') + ->setDescription('Chat with model (use ctrl+c to stop chatting)') + ->addUsage('llm:model:chat ollama llama3') + ->addArgument( + name: 'provider', + mode: InputArgument::REQUIRED, + description: 'Name of the provider (plugin).' + ) + ->addArgument( + name: 'name', + mode: InputArgument::REQUIRED, + description: 'Name of the model to use.' + ) + ->addOption( + name: 'system-prompt', + mode: InputOption::VALUE_REQUIRED, + description: 'System message to instruct the llm have to behave.', + default: 'Use the following pieces of context to answer the users question. If you don\'t know the answer, just say that you don\'t know, don\'t try to make up an answer.' + ) + ->addOption( + name: 'temperature', + mode: InputOption::VALUE_REQUIRED, + description: 'The temperature of the model. Increasing the temperature will make the model answer more creatively.', + default: '0.8' + ) + ->addOption( + name: 'top-k', + mode: InputOption::VALUE_REQUIRED, + description: 'Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers.', + default: '40' + ) + ->addOption( + name: 'top-p', + mode: InputOption::VALUE_REQUIRED, + description: 'A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.', + default: '0.9' + ); + } + + /** + * {@inheritDoc} + */ + protected function execute(InputInterface $input, OutputInterface $output): int { + $providerName = $input->getArgument('provider'); + $name = $input->getArgument('name'); + + $systemPrompt = $input->getOption('system-prompt'); + $temperature = $input->getOption('temperature'); + $topK = $input->getOption('top-k'); + $topP = $input->getOption('top-p'); + + $provider = $this->providerManager->createInstance($providerName); + + // Build configuration. + $payLoad = new Payload(); + $payLoad->model = $name; + $payLoad->options = [ + 'temperature' => $temperature, + 'top_k' => $topK, + 'top_p' => $topP, + ]; + $msg = new Message(); + $msg->role = MessageRoles::System; + $msg->content = $systemPrompt; + $payLoad->messages[] = $msg; + + $helper = $this->getHelper('question'); + $question = new Question('Message: ', ''); + + // Keep cheating with the user. Not optimal, but okay for now. + while (TRUE) { + // Query the next question. + $output->write("\n"); + $msg = new Message(); + $msg->role = MessageRoles::User; + $msg->content = $helper->ask($input, $output, $question); + $payLoad->messages[] = $msg; + $output->write("\n"); + + $answer = ''; + foreach ($provider->chat($payLoad) as $res) { + $output->write($res->getContent()); + $answer .= $res->getContent(); + } + $output->write("\n"); + + // Add answer as context to the next question. + $msg = new Message(); + $msg->role = MessageRoles::Assistant; + $msg->content = $answer; + $payLoad->messages[] = $msg; + } + } + +} diff --git a/src/Commands/ModelCompletionCommand.php b/src/Commands/ModelCompletionCommand.php index 552a3ca..43d53e3 100644 --- a/src/Commands/ModelCompletionCommand.php +++ b/src/Commands/ModelCompletionCommand.php @@ -14,11 +14,7 @@ use Symfony\Component\Console\Output\OutputInterface; /** - * This is a literal copy of the example Symfony Console command - * from the documentation. - * - * See: - * http://symfony.com/doc/2.7/components/console/introduction.html#creating-a-basic-command + * Make a completion request against a provider model. */ class ModelCompletionCommand extends Command { @@ -41,7 +37,7 @@ protected function configure(): void { $this ->setName('llm:model:completion') ->setDescription('Make a completion request to a model') - ->addUsage('llm:model:completion ollama llama2 "Why is the sky blue?') + ->addUsage('llm:model:completion ollama llama3 "Why is the sky blue?') ->addArgument( name: 'provider', mode: InputArgument::REQUIRED, diff --git a/src/Commands/ProviderInstallCommand.php b/src/Commands/ProviderInstallCommand.php index 17240cf..f0c336a 100644 --- a/src/Commands/ProviderInstallCommand.php +++ b/src/Commands/ProviderInstallCommand.php @@ -4,13 +4,10 @@ namespace Drupal\llm_services\Commands; -use Drupal\llm_services\Model\Message; -use Drupal\llm_services\Model\Payload; use Drupal\llm_services\Plugin\LLModelProviderManager; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; -use Symfony\Component\Console\Input\InputOption; use Symfony\Component\Console\Output\OutputInterface; /** @@ -37,7 +34,7 @@ protected function configure(): void { $this ->setName('llm:provider:install') ->setDescription('Install model in provider') - ->addUsage('llm:install:model ollama llama2') + ->addUsage('llm:provider:install ollama llama3') ->addArgument( name: 'provider', mode: InputArgument::REQUIRED, diff --git a/src/Commands/ProviderListCommand.php b/src/Commands/ProviderListCommand.php index 3482d4d..d1f1518 100644 --- a/src/Commands/ProviderListCommand.php +++ b/src/Commands/ProviderListCommand.php @@ -4,13 +4,10 @@ namespace Drupal\llm_services\Commands; -use Drupal\llm_services\Model\Message; -use Drupal\llm_services\Model\Payload; use Drupal\llm_services\Plugin\LLModelProviderManager; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; -use Symfony\Component\Console\Input\InputOption; use Symfony\Component\Console\Output\OutputInterface; /** @@ -37,7 +34,7 @@ protected function configure(): void { $this ->setName('llm:provider:list') ->setDescription('Install model in provider') - ->addUsage('llm:install:model ollama llama2') + ->addUsage('llm:provider:list ollama') ->addArgument( name: 'provider', mode: InputArgument::REQUIRED, diff --git a/src/Model/ChatResponseInterface.php b/src/Model/ChatResponseInterface.php new file mode 100644 index 0000000..503deb7 --- /dev/null +++ b/src/Model/ChatResponseInterface.php @@ -0,0 +1,55 @@ + + * String of base64 encoded images. + */ + public function getImages(): array; + + /** + * The completion status. + * + * @return bool + * If false, the model has more to say. + */ + public function getStatus(): bool; + +} diff --git a/src/Model/CompletionResponseInterface.php b/src/Model/CompletionResponseInterface.php index afc3b5f..f2aa649 100644 --- a/src/Model/CompletionResponseInterface.php +++ b/src/Model/CompletionResponseInterface.php @@ -16,10 +16,10 @@ interface CompletionResponseInterface { public function getModel(): string; /** - * The response from the module. + * The response from the model. * * @return string - * The text generated by the modul. + * The text generated by the model. */ public function getResponse(): string; diff --git a/src/Plugin/LLModelsProviders/LLMProviderInterface.php b/src/Plugin/LLModelsProviders/LLMProviderInterface.php index 97a186a..12f2d08 100644 --- a/src/Plugin/LLModelsProviders/LLMProviderInterface.php +++ b/src/Plugin/LLModelsProviders/LLMProviderInterface.php @@ -53,8 +53,8 @@ public function completion(Payload $payload): \Generator; * @param \Drupal\llm_services\Model\Payload $payload * The body of the chat request. * - * @return \Generator - * The result of the chat initiation. + * @return \Generator<\Drupal\llm_services\Model\ChatResponseInterface> + * The result of the chat. * * @throws \Drupal\llm_services\Exceptions\CommunicationException */ diff --git a/src/Plugin/LLModelsProviders/Ollama.php b/src/Plugin/LLModelsProviders/Ollama.php index 6a93c23..6ee0c97 100644 --- a/src/Plugin/LLModelsProviders/Ollama.php +++ b/src/Plugin/LLModelsProviders/Ollama.php @@ -7,11 +7,10 @@ use Drupal\Core\Plugin\PluginBase; use Drupal\Core\Plugin\PluginFormInterface; use Drupal\llm_services\Client\Ollama as ClientOllama; +use Drupal\llm_services\Client\OllamaChatResponse; use Drupal\llm_services\Client\OllamaCompletionResponse; -use Drupal\llm_services\Exceptions\CommunicationException; -use Drupal\llm_services\Exceptions\NotSupportedException; +use Drupal\llm_services\Model\MessageRoles; use Drupal\llm_services\Model\Payload; -use GuzzleHttp\Exception\GuzzleException; /** * Ollama integration provider. @@ -36,36 +35,21 @@ public function __construct(array $configuration, $plugin_id, $plugin_definition * {@inheritdoc} */ public function listModels(): array { - try { - return $this->getClient()->listLocalModels(); - } - catch (GuzzleException | \JsonException $exception) { - throw new CommunicationException( - message: 'Error in communicating with LLM services', - previous: $exception, - ); - } + return $this->getClient()->listLocalModels(); } /** * {@inheritdoc} + * + * @throws \JsonException */ public function installModel(string $modelName): \Generator|string { - try { - return $this->getClient()->install($modelName); - } - catch (GuzzleException $exception) { - throw new CommunicationException( - message: 'Error in communicating with LLM services', - previous: $exception, - ); - } + return $this->getClient()->install($modelName); } /** * {@inheritdoc} * - * @throws \Drupal\llm_services\Exceptions\CommunicationException * @throws \JsonException * * @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion @@ -73,10 +57,10 @@ public function installModel(string $modelName): \Generator|string { public function completion(Payload $payload): \Generator { foreach ($this->getClient()->completion($payload) as $chunk) { yield new OllamaCompletionResponse( - $chunk['model'], - $chunk['response'], - $chunk['done'], - $chunk['context'] ?? [], + model: $chunk['model'], + response: $chunk['response'], + done: $chunk['done'], + context: $chunk['context'] ?? [], ); } } @@ -84,11 +68,20 @@ public function completion(Payload $payload): \Generator { /** * {@inheritdoc} * + * @throws \JsonException + * * @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion */ public function chat(Payload $payload): \Generator { - // @todo Implement chatCompletions() method. - throw new NotSupportedException(); + foreach ($this->getClient()->chat($payload) as $chunk) { + yield new OllamaChatResponse( + model: $chunk['model'], + content: $chunk['message']['content'] ?? '', + role: $chunk['message']['role'] ? MessageRoles::from($chunk['message']['role']) : MessageRoles::Assistant, + images: $chunk['message']['images'] ?? [], + done: $chunk['done'], + ); + } } /**