Skip to content

Commit

Permalink
1452: Added chat support with context
Browse files Browse the repository at this point in the history
  • Loading branch information
cableman committed May 28, 2024
1 parent b6dfe0c commit c58cce2
Show file tree
Hide file tree
Showing 11 changed files with 358 additions and 50 deletions.
7 changes: 7 additions & 0 deletions drush.services.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,10 @@ services:
- '@plugin.manager.llm_services'
tags:
- { name: console.command }

llm.service.char.command:
class: Drupal\llm_services\Commands\ModelChatCommand
arguments:
- '@plugin.manager.llm_services'
tags:
- { name: console.command }
66 changes: 62 additions & 4 deletions src/Client/Ollama.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Ollama {
* @param string $url
* The URL of the Ollama server.
* @param int $port
* The port that Ollama is listening on.
* The port that Ollama is listening at.
*/
public function __construct(
private readonly string $url,
Expand All @@ -42,15 +42,13 @@ public function __construct(
* @return array<string, array<string, string>>
* Basic information about the models.
*
* @throws \GuzzleHttp\Exception\GuzzleException
* @throws \JsonException
* @throws \Drupal\llm_services\Exceptions\CommunicationException
*/
public function listLocalModels(): array {
$response = $this->call(method: 'get', uri: '/api/tags');
$data = $response->getBody()->getContents();
$data = json_decode($data, TRUE);

// @todo Change to value objects.
$models = [];
foreach ($data['models'] as $item) {
$models[$item['model']] = [
Expand Down Expand Up @@ -135,6 +133,66 @@ public function completion(Payload $payload): \Generator {
}
}

/**
* Chat with a model.
*
* @param \Drupal\llm_services\Model\Payload $payload
* The question to ask the module and the chat history.
*
* @return \Generator
* The response from the model as it completes it.
*
* @throws \Drupal\llm_services\Exceptions\CommunicationException
* @throws \JsonException
*/
public function chat(Payload $payload): \Generator {
$response = $this->call(method: 'post', uri: '/api/chat', options: [
'json' => [
'model' => $payload->model,
'messages' => $this->chatMessagesAsArray($payload),
'stream' => TRUE,
],
'headers' => [
'Content-Type' => 'application/json',
],
RequestOptions::CONNECT_TIMEOUT => 10,
RequestOptions::TIMEOUT => 300,
RequestOptions::STREAM => TRUE,
]);

$body = $response->getBody();
while (!$body->eof()) {
$data = $body->read(1024);
yield from $this->parse($data);
}
}

/**
* Take all payload messages and change them into an array.
*
* This array of messages is used to give the model some chat context to make
* the interaction appear more like real char with a person.
*
* @param \Drupal\llm_services\Model\Payload $payload
* The payload sent to the chat function.
*
* @return array
* Array of messages to send to Ollama.
*
* @see https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-history
*/
private function chatMessagesAsArray(Payload $payload): array {
$messages = [];
foreach ($payload->messages as $message) {
$messages[] = [
'content' => $message->content,
'role' => $message->role->value,
];
}

return $messages;
}

/**
* Parse LLM stream.
*
Expand Down
71 changes: 71 additions & 0 deletions src/Client/OllamaChatResponse.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
<?php

namespace Drupal\llm_services\Client;

use Drupal\llm_services\Model\ChatResponseInterface;
use Drupal\llm_services\Model\MessageRoles;

/**
* This class represents a completion response in the Ollama provider.
*/
readonly class OllamaChatResponse implements ChatResponseInterface {

/**
* Default constructor.
*
* @param string $model
* Name of the model.
* @param string $content
* The content of the message from the model.
* @param \Drupal\llm_services\Model\MessageRoles $role
* The role of the message.
* @param array $images
* Base64 encoded array of images.
* @param bool $done
* The module completion state.
*/
public function __construct(
private string $model,
private string $content,
private MessageRoles $role,
private array $images,
private bool $done,
) {
}

/**
* {@inheritdoc}
*/
public function getModel(): string {
return $this->model;
}

/**
* {@inheritdoc}
*/
public function getStatus(): bool {
return $this->done;
}

/**
* {@inheritdoc}
*/
public function getContent(): string {
return $this->content;
}

/**
* {@inheritdoc}
*/
public function getRole(): MessageRoles {
return $this->role;
}

/**
* {@inheritdoc}
*/
public function getImages(): array {
return $this->images;
}

}
134 changes: 134 additions & 0 deletions src/Commands/ModelChatCommand.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<?php

declare(strict_types=1);

namespace Drupal\llm_services\Commands;

use Drupal\llm_services\Model\Message;
use Drupal\llm_services\Model\MessageRoles;
use Drupal\llm_services\Model\Payload;
use Drupal\llm_services\Plugin\LLModelProviderManager;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
use Symfony\Component\Console\Input\InputInterface;
use Symfony\Component\Console\Input\InputOption;
use Symfony\Component\Console\Output\OutputInterface;
use Symfony\Component\Console\Question\Question;

/**
* Chat with a model through a provider.
*/
class ModelChatCommand extends Command {

/**
* Default constructor.
*
* @param \Drupal\llm_services\Plugin\LLModelProviderManager $providerManager
* The provider manager.
*/
public function __construct(
private readonly LLModelProviderManager $providerManager,
) {
parent::__construct();
}

/**
* {@inheritDoc}
*/
protected function configure(): void {
$this
->setName('llm:model:chat')
->setDescription('Chat with model (use ctrl+c to stop chatting)')
->addUsage('llm:model:chat ollama llama3')
->addArgument(
name: 'provider',
mode: InputArgument::REQUIRED,
description: 'Name of the provider (plugin).'
)
->addArgument(
name: 'name',
mode: InputArgument::REQUIRED,
description: 'Name of the model to use.'
)
->addOption(
name: 'system-prompt',
mode: InputOption::VALUE_REQUIRED,
description: 'System message to instruct the llm have to behave.',
default: 'Use the following pieces of context to answer the users question. If you don\'t know the answer, just say that you don\'t know, don\'t try to make up an answer.'
)
->addOption(
name: 'temperature',
mode: InputOption::VALUE_REQUIRED,
description: 'The temperature of the model. Increasing the temperature will make the model answer more creatively.',
default: '0.8'
)
->addOption(
name: 'top-k',
mode: InputOption::VALUE_REQUIRED,
description: 'Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers.',
default: '40'
)
->addOption(
name: 'top-p',
mode: InputOption::VALUE_REQUIRED,
description: 'A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.',
default: '0.9'
);
}

/**
* {@inheritDoc}
*/
protected function execute(InputInterface $input, OutputInterface $output): int {
$providerName = $input->getArgument('provider');
$name = $input->getArgument('name');

$systemPrompt = $input->getOption('system-prompt');
$temperature = $input->getOption('temperature');
$topK = $input->getOption('top-k');
$topP = $input->getOption('top-p');

$provider = $this->providerManager->createInstance($providerName);

// Build configuration.
$payLoad = new Payload();
$payLoad->model = $name;
$payLoad->options = [
'temperature' => $temperature,
'top_k' => $topK,
'top_p' => $topP,
];
$msg = new Message();
$msg->role = MessageRoles::System;
$msg->content = $systemPrompt;
$payLoad->messages[] = $msg;

$helper = $this->getHelper('question');
$question = new Question('Message: ', '');

// Keep cheating with the user. Not optimal, but okay for now.
while (TRUE) {
// Query the next question.
$output->write("\n");
$msg = new Message();
$msg->role = MessageRoles::User;
$msg->content = $helper->ask($input, $output, $question);
$payLoad->messages[] = $msg;
$output->write("\n");

$answer = '';
foreach ($provider->chat($payLoad) as $res) {
$output->write($res->getContent());
$answer .= $res->getContent();
}
$output->write("\n");

// Add answer as context to the next question.
$msg = new Message();
$msg->role = MessageRoles::Assistant;
$msg->content = $answer;
$payLoad->messages[] = $msg;
}
}

}
8 changes: 2 additions & 6 deletions src/Commands/ModelCompletionCommand.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
use Symfony\Component\Console\Output\OutputInterface;

/**
* This is a literal copy of the example Symfony Console command
* from the documentation.
*
* See:
* http://symfony.com/doc/2.7/components/console/introduction.html#creating-a-basic-command
* Make a completion request against a provider model.
*/
class ModelCompletionCommand extends Command {

Expand All @@ -41,7 +37,7 @@ protected function configure(): void {
$this
->setName('llm:model:completion')
->setDescription('Make a completion request to a model')
->addUsage('llm:model:completion ollama llama2 "Why is the sky blue?')
->addUsage('llm:model:completion ollama llama3 "Why is the sky blue?')
->addArgument(
name: 'provider',
mode: InputArgument::REQUIRED,
Expand Down
5 changes: 1 addition & 4 deletions src/Commands/ProviderInstallCommand.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

namespace Drupal\llm_services\Commands;

use Drupal\llm_services\Model\Message;
use Drupal\llm_services\Model\Payload;
use Drupal\llm_services\Plugin\LLModelProviderManager;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
use Symfony\Component\Console\Input\InputInterface;
use Symfony\Component\Console\Input\InputOption;
use Symfony\Component\Console\Output\OutputInterface;

/**
Expand All @@ -37,7 +34,7 @@ protected function configure(): void {
$this
->setName('llm:provider:install')
->setDescription('Install model in provider')
->addUsage('llm:install:model ollama llama2')
->addUsage('llm:provider:install ollama llama3')
->addArgument(
name: 'provider',
mode: InputArgument::REQUIRED,
Expand Down
5 changes: 1 addition & 4 deletions src/Commands/ProviderListCommand.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

namespace Drupal\llm_services\Commands;

use Drupal\llm_services\Model\Message;
use Drupal\llm_services\Model\Payload;
use Drupal\llm_services\Plugin\LLModelProviderManager;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
use Symfony\Component\Console\Input\InputInterface;
use Symfony\Component\Console\Input\InputOption;
use Symfony\Component\Console\Output\OutputInterface;

/**
Expand All @@ -37,7 +34,7 @@ protected function configure(): void {
$this
->setName('llm:provider:list')
->setDescription('Install model in provider')
->addUsage('llm:install:model ollama llama2')
->addUsage('llm:provider:list ollama')
->addArgument(
name: 'provider',
mode: InputArgument::REQUIRED,
Expand Down
Loading

0 comments on commit c58cce2

Please sign in to comment.