Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1.7.1 #96

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions modules/ocha_ai_tag/src/Services/OchaAiTagTagger.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use Drupal\Core\Cache\Cache;
use Drupal\Core\Cache\CacheBackendInterface;
use Drupal\Core\Config\ConfigFactoryInterface;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\Database\Connection;
use Drupal\Core\Logger\LoggerChannelFactoryInterface;
use Drupal\Core\Session\AccountProxyInterface;
Expand Down Expand Up @@ -37,6 +38,13 @@ class OchaAiTagTagger extends OchaAiChat {
*/
protected CacheBackendInterface $cacheBackend;

/**
* The AI tagger config.
*
* @var \Drupal\Core\Config\ImmutableConfig
*/
protected ImmutableConfig $config;

/**
* Vocabulary mapping.
*
Expand Down
18 changes: 18 additions & 0 deletions src/Plugin/CompletionPluginInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ interface CompletionPluginInterface {
*/
public function answer(string $question, string $context): string;

/**
* Perform a completion query.
*
* @param string $prompt
* Prompt.
* @param string $system_prompt
* Optional system prompt.
* @param array $parameters
* Optional parameters for the payload: max_tokens, temperature, top_p.
* @param bool $raw
* Whether to return the raw output text or let the plugin do some
* processing if any.
*
* @return string|null
* The model output text or NULL in case of error when querying the model.
*/
public function query(string $prompt, string $system_prompt = '', array $parameters = [], bool $raw = TRUE): ?string;

/**
* Get the prompt template.
*
Expand Down
40 changes: 29 additions & 11 deletions src/Plugin/ocha_ai/Completion/AwsBedrock.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,20 @@ public function answer(string $question, string $context): string {
return '';
}

return $this->query($prompt, raw: FALSE) ?? '';
}

/**
* {@inheritdoc}
*/
public function query(string $prompt, string $system_prompt = '', array $parameters = [], bool $raw = TRUE): ?string {
if (empty($prompt)) {
return '';
}

$payload = [
'accept' => 'application/json',
'body' => json_encode($this->generateRequestBody($prompt)),
'body' => json_encode($this->generateRequestBody($prompt, $parameters)),
'contentType' => 'application/json',
'modelId' => $this->getPluginSetting('model'),
];
Expand All @@ -67,20 +78,24 @@ public function answer(string $question, string $context): string {
return '';
}

return $this->parseResponseBody($data);
return $this->parseResponseBody($data, $raw);
}

/**
* Generate the request body for the completion.
*
* @param string $prompt
* Prompt.
* @param array $parameters
* Parameters for the payload: max_tokens, temperature, top_p.
*
* @return array
* Request body.
*/
protected function generateRequestBody(string $prompt): array {
$max_tokens = (int) $this->getPluginSetting('max_tokens', 512);
protected function generateRequestBody(string $prompt, array $parameters = []): array {
$max_tokens = (int) ($parameters['max_tokens'] ?? $this->getPluginSetting('max_tokens', 512));
$temperature = (float) ($parameters['temperature'] ?? 0.0);
$top_p = (float) ($parameters['top_p'] ?? 0.9);

switch ($this->getPluginSetting('model')) {
case 'amazon.titan-text-express-v1':
Expand All @@ -90,16 +105,16 @@ protected function generateRequestBody(string $prompt): array {
'maxTokenCount' => $max_tokens,
// @todo adjust based on the prompt?
'stopSequences' => [],
'temperature' => 0.0,
'topP' => 0.9,
'temperature' => $temperature,
'topP' => $top_p,
],
];

case 'anthropic.claude-instant-v1':
return [
'prompt' => "\n\nHuman:$prompt\n\nAssistant:",
'temperature' => 0.0,
'top_p' => 0.9,
'temperature' => $temperature,
'top_p' => $top_p,
'top_k' => 0,
'max_tokens_to_sample' => $max_tokens,
'stop_sequences' => ["\n\nHuman:"],
Expand All @@ -109,8 +124,8 @@ protected function generateRequestBody(string $prompt): array {
case 'cohere.command-light-text-v14':
return [
'prompt' => $prompt,
'temperature' => 0.0,
'p' => 0.9,
'temperature' => $temperature,
'p' => $top_p,
'k' => 0.0,
'max_tokens' => $max_tokens,
'stop_sequences' => [],
Expand All @@ -129,11 +144,14 @@ protected function generateRequestBody(string $prompt): array {
*
* @param array $data
* Decoded response.
* @param bool $raw
* Whether to return the raw output text or let the plugin do some
* processing if any.
*
* @return string
* The generated text.
*/
protected function parseResponseBody(array $data): string {
protected function parseResponseBody(array $data, bool $raw = TRUE): string {
switch ($this->getPluginSetting('model')) {
case 'amazon.titan-text-express-v1':
return trim($data['results'][0]['outputText'] ?? '');
Expand Down
16 changes: 11 additions & 5 deletions src/Plugin/ocha_ai/Completion/AwsBedrockTitanTextPremierV1.php
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,36 @@ public function getPromptTemplate(): string {
/**
* {@inheritdoc}
*/
protected function generateRequestBody(string $prompt): array {
$max_tokens = (int) $this->getPluginSetting('max_tokens', 512);
protected function generateRequestBody(string $prompt, array $parameters = []): array {
$max_tokens = (int) ($parameters['max_tokens'] ?? $this->getPluginSetting('max_tokens', 512));
$temperature = (float) ($parameters['temperature'] ?? 0.0);
$top_p = (float) ($parameters['top_p'] ?? 0.9);

return [
'inputText' => $prompt,
'textGenerationConfig' => [
'maxTokenCount' => $max_tokens,
// @todo adjust based on the prompt?
'stopSequences' => [],
'temperature' => 0.0,
'topP' => 0.9,
'temperature' => $temperature,
'topP' => $top_p,
],
];
}

/**
* {@inheritdoc}
*/
protected function parseResponseBody(array $data): string {
protected function parseResponseBody(array $data, bool $raw = TRUE): string {
$response = trim($data['results'][0]['outputText'] ?? '');
if ($response === '') {
return '';
}

if ($raw) {
return $response;
}

// Extract the answer.
$start = mb_strpos($response, '<answer>');
$end = mb_strpos($response, '</answer>');
Expand Down
29 changes: 22 additions & 7 deletions src/Plugin/ocha_ai/Completion/AzureOpenAi.php
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,38 @@ public function answer(string $question, string $context): string {
return '';
}

return $this->query($question, $prompt) ?? '';
}

/**
* {@inheritdoc}
*/
public function query(string $prompt, string $system_prompt = '', array $parameters = [], bool $process = FALSE): ?string {
if (empty($prompt)) {
return '';
}

$max_tokens = (int) ($parameters['max_tokens'] ?? $this->getPluginSetting('max_tokens', 512));
$temperature = (float) ($parameters['temperature'] ?? 0.0);
$top_p = (float) ($parameters['top_p'] ?? 0.9);

$messages = [
[
'role' => 'system',
'content' => $prompt,
'content' => $system_prompt ?: 'You are a helpful assistant.',
],
[
'role' => 'user',
'content' => $question,
'content' => $prompt,
],
];

$payload = [
'model' => $this->getPluginSetting('model'),
'messages' => $messages,
'temperature' => 0.0,
'top_p' => 0.9,
'max_tokens' => (int) $this->getPluginSetting('max_tokens', 512),
'temperature' => $temperature,
'top_p' => $top_p,
'max_tokens' => $max_tokens,
];

try {
Expand All @@ -70,7 +85,7 @@ public function answer(string $question, string $context): string {
$this->getLogger()->error(strtr('Completion request failed with: @error.', [
'@error' => $exception->getMessage(),
]));
return '';
return NULL;
}

try {
Expand All @@ -80,7 +95,7 @@ public function answer(string $question, string $context): string {
$this->getLogger()->error(strtr('Unable to retrieve completion result data: @error.', [
'@error' => $exception->getMessage(),
]));
return '';
return NULL;
}

return trim($data['choices'][0]['message']['content'] ?? '');
Expand Down
Loading